diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2a6be2585..11d4819c8 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7664,6 +7664,37 @@ static void ggml_compute_forward_ssm_scan_f32( const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; #if defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + const int ggml_f32_epr = svcntw(); + const int ggml_f32_step = 1 * ggml_f32_epr; + + const int np = (nc & ~(ggml_f32_step - 1)); + + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + for (int i = 0; i < np; i += ggml_f32_step) { + // TODO: maybe unroll more? + for (int j = 0; j < 1; j++) { + GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + + t0 = GGML_F32_VEC_MUL(t0, adA); + t1 = GGML_F32_VEC_MUL(t1, axdt); + + t0 = GGML_F32_VEC_ADD(t0, t1); + + sum = GGML_F32_VEC_FMA(sum, t0, t2); + + GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0); + } + } + + sumf = GGML_F32xt_REDUCE_ONE(sum); + #else const int np = (nc & ~(GGML_F32_STEP - 1)); GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; @@ -7694,6 +7725,7 @@ static void ggml_compute_forward_ssm_scan_f32( // reduce sum0..sum3 to sum0 GGML_F32_VEC_REDUCE(sumf, sum); + #endif #else const int np = 0; #endif @@ -7722,7 +7754,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i1 = 0; i1 < nr; ++i1) { const int ii = i1 + h*nr; const float x_dt = x[ii] * dt_soft_plus; -#ifdef __ARM_FEATURE_SVE +#if defined(__ARM_FEATURE_SVE) svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); svfloat32_t r1_vector = GGML_F32_VEC_ZERO;