implement swapped variants (cpu/cuda)

This commit is contained in:
Sigbjørn Skjæret
2025-06-13 22:48:53 +02:00
committed by Akarshan
parent f8705a2399
commit 0b2703fc57
7 changed files with 117 additions and 45 deletions

View File

@ -1110,16 +1110,18 @@ struct test_glu : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne_a;
int v; // view (1 : non-contiguous a)
bool swapped;
std::string vars() override {
return VARS_TO_STR3(type, ne_a, v);
return VARS_TO_STR4(type, ne_a, v, swapped);
}
test_glu(ggml_glu_op op,
ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
int v = 0)
: op(op), type(type), ne_a(ne_a), v(v) {}
int v = 0,
bool swapped = false)
: op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a;
@ -1135,7 +1137,7 @@ struct test_glu : public test_case {
ggml_set_name(a, "a");
}
ggml_tensor * out = ggml_glu(ctx, a, op);
ggml_tensor * out = ggml_glu(ctx, a, op, swapped);
ggml_set_name(out, "out");
return out;
@ -4009,8 +4011,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
for (bool swapped : {false, true}) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
}
}
}
}