diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index dd60529fa..04a8d80e1 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) #endif // AMD_MFMA_AVAILABLE #if defined(GGML_USE_HIP) -static int mmq_get_nwarps_host(const int cc) { - return amd_mfma_available(cc) ? 8 : 4; +static int mmq_get_nwarps_host(const int cc, const int warp_size) { + return amd_mfma_available(cc) ? 8 : 256/warp_size; } #else -static int mmq_get_nwarps_host(const int /*cc*/) { - return 8; +static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { + return 256/warp_size; } #endif // (GGML_USE_HIP) static constexpr __device__ int mmq_get_nwarps_device() { -#if defined(GGML_USE_HIP) #if defined(AMD_MFMA_AVAILABLE) return 8; #else - return 4; + return 256/ggml_cuda_get_physical_warp_size(); #endif // AMD_MFMA_AVAILABLE -#else - return 8; -#endif // defined(GGML_USE_HIP) } // ------------------------------------------------------------ @@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; const int warp_size = ggml_cuda_info().devices[id].warp_size; - const int nwarps = mmq_get_nwarps_host(cc); + const int nwarps = mmq_get_nwarps_host(cc, warp_size); const int mmq_y = get_mmq_y_host(cc); const dim3 block_dims(warp_size, nwarps, 1); @@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda const int cc = ggml_cuda_info().devices[id].cc; const size_t smpbo = ggml_cuda_info().devices[id].smpbo; const int warp_size = ggml_cuda_info().devices[id].warp_size; - const int nwarps = mmq_get_nwarps_host(cc); + const int nwarps = mmq_get_nwarps_host(cc, warp_size); const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc);