From 002cb1bb3345967fbe0fa7766c2d94c2da31ef45 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Mon, 11 Aug 2025 09:59:26 +0200 Subject: [PATCH] kleidiai: fix unsigned overflow bug (#15150) * kleidiai: fix unsigned overflow bug * address review comments --- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 3a513a55d..dff8fa244 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -259,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t m_start = 0; const int64_t n_step = static_cast(kernel->get_n_step()); - const int64_t num_threads = KAI_MIN(n / n_step, nth); + int64_t num_threads = KAI_MIN(n / n_step, nth); + if (num_threads <= 0) { + num_threads = 1; + } if (ith < num_threads) { const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); @@ -309,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_ASSERT(kernel); const int ith = params->ith; - const int nth = params->nth; + const int nth_raw = params->nth; + const int nth = nth_raw > 0 ? nth_raw : 1; const size_t k = ne00; const size_t m = ne11; @@ -327,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); const size_t n_start = ith * num_n_per_thread; - size_t n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; + size_t n_to_process = 0; + if (n_start < n) { + n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } } // Calculate number of columns to be processed per thread @@ -361,8 +368,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); - variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, - sizeof(float), -FLT_MAX, FLT_MAX); + if (n_to_process > 0) { + variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + sizeof(float), -FLT_MAX, FLT_MAX); + } return true; }