refactor into GGML_GLU_OP

This commit is contained in:
Sigbjørn Skjæret
2025-06-13 10:14:32 +02:00
committed by Akarshan
parent f8c20809de
commit a341aa3c2b
7 changed files with 170 additions and 53 deletions

View File

@ -1072,16 +1072,7 @@ struct test_unary : public test_case {
ggml_set_name(a, "a");
}
ggml_tensor * out;
if (op == GGML_UNARY_OP_REGLU) {
out = ggml_reglu(ctx, a);
} else if (op == GGML_UNARY_OP_GEGLU) {
out = ggml_geglu(ctx, a);
} else if (op == GGML_UNARY_OP_SWIGLU) {
out = ggml_swiglu(ctx, a);
} else {
out = ggml_unary(ctx, a, op);
}
ggml_tensor * out = ggml_unary(ctx, a, op);
ggml_set_name(out, "out");
return out;
@ -1113,6 +1104,51 @@ struct test_unary : public test_case {
};
// GGML_OP_GLU
struct test_glu : public test_case {
const ggml_glu_op op;
const ggml_type type;
const std::array<int64_t, 4> ne_a;
int v; // view (1 : non-contiguous a)
std::string vars() override {
return VARS_TO_STR3(type, ne_a, v);
}
test_glu(ggml_glu_op op,
ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
int v = 0)
: op(op), type(type), ne_a(ne_a), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a;
if (v & 1) {
auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a");
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
}
ggml_tensor * out = ggml_glu(ctx, a, op);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// test extended range of values to check for NaNs in GELU
init_tensor_uniform(t, -150.f, 150.f);
}
}
};
// GGML_OP_GET_ROWS
struct test_get_rows : public test_case {
const ggml_type type;
@ -3969,6 +4005,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
// glu ops
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
}
}
}
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {