mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-29 05:33:37 -04:00
ggml : implement REGLU/GEGLU/SWIGLU ops (#14158)
* implement unary REGLU/GEGLU/SWIGLU cpu ops * relax constraints * duplicate shape of source * fix ggml_vec_geglu_f16 * special case gated ops * implement unary REGLU/GEGLU/SWIGLU cuda ops * tighten constraints again * refactor into GGML_GLU_OP * metal : add glu kernels ggml-ci * add CUDA_GLU_BLOCK_SIZE [no ci] * more constraints and use 64bit ints ggml-ci * 64bit multiplication [no ci] * implement swapped variants (cpu/cuda) * update comment [no ci] ggml-ci * Vulkan: Add GLU ops and shaders * SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate * ggml : implement GLU for split up/gate (#14181) * implement GLU for split up/gate * add tests for ggml_glu_split * Vulkan: Implement glu_split logic and shader support * add split to logging [no ci] * SYCL: refactor element_size ops and add split up and gate support to gated kernels * SYCL: switch GEGLU to use tanh approximation --------- Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Akarshan <akarshan@menlo.ai> * GGML: increase OP count in assertion * Refactor: Optimize SYCL element-wise operations with unary function inlining This commit refactors the SYCL element-wise operations to improve performance by: - Inlining unary operations (sgn, abs, elu, gelu, silu, etc.) to reduce kernel launch overhead. - Introducing helper functions `op_xxx` for each unary operation to encapsulate the logic. - Replacing direct kernel calls with calls to these inlined functions. - Using `__dpct_inline__` to encourage compiler inlining. - Minor code cleanup and consistency improvements. The changes aim to reduce kernel launch overhead and improve the overall efficiency of element-wise operations on SYCL devices. * vulkan: Increase workgroup size for GLU, for performance (#14345) * vulkan: Increase workgroup size for GLU, for performance * vulkan: change GLU shaders to do one element per invocation rather than one row per workgroup * merge fix * metal : add support for split and swap ggml-ci --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Akarshan <akarshan@menlo.ai> Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
This commit is contained in:
@@ -560,12 +560,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
|
||||
switch (type_op) {
|
||||
case LLM_FFN_SILU:
|
||||
{
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_swiglu", il);
|
||||
type_gate = LLM_FFN_SEQ;
|
||||
} else {
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
cb(cur, "ffn_silu", il);
|
||||
} break;
|
||||
case LLM_FFN_GELU:
|
||||
{
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
cur = ggml_geglu_split(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_geglu", il);
|
||||
type_gate = LLM_FFN_SEQ;
|
||||
} else {
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cb(cur, "ffn_gelu", il);
|
||||
if (act_scales != NULL) {
|
||||
@@ -574,7 +582,11 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
}
|
||||
} break;
|
||||
case LLM_FFN_RELU:
|
||||
{
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
cur = ggml_reglu_split(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_reglu", il);
|
||||
type_gate = LLM_FFN_SEQ;
|
||||
} else {
|
||||
cur = ggml_relu(ctx0, cur);
|
||||
cb(cur, "ffn_relu", il);
|
||||
} break;
|
||||
@@ -588,32 +600,19 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
} break;
|
||||
case LLM_FFN_SWIGLU:
|
||||
{
|
||||
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
x0 = ggml_silu(ctx0, x0);
|
||||
cb(cur, "ffn_silu", il);
|
||||
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
cb(cur, "ffn_mul", il);
|
||||
cur = ggml_swiglu(ctx0, cur);
|
||||
cb(cur, "ffn_swiglu", il);
|
||||
} break;
|
||||
case LLM_FFN_GEGLU:
|
||||
{
|
||||
// Split into two equal parts
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
x0 = ggml_gelu(ctx0, x0);
|
||||
cb(x0, "ffn_gelu", il);
|
||||
|
||||
cur = ggml_mul(ctx0, x0, x1);
|
||||
cur = ggml_geglu(ctx0, cur);
|
||||
cb(cur, "ffn_geglu", il);
|
||||
} break;
|
||||
case LLM_FFN_REGLU:
|
||||
{
|
||||
cur = ggml_reglu(ctx0, cur);
|
||||
cb(cur, "ffn_reglu", il);
|
||||
} break;
|
||||
}
|
||||
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
@@ -743,12 +742,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
|
||||
switch (type_op) {
|
||||
case LLM_FFN_SILU:
|
||||
{
|
||||
if (gate_exps) {
|
||||
cur = ggml_swiglu_split(ctx0, cur, up);
|
||||
cb(cur, "ffn_moe_swiglu", il);
|
||||
} else {
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
cb(cur, "ffn_moe_silu", il);
|
||||
} break;
|
||||
case LLM_FFN_GELU:
|
||||
{
|
||||
if (gate_exps) {
|
||||
cur = ggml_geglu_split(ctx0, cur, up);
|
||||
cb(cur, "ffn_moe_geglu", il);
|
||||
} else {
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cb(cur, "ffn_moe_gelu", il);
|
||||
} break;
|
||||
@@ -756,11 +761,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
if (gate_exps) {
|
||||
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(cur, "ffn_moe_gate_par", il);
|
||||
}
|
||||
|
||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
cb(experts, "ffn_moe_down", il);
|
||||
|
||||
|
Reference in New Issue
Block a user