cuda : graceful fallback for Mamba-1 models with weird embd size

This commit is contained in:
Francis Couture-Harpin
2025-07-02 02:56:42 -04:00
parent 73de1fd170
commit 71bef66591
2 changed files with 8 additions and 7 deletions

View File

@ -3329,12 +3329,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
} else {
// Mamba
// (kernel only supports d_state == 16, n_group == 1, d_head == 1)
return op->src[0]->ne[0] == 16 && op->src[4]->ne[1] == 1 && op->src[0]->ne[1] == 1;
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
}
}
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_SSM_CONV: {
// assumes d_inner % threads == 0
return op->src[0]->ne[1] % 128 == 0;
}
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF:

View File

@ -204,7 +204,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) {
// Mamba2
// Mamba-2
if (d_state == 128) {
GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64
@ -219,8 +219,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
GGML_ABORT("doesn't support d_state!=128.");
}
} else {
// Mamba1
// todo: consider n_head cannot be divided, does this situation exist?
// Mamba-1
GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1);