cuda : graceful fallback for Mamba-1 models with weird embd size

This commit is contained in:
Francis Couture-Harpin
2025-07-02 02:56:42 -04:00
parent 73de1fd170
commit 71bef66591
2 changed files with 8 additions and 7 deletions

View File

@ -3329,12 +3329,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0; return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
} else { } else {
// Mamba // Mamba
// (kernel only supports d_state == 16, n_group == 1, d_head == 1) // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
return op->src[0]->ne[0] == 16 && op->src[4]->ne[1] == 1 && op->src[0]->ne[1] == 1; return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
} }
} }
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV: {
return true; // assumes d_inner % threads == 0
return op->src[0]->ne[1] % 128 == 0;
}
case GGML_OP_CONT: case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16; return op->src[0]->type != GGML_TYPE_BF16;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:

View File

@ -204,7 +204,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int threads = 128; const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) { if (src3_nb1 == sizeof(float)) {
// Mamba2 // Mamba-2
if (d_state == 128) { if (d_state == 128) {
GGML_ASSERT(d_state % threads == 0); GGML_ASSERT(d_state % threads == 0);
// NOTE: can be any power of two between 4 and 64 // NOTE: can be any power of two between 4 and 64
@ -219,8 +219,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
GGML_ABORT("doesn't support d_state!=128."); GGML_ABORT("doesn't support d_state!=128.");
} }
} else { } else {
// Mamba1 // Mamba-1
// todo: consider n_head cannot be divided, does this situation exist?
GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1); GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1); GGML_ASSERT(n_group == 1);