From 1b8fb8152d0f3357a9c1d9b4c02f6cc5b9cb7232 Mon Sep 17 00:00:00 2001 From: Vineel Abhinav <131174187+vineelabhinav@users.noreply.github.com> Date: Thu, 29 May 2025 11:31:33 +0530 Subject: [PATCH] ggml: aarch64: Implement SVE F32 kernels for vector functions (#13843) * F32-Mamba-SVE * F32-Mamba-SVE * Resolve test errors-1 * Resolve test errors-2 * F32-vec-SVE * F32-vec-SVE * F32-vec-SVE --- ggml/src/ggml-cpu/ops.cpp | 219 ++++++++++++++++++---------- ggml/src/ggml-cpu/simd-mappings.h | 118 ++++++++++++++- ggml/src/ggml-cpu/vec.cpp | 101 ++++++++++--- ggml/src/ggml-cpu/vec.h | 231 ++++++++++++++++++++++-------- 4 files changed, 522 insertions(+), 147 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 26501b711..f4692f3cd 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7641,8 +7641,8 @@ static void ggml_compute_forward_ssm_scan_f32( const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} // use the output as the source for the next token-wise iterations if (i2 > 0) { s0 = s; } @@ -8070,6 +8070,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #define GGML_F32X_MUL GGML_F32x16_MUL #define GGML_F32X_FMA GGML_F32x16_FMA #define WKV_VECTOR_SIZE 16 + #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + #define GGML_F32X GGML_F32xt + #define GGML_F32X_SET1 GGML_F32xt_SET1 + #define GGML_F32X_LOAD GGML_F32xt_LOAD + #define GGML_F32X_STORE GGML_F32xt_STORE + #define GGML_F32X_MUL GGML_F32xt_MUL + #define GGML_F32X_FMA GGML_F32xt_FMA + #define WKV_VECTOR_SIZE 8 #elif defined(__ARM_NEON) && defined(__aarch64__) #define GGML_F32X GGML_F32x4 #define GGML_F32X_SET1 GGML_F32x4_SET1 @@ -8080,8 +8088,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #define WKV_VECTOR_SIZE 4 #endif + int wkv_vector_size; #ifdef WKV_VECTOR_SIZE - const int64_t vec_count = head_size / WKV_VECTOR_SIZE; + #if defined(__ARM_FEATURE_SVE) + wkv_vector_size = svcntw(); + #else + wkv_vector_size = WKV_VECTOR_SIZE; + #endif + const int64_t vec_count = head_size / wkv_vector_size; for (int64_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; @@ -8111,7 +8125,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val); for (int64_t j = 0; j < vec_count; j++) { - size_t base_j = j * WKV_VECTOR_SIZE; + size_t base_j = j * wkv_vector_size; size_t t_h_j_offset = t_h_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j; @@ -8136,7 +8150,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( } // Handle remaining elements, this will not be used. - for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) { + for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; float v_val = v[t_h_j_offset]; @@ -8272,6 +8286,14 @@ static void ggml_compute_forward_gla_f32( #define GGML_F32X_MUL GGML_F32x16_MUL #define GGML_F32X_FMA GGML_F32x16_FMA #define GLA_VECTOR_SIZE 16 + #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + #define GGML_F32X GGML_F32xt + #define GGML_F32X_SET1 GGML_F32xt_SET1 + #define GGML_F32X_LOAD GGML_F32xt_LOAD + #define GGML_F32X_STORE GGML_F32xt_STORE + #define GGML_F32X_MUL GGML_F32xt_MUL + #define GGML_F32X_FMA GGML_F32xt_FMA + #define GLA_VECTOR_SIZE 8 #elif defined(__ARM_NEON) && defined(__aarch64__) #define GGML_F32X GGML_F32x4 #define GGML_F32X_SET1 GGML_F32x4_SET1 @@ -8282,8 +8304,14 @@ static void ggml_compute_forward_gla_f32( #define GLA_VECTOR_SIZE 4 #endif + int gla_vector_size; #ifdef GLA_VECTOR_SIZE - const int64_t vec_count = head_size / GLA_VECTOR_SIZE; + #if defined(__ARM_FEATURE_SVE) + gla_vector_size = svcntw(); + #else + gla_vector_size = GLA_VECTOR_SIZE; + #endif + const int64_t vec_count = head_size / gla_vector_size; for (int64_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; @@ -8310,7 +8338,7 @@ static void ggml_compute_forward_gla_f32( GGML_F32X g_vec = GGML_F32X_SET1(g_val); for (int64_t j = 0; j < vec_count; j++) { - size_t base_j = j * GLA_VECTOR_SIZE; + size_t base_j = j * gla_vector_size; size_t t_h_j_offset = t_h_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j; @@ -8334,7 +8362,7 @@ static void ggml_compute_forward_gla_f32( } // Handle remaining elements, this will not be used. - for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) { + for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; float v_val = v[t_h_j_offset]; @@ -8443,83 +8471,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_stride_2d = head_size * head_size; #if defined(GGML_SIMD) - for (int64_t t = 0; t < T; t++) { - int64_t t_offset = t * t_stride; - int64_t state_offset = head_size * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; + #if defined(__ARM_FEATURE_SVE) + // scalar Route to scalar implementation //TODO: Write SVE code + for (int64_t t = 0; t < T; t++) { + int64_t t_offset = t * t_stride; + int64_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; - for (int64_t h = h_start; h < h_end; h++) { - int64_t h_offset = h * h_stride; - int64_t t_h_offset = t_offset + h_offset; - int64_t h_2d_offset = h * h_stride_2d; + for (int64_t h = h_start; h < h_end; h++) { + int64_t h_offset = h * h_stride; + int64_t t_h_offset = t_offset + h_offset; + int64_t h_2d_offset = h * h_stride_2d; - for (int64_t ii = 0; ii < head_size; ii++) { - int64_t t_h_i_offset = t_h_offset + ii; - int64_t h_2d_i_offset = h_2d_offset + ii * h_stride; + for (int64_t i = 0; i < head_size; i++) { + int64_t t_h_i_offset = t_h_offset + i; + int64_t h_2d_i_offset = h_2d_offset + i * h_stride; - GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]); + float v_val = v[t_h_i_offset]; - float sa = 0; - { - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) { - for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { - ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]); - ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]); - sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]); - } + float sa = 0, result = 0; + for (int64_t j = 0; j < head_size; j++) { + sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j]; } - GGML_F32_VEC_REDUCE(sa, sum); - } - GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa); + for (int64_t j = 0; j < head_size; j++) { + int64_t t_h_j_offset = t_h_offset + j; + int64_t h_2d_i_j_offset = h_2d_i_offset + j; - int64_t j = 0; - GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - for (; j < head_size; j += GGML_F32_STEP) { - for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { - int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR; - int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR; - - GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]); - GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]); - GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); - GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); - - k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); - - GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); - // kv + s * decay + sa * b - state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec); - state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec); - GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec); - - result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec); + float r_val = r[t_h_j_offset]; + float w_val = w[t_h_j_offset]; + float k_val = k[t_h_j_offset]; + float b_val = b[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + result += state_cur[h_2d_i_j_offset] * r_val; } - } - GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec); - - // There shouldn't be left-overs though. - for (; j < head_size; j++) { - int64_t t_h_j_offset = t_h_offset + j; - int64_t h_2d_i_j_offset = h_2d_i_offset + j; - - float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; - float k_val = k[t_h_j_offset]; - float b_val = b[t_h_j_offset]; - float kv_val = v[t_h_i_offset] * k_val; - - float prev_state_val = state_prev[h_2d_i_j_offset]; - state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; - dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val; + dst_data[t_h_i_offset] = result; } } } - } + #else + for (int64_t t = 0; t < T; t++) { + int64_t t_offset = t * t_stride; + int64_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + int64_t h_offset = h * h_stride; + int64_t t_h_offset = t_offset + h_offset; + int64_t h_2d_offset = h * h_stride_2d; + + for (int64_t ii = 0; ii < head_size; ii++) { + int64_t t_h_i_offset = t_h_offset + ii; + int64_t h_2d_i_offset = h_2d_offset + ii * h_stride; + + GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]); + + float sa = 0; + { + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) { + for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { + ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]); + ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]); + sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]); + } + } + GGML_F32_VEC_REDUCE(sa, sum); + } + + GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa); + + int64_t j = 0; + GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + for (; j < head_size; j += GGML_F32_STEP) { + for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { + int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR; + int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR; + + GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]); + GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]); + GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); + GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); + + k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); + + GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); + // kv + s * decay + sa * b + state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec); + state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec); + GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec); + + result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec); + } + } + GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec); + + // There shouldn't be left-overs though. + for (; j < head_size; j++) { + int64_t t_h_j_offset = t_h_offset + j; + int64_t h_2d_i_j_offset = h_2d_i_offset + j; + + float r_val = r[t_h_j_offset]; + float w_val = w[t_h_j_offset]; + float k_val = k[t_h_j_offset]; + float b_val = b[t_h_j_offset]; + float kv_val = v[t_h_i_offset] * k_val; + + float prev_state_val = state_prev[h_2d_i_j_offset]; + state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val; + } + } + } + } + #endif #else for (int64_t t = 0; t < T; t++) { int64_t t_offset = t * t_stride; diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 45c31cf1f..2e3669c01 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -17,7 +17,123 @@ // number of elements to fit in a single register // -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 SVE +#define GGML_F32_EPR 8 +#define DEFAULT_PG svptrue_b32() + +#define GGML_F32xt svfloat32_t +#define GGML_F32xt_ZERO svdup_n_f32(0.0f) +#define GGML_F32xt_SET1(x) svdup_n_f32(x) +#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a) +#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) +#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) +#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) +#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b) +#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a) +#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ +{ \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \ + sum3 = svadd_f32_m(DEFAULT_PG, sum3, sum4); \ + sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum6); \ + sum7 = svadd_f32_m(DEFAULT_PG, sum7, sum8); \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum3); \ + sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum7); \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \ + (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \ +} +#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__) + +#define GGML_F32_VEC GGML_F32xt +#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO +#define GGML_F32_VEC_SET1 GGML_F32xt_SET1 +#define GGML_F32_VEC_LOAD GGML_F32xt_LOAD +#define GGML_F32_VEC_STORE GGML_F32xt_STORE +#define GGML_F32_VEC_FMA GGML_F32xt_FMA +#define GGML_F32_VEC_ADD GGML_F32xt_ADD +#define GGML_F32_VEC_MUL GGML_F32xt_MUL +#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x)) + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + do { \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \ + (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + } while (0) + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i]) + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x))) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) #define GGML_SIMD diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 02d406182..f7614568e 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -17,29 +17,98 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G #if defined(GGML_SIMD) float sumf = 0.0f; - const int np = (n & ~(GGML_F32_STEP - 1)); - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + svfloat32_t sum4 = svdup_n_f32(0.0f); + svfloat32_t sum5 = svdup_n_f32(0.0f); + svfloat32_t sum6 = svdup_n_f32(0.0f); + svfloat32_t sum7 = svdup_n_f32(0.0f); + svfloat32_t sum8 = svdup_n_f32(0.0f); + svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8; + svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8; + for (int i = 0; i < np; i += ggml_f32_step) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); - sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); + sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); + + ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); + sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); + + ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); + sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); + + ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); + sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); + + ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); + sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); + + ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); + sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); } - } + // leftovers + // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop + const int np2 = (n & ~(ggml_f32_epr - 1)); + for (int i = np; i < np2; i += ggml_f32_epr) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + } + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np2 < n) { + svbool_t pg = svwhilelt_b32(np2, n); + ax1 = svld1_f32(pg, x + np2); + ay1 = svld1_f32(pg, y + np2); + sum1 = svmad_f32_m(pg, ax1, ay1, sum1); + } + // reduce sum1,sum2 to sum1 + GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // reduce sum0..sum3 to sum0 - GGML_F32_VEC_REDUCE(sumf, sum); + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } + #endif #else // scalar ggml_float sumf = 0.0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index c77349ebe..163a27435 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -5,6 +5,7 @@ #include "ggml-impl.h" #include "simd-mappings.h" #include "ggml.h" +#include "ggml-cpu.h" #if defined(GGML_USE_ACCELERATE) #include @@ -148,27 +149,108 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) { #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f32_step) { - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + GGML_F32_VEC_STORE(y + i, ay1); + + ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); + + GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); + + ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); + + GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); + + ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); + + GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); + + ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); + + GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); + + ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); + + GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); + + ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); + + GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); + + ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); + + GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); } - } + // leftovers + // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop + const int np2 = (n & ~(ggml_f32_epr - 1)); + for (int i = np; i < np2; i += ggml_f32_epr) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; - } + GGML_F32_VEC_STORE(y + i, ay1); + } + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np2 < n) { + svbool_t pg =svwhilelt_b32(np2, n); + ax1 = svld1_f32(pg, x + np2); + ay1 = svld1_f32(pg, y + np2); + ay1 = svmad_f32_m(pg, ax1, vx, ay1); + + svst1_f32(pg, y + np2, ay1); + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -220,36 +302,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int } #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - vx[k] = GGML_F32_VEC_SET1(v[k][0]); - } - - GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + #if defined(__ARM_FEATURE_SVE) + // scalar Route to scalar implementation //TODO: Write SVE code + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; } - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } - } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // leftovers - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - for (int i = np; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; + GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = GGML_F32_VEC_SET1(v[k][0]); } - } + + GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } + #endif #else // scalar for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { @@ -265,25 +356,53 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) vDSP_vsmul(y, 1, &v, y, 1, n); #elif defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 2 * ggml_f32_epr; - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t ay1; + svfloat32_t ay2; + for (int i = 0; i < np; i += ggml_f32_step) { + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_MUL(ay1, vx); + GGML_F32_VEC_STORE(y + i, ay1); - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_MUL(ay[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_MUL(ay2, vx); + GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); } - } + // leftovers + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np < n) { + svbool_t pg = svwhilelt_b32(np, n); + ay1 = svld1_f32(pg, y + np); + ay1 = svmul_f32_m(pg, ay1, vx); + svst1_f32(pg, y + np, ay1); + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } + #endif #else // scalar for (int i = 0; i < n; ++i) {