diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f1eab136c..248fa378e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00); ggml_metal_kargs_sum_rows args = { @@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_rms_norm args = { @@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_l2_norm args = { @@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_norm args = { @@ -4986,6 +4990,8 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); + // when rows are small, we can batch them together in a single threadgroup int nrptg = 1;