ggml : implement GLU for split up/gate (#14181)

* implement GLU for split up/gate

* add tests for ggml_glu_split

* Vulkan: Implement glu_split logic and shader support

* add split to logging [no ci]

* SYCL: refactor element_size ops and add split up and gate support to gated kernels

* SYCL: switch GEGLU to use tanh approximation

---------

Co-authored-by: 0cc4m <picard12@live.de>
Co-authored-by: Akarshan <akarshan@menlo.ai>
This commit is contained in:
Sigbjørn Skjæret
2025-06-18 16:11:07 +02:00
committed by Akarshan
parent a9aedf46b4
commit 35dacd1a93
14 changed files with 985 additions and 1453 deletions

View File

@ -1132,6 +1132,29 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// A: n columns, r rows,
// B: n columns, r rows,
GGML_API struct ggml_tensor * ggml_glu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_glu_op op);
GGML_API struct ggml_tensor * ggml_reglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_geglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_swiglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,

View File

@ -3201,14 +3201,24 @@ static void ggml_compute_forward_reglu_f32(
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0] / 2;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
@ -3224,10 +3234,15 @@ static void ggml_compute_forward_reglu_f32(
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_reglu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3245,14 +3260,24 @@ static void ggml_compute_forward_reglu_f16(
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0] / 2;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
@ -3268,10 +3293,15 @@ static void ggml_compute_forward_reglu_f16(
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_reglu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3314,14 +3344,24 @@ static void ggml_compute_forward_geglu_f32(
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0] / 2;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
@ -3337,10 +3377,15 @@ static void ggml_compute_forward_geglu_f32(
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_geglu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3358,14 +3403,24 @@ static void ggml_compute_forward_geglu_f16(
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0] / 2;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
@ -3381,10 +3436,15 @@ static void ggml_compute_forward_geglu_f16(
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_geglu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3427,14 +3487,24 @@ static void ggml_compute_forward_swiglu_f32(
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0] / 2;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
@ -3450,10 +3520,15 @@ static void ggml_compute_forward_swiglu_f32(
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_swiglu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(float *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -3471,14 +3546,24 @@ static void ggml_compute_forward_swiglu_f16(
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src0->ne[0] / 2;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
@ -3494,10 +3579,15 @@ static void ggml_compute_forward_swiglu_f16(
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_swiglu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? nc : 0),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])) + (swapped ? 0 : nc));
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {

View File

@ -199,30 +199,36 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
/* gated ops */
template <float (*op)(float), typename T>
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o) {
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
// perform base op on half of the row and multiply with gate in other half
const int64_t j = (i / n) * o + (i % n);
dst[i] = (T)(op((float)x[j]) * (float)g[j]);
// perform base op and multiply with gate (either offset in same tensor or a separate one)
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
}
template <float (*op)(float), typename T>
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o, cudaStream_t stream) {
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o);
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
}
template <float (*op)(float)>
void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const void * src0_d = src0->data;
const ggml_tensor * src1 = dst->src[1];
void * src0_d = src0->data;
void * src1_d = src1 ? src1->data : src0->data;
const int64_t src0_o = src0->nb[1];
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
void * dst_d = dst->data;
const int64_t nc = src0->ne[0] / 2;
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous_1(src0));
@ -235,26 +241,35 @@ void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
GGML_ASSERT(src1->ne[0] == nc);
GGML_ASSERT(src0->type == src1->type);
}
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
if (src0->type == GGML_TYPE_F16) {
unary_gated_cuda<op>(
(const half *)src0_d + (swapped ? nc : 0),
(const half *)src0_d + (swapped ? 0 : nc),
(half *)dst_d,
ggml_nelements(dst),
nc,
src0->nb[1] / sizeof(half),
stream);
half * src0_p = (half *) src0_d;
half * src1_p = (half *) src1_d;
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);
} else {
unary_gated_cuda<op>(
(const float *)src0_d + (swapped ? nc : 0),
(const float *)src0_d + (swapped ? 0 : nc),
(float *)dst_d,
ggml_nelements(dst),
nc,
src0->nb[1] / sizeof(float),
stream);
float * src0_p = (float *) src0_d;
float * src1_p = (float *) src1_d;
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -3,24 +3,24 @@
#include "common.hpp"
#include "ggml.h"
#include <limits.h>
#include <limits> // For std::numeric_limits
template <typename T>
T neg_infinity() {
return -std::numeric_limits<T>::infinity();
}
template<typename T>
template<typename T_Dst, typename T_Src = T_Dst>
struct typed_data {
const T * src;
T * dst;
const T_Src * src;
T_Dst * dst;
};
template<typename T>
typed_data<T> cast_data(ggml_tensor * dst) {
template<typename T_Dst, typename T_Src = T_Dst>
typed_data<T_Dst, T_Src> cast_data(ggml_tensor * dst) {
return {
/* .src = */ static_cast<const T *>(dst->src[0]->data),
/* .dst = */ static_cast<T *>(dst->data)
/* .src = */ static_cast<const T_Src *>(dst->src[0]->data),
/* .dst = */ static_cast<T_Dst *>(dst->data)
};
}
@ -82,4 +82,3 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_ELEMENTWISE_HPP

View File

@ -664,6 +664,11 @@ struct vk_op_push_constants {
float param2;
};
struct vk_op_glu_push_constants {
uint32_t ne00;
uint32_t mode; // 0: default, 1: swapped, 2: split
};
struct vk_op_unary_push_constants {
uint32_t ne;
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@ -2756,8 +2761,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
#undef CREATE_UNARY
#define CREATE_GLU(name) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
CREATE_GLU(geglu)
CREATE_GLU(reglu)
@ -6987,7 +6992,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
if (op == GGML_OP_SOFT_MAX) {
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
// Empty src1 is possible in soft_max, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
@ -7579,12 +7584,23 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const bool swapped = (bool)dst->op_params[1];
const bool split = src1 != nullptr;
GGML_ASSERT(ggml_is_contiguous(src0));
if (!split) {
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
} else {
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(src0->ne[0] == dst->ne[0]);
GGML_ASSERT(src0->type == src1->type);
}
const uint32_t swapped = (uint32_t)dst->op_params[1];
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun);
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun);
}
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@ -9043,7 +9059,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun);
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
break;
default:
return false;
@ -10794,7 +10810,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
GGML_ABORT("fatal error");
}
} else if (tensor->op == GGML_OP_GLU) {
if (src_clone[1] == nullptr) {
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
} else {
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
}
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
if (src1 == nullptr) {
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);

View File

@ -1,43 +1,13 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#include "glu_head.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
void main() {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
const uint offset = p.KX / 2;
const bool swapped = p.KY > 0;
if (!swapped) {
for (uint i = col; i < offset; i += BLOCK_SIZE) {
const uint idx = row * p.KX + i;
const float xi = float(data_a[idx]);
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx + offset]));
float op(float a, float b) {
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
}
} else {
for (uint i = col; i < offset; i += BLOCK_SIZE) {
const uint idx = row * p.KX + i;
const float xi = float(data_a[idx + offset]);
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx]));
}
}
}
#include "glu_main.comp"

View File

@ -0,0 +1,15 @@
#extension GL_EXT_shader_16bit_storage : require
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (push_constant) uniform parameter
{
uint ne00;
uint mode;
} p;

View File

@ -0,0 +1,31 @@
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
if (p.mode == 0) {
// Default
const uint offset = p.ne00 / 2;
for (uint i = col; i < offset; i += BLOCK_SIZE) {
const uint idx = row * p.ne00 + i;
data_d[row * offset + i] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
}
} else if (p.mode == 1) {
// Swapped
const uint offset = p.ne00 / 2;
for (uint i = col; i < offset; i += BLOCK_SIZE) {
const uint idx = row * p.ne00 + i;
data_d[row * offset + i] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
}
} else {
// Split
for (uint i = col; i < p.ne00; i += BLOCK_SIZE) {
const uint idx = row * p.ne00 + i;
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
}
}
}

View File

@ -1,36 +1,9 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#include "glu_head.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
const uint offset = p.KX / 2;
const bool swapped = p.KY > 0;
if (!swapped) {
for (uint i = col; i < offset; i += BLOCK_SIZE) {
const uint idx = row * p.KX + i;
data_d[row * offset + i] = D_TYPE(max(float(data_a[idx]), 0.0f) * float(data_a[idx + offset]));
float op(float a, float b) {
return max(a, 0.0f) * b;
}
} else {
for (uint i = col; i < offset; i += BLOCK_SIZE) {
const uint idx = row * p.KX + i;
data_d[row * offset + i] = D_TYPE(max(float(data_a[idx + offset]), 0.0f) * float(data_a[idx]));
}
}
}
#include "glu_main.comp"

View File

@ -1,38 +1,9 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#include "glu_head.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
const uint offset = p.KX / 2;
const bool swapped = p.KY > 0;
if (!swapped) {
for (uint i = col; i < offset; i += BLOCK_SIZE) {
const uint idx = row * p.KX + i;
const float xi = float(data_a[idx]);
data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx + offset]));
float op(float a, float b) {
return a / (1.0f + exp(-a)) * b;
}
} else {
for (uint i = col; i < offset; i += BLOCK_SIZE) {
const uint idx = row * p.KX + i;
const float xi = float(data_a[idx + offset]);
data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx]));
}
}
}
#include "glu_main.comp"

View File

@ -2640,37 +2640,68 @@ struct ggml_tensor * ggml_exp_inplace(
// ggml_glu
struct ggml_tensor * ggml_glu(
static struct ggml_tensor * ggml_glu_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_glu_op op,
bool swapped) {
GGML_ASSERT(ggml_is_contiguous_1(a));
if (b) {
GGML_ASSERT(ggml_is_contiguous_1(b));
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(a->type == b->type);
}
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
ggml_set_op_params_i32(result, 0, (int32_t) op);
ggml_set_op_params_i32(result, 1, (int32_t) swapped);
result->op = GGML_OP_GLU;
result->src[0] = a;
result->src[1] = b;
return result;
}
struct ggml_tensor * ggml_glu(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_glu_op op,
bool swapped) {
return ggml_glu_impl(ctx, a, NULL, op, swapped);
}
struct ggml_tensor * ggml_glu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
enum ggml_glu_op op) {
return ggml_glu_impl(ctx, a, b, op, false);
}
// ggml_reglu
struct ggml_tensor * ggml_reglu(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, false);
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
}
struct ggml_tensor * ggml_reglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU, true);
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
}
struct ggml_tensor * ggml_reglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
}
// ggml_geglu
@ -2678,13 +2709,20 @@ struct ggml_tensor * ggml_reglu_swapped(
struct ggml_tensor * ggml_geglu(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, false);
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
}
struct ggml_tensor * ggml_geglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU, true);
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
}
struct ggml_tensor * ggml_geglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
}
// ggml_swiglu
@ -2692,13 +2730,20 @@ struct ggml_tensor * ggml_geglu_swapped(
struct ggml_tensor * ggml_swiglu(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, false);
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
}
struct ggml_tensor * ggml_swiglu_swapped(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU, true);
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
}
struct ggml_tensor * ggml_swiglu_split(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
}
// ggml_norm

View File

@ -554,12 +554,20 @@ ggml_tensor * llm_graph_context::build_ffn(
switch (type_op) {
case LLM_FFN_SILU:
{
if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_swiglu_split(ctx0, cur, tmp);
cb(cur, "ffn_swiglu", il);
type_gate = LLM_FFN_SEQ;
} else {
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_silu", il);
} break;
case LLM_FFN_GELU:
{
if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_geglu_split(ctx0, cur, tmp);
cb(cur, "ffn_geglu", il);
type_gate = LLM_FFN_SEQ;
} else {
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_gelu", il);
if (act_scales != NULL) {
@ -568,7 +576,11 @@ ggml_tensor * llm_graph_context::build_ffn(
}
} break;
case LLM_FFN_RELU:
{
if (gate && type_gate == LLM_FFN_PAR) {
cur = ggml_reglu_split(ctx0, cur, tmp);
cb(cur, "ffn_reglu", il);
type_gate = LLM_FFN_SEQ;
} else {
cur = ggml_relu(ctx0, cur);
cb(cur, "ffn_relu", il);
} break;
@ -724,12 +736,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
switch (type_op) {
case LLM_FFN_SILU:
{
if (gate_exps) {
cur = ggml_swiglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_swiglu", il);
} else {
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
{
if (gate_exps) {
cur = ggml_geglu_split(ctx0, cur, up);
cb(cur, "ffn_moe_geglu", il);
} else {
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_moe_gelu", il);
} break;
@ -737,11 +755,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
GGML_ABORT("fatal error");
}
if (gate_exps) {
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate_par", il);
}
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);

View File

@ -1151,6 +1151,60 @@ struct test_glu : public test_case {
}
};
struct test_glu_split : public test_case {
const ggml_glu_op op;
const ggml_type type;
const std::array<int64_t, 4> ne_a;
int v; // view (1 : non-contiguous a)
std::string vars() override {
return VARS_TO_STR3(type, ne_a, v) + ",split";
}
test_glu_split(ggml_glu_op op,
ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
int v = 0)
: op(op), type(type), ne_a(ne_a), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a;
ggml_tensor * b;
if (v & 1) {
auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a");
b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
ggml_set_name(a, "view_of_b");
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(b, "b");
}
ggml_tensor * out = ggml_glu_split(ctx, a, b, op);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// test extended range of values to check for NaNs in GELU
init_tensor_uniform(t, -150.f, 150.f);
}
}
};
// GGML_OP_GET_ROWS
struct test_get_rows : public test_case {
const ggml_type type;
@ -4015,6 +4069,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
}
test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
}
}
}