From 7b7ecc0109a31149ef2f18386444c5104870b298 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 26 Jun 2025 10:20:45 +0300 Subject: [PATCH] metal : handle some edge cases when threadgroup size is not a power of 2 ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f1eab136c..248fa378e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00); ggml_metal_kargs_sum_rows args = { @@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_rms_norm args = { @@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_l2_norm args = { @@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_norm args = { @@ -4986,6 +4990,8 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); + // when rows are small, we can batch them together in a single threadgroup int nrptg = 1;