From c8c4495b8d3a8799e2d46778f993965b0ac1ae43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 3 Jul 2025 17:05:18 +0200 Subject: [PATCH] ggml: backward pass for split swiglu (#14483) --- ggml/src/ggml.c | 19 +++++++++++++++++-- tests/test-backend-ops.cpp | 4 ++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index fdb57e178..b481193da 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6050,13 +6050,28 @@ static void ggml_compute_backward( } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break; + case GGML_OP_GLU: { + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_SWIGLU: { + if (src0_needs_grads) { + GGML_ASSERT(src1 && "backward pass only implemented for split swiglu"); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0)); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad)); + } + } break; + default: { + GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor))); + } //break; + } + } break; case GGML_OP_NONE: { // noop } break; case GGML_OP_COUNT: default: { - fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); - GGML_ABORT("fatal error"); + GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); } //break; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2ab6dd06f..c76635793 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1175,21 +1175,25 @@ struct test_glu_split : public test_case { if (v & 1) { auto ne = ne_a; ne[0] *= 3; a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(a); ggml_set_name(a, "a"); a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(a, "view_of_a"); b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(b); ggml_set_name(b, "b"); b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0); ggml_set_name(a, "view_of_b"); } else { a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_param(a); ggml_set_name(a, "a"); b = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_set_param(b); ggml_set_name(b, "b"); }