SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate

This commit is contained in:
Akarshan
2025-06-14 18:34:21 +05:30
parent 34d1aedafb
commit a9aedf46b4
3 changed files with 254 additions and 0 deletions

View File

@@ -1,6 +1,9 @@
#include "common.hpp"
#include "ggml-sycl/presets.hpp"
#include "ggml.h"
#include "element_wise.hpp"
#include <cstddef>
#include <cstdint>
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
@@ -324,6 +327,34 @@ static void clamp(const T * x, T * dst, const float min, const float max, const
dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
}
// Fused GLU kernels
template<typename T>
static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
const int64_t j = ((i / n) * o) + (i % n);
const T x_val = x[j];
const T gelu_val = x_val * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val)));
dst[i] = gelu_val * g[j];
}
}
template<typename T>
static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
const int64_t j = ((i / n) * o) + (i % n);
dst[i] = sycl::max((x[j]), static_cast<T>(0)) * g[j];
}
}
template<typename T>
static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
const int64_t j = ((i / n) * o) + (i % n);
dst[i] = (x[j] / (static_cast<T>(1) + sycl::native::exp(-x[j]))) * g[j];
}
}
static void acc_f32_sycl(const float *x, const float *y, float *dst,
const int n_elements, const int ne10, const int ne11,
const int ne12, const int nb1, const int nb2,
@@ -589,6 +620,33 @@ static void clamp_sycl(const T *x, T *dst, const float min,
[=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
}
template<typename T>
static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1);
});
}
template<typename T>
static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1);
});
}
template<typename T>
static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE);
main_stream->parallel_for(
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1);
});
}
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -1384,6 +1442,152 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
}
inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int64_t nc = dst->src[0]->ne[0] / 2;
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
GGML_ASSERT(ggml_is_contiguous(dst));
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
const void * src0_d = dst->src[0]->data;
void * dst_d = dst->data;
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
(const sycl::half *)src0_d + (swapped ? 0 : nc),
(sycl::half *) dst_d,
ggml_nelements(dst),
nc,
dst->src[0]->nb[1] / sizeof(sycl::half),
main_stream);
break;
}
#endif
case GGML_TYPE_F32:
{
geglu_sycl((const float *) src0_d + (swapped ? nc : 0),
(const float *)src0_d + (swapped ? 0 : nc),
(float *) dst_d,
ggml_nelements(dst),
nc,
dst->src[0]->nb[1] / sizeof(float),
main_stream);
break;
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
}
}
inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int64_t nc = dst->src[0]->ne[0] / 2;
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
GGML_ASSERT(ggml_is_contiguous(dst));
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
const void * src0_d = dst->src[0]->data;
void * dst_d = dst->data;
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
(const sycl::half *)src0_d + (swapped ? 0 : nc),
(sycl::half *) dst_d,
ggml_nelements(dst),
nc,
dst->src[0]->nb[1] / sizeof(sycl::half),
main_stream);
break;
}
#endif
case GGML_TYPE_F32:
{
reglu_sycl((const float *) src0_d + (swapped ? nc : 0),
(const float *)src0_d + (swapped ? 0 : nc),
(float *) dst_d,
ggml_nelements(dst),
nc,
dst->src[0]->nb[1] / sizeof(float),
main_stream);
break;
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
}
}
inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
#if defined (GGML_SYCL_F16)
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
#else
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
#endif
GGML_ASSERT(dst->src[0]->type == dst->type);
dpct::queue_ptr main_stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
const int64_t nc = dst->src[0]->ne[0] / 2;
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
GGML_ASSERT(ggml_is_contiguous(dst));
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
const void * src0_d = dst->src[0]->data;
void * dst_d = dst->data;
switch (dst->type) {
#if defined (GGML_SYCL_F16)
case GGML_TYPE_F16:
{
swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
(const sycl::half *)src0_d + (swapped ? 0 : nc),
(sycl::half *) dst_d,
ggml_nelements(dst),
nc,
dst->src[0]->nb[1] / sizeof(sycl::half),
main_stream);
break;
}
#endif
case GGML_TYPE_F32:
{
swiglu_sycl((const float *) src0_d + (swapped ? nc : 0),
(const float *)src0_d + (swapped ? 0 : nc),
(float *) dst_d,
ggml_nelements(dst),
nc,
dst->src[0]->nb[1] / sizeof(float),
main_stream);
break;
}
default:
GGML_ABORT("GGML tensor type not supported!\n");
}
}
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
@@ -1509,3 +1713,20 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_elu(ctx, dst);
}
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu(ctx, dst);
}
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_reglu(ctx, dst);
}
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_swiglu(ctx, dst);
}

View File

@@ -24,6 +24,9 @@ typed_data<T> cast_data(ggml_tensor * dst) {
};
}
const float GELU_QUICK_COEF = -1.702f;
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
@@ -73,5 +76,10 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
#endif // GGML_SYCL_ELEMENTWISE_HPP

View File

@@ -3678,6 +3678,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
return false;
}
break;
case GGML_OP_GLU:
switch (ggml_get_glu_op(dst)) {
case GGML_GLU_OP_REGLU:
ggml_sycl_reglu(ctx, dst);
break;
case GGML_GLU_OP_GEGLU:
ggml_sycl_geglu(ctx, dst);
break;
case GGML_GLU_OP_SWIGLU:
ggml_sycl_swiglu(ctx, dst);
break;
default:
return false;
}
break;
case GGML_OP_NORM:
ggml_sycl_norm(ctx, dst);
break;
@@ -4214,6 +4229,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
default:
return false;
}
case GGML_OP_GLU:
switch (ggml_get_glu_op(op)) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
return ggml_is_contiguous_1(op->src[0]);
default:
return false;
}
break;
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{