Vulkan: Add DP4A MMQ and Q8_1 quantization shader (#12135)

* Vulkan: Add DP4A MMQ and Q8_1 quantization shader

* Add q4_0 x q8_1 matrix matrix multiplication support

* Vulkan: Add int8 coopmat MMQ support

* Vulkan: Add q4_1, q5_0 and q5_1 quants, improve integer dot code

* Add GL_EXT_integer_dot_product check

* Remove ggml changes, fix mmq pipeline picker

* Remove ggml changes, restore Intel coopmat behaviour

* Fix glsl compile attempt when integer vec dot is not supported

* Remove redundant code, use non-saturating integer dot, enable all matmul sizes for mmq

* Remove redundant comment

* Fix integer dot check

* Fix compile issue with unsupported int dot glslc

* Update Windows build Vulkan SDK version
This commit is contained in:
0cc4m
2025-03-31 14:37:01 +02:00
committed by GitHub
parent 1790e73157
commit a8a1f33567
10 changed files with 1146 additions and 95 deletions

View File

@ -234,6 +234,8 @@ struct vk_device_struct {
bool float_controls_rte_fp16;
bool subgroup_add;
bool integer_dot_product;
bool subgroup_size_control;
uint32_t subgroup_min_size;
uint32_t subgroup_max_size;
@ -245,6 +247,12 @@ struct vk_device_struct {
uint32_t coopmat_m;
uint32_t coopmat_n;
uint32_t coopmat_k;
bool coopmat_int_support;
uint32_t coopmat_int_m;
uint32_t coopmat_int_n;
uint32_t coopmat_int_k;
bool coopmat2;
size_t idx;
@ -263,10 +271,10 @@ struct vk_device_struct {
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
vk_matmul_pipeline2 pipeline_matmul_f16;
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
vk_pipeline pipeline_matmul_split_k_reduce;
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
vk_matmul_pipeline pipeline_matmul_id_f32 {};
vk_matmul_pipeline2 pipeline_matmul_id_f16;
@ -274,6 +282,9 @@ struct vk_device_struct {
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
vk_pipeline pipeline_matmul_split_k_reduce;
vk_pipeline pipeline_quantize_q8_1;
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@ -640,6 +651,13 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t H;
};
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
};
// Allow pre-recording command buffers
struct vk_staging_memcpy {
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@ -649,13 +667,6 @@ struct vk_staging_memcpy {
size_t n;
};
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
};
struct vk_context_struct {
vk_submission * s;
std::vector<vk_sequence> seqs;
@ -1598,6 +1609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
// mulmat
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
@ -1662,6 +1674,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
const uint32_t tm_int_l = device->coopmat_int_support ? device->coopmat_int_m : 4;
const uint32_t tm_int_m = device->coopmat_int_support ? device->coopmat_int_m : 4;
const uint32_t tm_int_s = device->coopmat_int_support ? device->coopmat_int_m : 2;
const uint32_t tn_int_l = device->coopmat_int_support ? device->coopmat_int_n : 4;
const uint32_t tn_int_m = device->coopmat_int_support ? device->coopmat_int_n : 2;
const uint32_t tn_int_s = device->coopmat_int_support ? device->coopmat_int_n : 2;
const uint32_t tk_int_l = device->coopmat_int_support ? device->coopmat_int_k : 1;
const uint32_t tk_int_m = device->coopmat_int_support ? device->coopmat_int_k : 1;
const uint32_t tk_int_s = device->coopmat_int_support ? device->coopmat_int_k : 1;
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_int_l, tn_int_l, tk_int_l, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_int_m, tn_int_m, tk_int_m, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_int_s, tn_int_s, tk_int_s, subgroup_size_8 };
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
@ -2000,6 +2026,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@ -2031,6 +2065,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
#endif
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@ -2056,6 +2100,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
} else {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
@ -2073,6 +2118,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@ -2099,6 +2152,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
}
#endif
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@ -2132,7 +2195,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t rm_stdq = 1;
uint32_t rm_kq = 2;
if (device->vendor_id == VK_VENDOR_ID_AMD) {
if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
if (device->architecture == AMD_GCN) {
rm_stdq = 2;
rm_kq = 4;
}
@ -2266,6 +2329,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
if (device->subgroup_add && device->subgroup_require_full_support) {
@ -2452,6 +2516,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool pipeline_robustness = false;
bool coopmat2_support = false;
device->coopmat_support = false;
device->integer_dot_product = false;
for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@ -2477,6 +2542,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
coopmat2_support = true;
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
device->integer_dot_product = true;
#endif
}
}
@ -2490,6 +2560,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk::PhysicalDeviceVulkan11Properties vk11_props;
vk::PhysicalDeviceVulkan12Properties vk12_props;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
@ -2524,6 +2595,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
}
#endif
if (device->integer_dot_product) {
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
}
device->physical_device.getProperties2(&props2);
device->properties = props2.properties;
device->vendor_id = device->properties.vendorID;
@ -2570,6 +2646,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->coopmat_support = false;
}
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
// Try to find a non-graphics compute queue and transfer-focused queues
@ -2662,6 +2740,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_KHR_maintenance4");
}
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
if (device->integer_dot_product) {
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
}
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@ -2831,6 +2917,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->coopmat_acc_f16_support = true;
}
}
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
device->coopmat_int_m == 0
) {
device->coopmat_int_support = true;
device->coopmat_int_m = prop.MSize;
device->coopmat_int_n = prop.NSize;
device->coopmat_int_k = prop.KSize;
}
}
@ -2935,25 +3032,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
vk::PhysicalDevice physical_device = devices[dev_num];
std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceMaintenance3Properties props3;
vk::PhysicalDeviceSubgroupProperties subgroup_props;
vk::PhysicalDeviceDriverProperties driver_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
subgroup_props.pNext = &driver_props;
physical_device.getProperties2(&props2);
vk_device_architecture arch = get_device_architecture(physical_device);
uint32_t default_subgroup_size = get_subgroup_size("", arch);
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
bool fp16_storage = false;
bool fp16_compute = false;
bool coopmat_support = false;
bool coopmat2_support = false;
bool integer_dot_product = false;
for (auto properties : ext_props) {
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@ -2969,27 +3052,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
coopmat2_support = true;
#endif
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
integer_dot_product = true;
#endif
}
}
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
coopmat_support = false;
}
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceMaintenance3Properties props3;
vk::PhysicalDeviceSubgroupProperties subgroup_props;
vk::PhysicalDeviceDriverProperties driver_props;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
subgroup_props.pNext = &driver_props;
// Pointer to the last chain element
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
if (integer_dot_product) {
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
}
physical_device.getProperties2(&props2);
VkPhysicalDeviceFeatures2 device_features2;
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
device_features2.pNext = nullptr;
device_features2.features = (VkPhysicalDeviceFeatures)device_features;
VkPhysicalDeviceVulkan11Features vk11_features;
vk11_features.pNext = nullptr;
@ -3002,7 +3102,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
vk11_features.pNext = &vk12_features;
// Pointer to the last chain element
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
last_struct = (VkBaseOutStructure *)&vk12_features;
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
@ -3014,20 +3114,37 @@ static void ggml_vk_print_gpu_info(size_t idx) {
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
last_struct = (VkBaseOutStructure *)&coopmat_features;
}
#endif
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
if (integer_dot_product) {
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
}
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
fp16 = fp16 && vk12_features.shaderFloat16;
coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
#endif
uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
integer_dot_product = integer_dot_product
&& shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
&& shader_integer_dot_product_features.shaderIntegerDotProduct;
coopmat_support = coopmat_support
&& coopmat_features.cooperativeMatrix
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@ -3293,6 +3410,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
}
}
// MMQ
if (src1_type == GGML_TYPE_Q8_1) {
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
return nullptr;
}
return pipelines;
}
if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
return nullptr;
}
@ -3585,8 +3713,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
return s;
}
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@ -4016,8 +4142,8 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
return split_k;
}
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
if (ctx->device->coopmat2) {
// Use large shader when the N dimension is greater than the medium shader's tile size
@ -4042,9 +4168,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
return aligned ? mmp->a_l : mmp->l;
}
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
}
static void ggml_vk_matmul(
@ -4054,7 +4180,7 @@ static void ggml_vk_matmul(
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
uint32_t padded_n) {
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
ggml_vk_sync_buffers(subctx);
if (split_k == 1) {
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
@ -4072,7 +4198,7 @@ static void ggml_vk_matmul(
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
}
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
if (ctx->device->coopmat2) {
@ -4214,6 +4340,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
}
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
switch(type) {
case GGML_TYPE_Q8_1:
return ctx->device->pipeline_quantize_q8_1;
default:
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
GGML_ABORT("fatal error");
}
}
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
}
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@ -4265,10 +4410,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
// Check for mmq first
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
if (mmp == nullptr) {
// Fall back to f16 dequant mul mat
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
quantize_y = false;
}
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
if (qx_needs_dequant) {
// Fall back to dequant + f16 mulmat
@ -4278,13 +4432,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
const int x_ne = ne01 * ne00;
const int y_ne = padded_n * ne10;
const int d_ne = ne11 * ne01;
@ -4294,11 +4448,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
const uint64_t d_sz = sizeof(float) * d_ne;
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
vk_pipeline to_q8_1 = nullptr;
if (x_non_contig) {
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
@ -4313,6 +4468,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
}
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@ -4326,7 +4485,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
ctx->prealloc_size_x = x_sz_upd;
}
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
ctx->prealloc_size_y = y_sz_upd;
}
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
@ -4341,6 +4500,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
}
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
}
if (split_k > 1) {
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
}
@ -4376,6 +4538,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (qy_needs_dequant) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
} else if (quantize_y) {
d_Y = ctx->prealloc_y;
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
} else {
d_Y = d_Qy;
y_buf_offset = qy_buf_offset;
@ -4392,6 +4557,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (y_non_contig) {
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
}
if (quantize_y) {
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
}
uint32_t stride_batch_x = ne00*ne01;
uint32_t stride_batch_y = ne10*ne11;
@ -4400,7 +4568,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
}
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
}
@ -6929,6 +7097,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
}
}
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
}
ggml_pipeline_allocate_descriptor_sets(ctx->device);
vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@ -7177,6 +7349,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
}
ggml_pipeline_allocate_descriptor_sets(ctx->device);
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
@ -7236,66 +7412,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
free(x_chk);
}
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
// This does not work without ggml q8_1 quantization support
//
// typedef uint16_t ggml_half;
// typedef uint32_t ggml_half2;
//
// #define QK8_1 32
// typedef struct {
// union {
// struct {
// ggml_half d; // delta
// ggml_half s; // d * sum(qs[i])
// } GGML_COMMON_AGGR_S;
// ggml_half2 ds;
// } GGML_COMMON_AGGR_U;
// int8_t qs[QK8_1]; // quants
// } block_q8_1;
//
// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
// GGML_ASSERT(quant == GGML_TYPE_Q8_1);
//
// const size_t x_sz = sizeof(float) * ne;
// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
// float * x = (float *) malloc(x_sz);
// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
//
// for (size_t i = 0; i < ne; i++) {
// x[i] = rand() / (float)RAND_MAX;
// }
//
// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
//
// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
//
// if (ctx->device->need_compiles) {
// ggml_vk_load_shaders(ctx->device);
// }
//
// ggml_pipeline_allocate_descriptor_sets(ctx->device);
//
// ggml_vk_buffer_write(x_buf, 0, x, x_sz);
//
// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
// ggml_vk_ctx_begin(ctx->device, subctx);
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
// ggml_vk_ctx_end(subctx);
//
// auto begin = std::chrono::high_resolution_clock::now();
//
// ggml_vk_submit(subctx, ctx->fence);
// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
// ctx->device->device.resetFences({ ctx->fence });
//
// auto end = std::chrono::high_resolution_clock::now();
//
// double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
//
// ggml_vk_quantize_data(x, qx_res, ne, quant);
//
// int first_err = -1;
//
// for (size_t i = 0; i < ne / 32; i++) {
// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
//
// if (first_err < 0 && error > 0.1) {
// first_err = i;
// }
//
// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
//
// if (first_err < 0 && error > 0.1) {
// first_err = i;
// }
//
// for (size_t j = 0; j < 32; j++) {
// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
//
// if (first_err < 0 && error > 1) {
// first_err = i;
// }
// }
// }
//
// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
//
// if (first_err != -1) {
// std::cerr << "first_error = " << first_err << std::endl;
// std::cerr << "Actual result: " << std::endl << std::endl;
// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
// for (size_t j = 0; j < 32; j++) {
// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
// }
// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
// for (size_t j = 0; j < 32; j++) {
// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
// }
// std::cerr << std::endl;
// }
//
// ggml_vk_destroy_buffer(x_buf);
// ggml_vk_destroy_buffer(qx_buf);
//
// free(x);
// free(qx);
// free(qx_res);
// }
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
const size_t x_ne = m * k * batch;
const size_t y_ne = k * n * batch;
const size_t d_ne = m * n * batch;
vk_matmul_pipeline2 * pipelines;
if (mmq) {
pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
} else {
pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
}
const bool fp16acc = ctx->device->fp16;
vk_pipeline p;
std::string shname;
if (shader_size == 0) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
} else if (shader_size == 1) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
} else if (shader_size == 2) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
} else {
GGML_ASSERT(0);
}
const size_t kpad = ggml_vk_align_size(k, p->align);
const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
if (k != kpad) {
if (mmq || k != kpad) {
if (shader_size == 0) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
shname = std::string(ggml_type_name(quant)) + "_S";
} else if (shader_size == 1) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
shname = std::string(ggml_type_name(quant)) + "_M";
} else if (shader_size == 2) {
p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
shname = std::string(ggml_type_name(quant)) + "_L";
} else {
GGML_ASSERT(0);
}
}
if (p == nullptr) {
std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
return;
}
const size_t x_sz = sizeof(float) * x_ne;
const size_t y_sz = sizeof(float) * y_ne;
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
const size_t d_sz = sizeof(float) * d_ne;
float * x = (float *) malloc(x_sz);
float * y = (float *) malloc(y_sz);
void * qx = malloc(qx_sz);
vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
float * d = (float *) malloc(d_sz);
float * d_chk = (float *) malloc(d_sz);
for (size_t i = 0; i < x_ne; i++) {
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
// x[i] = i % k;
}
ggml_vk_quantize_data(x, qx, x_ne, quant);
for (size_t i = 0; i < y_ne; i++) {
// y[i] = rand() / (float)RAND_MAX;
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
// y[i] = (i % k == i / k) ? 1.0f : 0.0f;
// y[i] = i % k;
}
ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
@ -7310,6 +7618,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
}
}
if (mmq) {
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
}
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
}
ggml_pipeline_allocate_descriptor_sets(ctx->device);
@ -7318,13 +7633,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
ggml_vk_ctx_begin(ctx->device, subctx);
for (size_t i = 0; i < num_it; i++) {
ggml_vk_matmul(
ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
);
if (mmq) {
for (size_t i = 0; i < num_it; i++) {
ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
ggml_vk_matmul(
ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
);
}
} else {
for (size_t i = 0; i < num_it; i++) {
ggml_vk_matmul(
ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
);
}
}
ggml_vk_ctx_end(subctx);
@ -7382,7 +7709,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
std::cerr << "TEST dequant matmul " << shname;
if (mmq) {
std::cerr << " mmq";
}
std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
if (avg_err > 0.01 || std::isnan(avg_err)) {
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@ -7392,6 +7723,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
std::cerr << "Expected result: " << std::endl << std::endl;
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
std::cerr << "src0: " << std::endl << std::endl;
ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
std::cerr << std::endl;
std::cerr << "src1: " << std::endl << std::endl;
ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
if (split_k > 1) {
float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
@ -7414,6 +7751,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
ggml_vk_destroy_buffer(qx_buf);
ggml_vk_destroy_buffer(y_buf);
ggml_vk_destroy_buffer(qy_buf);
ggml_vk_destroy_buffer(d_buf);
free(x);
@ -7446,7 +7784,25 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
128, 49, 49,
4096, 49, 4096,
};
const size_t num_it = 100;
const size_t num_it = 1;
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
abort();
for (size_t i = 0; i < vals.size(); i += 3) {
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
@ -9258,7 +9614,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
}
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
const float *params = (const float *)tensor->op_params;
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
} else if (tensor->op == GGML_OP_MUL_MAT) {
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
@ -9275,7 +9631,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_UPSCALE) {
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
} else if (tensor->op == GGML_OP_SCALE) {
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
} else if (tensor->op == GGML_OP_SQR) {
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_SIN) {
@ -9283,7 +9640,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_COS) {
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
} else if (tensor->op == GGML_OP_PAD) {
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
} else if (tensor->op == GGML_OP_REPEAT) {
@ -9297,7 +9655,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
const float * float_params = (const float *)tensor->op_params;
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
} else if (tensor->op == GGML_OP_RMS_NORM) {
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
@ -9310,14 +9669,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
} else if (tensor->op == GGML_OP_SOFT_MAX) {
if (src1 != nullptr) {
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
} else {
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
}
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2];