refactor into GGML_GLU_OP

This commit is contained in:
Sigbjørn Skjæret
2025-06-13 10:14:32 +02:00
committed by Akarshan
parent f8c20809de
commit a341aa3c2b
7 changed files with 170 additions and 53 deletions

View File

@ -519,6 +519,8 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_OPT_STEP_ADAMW, GGML_OP_OPT_STEP_ADAMW,
GGML_OP_GLU,
GGML_OP_COUNT, GGML_OP_COUNT,
}; };
@ -538,13 +540,18 @@ extern "C" {
GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP, GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_REGLU,
GGML_UNARY_OP_GEGLU,
GGML_UNARY_OP_SWIGLU,
GGML_UNARY_OP_COUNT, GGML_UNARY_OP_COUNT,
}; };
enum ggml_glu_op {
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,
GGML_GLU_OP_COUNT,
};
enum ggml_object_type { enum ggml_object_type {
GGML_OBJECT_TYPE_TENSOR, GGML_OBJECT_TYPE_TENSOR,
GGML_OBJECT_TYPE_GRAPH, GGML_OBJECT_TYPE_GRAPH,
@ -660,6 +667,7 @@ extern "C" {
GGML_API const char * ggml_op_symbol(enum ggml_op op); GGML_API const char * ggml_op_symbol(enum ggml_op op);
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@ -761,6 +769,7 @@ extern "C" {
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
@ -1089,6 +1098,14 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// gated linear unit ops
// A: n columns, r rows,
// result is n / 2 columns, r rows,
GGML_API struct ggml_tensor * ggml_glu(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_glu_op op);
GGML_API struct ggml_tensor * ggml_reglu( GGML_API struct ggml_tensor * ggml_reglu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);

View File

@ -1941,6 +1941,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_unary(params, tensor); ggml_compute_forward_unary(params, tensor);
} break; } break;
case GGML_OP_GLU:
{
ggml_compute_forward_glu(params, tensor);
} break;
case GGML_OP_GET_REL_POS: case GGML_OP_GET_REL_POS:
{ {
ggml_compute_forward_get_rel_pos(params, tensor); ggml_compute_forward_get_rel_pos(params, tensor);
@ -2144,9 +2148,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_REGLU: {
case GGML_UNARY_OP_GEGLU: n_tasks = n_threads;
case GGML_UNARY_OP_SWIGLU: } break;
default:
GGML_ABORT("fatal error");
}
break;
case GGML_OP_GLU:
switch (ggml_get_glu_op(node)) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
{ {
n_tasks = n_threads; n_tasks = n_threads;
} break; } break;

View File

@ -8308,15 +8308,31 @@ void ggml_compute_forward_unary(
{ {
ggml_compute_forward_exp(params, dst); ggml_compute_forward_exp(params, dst);
} break; } break;
case GGML_UNARY_OP_REGLU: default:
{
GGML_ABORT("fatal error");
}
}
}
//ggml_compute_forward_glu
void ggml_compute_forward_glu(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_glu_op op = ggml_get_glu_op(dst);
switch (op) {
case GGML_GLU_OP_REGLU:
{ {
ggml_compute_forward_reglu(params, dst); ggml_compute_forward_reglu(params, dst);
} break; } break;
case GGML_UNARY_OP_GEGLU: case GGML_GLU_OP_GEGLU:
{ {
ggml_compute_forward_geglu(params, dst); ggml_compute_forward_geglu(params, dst);
} break; } break;
case GGML_UNARY_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
{ {
ggml_compute_forward_swiglu(params, dst); ggml_compute_forward_swiglu(params, dst);
} break; } break;

View File

@ -93,6 +93,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -2246,13 +2246,19 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_EXP:
ggml_cuda_op_exp(ctx, dst); ggml_cuda_op_exp(ctx, dst);
break; break;
case GGML_UNARY_OP_REGLU: default:
return false;
}
break;
case GGML_OP_GLU:
switch (ggml_get_glu_op(dst)) {
case GGML_GLU_OP_REGLU:
ggml_cuda_op_reglu(ctx, dst); ggml_cuda_op_reglu(ctx, dst);
break; break;
case GGML_UNARY_OP_GEGLU: case GGML_GLU_OP_GEGLU:
ggml_cuda_op_geglu(ctx, dst); ggml_cuda_op_geglu(ctx, dst);
break; break;
case GGML_UNARY_OP_SWIGLU: case GGML_GLU_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst); ggml_cuda_op_swiglu(ctx, dst);
break; break;
default: default:
@ -3048,9 +3054,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_EXP:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_UNARY_OP_REGLU: default:
case GGML_UNARY_OP_GEGLU: return false;
case GGML_UNARY_OP_SWIGLU: }
break;
case GGML_OP_GLU:
switch (ggml_get_glu_op(op)) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
return ggml_is_contiguous_1(op->src[0]); return ggml_is_contiguous_1(op->src[0]);
default: default:
return false; return false;

View File

@ -984,6 +984,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK", "CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAMW", "OPT_STEP_ADAMW",
"GLU",
}; };
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
@ -1080,6 +1082,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)", "cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)", "cross_entropy_loss_back(x,y)",
"adamw(x)", "adamw(x)",
"glu(x)",
}; };
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
@ -1103,12 +1107,18 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSIGMOID", "HARDSIGMOID",
"EXP", "EXP",
"GELU_ERF", "GELU_ERF",
};
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU", "REGLU",
"GEGLU", "GEGLU",
"SWIGLU", "SWIGLU",
}; };
static_assert(GGML_UNARY_OP_COUNT == 18, "GGML_UNARY_OP_COUNT != 18"); static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@ -1213,11 +1223,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
return GGML_UNARY_OP_NAME[op]; return GGML_UNARY_OP_NAME[op];
} }
const char * ggml_glu_op_name(enum ggml_glu_op op) {
return GGML_GLU_OP_NAME[op];
}
const char * ggml_op_desc(const struct ggml_tensor * t) { const char * ggml_op_desc(const struct ggml_tensor * t) {
if (t->op == GGML_OP_UNARY) { if (t->op == GGML_OP_UNARY) {
enum ggml_unary_op uop = ggml_get_unary_op(t); enum ggml_unary_op uop = ggml_get_unary_op(t);
return ggml_unary_op_name(uop); return ggml_unary_op_name(uop);
} }
if (t->op == GGML_OP_GLU) {
enum ggml_glu_op gop = ggml_get_glu_op(t);
return ggml_glu_op_name(gop);
}
return ggml_op_name(t->op); return ggml_op_name(t->op);
} }
@ -1736,6 +1754,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
} }
enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->op == GGML_OP_GLU);
return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
}
const char * ggml_get_name(const struct ggml_tensor * tensor) { const char * ggml_get_name(const struct ggml_tensor * tensor) {
return tensor->name; return tensor->name;
} }
@ -2615,40 +2638,39 @@ struct ggml_tensor * ggml_exp_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP); return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
} }
// ggml_reglu // ggml_glu
struct ggml_tensor * ggml_reglu( struct ggml_tensor * ggml_glu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a,
enum ggml_glu_op op) {
GGML_ASSERT(ggml_is_contiguous_1(a)); GGML_ASSERT(ggml_is_contiguous_1(a));
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i]; int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0); struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_REGLU); ggml_set_op_params_i32(result, 0, (int32_t) op);
result->op = GGML_OP_UNARY; result->op = GGML_OP_GLU;
result->src[0] = a; result->src[0] = a;
return result; return result;
} }
// ggml_reglu
struct ggml_tensor * ggml_reglu(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_glu(ctx, a, GGML_GLU_OP_REGLU);
}
// ggml_geglu // ggml_geglu
struct ggml_tensor * ggml_geglu( struct ggml_tensor * ggml_geglu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a) {
GGML_ASSERT(ggml_is_contiguous_1(a)); return ggml_glu(ctx, a, GGML_GLU_OP_GEGLU);
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_GEGLU);
result->op = GGML_OP_UNARY;
result->src[0] = a;
return result;
} }
// ggml_swiglu // ggml_swiglu
@ -2656,17 +2678,7 @@ struct ggml_tensor * ggml_geglu(
struct ggml_tensor * ggml_swiglu( struct ggml_tensor * ggml_swiglu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a) {
GGML_ASSERT(ggml_is_contiguous_1(a)); return ggml_glu(ctx, a, GGML_GLU_OP_SWIGLU);
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, ne, NULL, 0);
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU);
result->op = GGML_OP_UNARY;
result->src[0] = a;
return result;
} }
// ggml_norm // ggml_norm

View File

@ -1072,16 +1072,7 @@ struct test_unary : public test_case {
ggml_set_name(a, "a"); ggml_set_name(a, "a");
} }
ggml_tensor * out; ggml_tensor * out = ggml_unary(ctx, a, op);
if (op == GGML_UNARY_OP_REGLU) {
out = ggml_reglu(ctx, a);
} else if (op == GGML_UNARY_OP_GEGLU) {
out = ggml_geglu(ctx, a);
} else if (op == GGML_UNARY_OP_SWIGLU) {
out = ggml_swiglu(ctx, a);
} else {
out = ggml_unary(ctx, a, op);
}
ggml_set_name(out, "out"); ggml_set_name(out, "out");
return out; return out;
@ -1113,6 +1104,51 @@ struct test_unary : public test_case {
}; };
// GGML_OP_GLU
struct test_glu : public test_case {
const ggml_glu_op op;
const ggml_type type;
const std::array<int64_t, 4> ne_a;
int v; // view (1 : non-contiguous a)
std::string vars() override {
return VARS_TO_STR3(type, ne_a, v);
}
test_glu(ggml_glu_op op,
ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
int v = 0)
: op(op), type(type), ne_a(ne_a), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a;
if (v & 1) {
auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a");
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
}
ggml_tensor * out = ggml_glu(ctx, a, op);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// test extended range of values to check for NaNs in GELU
init_tensor_uniform(t, -150.f, 150.f);
}
}
};
// GGML_OP_GET_ROWS // GGML_OP_GET_ROWS
struct test_get_rows : public test_case { struct test_get_rows : public test_case {
const ggml_type type; const ggml_type type;
@ -3969,6 +4005,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
} }
// glu ops
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
}
}
}
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false)); test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) { for (ggml_type type : all_types) {
for (int b : {1, 7}) { for (int b : {1, 7}) {