slaren commited on
Commit
cb91db5
·
unverified ·
1 Parent(s): 362430b

backend_sched : fix assignments

Browse files
Files changed (1) hide show
  1. ggml-backend.c +20 -0
ggml-backend.c CHANGED
@@ -1087,6 +1087,24 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
1087
  }
1088
  }
1089
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1090
  #ifdef DEBUG_PASS2
1091
  fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1092
  #endif
@@ -1146,6 +1164,8 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
1146
 
1147
  ggml_tallocr_t node_allocr = node_allocr(node);
1148
 
 
 
1149
  if (node_allocr != cur_allocr) {
1150
  sched->splits[cur_split].i_end = i;
1151
  cur_split++;
 
1087
  }
1088
  }
1089
  }
1090
+
1091
+ // pass 2.4 expand rest down
1092
+ {
1093
+ ggml_tallocr_t cur_allocr = NULL;
1094
+ for (int i = 0; i < graph->n_nodes; i++) {
1095
+ struct ggml_tensor * node = graph->nodes[i];
1096
+ if (ggml_is_view_op(node->op)) {
1097
+ continue;
1098
+ }
1099
+ ggml_tallocr_t node_allocr = node_allocr(node);
1100
+ if (node_allocr != NULL) {
1101
+ cur_allocr = node_allocr;
1102
+ } else {
1103
+ node_allocr(node) = cur_allocr;
1104
+ SET_CAUSE(node, "2.4");
1105
+ }
1106
+ }
1107
+ }
1108
  #ifdef DEBUG_PASS2
1109
  fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
1110
  #endif
 
1164
 
1165
  ggml_tallocr_t node_allocr = node_allocr(node);
1166
 
1167
+ GGML_ASSERT(node_allocr != NULL); // all nodes should be assigned by now
1168
+
1169
  if (node_allocr != cur_allocr) {
1170
  sched->splits[cur_split].i_end = i;
1171
  cur_split++;