mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-11 11:05:39 -04:00
SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
#include "common.hpp"
|
||||
#include "ggml-sycl/presets.hpp"
|
||||
#include "ggml.h"
|
||||
#include "element_wise.hpp"
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
static void acc_f32(const float * x, const float * y, float * dst, const int ne,
|
||||
const int ne10, const int ne11, const int ne12,
|
||||
@@ -324,6 +327,34 @@ static void clamp(const T * x, T * dst, const float min, const float max, const
|
||||
dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
|
||||
}
|
||||
|
||||
// Fused GLU kernels
|
||||
template<typename T>
|
||||
static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
|
||||
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
|
||||
const int64_t j = ((i / n) * o) + (i % n);
|
||||
const T x_val = x[j];
|
||||
const T gelu_val = x_val * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x_val)));
|
||||
|
||||
dst[i] = gelu_val * g[j];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
|
||||
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
|
||||
const int64_t j = ((i / n) * o) + (i % n);
|
||||
dst[i] = sycl::max((x[j]), static_cast<T>(0)) * g[j];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, const sycl::nd_item<1> &item_ct1) {
|
||||
for (auto i = item_ct1.get_global_id(0); i < k; i += item_ct1.get_global_range(0)) {
|
||||
const int64_t j = ((i / n) * o) + (i % n);
|
||||
dst[i] = (x[j] / (static_cast<T>(1) + sycl::native::exp(-x[j]))) * g[j];
|
||||
}
|
||||
}
|
||||
|
||||
static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int n_elements, const int ne10, const int ne11,
|
||||
const int ne12, const int nb1, const int nb2,
|
||||
@@ -589,6 +620,33 @@ static void clamp_sycl(const T *x, T *dst, const float min,
|
||||
[=](sycl::nd_item<3> item_ct1) { clamp(x, dst, min, max, k, item_ct1); });
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void geglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
gated_op_fused_geglu(x, g, dst, k, n, o, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void reglu_sycl(const T * x, const T* g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_RELU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
gated_op_fused_reglu(x, g, dst, k, n, o, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void swiglu_sycl(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o, queue_ptr main_stream) {
|
||||
const uint32_t num_blocks = ceil_div(k, SYCL_SILU_BLOCK_SIZE);
|
||||
main_stream->parallel_for(
|
||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||
gated_op_fused_swiglu(x, g, dst, k, n, o, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
@@ -1384,6 +1442,152 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
||||
acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), dst->src[1]->ne[0], dst->src[1]->ne[1], dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const int64_t nc = dst->src[0]->ne[0] / 2;
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||
const void * src0_d = dst->src[0]->data;
|
||||
void * dst_d = dst->data;
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
geglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
|
||||
(const sycl::half *)src0_d + (swapped ? 0 : nc),
|
||||
(sycl::half *) dst_d,
|
||||
ggml_nelements(dst),
|
||||
nc,
|
||||
dst->src[0]->nb[1] / sizeof(sycl::half),
|
||||
main_stream);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
geglu_sycl((const float *) src0_d + (swapped ? nc : 0),
|
||||
(const float *)src0_d + (swapped ? 0 : nc),
|
||||
(float *) dst_d,
|
||||
ggml_nelements(dst),
|
||||
nc,
|
||||
dst->src[0]->nb[1] / sizeof(float),
|
||||
main_stream);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const int64_t nc = dst->src[0]->ne[0] / 2;
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||
const void * src0_d = dst->src[0]->data;
|
||||
void * dst_d = dst->data;
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
reglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
|
||||
(const sycl::half *)src0_d + (swapped ? 0 : nc),
|
||||
(sycl::half *) dst_d,
|
||||
ggml_nelements(dst),
|
||||
nc,
|
||||
dst->src[0]->nb[1] / sizeof(sycl::half),
|
||||
main_stream);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
reglu_sycl((const float *) src0_d + (swapped ? nc : 0),
|
||||
(const float *)src0_d + (swapped ? 0 : nc),
|
||||
(float *) dst_d,
|
||||
ggml_nelements(dst),
|
||||
nc,
|
||||
dst->src[0]->nb[1] / sizeof(float),
|
||||
main_stream);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||
|
||||
#else
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
#endif
|
||||
GGML_ASSERT(dst->src[0]->type == dst->type);
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
const int64_t nc = dst->src[0]->ne[0] / 2;
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||
const void * src0_d = dst->src[0]->data;
|
||||
void * dst_d = dst->data;
|
||||
switch (dst->type) {
|
||||
#if defined (GGML_SYCL_F16)
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
swiglu_sycl((const sycl::half *) src0_d + (swapped ? nc : 0),
|
||||
(const sycl::half *)src0_d + (swapped ? 0 : nc),
|
||||
(sycl::half *) dst_d,
|
||||
ggml_nelements(dst),
|
||||
nc,
|
||||
dst->src[0]->nb[1] / sizeof(sycl::half),
|
||||
main_stream);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
swiglu_sycl((const float *) src0_d + (swapped ? nc : 0),
|
||||
(const float *)src0_d + (swapped ? 0 : nc),
|
||||
(float *) dst_d,
|
||||
ggml_nelements(dst),
|
||||
nc,
|
||||
dst->src[0]->nb[1] / sizeof(float),
|
||||
main_stream);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("GGML tensor type not supported!\n");
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
@@ -1509,3 +1713,20 @@ void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_elu(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_geglu(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_reglu(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_swiglu(ctx, dst);
|
||||
}
|
||||
|
||||
|
||||
|
@@ -24,6 +24,9 @@ typed_data<T> cast_data(ggml_tensor * dst) {
|
||||
};
|
||||
}
|
||||
|
||||
const float GELU_QUICK_COEF = -1.702f;
|
||||
|
||||
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
@@ -73,5 +76,10 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
||||
|
||||
|
@@ -3678,6 +3678,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_GLU:
|
||||
switch (ggml_get_glu_op(dst)) {
|
||||
case GGML_GLU_OP_REGLU:
|
||||
ggml_sycl_reglu(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
ggml_sycl_geglu(ctx, dst);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
ggml_sycl_swiglu(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
ggml_sycl_norm(ctx, dst);
|
||||
break;
|
||||
@@ -4214,6 +4229,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_OP_GLU:
|
||||
switch (ggml_get_glu_op(op)) {
|
||||
case GGML_GLU_OP_REGLU:
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
return ggml_is_contiguous_1(op->src[0]);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
|
Reference in New Issue
Block a user