metal : handle some edge cases when threadgroup size is not a power of 2

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-26 10:20:45 +03:00
parent 97819a0ba4
commit 7b7ecc0109

View File

@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
nth *= 2;
}
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00);
ggml_metal_kargs_sum_rows args = {
@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
nth *= 2;
}
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);
ggml_metal_kargs_rms_norm args = {
@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
nth *= 2;
}
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);
ggml_metal_kargs_l2_norm args = {
@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
nth *= 2;
}
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
nth = MIN(nth, ne00/4);
ggml_metal_kargs_norm args = {
@ -4986,6 +4990,8 @@ static bool ggml_metal_encode_node(
nth *= 2;
}
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;