mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 11:45:21 +00:00
metal : handle some edge cases when threadgroup size is not a power of 2
ggml-ci
This commit is contained in:
@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_rms_norm args = {
|
||||
@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_l2_norm args = {
|
||||
@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_norm args = {
|
||||
@ -4986,6 +4990,8 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
|
Reference in New Issue
Block a user