Spaces:
Sleeping
Sleeping
ggml : sync (ggml-alloc, GPU, eps, etc.) (#1220)
Browse files* ggml : sync (ggml-alloc, GPU, eps, etc.)
* ggml : fix build
* wasm : fix build
- bindings/javascript/libwhisper.worker.js +1 -1
- bindings/javascript/whisper.js +0 -0
- examples/common.cpp +73 -21
- examples/common.h +20 -1
- examples/talk.wasm/gpt-2.cpp +8 -9
- examples/talk/gpt-2.cpp +6 -6
- ggml-alloc.c +594 -0
- ggml-alloc.h +26 -0
- ggml-cuda.cu +0 -0
- ggml-cuda.h +28 -18
- ggml-metal.h +19 -1
- ggml-metal.m +473 -258
- ggml-metal.metal +1185 -965
- ggml-opencl.cpp +5 -1
- ggml.c +0 -0
- ggml.h +539 -83
- whisper.cpp +18 -22
bindings/javascript/libwhisper.worker.js
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
"use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:
|
|
|
|
| 1 |
+
"use strict";var Module={};var ENVIRONMENT_IS_NODE=typeof process=="object"&&typeof process.versions=="object"&&typeof process.versions.node=="string";if(ENVIRONMENT_IS_NODE){var nodeWorkerThreads=require("worker_threads");var parentPort=nodeWorkerThreads.parentPort;parentPort.on("message",data=>onmessage({data:data}));var fs=require("fs");Object.assign(global,{self:global,require:require,Module:Module,location:{href:__filename},Worker:nodeWorkerThreads.Worker,importScripts:f=>(0,eval)(fs.readFileSync(f,"utf8")+"//# sourceURL="+f),postMessage:msg=>parentPort.postMessage(msg),performance:global.performance||{now:Date.now}})}var initializedJS=false;function threadPrintErr(){var text=Array.prototype.slice.call(arguments).join(" ");if(ENVIRONMENT_IS_NODE){fs.writeSync(2,text+"\n");return}console.error(text)}function threadAlert(){var text=Array.prototype.slice.call(arguments).join(" ");postMessage({cmd:"alert",text:text,threadId:Module["_pthread_self"]()})}var err=threadPrintErr;self.alert=threadAlert;Module["instantiateWasm"]=(info,receiveInstance)=>{var module=Module["wasmModule"];Module["wasmModule"]=null;var instance=new WebAssembly.Instance(module,info);return receiveInstance(instance)};self.onunhandledrejection=e=>{throw e.reason||e};function handleMessage(e){try{if(e.data.cmd==="load"){let messageQueue=[];self.onmessage=e=>messageQueue.push(e);self.startWorker=instance=>{Module=instance;postMessage({"cmd":"loaded"});for(let msg of messageQueue){handleMessage(msg)}self.onmessage=handleMessage};Module["wasmModule"]=e.data.wasmModule;for(const handler of e.data.handlers){Module[handler]=(...args)=>{postMessage({cmd:"callHandler",handler:handler,args:args})}}Module["wasmMemory"]=e.data.wasmMemory;Module["buffer"]=Module["wasmMemory"].buffer;Module["ENVIRONMENT_IS_PTHREAD"]=true;if(typeof e.data.urlOrBlob=="string"){importScripts(e.data.urlOrBlob)}else{var objectUrl=URL.createObjectURL(e.data.urlOrBlob);importScripts(objectUrl);URL.revokeObjectURL(objectUrl)}whisper_factory(Module)}else if(e.data.cmd==="run"){Module["__emscripten_thread_init"](e.data.pthread_ptr,0,0,1);Module["__emscripten_thread_mailbox_await"](e.data.pthread_ptr);Module["establishStackSpace"]();Module["PThread"].receiveObjectTransfer(e.data);Module["PThread"].threadInitTLS();if(!initializedJS){Module["__embind_initialize_bindings"]();initializedJS=true}try{Module["invokeEntryPoint"](e.data.start_routine,e.data.arg)}catch(ex){if(ex!="unwind"){throw ex}}}else if(e.data.cmd==="cancel"){if(Module["_pthread_self"]()){Module["__emscripten_thread_exit"](-1)}}else if(e.data.target==="setimmediate"){}else if(e.data.cmd==="checkMailbox"){if(initializedJS){Module["checkMailbox"]()}}else if(e.data.cmd){err(`worker.js received unknown command ${e.data.cmd}`);err(e.data)}}catch(ex){if(Module["__emscripten_thread_crashed"]){Module["__emscripten_thread_crashed"]()}throw ex}}self.onmessage=handleMessage;
|
bindings/javascript/whisper.js
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/common.cpp
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
#include "common.h"
|
| 2 |
|
| 3 |
// third-party utilities
|
|
@@ -13,53 +15,59 @@
|
|
| 13 |
#include <codecvt>
|
| 14 |
#include <sstream>
|
| 15 |
|
| 16 |
-
#ifndef M_PI
|
| 17 |
-
#define M_PI 3.14159265358979323846
|
| 18 |
-
#endif
|
| 19 |
-
|
| 20 |
#if defined(_MSC_VER)
|
| 21 |
#pragma warning(disable: 4244 4267) // possible loss of data
|
| 22 |
#endif
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
| 25 |
for (int i = 1; i < argc; i++) {
|
| 26 |
std::string arg = argv[i];
|
| 27 |
|
| 28 |
if (arg == "-s" || arg == "--seed") {
|
| 29 |
-
params.seed = std::stoi(argv
|
| 30 |
} else if (arg == "-t" || arg == "--threads") {
|
| 31 |
-
params.n_threads = std::stoi(argv
|
|
|
|
|
|
|
| 32 |
} else if (arg == "-p" || arg == "--prompt") {
|
| 33 |
-
params.prompt = argv
|
| 34 |
} else if (arg == "-n" || arg == "--n_predict") {
|
| 35 |
-
params.n_predict = std::stoi(argv
|
| 36 |
} else if (arg == "--top_k") {
|
| 37 |
-
params.top_k = std::
|
| 38 |
} else if (arg == "--top_p") {
|
| 39 |
-
params.top_p = std::stof(argv
|
| 40 |
} else if (arg == "--temp") {
|
| 41 |
-
params.temp = std::stof(argv
|
| 42 |
} else if (arg == "--repeat-last-n") {
|
| 43 |
-
params.repeat_last_n = std::
|
| 44 |
} else if (arg == "--repeat-penalty") {
|
| 45 |
-
params.repeat_penalty = std::stof(argv
|
| 46 |
} else if (arg == "-b" || arg == "--batch_size") {
|
| 47 |
-
params.n_batch
|
| 48 |
} else if (arg == "-m" || arg == "--model") {
|
| 49 |
-
params.model = argv
|
| 50 |
} else if (arg == "-i" || arg == "--interactive") {
|
| 51 |
params.interactive = true;
|
| 52 |
} else if (arg == "-ip" || arg == "--interactive-port") {
|
| 53 |
params.interactive = true;
|
| 54 |
-
params.interactive_port = std::stoi(argv
|
| 55 |
} else if (arg == "-h" || arg == "--help") {
|
| 56 |
gpt_print_usage(argc, argv, params);
|
| 57 |
exit(0);
|
| 58 |
} else if (arg == "-f" || arg == "--file") {
|
| 59 |
-
|
| 60 |
-
fprintf(stderr, "Invalid file param");
|
| 61 |
-
break;
|
| 62 |
-
}
|
| 63 |
std::ifstream file(argv[i]);
|
| 64 |
if (!file) {
|
| 65 |
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
|
@@ -70,7 +78,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|
| 70 |
params.prompt.pop_back();
|
| 71 |
}
|
| 72 |
} else if (arg == "-tt" || arg == "--token_test") {
|
| 73 |
-
params.token_test = argv
|
| 74 |
}
|
| 75 |
else {
|
| 76 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
@@ -89,6 +97,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|
| 89 |
fprintf(stderr, " -h, --help show this help message and exit\n");
|
| 90 |
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
|
| 91 |
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
|
|
|
| 92 |
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
|
| 93 |
fprintf(stderr, " prompt to start generation with (default: random)\n");
|
| 94 |
fprintf(stderr, " -f FNAME, --file FNAME\n");
|
|
@@ -755,3 +764,46 @@ float similarity(const std::string & s0, const std::string & s1) {
|
|
| 755 |
|
| 756 |
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
| 757 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#define _USE_MATH_DEFINES // for M_PI
|
| 2 |
+
|
| 3 |
#include "common.h"
|
| 4 |
|
| 5 |
// third-party utilities
|
|
|
|
| 15 |
#include <codecvt>
|
| 16 |
#include <sstream>
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
#if defined(_MSC_VER)
|
| 19 |
#pragma warning(disable: 4244 4267) // possible loss of data
|
| 20 |
#endif
|
| 21 |
|
| 22 |
+
// Function to check if the next argument exists
|
| 23 |
+
std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
|
| 24 |
+
if (i + 1 < argc && argv[i + 1][0] != '-') {
|
| 25 |
+
return argv[++i];
|
| 26 |
+
} else {
|
| 27 |
+
fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
|
| 28 |
+
gpt_print_usage(argc, argv, params);
|
| 29 |
+
exit(0);
|
| 30 |
+
}
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
| 34 |
for (int i = 1; i < argc; i++) {
|
| 35 |
std::string arg = argv[i];
|
| 36 |
|
| 37 |
if (arg == "-s" || arg == "--seed") {
|
| 38 |
+
params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
| 39 |
} else if (arg == "-t" || arg == "--threads") {
|
| 40 |
+
params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
| 41 |
+
} else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
|
| 42 |
+
params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
| 43 |
} else if (arg == "-p" || arg == "--prompt") {
|
| 44 |
+
params.prompt = get_next_arg(i, argc, argv, arg, params);
|
| 45 |
} else if (arg == "-n" || arg == "--n_predict") {
|
| 46 |
+
params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
| 47 |
} else if (arg == "--top_k") {
|
| 48 |
+
params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
| 49 |
} else if (arg == "--top_p") {
|
| 50 |
+
params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
|
| 51 |
} else if (arg == "--temp") {
|
| 52 |
+
params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
|
| 53 |
} else if (arg == "--repeat-last-n") {
|
| 54 |
+
params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
| 55 |
} else if (arg == "--repeat-penalty") {
|
| 56 |
+
params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
|
| 57 |
} else if (arg == "-b" || arg == "--batch_size") {
|
| 58 |
+
params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
|
| 59 |
} else if (arg == "-m" || arg == "--model") {
|
| 60 |
+
params.model = get_next_arg(i, argc, argv, arg, params);
|
| 61 |
} else if (arg == "-i" || arg == "--interactive") {
|
| 62 |
params.interactive = true;
|
| 63 |
} else if (arg == "-ip" || arg == "--interactive-port") {
|
| 64 |
params.interactive = true;
|
| 65 |
+
params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
| 66 |
} else if (arg == "-h" || arg == "--help") {
|
| 67 |
gpt_print_usage(argc, argv, params);
|
| 68 |
exit(0);
|
| 69 |
} else if (arg == "-f" || arg == "--file") {
|
| 70 |
+
get_next_arg(i, argc, argv, arg, params);
|
|
|
|
|
|
|
|
|
|
| 71 |
std::ifstream file(argv[i]);
|
| 72 |
if (!file) {
|
| 73 |
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
|
|
|
| 78 |
params.prompt.pop_back();
|
| 79 |
}
|
| 80 |
} else if (arg == "-tt" || arg == "--token_test") {
|
| 81 |
+
params.token_test = get_next_arg(i, argc, argv, arg, params);
|
| 82 |
}
|
| 83 |
else {
|
| 84 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
|
|
| 97 |
fprintf(stderr, " -h, --help show this help message and exit\n");
|
| 98 |
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
|
| 99 |
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
| 100 |
+
fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers);
|
| 101 |
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
|
| 102 |
fprintf(stderr, " prompt to start generation with (default: random)\n");
|
| 103 |
fprintf(stderr, " -f FNAME, --file FNAME\n");
|
|
|
|
| 764 |
|
| 765 |
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
| 766 |
}
|
| 767 |
+
|
| 768 |
+
bool sam_params_parse(int argc, char ** argv, sam_params & params) {
|
| 769 |
+
for (int i = 1; i < argc; i++) {
|
| 770 |
+
std::string arg = argv[i];
|
| 771 |
+
|
| 772 |
+
if (arg == "-s" || arg == "--seed") {
|
| 773 |
+
params.seed = std::stoi(argv[++i]);
|
| 774 |
+
} else if (arg == "-t" || arg == "--threads") {
|
| 775 |
+
params.n_threads = std::stoi(argv[++i]);
|
| 776 |
+
} else if (arg == "-m" || arg == "--model") {
|
| 777 |
+
params.model = argv[++i];
|
| 778 |
+
} else if (arg == "-i" || arg == "--inp") {
|
| 779 |
+
params.fname_inp = argv[++i];
|
| 780 |
+
} else if (arg == "-o" || arg == "--out") {
|
| 781 |
+
params.fname_out = argv[++i];
|
| 782 |
+
} else if (arg == "-h" || arg == "--help") {
|
| 783 |
+
sam_print_usage(argc, argv, params);
|
| 784 |
+
exit(0);
|
| 785 |
+
} else {
|
| 786 |
+
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 787 |
+
sam_print_usage(argc, argv, params);
|
| 788 |
+
exit(0);
|
| 789 |
+
}
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
return true;
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
void sam_print_usage(int argc, char ** argv, const sam_params & params) {
|
| 796 |
+
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
| 797 |
+
fprintf(stderr, "\n");
|
| 798 |
+
fprintf(stderr, "options:\n");
|
| 799 |
+
fprintf(stderr, " -h, --help show this help message and exit\n");
|
| 800 |
+
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
|
| 801 |
+
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
| 802 |
+
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
| 803 |
+
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
| 804 |
+
fprintf(stderr, " -i FNAME, --inp FNAME\n");
|
| 805 |
+
fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
|
| 806 |
+
fprintf(stderr, " -o FNAME, --out FNAME\n");
|
| 807 |
+
fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
|
| 808 |
+
fprintf(stderr, "\n");
|
| 809 |
+
}
|
examples/common.h
CHANGED
|
@@ -11,7 +11,7 @@
|
|
| 11 |
#define COMMON_SAMPLE_RATE 16000
|
| 12 |
|
| 13 |
//
|
| 14 |
-
// CLI argument parsing
|
| 15 |
//
|
| 16 |
|
| 17 |
struct gpt_params {
|
|
@@ -33,6 +33,8 @@ struct gpt_params {
|
|
| 33 |
|
| 34 |
bool interactive = false;
|
| 35 |
int32_t interactive_port = -1;
|
|
|
|
|
|
|
| 36 |
};
|
| 37 |
|
| 38 |
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
|
@@ -155,3 +157,20 @@ bool vad_simple(
|
|
| 155 |
|
| 156 |
// compute similarity between two strings using Levenshtein distance
|
| 157 |
float similarity(const std::string & s0, const std::string & s1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
#define COMMON_SAMPLE_RATE 16000
|
| 12 |
|
| 13 |
//
|
| 14 |
+
// GPT CLI argument parsing
|
| 15 |
//
|
| 16 |
|
| 17 |
struct gpt_params {
|
|
|
|
| 33 |
|
| 34 |
bool interactive = false;
|
| 35 |
int32_t interactive_port = -1;
|
| 36 |
+
|
| 37 |
+
int32_t n_gpu_layers = 0;
|
| 38 |
};
|
| 39 |
|
| 40 |
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
|
|
|
| 157 |
|
| 158 |
// compute similarity between two strings using Levenshtein distance
|
| 159 |
float similarity(const std::string & s0, const std::string & s1);
|
| 160 |
+
|
| 161 |
+
//
|
| 162 |
+
// SAM argument parsing
|
| 163 |
+
//
|
| 164 |
+
|
| 165 |
+
struct sam_params {
|
| 166 |
+
int32_t seed = -1; // RNG seed
|
| 167 |
+
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
| 168 |
+
|
| 169 |
+
std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
|
| 170 |
+
std::string fname_inp = "img.jpg";
|
| 171 |
+
std::string fname_out = "img.out";
|
| 172 |
+
};
|
| 173 |
+
|
| 174 |
+
bool sam_params_parse(int argc, char ** argv, sam_params & params);
|
| 175 |
+
|
| 176 |
+
void sam_print_usage(int argc, char ** argv, const sam_params & params);
|
examples/talk.wasm/gpt-2.cpp
CHANGED
|
@@ -191,9 +191,9 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|
| 191 |
// create the ggml context
|
| 192 |
{
|
| 193 |
struct ggml_init_params params = {
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
};
|
| 198 |
|
| 199 |
model.ctx = ggml_init(params);
|
|
@@ -420,7 +420,6 @@ bool gpt2_eval(
|
|
| 420 |
|
| 421 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 422 |
struct ggml_cgraph gf = {};
|
| 423 |
-
gf.n_threads = n_threads;
|
| 424 |
|
| 425 |
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 426 |
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
|
@@ -442,7 +441,7 @@ bool gpt2_eval(
|
|
| 442 |
// norm
|
| 443 |
{
|
| 444 |
// [ 768, N]
|
| 445 |
-
cur = ggml_norm(ctx0, inpL);
|
| 446 |
|
| 447 |
// cur = ln_1_g*cur + ln_1_b
|
| 448 |
// [ 768, N]
|
|
@@ -589,7 +588,7 @@ bool gpt2_eval(
|
|
| 589 |
{
|
| 590 |
// norm
|
| 591 |
{
|
| 592 |
-
cur = ggml_norm(ctx0, inpFF);
|
| 593 |
|
| 594 |
// cur = ln_2_g*cur + ln_2_b
|
| 595 |
// [ 768, N]
|
|
@@ -644,7 +643,7 @@ bool gpt2_eval(
|
|
| 644 |
// norm
|
| 645 |
{
|
| 646 |
// [ 768, N]
|
| 647 |
-
inpL = ggml_norm(ctx0, inpL);
|
| 648 |
|
| 649 |
// inpL = ln_f_g*inpL + ln_f_b
|
| 650 |
// [ 768, N]
|
|
@@ -664,8 +663,8 @@ bool gpt2_eval(
|
|
| 664 |
//inpL = ggml_soft_max(ctx0, inpL);
|
| 665 |
|
| 666 |
// run the computation
|
| 667 |
-
ggml_build_forward_expand(&gf, inpL);
|
| 668 |
-
|
| 669 |
|
| 670 |
//if (n_past%100 == 0) {
|
| 671 |
// ggml_graph_print (&gf);
|
|
|
|
| 191 |
// create the ggml context
|
| 192 |
{
|
| 193 |
struct ggml_init_params params = {
|
| 194 |
+
/*.mem_size =*/ ctx_size,
|
| 195 |
+
/*.mem_buffer =*/ NULL,
|
| 196 |
+
/*.no_alloc =*/ false,
|
| 197 |
};
|
| 198 |
|
| 199 |
model.ctx = ggml_init(params);
|
|
|
|
| 420 |
|
| 421 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 422 |
struct ggml_cgraph gf = {};
|
|
|
|
| 423 |
|
| 424 |
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 425 |
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
|
|
|
| 441 |
// norm
|
| 442 |
{
|
| 443 |
// [ 768, N]
|
| 444 |
+
cur = ggml_norm(ctx0, inpL, 1e-5f);
|
| 445 |
|
| 446 |
// cur = ln_1_g*cur + ln_1_b
|
| 447 |
// [ 768, N]
|
|
|
|
| 588 |
{
|
| 589 |
// norm
|
| 590 |
{
|
| 591 |
+
cur = ggml_norm(ctx0, inpFF, 1e-5f);
|
| 592 |
|
| 593 |
// cur = ln_2_g*cur + ln_2_b
|
| 594 |
// [ 768, N]
|
|
|
|
| 643 |
// norm
|
| 644 |
{
|
| 645 |
// [ 768, N]
|
| 646 |
+
inpL = ggml_norm(ctx0, inpL, 1e-5f);
|
| 647 |
|
| 648 |
// inpL = ln_f_g*inpL + ln_f_b
|
| 649 |
// [ 768, N]
|
|
|
|
| 663 |
//inpL = ggml_soft_max(ctx0, inpL);
|
| 664 |
|
| 665 |
// run the computation
|
| 666 |
+
ggml_build_forward_expand (&gf, inpL);
|
| 667 |
+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
| 668 |
|
| 669 |
//if (n_past%100 == 0) {
|
| 670 |
// ggml_graph_print (&gf);
|
examples/talk/gpt-2.cpp
CHANGED
|
@@ -379,6 +379,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
|
|
| 379 |
// - embd_inp: the embeddings of the tokens in the context
|
| 380 |
// - embd_w: the predicted logits for the next token
|
| 381 |
//
|
|
|
|
| 382 |
bool gpt2_eval(
|
| 383 |
const gpt2_model & model,
|
| 384 |
const int n_threads,
|
|
@@ -420,7 +421,6 @@ bool gpt2_eval(
|
|
| 420 |
|
| 421 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 422 |
struct ggml_cgraph gf = {};
|
| 423 |
-
gf.n_threads = n_threads;
|
| 424 |
|
| 425 |
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 426 |
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
|
@@ -442,7 +442,7 @@ bool gpt2_eval(
|
|
| 442 |
// norm
|
| 443 |
{
|
| 444 |
// [ 768, N]
|
| 445 |
-
cur = ggml_norm(ctx0, inpL);
|
| 446 |
|
| 447 |
// cur = ln_1_g*cur + ln_1_b
|
| 448 |
// [ 768, N]
|
|
@@ -589,7 +589,7 @@ bool gpt2_eval(
|
|
| 589 |
{
|
| 590 |
// norm
|
| 591 |
{
|
| 592 |
-
cur = ggml_norm(ctx0, inpFF);
|
| 593 |
|
| 594 |
// cur = ln_2_g*cur + ln_2_b
|
| 595 |
// [ 768, N]
|
|
@@ -644,7 +644,7 @@ bool gpt2_eval(
|
|
| 644 |
// norm
|
| 645 |
{
|
| 646 |
// [ 768, N]
|
| 647 |
-
inpL = ggml_norm(ctx0, inpL);
|
| 648 |
|
| 649 |
// inpL = ln_f_g*inpL + ln_f_b
|
| 650 |
// [ 768, N]
|
|
@@ -664,8 +664,8 @@ bool gpt2_eval(
|
|
| 664 |
//inpL = ggml_soft_max(ctx0, inpL);
|
| 665 |
|
| 666 |
// run the computation
|
| 667 |
-
ggml_build_forward_expand(&gf, inpL);
|
| 668 |
-
|
| 669 |
|
| 670 |
//if (n_past%100 == 0) {
|
| 671 |
// ggml_graph_print (&gf);
|
|
|
|
| 379 |
// - embd_inp: the embeddings of the tokens in the context
|
| 380 |
// - embd_w: the predicted logits for the next token
|
| 381 |
//
|
| 382 |
+
// TODO: sync latest version from ggml repo
|
| 383 |
bool gpt2_eval(
|
| 384 |
const gpt2_model & model,
|
| 385 |
const int n_threads,
|
|
|
|
| 421 |
|
| 422 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 423 |
struct ggml_cgraph gf = {};
|
|
|
|
| 424 |
|
| 425 |
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 426 |
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
|
|
|
| 442 |
// norm
|
| 443 |
{
|
| 444 |
// [ 768, N]
|
| 445 |
+
cur = ggml_norm(ctx0, inpL, 1e-5f);
|
| 446 |
|
| 447 |
// cur = ln_1_g*cur + ln_1_b
|
| 448 |
// [ 768, N]
|
|
|
|
| 589 |
{
|
| 590 |
// norm
|
| 591 |
{
|
| 592 |
+
cur = ggml_norm(ctx0, inpFF, 1e-5f);
|
| 593 |
|
| 594 |
// cur = ln_2_g*cur + ln_2_b
|
| 595 |
// [ 768, N]
|
|
|
|
| 644 |
// norm
|
| 645 |
{
|
| 646 |
// [ 768, N]
|
| 647 |
+
inpL = ggml_norm(ctx0, inpL, 1e-5f);
|
| 648 |
|
| 649 |
// inpL = ln_f_g*inpL + ln_f_b
|
| 650 |
// [ 768, N]
|
|
|
|
| 664 |
//inpL = ggml_soft_max(ctx0, inpL);
|
| 665 |
|
| 666 |
// run the computation
|
| 667 |
+
ggml_build_forward_expand (&gf, inpL);
|
| 668 |
+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
| 669 |
|
| 670 |
//if (n_past%100 == 0) {
|
| 671 |
// ggml_graph_print (&gf);
|
ggml-alloc.c
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ggml-alloc.h"
|
| 2 |
+
#include "ggml.h"
|
| 3 |
+
#include <assert.h>
|
| 4 |
+
#include <stdarg.h>
|
| 5 |
+
#include <stdio.h>
|
| 6 |
+
#include <stdlib.h>
|
| 7 |
+
#include <string.h>
|
| 8 |
+
|
| 9 |
+
#define UNUSED(x) (void)(x)
|
| 10 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
| 11 |
+
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
| 12 |
+
|
| 13 |
+
//#define GGML_ALLOCATOR_DEBUG
|
| 14 |
+
|
| 15 |
+
//#define AT_PRINTF printf
|
| 16 |
+
#define AT_PRINTF(...) ((void)0)
|
| 17 |
+
|
| 18 |
+
struct hash_node {
|
| 19 |
+
struct ggml_tensor * t;
|
| 20 |
+
int n_children;
|
| 21 |
+
int n_views;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
static size_t hash(void * p) {
|
| 25 |
+
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
|
| 29 |
+
size_t h = hash(t);
|
| 30 |
+
|
| 31 |
+
// linear probing
|
| 32 |
+
size_t i = h;
|
| 33 |
+
while (hash_table[i].t != NULL) {
|
| 34 |
+
if (hash_table[i].t == t) {
|
| 35 |
+
return &hash_table[i];
|
| 36 |
+
}
|
| 37 |
+
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
| 38 |
+
if (i == h) {
|
| 39 |
+
// hash table is full
|
| 40 |
+
GGML_ASSERT(false);
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
hash_table[i].t = t;
|
| 45 |
+
return &hash_table[i];
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// TODO: GGML_PAD ?
|
| 49 |
+
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
|
| 50 |
+
assert(alignment && !(alignment & (alignment - 1))); // power of 2
|
| 51 |
+
size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
|
| 52 |
+
return offset + align;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
struct free_block {
|
| 56 |
+
void * addr;
|
| 57 |
+
size_t size;
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
#define MAX_FREE_BLOCKS 128
|
| 61 |
+
|
| 62 |
+
struct ggml_allocr {
|
| 63 |
+
void * data;
|
| 64 |
+
size_t size;
|
| 65 |
+
size_t alignment;
|
| 66 |
+
int n_free_blocks;
|
| 67 |
+
struct free_block free_blocks[MAX_FREE_BLOCKS];
|
| 68 |
+
struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
|
| 69 |
+
size_t max_size;
|
| 70 |
+
bool measure;
|
| 71 |
+
int parse_seq[GGML_MAX_CONCUR];
|
| 72 |
+
int parse_seq_len;
|
| 73 |
+
|
| 74 |
+
#ifdef GGML_ALLOCATOR_DEBUG
|
| 75 |
+
struct ggml_tensor * allocated_tensors[1024];
|
| 76 |
+
#endif
|
| 77 |
+
};
|
| 78 |
+
|
| 79 |
+
#ifdef GGML_ALLOCATOR_DEBUG
|
| 80 |
+
static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 81 |
+
for (int i = 0; i < 1024; i++) {
|
| 82 |
+
if (alloc->allocated_tensors[i] == NULL) {
|
| 83 |
+
alloc->allocated_tensors[i] = tensor;
|
| 84 |
+
return;
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
GGML_ASSERT(!"out of allocated_tensors");
|
| 88 |
+
}
|
| 89 |
+
static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 90 |
+
for (int i = 0; i < 1024; i++) {
|
| 91 |
+
if (alloc->allocated_tensors[i] == tensor ||
|
| 92 |
+
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
|
| 93 |
+
alloc->allocated_tensors[i] = NULL;
|
| 94 |
+
return;
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
printf("tried to free tensor %s not found\n", tensor->name);
|
| 98 |
+
GGML_ASSERT(!"tensor not found");
|
| 99 |
+
}
|
| 100 |
+
#endif
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
static size_t ggml_allocator_get_alloc_size(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 104 |
+
return ggml_nbytes(tensor);
|
| 105 |
+
|
| 106 |
+
UNUSED(alloc);
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 110 |
+
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
|
| 111 |
+
size = aligned_offset(NULL, size, alloc->alignment);
|
| 112 |
+
|
| 113 |
+
AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size);
|
| 114 |
+
|
| 115 |
+
size_t max_avail = 0;
|
| 116 |
+
|
| 117 |
+
// find the best fitting free block besides the last block
|
| 118 |
+
int best_fit_block = -1;
|
| 119 |
+
size_t best_fit_size = SIZE_MAX;
|
| 120 |
+
for (int i = 0; i < alloc->n_free_blocks - 1; i++) {
|
| 121 |
+
struct free_block * block = &alloc->free_blocks[i];
|
| 122 |
+
max_avail = MAX(max_avail, block->size);
|
| 123 |
+
if (block->size >= size && block->size <= best_fit_size) {
|
| 124 |
+
best_fit_block = i;
|
| 125 |
+
best_fit_size = block->size;
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
AT_PRINTF("block %d\n", best_fit_block);
|
| 130 |
+
|
| 131 |
+
if (best_fit_block == -1) {
|
| 132 |
+
// the last block is our last resort
|
| 133 |
+
struct free_block * block = &alloc->free_blocks[alloc->n_free_blocks - 1];
|
| 134 |
+
if (block->size >= size) {
|
| 135 |
+
best_fit_block = alloc->n_free_blocks - 1;
|
| 136 |
+
max_avail = MAX(max_avail, block->size);
|
| 137 |
+
} else {
|
| 138 |
+
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, largest block available %zu)\n",
|
| 139 |
+
__func__, size, max_avail);
|
| 140 |
+
GGML_ASSERT(!"not enough space in the buffer");
|
| 141 |
+
return;
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
struct free_block * block = &alloc->free_blocks[best_fit_block];
|
| 145 |
+
void * addr = block->addr;
|
| 146 |
+
block->addr = (char*)block->addr + size;
|
| 147 |
+
block->size -= size;
|
| 148 |
+
if (block->size == 0) {
|
| 149 |
+
// remove block if empty
|
| 150 |
+
alloc->n_free_blocks--;
|
| 151 |
+
for (int j = best_fit_block; j < alloc->n_free_blocks; j++) {
|
| 152 |
+
alloc->free_blocks[j] = alloc->free_blocks[j+1];
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
tensor->data = addr;
|
| 157 |
+
|
| 158 |
+
#ifdef GGML_ALLOCATOR_DEBUG
|
| 159 |
+
add_allocated_tensor(alloc, tensor);
|
| 160 |
+
size_t cur_max = (char*)addr - (char*)alloc->data + size;
|
| 161 |
+
if (cur_max > alloc->max_size) {
|
| 162 |
+
printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
|
| 163 |
+
for (int i = 0; i < 1024; i++) {
|
| 164 |
+
if (alloc->allocated_tensors[i]) {
|
| 165 |
+
printf("%s (%.2f MB) ", alloc->allocated_tensors[i]->name, ggml_nbytes(alloc->allocated_tensors[i]) / 1024.0 / 1024.0);
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
printf("\n");
|
| 169 |
+
}
|
| 170 |
+
#endif
|
| 171 |
+
|
| 172 |
+
alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// this is a very naive implementation, but for our case the number of free blocks should be very small
|
| 176 |
+
static void ggml_allocator_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
|
| 177 |
+
void * ptr = tensor->data;
|
| 178 |
+
|
| 179 |
+
if (ptr < alloc->data || (char*)ptr >= (char*)alloc->data + alloc->max_size) {
|
| 180 |
+
// the tensor was not allocated in this buffer
|
| 181 |
+
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
|
| 182 |
+
// the easiest way to deal with this is just to ignore it
|
| 183 |
+
return;
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
size_t size = ggml_allocator_get_alloc_size(alloc, tensor);
|
| 187 |
+
size = aligned_offset(NULL, size, alloc->alignment);
|
| 188 |
+
AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks);
|
| 189 |
+
|
| 190 |
+
#ifdef GGML_ALLOCATOR_DEBUG
|
| 191 |
+
remove_allocated_tensor(alloc, tensor);
|
| 192 |
+
#endif
|
| 193 |
+
|
| 194 |
+
// see if we can merge with an existing block
|
| 195 |
+
for (int i = 0; i < alloc->n_free_blocks; i++) {
|
| 196 |
+
struct free_block * block = &alloc->free_blocks[i];
|
| 197 |
+
// check if ptr is at the end of the block
|
| 198 |
+
if ((char*)block->addr + block->size == ptr) {
|
| 199 |
+
block->size += size;
|
| 200 |
+
// check if we can merge with the next block
|
| 201 |
+
if (i < alloc->n_free_blocks - 1 && (char*)block->addr + block->size == alloc->free_blocks[i+1].addr) {
|
| 202 |
+
block->size += alloc->free_blocks[i+1].size;
|
| 203 |
+
alloc->n_free_blocks--;
|
| 204 |
+
for (int j = i+1; j < alloc->n_free_blocks; j++) {
|
| 205 |
+
alloc->free_blocks[j] = alloc->free_blocks[j+1];
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
return;
|
| 209 |
+
}
|
| 210 |
+
// check if ptr is at the beginning of the block
|
| 211 |
+
if ((char*)ptr + size == block->addr) {
|
| 212 |
+
block->addr = ptr;
|
| 213 |
+
block->size += size;
|
| 214 |
+
// check if we can merge with the previous block
|
| 215 |
+
if (i > 0 && (char*)alloc->free_blocks[i-1].addr + alloc->free_blocks[i-1].size == block->addr) {
|
| 216 |
+
alloc->free_blocks[i-1].size += block->size;
|
| 217 |
+
alloc->n_free_blocks--;
|
| 218 |
+
for (int j = i; j < alloc->n_free_blocks; j++) {
|
| 219 |
+
alloc->free_blocks[j] = alloc->free_blocks[j+1];
|
| 220 |
+
}
|
| 221 |
+
}
|
| 222 |
+
return;
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
// otherwise, add a new block
|
| 226 |
+
GGML_ASSERT(alloc->n_free_blocks < MAX_FREE_BLOCKS && "out of free blocks");
|
| 227 |
+
// insert the new block in the correct position to keep the array sorted by address (to make merging blocks faster)
|
| 228 |
+
int insert_pos = 0;
|
| 229 |
+
while (insert_pos < alloc->n_free_blocks && alloc->free_blocks[insert_pos].addr < ptr) {
|
| 230 |
+
insert_pos++;
|
| 231 |
+
}
|
| 232 |
+
// shift all blocks from insert_pos onward to make room for the new block
|
| 233 |
+
for (int i = alloc->n_free_blocks; i > insert_pos; i--) {
|
| 234 |
+
alloc->free_blocks[i] = alloc->free_blocks[i-1];
|
| 235 |
+
}
|
| 236 |
+
// insert the new block
|
| 237 |
+
alloc->free_blocks[insert_pos].addr = ptr;
|
| 238 |
+
alloc->free_blocks[insert_pos].size = size;
|
| 239 |
+
alloc->n_free_blocks++;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
|
| 243 |
+
for (int i = 0; i < n; i++) {
|
| 244 |
+
alloc->parse_seq[i] = list[i];
|
| 245 |
+
}
|
| 246 |
+
alloc->parse_seq_len = n;
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
void ggml_allocr_reset(struct ggml_allocr * alloc) {
|
| 250 |
+
alloc->n_free_blocks = 1;
|
| 251 |
+
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
|
| 252 |
+
alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
|
| 253 |
+
alloc->free_blocks[0].size = alloc->size - align_offset;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
|
| 257 |
+
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
| 258 |
+
|
| 259 |
+
*alloc = (struct ggml_allocr){
|
| 260 |
+
/*.data = */ data,
|
| 261 |
+
/*.size = */ size,
|
| 262 |
+
/*.alignment = */ alignment,
|
| 263 |
+
/*.n_free_blocks = */ 0,
|
| 264 |
+
/*.free_blocks = */ {{0}},
|
| 265 |
+
/*.hash_table = */ {{0}},
|
| 266 |
+
/*.max_size = */ 0,
|
| 267 |
+
/*.measure = */ false,
|
| 268 |
+
/*.parse_seq = */ {0},
|
| 269 |
+
/*.parse_seq_len = */ 0,
|
| 270 |
+
#ifdef GGML_ALLOCATOR_DEBUG
|
| 271 |
+
/*.allocated_tensors = */ {0},
|
| 272 |
+
#endif
|
| 273 |
+
};
|
| 274 |
+
|
| 275 |
+
ggml_allocr_reset(alloc);
|
| 276 |
+
|
| 277 |
+
return alloc;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
// address and size of the buffer when measuring
|
| 281 |
+
// it needs to be large enough to fit all the tensors, but it cannot overlap with other existing buffers
|
| 282 |
+
static void * const MEASURE_BASE_ADDR = (void *) 0x1000;
|
| 283 |
+
static const size_t MEASURE_MAX_SIZE = 1ULL<<40; // 1 TB
|
| 284 |
+
|
| 285 |
+
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
|
| 286 |
+
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */);
|
| 287 |
+
|
| 288 |
+
*alloc = (struct ggml_allocr){
|
| 289 |
+
/*.data = */ MEASURE_BASE_ADDR,
|
| 290 |
+
/*.size = */ MEASURE_MAX_SIZE,
|
| 291 |
+
/*.alignment = */ alignment,
|
| 292 |
+
/*.n_free_blocks = */ 0,
|
| 293 |
+
/*.free_blocks = */ {{0}},
|
| 294 |
+
/*.hash_table = */ {{0}},
|
| 295 |
+
/*.max_size = */ 0,
|
| 296 |
+
/*.measure = */ true,
|
| 297 |
+
/*.parse_seq = */ {0},
|
| 298 |
+
/*.parse_seq_len = */ 0,
|
| 299 |
+
#ifdef GGML_ALLOCATOR_DEBUG
|
| 300 |
+
/*.allocated_tensors = */ {0},
|
| 301 |
+
#endif
|
| 302 |
+
};
|
| 303 |
+
|
| 304 |
+
ggml_allocr_reset(alloc);
|
| 305 |
+
|
| 306 |
+
return alloc;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
void ggml_allocr_free(struct ggml_allocr * alloc) {
|
| 310 |
+
free(alloc);
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
|
| 314 |
+
return alloc->measure;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
//////////// compute graph allocator
|
| 318 |
+
|
| 319 |
+
static bool ggml_is_view(struct ggml_tensor * t) {
|
| 320 |
+
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
|
| 321 |
+
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_CPY;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
|
| 325 |
+
if (a->type != b->type) {
|
| 326 |
+
return false;
|
| 327 |
+
}
|
| 328 |
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
| 329 |
+
if (a->ne[i] != b->ne[i]) {
|
| 330 |
+
return false;
|
| 331 |
+
}
|
| 332 |
+
if (a->nb[i] != b->nb[i]) {
|
| 333 |
+
return false;
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
return true;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
static struct ggml_tensor * get_view_parent(struct ggml_tensor * t) {
|
| 340 |
+
switch (t->op) {
|
| 341 |
+
case GGML_OP_PERMUTE:
|
| 342 |
+
case GGML_OP_RESHAPE:
|
| 343 |
+
case GGML_OP_TRANSPOSE:
|
| 344 |
+
case GGML_OP_VIEW:
|
| 345 |
+
return t->src[0];
|
| 346 |
+
case GGML_OP_CPY:
|
| 347 |
+
return t->src[1];
|
| 348 |
+
default:
|
| 349 |
+
return NULL;
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
static struct ggml_tensor * get_view_source(struct ggml_tensor * t) {
|
| 354 |
+
struct ggml_tensor * parent = t;
|
| 355 |
+
do {
|
| 356 |
+
parent = get_view_parent(parent);
|
| 357 |
+
} while (ggml_is_view(parent));
|
| 358 |
+
return parent;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
static bool ggml_op_can_inplace(enum ggml_op op) {
|
| 362 |
+
switch (op) {
|
| 363 |
+
case GGML_OP_SCALE:
|
| 364 |
+
case GGML_OP_DIAG_MASK_ZERO:
|
| 365 |
+
case GGML_OP_DIAG_MASK_INF:
|
| 366 |
+
case GGML_OP_ADD:
|
| 367 |
+
case GGML_OP_ADD1:
|
| 368 |
+
case GGML_OP_ACC:
|
| 369 |
+
case GGML_OP_SUB:
|
| 370 |
+
case GGML_OP_MUL:
|
| 371 |
+
case GGML_OP_DIV:
|
| 372 |
+
case GGML_OP_SQR:
|
| 373 |
+
case GGML_OP_SQRT:
|
| 374 |
+
case GGML_OP_LOG:
|
| 375 |
+
case GGML_OP_UNARY:
|
| 376 |
+
case GGML_OP_ROPE:
|
| 377 |
+
case GGML_OP_RMS_NORM:
|
| 378 |
+
case GGML_OP_SET:
|
| 379 |
+
case GGML_OP_SOFT_MAX:
|
| 380 |
+
case GGML_OP_CONT:
|
| 381 |
+
case GGML_OP_ADD_REL_POS:
|
| 382 |
+
return true;
|
| 383 |
+
|
| 384 |
+
default:
|
| 385 |
+
return false;
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
|
| 390 |
+
struct hash_node * ht = alloc->hash_table;
|
| 391 |
+
if (node->data == NULL) {
|
| 392 |
+
if (ggml_is_view(node)) {
|
| 393 |
+
size_t offset;
|
| 394 |
+
switch(node->op) {
|
| 395 |
+
case GGML_OP_VIEW:
|
| 396 |
+
memcpy(&offset, node->op_params, sizeof(size_t));
|
| 397 |
+
node->data = (char *) node->src[0]->data + offset;
|
| 398 |
+
break;
|
| 399 |
+
case GGML_OP_PERMUTE:
|
| 400 |
+
case GGML_OP_RESHAPE:
|
| 401 |
+
case GGML_OP_TRANSPOSE:
|
| 402 |
+
node->data = node->src[0]->data;
|
| 403 |
+
break;
|
| 404 |
+
case GGML_OP_CPY:
|
| 405 |
+
node->data = node->src[1]->data;
|
| 406 |
+
break;
|
| 407 |
+
default:
|
| 408 |
+
GGML_ASSERT(!"unknown view op");
|
| 409 |
+
break;
|
| 410 |
+
}
|
| 411 |
+
} else {
|
| 412 |
+
// see if we can reuse a parent's buffer (inplace)
|
| 413 |
+
if (ggml_op_can_inplace(node->op)) {
|
| 414 |
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
| 415 |
+
struct ggml_tensor * parent = node->src[i];
|
| 416 |
+
if (parent == NULL) {
|
| 417 |
+
break;
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
// if the node's data is external, then we cannot re-use it
|
| 421 |
+
if ((char *) parent->data < (char *) alloc->data ||
|
| 422 |
+
(char *) parent->data >= ((char *) alloc->data + alloc->size)) {
|
| 423 |
+
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
|
| 424 |
+
continue;
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
struct hash_node * p_hn = hash_get(ht, parent);
|
| 428 |
+
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
|
| 429 |
+
if (ggml_is_view(parent)) {
|
| 430 |
+
struct ggml_tensor * view_src = get_view_source(parent);
|
| 431 |
+
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
| 432 |
+
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
|
| 433 |
+
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
|
| 434 |
+
// the parent's data that it will need later (same layout requirement). the problem is that then
|
| 435 |
+
// we cannot free the tensor because the original address of the allocation is lost.
|
| 436 |
+
// adding a view_src pointer to the tensor would solve this and simplify the code dealing with views
|
| 437 |
+
// for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data)
|
| 438 |
+
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
|
| 439 |
+
node->data = parent->data;
|
| 440 |
+
return;
|
| 441 |
+
}
|
| 442 |
+
}
|
| 443 |
+
else {
|
| 444 |
+
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
|
| 445 |
+
node->data = parent->data;
|
| 446 |
+
return;
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
}
|
| 450 |
+
}
|
| 451 |
+
ggml_allocr_alloc(alloc, node);
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
static size_t ggml_allocator_alloc_graph_tensors_n(
|
| 457 |
+
struct ggml_allocr * alloc,
|
| 458 |
+
struct ggml_cgraph ** graphs, int n_graphs,
|
| 459 |
+
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
|
| 460 |
+
|
| 461 |
+
// reset hash table
|
| 462 |
+
struct hash_node * ht = alloc->hash_table;
|
| 463 |
+
memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
|
| 464 |
+
|
| 465 |
+
// count number of children and views
|
| 466 |
+
for (int g = 0; g < n_graphs; g++) {
|
| 467 |
+
struct ggml_cgraph * gf = graphs[g];
|
| 468 |
+
for (int i = 0; i < gf->n_nodes; i++) {
|
| 469 |
+
struct ggml_tensor * node = gf->nodes[i];
|
| 470 |
+
|
| 471 |
+
if (ggml_is_view(node)) {
|
| 472 |
+
struct ggml_tensor * view_src = get_view_source(node);
|
| 473 |
+
hash_get(ht, view_src)->n_views += 1;
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 477 |
+
struct ggml_tensor * parent = node->src[j];
|
| 478 |
+
if (parent == NULL) {
|
| 479 |
+
break;
|
| 480 |
+
}
|
| 481 |
+
hash_get(ht, parent)->n_children += 1;
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
// allocate tensors
|
| 487 |
+
for (int g = 0; g < n_graphs; g++) {
|
| 488 |
+
struct ggml_cgraph * gf = graphs[g];
|
| 489 |
+
AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
|
| 490 |
+
// graph inputs are allocated first to ensure that they are not overwritten by each other
|
| 491 |
+
if (inputs != NULL && inputs[g] != NULL) {
|
| 492 |
+
for (int i = 0; inputs[g][i] != NULL; i++) {
|
| 493 |
+
struct ggml_tensor * input = inputs[g][i];
|
| 494 |
+
AT_PRINTF("input: %s\n", input->name);
|
| 495 |
+
allocate_node(alloc, input);
|
| 496 |
+
}
|
| 497 |
+
}
|
| 498 |
+
// if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
|
| 499 |
+
int last_barrier_pos = 0;
|
| 500 |
+
int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
|
| 501 |
+
|
| 502 |
+
for (int ind = 0; ind < n_nodes; ind++) {
|
| 503 |
+
// allocate a node if there is no parse_seq or this is not a barrier
|
| 504 |
+
if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
|
| 505 |
+
int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
|
| 506 |
+
struct ggml_tensor * node = gf->nodes[i];
|
| 507 |
+
|
| 508 |
+
// allocate parents (leafs)
|
| 509 |
+
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 510 |
+
struct ggml_tensor * parent = node->src[j];
|
| 511 |
+
if (parent == NULL) {
|
| 512 |
+
break;
|
| 513 |
+
}
|
| 514 |
+
allocate_node(alloc, parent);
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
// allocate node
|
| 518 |
+
allocate_node(alloc, node);
|
| 519 |
+
|
| 520 |
+
AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
|
| 521 |
+
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 522 |
+
struct ggml_tensor * parent = node->src[j];
|
| 523 |
+
if (parent == NULL) {
|
| 524 |
+
break;
|
| 525 |
+
}
|
| 526 |
+
AT_PRINTF("%s", parent->name);
|
| 527 |
+
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
|
| 528 |
+
AT_PRINTF(", ");
|
| 529 |
+
}
|
| 530 |
+
}
|
| 531 |
+
AT_PRINTF("\n");
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
// update parents
|
| 536 |
+
// update immediately if there is no parse_seq
|
| 537 |
+
// update only at barriers if there is parse_seq
|
| 538 |
+
if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] == -1) {
|
| 539 |
+
int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
|
| 540 |
+
int update_end = alloc->parse_seq_len ? ind : ind + 1;
|
| 541 |
+
for (int i = update_start; i < update_end; i++) {
|
| 542 |
+
int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
|
| 543 |
+
struct ggml_tensor * node = gf->nodes[node_i];
|
| 544 |
+
|
| 545 |
+
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 546 |
+
struct ggml_tensor * parent = node->src[j];
|
| 547 |
+
if (parent == NULL) {
|
| 548 |
+
break;
|
| 549 |
+
}
|
| 550 |
+
struct hash_node * p_hn = hash_get(ht, parent);
|
| 551 |
+
p_hn->n_children -= 1;
|
| 552 |
+
|
| 553 |
+
//AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
|
| 554 |
+
|
| 555 |
+
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
|
| 556 |
+
if (ggml_is_view(parent)) {
|
| 557 |
+
struct ggml_tensor * view_src = get_view_source(parent);
|
| 558 |
+
struct hash_node * view_src_hn = hash_get(ht, view_src);
|
| 559 |
+
view_src_hn->n_views -= 1;
|
| 560 |
+
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
|
| 561 |
+
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
|
| 562 |
+
ggml_allocator_free_tensor(alloc, view_src);
|
| 563 |
+
}
|
| 564 |
+
}
|
| 565 |
+
else {
|
| 566 |
+
if (parent->data != node->data) {
|
| 567 |
+
ggml_allocator_free_tensor(alloc, parent);
|
| 568 |
+
}
|
| 569 |
+
}
|
| 570 |
+
}
|
| 571 |
+
}
|
| 572 |
+
}
|
| 573 |
+
AT_PRINTF("\n");
|
| 574 |
+
if (alloc->parse_seq_len) {
|
| 575 |
+
last_barrier_pos = ind + 1;
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
}
|
| 579 |
+
// free graph outputs here that wouldn't be freed otherwise because they have no children
|
| 580 |
+
if (outputs != NULL && outputs[g] != NULL) {
|
| 581 |
+
for (int i = 0; outputs[g][i] != NULL; i++) {
|
| 582 |
+
struct ggml_tensor * output = outputs[g][i];
|
| 583 |
+
AT_PRINTF("output: %s\n", output->name);
|
| 584 |
+
ggml_allocator_free_tensor(alloc, output);
|
| 585 |
+
}
|
| 586 |
+
}
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
return alloc->max_size;
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
|
| 593 |
+
return ggml_allocator_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL);
|
| 594 |
+
}
|
ggml-alloc.h
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "ggml.h"
|
| 4 |
+
|
| 5 |
+
#ifdef __cplusplus
|
| 6 |
+
extern "C" {
|
| 7 |
+
#endif
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
|
| 11 |
+
GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
|
| 12 |
+
|
| 13 |
+
// tell the allocator to parse nodes following the order described in the list
|
| 14 |
+
// you should call this if your graph are optimized to execute out-of-order
|
| 15 |
+
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
|
| 16 |
+
|
| 17 |
+
GGML_API void ggml_allocr_free(struct ggml_allocr * alloc);
|
| 18 |
+
GGML_API bool ggml_allocr_is_measure(struct ggml_allocr * alloc);
|
| 19 |
+
GGML_API void ggml_allocr_reset(struct ggml_allocr * alloc);
|
| 20 |
+
GGML_API void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor);
|
| 21 |
+
GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
#ifdef __cplusplus
|
| 25 |
+
}
|
| 26 |
+
#endif
|
ggml-cuda.cu
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml-cuda.h
CHANGED
|
@@ -2,34 +2,44 @@
|
|
| 2 |
|
| 3 |
#include "ggml.h"
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
#ifdef __cplusplus
|
| 6 |
extern "C" {
|
| 7 |
#endif
|
| 8 |
|
| 9 |
#define GGML_CUDA_MAX_DEVICES 16
|
| 10 |
|
| 11 |
-
void ggml_init_cublas(void);
|
| 12 |
-
void
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
void
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
|
| 18 |
|
| 19 |
-
|
| 20 |
-
void *
|
| 21 |
-
void ggml_cuda_host_free(void * ptr);
|
| 22 |
|
| 23 |
-
void
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
void
|
| 26 |
-
void
|
| 27 |
-
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
| 28 |
-
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
| 29 |
-
void ggml_cuda_set_main_device(int main_device);
|
| 30 |
-
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
| 31 |
-
void ggml_cuda_free_scratch(void);
|
| 32 |
-
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
| 33 |
|
| 34 |
#ifdef __cplusplus
|
| 35 |
}
|
|
|
|
| 2 |
|
| 3 |
#include "ggml.h"
|
| 4 |
|
| 5 |
+
#ifdef GGML_USE_HIPBLAS
|
| 6 |
+
#define GGML_CUDA_NAME "ROCm"
|
| 7 |
+
#define GGML_CUBLAS_NAME "hipBLAS"
|
| 8 |
+
#else
|
| 9 |
+
#define GGML_CUDA_NAME "CUDA"
|
| 10 |
+
#define GGML_CUBLAS_NAME "cuBLAS"
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
#ifdef __cplusplus
|
| 14 |
extern "C" {
|
| 15 |
#endif
|
| 16 |
|
| 17 |
#define GGML_CUDA_MAX_DEVICES 16
|
| 18 |
|
| 19 |
+
GGML_API void ggml_init_cublas(void);
|
| 20 |
+
GGML_API void * ggml_cuda_host_malloc(size_t size);
|
| 21 |
+
GGML_API void ggml_cuda_host_free(void * ptr);
|
| 22 |
+
|
| 23 |
+
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
| 24 |
+
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
| 25 |
+
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
| 26 |
+
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
| 27 |
|
| 28 |
+
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
| 29 |
+
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
| 30 |
+
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
|
|
|
| 31 |
|
| 32 |
+
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
|
| 33 |
+
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
|
|
|
|
| 34 |
|
| 35 |
+
GGML_API void ggml_cuda_set_main_device(int main_device);
|
| 36 |
+
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
|
| 37 |
+
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
|
| 38 |
+
GGML_API void ggml_cuda_free_scratch(void);
|
| 39 |
+
GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
| 40 |
|
| 41 |
+
GGML_API int ggml_cuda_get_device_count(void);
|
| 42 |
+
GGML_API void ggml_cuda_get_device_description(int device, char * description, size_t description_size);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
#ifdef __cplusplus
|
| 45 |
}
|
ggml-metal.h
CHANGED
|
@@ -24,6 +24,7 @@
|
|
| 24 |
|
| 25 |
// max memory buffers that can be mapped to the device
|
| 26 |
#define GGML_METAL_MAX_BUFFERS 16
|
|
|
|
| 27 |
|
| 28 |
struct ggml_tensor;
|
| 29 |
struct ggml_cgraph;
|
|
@@ -34,9 +35,16 @@ extern "C" {
|
|
| 34 |
|
| 35 |
struct ggml_metal_context;
|
| 36 |
|
| 37 |
-
|
|
|
|
| 38 |
void ggml_metal_free(struct ggml_metal_context * ctx);
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
// creates a mapping between a host memory buffer and a device memory buffer
|
| 41 |
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
|
| 42 |
// - the mapping is used during computation to determine the arguments of the compute kernels
|
|
@@ -57,6 +65,16 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
|
|
| 57 |
// get data from the device into host memory
|
| 58 |
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
// same as ggml_graph_compute but uses Metal
|
| 61 |
// creates gf->n_threads command buffers in parallel
|
| 62 |
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
|
|
|
| 24 |
|
| 25 |
// max memory buffers that can be mapped to the device
|
| 26 |
#define GGML_METAL_MAX_BUFFERS 16
|
| 27 |
+
#define GGML_METAL_MAX_COMMAND_BUFFERS 32
|
| 28 |
|
| 29 |
struct ggml_tensor;
|
| 30 |
struct ggml_cgraph;
|
|
|
|
| 35 |
|
| 36 |
struct ggml_metal_context;
|
| 37 |
|
| 38 |
+
// number of command buffers to use
|
| 39 |
+
struct ggml_metal_context * ggml_metal_init(int n_cb);
|
| 40 |
void ggml_metal_free(struct ggml_metal_context * ctx);
|
| 41 |
|
| 42 |
+
void * ggml_metal_host_malloc(size_t n);
|
| 43 |
+
void ggml_metal_host_free (void * data);
|
| 44 |
+
|
| 45 |
+
// set the number of command buffers to use
|
| 46 |
+
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
|
| 47 |
+
|
| 48 |
// creates a mapping between a host memory buffer and a device memory buffer
|
| 49 |
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
|
| 50 |
// - the mapping is used during computation to determine the arguments of the compute kernels
|
|
|
|
| 65 |
// get data from the device into host memory
|
| 66 |
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
|
| 67 |
|
| 68 |
+
// try to find operations that can be run concurrently in the graph
|
| 69 |
+
// you should run it again if the topology of your graph changes
|
| 70 |
+
void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf, bool check_mem);
|
| 71 |
+
|
| 72 |
+
// if the graph has been optimized for concurrently dispatch, return length of the concur_list if optimized
|
| 73 |
+
int ggml_metal_if_optimized(struct ggml_metal_context * ctx);
|
| 74 |
+
|
| 75 |
+
// output the concur_list for ggml_alloc
|
| 76 |
+
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
|
| 77 |
+
|
| 78 |
// same as ggml_graph_compute but uses Metal
|
| 79 |
// creates gf->n_threads command buffers in parallel
|
| 80 |
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
|
ggml-metal.m
CHANGED
|
@@ -5,7 +5,11 @@
|
|
| 5 |
#import <Foundation/Foundation.h>
|
| 6 |
|
| 7 |
#import <Metal/Metal.h>
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
#ifdef GGML_METAL_NDEBUG
|
| 11 |
#define metal_printf(...)
|
|
@@ -15,6 +19,8 @@
|
|
| 15 |
|
| 16 |
#define UNUSED(x) (void)(x)
|
| 17 |
|
|
|
|
|
|
|
| 18 |
struct ggml_metal_buffer {
|
| 19 |
const char * name;
|
| 20 |
|
|
@@ -25,21 +31,30 @@ struct ggml_metal_buffer {
|
|
| 25 |
};
|
| 26 |
|
| 27 |
struct ggml_metal_context {
|
| 28 |
-
|
| 29 |
|
| 30 |
id<MTLDevice> device;
|
| 31 |
id<MTLCommandQueue> queue;
|
| 32 |
id<MTLLibrary> library;
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
int n_buffers;
|
| 35 |
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
| 36 |
|
|
|
|
|
|
|
|
|
|
| 37 |
// custom kernels
|
| 38 |
#define GGML_METAL_DECL_KERNEL(name) \
|
| 39 |
id<MTLFunction> function_##name; \
|
| 40 |
id<MTLComputePipelineState> pipeline_##name
|
| 41 |
|
| 42 |
GGML_METAL_DECL_KERNEL(add);
|
|
|
|
| 43 |
GGML_METAL_DECL_KERNEL(mul);
|
| 44 |
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
| 45 |
GGML_METAL_DECL_KERNEL(scale);
|
|
@@ -51,6 +66,7 @@ struct ggml_metal_context {
|
|
| 51 |
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
| 52 |
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
| 53 |
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
|
|
|
| 54 |
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
| 55 |
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
| 56 |
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
|
@@ -61,11 +77,21 @@ struct ggml_metal_context {
|
|
| 61 |
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
| 62 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
| 63 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
|
|
|
| 64 |
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
| 65 |
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
| 66 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
| 67 |
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
| 68 |
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
GGML_METAL_DECL_KERNEL(rope);
|
| 70 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
| 71 |
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
@@ -86,22 +112,18 @@ static NSString * const msl_library_source = @"see metal.metal";
|
|
| 86 |
@implementation GGMLMetalClass
|
| 87 |
@end
|
| 88 |
|
| 89 |
-
struct ggml_metal_context * ggml_metal_init(
|
| 90 |
fprintf(stderr, "%s: allocating\n", __func__);
|
| 91 |
|
| 92 |
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
| 93 |
|
|
|
|
| 94 |
ctx->device = MTLCreateSystemDefaultDevice();
|
| 95 |
ctx->queue = [ctx->device newCommandQueue];
|
| 96 |
ctx->n_buffers = 0;
|
|
|
|
| 97 |
|
| 98 |
-
|
| 99 |
-
if (MPSSupportsMTLDevice(ctx->device)) {
|
| 100 |
-
fprintf(stderr, "%s: using MPS\n", __func__);
|
| 101 |
-
} else {
|
| 102 |
-
fprintf(stderr, "%s: not using MPS\n", __func__);
|
| 103 |
-
GGML_ASSERT(false && "MPS not supported");
|
| 104 |
-
}
|
| 105 |
|
| 106 |
#if 0
|
| 107 |
// compile from source string and show compile log
|
|
@@ -111,7 +133,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 111 |
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
| 112 |
if (error) {
|
| 113 |
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 114 |
-
|
| 115 |
}
|
| 116 |
}
|
| 117 |
#else
|
|
@@ -129,7 +151,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 129 |
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
| 130 |
if (error) {
|
| 131 |
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 132 |
-
|
| 133 |
}
|
| 134 |
|
| 135 |
#ifdef GGML_QKK_64
|
|
@@ -141,19 +163,27 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 141 |
#endif
|
| 142 |
if (error) {
|
| 143 |
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 144 |
-
|
| 145 |
}
|
| 146 |
}
|
| 147 |
#endif
|
| 148 |
|
| 149 |
// load kernels
|
| 150 |
{
|
|
|
|
| 151 |
#define GGML_METAL_ADD_KERNEL(name) \
|
| 152 |
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
| 153 |
-
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error
|
| 154 |
-
fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
GGML_METAL_ADD_KERNEL(add);
|
|
|
|
| 157 |
GGML_METAL_ADD_KERNEL(mul);
|
| 158 |
GGML_METAL_ADD_KERNEL(mul_row);
|
| 159 |
GGML_METAL_ADD_KERNEL(scale);
|
|
@@ -165,6 +195,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 165 |
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
| 166 |
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
| 167 |
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
|
|
|
| 168 |
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
| 169 |
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
| 170 |
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
|
@@ -175,11 +206,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 175 |
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
| 176 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
| 177 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
|
|
|
| 178 |
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
| 179 |
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
| 180 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
| 181 |
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
| 182 |
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
GGML_METAL_ADD_KERNEL(rope);
|
| 184 |
GGML_METAL_ADD_KERNEL(alibi_f32);
|
| 185 |
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
@@ -189,12 +230,12 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 189 |
#undef GGML_METAL_ADD_KERNEL
|
| 190 |
}
|
| 191 |
|
| 192 |
-
fprintf(stderr, "%s: recommendedMaxWorkingSetSize
|
| 193 |
-
fprintf(stderr, "%s: hasUnifiedMemory
|
| 194 |
if (ctx->device.maxTransferRate != 0) {
|
| 195 |
-
fprintf(stderr, "%s: maxTransferRate
|
| 196 |
} else {
|
| 197 |
-
fprintf(stderr, "%s: maxTransferRate
|
| 198 |
}
|
| 199 |
|
| 200 |
return ctx;
|
|
@@ -202,12 +243,97 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 202 |
|
| 203 |
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
| 204 |
fprintf(stderr, "%s: deallocating\n", __func__);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
for (int i = 0; i < ctx->n_buffers; ++i) {
|
| 206 |
[ctx->buffers[i].metal release];
|
| 207 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
free(ctx);
|
| 209 |
}
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
// finds the Metal buffer that contains the tensor data on the GPU device
|
| 212 |
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
| 213 |
// Metal buffer based on the host memory pointer
|
|
@@ -346,48 +472,154 @@ void ggml_metal_get_tensor(
|
|
| 346 |
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
| 347 |
}
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
void ggml_metal_graph_compute(
|
| 350 |
struct ggml_metal_context * ctx,
|
| 351 |
struct ggml_cgraph * gf) {
|
| 352 |
metal_printf("%s: evaluating graph\n", __func__);
|
| 353 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
// create multiple command buffers and enqueue them
|
| 355 |
// then, we encode the graph into the command buffers in parallel
|
| 356 |
|
| 357 |
-
const int n_cb =
|
| 358 |
-
|
| 359 |
-
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
|
| 360 |
|
| 361 |
for (int i = 0; i < n_cb; ++i) {
|
| 362 |
-
command_buffers[i] = [ctx->queue commandBuffer];
|
| 363 |
|
| 364 |
// enqueue the command buffers in order to specify their execution order
|
| 365 |
-
[command_buffers[i] enqueue];
|
| 366 |
-
}
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
|
| 371 |
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
| 372 |
-
const int n_nodes_per_cb = (
|
| 373 |
|
| 374 |
-
dispatch_async(
|
| 375 |
size_t offs_src0 = 0;
|
| 376 |
size_t offs_src1 = 0;
|
| 377 |
size_t offs_dst = 0;
|
| 378 |
|
| 379 |
-
id<MTLCommandBuffer> command_buffer
|
| 380 |
-
|
| 381 |
-
id<MTLComputeCommandEncoder> encoder = nil;
|
| 382 |
|
| 383 |
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
| 384 |
-
const int node_end = (cb_idx == n_cb - 1) ?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
-
for (int i = node_start; i < node_end; ++i) {
|
| 387 |
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
| 388 |
|
| 389 |
-
struct ggml_tensor * src0 = gf->nodes[i]->
|
| 390 |
-
struct ggml_tensor * src1 = gf->nodes[i]->
|
| 391 |
struct ggml_tensor * dst = gf->nodes[i];
|
| 392 |
|
| 393 |
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
@@ -443,6 +675,7 @@ void ggml_metal_graph_compute(
|
|
| 443 |
//}
|
| 444 |
|
| 445 |
switch (dst->op) {
|
|
|
|
| 446 |
case GGML_OP_RESHAPE:
|
| 447 |
case GGML_OP_VIEW:
|
| 448 |
case GGML_OP_TRANSPOSE:
|
|
@@ -452,14 +685,16 @@ void ggml_metal_graph_compute(
|
|
| 452 |
} break;
|
| 453 |
case GGML_OP_ADD:
|
| 454 |
{
|
| 455 |
-
if (
|
| 456 |
-
|
|
|
|
|
|
|
|
|
|
| 457 |
}
|
| 458 |
-
|
| 459 |
-
[encoder setComputePipelineState:ctx->pipeline_add];
|
| 460 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 461 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 462 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
| 463 |
|
| 464 |
const int64_t n = ggml_nelements(dst);
|
| 465 |
|
|
@@ -467,10 +702,6 @@ void ggml_metal_graph_compute(
|
|
| 467 |
} break;
|
| 468 |
case GGML_OP_MUL:
|
| 469 |
{
|
| 470 |
-
if (encoder == nil) {
|
| 471 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 472 |
-
}
|
| 473 |
-
|
| 474 |
if (ggml_nelements(src1) == ne10) {
|
| 475 |
// src1 is a row
|
| 476 |
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
|
@@ -488,10 +719,6 @@ void ggml_metal_graph_compute(
|
|
| 488 |
} break;
|
| 489 |
case GGML_OP_SCALE:
|
| 490 |
{
|
| 491 |
-
if (encoder == nil) {
|
| 492 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 493 |
-
}
|
| 494 |
-
|
| 495 |
const float scale = *(const float *) src1->data;
|
| 496 |
|
| 497 |
[encoder setComputePipelineState:ctx->pipeline_scale];
|
|
@@ -503,54 +730,46 @@ void ggml_metal_graph_compute(
|
|
| 503 |
|
| 504 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 505 |
} break;
|
| 506 |
-
case
|
| 507 |
-
{
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
} break;
|
| 534 |
-
case GGML_OP_GELU:
|
| 535 |
-
{
|
| 536 |
-
if (encoder == nil) {
|
| 537 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 538 |
-
}
|
| 539 |
-
|
| 540 |
-
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
| 541 |
-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 542 |
-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 543 |
-
|
| 544 |
-
const int64_t n = ggml_nelements(dst);
|
| 545 |
-
|
| 546 |
-
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 547 |
-
} break;
|
| 548 |
case GGML_OP_SOFT_MAX:
|
| 549 |
{
|
| 550 |
-
if (encoder == nil) {
|
| 551 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 552 |
-
}
|
| 553 |
-
|
| 554 |
const int nth = 32;
|
| 555 |
|
| 556 |
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
@@ -565,11 +784,7 @@ void ggml_metal_graph_compute(
|
|
| 565 |
} break;
|
| 566 |
case GGML_OP_DIAG_MASK_INF:
|
| 567 |
{
|
| 568 |
-
|
| 569 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 570 |
-
}
|
| 571 |
-
|
| 572 |
-
const int n_past = ((int32_t *)(src1->data))[0];
|
| 573 |
|
| 574 |
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
| 575 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -585,53 +800,44 @@ void ggml_metal_graph_compute(
|
|
| 585 |
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
| 586 |
|
| 587 |
GGML_ASSERT(ne00 == ne10);
|
| 588 |
-
GGML_ASSERT(ne02 == ne12);
|
|
|
|
|
|
|
| 589 |
|
|
|
|
|
|
|
| 590 |
if (ggml_is_contiguous(src0) &&
|
| 591 |
ggml_is_contiguous(src1) &&
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
|
| 608 |
-
|
| 609 |
-
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
|
| 610 |
-
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
|
| 611 |
-
|
| 612 |
-
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
|
| 613 |
-
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
| 614 |
-
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
|
| 615 |
-
|
| 616 |
-
// we need to do ne02 multiplications
|
| 617 |
-
// TODO: is there a way to do this in parallel - currently very slow ..
|
| 618 |
-
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
|
| 619 |
-
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
| 620 |
-
size_t offs_src0_cur = offs_src0 + i02*nb02;
|
| 621 |
-
size_t offs_src1_cur = offs_src1 + i02*nb12;
|
| 622 |
-
size_t offs_dst_cur = offs_dst + i02*nb2;
|
| 623 |
-
|
| 624 |
-
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
|
| 625 |
-
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
|
| 626 |
-
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
|
| 627 |
-
|
| 628 |
-
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
| 629 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
} else {
|
| 631 |
-
if (encoder == nil) {
|
| 632 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 633 |
-
}
|
| 634 |
-
|
| 635 |
int nth0 = 32;
|
| 636 |
int nth1 = 1;
|
| 637 |
|
|
@@ -639,8 +845,6 @@ void ggml_metal_graph_compute(
|
|
| 639 |
switch (src0t) {
|
| 640 |
case GGML_TYPE_F16:
|
| 641 |
{
|
| 642 |
-
GGML_ASSERT(ne02 == ne12);
|
| 643 |
-
|
| 644 |
nth0 = 64;
|
| 645 |
nth1 = 1;
|
| 646 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
@@ -663,13 +867,22 @@ void ggml_metal_graph_compute(
|
|
| 663 |
nth1 = 8;
|
| 664 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
| 665 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
case GGML_TYPE_Q2_K:
|
| 667 |
{
|
| 668 |
GGML_ASSERT(ne02 == 1);
|
| 669 |
GGML_ASSERT(ne12 == 1);
|
| 670 |
|
| 671 |
-
nth0 =
|
| 672 |
-
nth1 =
|
| 673 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
| 674 |
} break;
|
| 675 |
case GGML_TYPE_Q3_K:
|
|
@@ -677,8 +890,8 @@ void ggml_metal_graph_compute(
|
|
| 677 |
GGML_ASSERT(ne02 == 1);
|
| 678 |
GGML_ASSERT(ne12 == 1);
|
| 679 |
|
| 680 |
-
nth0 =
|
| 681 |
-
nth1 =
|
| 682 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
| 683 |
} break;
|
| 684 |
case GGML_TYPE_Q4_K:
|
|
@@ -686,8 +899,8 @@ void ggml_metal_graph_compute(
|
|
| 686 |
GGML_ASSERT(ne02 == 1);
|
| 687 |
GGML_ASSERT(ne12 == 1);
|
| 688 |
|
| 689 |
-
nth0 =
|
| 690 |
-
nth1 =
|
| 691 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
| 692 |
} break;
|
| 693 |
case GGML_TYPE_Q5_K:
|
|
@@ -695,8 +908,8 @@ void ggml_metal_graph_compute(
|
|
| 695 |
GGML_ASSERT(ne02 == 1);
|
| 696 |
GGML_ASSERT(ne12 == 1);
|
| 697 |
|
| 698 |
-
nth0 =
|
| 699 |
-
nth1 =
|
| 700 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
| 701 |
} break;
|
| 702 |
case GGML_TYPE_Q6_K:
|
|
@@ -704,8 +917,8 @@ void ggml_metal_graph_compute(
|
|
| 704 |
GGML_ASSERT(ne02 == 1);
|
| 705 |
GGML_ASSERT(ne12 == 1);
|
| 706 |
|
| 707 |
-
nth0 =
|
| 708 |
-
nth1 =
|
| 709 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
| 710 |
} break;
|
| 711 |
default:
|
|
@@ -720,28 +933,36 @@ void ggml_metal_graph_compute(
|
|
| 720 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 721 |
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
| 722 |
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
| 723 |
-
[encoder setBytes:&
|
| 724 |
-
[encoder setBytes:&
|
| 725 |
-
[encoder setBytes:&
|
| 726 |
-
[encoder setBytes:&
|
| 727 |
-
[encoder setBytes:&
|
| 728 |
-
[encoder setBytes:&
|
| 729 |
-
[encoder setBytes:&
|
| 730 |
-
[encoder setBytes:&
|
| 731 |
-
[encoder setBytes:&
|
| 732 |
-
[encoder setBytes:&
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 737 |
}
|
| 738 |
-
else if (src0t ==
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
| 744 |
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 745 |
} else {
|
| 746 |
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
| 747 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
@@ -750,14 +971,11 @@ void ggml_metal_graph_compute(
|
|
| 750 |
} break;
|
| 751 |
case GGML_OP_GET_ROWS:
|
| 752 |
{
|
| 753 |
-
if (encoder == nil) {
|
| 754 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 755 |
-
}
|
| 756 |
-
|
| 757 |
switch (src0->type) {
|
| 758 |
-
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];
|
| 759 |
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
| 760 |
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
|
|
|
| 761 |
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
| 762 |
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
| 763 |
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
|
@@ -779,13 +997,10 @@ void ggml_metal_graph_compute(
|
|
| 779 |
} break;
|
| 780 |
case GGML_OP_RMS_NORM:
|
| 781 |
{
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
}
|
| 785 |
|
| 786 |
-
const
|
| 787 |
-
|
| 788 |
-
const int nth = 256;
|
| 789 |
|
| 790 |
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
| 791 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -793,7 +1008,7 @@ void ggml_metal_graph_compute(
|
|
| 793 |
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 794 |
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
| 795 |
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
| 796 |
-
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
| 797 |
|
| 798 |
const int64_t nrows = ggml_nrows(src0);
|
| 799 |
|
|
@@ -801,20 +1016,17 @@ void ggml_metal_graph_compute(
|
|
| 801 |
} break;
|
| 802 |
case GGML_OP_NORM:
|
| 803 |
{
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
}
|
| 807 |
-
|
| 808 |
-
const float eps = 1e-5f;
|
| 809 |
|
| 810 |
const int nth = 256;
|
| 811 |
|
| 812 |
[encoder setComputePipelineState:ctx->pipeline_norm];
|
| 813 |
-
[encoder setBuffer:id_src0 offset:offs_src0
|
| 814 |
-
[encoder setBuffer:id_dst offset:offs_dst
|
| 815 |
-
[encoder setBytes:&ne00
|
| 816 |
-
[encoder setBytes:&nb01
|
| 817 |
-
[encoder setBytes:&eps
|
| 818 |
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
| 819 |
|
| 820 |
const int64_t nrows = ggml_nrows(src0);
|
|
@@ -823,15 +1035,12 @@ void ggml_metal_graph_compute(
|
|
| 823 |
} break;
|
| 824 |
case GGML_OP_ALIBI:
|
| 825 |
{
|
| 826 |
-
if (encoder == nil) {
|
| 827 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 828 |
-
}
|
| 829 |
-
|
| 830 |
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
| 831 |
|
| 832 |
-
const int
|
| 833 |
-
const int
|
| 834 |
-
|
|
|
|
| 835 |
|
| 836 |
if (__builtin_popcount(n_head) != 1) {
|
| 837 |
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
|
@@ -860,51 +1069,53 @@ void ggml_metal_graph_compute(
|
|
| 860 |
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
| 861 |
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
| 862 |
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
|
|
|
| 863 |
const int nth = 32;
|
|
|
|
| 864 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 865 |
} break;
|
| 866 |
case GGML_OP_ROPE:
|
| 867 |
{
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
const int n_dims = ((int32_t *) src1->data)[1];
|
| 873 |
-
const int mode = ((int32_t *) src1->data)[2];
|
| 874 |
|
| 875 |
-
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
[encoder setComputePipelineState:ctx->pipeline_rope];
|
| 878 |
-
[encoder setBuffer:id_src0 offset:offs_src0
|
| 879 |
-
[encoder setBuffer:id_dst offset:offs_dst
|
| 880 |
-
[encoder setBytes:&ne00
|
| 881 |
-
[encoder setBytes:&ne01
|
| 882 |
-
[encoder setBytes:&ne02
|
| 883 |
-
[encoder setBytes:&ne03
|
| 884 |
-
[encoder setBytes:&nb00
|
| 885 |
-
[encoder setBytes:&nb01
|
| 886 |
-
[encoder setBytes:&nb02
|
| 887 |
-
[encoder setBytes:&nb03
|
| 888 |
-
[encoder setBytes:&ne0
|
| 889 |
-
[encoder setBytes:&ne1
|
| 890 |
-
[encoder setBytes:&ne2
|
| 891 |
-
[encoder setBytes:&ne3
|
| 892 |
-
[encoder setBytes:&nb0
|
| 893 |
-
[encoder setBytes:&nb1
|
| 894 |
-
[encoder setBytes:&nb2
|
| 895 |
-
[encoder setBytes:&nb3
|
| 896 |
-
[encoder setBytes:&n_past
|
| 897 |
-
[encoder setBytes:&n_dims
|
| 898 |
-
[encoder setBytes:&mode
|
|
|
|
|
|
|
| 899 |
|
| 900 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 901 |
} break;
|
|
|
|
| 902 |
case GGML_OP_CPY:
|
|
|
|
| 903 |
{
|
| 904 |
-
if (encoder == nil) {
|
| 905 |
-
encoder = [command_buffer computeCommandEncoder];
|
| 906 |
-
}
|
| 907 |
-
|
| 908 |
const int nth = 32;
|
| 909 |
|
| 910 |
switch (src0t) {
|
|
@@ -927,30 +1138,32 @@ void ggml_metal_graph_compute(
|
|
| 927 |
default: GGML_ASSERT(false && "not implemented");
|
| 928 |
}
|
| 929 |
|
| 930 |
-
[encoder setBuffer:id_src0 offset:offs_src0
|
| 931 |
-
[encoder setBuffer:id_dst offset:offs_dst
|
| 932 |
-
[encoder setBytes:&ne00
|
| 933 |
-
[encoder setBytes:&ne01
|
| 934 |
-
[encoder setBytes:&ne02
|
| 935 |
-
[encoder setBytes:&ne03
|
| 936 |
-
[encoder setBytes:&nb00
|
| 937 |
-
[encoder setBytes:&nb01
|
| 938 |
-
[encoder setBytes:&nb02
|
| 939 |
-
[encoder setBytes:&nb03
|
| 940 |
-
[encoder setBytes:&ne0
|
| 941 |
-
[encoder setBytes:&ne1
|
| 942 |
-
[encoder setBytes:&ne2
|
| 943 |
-
[encoder setBytes:&ne3
|
| 944 |
-
[encoder setBytes:&nb0
|
| 945 |
-
[encoder setBytes:&nb1
|
| 946 |
-
[encoder setBytes:&nb2
|
| 947 |
-
[encoder setBytes:&nb3
|
| 948 |
|
| 949 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 950 |
} break;
|
| 951 |
default:
|
| 952 |
-
|
| 953 |
-
|
|
|
|
|
|
|
| 954 |
}
|
| 955 |
}
|
| 956 |
|
|
@@ -964,17 +1177,19 @@ void ggml_metal_graph_compute(
|
|
| 964 |
}
|
| 965 |
|
| 966 |
// wait for all threads to finish
|
| 967 |
-
dispatch_barrier_sync(
|
| 968 |
-
|
| 969 |
-
[command_buffers[n_cb - 1] waitUntilCompleted];
|
| 970 |
|
| 971 |
// check status of command buffers
|
| 972 |
// needed to detect if the device ran out-of-memory for example (#1881)
|
| 973 |
for (int i = 0; i < n_cb; i++) {
|
| 974 |
-
|
|
|
|
|
|
|
| 975 |
if (status != MTLCommandBufferStatusCompleted) {
|
| 976 |
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
| 977 |
GGML_ASSERT(false);
|
| 978 |
}
|
| 979 |
}
|
|
|
|
|
|
|
| 980 |
}
|
|
|
|
| 5 |
#import <Foundation/Foundation.h>
|
| 6 |
|
| 7 |
#import <Metal/Metal.h>
|
| 8 |
+
|
| 9 |
+
#undef MIN
|
| 10 |
+
#undef MAX
|
| 11 |
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
| 12 |
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
| 13 |
|
| 14 |
#ifdef GGML_METAL_NDEBUG
|
| 15 |
#define metal_printf(...)
|
|
|
|
| 19 |
|
| 20 |
#define UNUSED(x) (void)(x)
|
| 21 |
|
| 22 |
+
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
|
| 23 |
+
|
| 24 |
struct ggml_metal_buffer {
|
| 25 |
const char * name;
|
| 26 |
|
|
|
|
| 31 |
};
|
| 32 |
|
| 33 |
struct ggml_metal_context {
|
| 34 |
+
int n_cb;
|
| 35 |
|
| 36 |
id<MTLDevice> device;
|
| 37 |
id<MTLCommandQueue> queue;
|
| 38 |
id<MTLLibrary> library;
|
| 39 |
|
| 40 |
+
id<MTLCommandBuffer> command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
|
| 41 |
+
id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
|
| 42 |
+
|
| 43 |
+
dispatch_queue_t d_queue;
|
| 44 |
+
|
| 45 |
int n_buffers;
|
| 46 |
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
|
| 47 |
|
| 48 |
+
int concur_list[GGML_MAX_CONCUR];
|
| 49 |
+
int concur_list_len;
|
| 50 |
+
|
| 51 |
// custom kernels
|
| 52 |
#define GGML_METAL_DECL_KERNEL(name) \
|
| 53 |
id<MTLFunction> function_##name; \
|
| 54 |
id<MTLComputePipelineState> pipeline_##name
|
| 55 |
|
| 56 |
GGML_METAL_DECL_KERNEL(add);
|
| 57 |
+
GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
|
| 58 |
GGML_METAL_DECL_KERNEL(mul);
|
| 59 |
GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
|
| 60 |
GGML_METAL_DECL_KERNEL(scale);
|
|
|
|
| 66 |
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
| 67 |
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
| 68 |
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
| 69 |
+
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
| 70 |
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
| 71 |
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
| 72 |
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
|
|
|
| 77 |
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
| 78 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
| 79 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
| 80 |
+
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
| 81 |
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
| 82 |
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
| 83 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
| 84 |
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
| 85 |
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
| 86 |
+
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
| 87 |
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
| 88 |
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
| 89 |
+
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
| 90 |
+
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
| 91 |
+
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
| 92 |
+
GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
|
| 93 |
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
| 94 |
+
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
| 95 |
GGML_METAL_DECL_KERNEL(rope);
|
| 96 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
| 97 |
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
|
|
| 112 |
@implementation GGMLMetalClass
|
| 113 |
@end
|
| 114 |
|
| 115 |
+
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
| 116 |
fprintf(stderr, "%s: allocating\n", __func__);
|
| 117 |
|
| 118 |
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
| 119 |
|
| 120 |
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
| 121 |
ctx->device = MTLCreateSystemDefaultDevice();
|
| 122 |
ctx->queue = [ctx->device newCommandQueue];
|
| 123 |
ctx->n_buffers = 0;
|
| 124 |
+
ctx->concur_list_len = 0;
|
| 125 |
|
| 126 |
+
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
#if 0
|
| 129 |
// compile from source string and show compile log
|
|
|
|
| 133 |
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
| 134 |
if (error) {
|
| 135 |
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 136 |
+
return NULL;
|
| 137 |
}
|
| 138 |
}
|
| 139 |
#else
|
|
|
|
| 151 |
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
| 152 |
if (error) {
|
| 153 |
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 154 |
+
return NULL;
|
| 155 |
}
|
| 156 |
|
| 157 |
#ifdef GGML_QKK_64
|
|
|
|
| 163 |
#endif
|
| 164 |
if (error) {
|
| 165 |
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 166 |
+
return NULL;
|
| 167 |
}
|
| 168 |
}
|
| 169 |
#endif
|
| 170 |
|
| 171 |
// load kernels
|
| 172 |
{
|
| 173 |
+
NSError * error = nil;
|
| 174 |
#define GGML_METAL_ADD_KERNEL(name) \
|
| 175 |
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
| 176 |
+
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
| 177 |
+
fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
| 178 |
+
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
| 179 |
+
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
| 180 |
+
if (error) { \
|
| 181 |
+
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
| 182 |
+
return NULL; \
|
| 183 |
+
}
|
| 184 |
|
| 185 |
GGML_METAL_ADD_KERNEL(add);
|
| 186 |
+
GGML_METAL_ADD_KERNEL(add_row);
|
| 187 |
GGML_METAL_ADD_KERNEL(mul);
|
| 188 |
GGML_METAL_ADD_KERNEL(mul_row);
|
| 189 |
GGML_METAL_ADD_KERNEL(scale);
|
|
|
|
| 195 |
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
| 196 |
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
| 197 |
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
| 198 |
+
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
| 199 |
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
| 200 |
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
| 201 |
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
|
|
|
| 206 |
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
| 207 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
| 208 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
| 209 |
+
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
| 210 |
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
| 211 |
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
| 212 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
| 213 |
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
| 214 |
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
| 215 |
+
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
| 216 |
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
| 217 |
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
| 218 |
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
| 219 |
+
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
| 220 |
+
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
| 221 |
+
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
| 222 |
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
| 223 |
+
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
| 224 |
GGML_METAL_ADD_KERNEL(rope);
|
| 225 |
GGML_METAL_ADD_KERNEL(alibi_f32);
|
| 226 |
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
|
|
| 230 |
#undef GGML_METAL_ADD_KERNEL
|
| 231 |
}
|
| 232 |
|
| 233 |
+
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
| 234 |
+
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
| 235 |
if (ctx->device.maxTransferRate != 0) {
|
| 236 |
+
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
| 237 |
} else {
|
| 238 |
+
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
|
| 239 |
}
|
| 240 |
|
| 241 |
return ctx;
|
|
|
|
| 243 |
|
| 244 |
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
| 245 |
fprintf(stderr, "%s: deallocating\n", __func__);
|
| 246 |
+
#define GGML_METAL_DEL_KERNEL(name) \
|
| 247 |
+
[ctx->function_##name release]; \
|
| 248 |
+
[ctx->pipeline_##name release];
|
| 249 |
+
|
| 250 |
+
GGML_METAL_DEL_KERNEL(add);
|
| 251 |
+
GGML_METAL_DEL_KERNEL(add_row);
|
| 252 |
+
GGML_METAL_DEL_KERNEL(mul);
|
| 253 |
+
GGML_METAL_DEL_KERNEL(mul_row);
|
| 254 |
+
GGML_METAL_DEL_KERNEL(scale);
|
| 255 |
+
GGML_METAL_DEL_KERNEL(silu);
|
| 256 |
+
GGML_METAL_DEL_KERNEL(relu);
|
| 257 |
+
GGML_METAL_DEL_KERNEL(gelu);
|
| 258 |
+
GGML_METAL_DEL_KERNEL(soft_max);
|
| 259 |
+
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
| 260 |
+
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
| 261 |
+
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
| 262 |
+
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
| 263 |
+
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
| 264 |
+
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
| 265 |
+
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
| 266 |
+
GGML_METAL_DEL_KERNEL(get_rows_q4_K);
|
| 267 |
+
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
| 268 |
+
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
| 269 |
+
GGML_METAL_DEL_KERNEL(rms_norm);
|
| 270 |
+
GGML_METAL_DEL_KERNEL(norm);
|
| 271 |
+
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
| 272 |
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
| 273 |
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
| 274 |
+
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
| 275 |
+
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
|
| 276 |
+
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
|
| 277 |
+
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
|
| 278 |
+
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
|
| 279 |
+
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
|
| 280 |
+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
| 281 |
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
| 282 |
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
| 283 |
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
| 284 |
+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
| 285 |
+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
| 286 |
+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
| 287 |
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
| 288 |
+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
| 289 |
+
GGML_METAL_DEL_KERNEL(rope);
|
| 290 |
+
GGML_METAL_DEL_KERNEL(alibi_f32);
|
| 291 |
+
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
| 292 |
+
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
| 293 |
+
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
| 294 |
+
|
| 295 |
+
#undef GGML_METAL_DEL_KERNEL
|
| 296 |
+
|
| 297 |
for (int i = 0; i < ctx->n_buffers; ++i) {
|
| 298 |
[ctx->buffers[i].metal release];
|
| 299 |
}
|
| 300 |
+
|
| 301 |
+
[ctx->library release];
|
| 302 |
+
[ctx->queue release];
|
| 303 |
+
[ctx->device release];
|
| 304 |
+
|
| 305 |
+
dispatch_release(ctx->d_queue);
|
| 306 |
+
|
| 307 |
free(ctx);
|
| 308 |
}
|
| 309 |
|
| 310 |
+
void * ggml_metal_host_malloc(size_t n) {
|
| 311 |
+
void * data = NULL;
|
| 312 |
+
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
| 313 |
+
if (result != 0) {
|
| 314 |
+
fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
|
| 315 |
+
return NULL;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
return data;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
void ggml_metal_host_free(void * data) {
|
| 322 |
+
free(data);
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
|
| 326 |
+
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
int ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
|
| 330 |
+
return ctx->concur_list_len;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
| 334 |
+
return ctx->concur_list;
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
// finds the Metal buffer that contains the tensor data on the GPU device
|
| 338 |
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
|
| 339 |
// Metal buffer based on the host memory pointer
|
|
|
|
| 472 |
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
|
| 473 |
}
|
| 474 |
|
| 475 |
+
void ggml_metal_graph_find_concurrency(
|
| 476 |
+
struct ggml_metal_context * ctx,
|
| 477 |
+
struct ggml_cgraph * gf, bool check_mem) {
|
| 478 |
+
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
|
| 479 |
+
int nodes_unused[GGML_MAX_CONCUR];
|
| 480 |
+
|
| 481 |
+
for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
|
| 482 |
+
for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
|
| 483 |
+
ctx->concur_list_len = 0;
|
| 484 |
+
|
| 485 |
+
int n_left = gf->n_nodes;
|
| 486 |
+
int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
|
| 487 |
+
int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
|
| 488 |
+
|
| 489 |
+
while (n_left > 0) {
|
| 490 |
+
// number of nodes at a layer (that can be issued concurrently)
|
| 491 |
+
int concurrency = 0;
|
| 492 |
+
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
|
| 493 |
+
if (nodes_unused[i]) {
|
| 494 |
+
// if the requirements for gf->nodes[i] are satisfied
|
| 495 |
+
int exe_flag = 1;
|
| 496 |
+
|
| 497 |
+
// scan all srcs
|
| 498 |
+
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
|
| 499 |
+
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
|
| 500 |
+
if (src_cur) {
|
| 501 |
+
// if is leaf nodes it's satisfied.
|
| 502 |
+
// TODO: ggml_is_leaf()
|
| 503 |
+
if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
|
| 504 |
+
continue;
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
// otherwise this src should be the output from previous nodes.
|
| 508 |
+
int is_found = 0;
|
| 509 |
+
|
| 510 |
+
// scan 2*search_depth back because we inserted barrier.
|
| 511 |
+
//for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
|
| 512 |
+
for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
|
| 513 |
+
if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
|
| 514 |
+
is_found = 1;
|
| 515 |
+
break;
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
if (is_found == 0) {
|
| 519 |
+
exe_flag = 0;
|
| 520 |
+
break;
|
| 521 |
+
}
|
| 522 |
+
}
|
| 523 |
+
}
|
| 524 |
+
if (exe_flag && check_mem) {
|
| 525 |
+
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
|
| 526 |
+
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
|
| 527 |
+
int64_t data_start = (int64_t) gf->nodes[i]->data;
|
| 528 |
+
int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
|
| 529 |
+
for (int j = n_start; j < i; j++) {
|
| 530 |
+
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
|
| 531 |
+
&& gf->nodes[j]->op != GGML_OP_VIEW \
|
| 532 |
+
&& gf->nodes[j]->op != GGML_OP_TRANSPOSE \
|
| 533 |
+
&& gf->nodes[j]->op != GGML_OP_PERMUTE) {
|
| 534 |
+
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
|
| 535 |
+
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
|
| 536 |
+
continue;
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
exe_flag = 0;
|
| 540 |
+
}
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
if (exe_flag) {
|
| 544 |
+
ctx->concur_list[level_pos + concurrency] = i;
|
| 545 |
+
nodes_unused[i] = 0;
|
| 546 |
+
concurrency++;
|
| 547 |
+
ctx->concur_list_len++;
|
| 548 |
+
}
|
| 549 |
+
}
|
| 550 |
+
}
|
| 551 |
+
n_left -= concurrency;
|
| 552 |
+
// adding a barrier different layer
|
| 553 |
+
ctx->concur_list[level_pos + concurrency] = -1;
|
| 554 |
+
ctx->concur_list_len++;
|
| 555 |
+
// jump all sorted nodes at nodes_bak
|
| 556 |
+
while (!nodes_unused[n_start]) {
|
| 557 |
+
n_start++;
|
| 558 |
+
}
|
| 559 |
+
level_pos += concurrency + 1;
|
| 560 |
+
}
|
| 561 |
+
|
| 562 |
+
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
| 563 |
+
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
| 564 |
+
}
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
void ggml_metal_graph_compute(
|
| 568 |
struct ggml_metal_context * ctx,
|
| 569 |
struct ggml_cgraph * gf) {
|
| 570 |
metal_printf("%s: evaluating graph\n", __func__);
|
| 571 |
|
| 572 |
+
@autoreleasepool {
|
| 573 |
+
|
| 574 |
+
// if there is ctx->concur_list, dispatch concurrently
|
| 575 |
+
// else fallback to serial dispatch
|
| 576 |
+
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
| 577 |
+
|
| 578 |
+
const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
|
| 579 |
+
|
| 580 |
+
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
|
| 581 |
+
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
|
| 582 |
+
|
| 583 |
// create multiple command buffers and enqueue them
|
| 584 |
// then, we encode the graph into the command buffers in parallel
|
| 585 |
|
| 586 |
+
const int n_cb = ctx->n_cb;
|
|
|
|
|
|
|
| 587 |
|
| 588 |
for (int i = 0; i < n_cb; ++i) {
|
| 589 |
+
ctx->command_buffers[i] = [ctx->queue commandBuffer];
|
| 590 |
|
| 591 |
// enqueue the command buffers in order to specify their execution order
|
| 592 |
+
[ctx->command_buffers[i] enqueue];
|
|
|
|
| 593 |
|
| 594 |
+
ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
|
| 595 |
+
}
|
| 596 |
|
| 597 |
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
| 598 |
+
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
| 599 |
|
| 600 |
+
dispatch_async(ctx->d_queue, ^{
|
| 601 |
size_t offs_src0 = 0;
|
| 602 |
size_t offs_src1 = 0;
|
| 603 |
size_t offs_dst = 0;
|
| 604 |
|
| 605 |
+
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
| 606 |
+
id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
|
|
|
|
| 607 |
|
| 608 |
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
| 609 |
+
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
| 610 |
+
|
| 611 |
+
for (int ind = node_start; ind < node_end; ++ind) {
|
| 612 |
+
const int i = has_concur ? ctx->concur_list[ind] : ind;
|
| 613 |
+
|
| 614 |
+
if (i == -1) {
|
| 615 |
+
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
| 616 |
+
continue;
|
| 617 |
+
}
|
| 618 |
|
|
|
|
| 619 |
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
| 620 |
|
| 621 |
+
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
| 622 |
+
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
| 623 |
struct ggml_tensor * dst = gf->nodes[i];
|
| 624 |
|
| 625 |
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
|
|
|
| 675 |
//}
|
| 676 |
|
| 677 |
switch (dst->op) {
|
| 678 |
+
case GGML_OP_NONE:
|
| 679 |
case GGML_OP_RESHAPE:
|
| 680 |
case GGML_OP_VIEW:
|
| 681 |
case GGML_OP_TRANSPOSE:
|
|
|
|
| 685 |
} break;
|
| 686 |
case GGML_OP_ADD:
|
| 687 |
{
|
| 688 |
+
if (ggml_nelements(src1) == ne10) {
|
| 689 |
+
// src1 is a row
|
| 690 |
+
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
| 691 |
+
} else {
|
| 692 |
+
[encoder setComputePipelineState:ctx->pipeline_add];
|
| 693 |
}
|
|
|
|
|
|
|
| 694 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 695 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 696 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 697 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
| 698 |
|
| 699 |
const int64_t n = ggml_nelements(dst);
|
| 700 |
|
|
|
|
| 702 |
} break;
|
| 703 |
case GGML_OP_MUL:
|
| 704 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
if (ggml_nelements(src1) == ne10) {
|
| 706 |
// src1 is a row
|
| 707 |
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
|
|
|
| 719 |
} break;
|
| 720 |
case GGML_OP_SCALE:
|
| 721 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 722 |
const float scale = *(const float *) src1->data;
|
| 723 |
|
| 724 |
[encoder setComputePipelineState:ctx->pipeline_scale];
|
|
|
|
| 730 |
|
| 731 |
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 732 |
} break;
|
| 733 |
+
case GGML_OP_UNARY:
|
| 734 |
+
switch (ggml_get_unary_op(gf->nodes[i])) {
|
| 735 |
+
case GGML_UNARY_OP_SILU:
|
| 736 |
+
{
|
| 737 |
+
[encoder setComputePipelineState:ctx->pipeline_silu];
|
| 738 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 739 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 740 |
+
|
| 741 |
+
const int64_t n = ggml_nelements(dst);
|
| 742 |
+
|
| 743 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 744 |
+
} break;
|
| 745 |
+
case GGML_UNARY_OP_RELU:
|
| 746 |
+
{
|
| 747 |
+
[encoder setComputePipelineState:ctx->pipeline_relu];
|
| 748 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 749 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 750 |
+
|
| 751 |
+
const int64_t n = ggml_nelements(dst);
|
| 752 |
+
|
| 753 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 754 |
+
} break;
|
| 755 |
+
case GGML_UNARY_OP_GELU:
|
| 756 |
+
{
|
| 757 |
+
[encoder setComputePipelineState:ctx->pipeline_gelu];
|
| 758 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 759 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 760 |
+
|
| 761 |
+
const int64_t n = ggml_nelements(dst);
|
| 762 |
+
|
| 763 |
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 764 |
+
} break;
|
| 765 |
+
default:
|
| 766 |
+
{
|
| 767 |
+
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
| 768 |
+
GGML_ASSERT(false);
|
| 769 |
+
}
|
| 770 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 771 |
case GGML_OP_SOFT_MAX:
|
| 772 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
const int nth = 32;
|
| 774 |
|
| 775 |
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
|
|
|
| 784 |
} break;
|
| 785 |
case GGML_OP_DIAG_MASK_INF:
|
| 786 |
{
|
| 787 |
+
const int n_past = ((int32_t *)(dst->op_params))[0];
|
|
|
|
|
|
|
|
|
|
|
|
|
| 788 |
|
| 789 |
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
|
| 790 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
| 800 |
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
|
| 801 |
|
| 802 |
GGML_ASSERT(ne00 == ne10);
|
| 803 |
+
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
|
| 804 |
+
uint gqa = ne12/ne02;
|
| 805 |
+
GGML_ASSERT(ne03 == ne13);
|
| 806 |
|
| 807 |
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 808 |
+
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
| 809 |
if (ggml_is_contiguous(src0) &&
|
| 810 |
ggml_is_contiguous(src1) &&
|
| 811 |
+
src1t == GGML_TYPE_F32 &&
|
| 812 |
+
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
| 813 |
+
ne00%32 == 0 &&
|
| 814 |
+
ne11 > 1) {
|
| 815 |
+
switch (src0->type) {
|
| 816 |
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
| 817 |
+
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
| 818 |
+
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
| 819 |
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
| 820 |
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
| 821 |
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
| 822 |
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
|
| 823 |
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
| 824 |
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
| 825 |
+
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
}
|
| 827 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 828 |
+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 829 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 830 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
| 831 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
| 832 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
|
| 833 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
|
| 834 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
| 835 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
|
| 836 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
|
| 837 |
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
|
| 838 |
+
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
| 839 |
+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 840 |
} else {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
int nth0 = 32;
|
| 842 |
int nth1 = 1;
|
| 843 |
|
|
|
|
| 845 |
switch (src0t) {
|
| 846 |
case GGML_TYPE_F16:
|
| 847 |
{
|
|
|
|
|
|
|
| 848 |
nth0 = 64;
|
| 849 |
nth1 = 1;
|
| 850 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
|
|
| 867 |
nth1 = 8;
|
| 868 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
|
| 869 |
} break;
|
| 870 |
+
case GGML_TYPE_Q8_0:
|
| 871 |
+
{
|
| 872 |
+
GGML_ASSERT(ne02 == 1);
|
| 873 |
+
GGML_ASSERT(ne12 == 1);
|
| 874 |
+
|
| 875 |
+
nth0 = 8;
|
| 876 |
+
nth1 = 8;
|
| 877 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
|
| 878 |
+
} break;
|
| 879 |
case GGML_TYPE_Q2_K:
|
| 880 |
{
|
| 881 |
GGML_ASSERT(ne02 == 1);
|
| 882 |
GGML_ASSERT(ne12 == 1);
|
| 883 |
|
| 884 |
+
nth0 = 2;
|
| 885 |
+
nth1 = 32;
|
| 886 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
| 887 |
} break;
|
| 888 |
case GGML_TYPE_Q3_K:
|
|
|
|
| 890 |
GGML_ASSERT(ne02 == 1);
|
| 891 |
GGML_ASSERT(ne12 == 1);
|
| 892 |
|
| 893 |
+
nth0 = 2;
|
| 894 |
+
nth1 = 32;
|
| 895 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
| 896 |
} break;
|
| 897 |
case GGML_TYPE_Q4_K:
|
|
|
|
| 899 |
GGML_ASSERT(ne02 == 1);
|
| 900 |
GGML_ASSERT(ne12 == 1);
|
| 901 |
|
| 902 |
+
nth0 = 2;
|
| 903 |
+
nth1 = 32;
|
| 904 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
| 905 |
} break;
|
| 906 |
case GGML_TYPE_Q5_K:
|
|
|
|
| 908 |
GGML_ASSERT(ne02 == 1);
|
| 909 |
GGML_ASSERT(ne12 == 1);
|
| 910 |
|
| 911 |
+
nth0 = 2;
|
| 912 |
+
nth1 = 32;
|
| 913 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
| 914 |
} break;
|
| 915 |
case GGML_TYPE_Q6_K:
|
|
|
|
| 917 |
GGML_ASSERT(ne02 == 1);
|
| 918 |
GGML_ASSERT(ne12 == 1);
|
| 919 |
|
| 920 |
+
nth0 = 2;
|
| 921 |
+
nth1 = 32;
|
| 922 |
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
| 923 |
} break;
|
| 924 |
default:
|
|
|
|
| 933 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 934 |
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
| 935 |
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
| 936 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
| 937 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
| 938 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
| 939 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
| 940 |
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
| 941 |
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
| 942 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
|
| 943 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
|
| 944 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
|
| 945 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
|
| 946 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
|
| 947 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
| 948 |
+
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
| 949 |
+
|
| 950 |
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
| 951 |
+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
| 952 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 953 |
+
}
|
| 954 |
+
else if (src0t == GGML_TYPE_Q3_K) {
|
| 955 |
+
#ifdef GGML_QKK_64
|
| 956 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 957 |
+
#else
|
| 958 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 959 |
+
#endif
|
| 960 |
}
|
| 961 |
+
else if (src0t == GGML_TYPE_Q5_K) {
|
| 962 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 963 |
+
}
|
| 964 |
+
else if (src0t == GGML_TYPE_Q6_K) {
|
| 965 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
|
|
|
| 966 |
} else {
|
| 967 |
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
| 968 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
|
| 971 |
} break;
|
| 972 |
case GGML_OP_GET_ROWS:
|
| 973 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 974 |
switch (src0->type) {
|
| 975 |
+
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
| 976 |
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
| 977 |
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
| 978 |
+
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
| 979 |
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
| 980 |
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
| 981 |
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
|
|
|
| 997 |
} break;
|
| 998 |
case GGML_OP_RMS_NORM:
|
| 999 |
{
|
| 1000 |
+
float eps;
|
| 1001 |
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
|
|
| 1002 |
|
| 1003 |
+
const int nth = 512;
|
|
|
|
|
|
|
| 1004 |
|
| 1005 |
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
|
| 1006 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
| 1008 |
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 1009 |
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
| 1010 |
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
| 1011 |
+
[encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
|
| 1012 |
|
| 1013 |
const int64_t nrows = ggml_nrows(src0);
|
| 1014 |
|
|
|
|
| 1016 |
} break;
|
| 1017 |
case GGML_OP_NORM:
|
| 1018 |
{
|
| 1019 |
+
float eps;
|
| 1020 |
+
memcpy(&eps, dst->op_params, sizeof(float));
|
|
|
|
|
|
|
|
|
|
| 1021 |
|
| 1022 |
const int nth = 256;
|
| 1023 |
|
| 1024 |
[encoder setComputePipelineState:ctx->pipeline_norm];
|
| 1025 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1026 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1027 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 1028 |
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
| 1029 |
+
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
| 1030 |
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
| 1031 |
|
| 1032 |
const int64_t nrows = ggml_nrows(src0);
|
|
|
|
| 1035 |
} break;
|
| 1036 |
case GGML_OP_ALIBI:
|
| 1037 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1038 |
GGML_ASSERT((src0t == GGML_TYPE_F32));
|
| 1039 |
|
| 1040 |
+
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
|
| 1041 |
+
const int n_head = ((int32_t *) dst->op_params)[1];
|
| 1042 |
+
float max_bias;
|
| 1043 |
+
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
| 1044 |
|
| 1045 |
if (__builtin_popcount(n_head) != 1) {
|
| 1046 |
GGML_ASSERT(false && "only power-of-two n_head implemented");
|
|
|
|
| 1069 |
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
| 1070 |
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
| 1071 |
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
|
| 1072 |
+
|
| 1073 |
const int nth = 32;
|
| 1074 |
+
|
| 1075 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1076 |
} break;
|
| 1077 |
case GGML_OP_ROPE:
|
| 1078 |
{
|
| 1079 |
+
const int n_past = ((int32_t *) dst->op_params)[0];
|
| 1080 |
+
const int n_dims = ((int32_t *) dst->op_params)[1];
|
| 1081 |
+
const int mode = ((int32_t *) dst->op_params)[2];
|
|
|
|
|
|
|
|
|
|
| 1082 |
|
| 1083 |
+
float freq_base;
|
| 1084 |
+
float freq_scale;
|
| 1085 |
+
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
| 1086 |
+
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
| 1087 |
|
| 1088 |
[encoder setComputePipelineState:ctx->pipeline_rope];
|
| 1089 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1090 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1091 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 1092 |
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
| 1093 |
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
| 1094 |
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
| 1095 |
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
| 1096 |
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
| 1097 |
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
| 1098 |
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
| 1099 |
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
| 1100 |
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
| 1101 |
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
| 1102 |
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
| 1103 |
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
| 1104 |
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
| 1105 |
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
| 1106 |
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
| 1107 |
+
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
|
| 1108 |
+
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
|
| 1109 |
+
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
|
| 1110 |
+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
|
| 1111 |
+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
|
| 1112 |
|
| 1113 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
| 1114 |
} break;
|
| 1115 |
+
case GGML_OP_DUP:
|
| 1116 |
case GGML_OP_CPY:
|
| 1117 |
+
case GGML_OP_CONT:
|
| 1118 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1119 |
const int nth = 32;
|
| 1120 |
|
| 1121 |
switch (src0t) {
|
|
|
|
| 1138 |
default: GGML_ASSERT(false && "not implemented");
|
| 1139 |
}
|
| 1140 |
|
| 1141 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 1142 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
| 1143 |
+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
| 1144 |
+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
| 1145 |
+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
| 1146 |
+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
| 1147 |
+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
| 1148 |
+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
| 1149 |
+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
| 1150 |
+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
| 1151 |
+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
| 1152 |
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
| 1153 |
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
| 1154 |
+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
| 1155 |
+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
| 1156 |
+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
| 1157 |
+
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
| 1158 |
+
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
| 1159 |
|
| 1160 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
| 1161 |
} break;
|
| 1162 |
default:
|
| 1163 |
+
{
|
| 1164 |
+
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
| 1165 |
+
GGML_ASSERT(false);
|
| 1166 |
+
}
|
| 1167 |
}
|
| 1168 |
}
|
| 1169 |
|
|
|
|
| 1177 |
}
|
| 1178 |
|
| 1179 |
// wait for all threads to finish
|
| 1180 |
+
dispatch_barrier_sync(ctx->d_queue, ^{});
|
|
|
|
|
|
|
| 1181 |
|
| 1182 |
// check status of command buffers
|
| 1183 |
// needed to detect if the device ran out-of-memory for example (#1881)
|
| 1184 |
for (int i = 0; i < n_cb; i++) {
|
| 1185 |
+
[ctx->command_buffers[i] waitUntilCompleted];
|
| 1186 |
+
|
| 1187 |
+
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
| 1188 |
if (status != MTLCommandBufferStatusCompleted) {
|
| 1189 |
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
| 1190 |
GGML_ASSERT(false);
|
| 1191 |
}
|
| 1192 |
}
|
| 1193 |
+
|
| 1194 |
+
}
|
| 1195 |
}
|
ggml-metal.metal
CHANGED
|
@@ -18,46 +18,11 @@ typedef struct {
|
|
| 18 |
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
| 19 |
} block_q4_1;
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
const int nb = k / qk;
|
| 27 |
-
|
| 28 |
-
for (int i = 0; i < nb; i++) {
|
| 29 |
-
const half d = x[i].d;
|
| 30 |
-
|
| 31 |
-
for (int j = 0; j < qk/2; ++j) {
|
| 32 |
-
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
| 33 |
-
const int x1 = (x[i].qs[j] >> 4) - 8;
|
| 34 |
-
|
| 35 |
-
y[i*qk + j + 0 ] = x0*d;
|
| 36 |
-
y[i*qk + j + qk/2] = x1*d;
|
| 37 |
-
}
|
| 38 |
-
}
|
| 39 |
-
}
|
| 40 |
-
|
| 41 |
-
static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
|
| 42 |
-
const int qk = QK4_1;
|
| 43 |
-
|
| 44 |
-
assert(k % qk == 0);
|
| 45 |
-
|
| 46 |
-
const int nb = k / qk;
|
| 47 |
-
|
| 48 |
-
for (int i = 0; i < nb; i++) {
|
| 49 |
-
const half d = x[i].d;
|
| 50 |
-
const half m = x[i].m;
|
| 51 |
-
|
| 52 |
-
for (int j = 0; j < qk/2; ++j) {
|
| 53 |
-
const int x0 = (x[i].qs[j] & 0x0F);
|
| 54 |
-
const int x1 = (x[i].qs[j] >> 4);
|
| 55 |
-
|
| 56 |
-
y[i*qk + j + 0 ] = x0*d + m;
|
| 57 |
-
y[i*qk + j + qk/2] = x1*d + m;
|
| 58 |
-
}
|
| 59 |
-
}
|
| 60 |
-
}
|
| 61 |
|
| 62 |
kernel void kernel_add(
|
| 63 |
device const float * src0,
|
|
@@ -67,6 +32,17 @@ kernel void kernel_add(
|
|
| 67 |
dst[tpig] = src0[tpig] + src1[tpig];
|
| 68 |
}
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
kernel void kernel_mul(
|
| 71 |
device const float * src0,
|
| 72 |
device const float * src1,
|
|
@@ -117,7 +93,12 @@ kernel void kernel_gelu(
|
|
| 117 |
device float * dst,
|
| 118 |
uint tpig[[thread_position_in_grid]]) {
|
| 119 |
float x = src0[tpig];
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
}
|
| 122 |
|
| 123 |
kernel void kernel_soft_max(
|
|
@@ -208,54 +189,6 @@ kernel void kernel_diag_mask_inf(
|
|
| 208 |
}
|
| 209 |
}
|
| 210 |
|
| 211 |
-
kernel void kernel_get_rows_f16(
|
| 212 |
-
device const void * src0,
|
| 213 |
-
device const int * src1,
|
| 214 |
-
device float * dst,
|
| 215 |
-
constant int64_t & ne00,
|
| 216 |
-
constant uint64_t & nb01,
|
| 217 |
-
constant uint64_t & nb1,
|
| 218 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 219 |
-
const int i = tpig;
|
| 220 |
-
const int r = ((device int32_t *) src1)[i];
|
| 221 |
-
|
| 222 |
-
for (int j = 0; j < ne00; j++) {
|
| 223 |
-
dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
|
| 224 |
-
}
|
| 225 |
-
}
|
| 226 |
-
|
| 227 |
-
kernel void kernel_get_rows_q4_0(
|
| 228 |
-
device const void * src0,
|
| 229 |
-
device const int * src1,
|
| 230 |
-
device float * dst,
|
| 231 |
-
constant int64_t & ne00,
|
| 232 |
-
constant uint64_t & nb01,
|
| 233 |
-
constant uint64_t & nb1,
|
| 234 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 235 |
-
const int i = tpig;
|
| 236 |
-
const int r = ((device int32_t *) src1)[i];
|
| 237 |
-
|
| 238 |
-
dequantize_row_q4_0(
|
| 239 |
-
(device const block_q4_0 *) ((device char *) src0 + r*nb01),
|
| 240 |
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 241 |
-
}
|
| 242 |
-
|
| 243 |
-
kernel void kernel_get_rows_q4_1(
|
| 244 |
-
device const void * src0,
|
| 245 |
-
device const int * src1,
|
| 246 |
-
device float * dst,
|
| 247 |
-
constant int64_t & ne00,
|
| 248 |
-
constant uint64_t & nb01,
|
| 249 |
-
constant uint64_t & nb1,
|
| 250 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 251 |
-
const int i = tpig;
|
| 252 |
-
const int r = ((device int32_t *) src1)[i];
|
| 253 |
-
|
| 254 |
-
dequantize_row_q4_1(
|
| 255 |
-
(device const block_q4_1 *) ((device char *) src0 + r*nb01),
|
| 256 |
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
kernel void kernel_norm(
|
| 260 |
device const void * src0,
|
| 261 |
device float * dst,
|
|
@@ -331,26 +264,33 @@ kernel void kernel_rms_norm(
|
|
| 331 |
threadgroup float * sum [[threadgroup(0)]],
|
| 332 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 333 |
uint tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
|
|
| 334 |
uint ntg[[threads_per_threadgroup]]) {
|
| 335 |
-
device const
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
// parallel sum
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
}
|
| 342 |
|
| 343 |
-
// reduce
|
| 344 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
}
|
| 351 |
-
|
| 352 |
-
// broadcast
|
| 353 |
if (tpitg == 0) {
|
|
|
|
| 354 |
sum[0] /= ne00;
|
| 355 |
}
|
| 356 |
|
|
@@ -359,146 +299,201 @@ kernel void kernel_rms_norm(
|
|
| 359 |
const float mean = sum[0];
|
| 360 |
const float scale = 1.0f/sqrt(mean + eps);
|
| 361 |
|
| 362 |
-
device
|
| 363 |
-
|
|
|
|
| 364 |
y[i00] = x[i00] * scale;
|
| 365 |
}
|
|
|
|
|
|
|
|
|
|
| 366 |
}
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
|
| 385 |
-
device const float * y = (device const float *) src1 + r1*ne10;
|
| 386 |
-
|
| 387 |
-
const int nth = tptg.x*tptg.y;
|
| 388 |
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 389 |
-
|
| 390 |
-
const int ix = tpitg.y/4; // 0 or 1
|
| 391 |
-
const int iy = tpitg.y - 4*ix; // 0...3
|
| 392 |
-
|
| 393 |
-
const int first = 4 * iy;
|
| 394 |
-
|
| 395 |
-
float sumf = 0;
|
| 396 |
-
|
| 397 |
-
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
| 398 |
-
|
| 399 |
-
const float d = (float)x[i].d;
|
| 400 |
-
|
| 401 |
-
device const uint8_t * xl = x[i].qs + first;
|
| 402 |
-
device const float * yl = y + i * QK4_0 + first;
|
| 403 |
-
|
| 404 |
-
float2 acc = {0.0f, 0.0f};
|
| 405 |
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
|
|
|
|
|
|
| 411 |
}
|
| 412 |
|
| 413 |
-
|
| 414 |
}
|
| 415 |
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 422 |
-
if (ith%4 == 0) {
|
| 423 |
-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
| 424 |
-
}
|
| 425 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 426 |
-
if (ith%16 == 0) {
|
| 427 |
-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
| 428 |
-
}
|
| 429 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 430 |
-
if (ith == 0) {
|
| 431 |
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
| 432 |
-
dst[r1*ne0 + r0] = sum[0];
|
| 433 |
}
|
| 434 |
}
|
| 435 |
|
| 436 |
-
kernel void
|
| 437 |
device const void * src0,
|
| 438 |
device const float * src1,
|
| 439 |
device float * dst,
|
| 440 |
constant int64_t & ne00,
|
| 441 |
-
constant int64_t &
|
| 442 |
-
constant int64_t &
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
device const float * y = (device const float *) src1 + r1*ne10;
|
| 454 |
-
|
| 455 |
-
const uint nth = tptg.x*tptg.y;
|
| 456 |
-
const uint ith = tptg.y*tpitg.x + tpitg.y;
|
| 457 |
-
|
| 458 |
-
const int ix = tpitg.y/4; // 0 or 1
|
| 459 |
-
const int iy = tpitg.y - 4*ix; // 0...3
|
| 460 |
-
|
| 461 |
-
const int first = 4 * iy;
|
| 462 |
-
|
| 463 |
-
float sumf = 0;
|
| 464 |
-
|
| 465 |
-
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
|
| 466 |
-
|
| 467 |
-
const float d = (float)x[i].d;
|
| 468 |
-
const float m = (float)x[i].m;
|
| 469 |
-
|
| 470 |
-
device const uint8_t * xl = x[i].qs + first;
|
| 471 |
-
device const float * yl = y + i * QK4_1 + first;
|
| 472 |
-
|
| 473 |
-
float2 acc = {0.0f, 0.0f};
|
| 474 |
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
-
|
| 478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
}
|
| 481 |
|
| 482 |
-
|
| 483 |
}
|
| 484 |
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 491 |
-
if (ith%4 == 0) {
|
| 492 |
-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
| 493 |
-
}
|
| 494 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 495 |
-
if (ith%16 == 0) {
|
| 496 |
-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
| 497 |
-
}
|
| 498 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 499 |
-
if (ith == 0) {
|
| 500 |
-
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
| 501 |
-
dst[r1*ne0 + r0] = sum[0];
|
| 502 |
}
|
| 503 |
}
|
| 504 |
|
|
@@ -508,11 +503,13 @@ kernel void kernel_mul_mat_f16_f32(
|
|
| 508 |
device float * dst,
|
| 509 |
constant int64_t & ne00,
|
| 510 |
constant int64_t & ne01,
|
|
|
|
| 511 |
constant uint64_t & nb00,
|
| 512 |
constant uint64_t & nb01,
|
| 513 |
constant uint64_t & nb02,
|
| 514 |
constant int64_t & ne10,
|
| 515 |
constant int64_t & ne11,
|
|
|
|
| 516 |
constant uint64_t & nb10,
|
| 517 |
constant uint64_t & nb11,
|
| 518 |
constant uint64_t & nb12,
|
|
@@ -528,7 +525,7 @@ kernel void kernel_mul_mat_f16_f32(
|
|
| 528 |
const int64_t r1 = tgpig.y;
|
| 529 |
const int64_t im = tgpig.z;
|
| 530 |
|
| 531 |
-
device const half * x = (device const half *) (src0 + r0*nb01 + im*nb02);
|
| 532 |
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
| 533 |
|
| 534 |
sum[tpitg.x] = 0.0f;
|
|
@@ -615,17 +612,19 @@ kernel void kernel_rope(
|
|
| 615 |
constant int & n_past,
|
| 616 |
constant int & n_dims,
|
| 617 |
constant int & mode,
|
|
|
|
|
|
|
| 618 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 619 |
const int64_t i3 = tpig[2];
|
| 620 |
const int64_t i2 = tpig[1];
|
| 621 |
const int64_t i1 = tpig[0];
|
| 622 |
|
| 623 |
const bool is_neox = mode & 2;
|
| 624 |
-
const float theta_scale = pow(
|
| 625 |
|
| 626 |
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
| 627 |
|
| 628 |
-
float theta = (float)p;
|
| 629 |
|
| 630 |
if (!is_neox) {
|
| 631 |
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
@@ -644,7 +643,25 @@ kernel void kernel_rope(
|
|
| 644 |
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
| 645 |
}
|
| 646 |
} else {
|
| 647 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 648 |
}
|
| 649 |
}
|
| 650 |
|
|
@@ -863,779 +880,581 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
| 863 |
return r;
|
| 864 |
}
|
| 865 |
|
| 866 |
-
|
| 867 |
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
|
| 872 |
-
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
-
|
| 875 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
-
|
| 878 |
|
| 879 |
#if QK_K == 256
|
| 880 |
-
|
| 881 |
-
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
|
| 894 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 895 |
}
|
| 896 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
}
|
|
|
|
|
|
|
|
|
|
| 898 |
#else
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
}
|
| 909 |
-
|
|
|
|
|
|
|
| 910 |
#endif
|
| 911 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 912 |
}
|
| 913 |
}
|
| 914 |
|
| 915 |
-
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
| 916 |
-
assert(k % QK_K == 0);
|
| 917 |
-
const int nb = k / QK_K;
|
| 918 |
-
|
| 919 |
#if QK_K == 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 920 |
|
| 921 |
const uint16_t kmask1 = 0x0303;
|
| 922 |
const uint16_t kmask2 = 0x0f0f;
|
| 923 |
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
|
|
|
|
|
|
|
|
|
| 928 |
|
| 929 |
-
|
|
|
|
| 930 |
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
|
| 937 |
-
aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
|
| 938 |
-
aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
|
| 939 |
-
aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
|
| 940 |
-
aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
|
| 941 |
-
aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
|
| 942 |
-
aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
|
| 943 |
-
aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
|
| 944 |
-
|
| 945 |
-
int is = 0;
|
| 946 |
-
float dl;
|
| 947 |
-
for (int n = 0; n < QK_K; n += 128) {
|
| 948 |
-
int shift = 0;
|
| 949 |
-
for (int j = 0; j < 4; ++j) {
|
| 950 |
-
|
| 951 |
-
dl = d_all * (scales[is++] - 32);
|
| 952 |
-
for (int l = 0; l < 16; ++l) {
|
| 953 |
-
*y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
|
| 954 |
-
}
|
| 955 |
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
}
|
| 960 |
|
| 961 |
-
|
| 962 |
-
|
| 963 |
-
}
|
| 964 |
-
q += 32;
|
| 965 |
-
}
|
| 966 |
-
}
|
| 967 |
-
#else
|
| 968 |
-
for (int i = 0; i < nb; i++) {
|
| 969 |
|
| 970 |
-
|
| 971 |
|
| 972 |
-
|
| 973 |
-
device const uint8_t * hm = x[i].hmask;
|
| 974 |
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
| 978 |
-
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
| 979 |
|
| 980 |
for (int l = 0; l < 8; ++l) {
|
| 981 |
-
|
| 982 |
-
|
| 983 |
-
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
| 984 |
-
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
| 985 |
-
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
| 986 |
-
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
| 987 |
-
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
| 988 |
-
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
| 989 |
-
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
| 990 |
}
|
| 991 |
-
y += QK_K;
|
| 992 |
-
}
|
| 993 |
-
#endif
|
| 994 |
|
| 995 |
-
|
|
|
|
|
|
|
|
|
|
| 996 |
|
| 997 |
-
|
| 998 |
-
assert(k % QK_K == 0);
|
| 999 |
-
const int nb = k / QK_K;
|
| 1000 |
|
| 1001 |
-
|
|
|
|
| 1002 |
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
|
| 1005 |
-
#if QK_K == 256
|
| 1006 |
-
const float d = x[i].d;
|
| 1007 |
-
const float min = x[i].dmin;
|
| 1008 |
-
|
| 1009 |
-
device const uint8_t * scales = x[i].scales;
|
| 1010 |
-
|
| 1011 |
-
int is = 0;
|
| 1012 |
-
for (int j = 0; j < QK_K; j += 64) {
|
| 1013 |
-
const uchar4 sc = get_scale_min_k4(is, scales);
|
| 1014 |
-
const float d1 = d * sc[0]; const float m1 = min * sc[1];
|
| 1015 |
-
const float d2 = d * sc[2]; const float m2 = min * sc[3];
|
| 1016 |
-
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
| 1017 |
-
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
| 1018 |
-
q += 32; is += 2;
|
| 1019 |
-
}
|
| 1020 |
-
#else
|
| 1021 |
-
device const uint8_t * s = x[i].scales;
|
| 1022 |
-
device const half2 * dh = (device const half2 *)x[i].d;
|
| 1023 |
-
const float2 d = (float2)dh[0];
|
| 1024 |
-
const float d1 = d[0] * (s[0] & 0xF);
|
| 1025 |
-
const float d2 = d[0] * (s[1] & 0xF);
|
| 1026 |
-
const float m1 = d[1] * (s[0] >> 4);
|
| 1027 |
-
const float m2 = d[1] * (s[1] >> 4);
|
| 1028 |
-
for (int l = 0; l < 32; ++l) {
|
| 1029 |
-
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
| 1030 |
-
y[l+32] = d2 * (q[l] >> 4) - m2;
|
| 1031 |
}
|
| 1032 |
-
y += QK_K;
|
| 1033 |
-
#endif
|
| 1034 |
|
| 1035 |
-
|
| 1036 |
-
}
|
| 1037 |
|
| 1038 |
-
|
| 1039 |
-
assert(k % QK_K == 0);
|
| 1040 |
-
const int nb = k / QK_K;
|
| 1041 |
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
device const uint8_t * ql = x[i].qs;
|
| 1049 |
-
device const uint8_t * qh = x[i].qh;
|
| 1050 |
-
|
| 1051 |
-
int is = 0;
|
| 1052 |
-
uint8_t u1 = 1, u2 = 2;
|
| 1053 |
-
for (int j = 0; j < QK_K; j += 64) {
|
| 1054 |
-
const uchar4 sc = get_scale_min_k4(is, x[i].scales);
|
| 1055 |
-
const float d1 = d * sc[0]; const float m1 = min * sc[1];
|
| 1056 |
-
const float d2 = d * sc[2]; const float m2 = min * sc[3];
|
| 1057 |
-
for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
|
| 1058 |
-
for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
|
| 1059 |
-
ql += 32; is += 2;
|
| 1060 |
-
u1 <<= 2; u2 <<= 2;
|
| 1061 |
}
|
| 1062 |
}
|
|
|
|
| 1063 |
#else
|
| 1064 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1065 |
|
| 1066 |
-
|
| 1067 |
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1071 |
|
| 1072 |
-
|
| 1073 |
-
y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
| 1074 |
-
y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
| 1075 |
-
y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
| 1076 |
-
y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
| 1077 |
-
y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
| 1078 |
-
y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
| 1079 |
-
y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
| 1080 |
-
y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
| 1081 |
-
}
|
| 1082 |
-
y += QK_K;
|
| 1083 |
-
}
|
| 1084 |
-
#endif
|
| 1085 |
|
| 1086 |
-
|
| 1087 |
|
| 1088 |
-
|
| 1089 |
-
assert(k % QK_K == 0);
|
| 1090 |
-
const int nb = k / QK_K;
|
| 1091 |
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
device const
|
| 1095 |
-
device const
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
const float
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
-
}
|
| 1113 |
-
y += 128;
|
| 1114 |
-
ql += 64;
|
| 1115 |
-
qh += 32;
|
| 1116 |
-
sc += 8;
|
| 1117 |
}
|
| 1118 |
-
#else
|
| 1119 |
-
for (int l = 0; l < 16; ++l) {
|
| 1120 |
-
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
| 1121 |
-
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
| 1122 |
-
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
| 1123 |
-
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
| 1124 |
-
y[l+ 0] = d * sc[0] * q1;
|
| 1125 |
-
y[l+16] = d * sc[1] * q2;
|
| 1126 |
-
y[l+32] = d * sc[2] * q3;
|
| 1127 |
-
y[l+48] = d * sc[3] * q4;
|
| 1128 |
-
}
|
| 1129 |
-
y += 64;
|
| 1130 |
-
#endif
|
| 1131 |
-
}
|
| 1132 |
-
}
|
| 1133 |
-
|
| 1134 |
-
kernel void kernel_get_rows_q2_K(
|
| 1135 |
-
device const void * src0,
|
| 1136 |
-
device const int * src1,
|
| 1137 |
-
device float * dst,
|
| 1138 |
-
constant int64_t & ne00,
|
| 1139 |
-
constant uint64_t & nb01,
|
| 1140 |
-
constant uint64_t & nb1,
|
| 1141 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 1142 |
-
const int i = tpig;
|
| 1143 |
-
const int r = ((device int32_t *) src1)[i];
|
| 1144 |
-
|
| 1145 |
-
dequantize_row_q2_K(
|
| 1146 |
-
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
| 1147 |
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1148 |
-
}
|
| 1149 |
-
|
| 1150 |
-
kernel void kernel_get_rows_q3_K(
|
| 1151 |
-
device const void * src0,
|
| 1152 |
-
device const int * src1,
|
| 1153 |
-
device float * dst,
|
| 1154 |
-
constant int64_t & ne00,
|
| 1155 |
-
constant uint64_t & nb01,
|
| 1156 |
-
constant uint64_t & nb1,
|
| 1157 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 1158 |
-
const int i = tpig;
|
| 1159 |
-
const int r = ((device int32_t *) src1)[i];
|
| 1160 |
-
|
| 1161 |
-
dequantize_row_q3_K(
|
| 1162 |
-
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
| 1163 |
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1164 |
-
}
|
| 1165 |
-
|
| 1166 |
-
kernel void kernel_get_rows_q4_K(
|
| 1167 |
-
device const void * src0,
|
| 1168 |
-
device const int * src1,
|
| 1169 |
-
device float * dst,
|
| 1170 |
-
constant int64_t & ne00,
|
| 1171 |
-
constant uint64_t & nb01,
|
| 1172 |
-
constant uint64_t & nb1,
|
| 1173 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 1174 |
-
const int i = tpig;
|
| 1175 |
-
const int r = ((device int32_t *) src1)[i];
|
| 1176 |
-
|
| 1177 |
-
dequantize_row_q4_K(
|
| 1178 |
-
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
| 1179 |
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1180 |
-
}
|
| 1181 |
-
|
| 1182 |
-
kernel void kernel_get_rows_q5_K(
|
| 1183 |
-
device const void * src0,
|
| 1184 |
-
device const int * src1,
|
| 1185 |
-
device float * dst,
|
| 1186 |
-
constant int64_t & ne00,
|
| 1187 |
-
constant uint64_t & nb01,
|
| 1188 |
-
constant uint64_t & nb1,
|
| 1189 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 1190 |
-
const int i = tpig;
|
| 1191 |
-
const int r = ((device int32_t *) src1)[i];
|
| 1192 |
-
|
| 1193 |
-
dequantize_row_q5_K(
|
| 1194 |
-
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
| 1195 |
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1196 |
-
}
|
| 1197 |
-
|
| 1198 |
-
kernel void kernel_get_rows_q6_K(
|
| 1199 |
-
device const void * src0,
|
| 1200 |
-
device const int * src1,
|
| 1201 |
-
device float * dst,
|
| 1202 |
-
constant int64_t & ne00,
|
| 1203 |
-
constant uint64_t & nb01,
|
| 1204 |
-
constant uint64_t & nb1,
|
| 1205 |
-
uint tpig[[thread_position_in_grid]]) {
|
| 1206 |
-
const int i = tpig;
|
| 1207 |
-
const int r = ((device int32_t *) src1)[i];
|
| 1208 |
-
|
| 1209 |
-
dequantize_row_q6_K(
|
| 1210 |
-
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
| 1211 |
-
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1212 |
-
}
|
| 1213 |
-
|
| 1214 |
-
//====================================== dot products =========================
|
| 1215 |
-
|
| 1216 |
-
kernel void kernel_mul_mat_q2_K_f32(
|
| 1217 |
-
device const void * src0,
|
| 1218 |
-
device const float * src1,
|
| 1219 |
-
device float * dst,
|
| 1220 |
-
constant int64_t & ne00,
|
| 1221 |
-
constant int64_t & ne10,
|
| 1222 |
-
constant int64_t & ne0,
|
| 1223 |
-
threadgroup float * sum [[threadgroup(0)]],
|
| 1224 |
-
uint2 tgpig[[threadgroup_position_in_grid]],
|
| 1225 |
-
uint2 tpitg[[thread_position_in_threadgroup]],
|
| 1226 |
-
uint2 tptg[[threads_per_threadgroup]]) {
|
| 1227 |
-
|
| 1228 |
-
const int nb = ne00/QK_K;
|
| 1229 |
-
|
| 1230 |
-
const int64_t r0 = tgpig.x;
|
| 1231 |
-
const int64_t r1 = tgpig.y;
|
| 1232 |
-
|
| 1233 |
-
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
|
| 1234 |
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1235 |
-
|
| 1236 |
-
const int nth = tptg.x*tptg.y;
|
| 1237 |
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1238 |
-
|
| 1239 |
-
float sumf = 0;
|
| 1240 |
-
|
| 1241 |
-
#if QK_K == 256
|
| 1242 |
-
const int tid = tpitg.y; // 0...16
|
| 1243 |
-
const int il = tid/4; // 0...3
|
| 1244 |
-
const int ir = tid%4; // 0...3
|
| 1245 |
-
const int ip = il/2; // 0 or 1
|
| 1246 |
-
const int shift1 = 4*(il%2);// 0 or 4
|
| 1247 |
-
const int shift2 = shift1+2;// 2 or 6
|
| 1248 |
-
const int n = 8;
|
| 1249 |
-
const int is = 4*il + (n*ir)/16;
|
| 1250 |
-
|
| 1251 |
-
const int y_offset = 64*il + n*ir;
|
| 1252 |
-
const int q_offset = 32*ip + n*ir;
|
| 1253 |
-
|
| 1254 |
-
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1255 |
-
|
| 1256 |
-
device const uint8_t * q = x[i].qs + q_offset;
|
| 1257 |
-
device const uint8_t * scales = x[i].scales + is;
|
| 1258 |
-
|
| 1259 |
-
uint8_t d1 = scales[0] & 0xF;
|
| 1260 |
-
uint8_t d2 = scales[2] & 0xF;
|
| 1261 |
-
uint8_t m1 = scales[0] >> 4;
|
| 1262 |
-
uint8_t m2 = scales[2] >> 4;
|
| 1263 |
-
|
| 1264 |
-
device const float * y = yy + i*QK_K + y_offset;
|
| 1265 |
-
|
| 1266 |
-
float2 s = {0.f, 0.f};
|
| 1267 |
-
float smin = 0;
|
| 1268 |
-
for (int l = 0; l < n; ++l) {
|
| 1269 |
-
s[0] += y[l+ 0] * ((q[l] >> shift1) & 3);
|
| 1270 |
-
s[1] += y[l+32] * ((q[l] >> shift2) & 3);
|
| 1271 |
-
smin += y[l+ 0] * m1 + y[l+32] * m2;
|
| 1272 |
-
}
|
| 1273 |
-
|
| 1274 |
-
const float dall = (float)x[i].d;
|
| 1275 |
-
const float dmin = (float)x[i].dmin;
|
| 1276 |
-
|
| 1277 |
-
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
| 1278 |
|
| 1279 |
}
|
| 1280 |
-
|
| 1281 |
-
const int il = 4 * tpitg.x;
|
| 1282 |
-
|
| 1283 |
-
uint32_t aux[2];
|
| 1284 |
-
thread const uint8_t * d = (thread const uint8_t *)aux;
|
| 1285 |
-
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
| 1286 |
|
| 1287 |
-
|
| 1288 |
-
|
| 1289 |
-
|
| 1290 |
-
device const float * y = yy + i*QK_K + il;
|
| 1291 |
-
|
| 1292 |
-
const float dall = (float)x[i].d;
|
| 1293 |
-
const float dmin = (float)x[i].dmin;
|
| 1294 |
-
|
| 1295 |
-
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
| 1296 |
-
aux[0] = a[0] & 0x0f0f0f0f;
|
| 1297 |
-
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
| 1298 |
-
|
| 1299 |
-
for (int l = 0; l < 4; ++l) {
|
| 1300 |
-
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
|
| 1301 |
-
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
|
| 1302 |
-
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
|
| 1303 |
-
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
|
| 1304 |
-
}
|
| 1305 |
}
|
| 1306 |
-
#endif
|
| 1307 |
-
|
| 1308 |
-
sum[ith] = sumf;
|
| 1309 |
|
| 1310 |
-
//
|
| 1311 |
-
// Accumulate the sum from all threads in the threadgroup
|
| 1312 |
-
//
|
| 1313 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1314 |
-
if (ith%4 == 0) {
|
| 1315 |
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
| 1316 |
-
}
|
| 1317 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1318 |
-
if (ith%16 == 0) {
|
| 1319 |
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
| 1320 |
-
}
|
| 1321 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1322 |
-
if (ith == 0) {
|
| 1323 |
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
| 1324 |
-
dst[r1*ne0 + r0] = sum[0];
|
| 1325 |
-
}
|
| 1326 |
}
|
|
|
|
| 1327 |
|
| 1328 |
-
|
|
|
|
| 1329 |
device const void * src0,
|
| 1330 |
device const float * src1,
|
| 1331 |
device float * dst,
|
| 1332 |
constant int64_t & ne00,
|
| 1333 |
-
constant int64_t &
|
| 1334 |
-
constant int64_t &
|
| 1335 |
-
constant int64_t &
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
|
| 1343 |
-
const int64_t r0 = tgpig.x;
|
| 1344 |
-
const int64_t r1 = tgpig.y;
|
| 1345 |
-
|
| 1346 |
-
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
| 1347 |
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1348 |
-
|
| 1349 |
-
const int nth = tptg.x*tptg.y;
|
| 1350 |
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1351 |
-
|
| 1352 |
-
#if QK_K == 256
|
| 1353 |
-
|
| 1354 |
-
const uint8_t m3 = 3;
|
| 1355 |
-
const int8_t m4 = 4;
|
| 1356 |
|
| 1357 |
-
const uint16_t kmask1 =
|
| 1358 |
const uint16_t kmask2 = 0x0f0f;
|
|
|
|
| 1359 |
|
| 1360 |
-
const int
|
| 1361 |
-
const int
|
| 1362 |
-
const int
|
| 1363 |
-
const int ir
|
| 1364 |
-
const int n = 8;
|
| 1365 |
-
const int l0 = n*ir;
|
| 1366 |
-
|
| 1367 |
-
const uint8_t m = 1 << (4*ip + il);
|
| 1368 |
-
|
| 1369 |
-
const int shift = 2*il;
|
| 1370 |
-
|
| 1371 |
-
const uint16_t s_shift1 = 4*ip;
|
| 1372 |
-
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
| 1373 |
-
const int ik = 4 + (il%2);
|
| 1374 |
-
|
| 1375 |
-
const int q_offset = 32*ip + l0;
|
| 1376 |
-
const int y_offset = 128*ip + 32*il + l0;
|
| 1377 |
-
|
| 1378 |
-
//float sumf = 0;
|
| 1379 |
-
float sumf1 = 0, sumf2 = 0;
|
| 1380 |
-
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1381 |
-
|
| 1382 |
-
const float d_all = (float)(x[i].d);
|
| 1383 |
-
|
| 1384 |
-
device const uint8_t * q = x[i].qs + q_offset;
|
| 1385 |
-
device const uint8_t * h = x[i].hmask + l0;
|
| 1386 |
-
device const float * y = yy + i * QK_K + y_offset;
|
| 1387 |
-
|
| 1388 |
-
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
| 1389 |
-
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
| 1390 |
-
|
| 1391 |
-
float s = 0;
|
| 1392 |
-
for (int l = 0; l < n; ++l) {
|
| 1393 |
-
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
|
| 1394 |
-
}
|
| 1395 |
-
float d = d_all * s;
|
| 1396 |
-
sumf1 += d * scales[0];
|
| 1397 |
-
sumf2 += d;
|
| 1398 |
-
//sumf += d_all * s * (scales[0] - 32);
|
| 1399 |
|
| 1400 |
-
|
| 1401 |
-
|
| 1402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1403 |
}
|
| 1404 |
-
d = d_all * s;
|
| 1405 |
-
sumf1 += d * scales[1];
|
| 1406 |
-
sumf2 += d;
|
| 1407 |
-
//sumf += d_all * s * (scales[1] - 32);
|
| 1408 |
|
| 1409 |
-
|
| 1410 |
-
|
| 1411 |
-
|
| 1412 |
-
|
| 1413 |
-
|
| 1414 |
-
|
| 1415 |
-
|
| 1416 |
-
|
| 1417 |
-
|
| 1418 |
-
|
| 1419 |
-
|
| 1420 |
-
|
| 1421 |
-
|
| 1422 |
-
|
| 1423 |
-
|
| 1424 |
-
|
| 1425 |
-
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
-
|
| 1429 |
-
|
| 1430 |
-
|
| 1431 |
-
|
|
|
|
|
|
|
| 1432 |
|
| 1433 |
-
|
| 1434 |
-
|
| 1435 |
-
sumf
|
| 1436 |
-
|
| 1437 |
-
|
| 1438 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1439 |
}
|
| 1440 |
|
|
|
|
| 1441 |
}
|
| 1442 |
|
| 1443 |
-
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
// Accumulate the sum from all threads in the threadgroup
|
| 1449 |
-
//
|
| 1450 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1451 |
-
if (ith%4 == 0) {
|
| 1452 |
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
| 1453 |
-
}
|
| 1454 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1455 |
-
if (ith%16 == 0) {
|
| 1456 |
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
| 1457 |
-
}
|
| 1458 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1459 |
-
if (ith == 0) {
|
| 1460 |
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
| 1461 |
-
dst[r1*ne0 + r0] = sum[0];
|
| 1462 |
}
|
| 1463 |
-
|
| 1464 |
}
|
| 1465 |
-
|
| 1466 |
kernel void kernel_mul_mat_q4_K_f32(
|
| 1467 |
device const void * src0,
|
| 1468 |
device const float * src1,
|
| 1469 |
device float * dst,
|
| 1470 |
constant int64_t & ne00,
|
| 1471 |
-
constant int64_t &
|
| 1472 |
-
constant int64_t &
|
| 1473 |
-
|
| 1474 |
-
|
| 1475 |
-
|
| 1476 |
-
|
| 1477 |
-
|
| 1478 |
-
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
-
const int64_t r1 = tgpig.y;
|
| 1482 |
-
|
| 1483 |
-
const int nth = tptg.x*tptg.y;
|
| 1484 |
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1485 |
-
|
| 1486 |
-
device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
|
| 1487 |
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1488 |
-
|
| 1489 |
-
float sumf = 0;
|
| 1490 |
-
|
| 1491 |
-
#if QK_K == 256
|
| 1492 |
-
|
| 1493 |
-
const uint16_t kmask1 = 0x3f3f;
|
| 1494 |
-
const uint16_t kmask2 = 0x0f0f;
|
| 1495 |
-
const uint16_t kmask3 = 0xc0c0;
|
| 1496 |
-
|
| 1497 |
-
const int tid = tpitg.y; // 0...16
|
| 1498 |
-
const int il = tid/4; // 0...3
|
| 1499 |
-
const int ir = tid - 4*il;// 0...3
|
| 1500 |
-
const int n = 4;
|
| 1501 |
-
|
| 1502 |
-
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
| 1503 |
-
const int in = il%2;
|
| 1504 |
-
|
| 1505 |
-
const int l0 = n*(2*ir + in);
|
| 1506 |
-
const int q_offset = 32*im + l0;
|
| 1507 |
-
const int y_offset = 64*im + l0;
|
| 1508 |
|
| 1509 |
-
|
|
|
|
| 1510 |
|
| 1511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1512 |
|
| 1513 |
-
device const
|
| 1514 |
-
device const
|
| 1515 |
-
device const
|
| 1516 |
-
device const float * y2 = y1 + 128;
|
| 1517 |
|
| 1518 |
-
|
| 1519 |
-
const float dmin = (float)((x + i)->dmin);
|
| 1520 |
|
| 1521 |
-
|
| 1522 |
-
|
| 1523 |
-
|
| 1524 |
-
|
| 1525 |
-
sc4 = as_type<uchar2>((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)));
|
| 1526 |
|
| 1527 |
-
|
| 1528 |
-
|
| 1529 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1530 |
|
| 1531 |
-
|
| 1532 |
-
|
| 1533 |
-
|
|
|
|
|
|
|
| 1534 |
|
|
|
|
|
|
|
|
|
|
| 1535 |
}
|
| 1536 |
-
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
| 1537 |
|
|
|
|
| 1538 |
}
|
| 1539 |
-
#else
|
| 1540 |
-
uint16_t aux16[2];
|
| 1541 |
-
thread const uint8_t * scales = (thread const uint8_t *)aux16;
|
| 1542 |
-
|
| 1543 |
-
const int il = 4*tpitg.x;
|
| 1544 |
-
|
| 1545 |
-
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
| 1546 |
|
| 1547 |
-
|
| 1548 |
-
|
| 1549 |
-
|
| 1550 |
-
|
| 1551 |
-
const float m = (float)x[i].d[1];
|
| 1552 |
-
|
| 1553 |
-
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
| 1554 |
-
aux16[0] = a[0] & 0x0f0f;
|
| 1555 |
-
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
| 1556 |
-
|
| 1557 |
-
for (int l = 0; l < 4; ++l) {
|
| 1558 |
-
sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
|
| 1559 |
-
+ d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
|
| 1560 |
}
|
| 1561 |
}
|
| 1562 |
-
#endif
|
| 1563 |
-
|
| 1564 |
-
sum[ith] = sumf;
|
| 1565 |
-
|
| 1566 |
-
//
|
| 1567 |
-
// Accumulate the sum from all threads in the threadgroup
|
| 1568 |
-
// This version is slightly faster than the commented out one below,
|
| 1569 |
-
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
| 1570 |
-
//
|
| 1571 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1572 |
-
if (ith%4 == 0) {
|
| 1573 |
-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
| 1574 |
-
}
|
| 1575 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1576 |
-
if (ith%16 == 0) {
|
| 1577 |
-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
| 1578 |
-
}
|
| 1579 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1580 |
-
if (ith == 0) {
|
| 1581 |
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
| 1582 |
-
dst[r1*ne0 + r0] = sum[0];
|
| 1583 |
-
}
|
| 1584 |
-
|
| 1585 |
-
//// accumulate the sum from all threads in the threadgroup
|
| 1586 |
-
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1587 |
-
//for (uint i = nth/2; i > 0; i /= 2) {
|
| 1588 |
-
// if (ith < i) {
|
| 1589 |
-
// sum[ith] += sum[ith + i];
|
| 1590 |
-
// }
|
| 1591 |
-
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1592 |
-
//}
|
| 1593 |
-
|
| 1594 |
-
//if (ith == 0) {
|
| 1595 |
-
// dst[r1*ne0 + r0] = sum[0];
|
| 1596 |
-
//}
|
| 1597 |
}
|
|
|
|
| 1598 |
|
| 1599 |
kernel void kernel_mul_mat_q5_K_f32(
|
| 1600 |
device const void * src0,
|
| 1601 |
device const float * src1,
|
| 1602 |
device float * dst,
|
| 1603 |
constant int64_t & ne00,
|
| 1604 |
-
constant int64_t &
|
| 1605 |
-
constant int64_t &
|
| 1606 |
-
|
| 1607 |
-
|
| 1608 |
-
|
| 1609 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1610 |
|
| 1611 |
const int nb = ne00/QK_K;
|
| 1612 |
|
| 1613 |
const int64_t r0 = tgpig.x;
|
| 1614 |
const int64_t r1 = tgpig.y;
|
|
|
|
| 1615 |
|
| 1616 |
-
|
| 1617 |
-
|
|
|
|
|
|
|
| 1618 |
|
| 1619 |
-
|
| 1620 |
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1621 |
|
| 1622 |
-
|
| 1623 |
|
| 1624 |
#if QK_K == 256
|
|
|
|
|
|
|
| 1625 |
|
| 1626 |
const uint16_t kmask1 = 0x3f3f;
|
| 1627 |
const uint16_t kmask2 = 0x0f0f;
|
| 1628 |
const uint16_t kmask3 = 0xc0c0;
|
| 1629 |
|
| 1630 |
-
const int tid =
|
| 1631 |
-
const int
|
| 1632 |
-
const int
|
| 1633 |
-
const int
|
| 1634 |
-
|
| 1635 |
-
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
| 1636 |
-
const int in = il%2;
|
| 1637 |
|
| 1638 |
-
const int l0 = n*
|
| 1639 |
const int q_offset = 32*im + l0;
|
| 1640 |
const int y_offset = 64*im + l0;
|
| 1641 |
|
|
@@ -1644,78 +1463,113 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|
| 1644 |
const uint8_t hm3 = hm1 << 4;
|
| 1645 |
const uint8_t hm4 = hm2 << 4;
|
| 1646 |
|
| 1647 |
-
|
|
|
|
| 1648 |
|
| 1649 |
-
|
| 1650 |
|
| 1651 |
-
|
| 1652 |
-
device const uint8_t * q2 = q1 + 64;
|
| 1653 |
-
device const uint8_t * qh = (x + i)->qh + l0;
|
| 1654 |
-
device const float * y1 = yy + i*QK_K + y_offset;
|
| 1655 |
-
device const float * y2 = y1 + 128;
|
| 1656 |
|
| 1657 |
-
const
|
| 1658 |
-
const
|
|
|
|
|
|
|
| 1659 |
|
| 1660 |
-
device const
|
| 1661 |
-
|
| 1662 |
-
|
| 1663 |
-
|
| 1664 |
-
|
|
|
|
|
|
|
|
|
|
| 1665 |
|
| 1666 |
-
|
| 1667 |
-
|
| 1668 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1669 |
|
| 1670 |
-
|
| 1671 |
-
|
| 1672 |
-
|
| 1673 |
-
|
| 1674 |
-
smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1];
|
| 1675 |
|
| 1676 |
}
|
| 1677 |
-
|
|
|
|
| 1678 |
|
| 1679 |
}
|
| 1680 |
#else
|
| 1681 |
-
|
| 1682 |
-
|
| 1683 |
-
const int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1684 |
|
| 1685 |
-
for (int i =
|
| 1686 |
|
| 1687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1688 |
device const uint8_t * q = x[i].qs + il;
|
| 1689 |
device const uint8_t * h = x[i].qh + in;
|
| 1690 |
device const int8_t * s = x[i].scales;
|
| 1691 |
-
device const float * y = yy + i*QK_K + il;
|
| 1692 |
|
| 1693 |
-
for (int
|
| 1694 |
-
|
| 1695 |
-
|
| 1696 |
-
|
| 1697 |
-
|
| 1698 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1699 |
}
|
|
|
|
|
|
|
| 1700 |
}
|
| 1701 |
#endif
|
| 1702 |
-
sum[ith] = sumf;
|
| 1703 |
|
| 1704 |
-
|
| 1705 |
-
|
| 1706 |
-
|
| 1707 |
-
|
| 1708 |
-
|
| 1709 |
-
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
|
| 1710 |
-
}
|
| 1711 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1712 |
-
if (ith%16 == 0) {
|
| 1713 |
-
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
|
| 1714 |
-
}
|
| 1715 |
-
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1716 |
-
if (ith == 0) {
|
| 1717 |
-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
| 1718 |
-
dst[r1*ne0 + r0] = sum[0];
|
| 1719 |
}
|
| 1720 |
|
| 1721 |
}
|
|
@@ -1725,12 +1579,16 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
| 1725 |
device const float * src1,
|
| 1726 |
device float * dst,
|
| 1727 |
constant int64_t & ne00,
|
| 1728 |
-
constant int64_t &
|
| 1729 |
-
constant int64_t &
|
| 1730 |
-
|
| 1731 |
-
|
| 1732 |
-
|
| 1733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1734 |
|
| 1735 |
const uint8_t kmask1 = 0x03;
|
| 1736 |
const uint8_t kmask2 = 0x0C;
|
|
@@ -1741,20 +1599,20 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
| 1741 |
|
| 1742 |
const int64_t r0 = tgpig.x;
|
| 1743 |
const int64_t r1 = tgpig.y;
|
|
|
|
| 1744 |
|
| 1745 |
-
|
| 1746 |
-
|
| 1747 |
-
|
| 1748 |
-
const
|
| 1749 |
-
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1750 |
|
| 1751 |
float sumf = 0;
|
| 1752 |
|
| 1753 |
#if QK_K == 256
|
| 1754 |
-
|
| 1755 |
-
const int
|
| 1756 |
-
const int ip =
|
| 1757 |
-
const int il =
|
| 1758 |
const int n = 4;
|
| 1759 |
const int l0 = n*il;
|
| 1760 |
const int is = 8*ip + l0/16;
|
|
@@ -1763,9 +1621,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
| 1763 |
const int q_offset_l = 64*ip + l0;
|
| 1764 |
const int q_offset_h = 32*ip + l0;
|
| 1765 |
|
| 1766 |
-
for (int i =
|
| 1767 |
|
| 1768 |
-
device const uint8_t *
|
|
|
|
| 1769 |
device const uint8_t * qh = x[i].qh + q_offset_h;
|
| 1770 |
device const int8_t * sc = x[i].scales + is;
|
| 1771 |
|
|
@@ -1775,19 +1634,21 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
| 1775 |
|
| 1776 |
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
| 1777 |
for (int l = 0; l < n; ++l) {
|
| 1778 |
-
sums[0] += y[l+ 0] * ((int8_t)((
|
| 1779 |
-
sums[1] += y[l+32] * ((int8_t)((
|
| 1780 |
-
sums[2] += y[l+64] * ((int8_t)((
|
| 1781 |
-
sums[3] += y[l+96] * ((int8_t)((
|
| 1782 |
}
|
| 1783 |
|
| 1784 |
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
| 1785 |
|
| 1786 |
}
|
|
|
|
| 1787 |
#else
|
| 1788 |
-
const int
|
|
|
|
| 1789 |
|
| 1790 |
-
for (int i =
|
| 1791 |
device const float * y = yy + i * QK_K + il;
|
| 1792 |
device const uint8_t * ql = x[i].ql + il;
|
| 1793 |
device const uint8_t * qh = x[i].qh + il;
|
|
@@ -1807,23 +1668,382 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|
| 1807 |
|
| 1808 |
#endif
|
| 1809 |
|
| 1810 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1811 |
|
| 1812 |
-
|
| 1813 |
-
|
| 1814 |
-
|
| 1815 |
-
|
| 1816 |
-
|
| 1817 |
-
|
|
|
|
| 1818 |
}
|
| 1819 |
-
|
| 1820 |
-
|
| 1821 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1822 |
}
|
| 1823 |
-
|
| 1824 |
-
|
| 1825 |
-
|
| 1826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1827 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1828 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1829 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
| 19 |
} block_q4_1;
|
| 20 |
|
| 21 |
+
#define QK8_0 32
|
| 22 |
+
typedef struct {
|
| 23 |
+
half d; // delta
|
| 24 |
+
int8_t qs[QK8_0]; // quants
|
| 25 |
+
} block_q8_0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
kernel void kernel_add(
|
| 28 |
device const float * src0,
|
|
|
|
| 32 |
dst[tpig] = src0[tpig] + src1[tpig];
|
| 33 |
}
|
| 34 |
|
| 35 |
+
// assumption: src1 is a row
|
| 36 |
+
// broadcast src1 into src0
|
| 37 |
+
kernel void kernel_add_row(
|
| 38 |
+
device const float * src0,
|
| 39 |
+
device const float * src1,
|
| 40 |
+
device float * dst,
|
| 41 |
+
constant int64_t & ne00,
|
| 42 |
+
uint tpig[[thread_position_in_grid]]) {
|
| 43 |
+
dst[tpig] = src0[tpig] + src1[tpig % ne00];
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
kernel void kernel_mul(
|
| 47 |
device const float * src0,
|
| 48 |
device const float * src1,
|
|
|
|
| 93 |
device float * dst,
|
| 94 |
uint tpig[[thread_position_in_grid]]) {
|
| 95 |
float x = src0[tpig];
|
| 96 |
+
|
| 97 |
+
// BEWARE !!!
|
| 98 |
+
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
|
| 99 |
+
// This was observed with Falcon 7B and 40B models
|
| 100 |
+
//
|
| 101 |
+
dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
| 102 |
}
|
| 103 |
|
| 104 |
kernel void kernel_soft_max(
|
|
|
|
| 189 |
}
|
| 190 |
}
|
| 191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
kernel void kernel_norm(
|
| 193 |
device const void * src0,
|
| 194 |
device float * dst,
|
|
|
|
| 264 |
threadgroup float * sum [[threadgroup(0)]],
|
| 265 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 266 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 267 |
+
uint sgitg[[simdgroup_index_in_threadgroup]],
|
| 268 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 269 |
uint ntg[[threads_per_threadgroup]]) {
|
| 270 |
+
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
|
| 271 |
+
device const float * x_scalar = (device const float *) x;
|
| 272 |
+
float4 sumf=0;
|
| 273 |
+
float all_sum=0;
|
| 274 |
|
| 275 |
// parallel sum
|
| 276 |
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 277 |
+
sumf += x[i00] * x[i00];
|
| 278 |
+
}
|
| 279 |
+
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
|
| 280 |
+
all_sum = simd_sum(all_sum);
|
| 281 |
+
if (tiisg == 0) {
|
| 282 |
+
sum[sgitg] = all_sum;
|
| 283 |
}
|
| 284 |
|
|
|
|
| 285 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 286 |
+
// broadcast, simd group number is ntg / 32
|
| 287 |
+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
| 288 |
+
if (tpitg < i) {
|
| 289 |
+
sum[tpitg] += sum[tpitg + i];
|
| 290 |
+
}
|
| 291 |
}
|
|
|
|
|
|
|
| 292 |
if (tpitg == 0) {
|
| 293 |
+
for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
|
| 294 |
sum[0] /= ne00;
|
| 295 |
}
|
| 296 |
|
|
|
|
| 299 |
const float mean = sum[0];
|
| 300 |
const float scale = 1.0f/sqrt(mean + eps);
|
| 301 |
|
| 302 |
+
device float4 * y = (device float4 *) (dst + tgpig*ne00);
|
| 303 |
+
device float * y_scalar = (device float *) y;
|
| 304 |
+
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 305 |
y[i00] = x[i00] * scale;
|
| 306 |
}
|
| 307 |
+
if (tpitg == 0) {
|
| 308 |
+
for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
|
| 309 |
+
}
|
| 310 |
}
|
| 311 |
|
| 312 |
+
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
| 313 |
+
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
| 314 |
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
| 315 |
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
| 316 |
+
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
| 317 |
+
float d = qb_curr->d;
|
| 318 |
+
float2 acc = 0.f;
|
| 319 |
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
| 320 |
+
for (int i = 0; i < 8; i+=2) {
|
| 321 |
+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
| 322 |
+
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
| 323 |
+
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
| 324 |
+
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
| 325 |
+
}
|
| 326 |
+
return d * (sumy * -8.f + acc[0] + acc[1]);
|
| 327 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
+
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
| 330 |
+
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
| 331 |
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
| 332 |
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
| 333 |
+
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
| 334 |
+
float d = qb_curr->d;
|
| 335 |
+
float m = qb_curr->m;
|
| 336 |
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
| 337 |
+
float2 acc = 0.f;
|
| 338 |
+
for (int i = 0; i < 8; i+=2) {
|
| 339 |
+
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
| 340 |
+
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
| 341 |
+
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
| 342 |
+
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
| 343 |
+
}
|
| 344 |
+
return d * (acc[0] + acc[1]) + sumy * m;
|
| 345 |
+
}
|
| 346 |
|
| 347 |
+
// putting them in the kernel cause a significant performance penalty
|
| 348 |
+
#define N_DST 4 // each SIMD group works on 4 rows
|
| 349 |
+
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
| 350 |
+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
| 351 |
+
//Note: This is a template, but strictly speaking it only applies to
|
| 352 |
+
// quantizations where the block size is 32. It also does not
|
| 353 |
+
// giard against the number of rows not being divisible by
|
| 354 |
+
// N_DST, so this is another explicit assumption of the implementation.
|
| 355 |
+
template<typename block_q_type, int nr, int nsg, int nw>
|
| 356 |
+
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
| 357 |
+
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
|
| 358 |
+
uint3 tgpig, uint tiisg, uint sgitg) {
|
| 359 |
+
const int nb = ne00/QK4_0;
|
| 360 |
+
const int r0 = tgpig.x;
|
| 361 |
+
const int r1 = tgpig.y;
|
| 362 |
+
const int im = tgpig.z;
|
| 363 |
+
const int first_row = (r0 * nsg + sgitg) * nr;
|
| 364 |
+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
| 365 |
+
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
|
| 366 |
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
| 367 |
+
float yl[16]; // src1 vector cache
|
| 368 |
+
float sumf[nr]={0.f};
|
| 369 |
+
|
| 370 |
+
const int ix = tiisg/2;
|
| 371 |
+
const int il = 8*(tiisg%2);
|
| 372 |
+
|
| 373 |
+
device const float * yb = y + ix * QK4_0 + il;
|
| 374 |
+
|
| 375 |
+
// each thread in a SIMD group deals with half a block.
|
| 376 |
+
for (int ib = ix; ib < nb; ib += nw/2) {
|
| 377 |
+
float sumy = 0;
|
| 378 |
+
for (int i = 0; i < 8; i += 2) {
|
| 379 |
+
sumy += yb[i] + yb[i+1];
|
| 380 |
+
yl[i+0] = yb[i+ 0];
|
| 381 |
+
yl[i+1] = yb[i+ 1]/256.f;
|
| 382 |
+
sumy += yb[i+16] + yb[i+17];
|
| 383 |
+
yl[i+8] = yb[i+16]/16.f;
|
| 384 |
+
yl[i+9] = yb[i+17]/4096.f;
|
| 385 |
+
}
|
| 386 |
|
| 387 |
+
for (int row = 0; row < nr; row++) {
|
| 388 |
+
sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
|
| 389 |
}
|
| 390 |
|
| 391 |
+
yb += QK4_0 * 16;
|
| 392 |
}
|
| 393 |
|
| 394 |
+
for (int row = 0; row < nr; ++row) {
|
| 395 |
+
const float tot = simd_sum(sumf[row]);
|
| 396 |
+
if (tiisg == 0 && first_row + row < ne01) {
|
| 397 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
| 398 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
}
|
| 400 |
}
|
| 401 |
|
| 402 |
+
kernel void kernel_mul_mat_q4_0_f32(
|
| 403 |
device const void * src0,
|
| 404 |
device const float * src1,
|
| 405 |
device float * dst,
|
| 406 |
constant int64_t & ne00,
|
| 407 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 408 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 409 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 410 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 411 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 412 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 413 |
+
constant uint & gqa[[buffer(17)]],
|
| 414 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 415 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 416 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 417 |
+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
| 418 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
+
kernel void kernel_mul_mat_q4_1_f32(
|
| 421 |
+
device const void * src0,
|
| 422 |
+
device const float * src1,
|
| 423 |
+
device float * dst,
|
| 424 |
+
constant int64_t & ne00,
|
| 425 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 426 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 427 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 428 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 429 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 430 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 431 |
+
constant uint & gqa[[buffer(17)]],
|
| 432 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 433 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 434 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 435 |
+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
| 436 |
+
}
|
| 437 |
|
| 438 |
+
kernel void kernel_mul_mat_q8_0_f32(
|
| 439 |
+
device const void * src0,
|
| 440 |
+
device const float * src1,
|
| 441 |
+
device float * dst,
|
| 442 |
+
constant int64_t & ne00,
|
| 443 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 444 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 445 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 446 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 447 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 448 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 449 |
+
constant uint & gqa[[buffer(17)]],
|
| 450 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 451 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 452 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 453 |
+
const int nr = N_DST;
|
| 454 |
+
const int nsg = N_SIMDGROUP;
|
| 455 |
+
const int nw = N_SIMDWIDTH;
|
| 456 |
+
|
| 457 |
+
const int nb = ne00/QK8_0;
|
| 458 |
+
const int r0 = tgpig.x;
|
| 459 |
+
const int r1 = tgpig.y;
|
| 460 |
+
const int im = tgpig.z;
|
| 461 |
+
const int first_row = (r0 * nsg + sgitg) * nr;
|
| 462 |
+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
|
| 463 |
+
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
| 464 |
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
| 465 |
+
|
| 466 |
+
float yl[16];
|
| 467 |
+
float sumf[nr]={0.f};
|
| 468 |
+
|
| 469 |
+
const int ix = tiisg/2;
|
| 470 |
+
const int il = tiisg%2;
|
| 471 |
+
|
| 472 |
+
device const float * yb = y + ix * QK8_0 + 16*il;
|
| 473 |
+
|
| 474 |
+
// each thread in a SIMD group deals with half a block.
|
| 475 |
+
for (int ib = ix; ib < nb; ib += nw/2) {
|
| 476 |
+
for (int i = 0; i < 16; ++i) {
|
| 477 |
+
yl[i] = yb[i];
|
| 478 |
+
}
|
| 479 |
|
| 480 |
+
for (int row = 0; row < nr; row++) {
|
| 481 |
+
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
| 482 |
+
float sumq = 0.f;
|
| 483 |
+
for (int iq = 0; iq < 16; ++iq) {
|
| 484 |
+
sumq += qs[iq] * yl[iq];
|
| 485 |
+
}
|
| 486 |
+
sumf[row] += sumq*x[ib+row*nb].d;
|
| 487 |
}
|
| 488 |
|
| 489 |
+
yb += QK8_0 * 16;
|
| 490 |
}
|
| 491 |
|
| 492 |
+
for (int row = 0; row < nr; ++row) {
|
| 493 |
+
const float tot = simd_sum(sumf[row]);
|
| 494 |
+
if (tiisg == 0 && first_row + row < ne01) {
|
| 495 |
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
|
| 496 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
}
|
| 498 |
}
|
| 499 |
|
|
|
|
| 503 |
device float * dst,
|
| 504 |
constant int64_t & ne00,
|
| 505 |
constant int64_t & ne01,
|
| 506 |
+
constant int64_t & ne02,
|
| 507 |
constant uint64_t & nb00,
|
| 508 |
constant uint64_t & nb01,
|
| 509 |
constant uint64_t & nb02,
|
| 510 |
constant int64_t & ne10,
|
| 511 |
constant int64_t & ne11,
|
| 512 |
+
constant int64_t & ne12,
|
| 513 |
constant uint64_t & nb10,
|
| 514 |
constant uint64_t & nb11,
|
| 515 |
constant uint64_t & nb12,
|
|
|
|
| 525 |
const int64_t r1 = tgpig.y;
|
| 526 |
const int64_t im = tgpig.z;
|
| 527 |
|
| 528 |
+
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
| 529 |
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
| 530 |
|
| 531 |
sum[tpitg.x] = 0.0f;
|
|
|
|
| 612 |
constant int & n_past,
|
| 613 |
constant int & n_dims,
|
| 614 |
constant int & mode,
|
| 615 |
+
constant float & freq_base,
|
| 616 |
+
constant float & freq_scale,
|
| 617 |
uint3 tpig[[thread_position_in_grid]]) {
|
| 618 |
const int64_t i3 = tpig[2];
|
| 619 |
const int64_t i2 = tpig[1];
|
| 620 |
const int64_t i1 = tpig[0];
|
| 621 |
|
| 622 |
const bool is_neox = mode & 2;
|
| 623 |
+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
|
| 624 |
|
| 625 |
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
|
| 626 |
|
| 627 |
+
float theta = freq_scale * (float)p;
|
| 628 |
|
| 629 |
if (!is_neox) {
|
| 630 |
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
|
|
| 643 |
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
| 644 |
}
|
| 645 |
} else {
|
| 646 |
+
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
| 647 |
+
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
| 648 |
+
const float cos_theta = cos(theta);
|
| 649 |
+
const float sin_theta = sin(theta);
|
| 650 |
+
|
| 651 |
+
theta *= theta_scale;
|
| 652 |
+
|
| 653 |
+
const int64_t i0 = ib*n_dims + ic/2;
|
| 654 |
+
|
| 655 |
+
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
| 656 |
+
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
| 657 |
+
|
| 658 |
+
const float x0 = src[0];
|
| 659 |
+
const float x1 = src[n_dims/2];
|
| 660 |
+
|
| 661 |
+
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
| 662 |
+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
| 663 |
+
}
|
| 664 |
+
}
|
| 665 |
}
|
| 666 |
}
|
| 667 |
|
|
|
|
| 880 |
return r;
|
| 881 |
}
|
| 882 |
|
| 883 |
+
//====================================== dot products =========================
|
| 884 |
|
| 885 |
+
kernel void kernel_mul_mat_q2_K_f32(
|
| 886 |
+
device const void * src0,
|
| 887 |
+
device const float * src1,
|
| 888 |
+
device float * dst,
|
| 889 |
+
constant int64_t & ne00,
|
| 890 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 891 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 892 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 893 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 894 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 895 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 896 |
+
constant uint & gqa[[buffer(17)]],
|
| 897 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 898 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 899 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 900 |
|
| 901 |
+
const int nb = ne00/QK_K;
|
| 902 |
+
const int r0 = tgpig.x;
|
| 903 |
+
const int r1 = tgpig.y;
|
| 904 |
+
const int r2 = tgpig.z;
|
| 905 |
|
| 906 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 907 |
+
const int ib_row = first_row * nb;
|
| 908 |
+
const uint offset0 = r2/gqa*(nb*ne0);
|
| 909 |
+
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
|
| 910 |
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
| 911 |
+
float yl[32];
|
| 912 |
+
float sumf[N_DST]={0.f}, all_sum;
|
| 913 |
|
| 914 |
+
const int step = sizeof(block_q2_K) * nb;
|
| 915 |
|
| 916 |
#if QK_K == 256
|
| 917 |
+
const int ix = tiisg/8; // 0...3
|
| 918 |
+
const int it = tiisg%8; // 0...7
|
| 919 |
+
const int im = it/4; // 0 or 1
|
| 920 |
+
const int ir = it%4; // 0...3
|
| 921 |
+
const int is = (8*ir)/16;// 0 or 1
|
| 922 |
+
|
| 923 |
+
device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
|
| 924 |
+
|
| 925 |
+
for (int ib = ix; ib < nb; ib += 4) {
|
| 926 |
+
|
| 927 |
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 928 |
+
for (int i = 0; i < 8; ++i) {
|
| 929 |
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
| 930 |
+
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
| 931 |
+
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
| 932 |
+
yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
|
| 933 |
+
}
|
| 934 |
|
| 935 |
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
|
| 936 |
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
| 937 |
+
device const half * dh = &x[ib].d;
|
| 938 |
+
|
| 939 |
+
for (int row = 0; row < N_DST; row++) {
|
| 940 |
+
|
| 941 |
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
| 942 |
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
| 943 |
+
for (int i = 0; i < 8; i += 2) {
|
| 944 |
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
| 945 |
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
| 946 |
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
| 947 |
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
| 948 |
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
| 949 |
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
| 950 |
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
| 951 |
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
| 952 |
}
|
| 953 |
+
float dall = dh[0];
|
| 954 |
+
float dmin = dh[1] * 1.f/16.f;
|
| 955 |
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
| 956 |
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
|
| 957 |
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
|
| 958 |
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
|
| 959 |
+
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
|
| 960 |
+
|
| 961 |
+
qs += step/2;
|
| 962 |
+
sc += step;
|
| 963 |
+
dh += step/2;
|
| 964 |
}
|
| 965 |
+
|
| 966 |
+
y4 += 4 * QK_K;
|
| 967 |
+
}
|
| 968 |
#else
|
| 969 |
+
const int ix = tiisg/2; // 0...15
|
| 970 |
+
const int it = tiisg%2; // 0...1
|
| 971 |
+
|
| 972 |
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
| 973 |
+
|
| 974 |
+
for (int ib = ix; ib < nb; ib += 16) {
|
| 975 |
+
|
| 976 |
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 977 |
+
for (int i = 0; i < 8; ++i) {
|
| 978 |
+
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
| 979 |
+
yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
|
| 980 |
+
yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
|
| 981 |
+
yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
|
| 982 |
+
}
|
| 983 |
+
|
| 984 |
+
device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
|
| 985 |
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
| 986 |
+
device const half * dh = &x[ib].d;
|
| 987 |
+
|
| 988 |
+
for (int row = 0; row < N_DST; row++) {
|
| 989 |
+
|
| 990 |
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
| 991 |
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
| 992 |
+
for (int i = 0; i < 8; i += 2) {
|
| 993 |
+
acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
|
| 994 |
+
acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
|
| 995 |
+
acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
|
| 996 |
+
acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
|
| 997 |
+
acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
|
| 998 |
+
acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
|
| 999 |
+
acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
|
| 1000 |
+
acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
|
| 1001 |
+
}
|
| 1002 |
+
|
| 1003 |
+
float dall = dh[0];
|
| 1004 |
+
float dmin = dh[1];
|
| 1005 |
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
|
| 1006 |
+
(acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
|
| 1007 |
+
(acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
|
| 1008 |
+
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
|
| 1009 |
+
dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
|
| 1010 |
+
|
| 1011 |
+
qs += step/2;
|
| 1012 |
+
sc += step;
|
| 1013 |
+
dh += step/2;
|
| 1014 |
}
|
| 1015 |
+
|
| 1016 |
+
y4 += 16 * QK_K;
|
| 1017 |
+
}
|
| 1018 |
#endif
|
| 1019 |
|
| 1020 |
+
for (int row = 0; row < N_DST; ++row) {
|
| 1021 |
+
all_sum = simd_sum(sumf[row]);
|
| 1022 |
+
if (tiisg == 0) {
|
| 1023 |
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
|
| 1024 |
+
}
|
| 1025 |
}
|
| 1026 |
}
|
| 1027 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1028 |
#if QK_K == 256
|
| 1029 |
+
kernel void kernel_mul_mat_q3_K_f32(
|
| 1030 |
+
device const void * src0,
|
| 1031 |
+
device const float * src1,
|
| 1032 |
+
device float * dst,
|
| 1033 |
+
constant int64_t & ne00,
|
| 1034 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 1035 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 1036 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 1037 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 1038 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 1039 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 1040 |
+
constant uint & gqa[[buffer(17)]],
|
| 1041 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1042 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 1043 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1044 |
+
|
| 1045 |
+
const int nb = ne00/QK_K;
|
| 1046 |
+
|
| 1047 |
+
const int64_t r0 = tgpig.x;
|
| 1048 |
+
const int64_t r1 = tgpig.y;
|
| 1049 |
+
const int64_t r2 = tgpig.z;
|
| 1050 |
+
|
| 1051 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
| 1052 |
+
const uint offset0 = r2/gqa*(nb*ne0);
|
| 1053 |
+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
|
| 1054 |
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
| 1055 |
+
|
| 1056 |
+
float yl[16];
|
| 1057 |
|
| 1058 |
const uint16_t kmask1 = 0x0303;
|
| 1059 |
const uint16_t kmask2 = 0x0f0f;
|
| 1060 |
|
| 1061 |
+
const int tid = tiisg/2;
|
| 1062 |
+
const int ix = tiisg%2;
|
| 1063 |
+
const int ip = tid/8; // 0 or 1
|
| 1064 |
+
const int il = tid/2 - 4*ip; // 0...3
|
| 1065 |
+
const int ir = tid%2;
|
| 1066 |
+
const int n = 8;
|
| 1067 |
+
const int l0 = n*ir;
|
| 1068 |
|
| 1069 |
+
const uint16_t m1 = 1 << (4*ip + il);
|
| 1070 |
+
const uint16_t m2 = m1 << 8;
|
| 1071 |
|
| 1072 |
+
const int shift = 2*il;
|
| 1073 |
+
const uint16_t qm1 = 0x0003 << shift;
|
| 1074 |
+
const uint16_t qm2 = 0x0300 << shift;
|
| 1075 |
+
const int32_t v1 = 4 << shift;
|
| 1076 |
+
const int32_t v2 = 1024 << shift;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
|
| 1078 |
+
const uint16_t s_shift1 = 4*ip;
|
| 1079 |
+
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
| 1080 |
+
const int ik = 4 + (il%2);
|
|
|
|
| 1081 |
|
| 1082 |
+
const int q_offset = 32*ip + l0;
|
| 1083 |
+
const int y_offset = 128*ip + 32*il + l0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1084 |
|
| 1085 |
+
const int step = sizeof(block_q3_K) * nb / 2;
|
| 1086 |
|
| 1087 |
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
|
|
|
| 1088 |
|
| 1089 |
+
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
| 1090 |
+
for (int i = ix; i < nb; i += 2) {
|
|
|
|
|
|
|
| 1091 |
|
| 1092 |
for (int l = 0; l < 8; ++l) {
|
| 1093 |
+
yl[l+0] = y1[l+ 0];
|
| 1094 |
+
yl[l+8] = y1[l+16];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1095 |
}
|
|
|
|
|
|
|
|
|
|
| 1096 |
|
| 1097 |
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
| 1098 |
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
| 1099 |
+
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
| 1100 |
+
device const half * dh = &x[i].d;
|
| 1101 |
|
| 1102 |
+
for (int row = 0; row < 2; ++row) {
|
|
|
|
|
|
|
| 1103 |
|
| 1104 |
+
const float d_all = (float)dh[0];
|
| 1105 |
+
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
| 1106 |
|
| 1107 |
+
float s1 = 0, s2 = 0;
|
| 1108 |
+
for (int l = 0; l < n; l += 2) {
|
| 1109 |
+
const uint16_t qs = q[l/2];
|
| 1110 |
+
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
| 1111 |
+
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
| 1112 |
+
}
|
| 1113 |
+
float d = d_all * (s1 + 1.f/256.f * s2);
|
| 1114 |
+
sumf1[row] += d * scales[0];
|
| 1115 |
+
sumf2[row] += d;
|
| 1116 |
+
|
| 1117 |
+
s1 = s2 = 0;
|
| 1118 |
+
for (int l = 0; l < n; l += 2) {
|
| 1119 |
+
const uint16_t qs = q[l/2+8];
|
| 1120 |
+
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
| 1121 |
+
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
| 1122 |
+
}
|
| 1123 |
+
d = d_all * (s1 + 1.f/256.f * s2);
|
| 1124 |
+
sumf1[row] += d * scales[1];
|
| 1125 |
+
sumf2[row] += d;
|
| 1126 |
+
|
| 1127 |
+
q += step;
|
| 1128 |
+
h += step;
|
| 1129 |
+
a += step;
|
| 1130 |
+
dh += step;
|
| 1131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1132 |
}
|
|
|
|
|
|
|
| 1133 |
|
| 1134 |
+
y1 += 2 * QK_K;
|
|
|
|
| 1135 |
|
| 1136 |
+
}
|
|
|
|
|
|
|
| 1137 |
|
| 1138 |
+
for (int row = 0; row < 2; ++row) {
|
| 1139 |
+
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
| 1140 |
+
const float tot = simd_sum(sumf);
|
| 1141 |
+
if (tiisg == 0) {
|
| 1142 |
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1143 |
}
|
| 1144 |
}
|
| 1145 |
+
}
|
| 1146 |
#else
|
| 1147 |
+
kernel void kernel_mul_mat_q3_K_f32(
|
| 1148 |
+
device const void * src0,
|
| 1149 |
+
device const float * src1,
|
| 1150 |
+
device float * dst,
|
| 1151 |
+
constant int64_t & ne00,
|
| 1152 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 1153 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 1154 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 1155 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 1156 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 1157 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 1158 |
+
constant uint & gqa[[buffer(17)]],
|
| 1159 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1160 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 1161 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1162 |
|
| 1163 |
+
const int nb = ne00/QK_K;
|
| 1164 |
|
| 1165 |
+
const int64_t r0 = tgpig.x;
|
| 1166 |
+
const int64_t r1 = tgpig.y;
|
| 1167 |
+
const int64_t r2 = tgpig.z;
|
| 1168 |
+
|
| 1169 |
+
const int row = 2 * r0 + sgitg;
|
| 1170 |
+
const uint offset0 = r2/gqa*(nb*ne0);
|
| 1171 |
+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
|
| 1172 |
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
| 1173 |
+
const int ix = tiisg/4;
|
| 1174 |
+
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
| 1175 |
+
const int im = il/8; // 0, 0, 1, 1
|
| 1176 |
+
const int in = il%8; // 0, 4, 0, 4
|
| 1177 |
|
| 1178 |
+
float2 sum = {0.f, 0.f};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1179 |
|
| 1180 |
+
for (int i = ix; i < nb; i += 8) {
|
| 1181 |
|
| 1182 |
+
const float d_all = (float)(x[i].d);
|
|
|
|
|
|
|
| 1183 |
|
| 1184 |
+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
| 1185 |
+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
| 1186 |
+
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
| 1187 |
+
device const float * y = yy + i * QK_K + il;
|
| 1188 |
+
|
| 1189 |
+
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
| 1190 |
+
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
| 1191 |
+
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
| 1192 |
+
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
| 1193 |
+
|
| 1194 |
+
for (int l = 0; l < 4; l += 2) {
|
| 1195 |
+
const uint16_t hm = h[l/2] >> im;
|
| 1196 |
+
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
| 1197 |
+
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
| 1198 |
+
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
| 1199 |
+
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
| 1200 |
+
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
| 1201 |
+
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
| 1202 |
+
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
| 1203 |
+
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1204 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1205 |
|
| 1206 |
}
|
| 1207 |
+
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1208 |
|
| 1209 |
+
const float tot = simd_sum(sumf);
|
| 1210 |
+
if (tiisg == 0) {
|
| 1211 |
+
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1212 |
}
|
|
|
|
|
|
|
|
|
|
| 1213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1214 |
}
|
| 1215 |
+
#endif
|
| 1216 |
|
| 1217 |
+
#if QK_K == 256
|
| 1218 |
+
kernel void kernel_mul_mat_q4_K_f32(
|
| 1219 |
device const void * src0,
|
| 1220 |
device const float * src1,
|
| 1221 |
device float * dst,
|
| 1222 |
constant int64_t & ne00,
|
| 1223 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 1224 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 1225 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 1226 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 1227 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 1228 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 1229 |
+
constant uint & gqa[[buffer(17)]],
|
| 1230 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1231 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 1232 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1233 |
|
| 1234 |
+
const uint16_t kmask1 = 0x3f3f;
|
| 1235 |
const uint16_t kmask2 = 0x0f0f;
|
| 1236 |
+
const uint16_t kmask3 = 0xc0c0;
|
| 1237 |
|
| 1238 |
+
const int ix = tiisg/8; // 0...3
|
| 1239 |
+
const int it = tiisg%8; // 0...7
|
| 1240 |
+
const int im = it/4; // 0 or 1
|
| 1241 |
+
const int ir = it%4; // 0...3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1242 |
|
| 1243 |
+
const int nb = ne00/QK_K;
|
| 1244 |
+
const int r0 = tgpig.x;
|
| 1245 |
+
const int r1 = tgpig.y;
|
| 1246 |
+
const int r2 = tgpig.z;
|
| 1247 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 1248 |
+
const int ib_row = first_row * nb;
|
| 1249 |
+
const uint offset0 = r2/gqa*(nb*ne0);
|
| 1250 |
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
| 1251 |
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
| 1252 |
+
float yl[16];
|
| 1253 |
+
float yh[16];
|
| 1254 |
+
float sumf[N_DST]={0.f}, all_sum;
|
| 1255 |
+
|
| 1256 |
+
const int step = sizeof(block_q4_K) * nb / 2;
|
| 1257 |
+
|
| 1258 |
+
device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
|
| 1259 |
+
|
| 1260 |
+
uint16_t sc16[4];
|
| 1261 |
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
| 1262 |
+
|
| 1263 |
+
for (int ib = ix; ib < nb; ib += 4) {
|
| 1264 |
+
|
| 1265 |
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 1266 |
+
for (int i = 0; i < 8; ++i) {
|
| 1267 |
+
yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
|
| 1268 |
+
yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
|
| 1269 |
+
yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
|
| 1270 |
+
yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
|
| 1271 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1272 |
|
| 1273 |
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
|
| 1274 |
+
device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
|
| 1275 |
+
device const half * dh = &x[ib].d;
|
| 1276 |
+
|
| 1277 |
+
for (int row = 0; row < N_DST; row++) {
|
| 1278 |
+
|
| 1279 |
+
sc16[0] = sc[0] & kmask1;
|
| 1280 |
+
sc16[1] = sc[2] & kmask1;
|
| 1281 |
+
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
|
| 1282 |
+
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
|
| 1283 |
+
|
| 1284 |
+
device const uint16_t * q2 = q1 + 32;
|
| 1285 |
+
|
| 1286 |
+
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
| 1287 |
+
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
| 1288 |
+
for (int i = 0; i < 8; i += 2) {
|
| 1289 |
+
acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
|
| 1290 |
+
acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
|
| 1291 |
+
acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
|
| 1292 |
+
acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
|
| 1293 |
+
acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
|
| 1294 |
+
acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
|
| 1295 |
+
acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
|
| 1296 |
+
acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
|
| 1297 |
+
}
|
| 1298 |
|
| 1299 |
+
float dall = dh[0];
|
| 1300 |
+
float dmin = dh[1];
|
| 1301 |
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
| 1302 |
+
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
| 1303 |
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
| 1304 |
+
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
| 1305 |
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
| 1306 |
+
|
| 1307 |
+
q1 += step;
|
| 1308 |
+
sc += step;
|
| 1309 |
+
dh += step;
|
| 1310 |
}
|
| 1311 |
|
| 1312 |
+
y4 += 4 * QK_K;
|
| 1313 |
}
|
| 1314 |
|
| 1315 |
+
for (int row = 0; row < N_DST; ++row) {
|
| 1316 |
+
all_sum = simd_sum(sumf[row]);
|
| 1317 |
+
if (tiisg == 0) {
|
| 1318 |
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
|
| 1319 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1320 |
}
|
|
|
|
| 1321 |
}
|
| 1322 |
+
#else
|
| 1323 |
kernel void kernel_mul_mat_q4_K_f32(
|
| 1324 |
device const void * src0,
|
| 1325 |
device const float * src1,
|
| 1326 |
device float * dst,
|
| 1327 |
constant int64_t & ne00,
|
| 1328 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 1329 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 1330 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 1331 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 1332 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 1333 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 1334 |
+
constant uint & gqa[[buffer(17)]],
|
| 1335 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1336 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 1337 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1338 |
|
| 1339 |
+
const int ix = tiisg/4; // 0...7
|
| 1340 |
+
const int it = tiisg%4; // 0...3
|
| 1341 |
|
| 1342 |
+
const int nb = ne00/QK_K;
|
| 1343 |
+
const int r0 = tgpig.x;
|
| 1344 |
+
const int r1 = tgpig.y;
|
| 1345 |
+
const int r2 = tgpig.z;
|
| 1346 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
| 1347 |
+
const int ib_row = first_row * nb;
|
| 1348 |
+
const uint offset0 = r2/gqa*(nb*ne0);
|
| 1349 |
+
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
| 1350 |
+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
| 1351 |
+
float yl[8];
|
| 1352 |
+
float yh[8];
|
| 1353 |
+
float sumf[N_DST]={0.f}, all_sum;
|
| 1354 |
+
|
| 1355 |
+
const int step = sizeof(block_q4_K) * nb / 2;
|
| 1356 |
+
|
| 1357 |
+
device const float * y4 = y + ix * QK_K + 8 * it;
|
| 1358 |
+
|
| 1359 |
+
uint16_t sc16[4];
|
| 1360 |
+
|
| 1361 |
+
for (int ib = ix; ib < nb; ib += 8) {
|
| 1362 |
+
|
| 1363 |
+
float2 sumy = {0.f, 0.f};
|
| 1364 |
+
for (int i = 0; i < 8; ++i) {
|
| 1365 |
+
yl[i] = y4[i+ 0]; sumy[0] += yl[i];
|
| 1366 |
+
yh[i] = y4[i+32]; sumy[1] += yh[i];
|
| 1367 |
+
}
|
| 1368 |
|
| 1369 |
+
device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
|
| 1370 |
+
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
|
| 1371 |
+
device const half * dh = x[ib].d;
|
|
|
|
| 1372 |
|
| 1373 |
+
for (int row = 0; row < N_DST; row++) {
|
|
|
|
| 1374 |
|
| 1375 |
+
sc16[0] = sc[0] & 0x000f;
|
| 1376 |
+
sc16[1] = sc[0] & 0x0f00;
|
| 1377 |
+
sc16[2] = sc[0] & 0x00f0;
|
| 1378 |
+
sc16[3] = sc[0] & 0xf000;
|
|
|
|
| 1379 |
|
| 1380 |
+
float2 acc1 = {0.f, 0.f};
|
| 1381 |
+
float2 acc2 = {0.f, 0.f};
|
| 1382 |
+
for (int i = 0; i < 8; i += 2) {
|
| 1383 |
+
acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
|
| 1384 |
+
acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
|
| 1385 |
+
acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
|
| 1386 |
+
acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
|
| 1387 |
+
}
|
| 1388 |
|
| 1389 |
+
float dall = dh[0];
|
| 1390 |
+
float dmin = dh[1];
|
| 1391 |
+
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
|
| 1392 |
+
(acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
|
| 1393 |
+
dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
|
| 1394 |
|
| 1395 |
+
qs += step;
|
| 1396 |
+
sc += step;
|
| 1397 |
+
dh += step;
|
| 1398 |
}
|
|
|
|
| 1399 |
|
| 1400 |
+
y4 += 8 * QK_K;
|
| 1401 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1402 |
|
| 1403 |
+
for (int row = 0; row < N_DST; ++row) {
|
| 1404 |
+
all_sum = simd_sum(sumf[row]);
|
| 1405 |
+
if (tiisg == 0) {
|
| 1406 |
+
dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1407 |
}
|
| 1408 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1409 |
}
|
| 1410 |
+
#endif
|
| 1411 |
|
| 1412 |
kernel void kernel_mul_mat_q5_K_f32(
|
| 1413 |
device const void * src0,
|
| 1414 |
device const float * src1,
|
| 1415 |
device float * dst,
|
| 1416 |
constant int64_t & ne00,
|
| 1417 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 1418 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 1419 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 1420 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 1421 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 1422 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 1423 |
+
constant uint & gqa[[buffer(17)]],
|
| 1424 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1425 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 1426 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1427 |
|
| 1428 |
const int nb = ne00/QK_K;
|
| 1429 |
|
| 1430 |
const int64_t r0 = tgpig.x;
|
| 1431 |
const int64_t r1 = tgpig.y;
|
| 1432 |
+
const int r2 = tgpig.z;
|
| 1433 |
|
| 1434 |
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
| 1435 |
+
const uint offset0 = r2/gqa*(nb*ne0);
|
| 1436 |
+
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
|
| 1437 |
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
| 1438 |
|
| 1439 |
+
float sumf[2]={0.f};
|
|
|
|
| 1440 |
|
| 1441 |
+
const int step = sizeof(block_q5_K) * nb;
|
| 1442 |
|
| 1443 |
#if QK_K == 256
|
| 1444 |
+
#
|
| 1445 |
+
float yl[16], yh[16];
|
| 1446 |
|
| 1447 |
const uint16_t kmask1 = 0x3f3f;
|
| 1448 |
const uint16_t kmask2 = 0x0f0f;
|
| 1449 |
const uint16_t kmask3 = 0xc0c0;
|
| 1450 |
|
| 1451 |
+
const int tid = tiisg/4;
|
| 1452 |
+
const int ix = tiisg%4;
|
| 1453 |
+
const int im = tid/4;
|
| 1454 |
+
const int ir = tid%4;
|
| 1455 |
+
const int n = 8;
|
|
|
|
|
|
|
| 1456 |
|
| 1457 |
+
const int l0 = n*ir;
|
| 1458 |
const int q_offset = 32*im + l0;
|
| 1459 |
const int y_offset = 64*im + l0;
|
| 1460 |
|
|
|
|
| 1463 |
const uint8_t hm3 = hm1 << 4;
|
| 1464 |
const uint8_t hm4 = hm2 << 4;
|
| 1465 |
|
| 1466 |
+
uint16_t sc16[4];
|
| 1467 |
+
thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
|
| 1468 |
|
| 1469 |
+
device const float * y1 = yy + ix*QK_K + y_offset;
|
| 1470 |
|
| 1471 |
+
for (int i = ix; i < nb; i += 4) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1472 |
|
| 1473 |
+
device const uint8_t * q1 = x[i].qs + q_offset;
|
| 1474 |
+
device const uint8_t * qh = x[i].qh + l0;
|
| 1475 |
+
device const half * dh = &x[i].d;
|
| 1476 |
+
device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
|
| 1477 |
|
| 1478 |
+
device const float * y2 = y1 + 128;
|
| 1479 |
+
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
| 1480 |
+
for (int l = 0; l < 8; ++l) {
|
| 1481 |
+
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
| 1482 |
+
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
| 1483 |
+
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
| 1484 |
+
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
| 1485 |
+
}
|
| 1486 |
|
| 1487 |
+
for (int row = 0; row < 2; ++row) {
|
| 1488 |
+
|
| 1489 |
+
device const uint8_t * q2 = q1 + 64;
|
| 1490 |
+
|
| 1491 |
+
sc16[0] = a[0] & kmask1;
|
| 1492 |
+
sc16[1] = a[2] & kmask1;
|
| 1493 |
+
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
|
| 1494 |
+
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
|
| 1495 |
+
|
| 1496 |
+
float4 acc = {0.f, 0.f, 0.f, 0.f};
|
| 1497 |
+
for (int l = 0; l < n; ++l) {
|
| 1498 |
+
uint8_t h = qh[l];
|
| 1499 |
+
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
|
| 1500 |
+
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
|
| 1501 |
+
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
|
| 1502 |
+
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
|
| 1503 |
+
}
|
| 1504 |
+
const float dall = dh[0];
|
| 1505 |
+
const float dmin = dh[1];
|
| 1506 |
+
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
|
| 1507 |
+
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
| 1508 |
|
| 1509 |
+
q1 += step;
|
| 1510 |
+
qh += step;
|
| 1511 |
+
dh += step/2;
|
| 1512 |
+
a += step/2;
|
|
|
|
| 1513 |
|
| 1514 |
}
|
| 1515 |
+
|
| 1516 |
+
y1 += 4 * QK_K;
|
| 1517 |
|
| 1518 |
}
|
| 1519 |
#else
|
| 1520 |
+
float yl[8], yh[8];
|
| 1521 |
+
|
| 1522 |
+
const int il = 4 * (tiisg/8); // 0, 4, 8, 12
|
| 1523 |
+
const int ix = tiisg%8;
|
| 1524 |
+
const int im = il/8; // 0, 0, 1, 1
|
| 1525 |
+
const int in = il%8; // 0, 4, 0, 4
|
| 1526 |
+
|
| 1527 |
+
device const float * y = yy + ix*QK_K + il;
|
| 1528 |
|
| 1529 |
+
for (int i = ix; i < nb; i += 8) {
|
| 1530 |
|
| 1531 |
+
for (int l = 0; l < 4; ++l) {
|
| 1532 |
+
yl[l+0] = y[l+ 0];
|
| 1533 |
+
yl[l+4] = y[l+16];
|
| 1534 |
+
yh[l+0] = y[l+32];
|
| 1535 |
+
yh[l+4] = y[l+48];
|
| 1536 |
+
}
|
| 1537 |
+
|
| 1538 |
+
device const half * dh = &x[i].d;
|
| 1539 |
device const uint8_t * q = x[i].qs + il;
|
| 1540 |
device const uint8_t * h = x[i].qh + in;
|
| 1541 |
device const int8_t * s = x[i].scales;
|
|
|
|
| 1542 |
|
| 1543 |
+
for (int row = 0; row < 2; ++row) {
|
| 1544 |
+
|
| 1545 |
+
const float d = dh[0];
|
| 1546 |
+
|
| 1547 |
+
float2 acc = {0.f, 0.f};
|
| 1548 |
+
for (int l = 0; l < 4; ++l) {
|
| 1549 |
+
const uint8_t hl = h[l] >> im;
|
| 1550 |
+
acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
|
| 1551 |
+
+ yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
|
| 1552 |
+
acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
|
| 1553 |
+
+ yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
|
| 1554 |
+
}
|
| 1555 |
+
sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
|
| 1556 |
+
|
| 1557 |
+
q += step;
|
| 1558 |
+
h += step;
|
| 1559 |
+
s += step;
|
| 1560 |
+
dh += step/2;
|
| 1561 |
+
|
| 1562 |
}
|
| 1563 |
+
|
| 1564 |
+
y += 8 * QK_K;
|
| 1565 |
}
|
| 1566 |
#endif
|
|
|
|
| 1567 |
|
| 1568 |
+
for (int row = 0; row < 2; ++row) {
|
| 1569 |
+
const float tot = simd_sum(sumf[row]);
|
| 1570 |
+
if (tiisg == 0) {
|
| 1571 |
+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
|
| 1572 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1573 |
}
|
| 1574 |
|
| 1575 |
}
|
|
|
|
| 1579 |
device const float * src1,
|
| 1580 |
device float * dst,
|
| 1581 |
constant int64_t & ne00,
|
| 1582 |
+
constant int64_t & ne01[[buffer(4)]],
|
| 1583 |
+
constant int64_t & ne02[[buffer(5)]],
|
| 1584 |
+
constant int64_t & ne10[[buffer(9)]],
|
| 1585 |
+
constant int64_t & ne12[[buffer(11)]],
|
| 1586 |
+
constant int64_t & ne0[[buffer(15)]],
|
| 1587 |
+
constant int64_t & ne1[[buffer(16)]],
|
| 1588 |
+
constant uint & gqa[[buffer(17)]],
|
| 1589 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1590 |
+
uint tiisg[[thread_index_in_simdgroup]],
|
| 1591 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1592 |
|
| 1593 |
const uint8_t kmask1 = 0x03;
|
| 1594 |
const uint8_t kmask2 = 0x0C;
|
|
|
|
| 1599 |
|
| 1600 |
const int64_t r0 = tgpig.x;
|
| 1601 |
const int64_t r1 = tgpig.y;
|
| 1602 |
+
const int r2 = tgpig.z;
|
| 1603 |
|
| 1604 |
+
const int row = 2 * r0 + sgitg;
|
| 1605 |
+
const uint offset0 = r2/gqa*(nb*ne0);
|
| 1606 |
+
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
|
| 1607 |
+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
|
|
|
|
| 1608 |
|
| 1609 |
float sumf = 0;
|
| 1610 |
|
| 1611 |
#if QK_K == 256
|
| 1612 |
+
const int tid = tiisg/2;
|
| 1613 |
+
const int ix = tiisg%2;
|
| 1614 |
+
const int ip = tid/8; // 0 or 1
|
| 1615 |
+
const int il = tid%8;
|
| 1616 |
const int n = 4;
|
| 1617 |
const int l0 = n*il;
|
| 1618 |
const int is = 8*ip + l0/16;
|
|
|
|
| 1621 |
const int q_offset_l = 64*ip + l0;
|
| 1622 |
const int q_offset_h = 32*ip + l0;
|
| 1623 |
|
| 1624 |
+
for (int i = ix; i < nb; i += 2) {
|
| 1625 |
|
| 1626 |
+
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
| 1627 |
+
device const uint8_t * q2 = q1 + 32;
|
| 1628 |
device const uint8_t * qh = x[i].qh + q_offset_h;
|
| 1629 |
device const int8_t * sc = x[i].scales + is;
|
| 1630 |
|
|
|
|
| 1634 |
|
| 1635 |
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
| 1636 |
for (int l = 0; l < n; ++l) {
|
| 1637 |
+
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
| 1638 |
+
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
| 1639 |
+
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
| 1640 |
+
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
| 1641 |
}
|
| 1642 |
|
| 1643 |
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
| 1644 |
|
| 1645 |
}
|
| 1646 |
+
|
| 1647 |
#else
|
| 1648 |
+
const int ix = tiisg/4;
|
| 1649 |
+
const int il = 4*(tiisg%4);
|
| 1650 |
|
| 1651 |
+
for (int i = ix; i < nb; i += 8) {
|
| 1652 |
device const float * y = yy + i * QK_K + il;
|
| 1653 |
device const uint8_t * ql = x[i].ql + il;
|
| 1654 |
device const uint8_t * qh = x[i].qh + il;
|
|
|
|
| 1668 |
|
| 1669 |
#endif
|
| 1670 |
|
| 1671 |
+
const float tot = simd_sum(sumf);
|
| 1672 |
+
if (tiisg == 0) {
|
| 1673 |
+
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
|
| 1674 |
+
}
|
| 1675 |
+
}
|
| 1676 |
|
| 1677 |
+
//============================= templates and their specializations =============================
|
| 1678 |
+
|
| 1679 |
+
template <typename type4x4>
|
| 1680 |
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
| 1681 |
+
half4x4 temp = *(((device half4x4 *)src));
|
| 1682 |
+
for (int i = 0; i < 16; i++){
|
| 1683 |
+
reg[i/4][i%4] = temp[i/4][i%4];
|
| 1684 |
}
|
| 1685 |
+
}
|
| 1686 |
+
|
| 1687 |
+
template <typename type4x4>
|
| 1688 |
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
| 1689 |
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
| 1690 |
+
const half d = il ? (xb->d / 16.h) : xb->d;
|
| 1691 |
+
const half m = il ? ( -8.h * 16.h) : -8.h;
|
| 1692 |
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 1693 |
+
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
| 1694 |
+
|
| 1695 |
+
for (int i=0;i<8;i++) {
|
| 1696 |
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
|
| 1697 |
+
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
|
| 1698 |
}
|
| 1699 |
+
}
|
| 1700 |
+
|
| 1701 |
+
template <typename type4x4>
|
| 1702 |
+
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
| 1703 |
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
| 1704 |
+
const half d = il ? (xb->d / 16.h) : xb->d;
|
| 1705 |
+
const half m = xb->m;
|
| 1706 |
+
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 1707 |
+
const ushort mask1 = il ? 0xF000 : 0x0F00;
|
| 1708 |
+
|
| 1709 |
+
for (int i=0;i<8;i++) {
|
| 1710 |
+
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
|
| 1711 |
+
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
|
| 1712 |
+
}
|
| 1713 |
+
}
|
| 1714 |
+
|
| 1715 |
+
template <typename type4x4>
|
| 1716 |
+
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
| 1717 |
+
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
| 1718 |
+
const half d = xb->d;
|
| 1719 |
+
|
| 1720 |
+
for (int i=0;i<16;i++) {
|
| 1721 |
+
reg[i/4][i%4] = (qs[i + 16*il] * d);
|
| 1722 |
+
}
|
| 1723 |
+
}
|
| 1724 |
+
|
| 1725 |
+
template <typename type4x4>
|
| 1726 |
+
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
| 1727 |
+
const half d = xb->d;
|
| 1728 |
+
const half min = xb->dmin;
|
| 1729 |
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
| 1730 |
+
half dl, ml;
|
| 1731 |
+
uint8_t sc = xb->scales[il];
|
| 1732 |
+
|
| 1733 |
+
#if QK_K == 256
|
| 1734 |
+
q = q + 32*(il/8) + 16*(il&1);
|
| 1735 |
+
il = (il/2)%4;
|
| 1736 |
+
#endif
|
| 1737 |
+
half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
| 1738 |
+
uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
| 1739 |
+
dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
|
| 1740 |
+
for (int i = 0; i < 16; ++i) {
|
| 1741 |
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
| 1742 |
+
}
|
| 1743 |
+
}
|
| 1744 |
+
|
| 1745 |
+
template <typename type4x4>
|
| 1746 |
+
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
|
| 1747 |
+
const float d_all = (float)(xb->d);
|
| 1748 |
+
device const uint8_t * q = (device const uint8_t *)xb->qs;
|
| 1749 |
+
device const uint8_t * h = (device const uint8_t *)xb->hmask;
|
| 1750 |
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
| 1751 |
+
|
| 1752 |
+
#if QK_K == 256
|
| 1753 |
+
q = q + 32 * (il/8) + 16 * (il&1);
|
| 1754 |
+
h = h + 16 * (il&1);
|
| 1755 |
+
uint8_t m = 1 << (il/2);
|
| 1756 |
+
uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
|
| 1757 |
+
((il/4)>0 ? 12 : 3);
|
| 1758 |
+
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
|
| 1759 |
+
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
|
| 1760 |
+
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
|
| 1761 |
+
(scale_2&kmask2) | ((scale_1&kmask1) << 4);
|
| 1762 |
+
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
|
| 1763 |
+
|
| 1764 |
+
il = (il/2)%4;
|
| 1765 |
+
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
| 1766 |
+
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
| 1767 |
+
|
| 1768 |
+
for (int i = 0; i < 16; ++i) {
|
| 1769 |
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
|
| 1770 |
+
}
|
| 1771 |
+
#else
|
| 1772 |
+
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
| 1773 |
+
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
| 1774 |
+
float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
|
| 1775 |
+
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
|
| 1776 |
+
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
| 1777 |
+
uint8_t m = 1<<(il*2);
|
| 1778 |
+
for (int i = 0; i < 16; ++i) {
|
| 1779 |
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
|
| 1780 |
+
}
|
| 1781 |
+
#endif
|
| 1782 |
+
}
|
| 1783 |
+
|
| 1784 |
+
template <typename type4x4>
|
| 1785 |
+
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
|
| 1786 |
+
device const uint8_t * q = xb->qs;
|
| 1787 |
+
|
| 1788 |
+
#if QK_K == 256
|
| 1789 |
+
const float d = (float)(xb->d);
|
| 1790 |
+
const float min = (float)(xb->dmin);
|
| 1791 |
+
short is = (il/4) * 2;
|
| 1792 |
+
q = q + (il/4) * 32 + 16 * (il&1);
|
| 1793 |
+
il = il%4;
|
| 1794 |
+
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
| 1795 |
+
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
| 1796 |
+
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
| 1797 |
+
#else
|
| 1798 |
+
q = q + 16 * (il&1);
|
| 1799 |
+
device const uint8_t * s = xb->scales;
|
| 1800 |
+
device const half2 * dh = (device const half2 *)xb->d;
|
| 1801 |
+
const float2 d = (float2)dh[0];
|
| 1802 |
+
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
|
| 1803 |
+
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
|
| 1804 |
+
#endif
|
| 1805 |
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
| 1806 |
+
for (int i = 0; i < 16; ++i) {
|
| 1807 |
+
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
|
| 1808 |
+
}
|
| 1809 |
+
}
|
| 1810 |
+
|
| 1811 |
+
template <typename type4x4>
|
| 1812 |
+
void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
|
| 1813 |
+
device const uint8_t * q = xb->qs;
|
| 1814 |
+
device const uint8_t * qh = xb->qh;
|
| 1815 |
+
|
| 1816 |
+
#if QK_K == 256
|
| 1817 |
+
const float d = (float)(xb->d);
|
| 1818 |
+
const float min = (float)(xb->dmin);
|
| 1819 |
+
short is = (il/4) * 2;
|
| 1820 |
+
q = q + 32 * (il/4) + 16 * (il&1);
|
| 1821 |
+
qh = qh + 16 * (il&1);
|
| 1822 |
+
uint8_t ul = 1 << (il/2);
|
| 1823 |
+
il = il%4;
|
| 1824 |
+
const uchar4 sc = get_scale_min_k4(is, xb->scales);
|
| 1825 |
+
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
|
| 1826 |
+
const float ml = il<2 ? min * sc[1] : min * sc[3];
|
| 1827 |
+
|
| 1828 |
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
| 1829 |
+
const float qh_val = il<2 ? 16.f : 256.f;
|
| 1830 |
+
for (int i = 0; i < 16; ++i) {
|
| 1831 |
+
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
|
| 1832 |
+
}
|
| 1833 |
+
#else
|
| 1834 |
+
q = q + 16 * (il&1);
|
| 1835 |
+
device const int8_t * s = xb->scales;
|
| 1836 |
+
const float dl = xb->d * s[il];
|
| 1837 |
+
uint8_t m = 1<<(il*2);
|
| 1838 |
+
const float coef = il<2 ? 1.f : 1.f/16.f;
|
| 1839 |
+
const ushort mask = il<2 ? 0x0F : 0xF0;
|
| 1840 |
+
for (int i = 0; i < 16; ++i) {
|
| 1841 |
+
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
|
| 1842 |
+
}
|
| 1843 |
+
#endif
|
| 1844 |
+
}
|
| 1845 |
+
|
| 1846 |
+
template <typename type4x4>
|
| 1847 |
+
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
|
| 1848 |
+
const float d_all = (float)(xb->d);
|
| 1849 |
+
device const uint8_t * ql = (device const uint8_t *)xb->ql;
|
| 1850 |
+
device const uint8_t * qh = (device const uint8_t *)xb->qh;
|
| 1851 |
+
device const int8_t * scales = (device const int8_t *)xb->scales;
|
| 1852 |
+
|
| 1853 |
+
#if QK_K == 256
|
| 1854 |
+
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
|
| 1855 |
+
qh = qh + 32*(il/8) + 16*(il&1);
|
| 1856 |
+
float sc = scales[(il%2) + 2 * ((il/2))];
|
| 1857 |
+
il = (il/2)%4;
|
| 1858 |
+
#else
|
| 1859 |
+
ql = ql + 16 * (il&1);
|
| 1860 |
+
float sc = scales[il];
|
| 1861 |
+
#endif
|
| 1862 |
+
for (int i = 0; i < 16; ++i) {
|
| 1863 |
+
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
|
| 1864 |
+
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
|
| 1865 |
+
const float coef = il>1 ? 1.f/16.f : 1.f;
|
| 1866 |
+
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
|
| 1867 |
+
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
|
| 1868 |
+
reg[i/4][i%4] = d_all * sc * q * coef;
|
| 1869 |
}
|
| 1870 |
+
}
|
| 1871 |
+
|
| 1872 |
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
| 1873 |
+
kernel void kernel_get_rows(
|
| 1874 |
+
device const void * src0,
|
| 1875 |
+
device const int * src1,
|
| 1876 |
+
device float * dst,
|
| 1877 |
+
constant int64_t & ne00,
|
| 1878 |
+
constant uint64_t & nb01,
|
| 1879 |
+
constant uint64_t & nb1,
|
| 1880 |
+
uint tgpig[[threadgroup_position_in_grid]],
|
| 1881 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 1882 |
+
uint tptg[[threads_per_threadgroup]]) {
|
| 1883 |
+
const int i = tgpig;
|
| 1884 |
+
const int r = ((device int32_t *) src1)[i];
|
| 1885 |
|
| 1886 |
+
for (int ind = tiitg; ind < ne00/16; ind += tptg) {
|
| 1887 |
+
float4x4 temp;
|
| 1888 |
+
dequantize_func(
|
| 1889 |
+
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
| 1890 |
+
*(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
|
| 1891 |
+
}
|
| 1892 |
}
|
| 1893 |
+
|
| 1894 |
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
| 1895 |
+
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
|
| 1896 |
+
#define BLOCK_SIZE_K 32
|
| 1897 |
+
#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
|
| 1898 |
+
#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
|
| 1899 |
+
#define THREAD_PER_BLOCK 128
|
| 1900 |
+
#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
|
| 1901 |
+
#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
|
| 1902 |
+
#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
|
| 1903 |
+
#define SG_MAT_ROW 8
|
| 1904 |
+
|
| 1905 |
+
// each block_q contains 16*nl weights
|
| 1906 |
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 1907 |
+
kernel void kernel_mul_mm(device const uchar * src0,
|
| 1908 |
+
device const float * src1,
|
| 1909 |
+
device float * dst,
|
| 1910 |
+
constant int64_t & ne00,
|
| 1911 |
+
constant int64_t & ne02,
|
| 1912 |
+
constant int64_t & nb01,
|
| 1913 |
+
constant int64_t & nb02,
|
| 1914 |
+
constant int64_t & ne12,
|
| 1915 |
+
constant int64_t & ne0,
|
| 1916 |
+
constant int64_t & ne1,
|
| 1917 |
+
constant uint & gqa,
|
| 1918 |
+
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
| 1919 |
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 1920 |
+
uint tiitg[[thread_index_in_threadgroup]],
|
| 1921 |
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 1922 |
+
|
| 1923 |
+
threadgroup half * sa = ((threadgroup half *)shared_memory);
|
| 1924 |
+
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
| 1925 |
+
|
| 1926 |
+
const uint r0 = tgpig.y;
|
| 1927 |
+
const uint r1 = tgpig.x;
|
| 1928 |
+
const uint im = tgpig.z;
|
| 1929 |
+
// if this block is of 64x32 shape or smaller
|
| 1930 |
+
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
| 1931 |
+
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
| 1932 |
+
// a thread shouldn't load data outside of the matrix
|
| 1933 |
+
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
| 1934 |
+
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
| 1935 |
+
|
| 1936 |
+
simdgroup_half8x8 ma[4];
|
| 1937 |
+
simdgroup_float8x8 mb[2];
|
| 1938 |
+
simdgroup_float8x8 c_res[8];
|
| 1939 |
+
for (int i = 0; i < 8; i++){
|
| 1940 |
+
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
| 1941 |
+
}
|
| 1942 |
+
|
| 1943 |
+
short il = (tiitg % THREAD_PER_ROW);
|
| 1944 |
+
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
|
| 1945 |
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
| 1946 |
+
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
|
| 1947 |
+
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
|
| 1948 |
+
|
| 1949 |
+
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
| 1950 |
+
//load data and store to threadgroup memory
|
| 1951 |
+
half4x4 temp_a;
|
| 1952 |
+
dequantize_func(x, il, temp_a);
|
| 1953 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1954 |
+
#pragma unroll(16)
|
| 1955 |
+
for (int i = 0; i < 16; i++) {
|
| 1956 |
+
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
| 1957 |
+
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
|
| 1958 |
+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
| 1959 |
+
}
|
| 1960 |
+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
|
| 1961 |
+
= *((device float2x4 *)y);
|
| 1962 |
+
il = (il + 2 < nl) ? il + 2 : il % 2;
|
| 1963 |
+
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
| 1964 |
+
y += BLOCK_SIZE_K;
|
| 1965 |
+
|
| 1966 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1967 |
+
//load matrices from threadgroup memory and conduct outer products
|
| 1968 |
+
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
| 1969 |
+
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
| 1970 |
+
#pragma unroll(4)
|
| 1971 |
+
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
| 1972 |
+
#pragma unroll(4)
|
| 1973 |
+
for (int i = 0; i < 4; i++) {
|
| 1974 |
+
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
| 1975 |
+
}
|
| 1976 |
+
simdgroup_barrier(mem_flags::mem_none);
|
| 1977 |
+
#pragma unroll(2)
|
| 1978 |
+
for (int i = 0; i < 2; i++) {
|
| 1979 |
+
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
| 1980 |
+
}
|
| 1981 |
+
|
| 1982 |
+
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
| 1983 |
+
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
| 1984 |
+
#pragma unroll(8)
|
| 1985 |
+
for (int i = 0; i < 8; i++){
|
| 1986 |
+
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
|
| 1987 |
+
}
|
| 1988 |
+
}
|
| 1989 |
+
}
|
| 1990 |
+
|
| 1991 |
+
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
|
| 1992 |
+
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
|
| 1993 |
+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
|
| 1994 |
+
for (int i = 0; i < 8; i++) {
|
| 1995 |
+
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
|
| 1996 |
+
}
|
| 1997 |
+
} else {
|
| 1998 |
+
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
| 1999 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2000 |
+
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
|
| 2001 |
+
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
| 2002 |
+
for (int i = 0; i < 8; i++) {
|
| 2003 |
+
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
| 2004 |
+
}
|
| 2005 |
+
|
| 2006 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2007 |
+
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
|
| 2008 |
+
if (sgitg==0) {
|
| 2009 |
+
for (int i = 0; i < n_rows; i++) {
|
| 2010 |
+
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
|
| 2011 |
+
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
|
| 2012 |
+
}
|
| 2013 |
+
}
|
| 2014 |
+
}
|
| 2015 |
+
}
|
| 2016 |
+
}
|
| 2017 |
+
|
| 2018 |
+
#if QK_K == 256
|
| 2019 |
+
#define QK_NL 16
|
| 2020 |
+
#else
|
| 2021 |
+
#define QK_NL 4
|
| 2022 |
+
#endif
|
| 2023 |
+
|
| 2024 |
+
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
| 2025 |
+
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
| 2026 |
+
|
| 2027 |
+
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
| 2028 |
+
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
| 2029 |
+
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
| 2030 |
+
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
| 2031 |
+
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
| 2032 |
+
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
| 2033 |
+
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 2034 |
+
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 2035 |
+
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
| 2036 |
+
|
| 2037 |
+
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
|
| 2038 |
+
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
| 2039 |
+
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
| 2040 |
+
|
| 2041 |
+
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
| 2042 |
+
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
| 2043 |
+
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
| 2044 |
+
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
| 2045 |
+
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
| 2046 |
+
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
| 2047 |
+
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
| 2048 |
+
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
| 2049 |
+
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
ggml-opencl.cpp
CHANGED
|
@@ -656,10 +656,14 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
|
|
| 656 |
\n#if K_QUANTS_PER_ITERATION == 1\n
|
| 657 |
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
| 658 |
const int is = 0;
|
|
|
|
| 659 |
\n#else\n
|
|
|
|
| 660 |
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
| 661 |
const int is = in / 4;
|
|
|
|
| 662 |
\n#endif\n
|
|
|
|
| 663 |
const int ql_offset = 64*im + l0;
|
| 664 |
const int qh_offset = 32*im + l0;
|
| 665 |
const int s_offset = 8*im + is;
|
|
@@ -1376,7 +1380,7 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
|
|
| 1376 |
const int64_t ne00 = src0->ne[0];
|
| 1377 |
const int64_t ne01 = src0->ne[1];
|
| 1378 |
const int64_t ne02 = src0->ne[2];
|
| 1379 |
-
const int64_t ne03 = src0->ne[
|
| 1380 |
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
|
| 1381 |
const int64_t ne10 = src1->ne[0];
|
| 1382 |
const int64_t ne11 = src1->ne[1];
|
|
|
|
| 656 |
\n#if K_QUANTS_PER_ITERATION == 1\n
|
| 657 |
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
| 658 |
const int is = 0;
|
| 659 |
+
|
| 660 |
\n#else\n
|
| 661 |
+
|
| 662 |
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
| 663 |
const int is = in / 4;
|
| 664 |
+
|
| 665 |
\n#endif\n
|
| 666 |
+
|
| 667 |
const int ql_offset = 64*im + l0;
|
| 668 |
const int qh_offset = 32*im + l0;
|
| 669 |
const int s_offset = 8*im + is;
|
|
|
|
| 1380 |
const int64_t ne00 = src0->ne[0];
|
| 1381 |
const int64_t ne01 = src0->ne[1];
|
| 1382 |
const int64_t ne02 = src0->ne[2];
|
| 1383 |
+
const int64_t ne03 = src0->ne[3];
|
| 1384 |
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
|
| 1385 |
const int64_t ne10 = src1->ne[0];
|
| 1386 |
const int64_t ne11 = src1->ne[1];
|
ggml.c
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml.h
CHANGED
|
@@ -65,7 +65,7 @@
|
|
| 65 |
// ggml_set_f32(a, 3.0f);
|
| 66 |
// ggml_set_f32(b, 4.0f);
|
| 67 |
//
|
| 68 |
-
//
|
| 69 |
//
|
| 70 |
// printf("f = %f\n", ggml_get_f32_1d(f, 0));
|
| 71 |
//
|
|
@@ -130,13 +130,16 @@
|
|
| 130 |
// The data of the tensor is accessed via the "data" pointer. For example:
|
| 131 |
//
|
| 132 |
// {
|
| 133 |
-
//
|
|
|
|
| 134 |
//
|
| 135 |
-
//
|
| 136 |
-
// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
|
| 137 |
//
|
| 138 |
-
//
|
| 139 |
-
//
|
|
|
|
|
|
|
|
|
|
| 140 |
//
|
| 141 |
// ...
|
| 142 |
// }
|
|
@@ -183,6 +186,15 @@
|
|
| 183 |
# define GGML_API
|
| 184 |
#endif
|
| 185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
#include <stdint.h>
|
| 187 |
#include <stddef.h>
|
| 188 |
#include <stdbool.h>
|
|
@@ -197,12 +209,29 @@
|
|
| 197 |
#define GGML_MAX_NODES 4096
|
| 198 |
#define GGML_MAX_PARAMS 256
|
| 199 |
#define GGML_MAX_CONTEXTS 64
|
| 200 |
-
#define
|
| 201 |
-
#define GGML_MAX_NAME
|
|
|
|
| 202 |
#define GGML_DEFAULT_N_THREADS 4
|
| 203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
#define GGML_UNUSED(x) (void)(x)
|
| 205 |
|
|
|
|
|
|
|
| 206 |
#define GGML_ASSERT(x) \
|
| 207 |
do { \
|
| 208 |
if (!(x)) { \
|
|
@@ -239,8 +268,9 @@
|
|
| 239 |
extern "C" {
|
| 240 |
#endif
|
| 241 |
|
| 242 |
-
#
|
| 243 |
-
|
|
|
|
| 244 |
typedef __fp16 ggml_fp16_t;
|
| 245 |
#else
|
| 246 |
typedef uint16_t ggml_fp16_t;
|
|
@@ -250,8 +280,8 @@ extern "C" {
|
|
| 250 |
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
|
| 251 |
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
|
| 252 |
|
| 253 |
-
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y,
|
| 254 |
-
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y,
|
| 255 |
|
| 256 |
struct ggml_object;
|
| 257 |
struct ggml_context;
|
|
@@ -324,20 +354,12 @@ extern "C" {
|
|
| 324 |
GGML_OP_ARGMAX,
|
| 325 |
GGML_OP_REPEAT,
|
| 326 |
GGML_OP_REPEAT_BACK,
|
| 327 |
-
|
| 328 |
-
GGML_OP_SGN,
|
| 329 |
-
GGML_OP_NEG,
|
| 330 |
-
GGML_OP_STEP,
|
| 331 |
-
GGML_OP_TANH,
|
| 332 |
-
GGML_OP_ELU,
|
| 333 |
-
GGML_OP_RELU,
|
| 334 |
-
GGML_OP_GELU,
|
| 335 |
-
GGML_OP_GELU_QUICK,
|
| 336 |
-
GGML_OP_SILU,
|
| 337 |
GGML_OP_SILU_BACK,
|
| 338 |
GGML_OP_NORM, // normalize
|
| 339 |
GGML_OP_RMS_NORM,
|
| 340 |
GGML_OP_RMS_NORM_BACK,
|
|
|
|
| 341 |
|
| 342 |
GGML_OP_MUL_MAT,
|
| 343 |
GGML_OP_OUT_PROD,
|
|
@@ -363,16 +385,29 @@ extern "C" {
|
|
| 363 |
GGML_OP_CLAMP,
|
| 364 |
GGML_OP_CONV_1D,
|
| 365 |
GGML_OP_CONV_2D,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
|
| 367 |
GGML_OP_FLASH_ATTN,
|
| 368 |
GGML_OP_FLASH_FF,
|
| 369 |
GGML_OP_FLASH_ATTN_BACK,
|
| 370 |
GGML_OP_WIN_PART,
|
| 371 |
GGML_OP_WIN_UNPART,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
GGML_OP_MAP_UNARY,
|
| 374 |
GGML_OP_MAP_BINARY,
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
GGML_OP_MAP_CUSTOM1,
|
| 377 |
GGML_OP_MAP_CUSTOM2,
|
| 378 |
GGML_OP_MAP_CUSTOM3,
|
|
@@ -383,6 +418,24 @@ extern "C" {
|
|
| 383 |
GGML_OP_COUNT,
|
| 384 |
};
|
| 385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
// ggml object
|
| 388 |
struct ggml_object {
|
|
@@ -391,7 +444,9 @@ extern "C" {
|
|
| 391 |
|
| 392 |
struct ggml_object * next;
|
| 393 |
|
| 394 |
-
|
|
|
|
|
|
|
| 395 |
};
|
| 396 |
|
| 397 |
static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
|
|
@@ -411,15 +466,13 @@ extern "C" {
|
|
| 411 |
// compute data
|
| 412 |
enum ggml_op op;
|
| 413 |
|
|
|
|
|
|
|
|
|
|
| 414 |
bool is_param;
|
| 415 |
|
| 416 |
struct ggml_tensor * grad;
|
| 417 |
-
struct ggml_tensor *
|
| 418 |
-
struct ggml_tensor * src1;
|
| 419 |
-
struct ggml_tensor * opt[GGML_MAX_OPT];
|
| 420 |
-
|
| 421 |
-
// thread scheduling
|
| 422 |
-
int n_tasks;
|
| 423 |
|
| 424 |
// performance
|
| 425 |
int perf_runs;
|
|
@@ -437,25 +490,46 @@ extern "C" {
|
|
| 437 |
|
| 438 |
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
| 439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 440 |
// computation graph
|
| 441 |
struct ggml_cgraph {
|
| 442 |
int n_nodes;
|
| 443 |
int n_leafs;
|
| 444 |
-
int n_threads;
|
| 445 |
-
|
| 446 |
-
size_t work_size;
|
| 447 |
-
struct ggml_tensor * work;
|
| 448 |
|
| 449 |
struct ggml_tensor * nodes[GGML_MAX_NODES];
|
| 450 |
struct ggml_tensor * grads[GGML_MAX_NODES];
|
| 451 |
struct ggml_tensor * leafs[GGML_MAX_NODES];
|
| 452 |
|
|
|
|
|
|
|
| 453 |
// performance
|
| 454 |
int perf_runs;
|
| 455 |
int64_t perf_cycles;
|
| 456 |
int64_t perf_time_us;
|
| 457 |
};
|
| 458 |
|
|
|
|
|
|
|
| 459 |
// scratch buffer
|
| 460 |
struct ggml_scratch {
|
| 461 |
size_t offs;
|
|
@@ -509,6 +583,7 @@ extern "C" {
|
|
| 509 |
GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
|
| 510 |
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
|
| 511 |
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
|
|
|
| 512 |
GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
|
| 513 |
|
| 514 |
GGML_API int ggml_blck_size (enum ggml_type type);
|
|
@@ -517,6 +592,7 @@ extern "C" {
|
|
| 517 |
|
| 518 |
GGML_API const char * ggml_type_name(enum ggml_type type);
|
| 519 |
GGML_API const char * ggml_op_name (enum ggml_op op);
|
|
|
|
| 520 |
|
| 521 |
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
| 522 |
|
|
@@ -529,6 +605,8 @@ extern "C" {
|
|
| 529 |
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
| 530 |
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
| 531 |
|
|
|
|
|
|
|
| 532 |
// use this to compute the memory overhead of a tensor
|
| 533 |
GGML_API size_t ggml_tensor_overhead(void);
|
| 534 |
|
|
@@ -540,6 +618,7 @@ extern "C" {
|
|
| 540 |
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
|
| 541 |
|
| 542 |
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
|
|
|
|
| 543 |
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
|
| 544 |
|
| 545 |
GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
|
|
@@ -599,9 +678,11 @@ extern "C" {
|
|
| 599 |
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
| 600 |
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
| 601 |
|
| 602 |
-
GGML_API
|
| 603 |
-
|
| 604 |
-
GGML_API
|
|
|
|
|
|
|
| 605 |
|
| 606 |
//
|
| 607 |
// operations on tensors with backpropagation
|
|
@@ -611,6 +692,11 @@ extern "C" {
|
|
| 611 |
struct ggml_context * ctx,
|
| 612 |
struct ggml_tensor * a);
|
| 613 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
GGML_API struct ggml_tensor * ggml_add(
|
| 615 |
struct ggml_context * ctx,
|
| 616 |
struct ggml_tensor * a,
|
|
@@ -735,6 +821,13 @@ extern "C" {
|
|
| 735 |
struct ggml_tensor * a,
|
| 736 |
struct ggml_tensor * b);
|
| 737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
GGML_API struct ggml_tensor * ggml_abs(
|
| 739 |
struct ggml_context * ctx,
|
| 740 |
struct ggml_tensor * a);
|
|
@@ -824,25 +917,42 @@ extern "C" {
|
|
| 824 |
struct ggml_tensor * b);
|
| 825 |
|
| 826 |
// normalize along rows
|
| 827 |
-
// TODO: eps is hardcoded to 1e-5 for now
|
| 828 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 829 |
struct ggml_context * ctx,
|
| 830 |
-
struct ggml_tensor * a
|
|
|
|
| 831 |
|
| 832 |
GGML_API struct ggml_tensor * ggml_norm_inplace(
|
| 833 |
struct ggml_context * ctx,
|
| 834 |
-
struct ggml_tensor * a
|
|
|
|
| 835 |
|
| 836 |
GGML_API struct ggml_tensor * ggml_rms_norm(
|
| 837 |
struct ggml_context * ctx,
|
| 838 |
-
struct ggml_tensor * a
|
|
|
|
| 839 |
|
| 840 |
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
|
| 841 |
struct ggml_context * ctx,
|
| 842 |
-
struct ggml_tensor * a
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
|
| 844 |
// a - x
|
| 845 |
// b - dy
|
|
|
|
| 846 |
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
| 847 |
struct ggml_context * ctx,
|
| 848 |
struct ggml_tensor * a,
|
|
@@ -934,11 +1044,22 @@ extern "C" {
|
|
| 934 |
struct ggml_tensor * a,
|
| 935 |
struct ggml_tensor * b);
|
| 936 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
// make contiguous
|
| 938 |
GGML_API struct ggml_tensor * ggml_cont(
|
| 939 |
struct ggml_context * ctx,
|
| 940 |
struct ggml_tensor * a);
|
| 941 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 942 |
// return view(a), b specifies the new shape
|
| 943 |
// TODO: when we start computing gradient, make a copy instead of view
|
| 944 |
GGML_API struct ggml_tensor * ggml_reshape(
|
|
@@ -1107,6 +1228,37 @@ extern "C" {
|
|
| 1107 |
int mode,
|
| 1108 |
int n_ctx);
|
| 1109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1110 |
// rotary position embedding backward, i.e compute dx from dy
|
| 1111 |
// a - dy
|
| 1112 |
GGML_API struct ggml_tensor * ggml_rope_back(
|
|
@@ -1114,7 +1266,12 @@ extern "C" {
|
|
| 1114 |
struct ggml_tensor * a,
|
| 1115 |
int n_past,
|
| 1116 |
int n_dims,
|
| 1117 |
-
int mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1118 |
|
| 1119 |
// alibi position embedding
|
| 1120 |
// in-place, returns view(a)
|
|
@@ -1141,6 +1298,15 @@ extern "C" {
|
|
| 1141 |
int p0, // padding
|
| 1142 |
int d0); // dilation
|
| 1143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1144 |
GGML_API struct ggml_tensor * ggml_conv_2d(
|
| 1145 |
struct ggml_context * ctx,
|
| 1146 |
struct ggml_tensor * a,
|
|
@@ -1152,14 +1318,70 @@ extern "C" {
|
|
| 1152 |
int d0,
|
| 1153 |
int d1);
|
| 1154 |
|
| 1155 |
-
|
| 1156 |
-
//
|
| 1157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1158 |
struct ggml_context * ctx,
|
| 1159 |
struct ggml_tensor * a,
|
| 1160 |
struct ggml_tensor * b,
|
| 1161 |
-
int
|
| 1162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1163 |
|
| 1164 |
GGML_API struct ggml_tensor * ggml_flash_attn(
|
| 1165 |
struct ggml_context * ctx,
|
|
@@ -1204,6 +1426,37 @@ extern "C" {
|
|
| 1204 |
int h0,
|
| 1205 |
int w);
|
| 1206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1207 |
// custom operators
|
| 1208 |
|
| 1209 |
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
|
@@ -1213,63 +1466,129 @@ extern "C" {
|
|
| 1213 |
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
| 1214 |
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
| 1215 |
|
| 1216 |
-
GGML_API struct ggml_tensor * ggml_map_unary_f32(
|
| 1217 |
struct ggml_context * ctx,
|
| 1218 |
struct ggml_tensor * a,
|
| 1219 |
-
ggml_unary_op_f32_t fun)
|
|
|
|
| 1220 |
|
| 1221 |
-
GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
|
| 1222 |
struct ggml_context * ctx,
|
| 1223 |
struct ggml_tensor * a,
|
| 1224 |
-
ggml_unary_op_f32_t fun)
|
|
|
|
| 1225 |
|
| 1226 |
-
GGML_API struct ggml_tensor * ggml_map_binary_f32(
|
| 1227 |
struct ggml_context * ctx,
|
| 1228 |
struct ggml_tensor * a,
|
| 1229 |
struct ggml_tensor * b,
|
| 1230 |
-
ggml_binary_op_f32_t fun)
|
|
|
|
| 1231 |
|
| 1232 |
-
GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
|
| 1233 |
struct ggml_context * ctx,
|
| 1234 |
struct ggml_tensor * a,
|
| 1235 |
struct ggml_tensor * b,
|
| 1236 |
-
ggml_binary_op_f32_t fun)
|
|
|
|
| 1237 |
|
| 1238 |
-
GGML_API struct ggml_tensor * ggml_map_custom1_f32(
|
| 1239 |
struct ggml_context * ctx,
|
| 1240 |
struct ggml_tensor * a,
|
| 1241 |
-
ggml_custom1_op_f32_t fun)
|
|
|
|
| 1242 |
|
| 1243 |
-
GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
|
| 1244 |
struct ggml_context * ctx,
|
| 1245 |
struct ggml_tensor * a,
|
| 1246 |
-
ggml_custom1_op_f32_t fun)
|
|
|
|
| 1247 |
|
| 1248 |
-
GGML_API struct ggml_tensor * ggml_map_custom2_f32(
|
| 1249 |
struct ggml_context * ctx,
|
| 1250 |
struct ggml_tensor * a,
|
| 1251 |
struct ggml_tensor * b,
|
| 1252 |
-
ggml_custom2_op_f32_t fun)
|
|
|
|
| 1253 |
|
| 1254 |
-
GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
|
| 1255 |
struct ggml_context * ctx,
|
| 1256 |
struct ggml_tensor * a,
|
| 1257 |
struct ggml_tensor * b,
|
| 1258 |
-
ggml_custom2_op_f32_t fun)
|
|
|
|
| 1259 |
|
| 1260 |
-
GGML_API struct ggml_tensor * ggml_map_custom3_f32(
|
| 1261 |
struct ggml_context * ctx,
|
| 1262 |
struct ggml_tensor * a,
|
| 1263 |
struct ggml_tensor * b,
|
| 1264 |
struct ggml_tensor * c,
|
| 1265 |
-
ggml_custom3_op_f32_t fun)
|
|
|
|
| 1266 |
|
| 1267 |
-
GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
|
| 1268 |
struct ggml_context * ctx,
|
| 1269 |
struct ggml_tensor * a,
|
| 1270 |
struct ggml_tensor * b,
|
| 1271 |
struct ggml_tensor * c,
|
| 1272 |
-
ggml_custom3_op_f32_t fun)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
|
| 1274 |
// loss function
|
| 1275 |
|
|
@@ -1290,15 +1609,28 @@ extern "C" {
|
|
| 1290 |
|
| 1291 |
GGML_API void ggml_set_param(
|
| 1292 |
struct ggml_context * ctx,
|
| 1293 |
-
struct ggml_tensor
|
|
|
|
| 1294 |
|
| 1295 |
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
| 1296 |
|
| 1297 |
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
| 1298 |
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
|
| 1299 |
|
| 1300 |
-
|
| 1301 |
-
GGML_API
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1302 |
|
| 1303 |
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
|
| 1304 |
|
|
@@ -1488,6 +1820,127 @@ extern "C" {
|
|
| 1488 |
|
| 1489 |
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
|
| 1490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1491 |
//
|
| 1492 |
// system info
|
| 1493 |
//
|
|
@@ -1516,25 +1969,28 @@ extern "C" {
|
|
| 1516 |
//
|
| 1517 |
|
| 1518 |
#ifdef __cplusplus
|
| 1519 |
-
|
| 1520 |
#define GGML_RESTRICT
|
| 1521 |
#else
|
| 1522 |
#define GGML_RESTRICT restrict
|
| 1523 |
#endif
|
| 1524 |
-
typedef void (*
|
| 1525 |
-
typedef void (*
|
| 1526 |
-
typedef void (*
|
| 1527 |
|
| 1528 |
typedef struct {
|
| 1529 |
-
|
| 1530 |
-
|
| 1531 |
-
|
| 1532 |
-
|
| 1533 |
-
|
| 1534 |
-
|
| 1535 |
-
|
| 1536 |
-
|
| 1537 |
-
|
|
|
|
|
|
|
|
|
|
| 1538 |
|
| 1539 |
#ifdef __cplusplus
|
| 1540 |
}
|
|
|
|
| 65 |
// ggml_set_f32(a, 3.0f);
|
| 66 |
// ggml_set_f32(b, 4.0f);
|
| 67 |
//
|
| 68 |
+
// ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
|
| 69 |
//
|
| 70 |
// printf("f = %f\n", ggml_get_f32_1d(f, 0));
|
| 71 |
//
|
|
|
|
| 130 |
// The data of the tensor is accessed via the "data" pointer. For example:
|
| 131 |
//
|
| 132 |
// {
|
| 133 |
+
// const int nx = 2;
|
| 134 |
+
// const int ny = 3;
|
| 135 |
//
|
| 136 |
+
// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);
|
|
|
|
| 137 |
//
|
| 138 |
+
// for (int y = 0; y < ny; y++) {
|
| 139 |
+
// for (int x = 0; x < nx; x++) {
|
| 140 |
+
// *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
|
| 141 |
+
// }
|
| 142 |
+
// }
|
| 143 |
//
|
| 144 |
// ...
|
| 145 |
// }
|
|
|
|
| 186 |
# define GGML_API
|
| 187 |
#endif
|
| 188 |
|
| 189 |
+
// TODO: support for clang
|
| 190 |
+
#ifdef __GNUC__
|
| 191 |
+
# define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
|
| 192 |
+
#elif defined(_MSC_VER)
|
| 193 |
+
# define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
|
| 194 |
+
#else
|
| 195 |
+
# define GGML_DEPRECATED(func, hint) func
|
| 196 |
+
#endif
|
| 197 |
+
|
| 198 |
#include <stdint.h>
|
| 199 |
#include <stddef.h>
|
| 200 |
#include <stdbool.h>
|
|
|
|
| 209 |
#define GGML_MAX_NODES 4096
|
| 210 |
#define GGML_MAX_PARAMS 256
|
| 211 |
#define GGML_MAX_CONTEXTS 64
|
| 212 |
+
#define GGML_MAX_SRC 6
|
| 213 |
+
#define GGML_MAX_NAME 64
|
| 214 |
+
#define GGML_MAX_OP_PARAMS 32
|
| 215 |
#define GGML_DEFAULT_N_THREADS 4
|
| 216 |
|
| 217 |
+
#if UINTPTR_MAX == 0xFFFFFFFF
|
| 218 |
+
#define GGML_MEM_ALIGN 4
|
| 219 |
+
#else
|
| 220 |
+
#define GGML_MEM_ALIGN 16
|
| 221 |
+
#endif
|
| 222 |
+
|
| 223 |
+
#define GGML_EXIT_SUCCESS 0
|
| 224 |
+
#define GGML_EXIT_ABORTED 1
|
| 225 |
+
|
| 226 |
+
#define GGUF_MAGIC 0x46554747 // "GGUF"
|
| 227 |
+
#define GGUF_VERSION 2
|
| 228 |
+
|
| 229 |
+
#define GGUF_DEFAULT_ALIGNMENT 32
|
| 230 |
+
|
| 231 |
#define GGML_UNUSED(x) (void)(x)
|
| 232 |
|
| 233 |
+
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
|
| 234 |
+
|
| 235 |
#define GGML_ASSERT(x) \
|
| 236 |
do { \
|
| 237 |
if (!(x)) { \
|
|
|
|
| 268 |
extern "C" {
|
| 269 |
#endif
|
| 270 |
|
| 271 |
+
#if defined(__ARM_NEON) && defined(__CUDACC__)
|
| 272 |
+
typedef half ggml_fp16_t;
|
| 273 |
+
#elif defined(__ARM_NEON)
|
| 274 |
typedef __fp16 ggml_fp16_t;
|
| 275 |
#else
|
| 276 |
typedef uint16_t ggml_fp16_t;
|
|
|
|
| 280 |
GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
|
| 281 |
GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
|
| 282 |
|
| 283 |
+
GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n);
|
| 284 |
+
GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n);
|
| 285 |
|
| 286 |
struct ggml_object;
|
| 287 |
struct ggml_context;
|
|
|
|
| 354 |
GGML_OP_ARGMAX,
|
| 355 |
GGML_OP_REPEAT,
|
| 356 |
GGML_OP_REPEAT_BACK,
|
| 357 |
+
GGML_OP_CONCAT,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
GGML_OP_SILU_BACK,
|
| 359 |
GGML_OP_NORM, // normalize
|
| 360 |
GGML_OP_RMS_NORM,
|
| 361 |
GGML_OP_RMS_NORM_BACK,
|
| 362 |
+
GGML_OP_GROUP_NORM,
|
| 363 |
|
| 364 |
GGML_OP_MUL_MAT,
|
| 365 |
GGML_OP_OUT_PROD,
|
|
|
|
| 385 |
GGML_OP_CLAMP,
|
| 386 |
GGML_OP_CONV_1D,
|
| 387 |
GGML_OP_CONV_2D,
|
| 388 |
+
GGML_OP_CONV_TRANSPOSE_2D,
|
| 389 |
+
GGML_OP_POOL_1D,
|
| 390 |
+
GGML_OP_POOL_2D,
|
| 391 |
+
|
| 392 |
+
GGML_OP_UPSCALE, // nearest interpolate
|
| 393 |
|
| 394 |
GGML_OP_FLASH_ATTN,
|
| 395 |
GGML_OP_FLASH_FF,
|
| 396 |
GGML_OP_FLASH_ATTN_BACK,
|
| 397 |
GGML_OP_WIN_PART,
|
| 398 |
GGML_OP_WIN_UNPART,
|
| 399 |
+
GGML_OP_GET_REL_POS,
|
| 400 |
+
GGML_OP_ADD_REL_POS,
|
| 401 |
+
|
| 402 |
+
GGML_OP_UNARY,
|
| 403 |
|
| 404 |
GGML_OP_MAP_UNARY,
|
| 405 |
GGML_OP_MAP_BINARY,
|
| 406 |
|
| 407 |
+
GGML_OP_MAP_CUSTOM1_F32,
|
| 408 |
+
GGML_OP_MAP_CUSTOM2_F32,
|
| 409 |
+
GGML_OP_MAP_CUSTOM3_F32,
|
| 410 |
+
|
| 411 |
GGML_OP_MAP_CUSTOM1,
|
| 412 |
GGML_OP_MAP_CUSTOM2,
|
| 413 |
GGML_OP_MAP_CUSTOM3,
|
|
|
|
| 418 |
GGML_OP_COUNT,
|
| 419 |
};
|
| 420 |
|
| 421 |
+
enum ggml_unary_op {
|
| 422 |
+
GGML_UNARY_OP_ABS,
|
| 423 |
+
GGML_UNARY_OP_SGN,
|
| 424 |
+
GGML_UNARY_OP_NEG,
|
| 425 |
+
GGML_UNARY_OP_STEP,
|
| 426 |
+
GGML_UNARY_OP_TANH,
|
| 427 |
+
GGML_UNARY_OP_ELU,
|
| 428 |
+
GGML_UNARY_OP_RELU,
|
| 429 |
+
GGML_UNARY_OP_GELU,
|
| 430 |
+
GGML_UNARY_OP_GELU_QUICK,
|
| 431 |
+
GGML_UNARY_OP_SILU,
|
| 432 |
+
};
|
| 433 |
+
|
| 434 |
+
enum ggml_object_type {
|
| 435 |
+
GGML_OBJECT_TENSOR,
|
| 436 |
+
GGML_OBJECT_GRAPH,
|
| 437 |
+
GGML_OBJECT_WORK_BUFFER
|
| 438 |
+
};
|
| 439 |
|
| 440 |
// ggml object
|
| 441 |
struct ggml_object {
|
|
|
|
| 444 |
|
| 445 |
struct ggml_object * next;
|
| 446 |
|
| 447 |
+
enum ggml_object_type type;
|
| 448 |
+
|
| 449 |
+
char padding[4];
|
| 450 |
};
|
| 451 |
|
| 452 |
static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
|
|
|
|
| 466 |
// compute data
|
| 467 |
enum ggml_op op;
|
| 468 |
|
| 469 |
+
// op params - allocated as int32_t for alignment
|
| 470 |
+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
| 471 |
+
|
| 472 |
bool is_param;
|
| 473 |
|
| 474 |
struct ggml_tensor * grad;
|
| 475 |
+
struct ggml_tensor * src[GGML_MAX_SRC];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
// performance
|
| 478 |
int perf_runs;
|
|
|
|
| 490 |
|
| 491 |
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
| 492 |
|
| 493 |
+
// the compute plan that needs to be prepared for ggml_graph_compute()
|
| 494 |
+
// since https://github.com/ggerganov/ggml/issues/287
|
| 495 |
+
struct ggml_cplan {
|
| 496 |
+
size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
|
| 497 |
+
uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
|
| 498 |
+
|
| 499 |
+
int n_threads;
|
| 500 |
+
|
| 501 |
+
// the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
|
| 502 |
+
int n_tasks[GGML_MAX_NODES];
|
| 503 |
+
|
| 504 |
+
// abort ggml_graph_compute when true
|
| 505 |
+
bool (*abort_callback)(void * data);
|
| 506 |
+
void * abort_callback_data;
|
| 507 |
+
};
|
| 508 |
+
|
| 509 |
+
// next prime after GGML_MAX_NODES
|
| 510 |
+
// #define GGML_GRAPH_HASHTABLE_SIZE 4099
|
| 511 |
+
// next prime after GGML_MAX_NODES * 2 (nodes + leafs)
|
| 512 |
+
#define GGML_GRAPH_HASHTABLE_SIZE 8273
|
| 513 |
+
|
| 514 |
// computation graph
|
| 515 |
struct ggml_cgraph {
|
| 516 |
int n_nodes;
|
| 517 |
int n_leafs;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
struct ggml_tensor * nodes[GGML_MAX_NODES];
|
| 520 |
struct ggml_tensor * grads[GGML_MAX_NODES];
|
| 521 |
struct ggml_tensor * leafs[GGML_MAX_NODES];
|
| 522 |
|
| 523 |
+
void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
|
| 524 |
+
|
| 525 |
// performance
|
| 526 |
int perf_runs;
|
| 527 |
int64_t perf_cycles;
|
| 528 |
int64_t perf_time_us;
|
| 529 |
};
|
| 530 |
|
| 531 |
+
static const size_t GGML_GRAPH_SIZE = sizeof(struct ggml_cgraph);
|
| 532 |
+
|
| 533 |
// scratch buffer
|
| 534 |
struct ggml_scratch {
|
| 535 |
size_t offs;
|
|
|
|
| 583 |
GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
|
| 584 |
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
|
| 585 |
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
| 586 |
+
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
|
| 587 |
GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
|
| 588 |
|
| 589 |
GGML_API int ggml_blck_size (enum ggml_type type);
|
|
|
|
| 592 |
|
| 593 |
GGML_API const char * ggml_type_name(enum ggml_type type);
|
| 594 |
GGML_API const char * ggml_op_name (enum ggml_op op);
|
| 595 |
+
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
| 596 |
|
| 597 |
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
| 598 |
|
|
|
|
| 605 |
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
| 606 |
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
| 607 |
|
| 608 |
+
GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
| 609 |
+
|
| 610 |
// use this to compute the memory overhead of a tensor
|
| 611 |
GGML_API size_t ggml_tensor_overhead(void);
|
| 612 |
|
|
|
|
| 618 |
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
|
| 619 |
|
| 620 |
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
|
| 621 |
+
GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
|
| 622 |
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
|
| 623 |
|
| 624 |
GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
|
|
|
|
| 678 |
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
| 679 |
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
| 680 |
|
| 681 |
+
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
|
| 682 |
+
|
| 683 |
+
GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
|
| 684 |
+
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
|
| 685 |
+
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
|
| 686 |
|
| 687 |
//
|
| 688 |
// operations on tensors with backpropagation
|
|
|
|
| 692 |
struct ggml_context * ctx,
|
| 693 |
struct ggml_tensor * a);
|
| 694 |
|
| 695 |
+
// in-place, returns view(a)
|
| 696 |
+
GGML_API struct ggml_tensor * ggml_dup_inplace(
|
| 697 |
+
struct ggml_context * ctx,
|
| 698 |
+
struct ggml_tensor * a);
|
| 699 |
+
|
| 700 |
GGML_API struct ggml_tensor * ggml_add(
|
| 701 |
struct ggml_context * ctx,
|
| 702 |
struct ggml_tensor * a,
|
|
|
|
| 821 |
struct ggml_tensor * a,
|
| 822 |
struct ggml_tensor * b);
|
| 823 |
|
| 824 |
+
// concat a and b on dim 2
|
| 825 |
+
// used in stable-diffusion
|
| 826 |
+
GGML_API struct ggml_tensor * ggml_concat(
|
| 827 |
+
struct ggml_context * ctx,
|
| 828 |
+
struct ggml_tensor * a,
|
| 829 |
+
struct ggml_tensor * b);
|
| 830 |
+
|
| 831 |
GGML_API struct ggml_tensor * ggml_abs(
|
| 832 |
struct ggml_context * ctx,
|
| 833 |
struct ggml_tensor * a);
|
|
|
|
| 917 |
struct ggml_tensor * b);
|
| 918 |
|
| 919 |
// normalize along rows
|
|
|
|
| 920 |
GGML_API struct ggml_tensor * ggml_norm(
|
| 921 |
struct ggml_context * ctx,
|
| 922 |
+
struct ggml_tensor * a,
|
| 923 |
+
float eps);
|
| 924 |
|
| 925 |
GGML_API struct ggml_tensor * ggml_norm_inplace(
|
| 926 |
struct ggml_context * ctx,
|
| 927 |
+
struct ggml_tensor * a,
|
| 928 |
+
float eps);
|
| 929 |
|
| 930 |
GGML_API struct ggml_tensor * ggml_rms_norm(
|
| 931 |
struct ggml_context * ctx,
|
| 932 |
+
struct ggml_tensor * a,
|
| 933 |
+
float eps);
|
| 934 |
|
| 935 |
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
|
| 936 |
struct ggml_context * ctx,
|
| 937 |
+
struct ggml_tensor * a,
|
| 938 |
+
float eps);
|
| 939 |
+
|
| 940 |
+
// group normalize along ne0*ne1*n_groups
|
| 941 |
+
// used in stable-diffusion
|
| 942 |
+
// TODO: eps is hardcoded to 1e-6 for now
|
| 943 |
+
GGML_API struct ggml_tensor * ggml_group_norm(
|
| 944 |
+
struct ggml_context * ctx,
|
| 945 |
+
struct ggml_tensor * a,
|
| 946 |
+
int n_groups);
|
| 947 |
+
|
| 948 |
+
GGML_API struct ggml_tensor * ggml_group_norm_inplace(
|
| 949 |
+
struct ggml_context * ctx,
|
| 950 |
+
struct ggml_tensor * a,
|
| 951 |
+
int n_groups);
|
| 952 |
|
| 953 |
// a - x
|
| 954 |
// b - dy
|
| 955 |
+
// TODO: update with configurable eps
|
| 956 |
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
| 957 |
struct ggml_context * ctx,
|
| 958 |
struct ggml_tensor * a,
|
|
|
|
| 1044 |
struct ggml_tensor * a,
|
| 1045 |
struct ggml_tensor * b);
|
| 1046 |
|
| 1047 |
+
// a -> b, in-place, return view(b)
|
| 1048 |
+
GGML_API struct ggml_tensor * ggml_cpy_inplace(
|
| 1049 |
+
struct ggml_context * ctx,
|
| 1050 |
+
struct ggml_tensor * a,
|
| 1051 |
+
struct ggml_tensor * b);
|
| 1052 |
+
|
| 1053 |
// make contiguous
|
| 1054 |
GGML_API struct ggml_tensor * ggml_cont(
|
| 1055 |
struct ggml_context * ctx,
|
| 1056 |
struct ggml_tensor * a);
|
| 1057 |
|
| 1058 |
+
// make contiguous, in-place
|
| 1059 |
+
GGML_API struct ggml_tensor * ggml_cont_inplace(
|
| 1060 |
+
struct ggml_context * ctx,
|
| 1061 |
+
struct ggml_tensor * a);
|
| 1062 |
+
|
| 1063 |
// return view(a), b specifies the new shape
|
| 1064 |
// TODO: when we start computing gradient, make a copy instead of view
|
| 1065 |
GGML_API struct ggml_tensor * ggml_reshape(
|
|
|
|
| 1228 |
int mode,
|
| 1229 |
int n_ctx);
|
| 1230 |
|
| 1231 |
+
// custom RoPE
|
| 1232 |
+
GGML_API struct ggml_tensor * ggml_rope_custom(
|
| 1233 |
+
struct ggml_context * ctx,
|
| 1234 |
+
struct ggml_tensor * a,
|
| 1235 |
+
int n_past,
|
| 1236 |
+
int n_dims,
|
| 1237 |
+
int mode,
|
| 1238 |
+
int n_ctx,
|
| 1239 |
+
float freq_base,
|
| 1240 |
+
float freq_scale);
|
| 1241 |
+
|
| 1242 |
+
// in-place, returns view(a)
|
| 1243 |
+
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
|
| 1244 |
+
struct ggml_context * ctx,
|
| 1245 |
+
struct ggml_tensor * a,
|
| 1246 |
+
int n_past,
|
| 1247 |
+
int n_dims,
|
| 1248 |
+
int mode,
|
| 1249 |
+
int n_ctx,
|
| 1250 |
+
float freq_base,
|
| 1251 |
+
float freq_scale);
|
| 1252 |
+
|
| 1253 |
+
// xPos RoPE, in-place, returns view(a)
|
| 1254 |
+
GGML_API struct ggml_tensor * ggml_rope_xpos_inplace(
|
| 1255 |
+
struct ggml_context * ctx,
|
| 1256 |
+
struct ggml_tensor * a,
|
| 1257 |
+
int n_past,
|
| 1258 |
+
int n_dims,
|
| 1259 |
+
float base,
|
| 1260 |
+
bool down);
|
| 1261 |
+
|
| 1262 |
// rotary position embedding backward, i.e compute dx from dy
|
| 1263 |
// a - dy
|
| 1264 |
GGML_API struct ggml_tensor * ggml_rope_back(
|
|
|
|
| 1266 |
struct ggml_tensor * a,
|
| 1267 |
int n_past,
|
| 1268 |
int n_dims,
|
| 1269 |
+
int mode,
|
| 1270 |
+
int n_ctx,
|
| 1271 |
+
float freq_base,
|
| 1272 |
+
float freq_scale,
|
| 1273 |
+
float xpos_base,
|
| 1274 |
+
bool xpos_down);
|
| 1275 |
|
| 1276 |
// alibi position embedding
|
| 1277 |
// in-place, returns view(a)
|
|
|
|
| 1298 |
int p0, // padding
|
| 1299 |
int d0); // dilation
|
| 1300 |
|
| 1301 |
+
// conv_1d with padding = half
|
| 1302 |
+
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
|
| 1303 |
+
GGML_API struct ggml_tensor* ggml_conv_1d_ph(
|
| 1304 |
+
struct ggml_context * ctx,
|
| 1305 |
+
struct ggml_tensor * a,
|
| 1306 |
+
struct ggml_tensor * b,
|
| 1307 |
+
int s,
|
| 1308 |
+
int d);
|
| 1309 |
+
|
| 1310 |
GGML_API struct ggml_tensor * ggml_conv_2d(
|
| 1311 |
struct ggml_context * ctx,
|
| 1312 |
struct ggml_tensor * a,
|
|
|
|
| 1318 |
int d0,
|
| 1319 |
int d1);
|
| 1320 |
|
| 1321 |
+
|
| 1322 |
+
// kernel size is a->ne[0] x a->ne[1]
|
| 1323 |
+
// stride is equal to kernel size
|
| 1324 |
+
// padding is zero
|
| 1325 |
+
// example:
|
| 1326 |
+
// a: 16 16 3 768
|
| 1327 |
+
// b: 1024 1024 3 1
|
| 1328 |
+
// res: 64 64 768 1
|
| 1329 |
+
// used in sam
|
| 1330 |
+
GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
|
| 1331 |
+
struct ggml_context * ctx,
|
| 1332 |
+
struct ggml_tensor * a,
|
| 1333 |
+
struct ggml_tensor * b);
|
| 1334 |
+
|
| 1335 |
+
// kernel size is a->ne[0] x a->ne[1]
|
| 1336 |
+
// stride is 1
|
| 1337 |
+
// padding is half
|
| 1338 |
+
// example:
|
| 1339 |
+
// a: 3 3 256 256
|
| 1340 |
+
// b: 64 64 256 1
|
| 1341 |
+
// res: 64 64 256 1
|
| 1342 |
+
// used in sam
|
| 1343 |
+
GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
|
| 1344 |
+
struct ggml_context * ctx,
|
| 1345 |
+
struct ggml_tensor * a,
|
| 1346 |
+
struct ggml_tensor * b);
|
| 1347 |
+
|
| 1348 |
+
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
|
| 1349 |
struct ggml_context * ctx,
|
| 1350 |
struct ggml_tensor * a,
|
| 1351 |
struct ggml_tensor * b,
|
| 1352 |
+
int stride);
|
| 1353 |
+
|
| 1354 |
+
enum ggml_op_pool {
|
| 1355 |
+
GGML_OP_POOL_MAX,
|
| 1356 |
+
GGML_OP_POOL_AVG,
|
| 1357 |
+
GGML_OP_POOL_COUNT,
|
| 1358 |
+
};
|
| 1359 |
+
|
| 1360 |
+
GGML_API struct ggml_tensor * ggml_pool_1d(
|
| 1361 |
+
struct ggml_context * ctx,
|
| 1362 |
+
struct ggml_tensor * a,
|
| 1363 |
+
enum ggml_op_pool op,
|
| 1364 |
+
int k0, // kernel size
|
| 1365 |
+
int s0, // stride
|
| 1366 |
+
int p0); // padding
|
| 1367 |
+
|
| 1368 |
+
GGML_API struct ggml_tensor * ggml_pool_2d(
|
| 1369 |
+
struct ggml_context * ctx,
|
| 1370 |
+
struct ggml_tensor * a,
|
| 1371 |
+
enum ggml_op_pool op,
|
| 1372 |
+
int k0,
|
| 1373 |
+
int k1,
|
| 1374 |
+
int s0,
|
| 1375 |
+
int s1,
|
| 1376 |
+
int p0,
|
| 1377 |
+
int p1);
|
| 1378 |
+
|
| 1379 |
+
// nearest interpolate
|
| 1380 |
+
// used in stable-diffusion
|
| 1381 |
+
GGML_API struct ggml_tensor * ggml_upscale(
|
| 1382 |
+
struct ggml_context * ctx,
|
| 1383 |
+
struct ggml_tensor * a,
|
| 1384 |
+
int scale_factor);
|
| 1385 |
|
| 1386 |
GGML_API struct ggml_tensor * ggml_flash_attn(
|
| 1387 |
struct ggml_context * ctx,
|
|
|
|
| 1426 |
int h0,
|
| 1427 |
int w);
|
| 1428 |
|
| 1429 |
+
GGML_API struct ggml_tensor * ggml_unary(
|
| 1430 |
+
struct ggml_context * ctx,
|
| 1431 |
+
struct ggml_tensor * a,
|
| 1432 |
+
enum ggml_unary_op op);
|
| 1433 |
+
|
| 1434 |
+
GGML_API struct ggml_tensor * ggml_unary_inplace(
|
| 1435 |
+
struct ggml_context * ctx,
|
| 1436 |
+
struct ggml_tensor * a,
|
| 1437 |
+
enum ggml_unary_op op);
|
| 1438 |
+
|
| 1439 |
+
// used in sam
|
| 1440 |
+
GGML_API struct ggml_tensor * ggml_get_rel_pos(
|
| 1441 |
+
struct ggml_context * ctx,
|
| 1442 |
+
struct ggml_tensor * a,
|
| 1443 |
+
int qh,
|
| 1444 |
+
int kh);
|
| 1445 |
+
|
| 1446 |
+
// used in sam
|
| 1447 |
+
|
| 1448 |
+
GGML_API struct ggml_tensor * ggml_add_rel_pos(
|
| 1449 |
+
struct ggml_context * ctx,
|
| 1450 |
+
struct ggml_tensor * a,
|
| 1451 |
+
struct ggml_tensor * pw,
|
| 1452 |
+
struct ggml_tensor * ph);
|
| 1453 |
+
|
| 1454 |
+
GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
|
| 1455 |
+
struct ggml_context * ctx,
|
| 1456 |
+
struct ggml_tensor * a,
|
| 1457 |
+
struct ggml_tensor * pw,
|
| 1458 |
+
struct ggml_tensor * ph);
|
| 1459 |
+
|
| 1460 |
// custom operators
|
| 1461 |
|
| 1462 |
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
|
|
|
| 1466 |
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
| 1467 |
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
| 1468 |
|
| 1469 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32(
|
| 1470 |
struct ggml_context * ctx,
|
| 1471 |
struct ggml_tensor * a,
|
| 1472 |
+
ggml_unary_op_f32_t fun),
|
| 1473 |
+
"use ggml_map_custom1 instead");
|
| 1474 |
|
| 1475 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32(
|
| 1476 |
struct ggml_context * ctx,
|
| 1477 |
struct ggml_tensor * a,
|
| 1478 |
+
ggml_unary_op_f32_t fun),
|
| 1479 |
+
"use ggml_map_custom1_inplace instead");
|
| 1480 |
|
| 1481 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32(
|
| 1482 |
struct ggml_context * ctx,
|
| 1483 |
struct ggml_tensor * a,
|
| 1484 |
struct ggml_tensor * b,
|
| 1485 |
+
ggml_binary_op_f32_t fun),
|
| 1486 |
+
"use ggml_map_custom2 instead");
|
| 1487 |
|
| 1488 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32(
|
| 1489 |
struct ggml_context * ctx,
|
| 1490 |
struct ggml_tensor * a,
|
| 1491 |
struct ggml_tensor * b,
|
| 1492 |
+
ggml_binary_op_f32_t fun),
|
| 1493 |
+
"use ggml_map_custom2_inplace instead");
|
| 1494 |
|
| 1495 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32(
|
| 1496 |
struct ggml_context * ctx,
|
| 1497 |
struct ggml_tensor * a,
|
| 1498 |
+
ggml_custom1_op_f32_t fun),
|
| 1499 |
+
"use ggml_map_custom1 instead");
|
| 1500 |
|
| 1501 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32(
|
| 1502 |
struct ggml_context * ctx,
|
| 1503 |
struct ggml_tensor * a,
|
| 1504 |
+
ggml_custom1_op_f32_t fun),
|
| 1505 |
+
"use ggml_map_custom1_inplace instead");
|
| 1506 |
|
| 1507 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32(
|
| 1508 |
struct ggml_context * ctx,
|
| 1509 |
struct ggml_tensor * a,
|
| 1510 |
struct ggml_tensor * b,
|
| 1511 |
+
ggml_custom2_op_f32_t fun),
|
| 1512 |
+
"use ggml_map_custom2 instead");
|
| 1513 |
|
| 1514 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32(
|
| 1515 |
struct ggml_context * ctx,
|
| 1516 |
struct ggml_tensor * a,
|
| 1517 |
struct ggml_tensor * b,
|
| 1518 |
+
ggml_custom2_op_f32_t fun),
|
| 1519 |
+
"use ggml_map_custom2_inplace instead");
|
| 1520 |
|
| 1521 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32(
|
| 1522 |
struct ggml_context * ctx,
|
| 1523 |
struct ggml_tensor * a,
|
| 1524 |
struct ggml_tensor * b,
|
| 1525 |
struct ggml_tensor * c,
|
| 1526 |
+
ggml_custom3_op_f32_t fun),
|
| 1527 |
+
"use ggml_map_custom3 instead");
|
| 1528 |
|
| 1529 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32(
|
| 1530 |
struct ggml_context * ctx,
|
| 1531 |
struct ggml_tensor * a,
|
| 1532 |
struct ggml_tensor * b,
|
| 1533 |
struct ggml_tensor * c,
|
| 1534 |
+
ggml_custom3_op_f32_t fun),
|
| 1535 |
+
"use ggml_map_custom3_inplace instead");
|
| 1536 |
+
|
| 1537 |
+
// custom operators v2
|
| 1538 |
+
|
| 1539 |
+
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
|
| 1540 |
+
typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
|
| 1541 |
+
typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
|
| 1542 |
+
|
| 1543 |
+
#define GGML_N_TASKS_MAX -1
|
| 1544 |
+
|
| 1545 |
+
GGML_API struct ggml_tensor * ggml_map_custom1(
|
| 1546 |
+
struct ggml_context * ctx,
|
| 1547 |
+
struct ggml_tensor * a,
|
| 1548 |
+
ggml_custom1_op_t fun,
|
| 1549 |
+
int n_tasks,
|
| 1550 |
+
void * userdata);
|
| 1551 |
+
|
| 1552 |
+
GGML_API struct ggml_tensor * ggml_map_custom1_inplace(
|
| 1553 |
+
struct ggml_context * ctx,
|
| 1554 |
+
struct ggml_tensor * a,
|
| 1555 |
+
ggml_custom1_op_t fun,
|
| 1556 |
+
int n_tasks,
|
| 1557 |
+
void * userdata);
|
| 1558 |
+
|
| 1559 |
+
GGML_API struct ggml_tensor * ggml_map_custom2(
|
| 1560 |
+
struct ggml_context * ctx,
|
| 1561 |
+
struct ggml_tensor * a,
|
| 1562 |
+
struct ggml_tensor * b,
|
| 1563 |
+
ggml_custom2_op_t fun,
|
| 1564 |
+
int n_tasks,
|
| 1565 |
+
void * userdata);
|
| 1566 |
+
|
| 1567 |
+
GGML_API struct ggml_tensor * ggml_map_custom2_inplace(
|
| 1568 |
+
struct ggml_context * ctx,
|
| 1569 |
+
struct ggml_tensor * a,
|
| 1570 |
+
struct ggml_tensor * b,
|
| 1571 |
+
ggml_custom2_op_t fun,
|
| 1572 |
+
int n_tasks,
|
| 1573 |
+
void * userdata);
|
| 1574 |
+
|
| 1575 |
+
GGML_API struct ggml_tensor * ggml_map_custom3(
|
| 1576 |
+
struct ggml_context * ctx,
|
| 1577 |
+
struct ggml_tensor * a,
|
| 1578 |
+
struct ggml_tensor * b,
|
| 1579 |
+
struct ggml_tensor * c,
|
| 1580 |
+
ggml_custom3_op_t fun,
|
| 1581 |
+
int n_tasks,
|
| 1582 |
+
void * userdata);
|
| 1583 |
+
|
| 1584 |
+
GGML_API struct ggml_tensor * ggml_map_custom3_inplace(
|
| 1585 |
+
struct ggml_context * ctx,
|
| 1586 |
+
struct ggml_tensor * a,
|
| 1587 |
+
struct ggml_tensor * b,
|
| 1588 |
+
struct ggml_tensor * c,
|
| 1589 |
+
ggml_custom3_op_t fun,
|
| 1590 |
+
int n_tasks,
|
| 1591 |
+
void * userdata);
|
| 1592 |
|
| 1593 |
// loss function
|
| 1594 |
|
|
|
|
| 1609 |
|
| 1610 |
GGML_API void ggml_set_param(
|
| 1611 |
struct ggml_context * ctx,
|
| 1612 |
+
struct ggml_tensor * tensor);
|
| 1613 |
+
|
| 1614 |
|
| 1615 |
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
| 1616 |
|
| 1617 |
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
| 1618 |
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
|
| 1619 |
|
| 1620 |
+
// graph allocation in a context
|
| 1621 |
+
GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx);
|
| 1622 |
+
GGML_API struct ggml_cgraph * ggml_build_forward_ctx(struct ggml_context * ctx, struct ggml_tensor * tensor);
|
| 1623 |
+
GGML_API size_t ggml_graph_overhead(void);
|
| 1624 |
+
|
| 1625 |
+
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
| 1626 |
+
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
| 1627 |
+
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
|
| 1628 |
+
GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
|
| 1629 |
+
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
|
| 1630 |
+
|
| 1631 |
+
// same as ggml_graph_compute() but the work data is allocated as a part of the context
|
| 1632 |
+
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
|
| 1633 |
+
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
|
| 1634 |
|
| 1635 |
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
|
| 1636 |
|
|
|
|
| 1820 |
|
| 1821 |
GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
|
| 1822 |
|
| 1823 |
+
//
|
| 1824 |
+
// gguf
|
| 1825 |
+
//
|
| 1826 |
+
|
| 1827 |
+
enum gguf_type {
|
| 1828 |
+
GGUF_TYPE_UINT8 = 0,
|
| 1829 |
+
GGUF_TYPE_INT8 = 1,
|
| 1830 |
+
GGUF_TYPE_UINT16 = 2,
|
| 1831 |
+
GGUF_TYPE_INT16 = 3,
|
| 1832 |
+
GGUF_TYPE_UINT32 = 4,
|
| 1833 |
+
GGUF_TYPE_INT32 = 5,
|
| 1834 |
+
GGUF_TYPE_FLOAT32 = 6,
|
| 1835 |
+
GGUF_TYPE_BOOL = 7,
|
| 1836 |
+
GGUF_TYPE_STRING = 8,
|
| 1837 |
+
GGUF_TYPE_ARRAY = 9,
|
| 1838 |
+
GGUF_TYPE_UINT64 = 10,
|
| 1839 |
+
GGUF_TYPE_INT64 = 11,
|
| 1840 |
+
GGUF_TYPE_FLOAT64 = 12,
|
| 1841 |
+
GGUF_TYPE_COUNT, // marks the end of the enum
|
| 1842 |
+
};
|
| 1843 |
+
|
| 1844 |
+
struct gguf_context;
|
| 1845 |
+
|
| 1846 |
+
struct gguf_init_params {
|
| 1847 |
+
bool no_alloc;
|
| 1848 |
+
|
| 1849 |
+
// if not NULL, create a ggml_context and allocate the tensor data in it
|
| 1850 |
+
struct ggml_context ** ctx;
|
| 1851 |
+
};
|
| 1852 |
+
|
| 1853 |
+
GGML_API struct gguf_context * gguf_init_empty(void);
|
| 1854 |
+
GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
|
| 1855 |
+
//GGML_API struct gguf_context * gguf_init_from_buffer(..);
|
| 1856 |
+
|
| 1857 |
+
GGML_API void gguf_free(struct gguf_context * ctx);
|
| 1858 |
+
|
| 1859 |
+
GGML_API const char * gguf_type_name(enum gguf_type type);
|
| 1860 |
+
|
| 1861 |
+
GGML_API int gguf_get_version (struct gguf_context * ctx);
|
| 1862 |
+
GGML_API size_t gguf_get_alignment (struct gguf_context * ctx);
|
| 1863 |
+
GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx);
|
| 1864 |
+
GGML_API void * gguf_get_data (struct gguf_context * ctx);
|
| 1865 |
+
|
| 1866 |
+
GGML_API int gguf_get_n_kv(struct gguf_context * ctx);
|
| 1867 |
+
GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key);
|
| 1868 |
+
GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i);
|
| 1869 |
+
|
| 1870 |
+
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
|
| 1871 |
+
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
|
| 1872 |
+
|
| 1873 |
+
// results are undefined if the wrong type is used for the key
|
| 1874 |
+
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
|
| 1875 |
+
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);
|
| 1876 |
+
GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i);
|
| 1877 |
+
GGML_API int16_t gguf_get_val_i16 (struct gguf_context * ctx, int i);
|
| 1878 |
+
GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i);
|
| 1879 |
+
GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i);
|
| 1880 |
+
GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i);
|
| 1881 |
+
GGML_API uint64_t gguf_get_val_u64 (struct gguf_context * ctx, int i);
|
| 1882 |
+
GGML_API int64_t gguf_get_val_i64 (struct gguf_context * ctx, int i);
|
| 1883 |
+
GGML_API double gguf_get_val_f64 (struct gguf_context * ctx, int i);
|
| 1884 |
+
GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i);
|
| 1885 |
+
GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i);
|
| 1886 |
+
GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i);
|
| 1887 |
+
GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i);
|
| 1888 |
+
GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i);
|
| 1889 |
+
|
| 1890 |
+
GGML_API int gguf_get_n_tensors (struct gguf_context * ctx);
|
| 1891 |
+
GGML_API int gguf_find_tensor (struct gguf_context * ctx, const char * name);
|
| 1892 |
+
GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i);
|
| 1893 |
+
GGML_API char * gguf_get_tensor_name (struct gguf_context * ctx, int i);
|
| 1894 |
+
|
| 1895 |
+
// overrides existing values or adds a new one
|
| 1896 |
+
GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val);
|
| 1897 |
+
GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val);
|
| 1898 |
+
GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
|
| 1899 |
+
GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val);
|
| 1900 |
+
GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
|
| 1901 |
+
GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val);
|
| 1902 |
+
GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val);
|
| 1903 |
+
GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
|
| 1904 |
+
GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val);
|
| 1905 |
+
GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val);
|
| 1906 |
+
GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val);
|
| 1907 |
+
GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
|
| 1908 |
+
GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
|
| 1909 |
+
GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
|
| 1910 |
+
|
| 1911 |
+
// set or add KV pairs from another context
|
| 1912 |
+
GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
|
| 1913 |
+
|
| 1914 |
+
// manage tensor info
|
| 1915 |
+
GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
|
| 1916 |
+
GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
|
| 1917 |
+
GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
|
| 1918 |
+
|
| 1919 |
+
// writing gguf files can be done in 2 ways:
|
| 1920 |
+
//
|
| 1921 |
+
// - write the entire gguf_context to a binary file in a single pass:
|
| 1922 |
+
//
|
| 1923 |
+
// gguf_write_to_file(ctx, fname);
|
| 1924 |
+
//
|
| 1925 |
+
// - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
|
| 1926 |
+
//
|
| 1927 |
+
// FILE * f = fopen(fname, "wb");
|
| 1928 |
+
// fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
|
| 1929 |
+
// fwrite(f, ...);
|
| 1930 |
+
// void * data = gguf_meta_get_meta_data(ctx);
|
| 1931 |
+
// fseek(f, 0, SEEK_SET);
|
| 1932 |
+
// fwrite(f, data, gguf_get_meta_size(ctx));
|
| 1933 |
+
// free(data);
|
| 1934 |
+
// fclose(f);
|
| 1935 |
+
//
|
| 1936 |
+
|
| 1937 |
+
// write the entire context to a binary file
|
| 1938 |
+
GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta);
|
| 1939 |
+
|
| 1940 |
+
// get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
|
| 1941 |
+
GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx);
|
| 1942 |
+
GGML_API void gguf_get_meta_data(struct gguf_context * ctx, void * data);
|
| 1943 |
+
|
| 1944 |
//
|
| 1945 |
// system info
|
| 1946 |
//
|
|
|
|
| 1969 |
//
|
| 1970 |
|
| 1971 |
#ifdef __cplusplus
|
| 1972 |
+
// restrict not standard in C++
|
| 1973 |
#define GGML_RESTRICT
|
| 1974 |
#else
|
| 1975 |
#define GGML_RESTRICT restrict
|
| 1976 |
#endif
|
| 1977 |
+
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
|
| 1978 |
+
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
|
| 1979 |
+
typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
|
| 1980 |
|
| 1981 |
typedef struct {
|
| 1982 |
+
const char * type_name;
|
| 1983 |
+
int blck_size;
|
| 1984 |
+
size_t type_size;
|
| 1985 |
+
bool is_quantized;
|
| 1986 |
+
ggml_to_float_t to_float;
|
| 1987 |
+
ggml_from_float_t from_float;
|
| 1988 |
+
ggml_from_float_t from_float_reference;
|
| 1989 |
+
ggml_vec_dot_t vec_dot;
|
| 1990 |
+
enum ggml_type vec_dot_type;
|
| 1991 |
+
} ggml_type_traits_t;
|
| 1992 |
+
|
| 1993 |
+
ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
|
| 1994 |
|
| 1995 |
#ifdef __cplusplus
|
| 1996 |
}
|
whisper.cpp
CHANGED
|
@@ -441,6 +441,7 @@ struct whisper_hparams {
|
|
| 441 |
int32_t n_text_layer = 4;
|
| 442 |
int32_t n_mels = 80;
|
| 443 |
int32_t ftype = 1;
|
|
|
|
| 444 |
};
|
| 445 |
|
| 446 |
// audio encoding layer
|
|
@@ -1578,7 +1579,7 @@ static bool whisper_encode_internal(
|
|
| 1578 |
{
|
| 1579 |
wstate.use_buf(ctx0, 0);
|
| 1580 |
|
| 1581 |
-
cur = ggml_norm(ctx0, inpL);
|
| 1582 |
|
| 1583 |
// cur = ln_0_w*cur + ln_0_b
|
| 1584 |
cur = ggml_add(ctx0,
|
|
@@ -1725,7 +1726,7 @@ static bool whisper_encode_internal(
|
|
| 1725 |
{
|
| 1726 |
wstate.use_buf(ctx0, 0);
|
| 1727 |
|
| 1728 |
-
cur = ggml_norm(ctx0, inpFF);
|
| 1729 |
|
| 1730 |
wstate.use_buf(ctx0, 1);
|
| 1731 |
|
|
@@ -1788,7 +1789,7 @@ static bool whisper_encode_internal(
|
|
| 1788 |
{
|
| 1789 |
wstate.use_buf(ctx0, 0);
|
| 1790 |
|
| 1791 |
-
cur = ggml_norm(ctx0, cur);
|
| 1792 |
|
| 1793 |
wstate.use_buf(ctx0, 1);
|
| 1794 |
|
|
@@ -1805,10 +1806,9 @@ static bool whisper_encode_internal(
|
|
| 1805 |
// run the computation
|
| 1806 |
{
|
| 1807 |
struct ggml_cgraph gf = {};
|
| 1808 |
-
gf.n_threads = n_threads;
|
| 1809 |
|
| 1810 |
-
ggml_build_forward_expand(&gf, cur);
|
| 1811 |
-
|
| 1812 |
|
| 1813 |
//ggml_graph_print(&gf);
|
| 1814 |
}
|
|
@@ -1851,12 +1851,11 @@ static bool whisper_encode_internal(
|
|
| 1851 |
// pre-compute cross-attention memory
|
| 1852 |
{
|
| 1853 |
struct ggml_cgraph gf = {};
|
| 1854 |
-
gf.n_threads = n_threads;
|
| 1855 |
|
| 1856 |
// TODO: hack to disconnect the encoded features from the previous graph
|
| 1857 |
cur->op = GGML_OP_NONE;
|
| 1858 |
-
cur->
|
| 1859 |
-
cur->
|
| 1860 |
|
| 1861 |
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
| 1862 |
auto& layer = model.layers_decoder[il];
|
|
@@ -1894,7 +1893,7 @@ static bool whisper_encode_internal(
|
|
| 1894 |
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
| 1895 |
}
|
| 1896 |
|
| 1897 |
-
|
| 1898 |
//ggml_graph_print(&gf);
|
| 1899 |
}
|
| 1900 |
|
|
@@ -1965,7 +1964,6 @@ static bool whisper_decode_internal(
|
|
| 1965 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1966 |
|
| 1967 |
struct ggml_cgraph gf = {};
|
| 1968 |
-
gf.n_threads = n_threads;
|
| 1969 |
|
| 1970 |
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 1971 |
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
|
@@ -1992,7 +1990,7 @@ static bool whisper_decode_internal(
|
|
| 1992 |
{
|
| 1993 |
wstate.use_buf(ctx0, 0);
|
| 1994 |
|
| 1995 |
-
cur = ggml_norm(ctx0, inpL);
|
| 1996 |
|
| 1997 |
// cur = ln_0_w*cur + ln_0_b
|
| 1998 |
cur = ggml_add(ctx0,
|
|
@@ -2119,7 +2117,7 @@ static bool whisper_decode_internal(
|
|
| 2119 |
{
|
| 2120 |
wstate.use_buf(ctx0, 0);
|
| 2121 |
|
| 2122 |
-
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
| 2123 |
|
| 2124 |
// cur = ln_0_w*cur + ln_0_b
|
| 2125 |
cur = ggml_add(ctx0,
|
|
@@ -2229,7 +2227,7 @@ static bool whisper_decode_internal(
|
|
| 2229 |
{
|
| 2230 |
wstate.use_buf(ctx0, 0);
|
| 2231 |
|
| 2232 |
-
cur = ggml_norm(ctx0, inpFF);
|
| 2233 |
|
| 2234 |
wstate.use_buf(ctx0, 1);
|
| 2235 |
|
|
@@ -2284,7 +2282,7 @@ static bool whisper_decode_internal(
|
|
| 2284 |
{
|
| 2285 |
wstate.use_buf(ctx0, 0);
|
| 2286 |
|
| 2287 |
-
cur = ggml_norm(ctx0, cur);
|
| 2288 |
|
| 2289 |
wstate.use_buf(ctx0, 1);
|
| 2290 |
|
|
@@ -2308,8 +2306,8 @@ static bool whisper_decode_internal(
|
|
| 2308 |
|
| 2309 |
// run the computation
|
| 2310 |
{
|
| 2311 |
-
ggml_build_forward_expand(&gf, logits);
|
| 2312 |
-
|
| 2313 |
}
|
| 2314 |
|
| 2315 |
// extract logits for all N tokens
|
|
@@ -2358,7 +2356,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
| 2358 |
static float sin_vals[SIN_COS_N_COUNT];
|
| 2359 |
static float cos_vals[SIN_COS_N_COUNT];
|
| 2360 |
|
| 2361 |
-
// In FFT, we frequently use sine and cosine operations with the same values.
|
| 2362 |
// We can use precalculated values to speed up the process.
|
| 2363 |
static void fill_sin_cos_table() {
|
| 2364 |
static bool is_filled = false;
|
|
@@ -5165,17 +5163,15 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|
| 5165 |
|
| 5166 |
struct ggml_cgraph gf = ggml_build_forward(c);
|
| 5167 |
|
| 5168 |
-
gf.n_threads = n_threads;
|
| 5169 |
-
|
| 5170 |
double tsum = 0.0;
|
| 5171 |
|
| 5172 |
// heat-up
|
| 5173 |
-
|
| 5174 |
|
| 5175 |
for (int i = 0; i < n_max; ++i) {
|
| 5176 |
const int64_t t0 = ggml_time_us();
|
| 5177 |
|
| 5178 |
-
|
| 5179 |
|
| 5180 |
const int64_t t1 = ggml_time_us();
|
| 5181 |
|
|
|
|
| 441 |
int32_t n_text_layer = 4;
|
| 442 |
int32_t n_mels = 80;
|
| 443 |
int32_t ftype = 1;
|
| 444 |
+
float eps = 1e-5f;
|
| 445 |
};
|
| 446 |
|
| 447 |
// audio encoding layer
|
|
|
|
| 1579 |
{
|
| 1580 |
wstate.use_buf(ctx0, 0);
|
| 1581 |
|
| 1582 |
+
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
| 1583 |
|
| 1584 |
// cur = ln_0_w*cur + ln_0_b
|
| 1585 |
cur = ggml_add(ctx0,
|
|
|
|
| 1726 |
{
|
| 1727 |
wstate.use_buf(ctx0, 0);
|
| 1728 |
|
| 1729 |
+
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
| 1730 |
|
| 1731 |
wstate.use_buf(ctx0, 1);
|
| 1732 |
|
|
|
|
| 1789 |
{
|
| 1790 |
wstate.use_buf(ctx0, 0);
|
| 1791 |
|
| 1792 |
+
cur = ggml_norm(ctx0, cur, hparams.eps);
|
| 1793 |
|
| 1794 |
wstate.use_buf(ctx0, 1);
|
| 1795 |
|
|
|
|
| 1806 |
// run the computation
|
| 1807 |
{
|
| 1808 |
struct ggml_cgraph gf = {};
|
|
|
|
| 1809 |
|
| 1810 |
+
ggml_build_forward_expand (&gf, cur);
|
| 1811 |
+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
| 1812 |
|
| 1813 |
//ggml_graph_print(&gf);
|
| 1814 |
}
|
|
|
|
| 1851 |
// pre-compute cross-attention memory
|
| 1852 |
{
|
| 1853 |
struct ggml_cgraph gf = {};
|
|
|
|
| 1854 |
|
| 1855 |
// TODO: hack to disconnect the encoded features from the previous graph
|
| 1856 |
cur->op = GGML_OP_NONE;
|
| 1857 |
+
cur->src[0] = nullptr;
|
| 1858 |
+
cur->src[1] = nullptr;
|
| 1859 |
|
| 1860 |
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
| 1861 |
auto& layer = model.layers_decoder[il];
|
|
|
|
| 1893 |
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
| 1894 |
}
|
| 1895 |
|
| 1896 |
+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
| 1897 |
//ggml_graph_print(&gf);
|
| 1898 |
}
|
| 1899 |
|
|
|
|
| 1964 |
struct ggml_context * ctx0 = ggml_init(params);
|
| 1965 |
|
| 1966 |
struct ggml_cgraph gf = {};
|
|
|
|
| 1967 |
|
| 1968 |
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 1969 |
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
|
|
|
| 1990 |
{
|
| 1991 |
wstate.use_buf(ctx0, 0);
|
| 1992 |
|
| 1993 |
+
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
| 1994 |
|
| 1995 |
// cur = ln_0_w*cur + ln_0_b
|
| 1996 |
cur = ggml_add(ctx0,
|
|
|
|
| 2117 |
{
|
| 2118 |
wstate.use_buf(ctx0, 0);
|
| 2119 |
|
| 2120 |
+
cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
|
| 2121 |
|
| 2122 |
// cur = ln_0_w*cur + ln_0_b
|
| 2123 |
cur = ggml_add(ctx0,
|
|
|
|
| 2227 |
{
|
| 2228 |
wstate.use_buf(ctx0, 0);
|
| 2229 |
|
| 2230 |
+
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
| 2231 |
|
| 2232 |
wstate.use_buf(ctx0, 1);
|
| 2233 |
|
|
|
|
| 2282 |
{
|
| 2283 |
wstate.use_buf(ctx0, 0);
|
| 2284 |
|
| 2285 |
+
cur = ggml_norm(ctx0, cur, hparams.eps);
|
| 2286 |
|
| 2287 |
wstate.use_buf(ctx0, 1);
|
| 2288 |
|
|
|
|
| 2306 |
|
| 2307 |
// run the computation
|
| 2308 |
{
|
| 2309 |
+
ggml_build_forward_expand (&gf, logits);
|
| 2310 |
+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
| 2311 |
}
|
| 2312 |
|
| 2313 |
// extract logits for all N tokens
|
|
|
|
| 2356 |
static float sin_vals[SIN_COS_N_COUNT];
|
| 2357 |
static float cos_vals[SIN_COS_N_COUNT];
|
| 2358 |
|
| 2359 |
+
// In FFT, we frequently use sine and cosine operations with the same values.
|
| 2360 |
// We can use precalculated values to speed up the process.
|
| 2361 |
static void fill_sin_cos_table() {
|
| 2362 |
static bool is_filled = false;
|
|
|
|
| 5163 |
|
| 5164 |
struct ggml_cgraph gf = ggml_build_forward(c);
|
| 5165 |
|
|
|
|
|
|
|
| 5166 |
double tsum = 0.0;
|
| 5167 |
|
| 5168 |
// heat-up
|
| 5169 |
+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
| 5170 |
|
| 5171 |
for (int i = 0; i < n_max; ++i) {
|
| 5172 |
const int64_t t0 = ggml_time_us();
|
| 5173 |
|
| 5174 |
+
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
| 5175 |
|
| 5176 |
const int64_t t1 = ggml_time_us();
|
| 5177 |
|