ggml: backward pass for split swiglu (#14483)

This commit is contained in:
Johannes Gäßler
2025-07-03 17:05:18 +02:00
committed by GitHub
parent 7b63a71a6b
commit c8c4495b8d
2 changed files with 21 additions and 2 deletions

View File

@ -6050,13 +6050,28 @@ static void ggml_compute_backward(
}
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
} break;
case GGML_OP_GLU: {
switch (ggml_get_glu_op(tensor)) {
case GGML_GLU_OP_SWIGLU: {
if (src0_needs_grads) {
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
}
if (src1_needs_grads) {
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
}
} break;
default: {
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
} //break;
}
} break;
case GGML_OP_NONE: {
// noop
} break;
case GGML_OP_COUNT:
default: {
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
GGML_ABORT("fatal error");
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
} //break;
}

View File

@ -1175,21 +1175,25 @@ struct test_glu_split : public test_case {
if (v & 1) {
auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a");
b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(b);
ggml_set_name(b, "b");
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
ggml_set_name(a, "view_of_b");
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(a);
ggml_set_name(a, "a");
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_param(b);
ggml_set_name(b, "b");
}