mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-28 03:55:06 -04:00
kleidiai: add support for get_rows (#14676)
* kleidiai: add support for get_rows * apply fixes based on code review * apply more fixes based on code review
This commit is contained in:
@ -494,9 +494,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||||||
|
|
||||||
# Fetch KleidiAI sources:
|
# Fetch KleidiAI sources:
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
set(KLEIDIAI_COMMIT_TAG "v1.9.0")
|
set(KLEIDIAI_COMMIT_TAG "v1.11.0")
|
||||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||||
set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017")
|
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2")
|
||||||
|
|
||||||
if (POLICY CMP0135)
|
if (POLICY CMP0135)
|
||||||
cmake_policy(SET CMP0135 NEW)
|
cmake_policy(SET CMP0135 NEW)
|
||||||
|
@ -22,9 +22,94 @@
|
|||||||
|
|
||||||
#include "kai_common.h"
|
#include "kai_common.h"
|
||||||
|
|
||||||
|
#include "simd-mappings.h"
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "kernels.h"
|
||||||
|
|
||||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||||
|
|
||||||
|
static const size_t INT4_PER_BYTE = 2;
|
||||||
|
static const size_t INT4_BITS = 4;
|
||||||
|
static const int Q4_0_ZERO_POINT = 8;
|
||||||
|
const size_t INT4_PER_UINT16 = 4;
|
||||||
|
|
||||||
|
static void dequantize_row_qsi4c32pscalef16(
|
||||||
|
const void *packed_data,
|
||||||
|
int32_t row_idx,
|
||||||
|
int64_t nc,
|
||||||
|
float *out,
|
||||||
|
size_t nr_pack,
|
||||||
|
size_t packed_row_stride,
|
||||||
|
size_t kr,
|
||||||
|
size_t bl,
|
||||||
|
size_t num_bytes_multiplier
|
||||||
|
) {
|
||||||
|
size_t group_idx = row_idx / nr_pack;
|
||||||
|
size_t row_in_group = row_idx % nr_pack;
|
||||||
|
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
|
||||||
|
size_t num_blocks = nc / bl;
|
||||||
|
const uint8_t *block_ptr = packed_group;
|
||||||
|
|
||||||
|
for (size_t b = 0; b < num_blocks; ++b) {
|
||||||
|
uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
|
||||||
|
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
|
||||||
|
|
||||||
|
const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
|
||||||
|
size_t num_segments = bl / kr;
|
||||||
|
size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
|
||||||
|
|
||||||
|
for (size_t s = 0; s < num_segments; ++s) {
|
||||||
|
const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
|
||||||
|
const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
|
||||||
|
for (size_t k = 0; k < num_bytes_per_segment; ++k) {
|
||||||
|
uint8_t byte = qbytes[k] ^ 0x88;
|
||||||
|
int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
|
||||||
|
int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
|
||||||
|
out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
|
||||||
|
out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dequantize_row_qsi4c32ps1s0scalef16(
|
||||||
|
const void *packed_data,
|
||||||
|
int32_t row_idx,
|
||||||
|
int64_t k,
|
||||||
|
float *out,
|
||||||
|
size_t nr,
|
||||||
|
size_t packed_row_stride,
|
||||||
|
size_t kr,
|
||||||
|
size_t bl,
|
||||||
|
size_t num_bytes_multiplier
|
||||||
|
) {
|
||||||
|
const size_t num_blocks = k / bl;
|
||||||
|
const size_t bl4 = bl / INT4_PER_UINT16;
|
||||||
|
|
||||||
|
size_t group_idx = row_idx / nr;
|
||||||
|
size_t row_in_group = row_idx % nr;
|
||||||
|
|
||||||
|
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
|
||||||
|
const uint16_t *qdata = (const uint16_t *)packed_group;
|
||||||
|
const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
|
||||||
|
|
||||||
|
for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
||||||
|
uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
|
||||||
|
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
|
||||||
|
|
||||||
|
for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
|
||||||
|
uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
|
||||||
|
|
||||||
|
for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
|
||||||
|
int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
|
||||||
|
out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_UNUSED(kr);
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||||
#if defined(__ARM_FEATURE_SME)
|
#if defined(__ARM_FEATURE_SME)
|
||||||
{
|
{
|
||||||
@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||||
},
|
},
|
||||||
/* .rhs_info = */ {
|
/* .rhs_info = */ {
|
||||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||||
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||||
|
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
||||||
},
|
},
|
||||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||||
/* .lhs_type = */ GGML_TYPE_F32,
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||||
},
|
},
|
||||||
/* .rhs_info = */ {
|
/* .rhs_info = */ {
|
||||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||||
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
/* .packed_stride = */ NULL,
|
||||||
|
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||||
|
/* .to_float = */ NULL,
|
||||||
},
|
},
|
||||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||||
/* .lhs_type = */ GGML_TYPE_F32,
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||||
},
|
},
|
||||||
/* .rhs_info = */ {
|
/* .rhs_info = */ {
|
||||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
|
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||||
},
|
},
|
||||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||||
/* .lhs_type = */ GGML_TYPE_F32,
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||||
},
|
},
|
||||||
/* .rhs_info = */ {
|
/* .rhs_info = */ {
|
||||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
|
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||||
},
|
},
|
||||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||||
/* .lhs_type = */ GGML_TYPE_F32,
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||||
},
|
},
|
||||||
/* .rhs_info = */ {
|
/* .rhs_info = */ {
|
||||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
|
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||||
},
|
},
|
||||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||||
/* .lhs_type = */ GGML_TYPE_F32,
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||||
},
|
},
|
||||||
/* .rhs_info = */ {
|
/* .rhs_info = */ {
|
||||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
|
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||||
|
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||||
},
|
},
|
||||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||||
/* .lhs_type = */ GGML_TYPE_F32,
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
|
@ -71,12 +71,15 @@ struct rhs_packing_info {
|
|||||||
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
||||||
std::function<size_t(size_t n, size_t k)>
|
std::function<size_t(size_t n, size_t k)>
|
||||||
> packed_size;
|
> packed_size;
|
||||||
|
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
|
||||||
std::variant<
|
std::variant<
|
||||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
||||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
||||||
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
||||||
> pack_func;
|
> pack_func;
|
||||||
|
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
|
||||||
|
size_t kr, size_t bl, size_t num_bytes_multiplier);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_kleidiai_kernels {
|
struct ggml_kleidiai_kernels {
|
||||||
|
@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
|
|||||||
ggml_kleidiai_kernels * kernels;
|
ggml_kleidiai_kernels * kernels;
|
||||||
} static ctx = { CPU_FEATURE_NONE, NULL };
|
} static ctx = { CPU_FEATURE_NONE, NULL };
|
||||||
|
|
||||||
|
static const char* cpu_feature_to_string(cpu_feature f) {
|
||||||
|
switch (f) {
|
||||||
|
case CPU_FEATURE_NONE: return "NONE";
|
||||||
|
case CPU_FEATURE_DOTPROD: return "DOTPROD";
|
||||||
|
case CPU_FEATURE_I8MM: return "I8MM";
|
||||||
|
case CPU_FEATURE_SVE: return "SVE";
|
||||||
|
case CPU_FEATURE_SME: return "SME";
|
||||||
|
default: return "UNKNOWN";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void init_kleidiai_context(void) {
|
static void init_kleidiai_context(void) {
|
||||||
|
|
||||||
ggml_critical_section_start();
|
ggml_critical_section_start();
|
||||||
@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
|
|||||||
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||||
}
|
}
|
||||||
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||||
|
#ifndef NDEBUG
|
||||||
|
if (ctx.kernels) {
|
||||||
|
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
ggml_critical_section_end();
|
ggml_critical_section_end();
|
||||||
}
|
}
|
||||||
@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
|
|||||||
|
|
||||||
class tensor_traits : public ggml::cpu::tensor_traits {
|
class tensor_traits : public ggml::cpu::tensor_traits {
|
||||||
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
||||||
|
if (op->op != GGML_OP_MUL_MAT) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
||||||
GGML_ASSERT(kernels);
|
GGML_ASSERT(kernels);
|
||||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||||
@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||||
return compute_forward_kv_cache(params, dst);
|
return compute_forward_kv_cache(params, dst);
|
||||||
}
|
}
|
||||||
|
} else if (dst->op == GGML_OP_GET_ROWS) {
|
||||||
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||||
|
return compute_forward_get_rows(params, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||||
|
GGML_ASSERT(ctx.kernels);
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
|
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
||||||
|
kernel_info * kernel = &ctx.kernels->gemm;
|
||||||
|
|
||||||
|
const int64_t nc = ne00;
|
||||||
|
const int64_t nr = ggml_nelements(src1);
|
||||||
|
|
||||||
|
const size_t block_rows = kernel->get_nr();
|
||||||
|
const size_t kr = kernel->get_kr();
|
||||||
|
|
||||||
|
const size_t num_bytes_multiplier = sizeof(uint16_t);
|
||||||
|
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int dr = (nr + nth - 1) / nth;
|
||||||
|
const int ir0 = dr * ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int64_t i = ir0; i < ir1; ++i) {
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||||
|
int64_t row_idx = ((const int32_t *)src1->data)[i];
|
||||||
|
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
||||||
|
|
||||||
|
float *out = (float *)((char *)dst->data + i * nb1);
|
||||||
|
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
||||||
|
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
||||||
GGML_ASSERT(ctx.kernels);
|
GGML_ASSERT(ctx.kernels);
|
||||||
const size_t n = tensor->ne[1];
|
const size_t n = tensor->ne[1];
|
||||||
const size_t k = tensor->ne[0];
|
const size_t k = tensor->ne[0];
|
||||||
@ -351,17 +417,12 @@ public:
|
|||||||
size_t kr = ctx.kernels->gemm.get_kr();
|
size_t kr = ctx.kernels->gemm.get_kr();
|
||||||
size_t sr = ctx.kernels->gemm.get_sr();
|
size_t sr = ctx.kernels->gemm.get_sr();
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
|
||||||
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
|
||||||
#endif
|
|
||||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||||
params.lhs_zero_point = 1;
|
params.lhs_zero_point = 1;
|
||||||
params.rhs_zero_point = 8;
|
params.rhs_zero_point = 8;
|
||||||
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
GGML_UNUSED(data_size);
|
GGML_UNUSED(data_size);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
|||||||
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||||
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
||||||
|
|
||||||
GGML_UNUSED(buffer);
|
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
|
GGML_UNUSED(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
||||||
@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
|||||||
GGML_UNUSED(buft);
|
GGML_UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||||
|
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
||||||
|
GGML_ASSERT(ctx.kernels);
|
||||||
|
|
||||||
|
const size_t n = tensor->ne[1];
|
||||||
|
const size_t k = tensor->ne[0];
|
||||||
|
const size_t nr = ctx.kernels->gemm.get_nr();
|
||||||
|
const size_t kr = ctx.kernels->gemm.get_kr();
|
||||||
|
|
||||||
|
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||||
|
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
namespace ggml::cpu::kleidiai {
|
namespace ggml::cpu::kleidiai {
|
||||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||||
if (op->op == GGML_OP_MUL_MAT &&
|
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
||||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||||
op->src[0]->buffer &&
|
op->src[0]->buffer &&
|
||||||
(ggml_n_dims(op->src[0]) == 2) &&
|
(ggml_n_dims(op->src[0]) == 2) &&
|
||||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
||||||
|
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (op->src[1]->type == GGML_TYPE_F32 &&
|
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
|
||||||
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
||||||
if (op->op == GGML_OP_MUL_MAT) {
|
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
|
||||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||||
}
|
}
|
||||||
@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
|||||||
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
||||||
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
||||||
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
||||||
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
/* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
|
||||||
/* .is_host = */ nullptr,
|
/* .is_host = */ nullptr,
|
||||||
},
|
},
|
||||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
||||||
|
Reference in New Issue
Block a user