From a57d1bcb3c0165ac87b1f0dbb429839b0da69689 Mon Sep 17 00:00:00 2001 From: compilade Date: Wed, 9 Jul 2025 23:54:38 -0400 Subject: [PATCH] cuda : support Falcon-H1 state size for SSM_SCAN (#14602) --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++-- ggml/src/ggml-cuda/ssm-scan.cu | 15 +++++++++++++-- tests/test-backend-ops.cpp | 1 + 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index da1e8f8f4..72406f0af 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3335,8 +3335,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 - // (kernel only supports d_state == 128 && d_head % 16 == 0) - return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0; + // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0) + return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0; } else { // Mamba // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1) diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index dc3b1a9a8..c9184398b 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { - const int threads = 128; // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! if (src3_nb1 == sizeof(float)) { // Mamba-2 if (d_state == 128) { + const int threads = 128; GGML_ASSERT(d_state % threads == 0); // NOTE: can be any power of two between 4 and 64 const int splitH = 16; @@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); + } else if (d_state == 256) { // Falcon-H1 + const int threads = 256; + // NOTE: can be any power of two between 8 and 64 + const int splitH = 16; + GGML_ASSERT(head_dim % splitH == 0); + const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); + ssm_scan_f32_group<16, 256><<>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); } else { - GGML_ABORT("doesn't support d_state!=128."); + GGML_ABORT("doesn't support d_state!=(128 or 256)."); } } else { + const int threads = 128; // Mamba-1 GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(head_dim == 1); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1d837b432..4eeeb6e43 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5069,6 +5069,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1 test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));