Spaces:
Running
Running
Commit
·
45c8df6
1
Parent(s):
1f15602
ggml: backward pass for split swiglu (llama/14483)
Browse files- ggml/src/ggml.c +17 -2
ggml/src/ggml.c
CHANGED
|
@@ -6042,13 +6042,28 @@ static void ggml_compute_backward(
|
|
| 6042 |
}
|
| 6043 |
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
| 6044 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6045 |
case GGML_OP_NONE: {
|
| 6046 |
// noop
|
| 6047 |
} break;
|
| 6048 |
case GGML_OP_COUNT:
|
| 6049 |
default: {
|
| 6050 |
-
|
| 6051 |
-
GGML_ABORT("fatal error");
|
| 6052 |
} //break;
|
| 6053 |
}
|
| 6054 |
|
|
|
|
| 6042 |
}
|
| 6043 |
GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
|
| 6044 |
} break;
|
| 6045 |
+
case GGML_OP_GLU: {
|
| 6046 |
+
switch (ggml_get_glu_op(tensor)) {
|
| 6047 |
+
case GGML_GLU_OP_SWIGLU: {
|
| 6048 |
+
if (src0_needs_grads) {
|
| 6049 |
+
GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
|
| 6050 |
+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
|
| 6051 |
+
}
|
| 6052 |
+
if (src1_needs_grads) {
|
| 6053 |
+
ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
|
| 6054 |
+
}
|
| 6055 |
+
} break;
|
| 6056 |
+
default: {
|
| 6057 |
+
GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
|
| 6058 |
+
} //break;
|
| 6059 |
+
}
|
| 6060 |
+
} break;
|
| 6061 |
case GGML_OP_NONE: {
|
| 6062 |
// noop
|
| 6063 |
} break;
|
| 6064 |
case GGML_OP_COUNT:
|
| 6065 |
default: {
|
| 6066 |
+
GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
|
|
|
|
| 6067 |
} //break;
|
| 6068 |
}
|
| 6069 |
|