mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-15 15:17:44 +00:00
ggml: backward pass for split swiglu (#14483)
This commit is contained in:
@ -6050,13 +6050,28 @@ static void ggml_compute_backward(
|
||||
}
|
||||
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
||||
} break;
|
||||
case GGML_OP_GLU: {
|
||||
switch (ggml_get_glu_op(tensor)) {
|
||||
case GGML_GLU_OP_SWIGLU: {
|
||||
if (src0_needs_grads) {
|
||||
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
|
||||
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
|
||||
}
|
||||
if (src1_needs_grads) {
|
||||
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
|
||||
}
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
|
||||
} //break;
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_NONE: {
|
||||
// noop
|
||||
} break;
|
||||
case GGML_OP_COUNT:
|
||||
default: {
|
||||
fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
||||
GGML_ABORT("fatal error");
|
||||
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
||||
} //break;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user