implement swapped variants (cpu/cuda)

This commit is contained in:
Sigbjørn Skjæret
2025-06-13 22:48:53 +02:00
committed by Akarshan
parent f8705a2399
commit 0b2703fc57
7 changed files with 117 additions and 45 deletions

View File

@ -1101,23 +1101,37 @@ extern "C" {
// gated linear unit ops
// A: n columns, r rows,
// result is n / 2 columns, r rows,
// expects gate in second half of row, unless swapped is true
GGML_API struct ggml_tensor * ggml_glu(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_glu_op op);
enum ggml_glu_op op,
bool swapped);
GGML_API struct ggml_tensor * ggml_reglu(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_reglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_geglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_swiglu(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_swiglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a);
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,

View File

@ -3214,6 +3214,8 @@ static void ggml_compute_forward_reglu_f32(
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
@ -3224,7 +3226,8 @@ static void ggml_compute_forward_reglu_f32(
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_reglu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3255,6 +3258,8 @@ static void ggml_compute_forward_reglu_f16(
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
@ -3265,7 +3270,8 @@ static void ggml_compute_forward_reglu_f16(
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_reglu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3321,6 +3327,8 @@ static void ggml_compute_forward_geglu_f32(
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
@ -3331,7 +3339,8 @@ static void ggml_compute_forward_geglu_f32(
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_geglu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3362,6 +3371,8 @@ static void ggml_compute_forward_geglu_f16(
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
@ -3372,7 +3383,8 @@ static void ggml_compute_forward_geglu_f16(
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_geglu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3428,6 +3440,8 @@ static void ggml_compute_forward_swiglu_f32(
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
@ -3438,7 +3452,8 @@ static void ggml_compute_forward_swiglu_f32(
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_swiglu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3469,6 +3484,8 @@ static void ggml_compute_forward_swiglu_f16(
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
// rows per thread
const int dr = (nr + nth - 1)/nth;
@ -3479,7 +3496,8 @@ static void ggml_compute_forward_swiglu_f16(
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_swiglu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {

View File

@ -254,27 +254,27 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
}
}
void ggml_vec_swiglu_f32(const int n, float * y, const float * x) {
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
int i = 0;
#if defined(__AVX512F__) && defined(__AVX512DQ__)
for (; i + 15 < n; i += 16) {
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(x + i + n)));
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
}
#elif defined(__AVX2__) && defined(__FMA__)
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(x + i + n)));
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
}
#elif defined(__SSE2__)
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(x + i + n)));
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(x + i + n)));
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
}
#endif
for (; i < n; ++i) {
y[i] = ggml_silu_f32(x[i]) * x[i + n];
y[i] = ggml_silu_f32(x[i]) * g[i];
}
}

View File

@ -905,57 +905,57 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
}
}
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x) {
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
for (int i = 0; i < n; ++i) {
y[i] = (x[i] > 0.f) ? x[i] * x[i + n] : 0.f;
y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
}
}
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
float v = GGML_FP16_TO_FP32(x[i]);
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(x[i + n]) : 0.f);
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f);
}
}
#ifdef GGML_GELU_FP16
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) {
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
uint16_t t;
for (int i = 0; i < n; ++i) {
if (x[i] <= -10.0f) {
y[i] = 0.0f;
} else if (x[i] >= 10.0f) {
y[i] = x[i] * x[i + n];
y[i] = x[i] * g[i];
} else {
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
memcpy(&t, &fp16, sizeof(uint16_t));
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * x[i + n];
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
}
}
}
#else
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x) {
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
for (int i = 0; i < n; ++i) {
y[i] = ggml_gelu_f32(x[i]) * x[i + n];
y[i] = ggml_gelu_f32(x[i]) * g[i];
}
}
#endif
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
const uint16_t * i16 = (const uint16_t *) x;
for (int i = 0; i < n; ++i) {
float g = GGML_FP16_TO_FP32(x[i + n]);
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * g);
float v = GGML_FP16_TO_FP32(g[i]);
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
}
}
void ggml_vec_swiglu_f32(const int n, float * y, const float * x);
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
float v = GGML_FP16_TO_FP32(x[i]);
float g = GGML_FP16_TO_FP32(x[i + n]);
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * g);
float w = GGML_FP16_TO_FP32(g[i]);
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
}
}

View File

@ -199,7 +199,7 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
/* gated ops */
template <float (*op)(float), typename T>
static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o) {
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o) {
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -208,13 +208,13 @@ static __global__ void unary_gated_op_kernel(const T * x, T * dst, const int64_t
// perform base op on first half of row and multiply with gate in second half
const int64_t j = (i / n) * o + (i % n);
dst[i] = (T)(op((float)x[j]) * (float)x[j + n]);
dst[i] = (T)(op((float)x[j]) * (float)g[j]);
}
template <float (*op)(float), typename T>
static void unary_gated_cuda(const T * x, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) {
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) {
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, dst, k, n, o);
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o);
}
template <float (*op)(float)>
@ -235,10 +235,26 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
if (src0->type == GGML_TYPE_F16) {
unary_gated_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(half), stream);
unary_gated_cuda<op>(
(const half *)src0_d + (swapped ? nc : 0),
(const half *)src0_d + (swapped ? 0 : nc),
(half *)dst_d,
ggml_nelements(dst),
nc,
src0->nb[1] / sizeof(half),
stream);
} else {
unary_gated_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(dst), nc, src0->nb[1] / sizeof(float), stream);
unary_gated_cuda<op>(
(const float *)src0_d + (swapped ? nc : 0),
(const float *)src0_d + (swapped ? 0 : nc),
(float *)dst_d,
ggml_nelements(dst),
nc,
src0->nb[1] / sizeof(float),
stream);
}
}

View File

@ -2643,13 +2643,15 @@ struct ggml_tensor * ggml_exp_inplace(
struct ggml_tensor * ggml_glu(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_glu_op op) {
enum ggml_glu_op op,
bool swapped) {
GGML_ASSERT(ggml_is_contiguous_1(a));
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
ggml_set_op_params_i32(result, 0, (int32_t) op);
ggml_set_op_params_i32(result, 1, (int32_t) swapped);
result->op = GGML_OP_GLU;
result->src[0] = a;
@ -2662,7 +2664,13 @@ struct ggml_tensor * ggml_glu(
struct ggml_tensor * ggml_reglu(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU);
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, false);
}
struct ggml_tensor * ggml_reglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, true);
}
// ggml_geglu
@ -2670,7 +2678,13 @@ struct ggml_tensor * ggml_reglu(
struct ggml_tensor * ggml_geglu(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU);
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, false);
}
struct ggml_tensor * ggml_geglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, true);
}
// ggml_swiglu
@ -2678,7 +2692,13 @@ struct ggml_tensor * ggml_geglu(
struct ggml_tensor * ggml_swiglu(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU);
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, false);
}
struct ggml_tensor * ggml_swiglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, true);
}
// ggml_norm

View File

@ -1110,16 +1110,18 @@ struct test_glu : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne_a;
int v; // view (1 : non-contiguous a)
bool swapped;
std::string vars() override {
return VARS_TO_STR3(type, ne_a, v);
return VARS_TO_STR4(type, ne_a, v, swapped);
}
test_glu(ggml_glu_op op,
ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
int v = 0)
: op(op), type(type), ne_a(ne_a), v(v) {}
int v = 0,
bool swapped = false)
: op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a;
@ -1135,7 +1137,7 @@ struct test_glu : public test_case {
ggml_set_name(a, "a");
}
ggml_tensor * out = ggml_glu(ctx, a, op);
ggml_tensor * out = ggml_glu(ctx, a, op, swapped);
ggml_set_name(out, "out");
return out;
@ -4009,8 +4011,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
for (bool swapped : {false, true}) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
}
}
}
}