ggerganov commited on
Commit
d41ba35
·
unverified ·
1 Parent(s): cf2a7c6

ggml : sync (ggml-alloc, GPU, eps, etc.) (#1220)

Browse files

* ggml : sync (ggml-alloc, GPU, eps, etc.)

* ggml : fix build

* wasm : fix build

bindings/javascript/libwhisper.worker.js CHANGED
@@ -1 +1 @@
1
- "use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:function(f){(0,eval)(fs.readFileSync(f,"utf8")+"//# sourceURL="+f)},postMessage:function(msg){parentPort.postMessage(msg)},performance:global.performance||{now:function(){return Date.now()}}})}var initializedJS=false;var pendingNotifiedProxyingQueues=[];function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");if(ENVIRONMENT_IS_NODE){fs.writeSync(2,text+"\n");return}console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=(info,receiveInstance)=>{var instance=new WebAssembly.Instance(Module["wasmModule"],info);receiveInstance(instance);Module["wasmModule"]=null;return instance.exports};self.onunhandledrejection=e=>{throw e.reason??e};self.onmessage=e=>{try{if(e.data.cmd==="load"){Module["wasmModule"]=e.data.wasmModule;for(const handler of e.data.handlers){Module[handler]=function(){postMessage({cmd:"callHandler",handler:handler,args:[...arguments]})}}Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob=="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}whisper_factory(Module).then(function(instance){Module=instance})}else if(e.data.cmd==="run"){Module["__performance_now_clock_drift"]=performance.now()-e.data.time;Module["__emscripten_thread_init"](e.data.pthread_ptr,0,0,1);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInitTLS();if(!initializedJS){Module["__embind_initialize_bindings"]();pendingNotifiedProxyingQueues.forEach(queue=>{Module["executeNotifiedProxyingQueue"](queue)});pendingNotifiedProxyingQueues=[];initializedJS=true}try{Module["invokeEntryPoint"](e.data.start_routine,e.data.arg)}catch(ex){if(ex!="unwind"){if(ex instanceof Module["ExitStatus"]){if(Module["keepRuntimeAlive"]()){}else{Module["__emscripten_thread_exit"](ex.status)}}else{throw ex}}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="processProxyingQueue"){if(initializedJS){Module["executeNotifiedProxyingQueue"](e.data.queue)}else{pendingNotifiedProxyingQueues.push(e.data.queue)}}else if(e.data.cmd){err("worker.js received unknown command "+e.data.cmd);err(e.data)}}catch(ex){if(Module["__emscripten_thread_crashed"]){Module["__emscripten_thread_crashed"]()}throw ex}};
 
1
+ "use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:f=>(0,eval)(fs.readFileSync(f,"utf8")+"//# sourceURL="+f),postMessage:msg=>parentPort.postMessage(msg),performance:global.performance||{now:Date.now}})}var initializedJS=false;function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");if(ENVIRONMENT_IS_NODE){fs.writeSync(2,text+"\n");return}console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=(info,receiveInstance)=>{var module=Module["wasmModule"];Module["wasmModule"]=null;var instance=new WebAssembly.Instance(module,info);return receiveInstance(instance)};self.onunhandledrejection=e=>{throw e.reason||e};function handleMessage(e){try{if(e.data.cmd==="load"){let messageQueue=[];self.onmessage=e=>messageQueue.push(e);self.startWorker=instance=>{Module=instance;postMessage({"cmd":"loaded"});for(let msg of messageQueue){handleMessage(msg)}self.onmessage=handleMessage};Module["wasmModule"]=e.data.wasmModule;for(const handler of e.data.handlers){Module[handler]=(...args)=>{postMessage({cmd:"callHandler",handler:handler,args:args})}}Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob=="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}whisper_factory(Module)}else if(e.data.cmd==="run"){Module["__emscripten_thread_init"](e.data.pthread_ptr,0,0,1);Module["__emscripten_thread_mailbox_await"](e.data.pthread_ptr);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInitTLS();if(!initializedJS){Module["__embind_initialize_bindings"]();initializedJS=true}try{Module["invokeEntryPoint"](e.data.start_routine,e.data.arg)}catch(ex){if(ex!="unwind"){throw ex}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="checkMailbox"){if(initializedJS){Module["checkMailbox"]()}}else if(e.data.cmd){err(`worker.js received unknown command ${e.data.cmd}`);err(e.data)}}catch(ex){if(Module["__emscripten_thread_crashed"]){Module["__emscripten_thread_crashed"]()}throw ex}}self.onmessage=handleMessage;
bindings/javascript/whisper.js CHANGED
The diff for this file is too large to render. See raw diff
 
examples/common.cpp CHANGED
@@ -1,3 +1,5 @@
 
 
1
  #include "common.h"
2
 
3
  // third-party utilities
@@ -13,53 +15,59 @@
13
  #include <codecvt>
14
  #include <sstream>
15
 
16
- #ifndef M_PI
17
- #define M_PI 3.14159265358979323846
18
- #endif
19
-
20
  #if defined(_MSC_VER)
21
  #pragma warning(disable: 4244 4267) // possible loss of data
22
  #endif
23
 
 
 
 
 
 
 
 
 
 
 
 
24
  bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
25
  for (int i = 1; i < argc; i++) {
26
  std::string arg = argv[i];
27
 
28
  if (arg == "-s" || arg == "--seed") {
29
- params.seed = std::stoi(argv[++i]);
30
  } else if (arg == "-t" || arg == "--threads") {
31
- params.n_threads = std::stoi(argv[++i]);
 
 
32
  } else if (arg == "-p" || arg == "--prompt") {
33
- params.prompt = argv[++i];
34
  } else if (arg == "-n" || arg == "--n_predict") {
35
- params.n_predict = std::stoi(argv[++i]);
36
  } else if (arg == "--top_k") {
37
- params.top_k = std::max(1, std::stoi(argv[++i]));
38
  } else if (arg == "--top_p") {
39
- params.top_p = std::stof(argv[++i]);
40
  } else if (arg == "--temp") {
41
- params.temp = std::stof(argv[++i]);
42
  } else if (arg == "--repeat-last-n") {
43
- params.repeat_last_n = std::stof(argv[++i]);
44
  } else if (arg == "--repeat-penalty") {
45
- params.repeat_penalty = std::stof(argv[++i]);
46
  } else if (arg == "-b" || arg == "--batch_size") {
47
- params.n_batch = std::stoi(argv[++i]);
48
  } else if (arg == "-m" || arg == "--model") {
49
- params.model = argv[++i];
50
  } else if (arg == "-i" || arg == "--interactive") {
51
  params.interactive = true;
52
  } else if (arg == "-ip" || arg == "--interactive-port") {
53
  params.interactive = true;
54
- params.interactive_port = std::stoi(argv[++i]);
55
  } else if (arg == "-h" || arg == "--help") {
56
  gpt_print_usage(argc, argv, params);
57
  exit(0);
58
  } else if (arg == "-f" || arg == "--file") {
59
- if (++i > argc) {
60
- fprintf(stderr, "Invalid file param");
61
- break;
62
- }
63
  std::ifstream file(argv[i]);
64
  if (!file) {
65
  fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@@ -70,7 +78,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
70
  params.prompt.pop_back();
71
  }
72
  } else if (arg == "-tt" || arg == "--token_test") {
73
- params.token_test = argv[++i];
74
  }
75
  else {
76
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@@ -89,6 +97,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
89
  fprintf(stderr, " -h, --help show this help message and exit\n");
90
  fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
91
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
 
92
  fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
93
  fprintf(stderr, " prompt to start generation with (default: random)\n");
94
  fprintf(stderr, " -f FNAME, --file FNAME\n");
@@ -755,3 +764,46 @@ float similarity(const std::string & s0, const std::string & s1) {
755
 
756
  return 1.0f - (dist / std::max(s0.size(), s1.size()));
757
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define _USE_MATH_DEFINES // for M_PI
2
+
3
  #include "common.h"
4
 
5
  // third-party utilities
 
15
  #include <codecvt>
16
  #include <sstream>
17
 
 
 
 
 
18
  #if defined(_MSC_VER)
19
  #pragma warning(disable: 4244 4267) // possible loss of data
20
  #endif
21
 
22
+ // Function to check if the next argument exists
23
+ std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
24
+ if (i + 1 < argc && argv[i + 1][0] != '-') {
25
+ return argv[++i];
26
+ } else {
27
+ fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
28
+ gpt_print_usage(argc, argv, params);
29
+ exit(0);
30
+ }
31
+ }
32
+
33
  bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
34
  for (int i = 1; i < argc; i++) {
35
  std::string arg = argv[i];
36
 
37
  if (arg == "-s" || arg == "--seed") {
38
+ params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
39
  } else if (arg == "-t" || arg == "--threads") {
40
+ params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
41
+ } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
42
+ params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
43
  } else if (arg == "-p" || arg == "--prompt") {
44
+ params.prompt = get_next_arg(i, argc, argv, arg, params);
45
  } else if (arg == "-n" || arg == "--n_predict") {
46
+ params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
47
  } else if (arg == "--top_k") {
48
+ params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
49
  } else if (arg == "--top_p") {
50
+ params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
51
  } else if (arg == "--temp") {
52
+ params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
53
  } else if (arg == "--repeat-last-n") {
54
+ params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));
55
  } else if (arg == "--repeat-penalty") {
56
+ params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
57
  } else if (arg == "-b" || arg == "--batch_size") {
58
+ params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
59
  } else if (arg == "-m" || arg == "--model") {
60
+ params.model = get_next_arg(i, argc, argv, arg, params);
61
  } else if (arg == "-i" || arg == "--interactive") {
62
  params.interactive = true;
63
  } else if (arg == "-ip" || arg == "--interactive-port") {
64
  params.interactive = true;
65
+ params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params));
66
  } else if (arg == "-h" || arg == "--help") {
67
  gpt_print_usage(argc, argv, params);
68
  exit(0);
69
  } else if (arg == "-f" || arg == "--file") {
70
+ get_next_arg(i, argc, argv, arg, params);
 
 
 
71
  std::ifstream file(argv[i]);
72
  if (!file) {
73
  fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
 
78
  params.prompt.pop_back();
79
  }
80
  } else if (arg == "-tt" || arg == "--token_test") {
81
+ params.token_test = get_next_arg(i, argc, argv, arg, params);
82
  }
83
  else {
84
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
 
97
  fprintf(stderr, " -h, --help show this help message and exit\n");
98
  fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
99
  fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
100
+ fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers);
101
  fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
102
  fprintf(stderr, " prompt to start generation with (default: random)\n");
103
  fprintf(stderr, " -f FNAME, --file FNAME\n");
 
764
 
765
  return 1.0f - (dist / std::max(s0.size(), s1.size()));
766
  }
767
+
768
+ bool sam_params_parse(int argc, char ** argv, sam_params & params) {
769
+ for (int i = 1; i < argc; i++) {
770
+ std::string arg = argv[i];
771
+
772
+ if (arg == "-s" || arg == "--seed") {
773
+ params.seed = std::stoi(argv[++i]);
774
+ } else if (arg == "-t" || arg == "--threads") {
775
+ params.n_threads = std::stoi(argv[++i]);
776
+ } else if (arg == "-m" || arg == "--model") {
777
+ params.model = argv[++i];
778
+ } else if (arg == "-i" || arg == "--inp") {
779
+ params.fname_inp = argv[++i];
780
+ } else if (arg == "-o" || arg == "--out") {
781
+ params.fname_out = argv[++i];
782
+ } else if (arg == "-h" || arg == "--help") {
783
+ sam_print_usage(argc, argv, params);
784
+ exit(0);
785
+ } else {
786
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
787
+ sam_print_usage(argc, argv, params);
788
+ exit(0);
789
+ }
790
+ }
791
+
792
+ return true;
793
+ }
794
+
795
+ void sam_print_usage(int argc, char ** argv, const sam_params & params) {
796
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
797
+ fprintf(stderr, "\n");
798
+ fprintf(stderr, "options:\n");
799
+ fprintf(stderr, " -h, --help show this help message and exit\n");
800
+ fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
801
+ fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
802
+ fprintf(stderr, " -m FNAME, --model FNAME\n");
803
+ fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
804
+ fprintf(stderr, " -i FNAME, --inp FNAME\n");
805
+ fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
806
+ fprintf(stderr, " -o FNAME, --out FNAME\n");
807
+ fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
808
+ fprintf(stderr, "\n");
809
+ }
examples/common.h CHANGED
@@ -11,7 +11,7 @@
11
  #define COMMON_SAMPLE_RATE 16000
12
 
13
  //
14
- // CLI argument parsing
15
  //
16
 
17
  struct gpt_params {
@@ -33,6 +33,8 @@ struct gpt_params {
33
 
34
  bool interactive = false;
35
  int32_t interactive_port = -1;
 
 
36
  };
37
 
38
  bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
@@ -155,3 +157,20 @@ bool vad_simple(
155
 
156
  // compute similarity between two strings using Levenshtein distance
157
  float similarity(const std::string & s0, const std::string & s1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  #define COMMON_SAMPLE_RATE 16000
12
 
13
  //
14
+ // GPT CLI argument parsing
15
  //
16
 
17
  struct gpt_params {
 
33
 
34
  bool interactive = false;
35
  int32_t interactive_port = -1;
36
+
37
+ int32_t n_gpu_layers = 0;
38
  };
39
 
40
  bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
 
157
 
158
  // compute similarity between two strings using Levenshtein distance
159
  float similarity(const std::string & s0, const std::string & s1);
160
+
161
+ //
162
+ // SAM argument parsing
163
+ //
164
+
165
+ struct sam_params {
166
+ int32_t seed = -1; // RNG seed
167
+ int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
168
+
169
+ std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
170
+ std::string fname_inp = "img.jpg";
171
+ std::string fname_out = "img.out";
172
+ };
173
+
174
+ bool sam_params_parse(int argc, char ** argv, sam_params & params);
175
+
176
+ void sam_print_usage(int argc, char ** argv, const sam_params & params);
examples/talk.wasm/gpt-2.cpp CHANGED
@@ -191,9 +191,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
191
  // create the ggml context
192
  {
193
  struct ggml_init_params params = {
194
- .mem_size = ctx_size,
195
- .mem_buffer = NULL,
196
- .no_alloc = false,
197
  };
198
 
199
  model.ctx = ggml_init(params);
@@ -420,7 +420,6 @@ bool gpt2_eval(
420
 
421
  struct ggml_context * ctx0 = ggml_init(params);
422
  struct ggml_cgraph gf = {};
423
- gf.n_threads = n_threads;
424
 
425
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
426
  memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@@ -442,7 +441,7 @@ bool gpt2_eval(
442
  // norm
443
  {
444
  // [ 768, N]
445
- cur = ggml_norm(ctx0, inpL);
446
 
447
  // cur = ln_1_g*cur + ln_1_b
448
  // [ 768, N]
@@ -589,7 +588,7 @@ bool gpt2_eval(
589
  {
590
  // norm
591
  {
592
- cur = ggml_norm(ctx0, inpFF);
593
 
594
  // cur = ln_2_g*cur + ln_2_b
595
  // [ 768, N]
@@ -644,7 +643,7 @@ bool gpt2_eval(
644
  // norm
645
  {
646
  // [ 768, N]
647
- inpL = ggml_norm(ctx0, inpL);
648
 
649
  // inpL = ln_f_g*inpL + ln_f_b
650
  // [ 768, N]
@@ -664,8 +663,8 @@ bool gpt2_eval(
664
  //inpL = ggml_soft_max(ctx0, inpL);
665
 
666
  // run the computation
667
- ggml_build_forward_expand(&gf, inpL);
668
- ggml_graph_compute (ctx0, &gf);
669
 
670
  //if (n_past%100 == 0) {
671
  // ggml_graph_print (&gf);
 
191
  // create the ggml context
192
  {
193
  struct ggml_init_params params = {
194
+ /*.mem_size =*/ ctx_size,
195
+ /*.mem_buffer =*/ NULL,
196
+ /*.no_alloc =*/ false,
197
  };
198
 
199
  model.ctx = ggml_init(params);
 
420
 
421
  struct ggml_context * ctx0 = ggml_init(params);
422
  struct ggml_cgraph gf = {};
 
423
 
424
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
425
  memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
 
441
  // norm
442
  {
443
  // [ 768, N]
444
+ cur = ggml_norm(ctx0, inpL, 1e-5f);
445
 
446
  // cur = ln_1_g*cur + ln_1_b
447
  // [ 768, N]
 
588
  {
589
  // norm
590
  {
591
+ cur = ggml_norm(ctx0, inpFF, 1e-5f);
592
 
593
  // cur = ln_2_g*cur + ln_2_b
594
  // [ 768, N]
 
643
  // norm
644
  {
645
  // [ 768, N]
646
+ inpL = ggml_norm(ctx0, inpL, 1e-5f);
647
 
648
  // inpL = ln_f_g*inpL + ln_f_b
649
  // [ 768, N]
 
663
  //inpL = ggml_soft_max(ctx0, inpL);
664
 
665
  // run the computation
666
+ ggml_build_forward_expand (&gf, inpL);
667
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
668
 
669
  //if (n_past%100 == 0) {
670
  // ggml_graph_print (&gf);
examples/talk/gpt-2.cpp CHANGED
@@ -379,6 +379,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
379
  // - embd_inp: the embeddings of the tokens in the context
380
  // - embd_w: the predicted logits for the next token
381
  //
 
382
  bool gpt2_eval(
383
  const gpt2_model & model,
384
  const int n_threads,
@@ -420,7 +421,6 @@ bool gpt2_eval(
420
 
421
  struct ggml_context * ctx0 = ggml_init(params);
422
  struct ggml_cgraph gf = {};
423
- gf.n_threads = n_threads;
424
 
425
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
426
  memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@@ -442,7 +442,7 @@ bool gpt2_eval(
442
  // norm
443
  {
444
  // [ 768, N]
445
- cur = ggml_norm(ctx0, inpL);
446
 
447
  // cur = ln_1_g*cur + ln_1_b
448
  // [ 768, N]
@@ -589,7 +589,7 @@ bool gpt2_eval(
589
  {
590
  // norm
591
  {
592
- cur = ggml_norm(ctx0, inpFF);
593
 
594
  // cur = ln_2_g*cur + ln_2_b
595
  // [ 768, N]
@@ -644,7 +644,7 @@ bool gpt2_eval(
644
  // norm
645
  {
646
  // [ 768, N]
647
- inpL = ggml_norm(ctx0, inpL);
648
 
649
  // inpL = ln_f_g*inpL + ln_f_b
650
  // [ 768, N]
@@ -664,8 +664,8 @@ bool gpt2_eval(
664
  //inpL = ggml_soft_max(ctx0, inpL);
665
 
666
  // run the computation
667
- ggml_build_forward_expand(&gf, inpL);
668
- ggml_graph_compute (ctx0, &gf);
669
 
670
  //if (n_past%100 == 0) {
671
  // ggml_graph_print (&gf);
 
379
  // - embd_inp: the embeddings of the tokens in the context
380
  // - embd_w: the predicted logits for the next token
381
  //
382
+ // TODO: sync latest version from ggml repo
383
  bool gpt2_eval(
384
  const gpt2_model & model,
385
  const int n_threads,
 
421
 
422
  struct ggml_context * ctx0 = ggml_init(params);
423
  struct ggml_cgraph gf = {};
 
424
 
425
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
426
  memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
 
442
  // norm
443
  {
444
  // [ 768, N]
445
+ cur = ggml_norm(ctx0, inpL, 1e-5f);
446
 
447
  // cur = ln_1_g*cur + ln_1_b
448
  // [ 768, N]
 
589
  {
590
  // norm
591
  {
592
+ cur = ggml_norm(ctx0, inpFF, 1e-5f);
593
 
594
  // cur = ln_2_g*cur + ln_2_b
595
  // [ 768, N]
 
644
  // norm
645
  {
646
  // [ 768, N]
647
+ inpL = ggml_norm(ctx0, inpL, 1e-5f);
648
 
649
  // inpL = ln_f_g*inpL + ln_f_b
650
  // [ 768, N]
 
664
  //inpL = ggml_soft_max(ctx0, inpL);
665
 
666
  // run the computation
667
+ ggml_build_forward_expand (&gf, inpL);
668
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
669
 
670
  //if (n_past%100 == 0) {
671
  // ggml_graph_print (&gf);
ggml-alloc.c ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-alloc.h"
2
+ #include "ggml.h"
3
+ #include <assert.h>
4
+ #include <stdarg.h>
5
+ #include <stdio.h>
6
+ #include <stdlib.h>
7
+ #include <string.h>
8
+
9
+ #define UNUSED(x) (void)(x)
10
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
11
+ #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
12
+
13
+ //#define GGML_ALLOCATOR_DEBUG
14
+
15
+ //#define AT_PRINTF printf
16
+ #define AT_PRINTF(...) ((void)0)
17
+
18
+ struct hash_node {
19
+ struct ggml_tensor * t;
20
+ int n_children;
21
+ int n_views;
22
+ };
23
+
24
+ static size_t hash(void * p) {
25
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
26
+ }
27
+
28
+ static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
29
+ size_t h = hash(t);
30
+
31
+ // linear probing
32
+ size_t i = h;
33
+ while (hash_table[i].t != NULL) {
34
+ if (hash_table[i].t == t) {
35
+ return &hash_table[i];
36
+ }
37
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
38
+ if (i == h) {
39
+ // hash table is full
40
+ GGML_ASSERT(false);
41
+ }
42
+ }
43
+
44
+ hash_table[i].t = t;
45
+ return &hash_table[i];
46
+ }
47
+
48
+ // TODO: GGML_PAD ?
49
+ static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
50
+ assert(alignment && !(alignment & (alignment - 1))); // power of 2
51
+ size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
52
+ return offset + align;
53
+ }
54
+
55
+ struct free_block {
56
+ void * addr;
57
+ size_t size;
58
+ };
59
+
60
+ #define MAX_FREE_BLOCKS 128
61
+
62
+ struct ggml_allocr {
63
+ void * data;
64
+ size_t size;
65
+ size_t alignment;
66
+ int n_free_blocks;
67
+ struct free_block free_blocks[MAX_FREE_BLOCKS];
68
+ struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
69
+ size_t max_size;
70
+ bool measure;
71
+ int parse_seq[GGML_MAX_CONCUR];
72
+ int parse_seq_len;
73
+
74
+ #ifdef GGML_ALLOCATOR_DEBUG
75
+ struct ggml_tensor * allocated_tensors[1024];
76
+ #endif
77
+ };
78
+
79
+ #ifdef GGML_ALLOCATOR_DEBUG
80
+ static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
81
+ for (int i = 0; i < 1024; i++) {
82
+ if (alloc->allocated_tensors[i] == NULL) {
83
+ alloc->allocated_tensors[i] = tensor;
84
+ return;
85
+ }
86
+ }
87
+ GGML_ASSERT(!"out of allocated_tensors");
88
+ }
89
+ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
90
+ for (int i = 0; i < 1024; i++) {
91
+ if (alloc->allocated_tensors[i] == tensor ||
92
+ (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
93
+ alloc->allocated_tensors[i] = NULL;
94
+ return;
95
+ }
96
+ }
97
+ printf("tried to free tensor %s not found\n", tensor->name);
98
+ GGML_ASSERT(!"tensor not found");
99
+ }
100
+ #endif
101
+
102
+
103
+ static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
104
+ return ggml_nbytes(tensor);
105
+
106
+ UNUSED(alloc);
107
+ }
108
+
109
+ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
110
+ size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
111
+ size = aligned_offset(NULL, size, alloc->alignment);
112
+
113
+ AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
114
+
115
+ size_t max_avail = 0;
116
+
117
+ // find the best fitting free block besides the last block
118
+ int best_fit_block = -1;
119
+ size_t best_fit_size = SIZE_MAX;
120
+ for (int i = 0; i < alloc->n_free_blocks - 1; i++) {
121
+ struct free_block * block = &alloc->free_blocks[i];
122
+ max_avail = MAX(max_avail, block->size);
123
+ if (block->size >= size && block->size <= best_fit_size) {
124
+ best_fit_block = i;
125
+ best_fit_size = block->size;
126
+ }
127
+ }
128
+
129
+ AT_PRINTF("block %d\n", best_fit_block);
130
+
131
+ if (best_fit_block == -1) {
132
+ // the last block is our last resort
133
+ struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
134
+ if (block->size >= size) {
135
+ best_fit_block = alloc->n_free_blocks - 1;
136
+ max_avail = MAX(max_avail, block->size);
137
+ } else {
138
+ fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
139
+ __func__, size, max_avail);
140
+ GGML_ASSERT(!"not enough space in the buffer");
141
+ return;
142
+ }
143
+ }
144
+ struct free_block * block = &alloc->free_blocks[best_fit_block];
145
+ void * addr = block->addr;
146
+ block->addr = (char*)block->addr + size;
147
+ block->size -= size;
148
+ if (block->size == 0) {
149
+ // remove block if empty
150
+ alloc->n_free_blocks--;
151
+ for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
152
+ alloc->free_blocks[j] = alloc->free_blocks[j+1];
153
+ }
154
+ }
155
+
156
+ tensor->data = addr;
157
+
158
+ #ifdef GGML_ALLOCATOR_DEBUG
159
+ add_allocated_tensor(alloc, tensor);
160
+ size_t cur_max = (char*)addr - (char*)alloc->data + size;
161
+ if (cur_max > alloc->max_size) {
162
+ printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
163
+ for (int i = 0; i < 1024; i++) {
164
+ if (alloc->allocated_tensors[i]) {
165
+ printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
166
+ }
167
+ }
168
+ printf("\n");
169
+ }
170
+ #endif
171
+
172
+ alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
173
+ }
174
+
175
+ // this is a very naive implementation, but for our case the number of free blocks should be very small
176
+ static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
177
+ void * ptr = tensor->data;
178
+
179
+ if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
180
+ // the tensor was not allocated in this buffer
181
+ // this can happen because the graph allocator will try to free weights and other tensors from different buffers
182
+ // the easiest way to deal with this is just to ignore it
183
+ return;
184
+ }
185
+
186
+ size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
187
+ size = aligned_offset(NULL, size, alloc->alignment);
188
+ AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
189
+
190
+ #ifdef GGML_ALLOCATOR_DEBUG
191
+ remove_allocated_tensor(alloc, tensor);
192
+ #endif
193
+
194
+ // see if we can merge with an existing block
195
+ for (int i = 0; i < alloc->n_free_blocks; i++) {
196
+ struct free_block * block = &alloc->free_blocks[i];
197
+ // check if ptr is at the end of the block
198
+ if ((char*)block->addr + block->size == ptr) {
199
+ block->size += size;
200
+ // check if we can merge with the next block
201
+ if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) {
202
+ block->size += alloc->free_blocks[i+1].size;
203
+ alloc->n_free_blocks--;
204
+ for (int j = i+1; j < alloc->n_free_blocks; j++) {
205
+ alloc->free_blocks[j] = alloc->free_blocks[j+1];
206
+ }
207
+ }
208
+ return;
209
+ }
210
+ // check if ptr is at the beginning of the block
211
+ if ((char*)ptr + size == block->addr) {
212
+ block->addr = ptr;
213
+ block->size += size;
214
+ // check if we can merge with the previous block
215
+ if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) {
216
+ alloc->free_blocks[i-1].size += block->size;
217
+ alloc->n_free_blocks--;
218
+ for (int j = i; j < alloc->n_free_blocks; j++) {
219
+ alloc->free_blocks[j] = alloc->free_blocks[j+1];
220
+ }
221
+ }
222
+ return;
223
+ }
224
+ }
225
+ // otherwise, add a new block
226
+ GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
227
+ // insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
228
+ int insert_pos = 0;
229
+ while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) {
230
+ insert_pos++;
231
+ }
232
+ // shift all blocks from insert_pos onward to make room for the new block
233
+ for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
234
+ alloc->free_blocks[i] = alloc->free_blocks[i-1];
235
+ }
236
+ // insert the new block
237
+ alloc->free_blocks[insert_pos].addr = ptr;
238
+ alloc->free_blocks[insert_pos].size = size;
239
+ alloc->n_free_blocks++;
240
+ }
241
+
242
+ void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
243
+ for (int i = 0; i < n; i++) {
244
+ alloc->parse_seq[i] = list[i];
245
+ }
246
+ alloc->parse_seq_len = n;
247
+ }
248
+
249
+ void ggml_allocr_reset(struct ggml_allocr * alloc) {
250
+ alloc->n_free_blocks = 1;
251
+ size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
252
+ alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
253
+ alloc->free_blocks[0].size = alloc->size - align_offset;
254
+ }
255
+
256
+ struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
257
+ struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
258
+
259
+ *alloc = (struct ggml_allocr){
260
+ /*.data = */ data,
261
+ /*.size = */ size,
262
+ /*.alignment = */ alignment,
263
+ /*.n_free_blocks = */ 0,
264
+ /*.free_blocks = */ {{0}},
265
+ /*.hash_table = */ {{0}},
266
+ /*.max_size = */ 0,
267
+ /*.measure = */ false,
268
+ /*.parse_seq = */ {0},
269
+ /*.parse_seq_len = */ 0,
270
+ #ifdef GGML_ALLOCATOR_DEBUG
271
+ /*.allocated_tensors = */ {0},
272
+ #endif
273
+ };
274
+
275
+ ggml_allocr_reset(alloc);
276
+
277
+ return alloc;
278
+ }
279
+
280
+ // address and size of the buffer when measuring
281
+ // it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
282
+ static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
283
+ static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
284
+
285
+ struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
286
+ struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
287
+
288
+ *alloc = (struct ggml_allocr){
289
+ /*.data = */ MEASURE_BASE_ADDR,
290
+ /*.size = */ MEASURE_MAX_SIZE,
291
+ /*.alignment = */ alignment,
292
+ /*.n_free_blocks = */ 0,
293
+ /*.free_blocks = */ {{0}},
294
+ /*.hash_table = */ {{0}},
295
+ /*.max_size = */ 0,
296
+ /*.measure = */ true,
297
+ /*.parse_seq = */ {0},
298
+ /*.parse_seq_len = */ 0,
299
+ #ifdef GGML_ALLOCATOR_DEBUG
300
+ /*.allocated_tensors = */ {0},
301
+ #endif
302
+ };
303
+
304
+ ggml_allocr_reset(alloc);
305
+
306
+ return alloc;
307
+ }
308
+
309
+ void ggml_allocr_free(struct ggml_allocr * alloc) {
310
+ free(alloc);
311
+ }
312
+
313
+ bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
314
+ return alloc->measure;
315
+ }
316
+
317
+ //////////// compute graph allocator
318
+
319
+ static bool ggml_is_view(struct ggml_tensor * t) {
320
+ return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
321
+ t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
322
+ }
323
+
324
+ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
325
+ if (a->type != b->type) {
326
+ return false;
327
+ }
328
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
329
+ if (a->ne[i] != b->ne[i]) {
330
+ return false;
331
+ }
332
+ if (a->nb[i] != b->nb[i]) {
333
+ return false;
334
+ }
335
+ }
336
+ return true;
337
+ }
338
+
339
+ static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
340
+ switch (t->op) {
341
+ case GGML_OP_PERMUTE:
342
+ case GGML_OP_RESHAPE:
343
+ case GGML_OP_TRANSPOSE:
344
+ case GGML_OP_VIEW:
345
+ return t->src[0];
346
+ case GGML_OP_CPY:
347
+ return t->src[1];
348
+ default:
349
+ return NULL;
350
+ }
351
+ }
352
+
353
+ static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
354
+ struct ggml_tensor * parent = t;
355
+ do {
356
+ parent = get_view_parent(parent);
357
+ } while (ggml_is_view(parent));
358
+ return parent;
359
+ }
360
+
361
+ static bool ggml_op_can_inplace(enum ggml_op op) {
362
+ switch (op) {
363
+ case GGML_OP_SCALE:
364
+ case GGML_OP_DIAG_MASK_ZERO:
365
+ case GGML_OP_DIAG_MASK_INF:
366
+ case GGML_OP_ADD:
367
+ case GGML_OP_ADD1:
368
+ case GGML_OP_ACC:
369
+ case GGML_OP_SUB:
370
+ case GGML_OP_MUL:
371
+ case GGML_OP_DIV:
372
+ case GGML_OP_SQR:
373
+ case GGML_OP_SQRT:
374
+ case GGML_OP_LOG:
375
+ case GGML_OP_UNARY:
376
+ case GGML_OP_ROPE:
377
+ case GGML_OP_RMS_NORM:
378
+ case GGML_OP_SET:
379
+ case GGML_OP_SOFT_MAX:
380
+ case GGML_OP_CONT:
381
+ case GGML_OP_ADD_REL_POS:
382
+ return true;
383
+
384
+ default:
385
+ return false;
386
+ }
387
+ }
388
+
389
+ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
390
+ struct hash_node * ht = alloc->hash_table;
391
+ if (node->data == NULL) {
392
+ if (ggml_is_view(node)) {
393
+ size_t offset;
394
+ switch(node->op) {
395
+ case GGML_OP_VIEW:
396
+ memcpy(&offset, node->op_params, sizeof(size_t));
397
+ node->data = (char *) node->src[0]->data + offset;
398
+ break;
399
+ case GGML_OP_PERMUTE:
400
+ case GGML_OP_RESHAPE:
401
+ case GGML_OP_TRANSPOSE:
402
+ node->data = node->src[0]->data;
403
+ break;
404
+ case GGML_OP_CPY:
405
+ node->data = node->src[1]->data;
406
+ break;
407
+ default:
408
+ GGML_ASSERT(!"unknown view op");
409
+ break;
410
+ }
411
+ } else {
412
+ // see if we can reuse a parent's buffer (inplace)
413
+ if (ggml_op_can_inplace(node->op)) {
414
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
415
+ struct ggml_tensor * parent = node->src[i];
416
+ if (parent == NULL) {
417
+ break;
418
+ }
419
+
420
+ // if the node's data is external, then we cannot re-use it
421
+ if ((char *) parent->data < (char *) alloc->data ||
422
+ (char *) parent->data >= ((char *) alloc->data + alloc->size)) {
423
+ AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
424
+ continue;
425
+ }
426
+
427
+ struct hash_node * p_hn = hash_get(ht, parent);
428
+ if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
429
+ if (ggml_is_view(parent)) {
430
+ struct ggml_tensor * view_src = get_view_source(parent);
431
+ struct hash_node * view_src_hn = hash_get(ht, view_src);
432
+ if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
433
+ // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
434
+ // the parent's data that it will need later (same layout requirement). the problem is that then
435
+ // we cannot free the tensor because the original address of the allocation is lost.
436
+ // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
437
+ // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
438
+ AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
439
+ node->data = parent->data;
440
+ return;
441
+ }
442
+ }
443
+ else {
444
+ AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
445
+ node->data = parent->data;
446
+ return;
447
+ }
448
+ }
449
+ }
450
+ }
451
+ ggml_allocr_alloc(alloc, node);
452
+ }
453
+ }
454
+ }
455
+
456
+ static size_t ggml_allocator_alloc_graph_tensors_n(
457
+ struct ggml_allocr * alloc,
458
+ struct ggml_cgraph ** graphs, int n_graphs,
459
+ struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
460
+
461
+ // reset hash table
462
+ struct hash_node * ht = alloc->hash_table;
463
+ memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
464
+
465
+ // count number of children and views
466
+ for (int g = 0; g < n_graphs; g++) {
467
+ struct ggml_cgraph * gf = graphs[g];
468
+ for (int i = 0; i < gf->n_nodes; i++) {
469
+ struct ggml_tensor * node = gf->nodes[i];
470
+
471
+ if (ggml_is_view(node)) {
472
+ struct ggml_tensor * view_src = get_view_source(node);
473
+ hash_get(ht, view_src)->n_views += 1;
474
+ }
475
+
476
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
477
+ struct ggml_tensor * parent = node->src[j];
478
+ if (parent == NULL) {
479
+ break;
480
+ }
481
+ hash_get(ht, parent)->n_children += 1;
482
+ }
483
+ }
484
+ }
485
+
486
+ // allocate tensors
487
+ for (int g = 0; g < n_graphs; g++) {
488
+ struct ggml_cgraph * gf = graphs[g];
489
+ AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
490
+ // graph inputs are allocated first to ensure that they are not overwritten by each other
491
+ if (inputs != NULL && inputs[g] != NULL) {
492
+ for (int i = 0; inputs[g][i] != NULL; i++) {
493
+ struct ggml_tensor * input = inputs[g][i];
494
+ AT_PRINTF("input: %s\n", input->name);
495
+ allocate_node(alloc, input);
496
+ }
497
+ }
498
+ // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
499
+ int last_barrier_pos = 0;
500
+ int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
501
+
502
+ for (int ind = 0; ind < n_nodes; ind++) {
503
+ // allocate a node if there is no parse_seq or this is not a barrier
504
+ if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
505
+ int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
506
+ struct ggml_tensor * node = gf->nodes[i];
507
+
508
+ // allocate parents (leafs)
509
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
510
+ struct ggml_tensor * parent = node->src[j];
511
+ if (parent == NULL) {
512
+ break;
513
+ }
514
+ allocate_node(alloc, parent);
515
+ }
516
+
517
+ // allocate node
518
+ allocate_node(alloc, node);
519
+
520
+ AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
521
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
522
+ struct ggml_tensor * parent = node->src[j];
523
+ if (parent == NULL) {
524
+ break;
525
+ }
526
+ AT_PRINTF("%s", parent->name);
527
+ if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
528
+ AT_PRINTF(", ");
529
+ }
530
+ }
531
+ AT_PRINTF("\n");
532
+ }
533
+
534
+
535
+ // update parents
536
+ // update immediately if there is no parse_seq
537
+ // update only at barriers if there is parse_seq
538
+ if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) {
539
+ int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
540
+ int update_end = alloc->parse_seq_len ? ind : ind + 1;
541
+ for (int i = update_start; i < update_end; i++) {
542
+ int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
543
+ struct ggml_tensor * node = gf->nodes[node_i];
544
+
545
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
546
+ struct ggml_tensor * parent = node->src[j];
547
+ if (parent == NULL) {
548
+ break;
549
+ }
550
+ struct hash_node * p_hn = hash_get(ht, parent);
551
+ p_hn->n_children -= 1;
552
+
553
+ //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
554
+
555
+ if (p_hn->n_children == 0 && p_hn->n_views == 0) {
556
+ if (ggml_is_view(parent)) {
557
+ struct ggml_tensor * view_src = get_view_source(parent);
558
+ struct hash_node * view_src_hn = hash_get(ht, view_src);
559
+ view_src_hn->n_views -= 1;
560
+ AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
561
+ if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
562
+ ggml_allocator_free_tensor(alloc, view_src);
563
+ }
564
+ }
565
+ else {
566
+ if (parent->data != node->data) {
567
+ ggml_allocator_free_tensor(alloc, parent);
568
+ }
569
+ }
570
+ }
571
+ }
572
+ }
573
+ AT_PRINTF("\n");
574
+ if (alloc->parse_seq_len) {
575
+ last_barrier_pos = ind + 1;
576
+ }
577
+ }
578
+ }
579
+ // free graph outputs here that wouldn't be freed otherwise because they have no children
580
+ if (outputs != NULL && outputs[g] != NULL) {
581
+ for (int i = 0; outputs[g][i] != NULL; i++) {
582
+ struct ggml_tensor * output = outputs[g][i];
583
+ AT_PRINTF("output: %s\n", output->name);
584
+ ggml_allocator_free_tensor(alloc, output);
585
+ }
586
+ }
587
+ }
588
+
589
+ return alloc->max_size;
590
+ }
591
+
592
+ size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
593
+ return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
594
+ }
ggml-alloc.h ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+
5
+ #ifdef __cplusplus
6
+ extern "C" {
7
+ #endif
8
+
9
+
10
+ GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
11
+ GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
12
+
13
+ // tell the allocator to parse nodes following the order described in the list
14
+ // you should call this if your graph are optimized to execute out-of-order
15
+ GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
16
+
17
+ GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
18
+ GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
19
+ GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);
20
+ GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
21
+ GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
22
+
23
+
24
+ #ifdef __cplusplus
25
+ }
26
+ #endif
ggml-cuda.cu CHANGED
The diff for this file is too large to render. See raw diff
 
ggml-cuda.h CHANGED
@@ -2,34 +2,44 @@
2
 
3
  #include "ggml.h"
4
 
 
 
 
 
 
 
 
 
5
  #ifdef __cplusplus
6
  extern "C" {
7
  #endif
8
 
9
  #define GGML_CUDA_MAX_DEVICES 16
10
 
11
- void ggml_init_cublas(void);
12
- void ggml_cuda_set_tensor_split(const float * tensor_split);
 
 
 
 
 
 
13
 
14
- void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
15
- bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
16
- size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
17
- void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
18
 
19
- // TODO: export these with GGML_API
20
- void * ggml_cuda_host_malloc(size_t size);
21
- void ggml_cuda_host_free(void * ptr);
22
 
23
- void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
 
 
 
 
24
 
25
- void ggml_cuda_free_data(struct ggml_tensor * tensor);
26
- void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
27
- void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
28
- void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
29
- void ggml_cuda_set_main_device(int main_device);
30
- void ggml_cuda_set_scratch_size(size_t scratch_size);
31
- void ggml_cuda_free_scratch(void);
32
- bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
33
 
34
  #ifdef __cplusplus
35
  }
 
2
 
3
  #include "ggml.h"
4
 
5
+ #ifdef GGML_USE_HIPBLAS
6
+ #define GGML_CUDA_NAME "ROCm"
7
+ #define GGML_CUBLAS_NAME "hipBLAS"
8
+ #else
9
+ #define GGML_CUDA_NAME "CUDA"
10
+ #define GGML_CUBLAS_NAME "cuBLAS"
11
+ #endif
12
+
13
  #ifdef __cplusplus
14
  extern "C" {
15
  #endif
16
 
17
  #define GGML_CUDA_MAX_DEVICES 16
18
 
19
+ GGML_API void ggml_init_cublas(void);
20
+ GGML_API void * ggml_cuda_host_malloc(size_t size);
21
+ GGML_API void ggml_cuda_host_free(void * ptr);
22
+
23
+ GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
24
+ GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
25
+ GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
26
+ GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
27
 
28
+ GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
29
+ GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
30
+ GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
 
31
 
32
+ GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
33
+ GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
 
34
 
35
+ GGML_API void ggml_cuda_set_main_device(int main_device);
36
+ GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
37
+ GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
38
+ GGML_API void ggml_cuda_free_scratch(void);
39
+ GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
40
 
41
+ GGML_API int ggml_cuda_get_device_count(void);
42
+ GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
 
 
 
 
 
 
43
 
44
  #ifdef __cplusplus
45
  }
ggml-metal.h CHANGED
@@ -24,6 +24,7 @@
24
 
25
  // max memory buffers that can be mapped to the device
26
  #define GGML_METAL_MAX_BUFFERS 16
 
27
 
28
  struct ggml_tensor;
29
  struct ggml_cgraph;
@@ -34,9 +35,16 @@ extern "C" {
34
 
35
  struct ggml_metal_context;
36
 
37
- struct ggml_metal_context * ggml_metal_init(void);
 
38
  void ggml_metal_free(struct ggml_metal_context * ctx);
39
 
 
 
 
 
 
 
40
  // creates a mapping between a host memory buffer and a device memory buffer
41
  // - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
42
  // - the mapping is used during computation to determine the arguments of the compute kernels
@@ -57,6 +65,16 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
57
  // get data from the device into host memory
58
  void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
59
 
 
 
 
 
 
 
 
 
 
 
60
  // same as ggml_graph_compute but uses Metal
61
  // creates gf->n_threads command buffers in parallel
62
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
 
24
 
25
  // max memory buffers that can be mapped to the device
26
  #define GGML_METAL_MAX_BUFFERS 16
27
+ #define GGML_METAL_MAX_COMMAND_BUFFERS 32
28
 
29
  struct ggml_tensor;
30
  struct ggml_cgraph;
 
35
 
36
  struct ggml_metal_context;
37
 
38
+ // number of command buffers to use
39
+ struct ggml_metal_context * ggml_metal_init(int n_cb);
40
  void ggml_metal_free(struct ggml_metal_context * ctx);
41
 
42
+ void * ggml_metal_host_malloc(size_t n);
43
+ void ggml_metal_host_free (void * data);
44
+
45
+ // set the number of command buffers to use
46
+ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
47
+
48
  // creates a mapping between a host memory buffer and a device memory buffer
49
  // - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
50
  // - the mapping is used during computation to determine the arguments of the compute kernels
 
65
  // get data from the device into host memory
66
  void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
67
 
68
+ // try to find operations that can be run concurrently in the graph
69
+ // you should run it again if the topology of your graph changes
70
+ void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
71
+
72
+ // if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
73
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
74
+
75
+ // output the concur_list for ggml_alloc
76
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
77
+
78
  // same as ggml_graph_compute but uses Metal
79
  // creates gf->n_threads command buffers in parallel
80
  void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
ggml-metal.m CHANGED
@@ -5,7 +5,11 @@
5
  #import <Foundation/Foundation.h>
6
 
7
  #import <Metal/Metal.h>
8
- #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
 
 
 
 
9
 
10
  #ifdef GGML_METAL_NDEBUG
11
  #define metal_printf(...)
@@ -15,6 +19,8 @@
15
 
16
  #define UNUSED(x) (void)(x)
17
 
 
 
18
  struct ggml_metal_buffer {
19
  const char * name;
20
 
@@ -25,21 +31,30 @@ struct ggml_metal_buffer {
25
  };
26
 
27
  struct ggml_metal_context {
28
- float * logits;
29
 
30
  id<MTLDevice> device;
31
  id<MTLCommandQueue> queue;
32
  id<MTLLibrary> library;
33
 
 
 
 
 
 
34
  int n_buffers;
35
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
36
 
 
 
 
37
  // custom kernels
38
  #define GGML_METAL_DECL_KERNEL(name) \
39
  id<MTLFunction> function_##name; \
40
  id<MTLComputePipelineState> pipeline_##name
41
 
42
  GGML_METAL_DECL_KERNEL(add);
 
43
  GGML_METAL_DECL_KERNEL(mul);
44
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
45
  GGML_METAL_DECL_KERNEL(scale);
@@ -51,6 +66,7 @@ struct ggml_metal_context {
51
  GGML_METAL_DECL_KERNEL(get_rows_f16);
52
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
 
54
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
55
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
56
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
@@ -61,11 +77,21 @@ struct ggml_metal_context {
61
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
62
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
63
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
 
64
  GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
65
  GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
66
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
67
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
68
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
 
 
 
 
 
 
 
 
 
69
  GGML_METAL_DECL_KERNEL(rope);
70
  GGML_METAL_DECL_KERNEL(alibi_f32);
71
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -86,22 +112,18 @@ static NSString * const msl_library_source = @"see metal.metal";
86
  @implementation GGMLMetalClass
87
  @end
88
 
89
- struct ggml_metal_context * ggml_metal_init(void) {
90
  fprintf(stderr, "%s: allocating\n", __func__);
91
 
92
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
93
 
 
94
  ctx->device = MTLCreateSystemDefaultDevice();
95
  ctx->queue = [ctx->device newCommandQueue];
96
  ctx->n_buffers = 0;
 
97
 
98
- // determine if we can use MPS
99
- if (MPSSupportsMTLDevice(ctx->device)) {
100
- fprintf(stderr, "%s: using MPS\n", __func__);
101
- } else {
102
- fprintf(stderr, "%s: not using MPS\n", __func__);
103
- GGML_ASSERT(false && "MPS not supported");
104
- }
105
 
106
  #if 0
107
  // compile from source string and show compile log
@@ -111,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
111
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
112
  if (error) {
113
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
114
- exit(1);
115
  }
116
  }
117
  #else
@@ -129,7 +151,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
129
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
130
  if (error) {
131
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
132
- exit(1);
133
  }
134
 
135
  #ifdef GGML_QKK_64
@@ -141,19 +163,27 @@ struct ggml_metal_context * ggml_metal_init(void) {
141
  #endif
142
  if (error) {
143
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
144
- exit(1);
145
  }
146
  }
147
  #endif
148
 
149
  // load kernels
150
  {
 
151
  #define GGML_METAL_ADD_KERNEL(name) \
152
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
153
- ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:nil]; \
154
- fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name);
 
 
 
 
 
 
155
 
156
  GGML_METAL_ADD_KERNEL(add);
 
157
  GGML_METAL_ADD_KERNEL(mul);
158
  GGML_METAL_ADD_KERNEL(mul_row);
159
  GGML_METAL_ADD_KERNEL(scale);
@@ -165,6 +195,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
165
  GGML_METAL_ADD_KERNEL(get_rows_f16);
166
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
167
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
 
168
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
169
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
170
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
@@ -175,11 +206,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
175
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
176
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
177
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
 
178
  GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
179
  GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
180
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
181
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
182
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
 
 
 
 
 
 
 
 
 
183
  GGML_METAL_ADD_KERNEL(rope);
184
  GGML_METAL_ADD_KERNEL(alibi_f32);
185
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -189,12 +230,12 @@ struct ggml_metal_context * ggml_metal_init(void) {
189
  #undef GGML_METAL_ADD_KERNEL
190
  }
191
 
192
- fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
193
- fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
194
  if (ctx->device.maxTransferRate != 0) {
195
- fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
196
  } else {
197
- fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
198
  }
199
 
200
  return ctx;
@@ -202,12 +243,97 @@ struct ggml_metal_context * ggml_metal_init(void) {
202
 
203
  void ggml_metal_free(struct ggml_metal_context * ctx) {
204
  fprintf(stderr, "%s: deallocating\n", __func__);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  for (int i = 0; i < ctx->n_buffers; ++i) {
206
  [ctx->buffers[i].metal release];
207
  }
 
 
 
 
 
 
 
208
  free(ctx);
209
  }
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  // finds the Metal buffer that contains the tensor data on the GPU device
212
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
213
  // Metal buffer based on the host memory pointer
@@ -346,48 +472,154 @@ void ggml_metal_get_tensor(
346
  memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
347
  }
348
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  void ggml_metal_graph_compute(
350
  struct ggml_metal_context * ctx,
351
  struct ggml_cgraph * gf) {
352
  metal_printf("%s: evaluating graph\n", __func__);
353
 
 
 
 
 
 
 
 
 
 
 
 
354
  // create multiple command buffers and enqueue them
355
  // then, we encode the graph into the command buffers in parallel
356
 
357
- const int n_cb = gf->n_threads;
358
-
359
- NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
360
 
361
  for (int i = 0; i < n_cb; ++i) {
362
- command_buffers[i] = [ctx->queue commandBuffer];
363
 
364
  // enqueue the command buffers in order to specify their execution order
365
- [command_buffers[i] enqueue];
366
- }
367
 
368
- // TODO: is this the best way to start threads?
369
- dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
370
 
371
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
372
- const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
373
 
374
- dispatch_async(queue, ^{
375
  size_t offs_src0 = 0;
376
  size_t offs_src1 = 0;
377
  size_t offs_dst = 0;
378
 
379
- id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
380
-
381
- id<MTLComputeCommandEncoder> encoder = nil;
382
 
383
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
384
- const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
 
 
 
 
 
 
 
 
385
 
386
- for (int i = node_start; i < node_end; ++i) {
387
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
388
 
389
- struct ggml_tensor * src0 = gf->nodes[i]->src0;
390
- struct ggml_tensor * src1 = gf->nodes[i]->src1;
391
  struct ggml_tensor * dst = gf->nodes[i];
392
 
393
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
@@ -443,6 +675,7 @@ void ggml_metal_graph_compute(
443
  //}
444
 
445
  switch (dst->op) {
 
446
  case GGML_OP_RESHAPE:
447
  case GGML_OP_VIEW:
448
  case GGML_OP_TRANSPOSE:
@@ -452,14 +685,16 @@ void ggml_metal_graph_compute(
452
  } break;
453
  case GGML_OP_ADD:
454
  {
455
- if (encoder == nil) {
456
- encoder = [command_buffer computeCommandEncoder];
 
 
 
457
  }
458
-
459
- [encoder setComputePipelineState:ctx->pipeline_add];
460
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
461
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
462
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
 
463
 
464
  const int64_t n = ggml_nelements(dst);
465
 
@@ -467,10 +702,6 @@ void ggml_metal_graph_compute(
467
  } break;
468
  case GGML_OP_MUL:
469
  {
470
- if (encoder == nil) {
471
- encoder = [command_buffer computeCommandEncoder];
472
- }
473
-
474
  if (ggml_nelements(src1) == ne10) {
475
  // src1 is a row
476
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -488,10 +719,6 @@ void ggml_metal_graph_compute(
488
  } break;
489
  case GGML_OP_SCALE:
490
  {
491
- if (encoder == nil) {
492
- encoder = [command_buffer computeCommandEncoder];
493
- }
494
-
495
  const float scale = *(const float *) src1->data;
496
 
497
  [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -503,54 +730,46 @@ void ggml_metal_graph_compute(
503
 
504
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
505
  } break;
506
- case GGML_OP_SILU:
507
- {
508
- if (encoder == nil) {
509
- encoder = [command_buffer computeCommandEncoder];
510
- }
511
-
512
- [encoder setComputePipelineState:ctx->pipeline_silu];
513
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
514
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
515
-
516
- const int64_t n = ggml_nelements(dst);
517
-
518
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
519
- } break;
520
- case GGML_OP_RELU:
521
- {
522
- if (encoder == nil) {
523
- encoder = [command_buffer computeCommandEncoder];
524
- }
525
-
526
- [encoder setComputePipelineState:ctx->pipeline_relu];
527
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
528
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
529
-
530
- const int64_t n = ggml_nelements(dst);
531
-
532
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
 
 
 
 
 
 
 
 
 
 
533
  } break;
534
- case GGML_OP_GELU:
535
- {
536
- if (encoder == nil) {
537
- encoder = [command_buffer computeCommandEncoder];
538
- }
539
-
540
- [encoder setComputePipelineState:ctx->pipeline_gelu];
541
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
542
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
543
-
544
- const int64_t n = ggml_nelements(dst);
545
-
546
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
547
- } break;
548
  case GGML_OP_SOFT_MAX:
549
  {
550
- if (encoder == nil) {
551
- encoder = [command_buffer computeCommandEncoder];
552
- }
553
-
554
  const int nth = 32;
555
 
556
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
@@ -565,11 +784,7 @@ void ggml_metal_graph_compute(
565
  } break;
566
  case GGML_OP_DIAG_MASK_INF:
567
  {
568
- if (encoder == nil) {
569
- encoder = [command_buffer computeCommandEncoder];
570
- }
571
-
572
- const int n_past = ((int32_t *)(src1->data))[0];
573
 
574
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
575
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -585,53 +800,44 @@ void ggml_metal_graph_compute(
585
  // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
586
 
587
  GGML_ASSERT(ne00 == ne10);
588
- GGML_ASSERT(ne02 == ne12);
 
 
589
 
 
 
590
  if (ggml_is_contiguous(src0) &&
591
  ggml_is_contiguous(src1) &&
592
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
593
-
594
- if (encoder != nil) {
595
- [encoder endEncoding];
596
- encoder = nil;
597
- }
598
-
599
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
600
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
601
-
602
- // for F32 x F32 we use MPS
603
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
604
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
605
-
606
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
607
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
608
-
609
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
610
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
611
-
612
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
613
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
614
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
615
-
616
- // we need to do ne02 multiplications
617
- // TODO: is there a way to do this in parallel - currently very slow ..
618
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
619
- for (int64_t i02 = 0; i02 < ne02; ++i02) {
620
- size_t offs_src0_cur = offs_src0 + i02*nb02;
621
- size_t offs_src1_cur = offs_src1 + i02*nb12;
622
- size_t offs_dst_cur = offs_dst + i02*nb2;
623
-
624
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
625
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
626
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
627
-
628
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
629
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  } else {
631
- if (encoder == nil) {
632
- encoder = [command_buffer computeCommandEncoder];
633
- }
634
-
635
  int nth0 = 32;
636
  int nth1 = 1;
637
 
@@ -639,8 +845,6 @@ void ggml_metal_graph_compute(
639
  switch (src0t) {
640
  case GGML_TYPE_F16:
641
  {
642
- GGML_ASSERT(ne02 == ne12);
643
-
644
  nth0 = 64;
645
  nth1 = 1;
646
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
@@ -663,13 +867,22 @@ void ggml_metal_graph_compute(
663
  nth1 = 8;
664
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
665
  } break;
 
 
 
 
 
 
 
 
 
666
  case GGML_TYPE_Q2_K:
667
  {
668
  GGML_ASSERT(ne02 == 1);
669
  GGML_ASSERT(ne12 == 1);
670
 
671
- nth0 = 4;
672
- nth1 = 16;
673
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
674
  } break;
675
  case GGML_TYPE_Q3_K:
@@ -677,8 +890,8 @@ void ggml_metal_graph_compute(
677
  GGML_ASSERT(ne02 == 1);
678
  GGML_ASSERT(ne12 == 1);
679
 
680
- nth0 = 4;
681
- nth1 = 16;
682
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
683
  } break;
684
  case GGML_TYPE_Q4_K:
@@ -686,8 +899,8 @@ void ggml_metal_graph_compute(
686
  GGML_ASSERT(ne02 == 1);
687
  GGML_ASSERT(ne12 == 1);
688
 
689
- nth0 = 4;
690
- nth1 = 16;
691
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
692
  } break;
693
  case GGML_TYPE_Q5_K:
@@ -695,8 +908,8 @@ void ggml_metal_graph_compute(
695
  GGML_ASSERT(ne02 == 1);
696
  GGML_ASSERT(ne12 == 1);
697
 
698
- nth0 = 4;
699
- nth1 = 16;
700
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
701
  } break;
702
  case GGML_TYPE_Q6_K:
@@ -704,8 +917,8 @@ void ggml_metal_graph_compute(
704
  GGML_ASSERT(ne02 == 1);
705
  GGML_ASSERT(ne12 == 1);
706
 
707
- nth0 = 4;
708
- nth1 = 16;
709
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
710
  } break;
711
  default:
@@ -720,28 +933,36 @@ void ggml_metal_graph_compute(
720
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
721
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
722
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
723
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
724
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
725
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
726
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
727
- [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
728
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
729
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
730
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
731
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
732
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
733
-
734
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
735
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
736
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 
 
 
 
 
 
 
 
 
 
737
  }
738
- else if (src0t == GGML_TYPE_Q2_K ||
739
- src0t == GGML_TYPE_Q3_K ||
740
- src0t == GGML_TYPE_Q4_K ||
741
- src0t == GGML_TYPE_Q5_K ||
742
- src0t == GGML_TYPE_Q6_K) {
743
- [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
744
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
745
  } else {
746
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
747
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -750,14 +971,11 @@ void ggml_metal_graph_compute(
750
  } break;
751
  case GGML_OP_GET_ROWS:
752
  {
753
- if (encoder == nil) {
754
- encoder = [command_buffer computeCommandEncoder];
755
- }
756
-
757
  switch (src0->type) {
758
- case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
759
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
760
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
 
761
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
762
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
763
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
@@ -779,13 +997,10 @@ void ggml_metal_graph_compute(
779
  } break;
780
  case GGML_OP_RMS_NORM:
781
  {
782
- if (encoder == nil) {
783
- encoder = [command_buffer computeCommandEncoder];
784
- }
785
 
786
- const float eps = 1e-6f;
787
-
788
- const int nth = 256;
789
 
790
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
791
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -793,7 +1008,7 @@ void ggml_metal_graph_compute(
793
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
794
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
795
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
796
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
797
 
798
  const int64_t nrows = ggml_nrows(src0);
799
 
@@ -801,20 +1016,17 @@ void ggml_metal_graph_compute(
801
  } break;
802
  case GGML_OP_NORM:
803
  {
804
- if (encoder == nil) {
805
- encoder = [command_buffer computeCommandEncoder];
806
- }
807
-
808
- const float eps = 1e-5f;
809
 
810
  const int nth = 256;
811
 
812
  [encoder setComputePipelineState:ctx->pipeline_norm];
813
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
814
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
815
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
816
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
817
- [encoder setBytes:&eps length:sizeof( float) atIndex:4];
818
  [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
819
 
820
  const int64_t nrows = ggml_nrows(src0);
@@ -823,15 +1035,12 @@ void ggml_metal_graph_compute(
823
  } break;
824
  case GGML_OP_ALIBI:
825
  {
826
- if (encoder == nil) {
827
- encoder = [command_buffer computeCommandEncoder];
828
- }
829
-
830
  GGML_ASSERT((src0t == GGML_TYPE_F32));
831
 
832
- const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past);
833
- const int n_head = ((int32_t *) src1->data)[1];
834
- const float max_bias = ((float *) src1->data)[2];
 
835
 
836
  if (__builtin_popcount(n_head) != 1) {
837
  GGML_ASSERT(false && "only power-of-two n_head implemented");
@@ -860,51 +1069,53 @@ void ggml_metal_graph_compute(
860
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
861
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
862
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
 
863
  const int nth = 32;
 
864
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
865
  } break;
866
  case GGML_OP_ROPE:
867
  {
868
- if (encoder == nil) {
869
- encoder = [command_buffer computeCommandEncoder];
870
- }
871
-
872
- const int n_dims = ((int32_t *) src1->data)[1];
873
- const int mode = ((int32_t *) src1->data)[2];
874
 
875
- const int n_past = ((int32_t *)(src1->data))[0];
 
 
 
876
 
877
  [encoder setComputePipelineState:ctx->pipeline_rope];
878
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
879
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
880
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
881
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
882
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
883
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
884
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
885
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
886
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
887
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
888
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
889
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
890
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
891
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
892
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
893
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
894
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
895
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
896
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
897
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
898
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
 
 
899
 
900
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
901
  } break;
 
902
  case GGML_OP_CPY:
 
903
  {
904
- if (encoder == nil) {
905
- encoder = [command_buffer computeCommandEncoder];
906
- }
907
-
908
  const int nth = 32;
909
 
910
  switch (src0t) {
@@ -927,30 +1138,32 @@ void ggml_metal_graph_compute(
927
  default: GGML_ASSERT(false && "not implemented");
928
  }
929
 
930
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
931
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
932
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
933
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
934
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
935
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
936
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
937
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
938
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
939
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
940
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
941
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
942
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
943
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
944
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
945
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
946
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
947
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
948
 
949
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
950
  } break;
951
  default:
952
- fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
953
- GGML_ASSERT(false);
 
 
954
  }
955
  }
956
 
@@ -964,17 +1177,19 @@ void ggml_metal_graph_compute(
964
  }
965
 
966
  // wait for all threads to finish
967
- dispatch_barrier_sync(queue, ^{});
968
-
969
- [command_buffers[n_cb - 1] waitUntilCompleted];
970
 
971
  // check status of command buffers
972
  // needed to detect if the device ran out-of-memory for example (#1881)
973
  for (int i = 0; i < n_cb; i++) {
974
- MTLCommandBufferStatus status = (MTLCommandBufferStatus) [command_buffers[i] status];
 
 
975
  if (status != MTLCommandBufferStatusCompleted) {
976
  fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
977
  GGML_ASSERT(false);
978
  }
979
  }
 
 
980
  }
 
5
  #import <Foundation/Foundation.h>
6
 
7
  #import <Metal/Metal.h>
8
+
9
+ #undef MIN
10
+ #undef MAX
11
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
12
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
13
 
14
  #ifdef GGML_METAL_NDEBUG
15
  #define metal_printf(...)
 
19
 
20
  #define UNUSED(x) (void)(x)
21
 
22
+ #define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
23
+
24
  struct ggml_metal_buffer {
25
  const char * name;
26
 
 
31
  };
32
 
33
  struct ggml_metal_context {
34
+ int n_cb;
35
 
36
  id<MTLDevice> device;
37
  id<MTLCommandQueue> queue;
38
  id<MTLLibrary> library;
39
 
40
+ id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
41
+ id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
42
+
43
+ dispatch_queue_t d_queue;
44
+
45
  int n_buffers;
46
  struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
47
 
48
+ int concur_list[GGML_MAX_CONCUR];
49
+ int concur_list_len;
50
+
51
  // custom kernels
52
  #define GGML_METAL_DECL_KERNEL(name) \
53
  id<MTLFunction> function_##name; \
54
  id<MTLComputePipelineState> pipeline_##name
55
 
56
  GGML_METAL_DECL_KERNEL(add);
57
+ GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
58
  GGML_METAL_DECL_KERNEL(mul);
59
  GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
60
  GGML_METAL_DECL_KERNEL(scale);
 
66
  GGML_METAL_DECL_KERNEL(get_rows_f16);
67
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
68
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
69
+ GGML_METAL_DECL_KERNEL(get_rows_q8_0);
70
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
71
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
72
  GGML_METAL_DECL_KERNEL(get_rows_q4_K);
 
77
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
78
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
79
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
80
+ GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
81
  GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
82
  GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
83
  GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
84
  GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
85
  GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
86
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
87
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
88
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
89
+ GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
90
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
91
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
92
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
94
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
95
  GGML_METAL_DECL_KERNEL(rope);
96
  GGML_METAL_DECL_KERNEL(alibi_f32);
97
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
 
112
  @implementation GGMLMetalClass
113
  @end
114
 
115
+ struct ggml_metal_context * ggml_metal_init(int n_cb) {
116
  fprintf(stderr, "%s: allocating\n", __func__);
117
 
118
  struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
119
 
120
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
121
  ctx->device = MTLCreateSystemDefaultDevice();
122
  ctx->queue = [ctx->device newCommandQueue];
123
  ctx->n_buffers = 0;
124
+ ctx->concur_list_len = 0;
125
 
126
+ ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
 
 
 
 
 
 
127
 
128
  #if 0
129
  // compile from source string and show compile log
 
133
  ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
134
  if (error) {
135
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
136
+ return NULL;
137
  }
138
  }
139
  #else
 
151
  NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
152
  if (error) {
153
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
154
+ return NULL;
155
  }
156
 
157
  #ifdef GGML_QKK_64
 
163
  #endif
164
  if (error) {
165
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
166
+ return NULL;
167
  }
168
  }
169
  #endif
170
 
171
  // load kernels
172
  {
173
+ NSError * error = nil;
174
  #define GGML_METAL_ADD_KERNEL(name) \
175
  ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
176
+ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
177
+ fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
178
+ (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
179
+ (int) ctx->pipeline_##name.threadExecutionWidth); \
180
+ if (error) { \
181
+ fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
182
+ return NULL; \
183
+ }
184
 
185
  GGML_METAL_ADD_KERNEL(add);
186
+ GGML_METAL_ADD_KERNEL(add_row);
187
  GGML_METAL_ADD_KERNEL(mul);
188
  GGML_METAL_ADD_KERNEL(mul_row);
189
  GGML_METAL_ADD_KERNEL(scale);
 
195
  GGML_METAL_ADD_KERNEL(get_rows_f16);
196
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
197
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
198
+ GGML_METAL_ADD_KERNEL(get_rows_q8_0);
199
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
200
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
201
  GGML_METAL_ADD_KERNEL(get_rows_q4_K);
 
206
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
207
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
208
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
209
+ GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
210
  GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
211
  GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
212
  GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
213
  GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
214
  GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
215
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
216
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
217
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
218
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
219
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
220
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
221
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
222
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
223
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
224
  GGML_METAL_ADD_KERNEL(rope);
225
  GGML_METAL_ADD_KERNEL(alibi_f32);
226
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
 
230
  #undef GGML_METAL_ADD_KERNEL
231
  }
232
 
233
+ fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
234
+ fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
235
  if (ctx->device.maxTransferRate != 0) {
236
+ fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
237
  } else {
238
+ fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
239
  }
240
 
241
  return ctx;
 
243
 
244
  void ggml_metal_free(struct ggml_metal_context * ctx) {
245
  fprintf(stderr, "%s: deallocating\n", __func__);
246
+ #define GGML_METAL_DEL_KERNEL(name) \
247
+ [ctx->function_##name release]; \
248
+ [ctx->pipeline_##name release];
249
+
250
+ GGML_METAL_DEL_KERNEL(add);
251
+ GGML_METAL_DEL_KERNEL(add_row);
252
+ GGML_METAL_DEL_KERNEL(mul);
253
+ GGML_METAL_DEL_KERNEL(mul_row);
254
+ GGML_METAL_DEL_KERNEL(scale);
255
+ GGML_METAL_DEL_KERNEL(silu);
256
+ GGML_METAL_DEL_KERNEL(relu);
257
+ GGML_METAL_DEL_KERNEL(gelu);
258
+ GGML_METAL_DEL_KERNEL(soft_max);
259
+ GGML_METAL_DEL_KERNEL(diag_mask_inf);
260
+ GGML_METAL_DEL_KERNEL(get_rows_f16);
261
+ GGML_METAL_DEL_KERNEL(get_rows_q4_0);
262
+ GGML_METAL_DEL_KERNEL(get_rows_q4_1);
263
+ GGML_METAL_DEL_KERNEL(get_rows_q8_0);
264
+ GGML_METAL_DEL_KERNEL(get_rows_q2_K);
265
+ GGML_METAL_DEL_KERNEL(get_rows_q3_K);
266
+ GGML_METAL_DEL_KERNEL(get_rows_q4_K);
267
+ GGML_METAL_DEL_KERNEL(get_rows_q5_K);
268
+ GGML_METAL_DEL_KERNEL(get_rows_q6_K);
269
+ GGML_METAL_DEL_KERNEL(rms_norm);
270
+ GGML_METAL_DEL_KERNEL(norm);
271
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
272
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
273
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
274
+ GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
275
+ GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
276
+ GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
277
+ GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
278
+ GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
279
+ GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
280
+ GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
281
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
282
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
283
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
284
+ GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
285
+ GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
286
+ GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
287
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
288
+ GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
289
+ GGML_METAL_DEL_KERNEL(rope);
290
+ GGML_METAL_DEL_KERNEL(alibi_f32);
291
+ GGML_METAL_DEL_KERNEL(cpy_f32_f16);
292
+ GGML_METAL_DEL_KERNEL(cpy_f32_f32);
293
+ GGML_METAL_DEL_KERNEL(cpy_f16_f16);
294
+
295
+ #undef GGML_METAL_DEL_KERNEL
296
+
297
  for (int i = 0; i < ctx->n_buffers; ++i) {
298
  [ctx->buffers[i].metal release];
299
  }
300
+
301
+ [ctx->library release];
302
+ [ctx->queue release];
303
+ [ctx->device release];
304
+
305
+ dispatch_release(ctx->d_queue);
306
+
307
  free(ctx);
308
  }
309
 
310
+ void * ggml_metal_host_malloc(size_t n) {
311
+ void * data = NULL;
312
+ const int result = posix_memalign((void **) &data, getpagesize(), n);
313
+ if (result != 0) {
314
+ fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
315
+ return NULL;
316
+ }
317
+
318
+ return data;
319
+ }
320
+
321
+ void ggml_metal_host_free(void * data) {
322
+ free(data);
323
+ }
324
+
325
+ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
326
+ ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
327
+ }
328
+
329
+ int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
330
+ return ctx->concur_list_len;
331
+ }
332
+
333
+ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
334
+ return ctx->concur_list;
335
+ }
336
+
337
  // finds the Metal buffer that contains the tensor data on the GPU device
338
  // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
339
  // Metal buffer based on the host memory pointer
 
472
  memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
473
  }
474
 
475
+ void ggml_metal_graph_find_concurrency(
476
+ struct ggml_metal_context * ctx,
477
+ struct ggml_cgraph * gf, bool check_mem) {
478
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
479
+ int nodes_unused[GGML_MAX_CONCUR];
480
+
481
+ for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
482
+ for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
483
+ ctx->concur_list_len = 0;
484
+
485
+ int n_left = gf->n_nodes;
486
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
487
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
488
+
489
+ while (n_left > 0) {
490
+ // number of nodes at a layer (that can be issued concurrently)
491
+ int concurrency = 0;
492
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
493
+ if (nodes_unused[i]) {
494
+ // if the requirements for gf->nodes[i] are satisfied
495
+ int exe_flag = 1;
496
+
497
+ // scan all srcs
498
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
499
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
500
+ if (src_cur) {
501
+ // if is leaf nodes it's satisfied.
502
+ // TODO: ggml_is_leaf()
503
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
504
+ continue;
505
+ }
506
+
507
+ // otherwise this src should be the output from previous nodes.
508
+ int is_found = 0;
509
+
510
+ // scan 2*search_depth back because we inserted barrier.
511
+ //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
512
+ for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
513
+ if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
514
+ is_found = 1;
515
+ break;
516
+ }
517
+ }
518
+ if (is_found == 0) {
519
+ exe_flag = 0;
520
+ break;
521
+ }
522
+ }
523
+ }
524
+ if (exe_flag && check_mem) {
525
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
526
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
527
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
528
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
529
+ for (int j = n_start; j < i; j++) {
530
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
531
+ && gf->nodes[j]->op != GGML_OP_VIEW \
532
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
533
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
534
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
535
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
536
+ continue;
537
+ }
538
+
539
+ exe_flag = 0;
540
+ }
541
+ }
542
+ }
543
+ if (exe_flag) {
544
+ ctx->concur_list[level_pos + concurrency] = i;
545
+ nodes_unused[i] = 0;
546
+ concurrency++;
547
+ ctx->concur_list_len++;
548
+ }
549
+ }
550
+ }
551
+ n_left -= concurrency;
552
+ // adding a barrier different layer
553
+ ctx->concur_list[level_pos + concurrency] = -1;
554
+ ctx->concur_list_len++;
555
+ // jump all sorted nodes at nodes_bak
556
+ while (!nodes_unused[n_start]) {
557
+ n_start++;
558
+ }
559
+ level_pos += concurrency + 1;
560
+ }
561
+
562
+ if (ctx->concur_list_len > GGML_MAX_CONCUR) {
563
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
564
+ }
565
+ }
566
+
567
  void ggml_metal_graph_compute(
568
  struct ggml_metal_context * ctx,
569
  struct ggml_cgraph * gf) {
570
  metal_printf("%s: evaluating graph\n", __func__);
571
 
572
+ @autoreleasepool {
573
+
574
+ // if there is ctx->concur_list, dispatch concurrently
575
+ // else fallback to serial dispatch
576
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
577
+
578
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
579
+
580
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
581
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
582
+
583
  // create multiple command buffers and enqueue them
584
  // then, we encode the graph into the command buffers in parallel
585
 
586
+ const int n_cb = ctx->n_cb;
 
 
587
 
588
  for (int i = 0; i < n_cb; ++i) {
589
+ ctx->command_buffers[i] = [ctx->queue commandBuffer];
590
 
591
  // enqueue the command buffers in order to specify their execution order
592
+ [ctx->command_buffers[i] enqueue];
 
593
 
594
+ ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
595
+ }
596
 
597
  for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
598
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
599
 
600
+ dispatch_async(ctx->d_queue, ^{
601
  size_t offs_src0 = 0;
602
  size_t offs_src1 = 0;
603
  size_t offs_dst = 0;
604
 
605
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
606
+ id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
 
607
 
608
  const int node_start = (cb_idx + 0) * n_nodes_per_cb;
609
+ const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
610
+
611
+ for (int ind = node_start; ind < node_end; ++ind) {
612
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
613
+
614
+ if (i == -1) {
615
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
616
+ continue;
617
+ }
618
 
 
619
  metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
620
 
621
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
622
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
623
  struct ggml_tensor * dst = gf->nodes[i];
624
 
625
  const int64_t ne00 = src0 ? src0->ne[0] : 0;
 
675
  //}
676
 
677
  switch (dst->op) {
678
+ case GGML_OP_NONE:
679
  case GGML_OP_RESHAPE:
680
  case GGML_OP_VIEW:
681
  case GGML_OP_TRANSPOSE:
 
685
  } break;
686
  case GGML_OP_ADD:
687
  {
688
+ if (ggml_nelements(src1) == ne10) {
689
+ // src1 is a row
690
+ [encoder setComputePipelineState:ctx->pipeline_add_row];
691
+ } else {
692
+ [encoder setComputePipelineState:ctx->pipeline_add];
693
  }
 
 
694
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
695
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
696
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
697
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
698
 
699
  const int64_t n = ggml_nelements(dst);
700
 
 
702
  } break;
703
  case GGML_OP_MUL:
704
  {
 
 
 
 
705
  if (ggml_nelements(src1) == ne10) {
706
  // src1 is a row
707
  [encoder setComputePipelineState:ctx->pipeline_mul_row];
 
719
  } break;
720
  case GGML_OP_SCALE:
721
  {
 
 
 
 
722
  const float scale = *(const float *) src1->data;
723
 
724
  [encoder setComputePipelineState:ctx->pipeline_scale];
 
730
 
731
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
732
  } break;
733
+ case GGML_OP_UNARY:
734
+ switch (ggml_get_unary_op(gf->nodes[i])) {
735
+ case GGML_UNARY_OP_SILU:
736
+ {
737
+ [encoder setComputePipelineState:ctx->pipeline_silu];
738
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
739
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
740
+
741
+ const int64_t n = ggml_nelements(dst);
742
+
743
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
744
+ } break;
745
+ case GGML_UNARY_OP_RELU:
746
+ {
747
+ [encoder setComputePipelineState:ctx->pipeline_relu];
748
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
749
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
750
+
751
+ const int64_t n = ggml_nelements(dst);
752
+
753
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
754
+ } break;
755
+ case GGML_UNARY_OP_GELU:
756
+ {
757
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
758
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
759
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
760
+
761
+ const int64_t n = ggml_nelements(dst);
762
+
763
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
764
+ } break;
765
+ default:
766
+ {
767
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
768
+ GGML_ASSERT(false);
769
+ }
770
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
  case GGML_OP_SOFT_MAX:
772
  {
 
 
 
 
773
  const int nth = 32;
774
 
775
  [encoder setComputePipelineState:ctx->pipeline_soft_max];
 
784
  } break;
785
  case GGML_OP_DIAG_MASK_INF:
786
  {
787
+ const int n_past = ((int32_t *)(dst->op_params))[0];
 
 
 
 
788
 
789
  [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
790
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
800
  // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
801
 
802
  GGML_ASSERT(ne00 == ne10);
803
+ // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
804
+ uint gqa = ne12/ne02;
805
+ GGML_ASSERT(ne03 == ne13);
806
 
807
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
808
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
809
  if (ggml_is_contiguous(src0) &&
810
  ggml_is_contiguous(src1) &&
811
+ src1t == GGML_TYPE_F32 &&
812
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
813
+ ne00%32 == 0 &&
814
+ ne11 > 1) {
815
+ switch (src0->type) {
816
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
817
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
818
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
819
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
820
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
821
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
822
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
823
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
824
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
825
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
  }
827
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
828
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
829
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
830
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
831
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
832
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
833
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
834
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
835
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
836
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
837
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
838
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
839
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
840
  } else {
 
 
 
 
841
  int nth0 = 32;
842
  int nth1 = 1;
843
 
 
845
  switch (src0t) {
846
  case GGML_TYPE_F16:
847
  {
 
 
848
  nth0 = 64;
849
  nth1 = 1;
850
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
 
867
  nth1 = 8;
868
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
869
  } break;
870
+ case GGML_TYPE_Q8_0:
871
+ {
872
+ GGML_ASSERT(ne02 == 1);
873
+ GGML_ASSERT(ne12 == 1);
874
+
875
+ nth0 = 8;
876
+ nth1 = 8;
877
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
878
+ } break;
879
  case GGML_TYPE_Q2_K:
880
  {
881
  GGML_ASSERT(ne02 == 1);
882
  GGML_ASSERT(ne12 == 1);
883
 
884
+ nth0 = 2;
885
+ nth1 = 32;
886
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
887
  } break;
888
  case GGML_TYPE_Q3_K:
 
890
  GGML_ASSERT(ne02 == 1);
891
  GGML_ASSERT(ne12 == 1);
892
 
893
+ nth0 = 2;
894
+ nth1 = 32;
895
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
896
  } break;
897
  case GGML_TYPE_Q4_K:
 
899
  GGML_ASSERT(ne02 == 1);
900
  GGML_ASSERT(ne12 == 1);
901
 
902
+ nth0 = 2;
903
+ nth1 = 32;
904
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
905
  } break;
906
  case GGML_TYPE_Q5_K:
 
908
  GGML_ASSERT(ne02 == 1);
909
  GGML_ASSERT(ne12 == 1);
910
 
911
+ nth0 = 2;
912
+ nth1 = 32;
913
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
914
  } break;
915
  case GGML_TYPE_Q6_K:
 
917
  GGML_ASSERT(ne02 == 1);
918
  GGML_ASSERT(ne12 == 1);
919
 
920
+ nth0 = 2;
921
+ nth1 = 32;
922
  [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
923
  } break;
924
  default:
 
933
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
934
  [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
935
  [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
936
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
937
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
938
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
939
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
940
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
941
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
942
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
943
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
944
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
945
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
946
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
947
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
948
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
949
+
950
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
951
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
952
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
953
+ }
954
+ else if (src0t == GGML_TYPE_Q3_K) {
955
+ #ifdef GGML_QKK_64
956
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
957
+ #else
958
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
959
+ #endif
960
  }
961
+ else if (src0t == GGML_TYPE_Q5_K) {
962
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
963
+ }
964
+ else if (src0t == GGML_TYPE_Q6_K) {
965
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 
 
966
  } else {
967
  [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
968
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
 
971
  } break;
972
  case GGML_OP_GET_ROWS:
973
  {
 
 
 
 
974
  switch (src0->type) {
975
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
976
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
977
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
978
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
979
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
980
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
981
  case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
 
997
  } break;
998
  case GGML_OP_RMS_NORM:
999
  {
1000
+ float eps;
1001
+ memcpy(&eps, dst->op_params, sizeof(float));
 
1002
 
1003
+ const int nth = 512;
 
 
1004
 
1005
  [encoder setComputePipelineState:ctx->pipeline_rms_norm];
1006
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1008
  [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1009
  [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1010
  [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1011
+ [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
1012
 
1013
  const int64_t nrows = ggml_nrows(src0);
1014
 
 
1016
  } break;
1017
  case GGML_OP_NORM:
1018
  {
1019
+ float eps;
1020
+ memcpy(&eps, dst->op_params, sizeof(float));
 
 
 
1021
 
1022
  const int nth = 256;
1023
 
1024
  [encoder setComputePipelineState:ctx->pipeline_norm];
1025
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1026
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1027
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1028
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
1029
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
1030
  [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
1031
 
1032
  const int64_t nrows = ggml_nrows(src0);
 
1035
  } break;
1036
  case GGML_OP_ALIBI:
1037
  {
 
 
 
 
1038
  GGML_ASSERT((src0t == GGML_TYPE_F32));
1039
 
1040
+ const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
1041
+ const int n_head = ((int32_t *) dst->op_params)[1];
1042
+ float max_bias;
1043
+ memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
1044
 
1045
  if (__builtin_popcount(n_head) != 1) {
1046
  GGML_ASSERT(false && "only power-of-two n_head implemented");
 
1069
  [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1070
  [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1071
  [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1072
+
1073
  const int nth = 32;
1074
+
1075
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1076
  } break;
1077
  case GGML_OP_ROPE:
1078
  {
1079
+ const int n_past = ((int32_t *) dst->op_params)[0];
1080
+ const int n_dims = ((int32_t *) dst->op_params)[1];
1081
+ const int mode = ((int32_t *) dst->op_params)[2];
 
 
 
1082
 
1083
+ float freq_base;
1084
+ float freq_scale;
1085
+ memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
1086
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
1087
 
1088
  [encoder setComputePipelineState:ctx->pipeline_rope];
1089
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1090
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1091
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1092
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1093
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1094
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1095
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1096
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1097
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1098
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1099
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1100
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1101
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1102
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1103
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1104
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1105
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1106
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1107
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
1108
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
1109
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
1110
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
1111
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
1112
 
1113
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1114
  } break;
1115
+ case GGML_OP_DUP:
1116
  case GGML_OP_CPY:
1117
+ case GGML_OP_CONT:
1118
  {
 
 
 
 
1119
  const int nth = 32;
1120
 
1121
  switch (src0t) {
 
1138
  default: GGML_ASSERT(false && "not implemented");
1139
  }
1140
 
1141
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1142
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1143
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
1144
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
1145
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
1146
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
1147
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
1148
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
1149
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
1150
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
1151
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
1152
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
1153
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
1154
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
1155
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
1156
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
1157
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
1158
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1159
 
1160
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1161
  } break;
1162
  default:
1163
+ {
1164
+ fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
1165
+ GGML_ASSERT(false);
1166
+ }
1167
  }
1168
  }
1169
 
 
1177
  }
1178
 
1179
  // wait for all threads to finish
1180
+ dispatch_barrier_sync(ctx->d_queue, ^{});
 
 
1181
 
1182
  // check status of command buffers
1183
  // needed to detect if the device ran out-of-memory for example (#1881)
1184
  for (int i = 0; i < n_cb; i++) {
1185
+ [ctx->command_buffers[i] waitUntilCompleted];
1186
+
1187
+ MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
1188
  if (status != MTLCommandBufferStatusCompleted) {
1189
  fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
1190
  GGML_ASSERT(false);
1191
  }
1192
  }
1193
+
1194
+ }
1195
  }
ggml-metal.metal CHANGED
@@ -18,46 +18,11 @@ typedef struct {
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
  } block_q4_1;
20
 
21
- static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
22
- const int qk = QK4_0;
23
-
24
- assert(k % qk == 0);
25
-
26
- const int nb = k / qk;
27
-
28
- for (int i = 0; i < nb; i++) {
29
- const half d = x[i].d;
30
-
31
- for (int j = 0; j < qk/2; ++j) {
32
- const int x0 = (x[i].qs[j] & 0x0F) - 8;
33
- const int x1 = (x[i].qs[j] >> 4) - 8;
34
-
35
- y[i*qk + j + 0 ] = x0*d;
36
- y[i*qk + j + qk/2] = x1*d;
37
- }
38
- }
39
- }
40
-
41
- static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
42
- const int qk = QK4_1;
43
-
44
- assert(k % qk == 0);
45
-
46
- const int nb = k / qk;
47
-
48
- for (int i = 0; i < nb; i++) {
49
- const half d = x[i].d;
50
- const half m = x[i].m;
51
-
52
- for (int j = 0; j < qk/2; ++j) {
53
- const int x0 = (x[i].qs[j] & 0x0F);
54
- const int x1 = (x[i].qs[j] >> 4);
55
-
56
- y[i*qk + j + 0 ] = x0*d + m;
57
- y[i*qk + j + qk/2] = x1*d + m;
58
- }
59
- }
60
- }
61
 
62
  kernel void kernel_add(
63
  device const float * src0,
@@ -67,6 +32,17 @@ kernel void kernel_add(
67
  dst[tpig] = src0[tpig] + src1[tpig];
68
  }
69
 
 
 
 
 
 
 
 
 
 
 
 
70
  kernel void kernel_mul(
71
  device const float * src0,
72
  device const float * src1,
@@ -117,7 +93,12 @@ kernel void kernel_gelu(
117
  device float * dst,
118
  uint tpig[[thread_position_in_grid]]) {
119
  float x = src0[tpig];
120
- dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 
 
 
 
 
121
  }
122
 
123
  kernel void kernel_soft_max(
@@ -208,54 +189,6 @@ kernel void kernel_diag_mask_inf(
208
  }
209
  }
210
 
211
- kernel void kernel_get_rows_f16(
212
- device const void * src0,
213
- device const int * src1,
214
- device float * dst,
215
- constant int64_t & ne00,
216
- constant uint64_t & nb01,
217
- constant uint64_t & nb1,
218
- uint tpig[[thread_position_in_grid]]) {
219
- const int i = tpig;
220
- const int r = ((device int32_t *) src1)[i];
221
-
222
- for (int j = 0; j < ne00; j++) {
223
- dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
224
- }
225
- }
226
-
227
- kernel void kernel_get_rows_q4_0(
228
- device const void * src0,
229
- device const int * src1,
230
- device float * dst,
231
- constant int64_t & ne00,
232
- constant uint64_t & nb01,
233
- constant uint64_t & nb1,
234
- uint tpig[[thread_position_in_grid]]) {
235
- const int i = tpig;
236
- const int r = ((device int32_t *) src1)[i];
237
-
238
- dequantize_row_q4_0(
239
- (device const block_q4_0 *) ((device char *) src0 + r*nb01),
240
- (device float *) ((device char *) dst + i*nb1), ne00);
241
- }
242
-
243
- kernel void kernel_get_rows_q4_1(
244
- device const void * src0,
245
- device const int * src1,
246
- device float * dst,
247
- constant int64_t & ne00,
248
- constant uint64_t & nb01,
249
- constant uint64_t & nb1,
250
- uint tpig[[thread_position_in_grid]]) {
251
- const int i = tpig;
252
- const int r = ((device int32_t *) src1)[i];
253
-
254
- dequantize_row_q4_1(
255
- (device const block_q4_1 *) ((device char *) src0 + r*nb01),
256
- (device float *) ((device char *) dst + i*nb1), ne00);
257
- }
258
-
259
  kernel void kernel_norm(
260
  device const void * src0,
261
  device float * dst,
@@ -331,26 +264,33 @@ kernel void kernel_rms_norm(
331
  threadgroup float * sum [[threadgroup(0)]],
332
  uint tgpig[[threadgroup_position_in_grid]],
333
  uint tpitg[[thread_position_in_threadgroup]],
 
 
334
  uint ntg[[threads_per_threadgroup]]) {
335
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
 
 
 
336
 
337
  // parallel sum
338
- sum[tpitg] = 0.0f;
339
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
340
- sum[tpitg] += x[i00] * x[i00];
 
 
 
 
341
  }
342
 
343
- // reduce
344
  threadgroup_barrier(mem_flags::mem_threadgroup);
345
- for (uint i = ntg/2; i > 0; i /= 2) {
346
- if (tpitg < i) {
347
- sum[tpitg] += sum[tpitg + i];
348
- }
349
- threadgroup_barrier(mem_flags::mem_threadgroup);
350
  }
351
-
352
- // broadcast
353
  if (tpitg == 0) {
 
354
  sum[0] /= ne00;
355
  }
356
 
@@ -359,146 +299,201 @@ kernel void kernel_rms_norm(
359
  const float mean = sum[0];
360
  const float scale = 1.0f/sqrt(mean + eps);
361
 
362
- device float * y = dst + tgpig*ne00;
363
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
 
364
  y[i00] = x[i00] * scale;
365
  }
 
 
 
366
  }
367
 
368
- kernel void kernel_mul_mat_q4_0_f32(
369
- device const void * src0,
370
- device const float * src1,
371
- device float * dst,
372
- constant int64_t & ne00,
373
- constant int64_t & ne10,
374
- constant int64_t & ne0,
375
- threadgroup float * sum [[threadgroup(0)]],
376
- uint2 tgpig[[threadgroup_position_in_grid]],
377
- uint2 tpitg[[thread_position_in_threadgroup]],
378
- uint2 tptg[[threads_per_threadgroup]]) {
379
- const int nb = ne00/QK4_0;
380
-
381
- const int64_t r0 = tgpig.x;
382
- const int64_t r1 = tgpig.y;
383
-
384
- device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
385
- device const float * y = (device const float *) src1 + r1*ne10;
386
-
387
- const int nth = tptg.x*tptg.y;
388
- const int ith = tptg.y*tpitg.x + tpitg.y;
389
-
390
- const int ix = tpitg.y/4; // 0 or 1
391
- const int iy = tpitg.y - 4*ix; // 0...3
392
-
393
- const int first = 4 * iy;
394
-
395
- float sumf = 0;
396
-
397
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
398
-
399
- const float d = (float)x[i].d;
400
-
401
- device const uint8_t * xl = x[i].qs + first;
402
- device const float * yl = y + i * QK4_0 + first;
403
-
404
- float2 acc = {0.0f, 0.0f};
405
 
406
- for (int j = 0; j < 4; ++j) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
- acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
409
- acc[1] += yl[j] + yl[j+16];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
 
 
411
  }
412
 
413
- sumf += d * (acc[0] - 8.f*acc[1]);
414
  }
415
 
416
- sum[ith] = sumf;
417
-
418
- //
419
- // Accumulate the sum from all threads in the threadgroup
420
- //
421
- threadgroup_barrier(mem_flags::mem_threadgroup);
422
- if (ith%4 == 0) {
423
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
424
- }
425
- threadgroup_barrier(mem_flags::mem_threadgroup);
426
- if (ith%16 == 0) {
427
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
428
- }
429
- threadgroup_barrier(mem_flags::mem_threadgroup);
430
- if (ith == 0) {
431
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
432
- dst[r1*ne0 + r0] = sum[0];
433
  }
434
  }
435
 
436
- kernel void kernel_mul_mat_q4_1_f32(
437
  device const void * src0,
438
  device const float * src1,
439
  device float * dst,
440
  constant int64_t & ne00,
441
- constant int64_t & ne10,
442
- constant int64_t & ne0,
443
- threadgroup float * sum [[threadgroup(0)]],
444
- uint2 tgpig[[threadgroup_position_in_grid]],
445
- uint2 tpitg[[thread_position_in_threadgroup]],
446
- uint2 tptg[[threads_per_threadgroup]]) {
447
- const int nb = ne00/QK4_1;
448
-
449
- const int64_t r0 = tgpig.x;
450
- const int64_t r1 = tgpig.y;
451
-
452
- device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
453
- device const float * y = (device const float *) src1 + r1*ne10;
454
-
455
- const uint nth = tptg.x*tptg.y;
456
- const uint ith = tptg.y*tpitg.x + tpitg.y;
457
-
458
- const int ix = tpitg.y/4; // 0 or 1
459
- const int iy = tpitg.y - 4*ix; // 0...3
460
-
461
- const int first = 4 * iy;
462
-
463
- float sumf = 0;
464
-
465
- for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
466
-
467
- const float d = (float)x[i].d;
468
- const float m = (float)x[i].m;
469
-
470
- device const uint8_t * xl = x[i].qs + first;
471
- device const float * yl = y + i * QK4_1 + first;
472
-
473
- float2 acc = {0.0f, 0.0f};
474
 
475
- for (int j = 0; j < 4; ++j) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
- acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
478
- acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
 
 
 
 
 
 
 
480
  }
481
 
482
- sumf += acc[0] + acc[1];
483
  }
484
 
485
- sum[ith] = sumf;
486
-
487
- //
488
- // Accumulate the sum from all threads in the threadgroup
489
- //
490
- threadgroup_barrier(mem_flags::mem_threadgroup);
491
- if (ith%4 == 0) {
492
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
493
- }
494
- threadgroup_barrier(mem_flags::mem_threadgroup);
495
- if (ith%16 == 0) {
496
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
497
- }
498
- threadgroup_barrier(mem_flags::mem_threadgroup);
499
- if (ith == 0) {
500
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
501
- dst[r1*ne0 + r0] = sum[0];
502
  }
503
  }
504
 
@@ -508,11 +503,13 @@ kernel void kernel_mul_mat_f16_f32(
508
  device float * dst,
509
  constant int64_t & ne00,
510
  constant int64_t & ne01,
 
511
  constant uint64_t & nb00,
512
  constant uint64_t & nb01,
513
  constant uint64_t & nb02,
514
  constant int64_t & ne10,
515
  constant int64_t & ne11,
 
516
  constant uint64_t & nb10,
517
  constant uint64_t & nb11,
518
  constant uint64_t & nb12,
@@ -528,7 +525,7 @@ kernel void kernel_mul_mat_f16_f32(
528
  const int64_t r1 = tgpig.y;
529
  const int64_t im = tgpig.z;
530
 
531
- device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
532
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
533
 
534
  sum[tpitg.x] = 0.0f;
@@ -615,17 +612,19 @@ kernel void kernel_rope(
615
  constant int & n_past,
616
  constant int & n_dims,
617
  constant int & mode,
 
 
618
  uint3 tpig[[thread_position_in_grid]]) {
619
  const int64_t i3 = tpig[2];
620
  const int64_t i2 = tpig[1];
621
  const int64_t i1 = tpig[0];
622
 
623
  const bool is_neox = mode & 2;
624
- const float theta_scale = pow(10000.0, -2.0f/n_dims);
625
 
626
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
627
 
628
- float theta = (float)p;
629
 
630
  if (!is_neox) {
631
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
@@ -644,7 +643,25 @@ kernel void kernel_rope(
644
  dst_data[1] = x0*sin_theta + x1*cos_theta;
645
  }
646
  } else {
647
- // TODO: implement
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648
  }
649
  }
650
 
@@ -863,779 +880,581 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
863
  return r;
864
  }
865
 
866
- //========================================== dequantization =============================
867
 
868
- static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
869
- assert(k % QK_K == 0);
870
- const int nb = k / QK_K;
 
 
 
 
 
 
 
 
 
 
 
 
871
 
872
- for (int i = 0; i < nb; i++) {
 
 
 
873
 
874
- const float d = x[i].d;
875
- const float min = x[i].dmin;
 
 
 
 
 
876
 
877
- device const uint8_t * q = x[i].qs;
878
 
879
  #if QK_K == 256
880
- int is = 0;
881
- float dl, ml;
882
- for (int n = 0; n < QK_K; n += 128) {
883
- int shift = 0;
884
- for (int j = 0; j < 4; ++j) {
885
-
886
- uint8_t sc = x[i].scales[is++];
887
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
888
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
889
-
890
- sc = x[i].scales[is++];
891
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
892
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
 
 
 
 
893
 
894
- shift += 2;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
895
  }
896
- q += 32;
 
 
 
 
 
 
 
 
 
 
897
  }
 
 
 
898
  #else
899
- float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
900
- float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
901
- float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
902
- float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
903
- for (int l = 0; l < 16; ++l) {
904
- y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
905
- y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
906
- y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
907
- y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
908
  }
909
- y += QK_K;
 
 
910
  #endif
911
 
 
 
 
 
 
912
  }
913
  }
914
 
915
- static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
916
- assert(k % QK_K == 0);
917
- const int nb = k / QK_K;
918
-
919
  #if QK_K == 256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
 
921
  const uint16_t kmask1 = 0x0303;
922
  const uint16_t kmask2 = 0x0f0f;
923
 
924
- uint16_t aux[8];
925
- thread const int8_t * scales = (thread const int8_t*)aux;
926
-
927
- for (int i = 0; i < nb; i++) {
 
 
 
928
 
929
- const float d_all = (float)(x[i].d);
 
930
 
931
- device const uint8_t * q = x[i].qs;
932
- device const uint8_t * h = x[i].hmask;
933
- uint8_t m = 1;
934
-
935
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
936
- aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
937
- aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
938
- aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
939
- aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
940
- aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
941
- aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
942
- aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
943
- aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
944
-
945
- int is = 0;
946
- float dl;
947
- for (int n = 0; n < QK_K; n += 128) {
948
- int shift = 0;
949
- for (int j = 0; j < 4; ++j) {
950
-
951
- dl = d_all * (scales[is++] - 32);
952
- for (int l = 0; l < 16; ++l) {
953
- *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
954
- }
955
 
956
- dl = d_all * (scales[is++] - 32);
957
- for (int l = 0; l < 16; ++l) {
958
- *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
959
- }
960
 
961
- shift += 2;
962
- m <<= 1;
963
- }
964
- q += 32;
965
- }
966
- }
967
- #else
968
- for (int i = 0; i < nb; i++) {
969
 
970
- const float d_all = (float)(x[i].d);
971
 
972
- device const uint8_t * q = x[i].qs;
973
- device const uint8_t * hm = x[i].hmask;
974
 
975
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
976
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
977
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
978
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
979
 
980
  for (int l = 0; l < 8; ++l) {
981
- uint8_t h = hm[l];
982
- y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
983
- y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
984
- y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
985
- y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
986
- y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
987
- y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
988
- y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
989
- y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
990
  }
991
- y += QK_K;
992
- }
993
- #endif
994
 
995
- }
 
 
 
996
 
997
- static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
998
- assert(k % QK_K == 0);
999
- const int nb = k / QK_K;
1000
 
1001
- for (int i = 0; i < nb; i++) {
 
1002
 
1003
- device const uint8_t * q = x[i].qs;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
 
1005
- #if QK_K == 256
1006
- const float d = x[i].d;
1007
- const float min = x[i].dmin;
1008
-
1009
- device const uint8_t * scales = x[i].scales;
1010
-
1011
- int is = 0;
1012
- for (int j = 0; j < QK_K; j += 64) {
1013
- const uchar4 sc = get_scale_min_k4(is, scales);
1014
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
1015
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
1016
- for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
1017
- for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
1018
- q += 32; is += 2;
1019
- }
1020
- #else
1021
- device const uint8_t * s = x[i].scales;
1022
- device const half2 * dh = (device const half2 *)x[i].d;
1023
- const float2 d = (float2)dh[0];
1024
- const float d1 = d[0] * (s[0] & 0xF);
1025
- const float d2 = d[0] * (s[1] & 0xF);
1026
- const float m1 = d[1] * (s[0] >> 4);
1027
- const float m2 = d[1] * (s[1] >> 4);
1028
- for (int l = 0; l < 32; ++l) {
1029
- y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1030
- y[l+32] = d2 * (q[l] >> 4) - m2;
1031
  }
1032
- y += QK_K;
1033
- #endif
1034
 
1035
- }
1036
- }
1037
 
1038
- static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
1039
- assert(k % QK_K == 0);
1040
- const int nb = k / QK_K;
1041
 
1042
- #if QK_K == 256
1043
- for (int i = 0; i < nb; i++) {
1044
-
1045
- const float d = (float)(x[i].d);
1046
- const float min = (float)(x[i].dmin);
1047
-
1048
- device const uint8_t * ql = x[i].qs;
1049
- device const uint8_t * qh = x[i].qh;
1050
-
1051
- int is = 0;
1052
- uint8_t u1 = 1, u2 = 2;
1053
- for (int j = 0; j < QK_K; j += 64) {
1054
- const uchar4 sc = get_scale_min_k4(is, x[i].scales);
1055
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
1056
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
1057
- for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
1058
- for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
1059
- ql += 32; is += 2;
1060
- u1 <<= 2; u2 <<= 2;
1061
  }
1062
  }
 
1063
  #else
1064
- for (int i = 0; i < nb; i++) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1065
 
1066
- const float d = (float)x[i].d;
1067
 
1068
- device const uint8_t * ql = x[i].qs;
1069
- device const uint8_t * qh = x[i].qh;
1070
- device const int8_t * sc = x[i].scales;
 
 
 
 
 
 
 
 
 
1071
 
1072
- for (int l = 0; l < 8; ++l) {
1073
- y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1074
- y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1075
- y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1076
- y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1077
- y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1078
- y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1079
- y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1080
- y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1081
- }
1082
- y += QK_K;
1083
- }
1084
- #endif
1085
 
1086
- }
1087
 
1088
- static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
1089
- assert(k % QK_K == 0);
1090
- const int nb = k / QK_K;
1091
 
1092
- for (int i = 0; i < nb; i++) {
1093
-
1094
- device const uint8_t * ql = x[i].ql;
1095
- device const uint8_t * qh = x[i].qh;
1096
- device const int8_t * sc = x[i].scales;
1097
-
1098
- const float d = x[i].d;
1099
-
1100
- #if QK_K == 256
1101
- for (int n = 0; n < QK_K; n += 128) {
1102
- for (int l = 0; l < 32; ++l) {
1103
- int is = l/16;
1104
- const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1105
- const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1106
- const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1107
- const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1108
- y[l + 0] = d * sc[is + 0] * q1;
1109
- y[l + 32] = d * sc[is + 2] * q2;
1110
- y[l + 64] = d * sc[is + 4] * q3;
1111
- y[l + 96] = d * sc[is + 6] * q4;
1112
- }
1113
- y += 128;
1114
- ql += 64;
1115
- qh += 32;
1116
- sc += 8;
1117
  }
1118
- #else
1119
- for (int l = 0; l < 16; ++l) {
1120
- const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1121
- const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1122
- const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1123
- const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1124
- y[l+ 0] = d * sc[0] * q1;
1125
- y[l+16] = d * sc[1] * q2;
1126
- y[l+32] = d * sc[2] * q3;
1127
- y[l+48] = d * sc[3] * q4;
1128
- }
1129
- y += 64;
1130
- #endif
1131
- }
1132
- }
1133
-
1134
- kernel void kernel_get_rows_q2_K(
1135
- device const void * src0,
1136
- device const int * src1,
1137
- device float * dst,
1138
- constant int64_t & ne00,
1139
- constant uint64_t & nb01,
1140
- constant uint64_t & nb1,
1141
- uint tpig[[thread_position_in_grid]]) {
1142
- const int i = tpig;
1143
- const int r = ((device int32_t *) src1)[i];
1144
-
1145
- dequantize_row_q2_K(
1146
- (device const block_q2_K *) ((device char *) src0 + r*nb01),
1147
- (device float *) ((device char *) dst + i*nb1), ne00);
1148
- }
1149
-
1150
- kernel void kernel_get_rows_q3_K(
1151
- device const void * src0,
1152
- device const int * src1,
1153
- device float * dst,
1154
- constant int64_t & ne00,
1155
- constant uint64_t & nb01,
1156
- constant uint64_t & nb1,
1157
- uint tpig[[thread_position_in_grid]]) {
1158
- const int i = tpig;
1159
- const int r = ((device int32_t *) src1)[i];
1160
-
1161
- dequantize_row_q3_K(
1162
- (device const block_q3_K *) ((device char *) src0 + r*nb01),
1163
- (device float *) ((device char *) dst + i*nb1), ne00);
1164
- }
1165
-
1166
- kernel void kernel_get_rows_q4_K(
1167
- device const void * src0,
1168
- device const int * src1,
1169
- device float * dst,
1170
- constant int64_t & ne00,
1171
- constant uint64_t & nb01,
1172
- constant uint64_t & nb1,
1173
- uint tpig[[thread_position_in_grid]]) {
1174
- const int i = tpig;
1175
- const int r = ((device int32_t *) src1)[i];
1176
-
1177
- dequantize_row_q4_K(
1178
- (device const block_q4_K *) ((device char *) src0 + r*nb01),
1179
- (device float *) ((device char *) dst + i*nb1), ne00);
1180
- }
1181
-
1182
- kernel void kernel_get_rows_q5_K(
1183
- device const void * src0,
1184
- device const int * src1,
1185
- device float * dst,
1186
- constant int64_t & ne00,
1187
- constant uint64_t & nb01,
1188
- constant uint64_t & nb1,
1189
- uint tpig[[thread_position_in_grid]]) {
1190
- const int i = tpig;
1191
- const int r = ((device int32_t *) src1)[i];
1192
-
1193
- dequantize_row_q5_K(
1194
- (device const block_q5_K *) ((device char *) src0 + r*nb01),
1195
- (device float *) ((device char *) dst + i*nb1), ne00);
1196
- }
1197
-
1198
- kernel void kernel_get_rows_q6_K(
1199
- device const void * src0,
1200
- device const int * src1,
1201
- device float * dst,
1202
- constant int64_t & ne00,
1203
- constant uint64_t & nb01,
1204
- constant uint64_t & nb1,
1205
- uint tpig[[thread_position_in_grid]]) {
1206
- const int i = tpig;
1207
- const int r = ((device int32_t *) src1)[i];
1208
-
1209
- dequantize_row_q6_K(
1210
- (device const block_q6_K *) ((device char *) src0 + r*nb01),
1211
- (device float *) ((device char *) dst + i*nb1), ne00);
1212
- }
1213
-
1214
- //====================================== dot products =========================
1215
-
1216
- kernel void kernel_mul_mat_q2_K_f32(
1217
- device const void * src0,
1218
- device const float * src1,
1219
- device float * dst,
1220
- constant int64_t & ne00,
1221
- constant int64_t & ne10,
1222
- constant int64_t & ne0,
1223
- threadgroup float * sum [[threadgroup(0)]],
1224
- uint2 tgpig[[threadgroup_position_in_grid]],
1225
- uint2 tpitg[[thread_position_in_threadgroup]],
1226
- uint2 tptg[[threads_per_threadgroup]]) {
1227
-
1228
- const int nb = ne00/QK_K;
1229
-
1230
- const int64_t r0 = tgpig.x;
1231
- const int64_t r1 = tgpig.y;
1232
-
1233
- device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1234
- device const float * yy = (device const float *) src1 + r1*ne10;
1235
-
1236
- const int nth = tptg.x*tptg.y;
1237
- const int ith = tptg.y*tpitg.x + tpitg.y;
1238
-
1239
- float sumf = 0;
1240
-
1241
- #if QK_K == 256
1242
- const int tid = tpitg.y; // 0...16
1243
- const int il = tid/4; // 0...3
1244
- const int ir = tid%4; // 0...3
1245
- const int ip = il/2; // 0 or 1
1246
- const int shift1 = 4*(il%2);// 0 or 4
1247
- const int shift2 = shift1+2;// 2 or 6
1248
- const int n = 8;
1249
- const int is = 4*il + (n*ir)/16;
1250
-
1251
- const int y_offset = 64*il + n*ir;
1252
- const int q_offset = 32*ip + n*ir;
1253
-
1254
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1255
-
1256
- device const uint8_t * q = x[i].qs + q_offset;
1257
- device const uint8_t * scales = x[i].scales + is;
1258
-
1259
- uint8_t d1 = scales[0] & 0xF;
1260
- uint8_t d2 = scales[2] & 0xF;
1261
- uint8_t m1 = scales[0] >> 4;
1262
- uint8_t m2 = scales[2] >> 4;
1263
-
1264
- device const float * y = yy + i*QK_K + y_offset;
1265
-
1266
- float2 s = {0.f, 0.f};
1267
- float smin = 0;
1268
- for (int l = 0; l < n; ++l) {
1269
- s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
1270
- s[1] += y[l+32] * ((q[l] >> shift2) & 3);
1271
- smin += y[l+ 0] * m1 + y[l+32] * m2;
1272
- }
1273
-
1274
- const float dall = (float)x[i].d;
1275
- const float dmin = (float)x[i].dmin;
1276
-
1277
- sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1278
 
1279
  }
1280
- #else
1281
- const int il = 4 * tpitg.x;
1282
-
1283
- uint32_t aux[2];
1284
- thread const uint8_t * d = (thread const uint8_t *)aux;
1285
- thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1286
 
1287
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1288
-
1289
- device const uint8_t * q = x[i].qs + il;
1290
- device const float * y = yy + i*QK_K + il;
1291
-
1292
- const float dall = (float)x[i].d;
1293
- const float dmin = (float)x[i].dmin;
1294
-
1295
- device const uint32_t * a = (device const uint32_t *)x[i].scales;
1296
- aux[0] = a[0] & 0x0f0f0f0f;
1297
- aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1298
-
1299
- for (int l = 0; l < 4; ++l) {
1300
- sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1301
- + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1302
- + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1303
- + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304
- }
1305
  }
1306
- #endif
1307
-
1308
- sum[ith] = sumf;
1309
 
1310
- //
1311
- // Accumulate the sum from all threads in the threadgroup
1312
- //
1313
- threadgroup_barrier(mem_flags::mem_threadgroup);
1314
- if (ith%4 == 0) {
1315
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1316
- }
1317
- threadgroup_barrier(mem_flags::mem_threadgroup);
1318
- if (ith%16 == 0) {
1319
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1320
- }
1321
- threadgroup_barrier(mem_flags::mem_threadgroup);
1322
- if (ith == 0) {
1323
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1324
- dst[r1*ne0 + r0] = sum[0];
1325
- }
1326
  }
 
1327
 
1328
- kernel void kernel_mul_mat_q3_K_f32(
 
1329
  device const void * src0,
1330
  device const float * src1,
1331
  device float * dst,
1332
  constant int64_t & ne00,
1333
- constant int64_t & ne10,
1334
- constant int64_t & ne0,
1335
- constant int64_t & ne1,
1336
- threadgroup float * sum [[threadgroup(0)]],
1337
- uint2 tgpig[[threadgroup_position_in_grid]],
1338
- uint2 tpitg[[thread_position_in_threadgroup]],
1339
- uint2 tptg[[threads_per_threadgroup]]) {
1340
-
1341
- const int nb = ne00/QK_K;
1342
-
1343
- const int64_t r0 = tgpig.x;
1344
- const int64_t r1 = tgpig.y;
1345
-
1346
- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1347
- device const float * yy = (device const float *) src1 + r1*ne10;
1348
-
1349
- const int nth = tptg.x*tptg.y;
1350
- const int ith = tptg.y*tpitg.x + tpitg.y;
1351
-
1352
- #if QK_K == 256
1353
-
1354
- const uint8_t m3 = 3;
1355
- const int8_t m4 = 4;
1356
 
1357
- const uint16_t kmask1 = 0x0303;
1358
  const uint16_t kmask2 = 0x0f0f;
 
1359
 
1360
- const int tid = tpitg.y; // expecting 16
1361
- const int ip = tid/8; // 0 or 1
1362
- const int il = tid/2 - 4*ip; // 0...3
1363
- const int ir = tid%2;
1364
- const int n = 8;
1365
- const int l0 = n*ir;
1366
-
1367
- const uint8_t m = 1 << (4*ip + il);
1368
-
1369
- const int shift = 2*il;
1370
-
1371
- const uint16_t s_shift1 = 4*ip;
1372
- const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1373
- const int ik = 4 + (il%2);
1374
-
1375
- const int q_offset = 32*ip + l0;
1376
- const int y_offset = 128*ip + 32*il + l0;
1377
-
1378
- //float sumf = 0;
1379
- float sumf1 = 0, sumf2 = 0;
1380
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1381
-
1382
- const float d_all = (float)(x[i].d);
1383
-
1384
- device const uint8_t * q = x[i].qs + q_offset;
1385
- device const uint8_t * h = x[i].hmask + l0;
1386
- device const float * y = yy + i * QK_K + y_offset;
1387
-
1388
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1389
- const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1390
-
1391
- float s = 0;
1392
- for (int l = 0; l < n; ++l) {
1393
- s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1394
- }
1395
- float d = d_all * s;
1396
- sumf1 += d * scales[0];
1397
- sumf2 += d;
1398
- //sumf += d_all * s * (scales[0] - 32);
1399
 
1400
- s = 0;
1401
- for (int l = 0; l < n; ++l) {
1402
- s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1403
  }
1404
- d = d_all * s;
1405
- sumf1 += d * scales[1];
1406
- sumf2 += d;
1407
- //sumf += d_all * s * (scales[1] - 32);
1408
 
1409
- }
1410
-
1411
- //sum[ith] = sumf;
1412
- sum[ith] = sumf1 - 32.f*sumf2;
1413
- #else
1414
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1415
- const int im = il/8; // 0, 0, 1, 1
1416
- const int in = il%8; // 0, 4, 0, 4
1417
-
1418
- float sumf = 0;
1419
-
1420
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1421
-
1422
- const float d_all = (float)(x[i].d);
1423
-
1424
- device const uint8_t * q = x[i].qs + il;
1425
- device const uint8_t * h = x[i].hmask + in;
1426
- device const float * y = yy + i * QK_K + il;
1427
-
1428
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1429
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1430
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1431
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
 
 
1432
 
1433
- for (int l = 0; l < 4; ++l) {
1434
- const uint8_t hm = h[l] >> im;
1435
- sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1436
- + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1437
- + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1438
- + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
 
 
 
 
 
1439
  }
1440
 
 
1441
  }
1442
 
1443
- sum[ith] = sumf;
1444
-
1445
- #endif
1446
-
1447
- //
1448
- // Accumulate the sum from all threads in the threadgroup
1449
- //
1450
- threadgroup_barrier(mem_flags::mem_threadgroup);
1451
- if (ith%4 == 0) {
1452
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1453
- }
1454
- threadgroup_barrier(mem_flags::mem_threadgroup);
1455
- if (ith%16 == 0) {
1456
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1457
- }
1458
- threadgroup_barrier(mem_flags::mem_threadgroup);
1459
- if (ith == 0) {
1460
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1461
- dst[r1*ne0 + r0] = sum[0];
1462
  }
1463
-
1464
  }
1465
-
1466
  kernel void kernel_mul_mat_q4_K_f32(
1467
  device const void * src0,
1468
  device const float * src1,
1469
  device float * dst,
1470
  constant int64_t & ne00,
1471
- constant int64_t & ne10,
1472
- constant int64_t & ne0,
1473
- threadgroup float * sum [[threadgroup(0)]],
1474
- uint2 tgpig[[threadgroup_position_in_grid]],
1475
- uint2 tpitg[[thread_position_in_threadgroup]],
1476
- uint2 tptg[[threads_per_threadgroup]]) {
1477
-
1478
- const int nb = ne00/QK_K;
1479
-
1480
- const int64_t r0 = tgpig.x;
1481
- const int64_t r1 = tgpig.y;
1482
-
1483
- const int nth = tptg.x*tptg.y;
1484
- const int ith = tptg.y*tpitg.x + tpitg.y;
1485
-
1486
- device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1487
- device const float * yy = (device const float *) src1 + r1*ne10;
1488
-
1489
- float sumf = 0;
1490
-
1491
- #if QK_K == 256
1492
-
1493
- const uint16_t kmask1 = 0x3f3f;
1494
- const uint16_t kmask2 = 0x0f0f;
1495
- const uint16_t kmask3 = 0xc0c0;
1496
-
1497
- const int tid = tpitg.y; // 0...16
1498
- const int il = tid/4; // 0...3
1499
- const int ir = tid - 4*il;// 0...3
1500
- const int n = 4;
1501
-
1502
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1503
- const int in = il%2;
1504
-
1505
- const int l0 = n*(2*ir + in);
1506
- const int q_offset = 32*im + l0;
1507
- const int y_offset = 64*im + l0;
1508
 
1509
- uchar2 sc1, sc2, sc3, sc4;
 
1510
 
1511
- for (int i = tpitg.x; i < nb; i += tptg.x) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1512
 
1513
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1514
- device const uint8_t * q2 = q1 + 64;
1515
- device const float * y1 = yy + i*QK_K + y_offset;
1516
- device const float * y2 = y1 + 128;
1517
 
1518
- const float dall = (float)((x + i)->d);
1519
- const float dmin = (float)((x + i)->dmin);
1520
 
1521
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1522
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1523
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1524
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1525
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
1526
 
1527
- float4 s = {0.f, 0.f, 0.f, 0.f};
1528
- float smin = 0;
1529
- for (int l = 0; l < n; ++l) {
 
 
 
 
 
1530
 
1531
- s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4);
1532
- s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4);
1533
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
 
 
1534
 
 
 
 
1535
  }
1536
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1537
 
 
1538
  }
1539
- #else
1540
- uint16_t aux16[2];
1541
- thread const uint8_t * scales = (thread const uint8_t *)aux16;
1542
-
1543
- const int il = 4*tpitg.x;
1544
-
1545
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1546
 
1547
- device const uint8_t * q = x[i].qs + il;
1548
- device const float * y = yy + i * QK_K + il;
1549
-
1550
- const float d = (float)x[i].d[0];
1551
- const float m = (float)x[i].d[1];
1552
-
1553
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
1554
- aux16[0] = a[0] & 0x0f0f;
1555
- aux16[1] = (a[0] >> 4) & 0x0f0f;
1556
-
1557
- for (int l = 0; l < 4; ++l) {
1558
- sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1559
- + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1560
  }
1561
  }
1562
- #endif
1563
-
1564
- sum[ith] = sumf;
1565
-
1566
- //
1567
- // Accumulate the sum from all threads in the threadgroup
1568
- // This version is slightly faster than the commented out one below,
1569
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1570
- //
1571
- threadgroup_barrier(mem_flags::mem_threadgroup);
1572
- if (ith%4 == 0) {
1573
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1574
- }
1575
- threadgroup_barrier(mem_flags::mem_threadgroup);
1576
- if (ith%16 == 0) {
1577
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1578
- }
1579
- threadgroup_barrier(mem_flags::mem_threadgroup);
1580
- if (ith == 0) {
1581
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1582
- dst[r1*ne0 + r0] = sum[0];
1583
- }
1584
-
1585
- //// accumulate the sum from all threads in the threadgroup
1586
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1587
- //for (uint i = nth/2; i > 0; i /= 2) {
1588
- // if (ith < i) {
1589
- // sum[ith] += sum[ith + i];
1590
- // }
1591
- // threadgroup_barrier(mem_flags::mem_threadgroup);
1592
- //}
1593
-
1594
- //if (ith == 0) {
1595
- // dst[r1*ne0 + r0] = sum[0];
1596
- //}
1597
  }
 
1598
 
1599
  kernel void kernel_mul_mat_q5_K_f32(
1600
  device const void * src0,
1601
  device const float * src1,
1602
  device float * dst,
1603
  constant int64_t & ne00,
1604
- constant int64_t & ne10,
1605
- constant int64_t & ne0,
1606
- threadgroup float * sum [[threadgroup(0)]],
1607
- uint2 tgpig[[threadgroup_position_in_grid]],
1608
- uint2 tpitg[[thread_position_in_threadgroup]],
1609
- uint2 tptg[[threads_per_threadgroup]]) {
 
 
 
 
1610
 
1611
  const int nb = ne00/QK_K;
1612
 
1613
  const int64_t r0 = tgpig.x;
1614
  const int64_t r1 = tgpig.y;
 
1615
 
1616
- device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1617
- device const float * yy = (device const float *) src1 + r1*ne10;
 
 
1618
 
1619
- const int nth = tptg.x*tptg.y;
1620
- const int ith = tptg.y*tpitg.x + tpitg.y;
1621
 
1622
- float sumf = 0;
1623
 
1624
  #if QK_K == 256
 
 
1625
 
1626
  const uint16_t kmask1 = 0x3f3f;
1627
  const uint16_t kmask2 = 0x0f0f;
1628
  const uint16_t kmask3 = 0xc0c0;
1629
 
1630
- const int tid = tpitg.y; // 0...16
1631
- const int il = tid/4; // 0...3
1632
- const int ir = tid - 4*il;// 0...3
1633
- const int n = 4;
1634
-
1635
- const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
1636
- const int in = il%2;
1637
 
1638
- const int l0 = n*(2*ir + in);
1639
  const int q_offset = 32*im + l0;
1640
  const int y_offset = 64*im + l0;
1641
 
@@ -1644,78 +1463,113 @@ kernel void kernel_mul_mat_q5_K_f32(
1644
  const uint8_t hm3 = hm1 << 4;
1645
  const uint8_t hm4 = hm2 << 4;
1646
 
1647
- uchar2 sc1, sc2, sc3, sc4;
 
1648
 
1649
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1650
 
1651
- device const uint8_t * q1 = (x + i)->qs + q_offset;
1652
- device const uint8_t * q2 = q1 + 64;
1653
- device const uint8_t * qh = (x + i)->qh + l0;
1654
- device const float * y1 = yy + i*QK_K + y_offset;
1655
- device const float * y2 = y1 + 128;
1656
 
1657
- const float dall = (float)((x + i)->d);
1658
- const float dmin = (float)((x + i)->dmin);
 
 
1659
 
1660
- device const uint16_t * a = (device const uint16_t *)(x + i)->scales;
1661
- sc1 = as_type<uchar2>((uint16_t)(a[im+0] & kmask1));
1662
- sc2 = as_type<uchar2>((uint16_t)(a[im+2] & kmask1));
1663
- sc3 = as_type<uchar2>((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)));
1664
- sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
 
 
 
1665
 
1666
- float4 s = {0.f, 0.f, 0.f, 0.f};
1667
- float smin = 0;
1668
- for (int l = 0; l < n; ++l) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1669
 
1670
- s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0));
1671
- s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0));
1672
- s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0));
1673
- s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0));
1674
- smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
1675
 
1676
  }
1677
- sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
 
1678
 
1679
  }
1680
  #else
1681
- const int il = 4 * tpitg.x; // 0, 4, 8, 12
1682
- const int im = il/8; // 0, 0, 1, 1
1683
- const int in = il%8; // 0, 4, 0, 4
 
 
 
 
 
1684
 
1685
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1686
 
1687
- const float d = (float)x[i].d;
 
 
 
 
 
 
 
1688
  device const uint8_t * q = x[i].qs + il;
1689
  device const uint8_t * h = x[i].qh + in;
1690
  device const int8_t * s = x[i].scales;
1691
- device const float * y = yy + i*QK_K + il;
1692
 
1693
- for (int l = 0; l < 4; ++l) {
1694
- const uint8_t hl = h[l] >> im;
1695
- sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1696
- + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1697
- + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1698
- + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
 
 
 
 
 
 
 
 
 
 
 
 
 
1699
  }
 
 
1700
  }
1701
  #endif
1702
- sum[ith] = sumf;
1703
 
1704
- //
1705
- // Accumulate the sum from all threads in the threadgroup
1706
- //
1707
- threadgroup_barrier(mem_flags::mem_threadgroup);
1708
- if (ith%4 == 0) {
1709
- sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
1710
- }
1711
- threadgroup_barrier(mem_flags::mem_threadgroup);
1712
- if (ith%16 == 0) {
1713
- sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
1714
- }
1715
- threadgroup_barrier(mem_flags::mem_threadgroup);
1716
- if (ith == 0) {
1717
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1718
- dst[r1*ne0 + r0] = sum[0];
1719
  }
1720
 
1721
  }
@@ -1725,12 +1579,16 @@ kernel void kernel_mul_mat_q6_K_f32(
1725
  device const float * src1,
1726
  device float * dst,
1727
  constant int64_t & ne00,
1728
- constant int64_t & ne10,
1729
- constant int64_t & ne0,
1730
- threadgroup float * sum [[threadgroup(0)]],
1731
- uint2 tgpig[[threadgroup_position_in_grid]],
1732
- uint2 tpitg[[thread_position_in_threadgroup]],
1733
- uint2 tptg[[threads_per_threadgroup]]) {
 
 
 
 
1734
 
1735
  const uint8_t kmask1 = 0x03;
1736
  const uint8_t kmask2 = 0x0C;
@@ -1741,20 +1599,20 @@ kernel void kernel_mul_mat_q6_K_f32(
1741
 
1742
  const int64_t r0 = tgpig.x;
1743
  const int64_t r1 = tgpig.y;
 
1744
 
1745
- device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1746
- device const float * yy = (device const float *) src1 + r1*ne10;
1747
-
1748
- const int nth = tptg.x*tptg.y;
1749
- const int ith = tptg.y*tpitg.x + tpitg.y;
1750
 
1751
  float sumf = 0;
1752
 
1753
  #if QK_K == 256
1754
- // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1755
- const int iqs = 16 * tpitg.y;
1756
- const int ip = iqs / 128; // 0 or 1
1757
- const int il = (iqs - 128*ip)/16; // 0...7
1758
  const int n = 4;
1759
  const int l0 = n*il;
1760
  const int is = 8*ip + l0/16;
@@ -1763,9 +1621,10 @@ kernel void kernel_mul_mat_q6_K_f32(
1763
  const int q_offset_l = 64*ip + l0;
1764
  const int q_offset_h = 32*ip + l0;
1765
 
1766
- for (int i = tpitg.x; i < nb; i += tptg.x) {
1767
 
1768
- device const uint8_t * ql = x[i].ql + q_offset_l;
 
1769
  device const uint8_t * qh = x[i].qh + q_offset_h;
1770
  device const int8_t * sc = x[i].scales + is;
1771
 
@@ -1775,19 +1634,21 @@ kernel void kernel_mul_mat_q6_K_f32(
1775
 
1776
  float4 sums = {0.f, 0.f, 0.f, 0.f};
1777
  for (int l = 0; l < n; ++l) {
1778
- sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1779
- sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1780
- sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1781
- sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1782
  }
1783
 
1784
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1785
 
1786
  }
 
1787
  #else
1788
- const int il = 4*tpitg.x; // 0, 4, 8, 12
 
1789
 
1790
- for (int i = tpitg.y; i < nb; i += tptg.y) {
1791
  device const float * y = yy + i * QK_K + il;
1792
  device const uint8_t * ql = x[i].ql + il;
1793
  device const uint8_t * qh = x[i].qh + il;
@@ -1807,23 +1668,382 @@ kernel void kernel_mul_mat_q6_K_f32(
1807
 
1808
  #endif
1809
 
1810
- sum[ith] = sumf;
 
 
 
 
1811
 
1812
- //
1813
- // Accumulate the sum from all threads in the threadgroup
1814
- //
1815
- threadgroup_barrier(mem_flags::mem_threadgroup);
1816
- if (ith%4 == 0) {
1817
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
 
1818
  }
1819
- threadgroup_barrier(mem_flags::mem_threadgroup);
1820
- if (ith%16 == 0) {
1821
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
 
 
 
 
 
 
 
 
 
 
1822
  }
1823
- threadgroup_barrier(mem_flags::mem_threadgroup);
1824
- if (ith == 0) {
1825
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1826
- dst[r1*ne0 + r0] = sum[0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1827
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1828
 
 
 
 
 
 
 
1829
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
  } block_q4_1;
20
 
21
+ #define QK8_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ int8_t qs[QK8_0]; // quants
25
+ } block_q8_0;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  kernel void kernel_add(
28
  device const float * src0,
 
32
  dst[tpig] = src0[tpig] + src1[tpig];
33
  }
34
 
35
+ // assumption: src1 is a row
36
+ // broadcast src1 into src0
37
+ kernel void kernel_add_row(
38
+ device const float * src0,
39
+ device const float * src1,
40
+ device float * dst,
41
+ constant int64_t & ne00,
42
+ uint tpig[[thread_position_in_grid]]) {
43
+ dst[tpig] = src0[tpig] + src1[tpig % ne00];
44
+ }
45
+
46
  kernel void kernel_mul(
47
  device const float * src0,
48
  device const float * src1,
 
93
  device float * dst,
94
  uint tpig[[thread_position_in_grid]]) {
95
  float x = src0[tpig];
96
+
97
+ // BEWARE !!!
98
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
99
+ // This was observed with Falcon 7B and 40B models
100
+ //
101
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
102
  }
103
 
104
  kernel void kernel_soft_max(
 
189
  }
190
  }
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  kernel void kernel_norm(
193
  device const void * src0,
194
  device float * dst,
 
264
  threadgroup float * sum [[threadgroup(0)]],
265
  uint tgpig[[threadgroup_position_in_grid]],
266
  uint tpitg[[thread_position_in_threadgroup]],
267
+ uint sgitg[[simdgroup_index_in_threadgroup]],
268
+ uint tiisg[[thread_index_in_simdgroup]],
269
  uint ntg[[threads_per_threadgroup]]) {
270
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
271
+ device const float * x_scalar = (device const float *) x;
272
+ float4 sumf=0;
273
+ float all_sum=0;
274
 
275
  // parallel sum
276
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
277
+ sumf += x[i00] * x[i00];
278
+ }
279
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
280
+ all_sum = simd_sum(all_sum);
281
+ if (tiisg == 0) {
282
+ sum[sgitg] = all_sum;
283
  }
284
 
 
285
  threadgroup_barrier(mem_flags::mem_threadgroup);
286
+ // broadcast, simd group number is ntg / 32
287
+ for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
288
+ if (tpitg < i) {
289
+ sum[tpitg] += sum[tpitg + i];
290
+ }
291
  }
 
 
292
  if (tpitg == 0) {
293
+ for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
294
  sum[0] /= ne00;
295
  }
296
 
 
299
  const float mean = sum[0];
300
  const float scale = 1.0f/sqrt(mean + eps);
301
 
302
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
303
+ device float * y_scalar = (device float *) y;
304
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
305
  y[i00] = x[i00] * scale;
306
  }
307
+ if (tpitg == 0) {
308
+ for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
309
+ }
310
  }
311
 
312
+ // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
313
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
314
+ // we assume that the yl's have been multiplied with the appropriate scale factor
315
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
316
+ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
317
+ float d = qb_curr->d;
318
+ float2 acc = 0.f;
319
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
320
+ for (int i = 0; i < 8; i+=2) {
321
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
322
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
323
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
324
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
325
+ }
326
+ return d * (sumy * -8.f + acc[0] + acc[1]);
327
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
+ // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
330
+ // il indicates where the q4 quants begin (0 or QK4_0/4)
331
+ // we assume that the yl's have been multiplied with the appropriate scale factor
332
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
333
+ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
334
+ float d = qb_curr->d;
335
+ float m = qb_curr->m;
336
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
337
+ float2 acc = 0.f;
338
+ for (int i = 0; i < 8; i+=2) {
339
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
340
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
341
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
342
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
343
+ }
344
+ return d * (acc[0] + acc[1]) + sumy * m;
345
+ }
346
 
347
+ // putting them in the kernel cause a significant performance penalty
348
+ #define N_DST 4 // each SIMD group works on 4 rows
349
+ #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
350
+ #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
351
+ //Note: This is a template, but strictly speaking it only applies to
352
+ // quantizations where the block size is 32. It also does not
353
+ // giard against the number of rows not being divisible by
354
+ // N_DST, so this is another explicit assumption of the implementation.
355
+ template<typename block_q_type, int nr, int nsg, int nw>
356
+ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
357
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
358
+ uint3 tgpig, uint tiisg, uint sgitg) {
359
+ const int nb = ne00/QK4_0;
360
+ const int r0 = tgpig.x;
361
+ const int r1 = tgpig.y;
362
+ const int im = tgpig.z;
363
+ const int first_row = (r0 * nsg + sgitg) * nr;
364
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
365
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
366
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
367
+ float yl[16]; // src1 vector cache
368
+ float sumf[nr]={0.f};
369
+
370
+ const int ix = tiisg/2;
371
+ const int il = 8*(tiisg%2);
372
+
373
+ device const float * yb = y + ix * QK4_0 + il;
374
+
375
+ // each thread in a SIMD group deals with half a block.
376
+ for (int ib = ix; ib < nb; ib += nw/2) {
377
+ float sumy = 0;
378
+ for (int i = 0; i < 8; i += 2) {
379
+ sumy += yb[i] + yb[i+1];
380
+ yl[i+0] = yb[i+ 0];
381
+ yl[i+1] = yb[i+ 1]/256.f;
382
+ sumy += yb[i+16] + yb[i+17];
383
+ yl[i+8] = yb[i+16]/16.f;
384
+ yl[i+9] = yb[i+17]/4096.f;
385
+ }
386
 
387
+ for (int row = 0; row < nr; row++) {
388
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
389
  }
390
 
391
+ yb += QK4_0 * 16;
392
  }
393
 
394
+ for (int row = 0; row < nr; ++row) {
395
+ const float tot = simd_sum(sumf[row]);
396
+ if (tiisg == 0 && first_row + row < ne01) {
397
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
398
+ }
 
 
 
 
 
 
 
 
 
 
 
 
399
  }
400
  }
401
 
402
+ kernel void kernel_mul_mat_q4_0_f32(
403
  device const void * src0,
404
  device const float * src1,
405
  device float * dst,
406
  constant int64_t & ne00,
407
+ constant int64_t & ne01[[buffer(4)]],
408
+ constant int64_t & ne02[[buffer(5)]],
409
+ constant int64_t & ne10[[buffer(9)]],
410
+ constant int64_t & ne12[[buffer(11)]],
411
+ constant int64_t & ne0[[buffer(15)]],
412
+ constant int64_t & ne1[[buffer(16)]],
413
+ constant uint & gqa[[buffer(17)]],
414
+ uint3 tgpig[[threadgroup_position_in_grid]],
415
+ uint tiisg[[thread_index_in_simdgroup]],
416
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
417
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
418
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
+ kernel void kernel_mul_mat_q4_1_f32(
421
+ device const void * src0,
422
+ device const float * src1,
423
+ device float * dst,
424
+ constant int64_t & ne00,
425
+ constant int64_t & ne01[[buffer(4)]],
426
+ constant int64_t & ne02[[buffer(5)]],
427
+ constant int64_t & ne10[[buffer(9)]],
428
+ constant int64_t & ne12[[buffer(11)]],
429
+ constant int64_t & ne0[[buffer(15)]],
430
+ constant int64_t & ne1[[buffer(16)]],
431
+ constant uint & gqa[[buffer(17)]],
432
+ uint3 tgpig[[threadgroup_position_in_grid]],
433
+ uint tiisg[[thread_index_in_simdgroup]],
434
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
435
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
436
+ }
437
 
438
+ kernel void kernel_mul_mat_q8_0_f32(
439
+ device const void * src0,
440
+ device const float * src1,
441
+ device float * dst,
442
+ constant int64_t & ne00,
443
+ constant int64_t & ne01[[buffer(4)]],
444
+ constant int64_t & ne02[[buffer(5)]],
445
+ constant int64_t & ne10[[buffer(9)]],
446
+ constant int64_t & ne12[[buffer(11)]],
447
+ constant int64_t & ne0[[buffer(15)]],
448
+ constant int64_t & ne1[[buffer(16)]],
449
+ constant uint & gqa[[buffer(17)]],
450
+ uint3 tgpig[[threadgroup_position_in_grid]],
451
+ uint tiisg[[thread_index_in_simdgroup]],
452
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
453
+ const int nr = N_DST;
454
+ const int nsg = N_SIMDGROUP;
455
+ const int nw = N_SIMDWIDTH;
456
+
457
+ const int nb = ne00/QK8_0;
458
+ const int r0 = tgpig.x;
459
+ const int r1 = tgpig.y;
460
+ const int im = tgpig.z;
461
+ const int first_row = (r0 * nsg + sgitg) * nr;
462
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
463
+ device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
464
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
465
+
466
+ float yl[16];
467
+ float sumf[nr]={0.f};
468
+
469
+ const int ix = tiisg/2;
470
+ const int il = tiisg%2;
471
+
472
+ device const float * yb = y + ix * QK8_0 + 16*il;
473
+
474
+ // each thread in a SIMD group deals with half a block.
475
+ for (int ib = ix; ib < nb; ib += nw/2) {
476
+ for (int i = 0; i < 16; ++i) {
477
+ yl[i] = yb[i];
478
+ }
479
 
480
+ for (int row = 0; row < nr; row++) {
481
+ device const int8_t * qs = x[ib+row*nb].qs + 16*il;
482
+ float sumq = 0.f;
483
+ for (int iq = 0; iq < 16; ++iq) {
484
+ sumq += qs[iq] * yl[iq];
485
+ }
486
+ sumf[row] += sumq*x[ib+row*nb].d;
487
  }
488
 
489
+ yb += QK8_0 * 16;
490
  }
491
 
492
+ for (int row = 0; row < nr; ++row) {
493
+ const float tot = simd_sum(sumf[row]);
494
+ if (tiisg == 0 && first_row + row < ne01) {
495
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
496
+ }
 
 
 
 
 
 
 
 
 
 
 
 
497
  }
498
  }
499
 
 
503
  device float * dst,
504
  constant int64_t & ne00,
505
  constant int64_t & ne01,
506
+ constant int64_t & ne02,
507
  constant uint64_t & nb00,
508
  constant uint64_t & nb01,
509
  constant uint64_t & nb02,
510
  constant int64_t & ne10,
511
  constant int64_t & ne11,
512
+ constant int64_t & ne12,
513
  constant uint64_t & nb10,
514
  constant uint64_t & nb11,
515
  constant uint64_t & nb12,
 
525
  const int64_t r1 = tgpig.y;
526
  const int64_t im = tgpig.z;
527
 
528
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
529
  device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
530
 
531
  sum[tpitg.x] = 0.0f;
 
612
  constant int & n_past,
613
  constant int & n_dims,
614
  constant int & mode,
615
+ constant float & freq_base,
616
+ constant float & freq_scale,
617
  uint3 tpig[[thread_position_in_grid]]) {
618
  const int64_t i3 = tpig[2];
619
  const int64_t i2 = tpig[1];
620
  const int64_t i1 = tpig[0];
621
 
622
  const bool is_neox = mode & 2;
623
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
624
 
625
  const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
626
 
627
+ float theta = freq_scale * (float)p;
628
 
629
  if (!is_neox) {
630
  for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 
643
  dst_data[1] = x0*sin_theta + x1*cos_theta;
644
  }
645
  } else {
646
+ for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
647
+ for (int64_t ic = 0; ic < n_dims; ic += 2) {
648
+ const float cos_theta = cos(theta);
649
+ const float sin_theta = sin(theta);
650
+
651
+ theta *= theta_scale;
652
+
653
+ const int64_t i0 = ib*n_dims + ic/2;
654
+
655
+ device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
656
+ device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
657
+
658
+ const float x0 = src[0];
659
+ const float x1 = src[n_dims/2];
660
+
661
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
662
+ dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
663
+ }
664
+ }
665
  }
666
  }
667
 
 
880
  return r;
881
  }
882
 
883
+ //====================================== dot products =========================
884
 
885
+ kernel void kernel_mul_mat_q2_K_f32(
886
+ device const void * src0,
887
+ device const float * src1,
888
+ device float * dst,
889
+ constant int64_t & ne00,
890
+ constant int64_t & ne01[[buffer(4)]],
891
+ constant int64_t & ne02[[buffer(5)]],
892
+ constant int64_t & ne10[[buffer(9)]],
893
+ constant int64_t & ne12[[buffer(11)]],
894
+ constant int64_t & ne0[[buffer(15)]],
895
+ constant int64_t & ne1[[buffer(16)]],
896
+ constant uint & gqa[[buffer(17)]],
897
+ uint3 tgpig[[threadgroup_position_in_grid]],
898
+ uint tiisg[[thread_index_in_simdgroup]],
899
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
900
 
901
+ const int nb = ne00/QK_K;
902
+ const int r0 = tgpig.x;
903
+ const int r1 = tgpig.y;
904
+ const int r2 = tgpig.z;
905
 
906
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
907
+ const int ib_row = first_row * nb;
908
+ const uint offset0 = r2/gqa*(nb*ne0);
909
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
910
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
911
+ float yl[32];
912
+ float sumf[N_DST]={0.f}, all_sum;
913
 
914
+ const int step = sizeof(block_q2_K) * nb;
915
 
916
  #if QK_K == 256
917
+ const int ix = tiisg/8; // 0...3
918
+ const int it = tiisg%8; // 0...7
919
+ const int im = it/4; // 0 or 1
920
+ const int ir = it%4; // 0...3
921
+ const int is = (8*ir)/16;// 0 or 1
922
+
923
+ device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
924
+
925
+ for (int ib = ix; ib < nb; ib += 4) {
926
+
927
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
928
+ for (int i = 0; i < 8; ++i) {
929
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
930
+ yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
931
+ yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
932
+ yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
933
+ }
934
 
935
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
936
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
937
+ device const half * dh = &x[ib].d;
938
+
939
+ for (int row = 0; row < N_DST; row++) {
940
+
941
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
942
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
943
+ for (int i = 0; i < 8; i += 2) {
944
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
945
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
946
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
947
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
948
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
949
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
950
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
951
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
952
  }
953
+ float dall = dh[0];
954
+ float dmin = dh[1] * 1.f/16.f;
955
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
956
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
957
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
958
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
959
+ dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
960
+
961
+ qs += step/2;
962
+ sc += step;
963
+ dh += step/2;
964
  }
965
+
966
+ y4 += 4 * QK_K;
967
+ }
968
  #else
969
+ const int ix = tiisg/2; // 0...15
970
+ const int it = tiisg%2; // 0...1
971
+
972
+ device const float * y4 = y + ix * QK_K + 8 * it;
973
+
974
+ for (int ib = ix; ib < nb; ib += 16) {
975
+
976
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
977
+ for (int i = 0; i < 8; ++i) {
978
+ yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
979
+ yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
980
+ yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
981
+ yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
982
+ }
983
+
984
+ device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
985
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
986
+ device const half * dh = &x[ib].d;
987
+
988
+ for (int row = 0; row < N_DST; row++) {
989
+
990
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
991
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
992
+ for (int i = 0; i < 8; i += 2) {
993
+ acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
994
+ acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
995
+ acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
996
+ acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
997
+ acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
998
+ acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
999
+ acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
1000
+ acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
1001
+ }
1002
+
1003
+ float dall = dh[0];
1004
+ float dmin = dh[1];
1005
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
1006
+ (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
1007
+ (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
1008
+ (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
1009
+ dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
1010
+
1011
+ qs += step/2;
1012
+ sc += step;
1013
+ dh += step/2;
1014
  }
1015
+
1016
+ y4 += 16 * QK_K;
1017
+ }
1018
  #endif
1019
 
1020
+ for (int row = 0; row < N_DST; ++row) {
1021
+ all_sum = simd_sum(sumf[row]);
1022
+ if (tiisg == 0) {
1023
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1024
+ }
1025
  }
1026
  }
1027
 
 
 
 
 
1028
  #if QK_K == 256
1029
+ kernel void kernel_mul_mat_q3_K_f32(
1030
+ device const void * src0,
1031
+ device const float * src1,
1032
+ device float * dst,
1033
+ constant int64_t & ne00,
1034
+ constant int64_t & ne01[[buffer(4)]],
1035
+ constant int64_t & ne02[[buffer(5)]],
1036
+ constant int64_t & ne10[[buffer(9)]],
1037
+ constant int64_t & ne12[[buffer(11)]],
1038
+ constant int64_t & ne0[[buffer(15)]],
1039
+ constant int64_t & ne1[[buffer(16)]],
1040
+ constant uint & gqa[[buffer(17)]],
1041
+ uint3 tgpig[[threadgroup_position_in_grid]],
1042
+ uint tiisg[[thread_index_in_simdgroup]],
1043
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1044
+
1045
+ const int nb = ne00/QK_K;
1046
+
1047
+ const int64_t r0 = tgpig.x;
1048
+ const int64_t r1 = tgpig.y;
1049
+ const int64_t r2 = tgpig.z;
1050
+
1051
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1052
+ const uint offset0 = r2/gqa*(nb*ne0);
1053
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
1054
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1055
+
1056
+ float yl[16];
1057
 
1058
  const uint16_t kmask1 = 0x0303;
1059
  const uint16_t kmask2 = 0x0f0f;
1060
 
1061
+ const int tid = tiisg/2;
1062
+ const int ix = tiisg%2;
1063
+ const int ip = tid/8; // 0 or 1
1064
+ const int il = tid/2 - 4*ip; // 0...3
1065
+ const int ir = tid%2;
1066
+ const int n = 8;
1067
+ const int l0 = n*ir;
1068
 
1069
+ const uint16_t m1 = 1 << (4*ip + il);
1070
+ const uint16_t m2 = m1 << 8;
1071
 
1072
+ const int shift = 2*il;
1073
+ const uint16_t qm1 = 0x0003 << shift;
1074
+ const uint16_t qm2 = 0x0300 << shift;
1075
+ const int32_t v1 = 4 << shift;
1076
+ const int32_t v2 = 1024 << shift;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1077
 
1078
+ const uint16_t s_shift1 = 4*ip;
1079
+ const uint16_t s_shift2 = s_shift1 + 2*(il/2);
1080
+ const int ik = 4 + (il%2);
 
1081
 
1082
+ const int q_offset = 32*ip + l0;
1083
+ const int y_offset = 128*ip + 32*il + l0;
 
 
 
 
 
 
1084
 
1085
+ const int step = sizeof(block_q3_K) * nb / 2;
1086
 
1087
+ device const float * y1 = yy + ix*QK_K + y_offset;
 
1088
 
1089
+ float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1090
+ for (int i = ix; i < nb; i += 2) {
 
 
1091
 
1092
  for (int l = 0; l < 8; ++l) {
1093
+ yl[l+0] = y1[l+ 0];
1094
+ yl[l+8] = y1[l+16];
 
 
 
 
 
 
 
1095
  }
 
 
 
1096
 
1097
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1098
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1099
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1100
+ device const half * dh = &x[i].d;
1101
 
1102
+ for (int row = 0; row < 2; ++row) {
 
 
1103
 
1104
+ const float d_all = (float)dh[0];
1105
+ const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1106
 
1107
+ float s1 = 0, s2 = 0;
1108
+ for (int l = 0; l < n; l += 2) {
1109
+ const uint16_t qs = q[l/2];
1110
+ s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1111
+ s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1112
+ }
1113
+ float d = d_all * (s1 + 1.f/256.f * s2);
1114
+ sumf1[row] += d * scales[0];
1115
+ sumf2[row] += d;
1116
+
1117
+ s1 = s2 = 0;
1118
+ for (int l = 0; l < n; l += 2) {
1119
+ const uint16_t qs = q[l/2+8];
1120
+ s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1121
+ s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1122
+ }
1123
+ d = d_all * (s1 + 1.f/256.f * s2);
1124
+ sumf1[row] += d * scales[1];
1125
+ sumf2[row] += d;
1126
+
1127
+ q += step;
1128
+ h += step;
1129
+ a += step;
1130
+ dh += step;
1131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1132
  }
 
 
1133
 
1134
+ y1 += 2 * QK_K;
 
1135
 
1136
+ }
 
 
1137
 
1138
+ for (int row = 0; row < 2; ++row) {
1139
+ const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1140
+ const float tot = simd_sum(sumf);
1141
+ if (tiisg == 0) {
1142
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1143
  }
1144
  }
1145
+ }
1146
  #else
1147
+ kernel void kernel_mul_mat_q3_K_f32(
1148
+ device const void * src0,
1149
+ device const float * src1,
1150
+ device float * dst,
1151
+ constant int64_t & ne00,
1152
+ constant int64_t & ne01[[buffer(4)]],
1153
+ constant int64_t & ne02[[buffer(5)]],
1154
+ constant int64_t & ne10[[buffer(9)]],
1155
+ constant int64_t & ne12[[buffer(11)]],
1156
+ constant int64_t & ne0[[buffer(15)]],
1157
+ constant int64_t & ne1[[buffer(16)]],
1158
+ constant uint & gqa[[buffer(17)]],
1159
+ uint3 tgpig[[threadgroup_position_in_grid]],
1160
+ uint tiisg[[thread_index_in_simdgroup]],
1161
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1162
 
1163
+ const int nb = ne00/QK_K;
1164
 
1165
+ const int64_t r0 = tgpig.x;
1166
+ const int64_t r1 = tgpig.y;
1167
+ const int64_t r2 = tgpig.z;
1168
+
1169
+ const int row = 2 * r0 + sgitg;
1170
+ const uint offset0 = r2/gqa*(nb*ne0);
1171
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1172
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1173
+ const int ix = tiisg/4;
1174
+ const int il = 4 * (tiisg%4);// 0, 4, 8, 12
1175
+ const int im = il/8; // 0, 0, 1, 1
1176
+ const int in = il%8; // 0, 4, 0, 4
1177
 
1178
+ float2 sum = {0.f, 0.f};
 
 
 
 
 
 
 
 
 
 
 
 
1179
 
1180
+ for (int i = ix; i < nb; i += 8) {
1181
 
1182
+ const float d_all = (float)(x[i].d);
 
 
1183
 
1184
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1185
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1186
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1187
+ device const float * y = yy + i * QK_K + il;
1188
+
1189
+ const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1190
+ const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1191
+ const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1192
+ const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1193
+
1194
+ for (int l = 0; l < 4; l += 2) {
1195
+ const uint16_t hm = h[l/2] >> im;
1196
+ sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1197
+ + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1198
+ + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1199
+ + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1200
+ sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1201
+ + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1202
+ + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1203
+ + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
 
 
 
 
 
1204
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1205
 
1206
  }
1207
+ const float sumf = sum[0] + sum[1] * 1.f/256.f;
 
 
 
 
 
1208
 
1209
+ const float tot = simd_sum(sumf);
1210
+ if (tiisg == 0) {
1211
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1212
  }
 
 
 
1213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1214
  }
1215
+ #endif
1216
 
1217
+ #if QK_K == 256
1218
+ kernel void kernel_mul_mat_q4_K_f32(
1219
  device const void * src0,
1220
  device const float * src1,
1221
  device float * dst,
1222
  constant int64_t & ne00,
1223
+ constant int64_t & ne01[[buffer(4)]],
1224
+ constant int64_t & ne02[[buffer(5)]],
1225
+ constant int64_t & ne10[[buffer(9)]],
1226
+ constant int64_t & ne12[[buffer(11)]],
1227
+ constant int64_t & ne0[[buffer(15)]],
1228
+ constant int64_t & ne1[[buffer(16)]],
1229
+ constant uint & gqa[[buffer(17)]],
1230
+ uint3 tgpig[[threadgroup_position_in_grid]],
1231
+ uint tiisg[[thread_index_in_simdgroup]],
1232
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
 
 
 
 
 
 
 
 
 
 
 
1233
 
1234
+ const uint16_t kmask1 = 0x3f3f;
1235
  const uint16_t kmask2 = 0x0f0f;
1236
+ const uint16_t kmask3 = 0xc0c0;
1237
 
1238
+ const int ix = tiisg/8; // 0...3
1239
+ const int it = tiisg%8; // 0...7
1240
+ const int im = it/4; // 0 or 1
1241
+ const int ir = it%4; // 0...3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1242
 
1243
+ const int nb = ne00/QK_K;
1244
+ const int r0 = tgpig.x;
1245
+ const int r1 = tgpig.y;
1246
+ const int r2 = tgpig.z;
1247
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1248
+ const int ib_row = first_row * nb;
1249
+ const uint offset0 = r2/gqa*(nb*ne0);
1250
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1251
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1252
+ float yl[16];
1253
+ float yh[16];
1254
+ float sumf[N_DST]={0.f}, all_sum;
1255
+
1256
+ const int step = sizeof(block_q4_K) * nb / 2;
1257
+
1258
+ device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
1259
+
1260
+ uint16_t sc16[4];
1261
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1262
+
1263
+ for (int ib = ix; ib < nb; ib += 4) {
1264
+
1265
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1266
+ for (int i = 0; i < 8; ++i) {
1267
+ yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
1268
+ yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
1269
+ yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
1270
+ yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
1271
  }
 
 
 
 
1272
 
1273
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
1274
+ device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
1275
+ device const half * dh = &x[ib].d;
1276
+
1277
+ for (int row = 0; row < N_DST; row++) {
1278
+
1279
+ sc16[0] = sc[0] & kmask1;
1280
+ sc16[1] = sc[2] & kmask1;
1281
+ sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
1282
+ sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
1283
+
1284
+ device const uint16_t * q2 = q1 + 32;
1285
+
1286
+ float4 acc1 = {0.f, 0.f, 0.f, 0.f};
1287
+ float4 acc2 = {0.f, 0.f, 0.f, 0.f};
1288
+ for (int i = 0; i < 8; i += 2) {
1289
+ acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
1290
+ acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
1291
+ acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
1292
+ acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
1293
+ acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
1294
+ acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
1295
+ acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
1296
+ acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
1297
+ }
1298
 
1299
+ float dall = dh[0];
1300
+ float dmin = dh[1];
1301
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
1302
+ (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
1303
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
1304
+ (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
1305
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1306
+
1307
+ q1 += step;
1308
+ sc += step;
1309
+ dh += step;
1310
  }
1311
 
1312
+ y4 += 4 * QK_K;
1313
  }
1314
 
1315
+ for (int row = 0; row < N_DST; ++row) {
1316
+ all_sum = simd_sum(sumf[row]);
1317
+ if (tiisg == 0) {
1318
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1319
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1320
  }
 
1321
  }
1322
+ #else
1323
  kernel void kernel_mul_mat_q4_K_f32(
1324
  device const void * src0,
1325
  device const float * src1,
1326
  device float * dst,
1327
  constant int64_t & ne00,
1328
+ constant int64_t & ne01[[buffer(4)]],
1329
+ constant int64_t & ne02[[buffer(5)]],
1330
+ constant int64_t & ne10[[buffer(9)]],
1331
+ constant int64_t & ne12[[buffer(11)]],
1332
+ constant int64_t & ne0[[buffer(15)]],
1333
+ constant int64_t & ne1[[buffer(16)]],
1334
+ constant uint & gqa[[buffer(17)]],
1335
+ uint3 tgpig[[threadgroup_position_in_grid]],
1336
+ uint tiisg[[thread_index_in_simdgroup]],
1337
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1338
 
1339
+ const int ix = tiisg/4; // 0...7
1340
+ const int it = tiisg%4; // 0...3
1341
 
1342
+ const int nb = ne00/QK_K;
1343
+ const int r0 = tgpig.x;
1344
+ const int r1 = tgpig.y;
1345
+ const int r2 = tgpig.z;
1346
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1347
+ const int ib_row = first_row * nb;
1348
+ const uint offset0 = r2/gqa*(nb*ne0);
1349
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1350
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1351
+ float yl[8];
1352
+ float yh[8];
1353
+ float sumf[N_DST]={0.f}, all_sum;
1354
+
1355
+ const int step = sizeof(block_q4_K) * nb / 2;
1356
+
1357
+ device const float * y4 = y + ix * QK_K + 8 * it;
1358
+
1359
+ uint16_t sc16[4];
1360
+
1361
+ for (int ib = ix; ib < nb; ib += 8) {
1362
+
1363
+ float2 sumy = {0.f, 0.f};
1364
+ for (int i = 0; i < 8; ++i) {
1365
+ yl[i] = y4[i+ 0]; sumy[0] += yl[i];
1366
+ yh[i] = y4[i+32]; sumy[1] += yh[i];
1367
+ }
1368
 
1369
+ device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
1370
+ device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
1371
+ device const half * dh = x[ib].d;
 
1372
 
1373
+ for (int row = 0; row < N_DST; row++) {
 
1374
 
1375
+ sc16[0] = sc[0] & 0x000f;
1376
+ sc16[1] = sc[0] & 0x0f00;
1377
+ sc16[2] = sc[0] & 0x00f0;
1378
+ sc16[3] = sc[0] & 0xf000;
 
1379
 
1380
+ float2 acc1 = {0.f, 0.f};
1381
+ float2 acc2 = {0.f, 0.f};
1382
+ for (int i = 0; i < 8; i += 2) {
1383
+ acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
1384
+ acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
1385
+ acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
1386
+ acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
1387
+ }
1388
 
1389
+ float dall = dh[0];
1390
+ float dmin = dh[1];
1391
+ sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
1392
+ (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
1393
+ dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
1394
 
1395
+ qs += step;
1396
+ sc += step;
1397
+ dh += step;
1398
  }
 
1399
 
1400
+ y4 += 8 * QK_K;
1401
  }
 
 
 
 
 
 
 
1402
 
1403
+ for (int row = 0; row < N_DST; ++row) {
1404
+ all_sum = simd_sum(sumf[row]);
1405
+ if (tiisg == 0) {
1406
+ dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
 
 
 
 
 
 
 
 
 
1407
  }
1408
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1409
  }
1410
+ #endif
1411
 
1412
  kernel void kernel_mul_mat_q5_K_f32(
1413
  device const void * src0,
1414
  device const float * src1,
1415
  device float * dst,
1416
  constant int64_t & ne00,
1417
+ constant int64_t & ne01[[buffer(4)]],
1418
+ constant int64_t & ne02[[buffer(5)]],
1419
+ constant int64_t & ne10[[buffer(9)]],
1420
+ constant int64_t & ne12[[buffer(11)]],
1421
+ constant int64_t & ne0[[buffer(15)]],
1422
+ constant int64_t & ne1[[buffer(16)]],
1423
+ constant uint & gqa[[buffer(17)]],
1424
+ uint3 tgpig[[threadgroup_position_in_grid]],
1425
+ uint tiisg[[thread_index_in_simdgroup]],
1426
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1427
 
1428
  const int nb = ne00/QK_K;
1429
 
1430
  const int64_t r0 = tgpig.x;
1431
  const int64_t r1 = tgpig.y;
1432
+ const int r2 = tgpig.z;
1433
 
1434
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1435
+ const uint offset0 = r2/gqa*(nb*ne0);
1436
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
1437
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
1438
 
1439
+ float sumf[2]={0.f};
 
1440
 
1441
+ const int step = sizeof(block_q5_K) * nb;
1442
 
1443
  #if QK_K == 256
1444
+ #
1445
+ float yl[16], yh[16];
1446
 
1447
  const uint16_t kmask1 = 0x3f3f;
1448
  const uint16_t kmask2 = 0x0f0f;
1449
  const uint16_t kmask3 = 0xc0c0;
1450
 
1451
+ const int tid = tiisg/4;
1452
+ const int ix = tiisg%4;
1453
+ const int im = tid/4;
1454
+ const int ir = tid%4;
1455
+ const int n = 8;
 
 
1456
 
1457
+ const int l0 = n*ir;
1458
  const int q_offset = 32*im + l0;
1459
  const int y_offset = 64*im + l0;
1460
 
 
1463
  const uint8_t hm3 = hm1 << 4;
1464
  const uint8_t hm4 = hm2 << 4;
1465
 
1466
+ uint16_t sc16[4];
1467
+ thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
1468
 
1469
+ device const float * y1 = yy + ix*QK_K + y_offset;
1470
 
1471
+ for (int i = ix; i < nb; i += 4) {
 
 
 
 
1472
 
1473
+ device const uint8_t * q1 = x[i].qs + q_offset;
1474
+ device const uint8_t * qh = x[i].qh + l0;
1475
+ device const half * dh = &x[i].d;
1476
+ device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
1477
 
1478
+ device const float * y2 = y1 + 128;
1479
+ float4 sumy = {0.f, 0.f, 0.f, 0.f};
1480
+ for (int l = 0; l < 8; ++l) {
1481
+ yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
1482
+ yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
1483
+ yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
1484
+ yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
1485
+ }
1486
 
1487
+ for (int row = 0; row < 2; ++row) {
1488
+
1489
+ device const uint8_t * q2 = q1 + 64;
1490
+
1491
+ sc16[0] = a[0] & kmask1;
1492
+ sc16[1] = a[2] & kmask1;
1493
+ sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
1494
+ sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
1495
+
1496
+ float4 acc = {0.f, 0.f, 0.f, 0.f};
1497
+ for (int l = 0; l < n; ++l) {
1498
+ uint8_t h = qh[l];
1499
+ acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
1500
+ acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
1501
+ acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
1502
+ acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
1503
+ }
1504
+ const float dall = dh[0];
1505
+ const float dmin = dh[1];
1506
+ sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
1507
+ dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
1508
 
1509
+ q1 += step;
1510
+ qh += step;
1511
+ dh += step/2;
1512
+ a += step/2;
 
1513
 
1514
  }
1515
+
1516
+ y1 += 4 * QK_K;
1517
 
1518
  }
1519
  #else
1520
+ float yl[8], yh[8];
1521
+
1522
+ const int il = 4 * (tiisg/8); // 0, 4, 8, 12
1523
+ const int ix = tiisg%8;
1524
+ const int im = il/8; // 0, 0, 1, 1
1525
+ const int in = il%8; // 0, 4, 0, 4
1526
+
1527
+ device const float * y = yy + ix*QK_K + il;
1528
 
1529
+ for (int i = ix; i < nb; i += 8) {
1530
 
1531
+ for (int l = 0; l < 4; ++l) {
1532
+ yl[l+0] = y[l+ 0];
1533
+ yl[l+4] = y[l+16];
1534
+ yh[l+0] = y[l+32];
1535
+ yh[l+4] = y[l+48];
1536
+ }
1537
+
1538
+ device const half * dh = &x[i].d;
1539
  device const uint8_t * q = x[i].qs + il;
1540
  device const uint8_t * h = x[i].qh + in;
1541
  device const int8_t * s = x[i].scales;
 
1542
 
1543
+ for (int row = 0; row < 2; ++row) {
1544
+
1545
+ const float d = dh[0];
1546
+
1547
+ float2 acc = {0.f, 0.f};
1548
+ for (int l = 0; l < 4; ++l) {
1549
+ const uint8_t hl = h[l] >> im;
1550
+ acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
1551
+ + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
1552
+ acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
1553
+ + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
1554
+ }
1555
+ sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
1556
+
1557
+ q += step;
1558
+ h += step;
1559
+ s += step;
1560
+ dh += step/2;
1561
+
1562
  }
1563
+
1564
+ y += 8 * QK_K;
1565
  }
1566
  #endif
 
1567
 
1568
+ for (int row = 0; row < 2; ++row) {
1569
+ const float tot = simd_sum(sumf[row]);
1570
+ if (tiisg == 0) {
1571
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1572
+ }
 
 
 
 
 
 
 
 
 
 
1573
  }
1574
 
1575
  }
 
1579
  device const float * src1,
1580
  device float * dst,
1581
  constant int64_t & ne00,
1582
+ constant int64_t & ne01[[buffer(4)]],
1583
+ constant int64_t & ne02[[buffer(5)]],
1584
+ constant int64_t & ne10[[buffer(9)]],
1585
+ constant int64_t & ne12[[buffer(11)]],
1586
+ constant int64_t & ne0[[buffer(15)]],
1587
+ constant int64_t & ne1[[buffer(16)]],
1588
+ constant uint & gqa[[buffer(17)]],
1589
+ uint3 tgpig[[threadgroup_position_in_grid]],
1590
+ uint tiisg[[thread_index_in_simdgroup]],
1591
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1592
 
1593
  const uint8_t kmask1 = 0x03;
1594
  const uint8_t kmask2 = 0x0C;
 
1599
 
1600
  const int64_t r0 = tgpig.x;
1601
  const int64_t r1 = tgpig.y;
1602
+ const int r2 = tgpig.z;
1603
 
1604
+ const int row = 2 * r0 + sgitg;
1605
+ const uint offset0 = r2/gqa*(nb*ne0);
1606
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
1607
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
 
1608
 
1609
  float sumf = 0;
1610
 
1611
  #if QK_K == 256
1612
+ const int tid = tiisg/2;
1613
+ const int ix = tiisg%2;
1614
+ const int ip = tid/8; // 0 or 1
1615
+ const int il = tid%8;
1616
  const int n = 4;
1617
  const int l0 = n*il;
1618
  const int is = 8*ip + l0/16;
 
1621
  const int q_offset_l = 64*ip + l0;
1622
  const int q_offset_h = 32*ip + l0;
1623
 
1624
+ for (int i = ix; i < nb; i += 2) {
1625
 
1626
+ device const uint8_t * q1 = x[i].ql + q_offset_l;
1627
+ device const uint8_t * q2 = q1 + 32;
1628
  device const uint8_t * qh = x[i].qh + q_offset_h;
1629
  device const int8_t * sc = x[i].scales + is;
1630
 
 
1634
 
1635
  float4 sums = {0.f, 0.f, 0.f, 0.f};
1636
  for (int l = 0; l < n; ++l) {
1637
+ sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1638
+ sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1639
+ sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
1640
+ sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1641
  }
1642
 
1643
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1644
 
1645
  }
1646
+
1647
  #else
1648
+ const int ix = tiisg/4;
1649
+ const int il = 4*(tiisg%4);
1650
 
1651
+ for (int i = ix; i < nb; i += 8) {
1652
  device const float * y = yy + i * QK_K + il;
1653
  device const uint8_t * ql = x[i].ql + il;
1654
  device const uint8_t * qh = x[i].qh + il;
 
1668
 
1669
  #endif
1670
 
1671
+ const float tot = simd_sum(sumf);
1672
+ if (tiisg == 0) {
1673
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1674
+ }
1675
+ }
1676
 
1677
+ //============================= templates and their specializations =============================
1678
+
1679
+ template <typename type4x4>
1680
+ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
1681
+ half4x4 temp = *(((device half4x4 *)src));
1682
+ for (int i = 0; i < 16; i++){
1683
+ reg[i/4][i%4] = temp[i/4][i%4];
1684
  }
1685
+ }
1686
+
1687
+ template <typename type4x4>
1688
+ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
1689
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
1690
+ const half d = il ? (xb->d / 16.h) : xb->d;
1691
+ const half m = il ? ( -8.h * 16.h) : -8.h;
1692
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
1693
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
1694
+
1695
+ for (int i=0;i<8;i++) {
1696
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
1697
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
1698
  }
1699
+ }
1700
+
1701
+ template <typename type4x4>
1702
+ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
1703
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
1704
+ const half d = il ? (xb->d / 16.h) : xb->d;
1705
+ const half m = xb->m;
1706
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
1707
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
1708
+
1709
+ for (int i=0;i<8;i++) {
1710
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
1711
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
1712
+ }
1713
+ }
1714
+
1715
+ template <typename type4x4>
1716
+ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
1717
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
1718
+ const half d = xb->d;
1719
+
1720
+ for (int i=0;i<16;i++) {
1721
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
1722
+ }
1723
+ }
1724
+
1725
+ template <typename type4x4>
1726
+ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
1727
+ const half d = xb->d;
1728
+ const half min = xb->dmin;
1729
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
1730
+ half dl, ml;
1731
+ uint8_t sc = xb->scales[il];
1732
+
1733
+ #if QK_K == 256
1734
+ q = q + 32*(il/8) + 16*(il&1);
1735
+ il = (il/2)%4;
1736
+ #endif
1737
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1738
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1739
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
1740
+ for (int i = 0; i < 16; ++i) {
1741
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1742
+ }
1743
+ }
1744
+
1745
+ template <typename type4x4>
1746
+ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
1747
+ const float d_all = (float)(xb->d);
1748
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
1749
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
1750
+ device const int8_t * scales = (device const int8_t *)xb->scales;
1751
+
1752
+ #if QK_K == 256
1753
+ q = q + 32 * (il/8) + 16 * (il&1);
1754
+ h = h + 16 * (il&1);
1755
+ uint8_t m = 1 << (il/2);
1756
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
1757
+ ((il/4)>0 ? 12 : 3);
1758
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
1759
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
1760
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
1761
+ (scale_2&kmask2) | ((scale_1&kmask1) << 4);
1762
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
1763
+
1764
+ il = (il/2)%4;
1765
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1766
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1767
+
1768
+ for (int i = 0; i < 16; ++i) {
1769
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
1770
+ }
1771
+ #else
1772
+ float kcoef = il&1 ? 1.f/16.f : 1.f;
1773
+ uint16_t kmask = il&1 ? 0xF0 : 0x0F;
1774
+ float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
1775
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
1776
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1777
+ uint8_t m = 1<<(il*2);
1778
+ for (int i = 0; i < 16; ++i) {
1779
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
1780
+ }
1781
+ #endif
1782
+ }
1783
+
1784
+ template <typename type4x4>
1785
+ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
1786
+ device const uint8_t * q = xb->qs;
1787
+
1788
+ #if QK_K == 256
1789
+ const float d = (float)(xb->d);
1790
+ const float min = (float)(xb->dmin);
1791
+ short is = (il/4) * 2;
1792
+ q = q + (il/4) * 32 + 16 * (il&1);
1793
+ il = il%4;
1794
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
1795
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1796
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
1797
+ #else
1798
+ q = q + 16 * (il&1);
1799
+ device const uint8_t * s = xb->scales;
1800
+ device const half2 * dh = (device const half2 *)xb->d;
1801
+ const float2 d = (float2)dh[0];
1802
+ const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
1803
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
1804
+ #endif
1805
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1806
+ for (int i = 0; i < 16; ++i) {
1807
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
1808
+ }
1809
+ }
1810
+
1811
+ template <typename type4x4>
1812
+ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
1813
+ device const uint8_t * q = xb->qs;
1814
+ device const uint8_t * qh = xb->qh;
1815
+
1816
+ #if QK_K == 256
1817
+ const float d = (float)(xb->d);
1818
+ const float min = (float)(xb->dmin);
1819
+ short is = (il/4) * 2;
1820
+ q = q + 32 * (il/4) + 16 * (il&1);
1821
+ qh = qh + 16 * (il&1);
1822
+ uint8_t ul = 1 << (il/2);
1823
+ il = il%4;
1824
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
1825
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
1826
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
1827
+
1828
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1829
+ const float qh_val = il<2 ? 16.f : 256.f;
1830
+ for (int i = 0; i < 16; ++i) {
1831
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
1832
+ }
1833
+ #else
1834
+ q = q + 16 * (il&1);
1835
+ device const int8_t * s = xb->scales;
1836
+ const float dl = xb->d * s[il];
1837
+ uint8_t m = 1<<(il*2);
1838
+ const float coef = il<2 ? 1.f : 1.f/16.f;
1839
+ const ushort mask = il<2 ? 0x0F : 0xF0;
1840
+ for (int i = 0; i < 16; ++i) {
1841
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
1842
+ }
1843
+ #endif
1844
+ }
1845
+
1846
+ template <typename type4x4>
1847
+ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
1848
+ const float d_all = (float)(xb->d);
1849
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
1850
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
1851
+ device const int8_t * scales = (device const int8_t *)xb->scales;
1852
+
1853
+ #if QK_K == 256
1854
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
1855
+ qh = qh + 32*(il/8) + 16*(il&1);
1856
+ float sc = scales[(il%2) + 2 * ((il/2))];
1857
+ il = (il/2)%4;
1858
+ #else
1859
+ ql = ql + 16 * (il&1);
1860
+ float sc = scales[il];
1861
+ #endif
1862
+ for (int i = 0; i < 16; ++i) {
1863
+ uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
1864
+ uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
1865
+ const float coef = il>1 ? 1.f/16.f : 1.f;
1866
+ float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
1867
+ ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
1868
+ reg[i/4][i%4] = d_all * sc * q * coef;
1869
  }
1870
+ }
1871
+
1872
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
1873
+ kernel void kernel_get_rows(
1874
+ device const void * src0,
1875
+ device const int * src1,
1876
+ device float * dst,
1877
+ constant int64_t & ne00,
1878
+ constant uint64_t & nb01,
1879
+ constant uint64_t & nb1,
1880
+ uint tgpig[[threadgroup_position_in_grid]],
1881
+ uint tiitg[[thread_index_in_threadgroup]],
1882
+ uint tptg[[threads_per_threadgroup]]) {
1883
+ const int i = tgpig;
1884
+ const int r = ((device int32_t *) src1)[i];
1885
 
1886
+ for (int ind = tiitg; ind < ne00/16; ind += tptg) {
1887
+ float4x4 temp;
1888
+ dequantize_func(
1889
+ ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
1890
+ *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
1891
+ }
1892
  }
1893
+
1894
+ #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
1895
+ #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
1896
+ #define BLOCK_SIZE_K 32
1897
+ #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
1898
+ #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
1899
+ #define THREAD_PER_BLOCK 128
1900
+ #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
1901
+ #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
1902
+ #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
1903
+ #define SG_MAT_ROW 8
1904
+
1905
+ // each block_q contains 16*nl weights
1906
+ template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
1907
+ kernel void kernel_mul_mm(device const uchar * src0,
1908
+ device const float * src1,
1909
+ device float * dst,
1910
+ constant int64_t & ne00,
1911
+ constant int64_t & ne02,
1912
+ constant int64_t & nb01,
1913
+ constant int64_t & nb02,
1914
+ constant int64_t & ne12,
1915
+ constant int64_t & ne0,
1916
+ constant int64_t & ne1,
1917
+ constant uint & gqa,
1918
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
1919
+ uint3 tgpig[[threadgroup_position_in_grid]],
1920
+ uint tiitg[[thread_index_in_threadgroup]],
1921
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1922
+
1923
+ threadgroup half * sa = ((threadgroup half *)shared_memory);
1924
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
1925
+
1926
+ const uint r0 = tgpig.y;
1927
+ const uint r1 = tgpig.x;
1928
+ const uint im = tgpig.z;
1929
+ // if this block is of 64x32 shape or smaller
1930
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
1931
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
1932
+ // a thread shouldn't load data outside of the matrix
1933
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
1934
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
1935
+
1936
+ simdgroup_half8x8 ma[4];
1937
+ simdgroup_float8x8 mb[2];
1938
+ simdgroup_float8x8 c_res[8];
1939
+ for (int i = 0; i < 8; i++){
1940
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
1941
+ }
1942
+
1943
+ short il = (tiitg % THREAD_PER_ROW);
1944
+ uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
1945
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
1946
+ device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
1947
+ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
1948
+
1949
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
1950
+ //load data and store to threadgroup memory
1951
+ half4x4 temp_a;
1952
+ dequantize_func(x, il, temp_a);
1953
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1954
+ #pragma unroll(16)
1955
+ for (int i = 0; i < 16; i++) {
1956
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
1957
+ + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
1958
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
1959
+ }
1960
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
1961
+ = *((device float2x4 *)y);
1962
+ il = (il + 2 < nl) ? il + 2 : il % 2;
1963
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
1964
+ y += BLOCK_SIZE_K;
1965
+
1966
+ threadgroup_barrier(mem_flags::mem_threadgroup);
1967
+ //load matrices from threadgroup memory and conduct outer products
1968
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
1969
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
1970
+ #pragma unroll(4)
1971
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
1972
+ #pragma unroll(4)
1973
+ for (int i = 0; i < 4; i++) {
1974
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
1975
+ }
1976
+ simdgroup_barrier(mem_flags::mem_none);
1977
+ #pragma unroll(2)
1978
+ for (int i = 0; i < 2; i++) {
1979
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
1980
+ }
1981
+
1982
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
1983
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
1984
+ #pragma unroll(8)
1985
+ for (int i = 0; i < 8; i++){
1986
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
1987
+ }
1988
+ }
1989
+ }
1990
+
1991
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
1992
+ device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
1993
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
1994
+ for (int i = 0; i < 8; i++) {
1995
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
1996
+ }
1997
+ } else {
1998
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
1999
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2000
+ threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2001
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
2002
+ for (int i = 0; i < 8; i++) {
2003
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
2004
+ }
2005
+
2006
+ threadgroup_barrier(mem_flags::mem_threadgroup);
2007
+ device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2008
+ if (sgitg==0) {
2009
+ for (int i = 0; i < n_rows; i++) {
2010
+ for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2011
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
2012
+ }
2013
+ }
2014
+ }
2015
+ }
2016
+ }
2017
+
2018
+ #if QK_K == 256
2019
+ #define QK_NL 16
2020
+ #else
2021
+ #define QK_NL 4
2022
+ #endif
2023
+
2024
+ typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
2025
+ constant uint64_t &, constant uint64_t &, uint, uint, uint);
2026
+
2027
+ template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2028
+ template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2029
+ template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2030
+ template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2031
+ template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2032
+ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
2033
+ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
2034
+ template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
2035
+ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
2036
+
2037
+ typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
2038
+ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
2039
+ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
2040
+
2041
+ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2042
+ template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2043
+ template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2044
+ template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2045
+ template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2046
+ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
2047
+ template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
2048
+ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
2049
+ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
ggml-opencl.cpp CHANGED
@@ -656,10 +656,14 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
656
  \n#if K_QUANTS_PER_ITERATION == 1\n
657
  const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
658
  const int is = 0;
 
659
  \n#else\n
 
660
  const int l0 = 4 * in; // 0, 4, 8, ..., 28
661
  const int is = in / 4;
 
662
  \n#endif\n
 
663
  const int ql_offset = 64*im + l0;
664
  const int qh_offset = 32*im + l0;
665
  const int s_offset = 8*im + is;
@@ -1376,7 +1380,7 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
1376
  const int64_t ne00 = src0->ne[0];
1377
  const int64_t ne01 = src0->ne[1];
1378
  const int64_t ne02 = src0->ne[2];
1379
- const int64_t ne03 = src0->ne[2];
1380
  const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
1381
  const int64_t ne10 = src1->ne[0];
1382
  const int64_t ne11 = src1->ne[1];
 
656
  \n#if K_QUANTS_PER_ITERATION == 1\n
657
  const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
658
  const int is = 0;
659
+
660
  \n#else\n
661
+
662
  const int l0 = 4 * in; // 0, 4, 8, ..., 28
663
  const int is = in / 4;
664
+
665
  \n#endif\n
666
+
667
  const int ql_offset = 64*im + l0;
668
  const int qh_offset = 32*im + l0;
669
  const int s_offset = 8*im + is;
 
1380
  const int64_t ne00 = src0->ne[0];
1381
  const int64_t ne01 = src0->ne[1];
1382
  const int64_t ne02 = src0->ne[2];
1383
+ const int64_t ne03 = src0->ne[3];
1384
  const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
1385
  const int64_t ne10 = src1->ne[0];
1386
  const int64_t ne11 = src1->ne[1];
ggml.c CHANGED
The diff for this file is too large to render. See raw diff
 
ggml.h CHANGED
@@ -65,7 +65,7 @@
65
  // ggml_set_f32(a, 3.0f);
66
  // ggml_set_f32(b, 4.0f);
67
  //
68
- // ggml_graph_compute(ctx0, &gf);
69
  //
70
  // printf("f = %f\n", ggml_get_f32_1d(f, 0));
71
  //
@@ -130,13 +130,16 @@
130
  // The data of the tensor is accessed via the "data" pointer. For example:
131
  //
132
  // {
133
- // struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
 
134
  //
135
- // // a[1, 2] = 1.0f;
136
- // *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
137
  //
138
- // // a[2, 0] = 2.0f;
139
- // *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
 
 
 
140
  //
141
  // ...
142
  // }
@@ -183,6 +186,15 @@
183
  # define GGML_API
184
  #endif
185
 
 
 
 
 
 
 
 
 
 
186
  #include <stdint.h>
187
  #include <stddef.h>
188
  #include <stdbool.h>
@@ -197,12 +209,29 @@
197
  #define GGML_MAX_NODES 4096
198
  #define GGML_MAX_PARAMS 256
199
  #define GGML_MAX_CONTEXTS 64
200
- #define GGML_MAX_OPT 4
201
- #define GGML_MAX_NAME 48
 
202
  #define GGML_DEFAULT_N_THREADS 4
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  #define GGML_UNUSED(x) (void)(x)
205
 
 
 
206
  #define GGML_ASSERT(x) \
207
  do { \
208
  if (!(x)) { \
@@ -239,8 +268,9 @@
239
  extern "C" {
240
  #endif
241
 
242
- #ifdef __ARM_NEON
243
- // we use the built-in 16-bit float type
 
244
  typedef __fp16 ggml_fp16_t;
245
  #else
246
  typedef uint16_t ggml_fp16_t;
@@ -250,8 +280,8 @@ extern "C" {
250
  GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
251
  GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
252
 
253
- GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
254
- GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
255
 
256
  struct ggml_object;
257
  struct ggml_context;
@@ -324,20 +354,12 @@ extern "C" {
324
  GGML_OP_ARGMAX,
325
  GGML_OP_REPEAT,
326
  GGML_OP_REPEAT_BACK,
327
- GGML_OP_ABS,
328
- GGML_OP_SGN,
329
- GGML_OP_NEG,
330
- GGML_OP_STEP,
331
- GGML_OP_TANH,
332
- GGML_OP_ELU,
333
- GGML_OP_RELU,
334
- GGML_OP_GELU,
335
- GGML_OP_GELU_QUICK,
336
- GGML_OP_SILU,
337
  GGML_OP_SILU_BACK,
338
  GGML_OP_NORM, // normalize
339
  GGML_OP_RMS_NORM,
340
  GGML_OP_RMS_NORM_BACK,
 
341
 
342
  GGML_OP_MUL_MAT,
343
  GGML_OP_OUT_PROD,
@@ -363,16 +385,29 @@ extern "C" {
363
  GGML_OP_CLAMP,
364
  GGML_OP_CONV_1D,
365
  GGML_OP_CONV_2D,
 
 
 
 
 
366
 
367
  GGML_OP_FLASH_ATTN,
368
  GGML_OP_FLASH_FF,
369
  GGML_OP_FLASH_ATTN_BACK,
370
  GGML_OP_WIN_PART,
371
  GGML_OP_WIN_UNPART,
 
 
 
 
372
 
373
  GGML_OP_MAP_UNARY,
374
  GGML_OP_MAP_BINARY,
375
 
 
 
 
 
376
  GGML_OP_MAP_CUSTOM1,
377
  GGML_OP_MAP_CUSTOM2,
378
  GGML_OP_MAP_CUSTOM3,
@@ -383,6 +418,24 @@ extern "C" {
383
  GGML_OP_COUNT,
384
  };
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  // ggml object
388
  struct ggml_object {
@@ -391,7 +444,9 @@ extern "C" {
391
 
392
  struct ggml_object * next;
393
 
394
- char padding[8];
 
 
395
  };
396
 
397
  static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
@@ -411,15 +466,13 @@ extern "C" {
411
  // compute data
412
  enum ggml_op op;
413
 
 
 
 
414
  bool is_param;
415
 
416
  struct ggml_tensor * grad;
417
- struct ggml_tensor * src0;
418
- struct ggml_tensor * src1;
419
- struct ggml_tensor * opt[GGML_MAX_OPT];
420
-
421
- // thread scheduling
422
- int n_tasks;
423
 
424
  // performance
425
  int perf_runs;
@@ -437,25 +490,46 @@ extern "C" {
437
 
438
  static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  // computation graph
441
  struct ggml_cgraph {
442
  int n_nodes;
443
  int n_leafs;
444
- int n_threads;
445
-
446
- size_t work_size;
447
- struct ggml_tensor * work;
448
 
449
  struct ggml_tensor * nodes[GGML_MAX_NODES];
450
  struct ggml_tensor * grads[GGML_MAX_NODES];
451
  struct ggml_tensor * leafs[GGML_MAX_NODES];
452
 
 
 
453
  // performance
454
  int perf_runs;
455
  int64_t perf_cycles;
456
  int64_t perf_time_us;
457
  };
458
 
 
 
459
  // scratch buffer
460
  struct ggml_scratch {
461
  size_t offs;
@@ -509,6 +583,7 @@ extern "C" {
509
  GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
510
  GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
511
  GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
 
512
  GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
513
 
514
  GGML_API int ggml_blck_size (enum ggml_type type);
@@ -517,6 +592,7 @@ extern "C" {
517
 
518
  GGML_API const char * ggml_type_name(enum ggml_type type);
519
  GGML_API const char * ggml_op_name (enum ggml_op op);
 
520
 
521
  GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
522
 
@@ -529,6 +605,8 @@ extern "C" {
529
  GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
530
  GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
531
 
 
 
532
  // use this to compute the memory overhead of a tensor
533
  GGML_API size_t ggml_tensor_overhead(void);
534
 
@@ -540,6 +618,7 @@ extern "C" {
540
  GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
541
 
542
  GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
 
543
  GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
544
 
545
  GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
@@ -599,9 +678,11 @@ extern "C" {
599
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
600
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
601
 
602
- GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
603
- GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
604
- GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...);
 
 
605
 
606
  //
607
  // operations on tensors with backpropagation
@@ -611,6 +692,11 @@ extern "C" {
611
  struct ggml_context * ctx,
612
  struct ggml_tensor * a);
613
 
 
 
 
 
 
614
  GGML_API struct ggml_tensor * ggml_add(
615
  struct ggml_context * ctx,
616
  struct ggml_tensor * a,
@@ -735,6 +821,13 @@ extern "C" {
735
  struct ggml_tensor * a,
736
  struct ggml_tensor * b);
737
 
 
 
 
 
 
 
 
738
  GGML_API struct ggml_tensor * ggml_abs(
739
  struct ggml_context * ctx,
740
  struct ggml_tensor * a);
@@ -824,25 +917,42 @@ extern "C" {
824
  struct ggml_tensor * b);
825
 
826
  // normalize along rows
827
- // TODO: eps is hardcoded to 1e-5 for now
828
  GGML_API struct ggml_tensor * ggml_norm(
829
  struct ggml_context * ctx,
830
- struct ggml_tensor * a);
 
831
 
832
  GGML_API struct ggml_tensor * ggml_norm_inplace(
833
  struct ggml_context * ctx,
834
- struct ggml_tensor * a);
 
835
 
836
  GGML_API struct ggml_tensor * ggml_rms_norm(
837
  struct ggml_context * ctx,
838
- struct ggml_tensor * a);
 
839
 
840
  GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
841
  struct ggml_context * ctx,
842
- struct ggml_tensor * a);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
843
 
844
  // a - x
845
  // b - dy
 
846
  GGML_API struct ggml_tensor * ggml_rms_norm_back(
847
  struct ggml_context * ctx,
848
  struct ggml_tensor * a,
@@ -934,11 +1044,22 @@ extern "C" {
934
  struct ggml_tensor * a,
935
  struct ggml_tensor * b);
936
 
 
 
 
 
 
 
937
  // make contiguous
938
  GGML_API struct ggml_tensor * ggml_cont(
939
  struct ggml_context * ctx,
940
  struct ggml_tensor * a);
941
 
 
 
 
 
 
942
  // return view(a), b specifies the new shape
943
  // TODO: when we start computing gradient, make a copy instead of view
944
  GGML_API struct ggml_tensor * ggml_reshape(
@@ -1107,6 +1228,37 @@ extern "C" {
1107
  int mode,
1108
  int n_ctx);
1109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1110
  // rotary position embedding backward, i.e compute dx from dy
1111
  // a - dy
1112
  GGML_API struct ggml_tensor * ggml_rope_back(
@@ -1114,7 +1266,12 @@ extern "C" {
1114
  struct ggml_tensor * a,
1115
  int n_past,
1116
  int n_dims,
1117
- int mode);
 
 
 
 
 
1118
 
1119
  // alibi position embedding
1120
  // in-place, returns view(a)
@@ -1141,6 +1298,15 @@ extern "C" {
1141
  int p0, // padding
1142
  int d0); // dilation
1143
 
 
 
 
 
 
 
 
 
 
1144
  GGML_API struct ggml_tensor * ggml_conv_2d(
1145
  struct ggml_context * ctx,
1146
  struct ggml_tensor * a,
@@ -1152,14 +1318,70 @@ extern "C" {
1152
  int d0,
1153
  int d1);
1154
 
1155
- // conv_1d with padding = half
1156
- // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1157
- GGML_API struct ggml_tensor* ggml_conv_1d_ph(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1158
  struct ggml_context * ctx,
1159
  struct ggml_tensor * a,
1160
  struct ggml_tensor * b,
1161
- int s,
1162
- int d);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1163
 
1164
  GGML_API struct ggml_tensor * ggml_flash_attn(
1165
  struct ggml_context * ctx,
@@ -1204,6 +1426,37 @@ extern "C" {
1204
  int h0,
1205
  int w);
1206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1207
  // custom operators
1208
 
1209
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -1213,63 +1466,129 @@ extern "C" {
1213
  typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
1214
  typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
1215
 
1216
- GGML_API struct ggml_tensor * ggml_map_unary_f32(
1217
  struct ggml_context * ctx,
1218
  struct ggml_tensor * a,
1219
- ggml_unary_op_f32_t fun);
 
1220
 
1221
- GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
1222
  struct ggml_context * ctx,
1223
  struct ggml_tensor * a,
1224
- ggml_unary_op_f32_t fun);
 
1225
 
1226
- GGML_API struct ggml_tensor * ggml_map_binary_f32(
1227
  struct ggml_context * ctx,
1228
  struct ggml_tensor * a,
1229
  struct ggml_tensor * b,
1230
- ggml_binary_op_f32_t fun);
 
1231
 
1232
- GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
1233
  struct ggml_context * ctx,
1234
  struct ggml_tensor * a,
1235
  struct ggml_tensor * b,
1236
- ggml_binary_op_f32_t fun);
 
1237
 
1238
- GGML_API struct ggml_tensor * ggml_map_custom1_f32(
1239
  struct ggml_context * ctx,
1240
  struct ggml_tensor * a,
1241
- ggml_custom1_op_f32_t fun);
 
1242
 
1243
- GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
1244
  struct ggml_context * ctx,
1245
  struct ggml_tensor * a,
1246
- ggml_custom1_op_f32_t fun);
 
1247
 
1248
- GGML_API struct ggml_tensor * ggml_map_custom2_f32(
1249
  struct ggml_context * ctx,
1250
  struct ggml_tensor * a,
1251
  struct ggml_tensor * b,
1252
- ggml_custom2_op_f32_t fun);
 
1253
 
1254
- GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
1255
  struct ggml_context * ctx,
1256
  struct ggml_tensor * a,
1257
  struct ggml_tensor * b,
1258
- ggml_custom2_op_f32_t fun);
 
1259
 
1260
- GGML_API struct ggml_tensor * ggml_map_custom3_f32(
1261
  struct ggml_context * ctx,
1262
  struct ggml_tensor * a,
1263
  struct ggml_tensor * b,
1264
  struct ggml_tensor * c,
1265
- ggml_custom3_op_f32_t fun);
 
1266
 
1267
- GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
1268
  struct ggml_context * ctx,
1269
  struct ggml_tensor * a,
1270
  struct ggml_tensor * b,
1271
  struct ggml_tensor * c,
1272
- ggml_custom3_op_f32_t fun);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1273
 
1274
  // loss function
1275
 
@@ -1290,15 +1609,28 @@ extern "C" {
1290
 
1291
  GGML_API void ggml_set_param(
1292
  struct ggml_context * ctx,
1293
- struct ggml_tensor * tensor);
 
1294
 
1295
  GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1296
 
1297
  GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
1298
  GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
1299
 
1300
- GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
1301
- GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
 
 
 
 
 
 
 
 
 
 
 
 
1302
 
1303
  GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
1304
 
@@ -1488,6 +1820,127 @@ extern "C" {
1488
 
1489
  GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1491
  //
1492
  // system info
1493
  //
@@ -1516,25 +1969,28 @@ extern "C" {
1516
  //
1517
 
1518
  #ifdef __cplusplus
1519
- // restrict not standard in C++
1520
  #define GGML_RESTRICT
1521
  #else
1522
  #define GGML_RESTRICT restrict
1523
  #endif
1524
- typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
1525
- typedef void (*quantize_row_q_t) (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
1526
- typedef void (*vec_dot_q_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
1527
 
1528
  typedef struct {
1529
- dequantize_row_q_t dequantize_row_q;
1530
- quantize_row_q_t quantize_row_q;
1531
- quantize_row_q_t quantize_row_q_reference;
1532
- quantize_row_q_t quantize_row_q_dot;
1533
- vec_dot_q_t vec_dot_q;
1534
- enum ggml_type vec_dot_type;
1535
- } quantize_fns_t;
1536
-
1537
- quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
 
 
 
1538
 
1539
  #ifdef __cplusplus
1540
  }
 
65
  // ggml_set_f32(a, 3.0f);
66
  // ggml_set_f32(b, 4.0f);
67
  //
68
+ // ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
69
  //
70
  // printf("f = %f\n", ggml_get_f32_1d(f, 0));
71
  //
 
130
  // The data of the tensor is accessed via the "data" pointer. For example:
131
  //
132
  // {
133
+ // const int nx = 2;
134
+ // const int ny = 3;
135
  //
136
+ // struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);
 
137
  //
138
+ // for (int y = 0; y < ny; y++) {
139
+ // for (int x = 0; x < nx; x++) {
140
+ // *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
141
+ // }
142
+ // }
143
  //
144
  // ...
145
  // }
 
186
  # define GGML_API
187
  #endif
188
 
189
+ // TODO: support for clang
190
+ #ifdef __GNUC__
191
+ # define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
192
+ #elif defined(_MSC_VER)
193
+ # define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
194
+ #else
195
+ # define GGML_DEPRECATED(func, hint) func
196
+ #endif
197
+
198
  #include <stdint.h>
199
  #include <stddef.h>
200
  #include <stdbool.h>
 
209
  #define GGML_MAX_NODES 4096
210
  #define GGML_MAX_PARAMS 256
211
  #define GGML_MAX_CONTEXTS 64
212
+ #define GGML_MAX_SRC 6
213
+ #define GGML_MAX_NAME 64
214
+ #define GGML_MAX_OP_PARAMS 32
215
  #define GGML_DEFAULT_N_THREADS 4
216
 
217
+ #if UINTPTR_MAX == 0xFFFFFFFF
218
+ #define GGML_MEM_ALIGN 4
219
+ #else
220
+ #define GGML_MEM_ALIGN 16
221
+ #endif
222
+
223
+ #define GGML_EXIT_SUCCESS 0
224
+ #define GGML_EXIT_ABORTED 1
225
+
226
+ #define GGUF_MAGIC 0x46554747 // "GGUF"
227
+ #define GGUF_VERSION 2
228
+
229
+ #define GGUF_DEFAULT_ALIGNMENT 32
230
+
231
  #define GGML_UNUSED(x) (void)(x)
232
 
233
+ #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
234
+
235
  #define GGML_ASSERT(x) \
236
  do { \
237
  if (!(x)) { \
 
268
  extern "C" {
269
  #endif
270
 
271
+ #if defined(__ARM_NEON) && defined(__CUDACC__)
272
+ typedef half ggml_fp16_t;
273
+ #elif defined(__ARM_NEON)
274
  typedef __fp16 ggml_fp16_t;
275
  #else
276
  typedef uint16_t ggml_fp16_t;
 
280
  GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
281
  GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
282
 
283
+ GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);
284
+ GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);
285
 
286
  struct ggml_object;
287
  struct ggml_context;
 
354
  GGML_OP_ARGMAX,
355
  GGML_OP_REPEAT,
356
  GGML_OP_REPEAT_BACK,
357
+ GGML_OP_CONCAT,
 
 
 
 
 
 
 
 
 
358
  GGML_OP_SILU_BACK,
359
  GGML_OP_NORM, // normalize
360
  GGML_OP_RMS_NORM,
361
  GGML_OP_RMS_NORM_BACK,
362
+ GGML_OP_GROUP_NORM,
363
 
364
  GGML_OP_MUL_MAT,
365
  GGML_OP_OUT_PROD,
 
385
  GGML_OP_CLAMP,
386
  GGML_OP_CONV_1D,
387
  GGML_OP_CONV_2D,
388
+ GGML_OP_CONV_TRANSPOSE_2D,
389
+ GGML_OP_POOL_1D,
390
+ GGML_OP_POOL_2D,
391
+
392
+ GGML_OP_UPSCALE, // nearest interpolate
393
 
394
  GGML_OP_FLASH_ATTN,
395
  GGML_OP_FLASH_FF,
396
  GGML_OP_FLASH_ATTN_BACK,
397
  GGML_OP_WIN_PART,
398
  GGML_OP_WIN_UNPART,
399
+ GGML_OP_GET_REL_POS,
400
+ GGML_OP_ADD_REL_POS,
401
+
402
+ GGML_OP_UNARY,
403
 
404
  GGML_OP_MAP_UNARY,
405
  GGML_OP_MAP_BINARY,
406
 
407
+ GGML_OP_MAP_CUSTOM1_F32,
408
+ GGML_OP_MAP_CUSTOM2_F32,
409
+ GGML_OP_MAP_CUSTOM3_F32,
410
+
411
  GGML_OP_MAP_CUSTOM1,
412
  GGML_OP_MAP_CUSTOM2,
413
  GGML_OP_MAP_CUSTOM3,
 
418
  GGML_OP_COUNT,
419
  };
420
 
421
+ enum ggml_unary_op {
422
+ GGML_UNARY_OP_ABS,
423
+ GGML_UNARY_OP_SGN,
424
+ GGML_UNARY_OP_NEG,
425
+ GGML_UNARY_OP_STEP,
426
+ GGML_UNARY_OP_TANH,
427
+ GGML_UNARY_OP_ELU,
428
+ GGML_UNARY_OP_RELU,
429
+ GGML_UNARY_OP_GELU,
430
+ GGML_UNARY_OP_GELU_QUICK,
431
+ GGML_UNARY_OP_SILU,
432
+ };
433
+
434
+ enum ggml_object_type {
435
+ GGML_OBJECT_TENSOR,
436
+ GGML_OBJECT_GRAPH,
437
+ GGML_OBJECT_WORK_BUFFER
438
+ };
439
 
440
  // ggml object
441
  struct ggml_object {
 
444
 
445
  struct ggml_object * next;
446
 
447
+ enum ggml_object_type type;
448
+
449
+ char padding[4];
450
  };
451
 
452
  static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
 
466
  // compute data
467
  enum ggml_op op;
468
 
469
+ // op params - allocated as int32_t for alignment
470
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
471
+
472
  bool is_param;
473
 
474
  struct ggml_tensor * grad;
475
+ struct ggml_tensor * src[GGML_MAX_SRC];
 
 
 
 
 
476
 
477
  // performance
478
  int perf_runs;
 
490
 
491
  static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
492
 
493
+ // the compute plan that needs to be prepared for ggml_graph_compute()
494
+ // since https://github.com/ggerganov/ggml/issues/287
495
+ struct ggml_cplan {
496
+ size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
497
+ uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
498
+
499
+ int n_threads;
500
+
501
+ // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
502
+ int n_tasks[GGML_MAX_NODES];
503
+
504
+ // abort ggml_graph_compute when true
505
+ bool (*abort_callback)(void * data);
506
+ void * abort_callback_data;
507
+ };
508
+
509
+ // next prime after GGML_MAX_NODES
510
+ // #define GGML_GRAPH_HASHTABLE_SIZE 4099
511
+ // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
512
+ #define GGML_GRAPH_HASHTABLE_SIZE 8273
513
+
514
  // computation graph
515
  struct ggml_cgraph {
516
  int n_nodes;
517
  int n_leafs;
 
 
 
 
518
 
519
  struct ggml_tensor * nodes[GGML_MAX_NODES];
520
  struct ggml_tensor * grads[GGML_MAX_NODES];
521
  struct ggml_tensor * leafs[GGML_MAX_NODES];
522
 
523
+ void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
524
+
525
  // performance
526
  int perf_runs;
527
  int64_t perf_cycles;
528
  int64_t perf_time_us;
529
  };
530
 
531
+ static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
532
+
533
  // scratch buffer
534
  struct ggml_scratch {
535
  size_t offs;
 
583
  GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
584
  GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
585
  GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
586
+ GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
587
  GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
588
 
589
  GGML_API int ggml_blck_size (enum ggml_type type);
 
592
 
593
  GGML_API const char * ggml_type_name(enum ggml_type type);
594
  GGML_API const char * ggml_op_name (enum ggml_op op);
595
+ GGML_API const char * ggml_op_symbol(enum ggml_op op);
596
 
597
  GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
598
 
 
605
  GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
606
  GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
607
 
608
+ GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
609
+
610
  // use this to compute the memory overhead of a tensor
611
  GGML_API size_t ggml_tensor_overhead(void);
612
 
 
618
  GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
619
 
620
  GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
621
+ GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
622
  GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
623
 
624
  GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
 
678
  GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
679
  GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
680
 
681
+ GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
682
+
683
+ GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
684
+ GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
685
+ GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
686
 
687
  //
688
  // operations on tensors with backpropagation
 
692
  struct ggml_context * ctx,
693
  struct ggml_tensor * a);
694
 
695
+ // in-place, returns view(a)
696
+ GGML_API struct ggml_tensor * ggml_dup_inplace(
697
+ struct ggml_context * ctx,
698
+ struct ggml_tensor * a);
699
+
700
  GGML_API struct ggml_tensor * ggml_add(
701
  struct ggml_context * ctx,
702
  struct ggml_tensor * a,
 
821
  struct ggml_tensor * a,
822
  struct ggml_tensor * b);
823
 
824
+ // concat a and b on dim 2
825
+ // used in stable-diffusion
826
+ GGML_API struct ggml_tensor * ggml_concat(
827
+ struct ggml_context * ctx,
828
+ struct ggml_tensor * a,
829
+ struct ggml_tensor * b);
830
+
831
  GGML_API struct ggml_tensor * ggml_abs(
832
  struct ggml_context * ctx,
833
  struct ggml_tensor * a);
 
917
  struct ggml_tensor * b);
918
 
919
  // normalize along rows
 
920
  GGML_API struct ggml_tensor * ggml_norm(
921
  struct ggml_context * ctx,
922
+ struct ggml_tensor * a,
923
+ float eps);
924
 
925
  GGML_API struct ggml_tensor * ggml_norm_inplace(
926
  struct ggml_context * ctx,
927
+ struct ggml_tensor * a,
928
+ float eps);
929
 
930
  GGML_API struct ggml_tensor * ggml_rms_norm(
931
  struct ggml_context * ctx,
932
+ struct ggml_tensor * a,
933
+ float eps);
934
 
935
  GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
936
  struct ggml_context * ctx,
937
+ struct ggml_tensor * a,
938
+ float eps);
939
+
940
+ // group normalize along ne0*ne1*n_groups
941
+ // used in stable-diffusion
942
+ // TODO: eps is hardcoded to 1e-6 for now
943
+ GGML_API struct ggml_tensor * ggml_group_norm(
944
+ struct ggml_context * ctx,
945
+ struct ggml_tensor * a,
946
+ int n_groups);
947
+
948
+ GGML_API struct ggml_tensor * ggml_group_norm_inplace(
949
+ struct ggml_context * ctx,
950
+ struct ggml_tensor * a,
951
+ int n_groups);
952
 
953
  // a - x
954
  // b - dy
955
+ // TODO: update with configurable eps
956
  GGML_API struct ggml_tensor * ggml_rms_norm_back(
957
  struct ggml_context * ctx,
958
  struct ggml_tensor * a,
 
1044
  struct ggml_tensor * a,
1045
  struct ggml_tensor * b);
1046
 
1047
+ // a -> b, in-place, return view(b)
1048
+ GGML_API struct ggml_tensor * ggml_cpy_inplace(
1049
+ struct ggml_context * ctx,
1050
+ struct ggml_tensor * a,
1051
+ struct ggml_tensor * b);
1052
+
1053
  // make contiguous
1054
  GGML_API struct ggml_tensor * ggml_cont(
1055
  struct ggml_context * ctx,
1056
  struct ggml_tensor * a);
1057
 
1058
+ // make contiguous, in-place
1059
+ GGML_API struct ggml_tensor * ggml_cont_inplace(
1060
+ struct ggml_context * ctx,
1061
+ struct ggml_tensor * a);
1062
+
1063
  // return view(a), b specifies the new shape
1064
  // TODO: when we start computing gradient, make a copy instead of view
1065
  GGML_API struct ggml_tensor * ggml_reshape(
 
1228
  int mode,
1229
  int n_ctx);
1230
 
1231
+ // custom RoPE
1232
+ GGML_API struct ggml_tensor * ggml_rope_custom(
1233
+ struct ggml_context * ctx,
1234
+ struct ggml_tensor * a,
1235
+ int n_past,
1236
+ int n_dims,
1237
+ int mode,
1238
+ int n_ctx,
1239
+ float freq_base,
1240
+ float freq_scale);
1241
+
1242
+ // in-place, returns view(a)
1243
+ GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1244
+ struct ggml_context * ctx,
1245
+ struct ggml_tensor * a,
1246
+ int n_past,
1247
+ int n_dims,
1248
+ int mode,
1249
+ int n_ctx,
1250
+ float freq_base,
1251
+ float freq_scale);
1252
+
1253
+ // xPos RoPE, in-place, returns view(a)
1254
+ GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
1255
+ struct ggml_context * ctx,
1256
+ struct ggml_tensor * a,
1257
+ int n_past,
1258
+ int n_dims,
1259
+ float base,
1260
+ bool down);
1261
+
1262
  // rotary position embedding backward, i.e compute dx from dy
1263
  // a - dy
1264
  GGML_API struct ggml_tensor * ggml_rope_back(
 
1266
  struct ggml_tensor * a,
1267
  int n_past,
1268
  int n_dims,
1269
+ int mode,
1270
+ int n_ctx,
1271
+ float freq_base,
1272
+ float freq_scale,
1273
+ float xpos_base,
1274
+ bool xpos_down);
1275
 
1276
  // alibi position embedding
1277
  // in-place, returns view(a)
 
1298
  int p0, // padding
1299
  int d0); // dilation
1300
 
1301
+ // conv_1d with padding = half
1302
+ // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1303
+ GGML_API struct ggml_tensor* ggml_conv_1d_ph(
1304
+ struct ggml_context * ctx,
1305
+ struct ggml_tensor * a,
1306
+ struct ggml_tensor * b,
1307
+ int s,
1308
+ int d);
1309
+
1310
  GGML_API struct ggml_tensor * ggml_conv_2d(
1311
  struct ggml_context * ctx,
1312
  struct ggml_tensor * a,
 
1318
  int d0,
1319
  int d1);
1320
 
1321
+
1322
+ // kernel size is a->ne[0] x a->ne[1]
1323
+ // stride is equal to kernel size
1324
+ // padding is zero
1325
+ // example:
1326
+ // a: 16 16 3 768
1327
+ // b: 1024 1024 3 1
1328
+ // res: 64 64 768 1
1329
+ // used in sam
1330
+ GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
1331
+ struct ggml_context * ctx,
1332
+ struct ggml_tensor * a,
1333
+ struct ggml_tensor * b);
1334
+
1335
+ // kernel size is a->ne[0] x a->ne[1]
1336
+ // stride is 1
1337
+ // padding is half
1338
+ // example:
1339
+ // a: 3 3 256 256
1340
+ // b: 64 64 256 1
1341
+ // res: 64 64 256 1
1342
+ // used in sam
1343
+ GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
1344
+ struct ggml_context * ctx,
1345
+ struct ggml_tensor * a,
1346
+ struct ggml_tensor * b);
1347
+
1348
+ GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
1349
  struct ggml_context * ctx,
1350
  struct ggml_tensor * a,
1351
  struct ggml_tensor * b,
1352
+ int stride);
1353
+
1354
+ enum ggml_op_pool {
1355
+ GGML_OP_POOL_MAX,
1356
+ GGML_OP_POOL_AVG,
1357
+ GGML_OP_POOL_COUNT,
1358
+ };
1359
+
1360
+ GGML_API struct ggml_tensor * ggml_pool_1d(
1361
+ struct ggml_context * ctx,
1362
+ struct ggml_tensor * a,
1363
+ enum ggml_op_pool op,
1364
+ int k0, // kernel size
1365
+ int s0, // stride
1366
+ int p0); // padding
1367
+
1368
+ GGML_API struct ggml_tensor * ggml_pool_2d(
1369
+ struct ggml_context * ctx,
1370
+ struct ggml_tensor * a,
1371
+ enum ggml_op_pool op,
1372
+ int k0,
1373
+ int k1,
1374
+ int s0,
1375
+ int s1,
1376
+ int p0,
1377
+ int p1);
1378
+
1379
+ // nearest interpolate
1380
+ // used in stable-diffusion
1381
+ GGML_API struct ggml_tensor * ggml_upscale(
1382
+ struct ggml_context * ctx,
1383
+ struct ggml_tensor * a,
1384
+ int scale_factor);
1385
 
1386
  GGML_API struct ggml_tensor * ggml_flash_attn(
1387
  struct ggml_context * ctx,
 
1426
  int h0,
1427
  int w);
1428
 
1429
+ GGML_API struct ggml_tensor * ggml_unary(
1430
+ struct ggml_context * ctx,
1431
+ struct ggml_tensor * a,
1432
+ enum ggml_unary_op op);
1433
+
1434
+ GGML_API struct ggml_tensor * ggml_unary_inplace(
1435
+ struct ggml_context * ctx,
1436
+ struct ggml_tensor * a,
1437
+ enum ggml_unary_op op);
1438
+
1439
+ // used in sam
1440
+ GGML_API struct ggml_tensor * ggml_get_rel_pos(
1441
+ struct ggml_context * ctx,
1442
+ struct ggml_tensor * a,
1443
+ int qh,
1444
+ int kh);
1445
+
1446
+ // used in sam
1447
+
1448
+ GGML_API struct ggml_tensor * ggml_add_rel_pos(
1449
+ struct ggml_context * ctx,
1450
+ struct ggml_tensor * a,
1451
+ struct ggml_tensor * pw,
1452
+ struct ggml_tensor * ph);
1453
+
1454
+ GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
1455
+ struct ggml_context * ctx,
1456
+ struct ggml_tensor * a,
1457
+ struct ggml_tensor * pw,
1458
+ struct ggml_tensor * ph);
1459
+
1460
  // custom operators
1461
 
1462
  typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
 
1466
  typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
1467
  typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
1468
 
1469
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
1470
  struct ggml_context * ctx,
1471
  struct ggml_tensor * a,
1472
+ ggml_unary_op_f32_t fun),
1473
+ "use ggml_map_custom1 instead");
1474
 
1475
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
1476
  struct ggml_context * ctx,
1477
  struct ggml_tensor * a,
1478
+ ggml_unary_op_f32_t fun),
1479
+ "use ggml_map_custom1_inplace instead");
1480
 
1481
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
1482
  struct ggml_context * ctx,
1483
  struct ggml_tensor * a,
1484
  struct ggml_tensor * b,
1485
+ ggml_binary_op_f32_t fun),
1486
+ "use ggml_map_custom2 instead");
1487
 
1488
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
1489
  struct ggml_context * ctx,
1490
  struct ggml_tensor * a,
1491
  struct ggml_tensor * b,
1492
+ ggml_binary_op_f32_t fun),
1493
+ "use ggml_map_custom2_inplace instead");
1494
 
1495
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
1496
  struct ggml_context * ctx,
1497
  struct ggml_tensor * a,
1498
+ ggml_custom1_op_f32_t fun),
1499
+ "use ggml_map_custom1 instead");
1500
 
1501
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
1502
  struct ggml_context * ctx,
1503
  struct ggml_tensor * a,
1504
+ ggml_custom1_op_f32_t fun),
1505
+ "use ggml_map_custom1_inplace instead");
1506
 
1507
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
1508
  struct ggml_context * ctx,
1509
  struct ggml_tensor * a,
1510
  struct ggml_tensor * b,
1511
+ ggml_custom2_op_f32_t fun),
1512
+ "use ggml_map_custom2 instead");
1513
 
1514
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
1515
  struct ggml_context * ctx,
1516
  struct ggml_tensor * a,
1517
  struct ggml_tensor * b,
1518
+ ggml_custom2_op_f32_t fun),
1519
+ "use ggml_map_custom2_inplace instead");
1520
 
1521
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
1522
  struct ggml_context * ctx,
1523
  struct ggml_tensor * a,
1524
  struct ggml_tensor * b,
1525
  struct ggml_tensor * c,
1526
+ ggml_custom3_op_f32_t fun),
1527
+ "use ggml_map_custom3 instead");
1528
 
1529
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
1530
  struct ggml_context * ctx,
1531
  struct ggml_tensor * a,
1532
  struct ggml_tensor * b,
1533
  struct ggml_tensor * c,
1534
+ ggml_custom3_op_f32_t fun),
1535
+ "use ggml_map_custom3_inplace instead");
1536
+
1537
+ // custom operators v2
1538
+
1539
+ typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
1540
+ typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
1541
+ typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
1542
+
1543
+ #define GGML_N_TASKS_MAX -1
1544
+
1545
+ GGML_API struct ggml_tensor * ggml_map_custom1(
1546
+ struct ggml_context * ctx,
1547
+ struct ggml_tensor * a,
1548
+ ggml_custom1_op_t fun,
1549
+ int n_tasks,
1550
+ void * userdata);
1551
+
1552
+ GGML_API struct ggml_tensor * ggml_map_custom1_inplace(
1553
+ struct ggml_context * ctx,
1554
+ struct ggml_tensor * a,
1555
+ ggml_custom1_op_t fun,
1556
+ int n_tasks,
1557
+ void * userdata);
1558
+
1559
+ GGML_API struct ggml_tensor * ggml_map_custom2(
1560
+ struct ggml_context * ctx,
1561
+ struct ggml_tensor * a,
1562
+ struct ggml_tensor * b,
1563
+ ggml_custom2_op_t fun,
1564
+ int n_tasks,
1565
+ void * userdata);
1566
+
1567
+ GGML_API struct ggml_tensor * ggml_map_custom2_inplace(
1568
+ struct ggml_context * ctx,
1569
+ struct ggml_tensor * a,
1570
+ struct ggml_tensor * b,
1571
+ ggml_custom2_op_t fun,
1572
+ int n_tasks,
1573
+ void * userdata);
1574
+
1575
+ GGML_API struct ggml_tensor * ggml_map_custom3(
1576
+ struct ggml_context * ctx,
1577
+ struct ggml_tensor * a,
1578
+ struct ggml_tensor * b,
1579
+ struct ggml_tensor * c,
1580
+ ggml_custom3_op_t fun,
1581
+ int n_tasks,
1582
+ void * userdata);
1583
+
1584
+ GGML_API struct ggml_tensor * ggml_map_custom3_inplace(
1585
+ struct ggml_context * ctx,
1586
+ struct ggml_tensor * a,
1587
+ struct ggml_tensor * b,
1588
+ struct ggml_tensor * c,
1589
+ ggml_custom3_op_t fun,
1590
+ int n_tasks,
1591
+ void * userdata);
1592
 
1593
  // loss function
1594
 
 
1609
 
1610
  GGML_API void ggml_set_param(
1611
  struct ggml_context * ctx,
1612
+ struct ggml_tensor * tensor);
1613
+
1614
 
1615
  GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1616
 
1617
  GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
1618
  GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
1619
 
1620
+ // graph allocation in a context
1621
+ GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx);
1622
+ GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
1623
+ GGML_API size_t ggml_graph_overhead(void);
1624
+
1625
+ // ggml_graph_plan() has to be called before ggml_graph_compute()
1626
+ // when plan.work_size > 0, caller must allocate memory for plan.work_data
1627
+ GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
1628
+ GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
1629
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
1630
+
1631
+ // same as ggml_graph_compute() but the work data is allocated as a part of the context
1632
+ // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
1633
+ GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
1634
 
1635
  GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
1636
 
 
1820
 
1821
  GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
1822
 
1823
+ //
1824
+ // gguf
1825
+ //
1826
+
1827
+ enum gguf_type {
1828
+ GGUF_TYPE_UINT8 = 0,
1829
+ GGUF_TYPE_INT8 = 1,
1830
+ GGUF_TYPE_UINT16 = 2,
1831
+ GGUF_TYPE_INT16 = 3,
1832
+ GGUF_TYPE_UINT32 = 4,
1833
+ GGUF_TYPE_INT32 = 5,
1834
+ GGUF_TYPE_FLOAT32 = 6,
1835
+ GGUF_TYPE_BOOL = 7,
1836
+ GGUF_TYPE_STRING = 8,
1837
+ GGUF_TYPE_ARRAY = 9,
1838
+ GGUF_TYPE_UINT64 = 10,
1839
+ GGUF_TYPE_INT64 = 11,
1840
+ GGUF_TYPE_FLOAT64 = 12,
1841
+ GGUF_TYPE_COUNT, // marks the end of the enum
1842
+ };
1843
+
1844
+ struct gguf_context;
1845
+
1846
+ struct gguf_init_params {
1847
+ bool no_alloc;
1848
+
1849
+ // if not NULL, create a ggml_context and allocate the tensor data in it
1850
+ struct ggml_context ** ctx;
1851
+ };
1852
+
1853
+ GGML_API struct gguf_context * gguf_init_empty(void);
1854
+ GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
1855
+ //GGML_API struct gguf_context * gguf_init_from_buffer(..);
1856
+
1857
+ GGML_API void gguf_free(struct gguf_context * ctx);
1858
+
1859
+ GGML_API const char * gguf_type_name(enum gguf_type type);
1860
+
1861
+ GGML_API int gguf_get_version (struct gguf_context * ctx);
1862
+ GGML_API size_t gguf_get_alignment (struct gguf_context * ctx);
1863
+ GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);
1864
+ GGML_API void * gguf_get_data (struct gguf_context * ctx);
1865
+
1866
+ GGML_API int gguf_get_n_kv(struct gguf_context * ctx);
1867
+ GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key);
1868
+ GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
1869
+
1870
+ GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
1871
+ GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
1872
+
1873
+ // results are undefined if the wrong type is used for the key
1874
+ GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
1875
+ GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);
1876
+ GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i);
1877
+ GGML_API int16_t gguf_get_val_i16 (struct gguf_context * ctx, int i);
1878
+ GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i);
1879
+ GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i);
1880
+ GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i);
1881
+ GGML_API uint64_t gguf_get_val_u64 (struct gguf_context * ctx, int i);
1882
+ GGML_API int64_t gguf_get_val_i64 (struct gguf_context * ctx, int i);
1883
+ GGML_API double gguf_get_val_f64 (struct gguf_context * ctx, int i);
1884
+ GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i);
1885
+ GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);
1886
+ GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i);
1887
+ GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i);
1888
+ GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i);
1889
+
1890
+ GGML_API int gguf_get_n_tensors (struct gguf_context * ctx);
1891
+ GGML_API int gguf_find_tensor (struct gguf_context * ctx, const char * name);
1892
+ GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i);
1893
+ GGML_API char * gguf_get_tensor_name (struct gguf_context * ctx, int i);
1894
+
1895
+ // overrides existing values or adds a new one
1896
+ GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
1897
+ GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
1898
+ GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
1899
+ GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
1900
+ GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
1901
+ GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
1902
+ GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
1903
+ GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
1904
+ GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
1905
+ GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
1906
+ GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
1907
+ GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
1908
+ GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
1909
+ GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
1910
+
1911
+ // set or add KV pairs from another context
1912
+ GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
1913
+
1914
+ // manage tensor info
1915
+ GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
1916
+ GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
1917
+ GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
1918
+
1919
+ // writing gguf files can be done in 2 ways:
1920
+ //
1921
+ // - write the entire gguf_context to a binary file in a single pass:
1922
+ //
1923
+ // gguf_write_to_file(ctx, fname);
1924
+ //
1925
+ // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
1926
+ //
1927
+ // FILE * f = fopen(fname, "wb");
1928
+ // fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
1929
+ // fwrite(f, ...);
1930
+ // void * data = gguf_meta_get_meta_data(ctx);
1931
+ // fseek(f, 0, SEEK_SET);
1932
+ // fwrite(f, data, gguf_get_meta_size(ctx));
1933
+ // free(data);
1934
+ // fclose(f);
1935
+ //
1936
+
1937
+ // write the entire context to a binary file
1938
+ GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta);
1939
+
1940
+ // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
1941
+ GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx);
1942
+ GGML_API void gguf_get_meta_data(struct gguf_context * ctx, void * data);
1943
+
1944
  //
1945
  // system info
1946
  //
 
1969
  //
1970
 
1971
  #ifdef __cplusplus
1972
+ // restrict not standard in C++
1973
  #define GGML_RESTRICT
1974
  #else
1975
  #define GGML_RESTRICT restrict
1976
  #endif
1977
+ typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
1978
+ typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
1979
+ typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
1980
 
1981
  typedef struct {
1982
+ const char * type_name;
1983
+ int blck_size;
1984
+ size_t type_size;
1985
+ bool is_quantized;
1986
+ ggml_to_float_t to_float;
1987
+ ggml_from_float_t from_float;
1988
+ ggml_from_float_t from_float_reference;
1989
+ ggml_vec_dot_t vec_dot;
1990
+ enum ggml_type vec_dot_type;
1991
+ } ggml_type_traits_t;
1992
+
1993
+ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
1994
 
1995
  #ifdef __cplusplus
1996
  }
whisper.cpp CHANGED
@@ -441,6 +441,7 @@ struct whisper_hparams {
441
  int32_t n_text_layer = 4;
442
  int32_t n_mels = 80;
443
  int32_t ftype = 1;
 
444
  };
445
 
446
  // audio encoding layer
@@ -1578,7 +1579,7 @@ static bool whisper_encode_internal(
1578
  {
1579
  wstate.use_buf(ctx0, 0);
1580
 
1581
- cur = ggml_norm(ctx0, inpL);
1582
 
1583
  // cur = ln_0_w*cur + ln_0_b
1584
  cur = ggml_add(ctx0,
@@ -1725,7 +1726,7 @@ static bool whisper_encode_internal(
1725
  {
1726
  wstate.use_buf(ctx0, 0);
1727
 
1728
- cur = ggml_norm(ctx0, inpFF);
1729
 
1730
  wstate.use_buf(ctx0, 1);
1731
 
@@ -1788,7 +1789,7 @@ static bool whisper_encode_internal(
1788
  {
1789
  wstate.use_buf(ctx0, 0);
1790
 
1791
- cur = ggml_norm(ctx0, cur);
1792
 
1793
  wstate.use_buf(ctx0, 1);
1794
 
@@ -1805,10 +1806,9 @@ static bool whisper_encode_internal(
1805
  // run the computation
1806
  {
1807
  struct ggml_cgraph gf = {};
1808
- gf.n_threads = n_threads;
1809
 
1810
- ggml_build_forward_expand(&gf, cur);
1811
- ggml_graph_compute(ctx0, &gf);
1812
 
1813
  //ggml_graph_print(&gf);
1814
  }
@@ -1851,12 +1851,11 @@ static bool whisper_encode_internal(
1851
  // pre-compute cross-attention memory
1852
  {
1853
  struct ggml_cgraph gf = {};
1854
- gf.n_threads = n_threads;
1855
 
1856
  // TODO: hack to disconnect the encoded features from the previous graph
1857
  cur->op = GGML_OP_NONE;
1858
- cur->src0 = nullptr;
1859
- cur->src1 = nullptr;
1860
 
1861
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1862
  auto& layer = model.layers_decoder[il];
@@ -1894,7 +1893,7 @@ static bool whisper_encode_internal(
1894
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1895
  }
1896
 
1897
- ggml_graph_compute(ctx0, &gf);
1898
  //ggml_graph_print(&gf);
1899
  }
1900
 
@@ -1965,7 +1964,6 @@ static bool whisper_decode_internal(
1965
  struct ggml_context * ctx0 = ggml_init(params);
1966
 
1967
  struct ggml_cgraph gf = {};
1968
- gf.n_threads = n_threads;
1969
 
1970
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1971
  memcpy(embd->data, tokens, N*ggml_element_size(embd));
@@ -1992,7 +1990,7 @@ static bool whisper_decode_internal(
1992
  {
1993
  wstate.use_buf(ctx0, 0);
1994
 
1995
- cur = ggml_norm(ctx0, inpL);
1996
 
1997
  // cur = ln_0_w*cur + ln_0_b
1998
  cur = ggml_add(ctx0,
@@ -2119,7 +2117,7 @@ static bool whisper_decode_internal(
2119
  {
2120
  wstate.use_buf(ctx0, 0);
2121
 
2122
- cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
2123
 
2124
  // cur = ln_0_w*cur + ln_0_b
2125
  cur = ggml_add(ctx0,
@@ -2229,7 +2227,7 @@ static bool whisper_decode_internal(
2229
  {
2230
  wstate.use_buf(ctx0, 0);
2231
 
2232
- cur = ggml_norm(ctx0, inpFF);
2233
 
2234
  wstate.use_buf(ctx0, 1);
2235
 
@@ -2284,7 +2282,7 @@ static bool whisper_decode_internal(
2284
  {
2285
  wstate.use_buf(ctx0, 0);
2286
 
2287
- cur = ggml_norm(ctx0, cur);
2288
 
2289
  wstate.use_buf(ctx0, 1);
2290
 
@@ -2308,8 +2306,8 @@ static bool whisper_decode_internal(
2308
 
2309
  // run the computation
2310
  {
2311
- ggml_build_forward_expand(&gf, logits);
2312
- ggml_graph_compute (ctx0, &gf);
2313
  }
2314
 
2315
  // extract logits for all N tokens
@@ -2358,7 +2356,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
2358
  static float sin_vals[SIN_COS_N_COUNT];
2359
  static float cos_vals[SIN_COS_N_COUNT];
2360
 
2361
- // In FFT, we frequently use sine and cosine operations with the same values.
2362
  // We can use precalculated values to speed up the process.
2363
  static void fill_sin_cos_table() {
2364
  static bool is_filled = false;
@@ -5165,17 +5163,15 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5165
 
5166
  struct ggml_cgraph gf = ggml_build_forward(c);
5167
 
5168
- gf.n_threads = n_threads;
5169
-
5170
  double tsum = 0.0;
5171
 
5172
  // heat-up
5173
- ggml_graph_compute(ctx0, &gf);
5174
 
5175
  for (int i = 0; i < n_max; ++i) {
5176
  const int64_t t0 = ggml_time_us();
5177
 
5178
- ggml_graph_compute(ctx0, &gf);
5179
 
5180
  const int64_t t1 = ggml_time_us();
5181
 
 
441
  int32_t n_text_layer = 4;
442
  int32_t n_mels = 80;
443
  int32_t ftype = 1;
444
+ float eps = 1e-5f;
445
  };
446
 
447
  // audio encoding layer
 
1579
  {
1580
  wstate.use_buf(ctx0, 0);
1581
 
1582
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
1583
 
1584
  // cur = ln_0_w*cur + ln_0_b
1585
  cur = ggml_add(ctx0,
 
1726
  {
1727
  wstate.use_buf(ctx0, 0);
1728
 
1729
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
1730
 
1731
  wstate.use_buf(ctx0, 1);
1732
 
 
1789
  {
1790
  wstate.use_buf(ctx0, 0);
1791
 
1792
+ cur = ggml_norm(ctx0, cur, hparams.eps);
1793
 
1794
  wstate.use_buf(ctx0, 1);
1795
 
 
1806
  // run the computation
1807
  {
1808
  struct ggml_cgraph gf = {};
 
1809
 
1810
+ ggml_build_forward_expand (&gf, cur);
1811
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
1812
 
1813
  //ggml_graph_print(&gf);
1814
  }
 
1851
  // pre-compute cross-attention memory
1852
  {
1853
  struct ggml_cgraph gf = {};
 
1854
 
1855
  // TODO: hack to disconnect the encoded features from the previous graph
1856
  cur->op = GGML_OP_NONE;
1857
+ cur->src[0] = nullptr;
1858
+ cur->src[1] = nullptr;
1859
 
1860
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
1861
  auto& layer = model.layers_decoder[il];
 
1893
  ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
1894
  }
1895
 
1896
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
1897
  //ggml_graph_print(&gf);
1898
  }
1899
 
 
1964
  struct ggml_context * ctx0 = ggml_init(params);
1965
 
1966
  struct ggml_cgraph gf = {};
 
1967
 
1968
  struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
1969
  memcpy(embd->data, tokens, N*ggml_element_size(embd));
 
1990
  {
1991
  wstate.use_buf(ctx0, 0);
1992
 
1993
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
1994
 
1995
  // cur = ln_0_w*cur + ln_0_b
1996
  cur = ggml_add(ctx0,
 
2117
  {
2118
  wstate.use_buf(ctx0, 0);
2119
 
2120
+ cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
2121
 
2122
  // cur = ln_0_w*cur + ln_0_b
2123
  cur = ggml_add(ctx0,
 
2227
  {
2228
  wstate.use_buf(ctx0, 0);
2229
 
2230
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
2231
 
2232
  wstate.use_buf(ctx0, 1);
2233
 
 
2282
  {
2283
  wstate.use_buf(ctx0, 0);
2284
 
2285
+ cur = ggml_norm(ctx0, cur, hparams.eps);
2286
 
2287
  wstate.use_buf(ctx0, 1);
2288
 
 
2306
 
2307
  // run the computation
2308
  {
2309
+ ggml_build_forward_expand (&gf, logits);
2310
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
2311
  }
2312
 
2313
  // extract logits for all N tokens
 
2356
  static float sin_vals[SIN_COS_N_COUNT];
2357
  static float cos_vals[SIN_COS_N_COUNT];
2358
 
2359
+ // In FFT, we frequently use sine and cosine operations with the same values.
2360
  // We can use precalculated values to speed up the process.
2361
  static void fill_sin_cos_table() {
2362
  static bool is_filled = false;
 
5163
 
5164
  struct ggml_cgraph gf = ggml_build_forward(c);
5165
 
 
 
5166
  double tsum = 0.0;
5167
 
5168
  // heat-up
5169
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
5170
 
5171
  for (int i = 0; i < n_max; ++i) {
5172
  const int64_t t0 = ggml_time_us();
5173
 
5174
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
5175
 
5176
  const int64_t t1 = ggml_time_us();
5177