diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b6adbd922..d13f5212b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3329,12 +3329,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0; } else { // Mamba - // (kernel only supports d_state == 16, n_group == 1, d_head == 1) - return op->src[0]->ne[0] == 16 && op->src[4]->ne[1] == 1 && op->src[0]->ne[1] == 1; + // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1) + return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1; } } - case GGML_OP_SSM_CONV: - return true; + case GGML_OP_SSM_CONV: { + // assumes d_inner % threads == 0 + return op->src[0]->ne[1] % 128 == 0; + } case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; case GGML_OP_DIAG_MASK_INF: diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 61f35f859..dc3b1a9a8 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -204,7 +204,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int threads = 128; // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! if (src3_nb1 == sizeof(float)) { - // Mamba2 + // Mamba-2 if (d_state == 128) { GGML_ASSERT(d_state % threads == 0); // NOTE: can be any power of two between 4 and 64 @@ -219,8 +219,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa GGML_ABORT("doesn't support d_state!=128."); } } else { - // Mamba1 - // todo: consider n_head cannot be divided, does this situation exist? + // Mamba-1 GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1);