JohannesGaessler commited on
Commit
45c8df6
·
1 Parent(s): 1f15602

ggml: backward pass for split swiglu (llama/14483)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml.c +17 -2
ggml/src/ggml.c CHANGED
@@ -6042,13 +6042,28 @@ static void ggml_compute_backward(
6042
  }
6043
  GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
6044
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6045
  case GGML_OP_NONE: {
6046
  // noop
6047
  } break;
6048
  case GGML_OP_COUNT:
6049
  default: {
6050
- fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
6051
- GGML_ABORT("fatal error");
6052
  } //break;
6053
  }
6054
 
 
6042
  }
6043
  GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
6044
  } break;
6045
+ case GGML_OP_GLU: {
6046
+ switch (ggml_get_glu_op(tensor)) {
6047
+ case GGML_GLU_OP_SWIGLU: {
6048
+ if (src0_needs_grads) {
6049
+ GGML_ASSERT(src1 && "backward pass only implemented for split swiglu");
6050
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, ggml_mul(ctx, grad, src1), src0));
6051
+ }
6052
+ if (src1_needs_grads) {
6053
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, ggml_silu(ctx, src0), grad));
6054
+ }
6055
+ } break;
6056
+ default: {
6057
+ GGML_ABORT("unsupported glu op for backward pass: %s", ggml_glu_op_name(ggml_get_glu_op(tensor)));
6058
+ } //break;
6059
+ }
6060
+ } break;
6061
  case GGML_OP_NONE: {
6062
  // noop
6063
  } break;
6064
  case GGML_OP_COUNT:
6065
  default: {
6066
+ GGML_ABORT("%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
 
6067
  } //break;
6068
  }
6069