From 4aaf8114e7953715ef9880463d0de46dd294e13d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 3 Jul 2025 17:05:18 +0200 Subject: [PATCH] ggml: backward pass for split swiglu (llama/14483) --- ggml/src/ggml.c | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b89a68db..68768842 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6042,13 +6042,28 @@ static void ggml_compute_backward( } GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); } break; + case GGML_OP_GLU: { + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_SWIGLU: { + if (src0_needs_grads) { + GGML_ASSERT(src1 && "backward pass only implemented for split swiglu"); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0)); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad)); + } + } break; + default: { + GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor))); + } //break; + } + } break; case GGML_OP_NONE: { // noop } break; case GGML_OP_COUNT: default: { - fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); - GGML_ABORT("fatal error"); + GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); } //break; }