mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-06 18:53:51 +00:00
cuda : graceful fallback for Mamba-1 models with weird embd size
This commit is contained in:
@ -3329,12 +3329,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
|
||||
} else {
|
||||
// Mamba
|
||||
// (kernel only supports d_state == 16, n_group == 1, d_head == 1)
|
||||
return op->src[0]->ne[0] == 16 && op->src[4]->ne[1] == 1 && op->src[0]->ne[1] == 1;
|
||||
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
|
||||
return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
|
||||
}
|
||||
}
|
||||
case GGML_OP_SSM_CONV:
|
||||
return true;
|
||||
case GGML_OP_SSM_CONV: {
|
||||
// assumes d_inner % threads == 0
|
||||
return op->src[0]->ne[1] % 128 == 0;
|
||||
}
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
|
@ -204,7 +204,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
const int threads = 128;
|
||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||
if (src3_nb1 == sizeof(float)) {
|
||||
// Mamba2
|
||||
// Mamba-2
|
||||
if (d_state == 128) {
|
||||
GGML_ASSERT(d_state % threads == 0);
|
||||
// NOTE: can be any power of two between 4 and 64
|
||||
@ -219,8 +219,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
GGML_ABORT("doesn't support d_state!=128.");
|
||||
}
|
||||
} else {
|
||||
// Mamba1
|
||||
// todo: consider n_head cannot be divided, does this situation exist?
|
||||
// Mamba-1
|
||||
GGML_ASSERT(n_head % threads == 0);
|
||||
GGML_ASSERT(head_dim == 1);
|
||||
GGML_ASSERT(n_group == 1);
|
||||
|
Reference in New Issue
Block a user