|
|
@ -269,6 +269,12 @@ enum ggml_metal_kernel_type {
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
|
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
|
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
|
|
@ -300,12 +306,14 @@ enum ggml_metal_kernel_type {
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
|
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
|
|
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
|
|
@ -585,6 +593,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
|
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
|
|
|
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
|
|
|
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
|
|
|
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
|
|
|
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
|
|
|
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
|
|
|
|
|
|
|
GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
|
|
|
|
|
|
|
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
|
|
|
|
|
|
|
(int) kernel->pipeline.threadExecutionWidth); \
|
|
|
|
[metal_function release]; \
|
|
|
|
[metal_function release]; \
|
|
|
|
if (error) { \
|
|
|
|
if (error) { \
|
|
|
|
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
|
|
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
|
@ -777,6 +788,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && has_bfloat);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && has_bfloat);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && has_bfloat);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && has_bfloat);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && has_bfloat);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && has_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
|
|
@ -808,12 +825,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && has_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
|
|
|
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && has_bfloat);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
|
@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
|
|
|
|
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
|
|
|
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
|
|
|
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
|
|
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
|
|
|
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
|
|
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
|
|
|
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
|
|
|
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
|
|
|
|
|
|
|
|
|
|
|
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
|
|
|
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
|
|
|
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
|
|
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
|
@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} break;
|
|
|
|
} break;
|
|
|
|
|
|
|
|
case GGML_TYPE_BF16:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
switch (ne00) {
|
|
|
|
|
|
|
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
|
|
|
|
|
|
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
|
|
|
|
|
|
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
|
|
|
|
|
|
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
|
|
|
|
|
|
|
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
|
|
|
|
|
|
|
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
|
|
|
|
|
|
|
GGML_LOG_ERROR("add template specialization for this size\n");
|
|
|
|
|
|
|
|
GGML_ABORT("add template specialization for this size");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} break;
|
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
|
{
|
|
|
|
{
|
|
|
|
switch (ne00) {
|
|
|
|
switch (ne00) {
|
|
|
@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
|
|
|
|
{
|
|
|
|
{
|
|
|
|
switch (src1->type) {
|
|
|
|
switch (src1->type) {
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
|
|
|
|
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
|
|
@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
|
|
|
|
{
|
|
|
|
{
|
|
|
|
switch (src1->type) {
|
|
|
|
switch (src1->type) {
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
|
|
|
|
|
|
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
|
|
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
|
|
@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
|
|
|
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
|
|
|
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
|
|
|
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
|
|
|
[encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
|
|
|
|
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
|
|
|
|
[encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
|
|
|
|
[encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
|
|
|
|
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
|
|
|
|
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
|
|
|
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
|
|
|
|
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
|
|
|
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
|
|
|
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
|
|
|
|
[encoder setBytes:&scale length:sizeof( float) atIndex:23];
|
|
|
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
|
|
|
|
[encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
|
|
|
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
|
|
|
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
|
|
|
|
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
|
|
|
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
|
|
|
|
|
|
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
|
|
|
|
|
|
|
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!use_vec_kernel) {
|
|
|
|
if (!use_vec_kernel) {
|
|
|
|
// half8x8 kernel
|
|
|
|
// half8x8 kernel
|
|
|
@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
|
|
|
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
|
|
GGML_ASSERT(nqptg % 8 == 0);
|
|
|
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
|
|
GGML_ASSERT(ncpsg % 32 == 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 2*(2*ncpsg + nqptg)*(nsg)
|
|
|
|
|
|
|
|
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
|
|
|
|
|
|
|
//
|
|
|
|
// 16*32*(nsg)
|
|
|
|
// 16*32*(nsg)
|
|
|
|
// the shared memory needed for the simdgroups to load the KV cache
|
|
|
|
// the shared memory needed for the simdgroups to load the KV cache
|
|
|
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
|
|
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
|
|
|
//
|
|
|
|
//
|
|
|
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
|
|
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
|
|
|
|
|
|
|
|
|
|
|
int64_t nsgmax = 2;
|
|
|
|
int64_t nsgmax = 2;
|
|
|
|
|
|
|
|
|
|
|
@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
|
|
|
|
|
|
// ne00 + 2*ncpsg*(nsg)
|
|
|
|
// ne00 + 2*ncpsg*(nsg)
|
|
|
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
|
|
// for each query, we load it as f16 in shared memory (ne00)
|
|
|
|
// and store the attention scores (nqptg x ncpsg) as f32
|
|
|
|
// and store the soft_max values and the mask
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// 2*ne00*(nsg)
|
|
|
|
// ne00*(nsg)
|
|
|
|
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
|
|
|
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
|
|
|
//
|
|
|
|
//
|
|
|
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
|
|
|
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
|
|
|
|
|
|
|
|
|
|
|
|
int64_t nsgmax = 2;
|
|
|
|
int64_t nsgmax = 2;
|
|
|
|
|
|
|
|
|
|
|
|