diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c663b53f9..d3d009cd6 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1101,23 +1101,37 @@ extern "C" { // gated linear unit ops // A: n columns, r rows, // result is n / 2 columns, r rows, + // expects gate in second half of row, unless swapped is true GGML_API struct ggml_tensor * ggml_glu( struct ggml_context * ctx, struct ggml_tensor * a, - enum ggml_glu_op op); + enum ggml_glu_op op, + bool swapped); GGML_API struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_geglu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_swiglu( struct ggml_context * ctx, struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a); + // normalize along rows GGML_API struct ggml_tensor * ggml_norm( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 5ce11915d..53ad20a20 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3214,6 +3214,8 @@ static void ggml_compute_forward_reglu_f32( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3224,7 +3226,8 @@ static void ggml_compute_forward_reglu_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3255,6 +3258,8 @@ static void ggml_compute_forward_reglu_f16( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3265,7 +3270,8 @@ static void ggml_compute_forward_reglu_f16( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3321,6 +3327,8 @@ static void ggml_compute_forward_geglu_f32( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3331,7 +3339,8 @@ static void ggml_compute_forward_geglu_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3362,6 +3371,8 @@ static void ggml_compute_forward_geglu_f16( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3372,7 +3383,8 @@ static void ggml_compute_forward_geglu_f16( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3428,6 +3440,8 @@ static void ggml_compute_forward_swiglu_f32( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3438,7 +3452,8 @@ static void ggml_compute_forward_swiglu_f32( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -3469,6 +3484,8 @@ static void ggml_compute_forward_swiglu_f16( GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == nr); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -3479,7 +3496,8 @@ static void ggml_compute_forward_swiglu_f16( for (int i1 = ir0; i1 < ir1; i1++) { ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index bfb2d5d36..1956f78e4 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -254,27 +254,27 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { } } -void ggml_vec_swiglu_f32(const int n, float * y, const float * x) { +void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) { int i = 0; #if defined(__AVX512F__) && defined(__AVX512DQ__) for (; i + 15 < n; i += 16) { - _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n))); + _mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i))); } #elif defined(__AVX2__) && defined(__FMA__) for (; i + 7 < n; i += 8) { - _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n))); + _mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i))); } #elif defined(__SSE2__) for (; i + 3 < n; i += 4) { - _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n))); + _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i))); } #elif defined(__ARM_NEON) && defined(__aarch64__) for (; i + 3 < n; i += 4) { - vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n))); + vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i))); } #endif for (; i < n; ++i) { - y[i] = ggml_silu_f32(x[i]) * x[i + n]; + y[i] = ggml_silu_f32(x[i]) * g[i]; } } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 178629e99..f9113a0b1 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -905,57 +905,57 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con } } -inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x) { +inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) { for (int i = 0; i < n; ++i) { - y[i] = (x[i] > 0.f) ? x[i] * x[i + n] : 0.f; + y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f; } } -inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { for (int i = 0; i < n; ++i) { float v = GGML_FP16_TO_FP32(x[i]); - y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(x[i + n]) : 0.f); + y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f); } } #ifdef GGML_GELU_FP16 -inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) { +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) { uint16_t t; for (int i = 0; i < n; ++i) { if (x[i] <= -10.0f) { y[i] = 0.0f; } else if (x[i] >= 10.0f) { - y[i] = x[i] * x[i + n]; + y[i] = x[i] * g[i]; } else { ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * x[i + n]; + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i]; } } } #else -inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) { +inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) { for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_f32(x[i]) * x[i + n]; + y[i] = ggml_gelu_f32(x[i]) * g[i]; } } #endif -inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { - float g = GGML_FP16_TO_FP32(x[i + n]); - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * g); + float v = GGML_FP16_TO_FP32(g[i]); + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v); } } -void ggml_vec_swiglu_f32(const int n, float * y, const float * x); +void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g); -inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) { for (int i = 0; i < n; ++i) { float v = GGML_FP16_TO_FP32(x[i]); - float g = GGML_FP16_TO_FP32(x[i + n]); - y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * g); + float w = GGML_FP16_TO_FP32(g[i]); + y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w); } } diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 31177a099..caab84d52 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -199,7 +199,7 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { /* gated ops */ template -static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o) { +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o) { const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -208,13 +208,13 @@ static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t // perform base op on first half of row and multiply with gate in second half const int64_t j = (i / n) * o + (i % n); - dst[i] = (T)(op((float)x[j]) * (float)x[j + n]); + dst[i] = (T)(op((float)x[j]) * (float)g[j]); } template -static void unary_gated_cuda(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) { +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, dst, k, n, o); + unary_gated_op_kernel<<>>(x, g, dst, k, n, o); } template @@ -235,10 +235,26 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->ne[0] == nc); GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + if (src0->type == GGML_TYPE_F16) { - unary_gated_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream); + unary_gated_cuda( + (const half *)src0_d + (swapped ? nc : 0), + (const half *)src0_d + (swapped ? 0 : nc), + (half *)dst_d, + ggml_nelements(dst), + nc, + src0->nb[1] / sizeof(half), + stream); } else { - unary_gated_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(float), stream); + unary_gated_cuda( + (const float *)src0_d + (swapped ? nc : 0), + (const float *)src0_d + (swapped ? 0 : nc), + (float *)dst_d, + ggml_nelements(dst), + nc, + src0->nb[1] / sizeof(float), + stream); } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c34f07217..9b30ac4cd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2643,13 +2643,15 @@ struct ggml_tensor * ggml_exp_inplace( struct ggml_tensor * ggml_glu( struct ggml_context * ctx, struct ggml_tensor * a, - enum ggml_glu_op op) { + enum ggml_glu_op op, + bool swapped) { GGML_ASSERT(ggml_is_contiguous_1(a)); int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); ggml_set_op_params_i32(result, 0, (int32_t) op); + ggml_set_op_params_i32(result, 1, (int32_t) swapped); result->op = GGML_OP_GLU; result->src[0] = a; @@ -2662,7 +2664,13 @@ struct ggml_tensor * ggml_glu( struct ggml_tensor * ggml_reglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_REGLU); + return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, false); +} + +struct ggml_tensor * ggml_reglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, true); } // ggml_geglu @@ -2670,7 +2678,13 @@ struct ggml_tensor * ggml_reglu( struct ggml_tensor * ggml_geglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU); + return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, false); +} + +struct ggml_tensor * ggml_geglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, true); } // ggml_swiglu @@ -2678,7 +2692,13 @@ struct ggml_tensor * ggml_geglu( struct ggml_tensor * ggml_swiglu( struct ggml_context * ctx, struct ggml_tensor * a) { - return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU); + return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, false); +} + +struct ggml_tensor * ggml_swiglu_swapped( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, true); } // ggml_norm diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index bdfa0d3e5..ef3842388 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1110,16 +1110,18 @@ struct test_glu : public test_case { const ggml_type type; const std::array ne_a; int v; // view (1 : non-contiguous a) + bool swapped; std::string vars() override { - return VARS_TO_STR3(type, ne_a, v); + return VARS_TO_STR4(type, ne_a, v, swapped); } test_glu(ggml_glu_op op, ggml_type type = GGML_TYPE_F32, std::array ne_a = {128, 2, 2, 2}, - int v = 0) - : op(op), type(type), ne_a(ne_a), v(v) {} + int v = 0, + bool swapped = false) + : op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a; @@ -1135,7 +1137,7 @@ struct test_glu : public test_case { ggml_set_name(a, "a"); } - ggml_tensor * out = ggml_glu(ctx, a, op); + ggml_tensor * out = ggml_glu(ctx, a, op, swapped); ggml_set_name(out, "out"); return out; @@ -4009,8 +4011,10 @@ static std::vector> make_test_cases_eval() { for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (int v : {0, 1}) { for (int op = 0; op < GGML_GLU_OP_COUNT; op++) { - test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v)); - test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v)); + for (bool swapped : {false, true}) { + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped)); + test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped)); + } } } }