mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-14 12:19:48 -04:00
CUDA: fix MMQ nwarps for AMD with warp_size==32 (#15014)
This commit is contained in:
@@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
|
||||
#if defined(GGML_USE_HIP)
|
||||
static int mmq_get_nwarps_host(const int cc) {
|
||||
return amd_mfma_available(cc) ? 8 : 4;
|
||||
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
|
||||
return amd_mfma_available(cc) ? 8 : 256/warp_size;
|
||||
}
|
||||
#else
|
||||
static int mmq_get_nwarps_host(const int /*cc*/) {
|
||||
return 8;
|
||||
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
|
||||
return 256/warp_size;
|
||||
}
|
||||
#endif // (GGML_USE_HIP)
|
||||
|
||||
static constexpr __device__ int mmq_get_nwarps_device() {
|
||||
#if defined(GGML_USE_HIP)
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
return 8;
|
||||
#else
|
||||
return 4;
|
||||
return 256/ggml_cuda_get_physical_warp_size();
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
#else
|
||||
return 8;
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------
|
||||
@@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const int nwarps = mmq_get_nwarps_host(cc);
|
||||
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
||||
const int mmq_y = get_mmq_y_host(cc);
|
||||
|
||||
const dim3 block_dims(warp_size, nwarps, 1);
|
||||
@@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const int nwarps = mmq_get_nwarps_host(cc);
|
||||
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
||||
|
||||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||
const int mmq_y = get_mmq_y_host(cc);
|
||||
|
Reference in New Issue
Block a user