mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-14 12:19:48 -04:00
CUDA: fix MMQ nwarps for AMD with warp_size==32 (#15014)
This commit is contained in:
@@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
|
|||||||
#endif // AMD_MFMA_AVAILABLE
|
#endif // AMD_MFMA_AVAILABLE
|
||||||
|
|
||||||
#if defined(GGML_USE_HIP)
|
#if defined(GGML_USE_HIP)
|
||||||
static int mmq_get_nwarps_host(const int cc) {
|
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
|
||||||
return amd_mfma_available(cc) ? 8 : 4;
|
return amd_mfma_available(cc) ? 8 : 256/warp_size;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
static int mmq_get_nwarps_host(const int /*cc*/) {
|
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
|
||||||
return 8;
|
return 256/warp_size;
|
||||||
}
|
}
|
||||||
#endif // (GGML_USE_HIP)
|
#endif // (GGML_USE_HIP)
|
||||||
|
|
||||||
static constexpr __device__ int mmq_get_nwarps_device() {
|
static constexpr __device__ int mmq_get_nwarps_device() {
|
||||||
#if defined(GGML_USE_HIP)
|
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
return 8;
|
return 8;
|
||||||
#else
|
#else
|
||||||
return 4;
|
return 256/ggml_cuda_get_physical_warp_size();
|
||||||
#endif // AMD_MFMA_AVAILABLE
|
#endif // AMD_MFMA_AVAILABLE
|
||||||
#else
|
|
||||||
return 8;
|
|
||||||
#endif // defined(GGML_USE_HIP)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------
|
// ------------------------------------------------------------
|
||||||
@@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const int nsm = ggml_cuda_info().devices[id].nsm;
|
const int nsm = ggml_cuda_info().devices[id].nsm;
|
||||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||||
const int nwarps = mmq_get_nwarps_host(cc);
|
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
||||||
const int mmq_y = get_mmq_y_host(cc);
|
const int mmq_y = get_mmq_y_host(cc);
|
||||||
|
|
||||||
const dim3 block_dims(warp_size, nwarps, 1);
|
const dim3 block_dims(warp_size, nwarps, 1);
|
||||||
@@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
|||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||||
const int nwarps = mmq_get_nwarps_host(cc);
|
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
||||||
|
|
||||||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||||
const int mmq_y = get_mmq_y_host(cc);
|
const int mmq_y = get_mmq_y_host(cc);
|
||||||
|
Reference in New Issue
Block a user