Spaces:
Running
Running
whisper : use flash attention (#2152)
Browse files* whisper : use flash attention in the encoder
* whisper : add kv_pad
* whisper : remove extra backend instance (huh?)
* whisper : use FA for cross-attention
* whisper : use FA for self-attention
* whisper : simplify encoder FA
* whisper : add flash_attn runtime parameter
* scripts : add bench log
* scripts : add M1 Pro bench log
- examples/bench/bench.cpp +11 -6
- examples/command/command.cpp +6 -1
- examples/lsp/lsp.cpp +7 -1
- examples/main/main.cpp +7 -2
- examples/server/server.cpp +6 -1
- examples/stream/stream.cpp +6 -1
- examples/talk-llama/talk-llama.cpp +7 -2
- examples/talk/talk.cpp +6 -1
- examples/wchess/wchess.cmd/wchess.cmd.cpp +6 -1
- scripts/bench-all-gg.txt +298 -0
- scripts/bench-all.sh +19 -6
- whisper.cpp +278 -151
- whisper.h +1 -0
examples/bench/bench.cpp
CHANGED
|
@@ -12,7 +12,8 @@ struct whisper_params {
|
|
| 12 |
|
| 13 |
std::string model = "models/ggml-base.en.bin";
|
| 14 |
|
| 15 |
-
bool use_gpu
|
|
|
|
| 16 |
};
|
| 17 |
|
| 18 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
@@ -25,10 +26,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 25 |
whisper_print_usage(argc, argv, params);
|
| 26 |
exit(0);
|
| 27 |
}
|
| 28 |
-
else if (arg == "-t" || arg == "--threads")
|
| 29 |
-
else if (arg == "-m" || arg == "--model")
|
| 30 |
-
else if (arg == "-w" || arg == "--what")
|
| 31 |
-
else if (arg == "-ng" || arg == "--no-gpu")
|
|
|
|
| 32 |
else {
|
| 33 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 34 |
whisper_print_usage(argc, argv, params);
|
|
@@ -49,6 +51,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 49 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 50 |
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
| 51 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 52 |
fprintf(stderr, " %-7s 0 - whisper\n", "");
|
| 53 |
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
| 54 |
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
|
@@ -59,7 +62,9 @@ int whisper_bench_full(const whisper_params & params) {
|
|
| 59 |
// whisper init
|
| 60 |
|
| 61 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
|
| 64 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 65 |
|
|
|
|
| 12 |
|
| 13 |
std::string model = "models/ggml-base.en.bin";
|
| 14 |
|
| 15 |
+
bool use_gpu = true;
|
| 16 |
+
bool flash_attn = false;
|
| 17 |
};
|
| 18 |
|
| 19 |
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
|
|
|
| 26 |
whisper_print_usage(argc, argv, params);
|
| 27 |
exit(0);
|
| 28 |
}
|
| 29 |
+
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
| 30 |
+
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 31 |
+
else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
|
| 32 |
+
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 33 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 34 |
else {
|
| 35 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 36 |
whisper_print_usage(argc, argv, params);
|
|
|
|
| 51 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 52 |
fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
|
| 53 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 54 |
+
fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
|
| 55 |
fprintf(stderr, " %-7s 0 - whisper\n", "");
|
| 56 |
fprintf(stderr, " %-7s 1 - memcpy\n", "");
|
| 57 |
fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
|
|
|
|
| 62 |
// whisper init
|
| 63 |
|
| 64 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 65 |
+
|
| 66 |
+
cparams.use_gpu = params.use_gpu;
|
| 67 |
+
cparams.flash_attn = params.flash_attn;
|
| 68 |
|
| 69 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 70 |
|
examples/command/command.cpp
CHANGED
|
@@ -44,6 +44,7 @@ struct whisper_params {
|
|
| 44 |
bool print_energy = false;
|
| 45 |
bool no_timestamps = true;
|
| 46 |
bool use_gpu = true;
|
|
|
|
| 47 |
|
| 48 |
std::string language = "en";
|
| 49 |
std::string model = "models/ggml-base.en.bin";
|
|
@@ -80,6 +81,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 80 |
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 81 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 82 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
| 83 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 84 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 85 |
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
|
@@ -118,6 +120,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 118 |
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 119 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 120 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 121 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 122 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 123 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
|
@@ -696,7 +699,9 @@ int main(int argc, char ** argv) {
|
|
| 696 |
// whisper init
|
| 697 |
|
| 698 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 699 |
-
|
|
|
|
|
|
|
| 700 |
|
| 701 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 702 |
|
|
|
|
| 44 |
bool print_energy = false;
|
| 45 |
bool no_timestamps = true;
|
| 46 |
bool use_gpu = true;
|
| 47 |
+
bool flash_attn = false;
|
| 48 |
|
| 49 |
std::string language = "en";
|
| 50 |
std::string model = "models/ggml-base.en.bin";
|
|
|
|
| 81 |
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 82 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 83 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 84 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 85 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 86 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 87 |
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
|
|
|
| 120 |
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 121 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 122 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 123 |
+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
| 124 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 125 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 126 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
|
|
|
| 699 |
// whisper init
|
| 700 |
|
| 701 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 702 |
+
|
| 703 |
+
cparams.use_gpu = params.use_gpu;
|
| 704 |
+
cparams.flash_attn = params.flash_attn;
|
| 705 |
|
| 706 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 707 |
|
examples/lsp/lsp.cpp
CHANGED
|
@@ -31,6 +31,7 @@ struct whisper_params {
|
|
| 31 |
bool print_special = false;
|
| 32 |
bool print_energy = false;
|
| 33 |
bool use_gpu = true;
|
|
|
|
| 34 |
|
| 35 |
std::string language = "en";
|
| 36 |
std::string model = "models/ggml-base.en.bin";
|
|
@@ -74,6 +75,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 74 |
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 75 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 76 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
| 77 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 78 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 79 |
else {
|
|
@@ -105,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 105 |
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 106 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 107 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 108 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 109 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 110 |
fprintf(stderr, "\n");
|
|
@@ -436,7 +439,10 @@ int main(int argc, char ** argv) {
|
|
| 436 |
|
| 437 |
// whisper init
|
| 438 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
| 440 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 441 |
// init audio
|
| 442 |
|
|
|
|
| 31 |
bool print_special = false;
|
| 32 |
bool print_energy = false;
|
| 33 |
bool use_gpu = true;
|
| 34 |
+
bool flash_attn = false;
|
| 35 |
|
| 36 |
std::string language = "en";
|
| 37 |
std::string model = "models/ggml-base.en.bin";
|
|
|
|
| 75 |
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 76 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 77 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 78 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 79 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 80 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 81 |
else {
|
|
|
|
| 107 |
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 108 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 109 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 110 |
+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
| 111 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 112 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 113 |
fprintf(stderr, "\n");
|
|
|
|
| 439 |
|
| 440 |
// whisper init
|
| 441 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 442 |
+
|
| 443 |
+
cparams.use_gpu = params.use_gpu;
|
| 444 |
+
cparams.flash_attn = params.flash_attn;
|
| 445 |
+
|
| 446 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 447 |
// init audio
|
| 448 |
|
examples/main/main.cpp
CHANGED
|
@@ -70,6 +70,7 @@ struct whisper_params {
|
|
| 70 |
bool no_timestamps = false;
|
| 71 |
bool log_score = false;
|
| 72 |
bool use_gpu = true;
|
|
|
|
| 73 |
|
| 74 |
std::string language = "en";
|
| 75 |
std::string prompt;
|
|
@@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 168 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 169 |
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
| 170 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 171 |
-
else if (
|
|
|
|
| 172 |
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
| 173 |
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
|
| 174 |
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
|
@@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 234 |
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
| 235 |
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
| 236 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 237 |
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
|
| 238 |
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
| 239 |
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
|
@@ -977,7 +980,9 @@ int main(int argc, char ** argv) {
|
|
| 977 |
// whisper init
|
| 978 |
|
| 979 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 980 |
-
|
|
|
|
|
|
|
| 981 |
|
| 982 |
if (!params.dtw.empty()) {
|
| 983 |
cparams.dtw_token_timestamps = true;
|
|
|
|
| 70 |
bool no_timestamps = false;
|
| 71 |
bool log_score = false;
|
| 72 |
bool use_gpu = true;
|
| 73 |
+
bool flash_attn = false;
|
| 74 |
|
| 75 |
std::string language = "en";
|
| 76 |
std::string prompt;
|
|
|
|
| 169 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 170 |
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
| 171 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 172 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 173 |
+
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
|
| 174 |
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
| 175 |
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
|
| 176 |
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
|
|
|
| 236 |
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
| 237 |
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
| 238 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 239 |
+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
| 240 |
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
|
| 241 |
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
| 242 |
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
|
|
|
| 980 |
// whisper init
|
| 981 |
|
| 982 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 983 |
+
|
| 984 |
+
cparams.use_gpu = params.use_gpu;
|
| 985 |
+
cparams.flash_attn = params.flash_attn;
|
| 986 |
|
| 987 |
if (!params.dtw.empty()) {
|
| 988 |
cparams.dtw_token_timestamps = true;
|
examples/server/server.cpp
CHANGED
|
@@ -75,6 +75,7 @@ struct whisper_params {
|
|
| 75 |
bool print_progress = false;
|
| 76 |
bool no_timestamps = false;
|
| 77 |
bool use_gpu = true;
|
|
|
|
| 78 |
|
| 79 |
std::string language = "en";
|
| 80 |
std::string prompt = "";
|
|
@@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
|
| 178 |
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
| 179 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 180 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
| 181 |
// server params
|
| 182 |
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
| 183 |
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
|
@@ -502,7 +504,10 @@ int main(int argc, char ** argv) {
|
|
| 502 |
}
|
| 503 |
// whisper init
|
| 504 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
| 506 |
if (!params.dtw.empty()) {
|
| 507 |
cparams.dtw_token_timestamps = true;
|
| 508 |
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
|
|
|
|
| 75 |
bool print_progress = false;
|
| 76 |
bool no_timestamps = false;
|
| 77 |
bool use_gpu = true;
|
| 78 |
+
bool flash_attn = false;
|
| 79 |
|
| 80 |
std::string language = "en";
|
| 81 |
std::string prompt = "";
|
|
|
|
| 179 |
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
| 180 |
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
| 181 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 182 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 183 |
// server params
|
| 184 |
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
| 185 |
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
|
|
|
| 504 |
}
|
| 505 |
// whisper init
|
| 506 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 507 |
+
|
| 508 |
+
cparams.use_gpu = params.use_gpu;
|
| 509 |
+
cparams.flash_attn = params.flash_attn;
|
| 510 |
+
|
| 511 |
if (!params.dtw.empty()) {
|
| 512 |
cparams.dtw_token_timestamps = true;
|
| 513 |
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
|
examples/stream/stream.cpp
CHANGED
|
@@ -36,6 +36,7 @@ struct whisper_params {
|
|
| 36 |
bool tinydiarize = false;
|
| 37 |
bool save_audio = false; // save audio to wav file
|
| 38 |
bool use_gpu = true;
|
|
|
|
| 39 |
|
| 40 |
std::string language = "en";
|
| 41 |
std::string model = "models/ggml-base.en.bin";
|
|
@@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 72 |
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
| 73 |
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
|
| 74 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
| 75 |
|
| 76 |
else {
|
| 77 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
@@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 109 |
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
| 110 |
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
|
| 111 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 112 |
fprintf(stderr, "\n");
|
| 113 |
}
|
| 114 |
|
|
@@ -153,7 +156,9 @@ int main(int argc, char ** argv) {
|
|
| 153 |
}
|
| 154 |
|
| 155 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 156 |
-
|
|
|
|
|
|
|
| 157 |
|
| 158 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 159 |
|
|
|
|
| 36 |
bool tinydiarize = false;
|
| 37 |
bool save_audio = false; // save audio to wav file
|
| 38 |
bool use_gpu = true;
|
| 39 |
+
bool flash_attn = false;
|
| 40 |
|
| 41 |
std::string language = "en";
|
| 42 |
std::string model = "models/ggml-base.en.bin";
|
|
|
|
| 73 |
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
| 74 |
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
|
| 75 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 76 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 77 |
|
| 78 |
else {
|
| 79 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
|
|
| 111 |
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
| 112 |
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
|
| 113 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
|
| 114 |
+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
|
| 115 |
fprintf(stderr, "\n");
|
| 116 |
}
|
| 117 |
|
|
|
|
| 156 |
}
|
| 157 |
|
| 158 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 159 |
+
|
| 160 |
+
cparams.use_gpu = params.use_gpu;
|
| 161 |
+
cparams.flash_attn = params.flash_attn;
|
| 162 |
|
| 163 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 164 |
|
examples/talk-llama/talk-llama.cpp
CHANGED
|
@@ -66,6 +66,7 @@ struct whisper_params {
|
|
| 66 |
bool no_timestamps = true;
|
| 67 |
bool verbose_prompt = false;
|
| 68 |
bool use_gpu = true;
|
|
|
|
| 69 |
|
| 70 |
std::string person = "Georgi";
|
| 71 |
std::string bot_name = "LLaMA";
|
|
@@ -105,6 +106,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 105 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 106 |
else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; }
|
| 107 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
| 108 |
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
| 109 |
else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
|
| 110 |
else if (arg == "--session") { params.path_session = argv[++i]; }
|
|
@@ -123,7 +125,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 123 |
}
|
| 124 |
}
|
| 125 |
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
| 126 |
-
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 127 |
else {
|
| 128 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 129 |
whisper_print_usage(argc, argv, params);
|
|
@@ -154,6 +155,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 154 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 155 |
fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
|
| 156 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 157 |
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
| 158 |
fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
|
| 159 |
fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
|
|
@@ -285,7 +287,9 @@ int main(int argc, char ** argv) {
|
|
| 285 |
// whisper init
|
| 286 |
|
| 287 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 288 |
-
|
|
|
|
|
|
|
| 289 |
|
| 290 |
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
|
| 291 |
if (!ctx_wsp) {
|
|
@@ -316,6 +320,7 @@ int main(int argc, char ** argv) {
|
|
| 316 |
lcparams.n_ctx = 2048;
|
| 317 |
lcparams.seed = 1;
|
| 318 |
lcparams.n_threads = params.n_threads;
|
|
|
|
| 319 |
|
| 320 |
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
|
| 321 |
|
|
|
|
| 66 |
bool no_timestamps = true;
|
| 67 |
bool verbose_prompt = false;
|
| 68 |
bool use_gpu = true;
|
| 69 |
+
bool flash_attn = false;
|
| 70 |
|
| 71 |
std::string person = "Georgi";
|
| 72 |
std::string bot_name = "LLaMA";
|
|
|
|
| 106 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 107 |
else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; }
|
| 108 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 109 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 110 |
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
| 111 |
else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
|
| 112 |
else if (arg == "--session") { params.path_session = argv[++i]; }
|
|
|
|
| 125 |
}
|
| 126 |
}
|
| 127 |
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
|
|
|
| 128 |
else {
|
| 129 |
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 130 |
whisper_print_usage(argc, argv, params);
|
|
|
|
| 155 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 156 |
fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
|
| 157 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 158 |
+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
| 159 |
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
| 160 |
fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
|
| 161 |
fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
|
|
|
|
| 287 |
// whisper init
|
| 288 |
|
| 289 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 290 |
+
|
| 291 |
+
cparams.use_gpu = params.use_gpu;
|
| 292 |
+
cparams.flash_attn = params.flash_attn;
|
| 293 |
|
| 294 |
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
|
| 295 |
if (!ctx_wsp) {
|
|
|
|
| 320 |
lcparams.n_ctx = 2048;
|
| 321 |
lcparams.seed = 1;
|
| 322 |
lcparams.n_threads = params.n_threads;
|
| 323 |
+
lcparams.flash_attn = params.flash_attn;
|
| 324 |
|
| 325 |
struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
|
| 326 |
|
examples/talk/talk.cpp
CHANGED
|
@@ -32,6 +32,7 @@ struct whisper_params {
|
|
| 32 |
bool print_energy = false;
|
| 33 |
bool no_timestamps = true;
|
| 34 |
bool use_gpu = true;
|
|
|
|
| 35 |
|
| 36 |
std::string person = "Santa";
|
| 37 |
std::string language = "en";
|
|
@@ -64,6 +65,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 64 |
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 65 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 66 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
| 67 |
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
| 68 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 69 |
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
|
|
@@ -99,6 +101,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 99 |
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 100 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 101 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 102 |
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
| 103 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 104 |
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
|
|
@@ -188,7 +191,9 @@ int main(int argc, char ** argv) {
|
|
| 188 |
|
| 189 |
// whisper init
|
| 190 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
|
| 193 |
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
|
| 194 |
|
|
|
|
| 32 |
bool print_energy = false;
|
| 33 |
bool no_timestamps = true;
|
| 34 |
bool use_gpu = true;
|
| 35 |
+
bool flash_attn = false;
|
| 36 |
|
| 37 |
std::string person = "Santa";
|
| 38 |
std::string language = "en";
|
|
|
|
| 65 |
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 66 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 67 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 68 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 69 |
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
| 70 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 71 |
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
|
|
|
|
| 101 |
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 102 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 103 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 104 |
+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
| 105 |
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
| 106 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 107 |
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
|
|
|
|
| 191 |
|
| 192 |
// whisper init
|
| 193 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 194 |
+
|
| 195 |
+
cparams.use_gpu = params.use_gpu;
|
| 196 |
+
cparams.flash_attn = params.flash_attn;
|
| 197 |
|
| 198 |
struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
|
| 199 |
|
examples/wchess/wchess.cmd/wchess.cmd.cpp
CHANGED
|
@@ -32,6 +32,7 @@ struct whisper_params {
|
|
| 32 |
bool print_energy = false;
|
| 33 |
bool no_timestamps = true;
|
| 34 |
bool use_gpu = true;
|
|
|
|
| 35 |
|
| 36 |
std::string language = "en";
|
| 37 |
std::string model = "models/ggml-base.en.bin";
|
|
@@ -61,6 +62,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|
| 61 |
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 62 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 63 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
|
|
|
| 64 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 65 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 66 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
|
@@ -92,6 +94,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 92 |
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 93 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 94 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
|
|
|
| 95 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 96 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 97 |
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
|
@@ -183,7 +186,9 @@ int main(int argc, char ** argv) {
|
|
| 183 |
// whisper init
|
| 184 |
|
| 185 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
|
| 188 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 189 |
if (!ctx) {
|
|
|
|
| 32 |
bool print_energy = false;
|
| 33 |
bool no_timestamps = true;
|
| 34 |
bool use_gpu = true;
|
| 35 |
+
bool flash_attn = false;
|
| 36 |
|
| 37 |
std::string language = "en";
|
| 38 |
std::string model = "models/ggml-base.en.bin";
|
|
|
|
| 62 |
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 63 |
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 64 |
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
| 65 |
+
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during decoding\n", params.flash_attn ? "true" : "false");
|
| 66 |
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 67 |
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
| 68 |
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
|
|
|
| 94 |
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 95 |
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 96 |
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
| 97 |
+
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
| 98 |
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 99 |
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
| 100 |
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
|
|
|
| 186 |
// whisper init
|
| 187 |
|
| 188 |
struct whisper_context_params cparams = whisper_context_default_params();
|
| 189 |
+
|
| 190 |
+
cparams.use_gpu = params.use_gpu;
|
| 191 |
+
cparams.flash_attn = params.flash_attn;
|
| 192 |
|
| 193 |
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
| 194 |
if (!ctx) {
|
scripts/bench-all-gg.txt
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## M1 Pro
|
| 2 |
+
|
| 3 |
+
make -j && ./scripts/bench-all.sh 8
|
| 4 |
+
|
| 5 |
+
Running memcpy benchmark
|
| 6 |
+
|
| 7 |
+
memcpy: 39.10 GB/s (heat-up)
|
| 8 |
+
memcpy: 44.75 GB/s ( 1 thread)
|
| 9 |
+
memcpy: 44.78 GB/s ( 1 thread)
|
| 10 |
+
memcpy: 44.97 GB/s ( 2 thread)
|
| 11 |
+
memcpy: 48.04 GB/s ( 3 thread)
|
| 12 |
+
memcpy: 50.55 GB/s ( 4 thread)
|
| 13 |
+
memcpy: 55.20 GB/s ( 5 thread)
|
| 14 |
+
memcpy: 65.60 GB/s ( 6 thread)
|
| 15 |
+
memcpy: 70.64 GB/s ( 7 thread)
|
| 16 |
+
memcpy: 73.34 GB/s ( 8 thread)
|
| 17 |
+
sum: -5120002535.000000
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
make -j && ./scripts/bench-all.sh 1 0 0
|
| 21 |
+
|
| 22 |
+
Running ggml_mul_mat benchmark with 1 threads
|
| 23 |
+
|
| 24 |
+
64 x 64: Q4_0 237.1 GFLOPS (128 runs) | Q4_1 168.6 GFLOPS (128 runs)
|
| 25 |
+
64 x 64: Q5_0 136.4 GFLOPS (128 runs) | Q5_1 135.6 GFLOPS (128 runs) | Q8_0 243.1 GFLOPS (128 runs)
|
| 26 |
+
64 x 64: F16 140.4 GFLOPS (128 runs) | F32 316.6 GFLOPS (128 runs)
|
| 27 |
+
128 x 128: Q4_0 496.6 GFLOPS (128 runs) | Q4_1 348.6 GFLOPS (128 runs)
|
| 28 |
+
128 x 128: Q5_0 273.2 GFLOPS (128 runs) | Q5_1 274.1 GFLOPS (128 runs) | Q8_0 505.1 GFLOPS (128 runs)
|
| 29 |
+
128 x 128: F16 300.4 GFLOPS (128 runs) | F32 653.9 GFLOPS (128 runs)
|
| 30 |
+
256 x 256: Q4_0 791.7 GFLOPS (128 runs) | Q4_1 615.3 GFLOPS (128 runs)
|
| 31 |
+
256 x 256: Q5_0 651.0 GFLOPS (128 runs) | Q5_1 674.7 GFLOPS (128 runs) | Q8_0 803.1 GFLOPS (128 runs)
|
| 32 |
+
256 x 256: F16 869.6 GFLOPS (128 runs) | F32 957.2 GFLOPS (128 runs)
|
| 33 |
+
512 x 512: Q4_0 973.3 GFLOPS (128 runs) | Q4_1 897.9 GFLOPS (128 runs)
|
| 34 |
+
512 x 512: Q5_0 1078.8 GFLOPS (128 runs) | Q5_1 998.4 GFLOPS (128 runs) | Q8_0 752.4 GFLOPS (128 runs)
|
| 35 |
+
512 x 512: F16 892.5 GFLOPS (128 runs) | F32 1399.6 GFLOPS (128 runs)
|
| 36 |
+
1024 x 1024: Q4_0 1402.7 GFLOPS (128 runs) | Q4_1 1218.5 GFLOPS (128 runs)
|
| 37 |
+
1024 x 1024: Q5_0 1444.8 GFLOPS (128 runs) | Q5_1 1444.7 GFLOPS (128 runs) | Q8_0 1395.7 GFLOPS (128 runs)
|
| 38 |
+
1024 x 1024: F16 1524.1 GFLOPS (128 runs) | F32 1726.6 GFLOPS (128 runs)
|
| 39 |
+
2048 x 2048: Q4_0 1479.4 GFLOPS ( 87 runs) | Q4_1 1378.5 GFLOPS ( 81 runs)
|
| 40 |
+
2048 x 2048: Q5_0 1454.6 GFLOPS ( 85 runs) | Q5_1 1462.9 GFLOPS ( 86 runs) | Q8_0 1483.2 GFLOPS ( 87 runs)
|
| 41 |
+
2048 x 2048: F16 1488.0 GFLOPS ( 87 runs) | F32 1538.2 GFLOPS ( 90 runs)
|
| 42 |
+
4096 x 4096: Q4_0 1509.7 GFLOPS ( 11 runs) | Q4_1 1433.0 GFLOPS ( 11 runs)
|
| 43 |
+
4096 x 4096: Q5_0 1422.4 GFLOPS ( 11 runs) | Q5_1 1437.0 GFLOPS ( 11 runs) | Q8_0 1523.0 GFLOPS ( 12 runs)
|
| 44 |
+
4096 x 4096: F16 1551.3 GFLOPS ( 12 runs) | F32 1451.0 GFLOPS ( 11 runs)
|
| 45 |
+
|
| 46 |
+
| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 47 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 48 |
+
| M1 Pro | METAL | tiny | 1 | 0 | 39.21 | 1.74 | 0.61 | 0.04 | 22c96b4 |
|
| 49 |
+
| M1 Pro | METAL | base | 1 | 0 | 70.76 | 2.60 | 0.93 | 0.06 | 22c96b4 |
|
| 50 |
+
| M1 Pro | METAL | small | 1 | 0 | 217.28 | 6.42 | 2.14 | 0.17 | 22c96b4 |
|
| 51 |
+
| M1 Pro | METAL | medium | 1 | 0 | 596.74 | 14.43 | 4.75 | 0.45 | 22c96b4 |
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
make -j && ./scripts/bench-all.sh 1 1 1
|
| 55 |
+
|
| 56 |
+
| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 57 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 58 |
+
| M1 Pro | METAL | tiny | 1 | 1 | 30.77 | 1.59 | 0.54 | 0.03 | 22c96b4 |
|
| 59 |
+
| M1 Pro | METAL | base | 1 | 1 | 60.42 | 2.29 | 0.81 | 0.05 | 22c96b4 |
|
| 60 |
+
| M1 Pro | METAL | small | 1 | 1 | 183.82 | 5.12 | 1.81 | 0.14 | 22c96b4 |
|
| 61 |
+
| M1 Pro | METAL | medium | 1 | 1 | 517.92 | 11.60 | 4.01 | 0.38 | 22c96b4 |
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
## M2 Ultra
|
| 65 |
+
|
| 66 |
+
make -j && ./scripts/bench-all.sh 8
|
| 67 |
+
|
| 68 |
+
Running memcpy benchmark
|
| 69 |
+
|
| 70 |
+
memcpy: 46.58 GB/s (heat-up)
|
| 71 |
+
memcpy: 54.16 GB/s ( 1 thread)
|
| 72 |
+
memcpy: 54.23 GB/s ( 1 thread)
|
| 73 |
+
memcpy: 99.63 GB/s ( 2 thread)
|
| 74 |
+
memcpy: 140.59 GB/s ( 3 thread)
|
| 75 |
+
memcpy: 176.52 GB/s ( 4 thread)
|
| 76 |
+
memcpy: 158.90 GB/s ( 5 thread)
|
| 77 |
+
memcpy: 163.00 GB/s ( 6 thread)
|
| 78 |
+
memcpy: 189.69 GB/s ( 7 thread)
|
| 79 |
+
memcpy: 197.15 GB/s ( 8 thread)
|
| 80 |
+
sum: -5120002007.000000
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
make -j && ./scripts/bench-all.sh 1
|
| 84 |
+
|
| 85 |
+
Running ggml_mul_mat benchmark with 1 threads
|
| 86 |
+
|
| 87 |
+
64 x 64: Q4_0 245.8 GFLOPS (128 runs) | Q4_1 168.6 GFLOPS (128 runs)
|
| 88 |
+
64 x 64: Q5_0 115.7 GFLOPS (128 runs) | Q5_1 125.9 GFLOPS (128 runs) | Q8_0 215.8 GFLOPS (128 runs)
|
| 89 |
+
64 x 64: F16 139.5 GFLOPS (128 runs) | F32 337.2 GFLOPS (128 runs)
|
| 90 |
+
128 x 128: Q4_0 494.8 GFLOPS (128 runs) | Q4_1 350.4 GFLOPS (128 runs)
|
| 91 |
+
128 x 128: Q5_0 257.1 GFLOPS (128 runs) | Q5_1 261.4 GFLOPS (128 runs) | Q8_0 509.4 GFLOPS (128 runs)
|
| 92 |
+
128 x 128: F16 302.3 GFLOPS (128 runs) | F32 672.8 GFLOPS (128 runs)
|
| 93 |
+
256 x 256: Q4_0 795.7 GFLOPS (128 runs) | Q4_1 663.7 GFLOPS (128 runs)
|
| 94 |
+
256 x 256: Q5_0 737.8 GFLOPS (128 runs) | Q5_1 757.6 GFLOPS (128 runs) | Q8_0 827.7 GFLOPS (128 runs)
|
| 95 |
+
256 x 256: F16 872.6 GFLOPS (128 runs) | F32 956.3 GFLOPS (128 runs)
|
| 96 |
+
512 x 512: Q4_0 1188.0 GFLOPS (128 runs) | Q4_1 1085.0 GFLOPS (128 runs)
|
| 97 |
+
512 x 512: Q5_0 1421.1 GFLOPS (128 runs) | Q5_1 1454.9 GFLOPS (128 runs) | Q8_0 1191.4 GFLOPS (128 runs)
|
| 98 |
+
512 x 512: F16 1577.4 GFLOPS (128 runs) | F32 1982.0 GFLOPS (128 runs)
|
| 99 |
+
1024 x 1024: Q4_0 2342.6 GFLOPS (128 runs) | Q4_1 1955.8 GFLOPS (128 runs)
|
| 100 |
+
1024 x 1024: Q5_0 2306.7 GFLOPS (128 runs) | Q5_1 2217.0 GFLOPS (128 runs) | Q8_0 2230.7 GFLOPS (128 runs)
|
| 101 |
+
1024 x 1024: F16 2593.8 GFLOPS (128 runs) | F32 3269.0 GFLOPS (128 runs)
|
| 102 |
+
2048 x 2048: Q4_0 3735.7 GFLOPS (128 runs) | Q4_1 3205.3 GFLOPS (128 runs)
|
| 103 |
+
2048 x 2048: Q5_0 3584.5 GFLOPS (128 runs) | Q5_1 3621.7 GFLOPS (128 runs) | Q8_0 3622.3 GFLOPS (128 runs)
|
| 104 |
+
2048 x 2048: F16 3763.6 GFLOPS (128 runs) | F32 4153.3 GFLOPS (128 runs)
|
| 105 |
+
4096 x 4096: Q4_0 3891.1 GFLOPS ( 29 runs) | Q4_1 3554.0 GFLOPS ( 26 runs)
|
| 106 |
+
4096 x 4096: Q5_0 3753.1 GFLOPS ( 28 runs) | Q5_1 3750.1 GFLOPS ( 28 runs) | Q8_0 3768.5 GFLOPS ( 28 runs)
|
| 107 |
+
4096 x 4096: F16 3864.2 GFLOPS ( 29 runs) | F32 3970.5 GFLOPS ( 29 runs)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
make -j && ./scripts/bench-all.sh 1 1 0
|
| 111 |
+
|
| 112 |
+
| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 113 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 114 |
+
| M2 ULTRA | METAL | tiny | 1 | 0 | 12.32 | 1.35 | 0.49 | 0.01 | 22c96b4 |
|
| 115 |
+
| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 11.65 | 1.30 | 0.51 | 0.01 | 22c96b4 |
|
| 116 |
+
| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 12.08 | 1.30 | 0.51 | 0.01 | 22c96b4 |
|
| 117 |
+
| M2 ULTRA | METAL | base | 1 | 0 | 17.58 | 1.90 | 0.76 | 0.02 | 22c96b4 |
|
| 118 |
+
| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 18.89 | 1.86 | 0.79 | 0.02 | 22c96b4 |
|
| 119 |
+
| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 20.69 | 1.88 | 0.79 | 0.02 | 22c96b4 |
|
| 120 |
+
| M2 ULTRA | METAL | small | 1 | 0 | 49.32 | 3.85 | 1.71 | 0.05 | 22c96b4 |
|
| 121 |
+
| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 54.91 | 3.81 | 1.82 | 0.06 | 22c96b4 |
|
| 122 |
+
| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 54.92 | 3.81 | 1.79 | 0.06 | 22c96b4 |
|
| 123 |
+
| M2 ULTRA | METAL | medium | 1 | 0 | 134.34 | 8.04 | 3.82 | 0.13 | 22c96b4 |
|
| 124 |
+
| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 151.68 | 7.59 | 4.07 | 0.14 | 22c96b4 |
|
| 125 |
+
| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 151.58 | 7.67 | 4.07 | 0.14 | 22c96b4 |
|
| 126 |
+
| M2 ULTRA | METAL | medium-dis | 1 | 0 | 120.82 | 1.07 | 0.41 | 0.02 | 22c96b4 |
|
| 127 |
+
| M2 ULTRA | METAL | large-v2 | 1 | 0 | 235.63 | 12.27 | 5.85 | 0.22 | 22c96b4 |
|
| 128 |
+
| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 273.38 | 11.17 | 6.40 | 0.26 | 22c96b4 |
|
| 129 |
+
| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 272.44 | 11.32 | 6.29 | 0.26 | 22c96b4 |
|
| 130 |
+
| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 212.51 | 1.20 | 0.47 | 0.02 | 22c96b4 |
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
make -j && ./scripts/bench-all.sh 1 1 1
|
| 134 |
+
|
| 135 |
+
| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 136 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 137 |
+
| M2 ULTRA | METAL | tiny | 1 | 1 | 9.07 | 1.33 | 0.45 | 0.01 | 22c96b4 |
|
| 138 |
+
| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 9.74 | 1.33 | 0.47 | 0.01 | 22c96b4 |
|
| 139 |
+
| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 8.93 | 1.31 | 0.46 | 0.01 | 22c96b4 |
|
| 140 |
+
| M2 ULTRA | METAL | base | 1 | 1 | 15.75 | 1.87 | 0.71 | 0.02 | 22c96b4 |
|
| 141 |
+
| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 17.04 | 1.83 | 0.74 | 0.02 | 22c96b4 |
|
| 142 |
+
| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 17.17 | 1.83 | 0.74 | 0.02 | 22c96b4 |
|
| 143 |
+
| M2 ULTRA | METAL | small | 1 | 1 | 42.33 | 3.64 | 1.60 | 0.05 | 22c96b4 |
|
| 144 |
+
| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 47.61 | 3.63 | 1.70 | 0.05 | 22c96b4 |
|
| 145 |
+
| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 47.70 | 3.66 | 1.68 | 0.05 | 22c96b4 |
|
| 146 |
+
| M2 ULTRA | METAL | medium | 1 | 1 | 114.42 | 7.53 | 3.55 | 0.11 | 22c96b4 |
|
| 147 |
+
| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 132.63 | 7.02 | 3.77 | 0.13 | 22c96b4 |
|
| 148 |
+
| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 132.28 | 7.10 | 3.76 | 0.13 | 22c96b4 |
|
| 149 |
+
| M2 ULTRA | METAL | medium-dis | 1 | 1 | 102.34 | 1.01 | 0.42 | 0.01 | 22c96b4 |
|
| 150 |
+
| M2 ULTRA | METAL | large-v2 | 1 | 1 | 203.01 | 11.03 | 5.45 | 0.20 | 22c96b4 |
|
| 151 |
+
| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 240.05 | 10.18 | 5.98 | 0.23 | 22c96b4 |
|
| 152 |
+
| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 239.22 | 10.23 | 5.87 | 0.23 | 22c96b4 |
|
| 153 |
+
| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 181.14 | 1.14 | 0.48 | 0.02 | 22c96b4 |
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
## Ryzen 9 5950X + RTX 2060
|
| 158 |
+
|
| 159 |
+
make -j && ./scripts/bench-all.sh 8 0 0
|
| 160 |
+
|
| 161 |
+
Running memcpy benchmark
|
| 162 |
+
|
| 163 |
+
memcpy: 12.36 GB/s (heat-up)
|
| 164 |
+
memcpy: 12.33 GB/s ( 1 thread)
|
| 165 |
+
memcpy: 12.38 GB/s ( 1 thread)
|
| 166 |
+
memcpy: 14.48 GB/s ( 2 thread)
|
| 167 |
+
memcpy: 15.00 GB/s ( 3 thread)
|
| 168 |
+
memcpy: 14.77 GB/s ( 4 thread)
|
| 169 |
+
memcpy: 14.60 GB/s ( 5 thread)
|
| 170 |
+
memcpy: 14.57 GB/s ( 6 thread)
|
| 171 |
+
memcpy: 14.34 GB/s ( 7 thread)
|
| 172 |
+
memcpy: 14.40 GB/s ( 8 thread)
|
| 173 |
+
sum: -5119998076.000000
|
| 174 |
+
|
| 175 |
+
Running ggml_mul_mat benchmark with 8 threads
|
| 176 |
+
|
| 177 |
+
64 x 64: Q4_0 3.1 GFLOPS (128 runs) | Q4_1 3.1 GFLOPS (128 runs)
|
| 178 |
+
64 x 64: Q5_0 3.0 GFLOPS (128 runs) | Q5_1 2.9 GFLOPS (128 runs) | Q8_0 3.1 GFLOPS (128 runs)
|
| 179 |
+
64 x 64: F16 3.0 GFLOPS (128 runs) | F32 3.0 GFLOPS (128 runs)
|
| 180 |
+
128 x 128: Q4_0 21.1 GFLOPS (128 runs) | Q4_1 20.3 GFLOPS (128 runs)
|
| 181 |
+
128 x 128: Q5_0 20.6 GFLOPS (128 runs) | Q5_1 20.4 GFLOPS (128 runs) | Q8_0 22.1 GFLOPS (128 runs)
|
| 182 |
+
128 x 128: F16 21.7 GFLOPS (128 runs) | F32 21.7 GFLOPS (128 runs)
|
| 183 |
+
256 x 256: Q4_0 105.7 GFLOPS (128 runs) | Q4_1 94.4 GFLOPS (128 runs)
|
| 184 |
+
256 x 256: Q5_0 94.8 GFLOPS (128 runs) | Q5_1 87.5 GFLOPS (128 runs) | Q8_0 107.2 GFLOPS (128 runs)
|
| 185 |
+
256 x 256: F16 95.1 GFLOPS (128 runs) | F32 94.3 GFLOPS (128 runs)
|
| 186 |
+
512 x 512: Q4_0 214.7 GFLOPS (128 runs) | Q4_1 189.8 GFLOPS (128 runs)
|
| 187 |
+
512 x 512: Q5_0 187.7 GFLOPS (128 runs) | Q5_1 176.2 GFLOPS (128 runs) | Q8_0 252.2 GFLOPS (128 runs)
|
| 188 |
+
512 x 512: F16 220.8 GFLOPS (128 runs) | F32 218.3 GFLOPS (128 runs)
|
| 189 |
+
1024 x 1024: Q4_0 333.7 GFLOPS (128 runs) | Q4_1 305.8 GFLOPS (128 runs)
|
| 190 |
+
1024 x 1024: Q5_0 283.2 GFLOPS (128 runs) | Q5_1 268.2 GFLOPS (125 runs) | Q8_0 394.1 GFLOPS (128 runs)
|
| 191 |
+
1024 x 1024: F16 355.0 GFLOPS (128 runs) | F32 313.0 GFLOPS (128 runs)
|
| 192 |
+
2048 x 2048: Q4_0 395.0 GFLOPS ( 23 runs) | Q4_1 380.6 GFLOPS ( 23 runs)
|
| 193 |
+
2048 x 2048: Q5_0 336.6 GFLOPS ( 20 runs) | Q5_1 318.4 GFLOPS ( 19 runs) | Q8_0 482.6 GFLOPS ( 29 runs)
|
| 194 |
+
2048 x 2048: F16 424.5 GFLOPS ( 25 runs) | F32 337.7 GFLOPS ( 20 runs)
|
| 195 |
+
4096 x 4096: Q4_0 412.8 GFLOPS ( 4 runs) | Q4_1 405.1 GFLOPS ( 3 runs)
|
| 196 |
+
4096 x 4096: Q5_0 346.0 GFLOPS ( 3 runs) | Q5_1 334.6 GFLOPS ( 3 runs) | Q8_0 502.6 GFLOPS ( 4 runs)
|
| 197 |
+
4096 x 4096: F16 412.5 GFLOPS ( 4 runs) | F32 274.0 GFLOPS ( 3 runs)
|
| 198 |
+
|
| 199 |
+
| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 200 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 201 |
+
| Ryzen 9 5950X | AVX2 | tiny | 8 | 0 | 195.29 | 1.57 | 0.51 | 0.26 | 22c96b4 |
|
| 202 |
+
| Ryzen 9 5950X | AVX2 | tiny-q5_0 | 8 | 0 | 213.33 | 1.10 | 0.50 | 0.30 | 22c96b4 |
|
| 203 |
+
| Ryzen 9 5950X | AVX2 | tiny-q5_1 | 8 | 0 | 219.38 | 1.18 | 0.53 | 0.32 | 22c96b4 |
|
| 204 |
+
| Ryzen 9 5950X | AVX2 | base | 8 | 0 | 424.85 | 3.71 | 1.03 | 0.46 | 22c96b4 |
|
| 205 |
+
| Ryzen 9 5950X | AVX2 | base-q5_0 | 8 | 0 | 473.61 | 1.81 | 0.82 | 0.52 | 22c96b4 |
|
| 206 |
+
| Ryzen 9 5950X | AVX2 | base-q5_1 | 8 | 0 | 484.14 | 1.92 | 0.85 | 0.56 | 22c96b4 |
|
| 207 |
+
| Ryzen 9 5950X | AVX2 | small | 8 | 0 | 1458.32 | 12.66 | 3.09 | 1.26 | 22c96b4 |
|
| 208 |
+
| Ryzen 9 5950X | AVX2 | small-q5_0 | 8 | 0 | 1673.22 | 6.42 | 2.18 | 1.45 | 22c96b4 |
|
| 209 |
+
| Ryzen 9 5950X | AVX2 | small-q5_1 | 8 | 0 | 1724.78 | 6.72 | 2.32 | 1.52 | 22c96b4 |
|
| 210 |
+
| Ryzen 9 5950X | AVX2 | medium | 8 | 0 | 4333.87 | 36.80 | 8.56 | 3.37 | 22c96b4 |
|
| 211 |
+
| Ryzen 9 5950X | AVX2 | medium-q5_0 | 8 | 0 | 5194.09 | 19.21 | 5.71 | 3.97 | 22c96b4 |
|
| 212 |
+
| Ryzen 9 5950X | AVX2 | medium-q5_1 | 8 | 0 | 5450.39 | 20.01 | 5.99 | 4.17 | 22c96b4 |
|
| 213 |
+
| Ryzen 9 5950X | AVX2 | medium-dis | 8 | 0 | 3995.19 | 5.08 | 1.21 | 0.55 | 22c96b4 |
|
| 214 |
+
| Ryzen 9 5950X | AVX2 | large-v2 | 8 | 0 | 8056.16 | 69.74 | 16.11 | 6.13 | 22c96b4 |
|
| 215 |
+
| Ryzen 9 5950X | AVX2 | large-v2-q5_0 | 8 | 0 | 9799.58 | 35.16 | 10.49 | 7.28 | 22c96b4 |
|
| 216 |
+
| Ryzen 9 5950X | AVX2 | large-v2-q5_1 | 8 | 0 | ms | 36.74 | 11.02 | 7.65 | 22c96b4 |
|
| 217 |
+
| Ryzen 9 5950X | AVX2 | large-v2-dis | 8 | 0 | 7490.03 | 7.40 | 1.70 | 0.72 | 22c96b4 |
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0
|
| 221 |
+
|
| 222 |
+
| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 223 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 224 |
+
| RTX 2060 | AVX2 CUDA | tiny | 8 | 0 | 12.54 | 0.93 | 0.29 | 0.02 | 22c96b4 |
|
| 225 |
+
| RTX 2060 | AVX2 CUDA | tiny-q5_0 | 8 | 0 | 12.73 | 0.98 | 0.24 | 0.02 | 22c96b4 |
|
| 226 |
+
| RTX 2060 | AVX2 CUDA | tiny-q5_1 | 8 | 0 | 12.72 | 0.99 | 0.24 | 0.02 | 22c96b4 |
|
| 227 |
+
| RTX 2060 | AVX2 CUDA | base | 8 | 0 | 24.14 | 1.28 | 0.41 | 0.03 | 22c96b4 |
|
| 228 |
+
| RTX 2060 | AVX2 CUDA | base-q5_0 | 8 | 0 | 24.58 | 1.38 | 0.35 | 0.03 | 22c96b4 |
|
| 229 |
+
| RTX 2060 | AVX2 CUDA | base-q5_1 | 8 | 0 | 24.58 | 1.37 | 0.35 | 0.03 | 22c96b4 |
|
| 230 |
+
| RTX 2060 | AVX2 CUDA | small | 8 | 0 | 74.70 | 2.91 | 0.84 | 0.07 | 22c96b4 |
|
| 231 |
+
| RTX 2060 | AVX2 CUDA | small-q5_0 | 8 | 0 | 76.12 | 2.84 | 0.77 | 0.08 | 22c96b4 |
|
| 232 |
+
| RTX 2060 | AVX2 CUDA | small-q5_1 | 8 | 0 | 76.14 | 2.84 | 0.76 | 0.08 | 22c96b4 |
|
| 233 |
+
| RTX 2060 | AVX2 CUDA | medium | 8 | 0 | 200.69 | 6.46 | 1.83 | 0.17 | 22c96b4 |
|
| 234 |
+
| RTX 2060 | AVX2 CUDA | medium-q5_0 | 8 | 0 | 204.80 | 5.90 | 1.65 | 0.19 | 22c96b4 |
|
| 235 |
+
| RTX 2060 | AVX2 CUDA | medium-q5_1 | 8 | 0 | 205.61 | 5.85 | 1.61 | 0.19 | 22c96b4 |
|
| 236 |
+
| RTX 2060 | AVX2 CUDA | medium-dis | 8 | 0 | 186.17 | 0.86 | 0.24 | 0.02 | 22c96b4 |
|
| 237 |
+
| RTX 2060 | AVX2 CUDA | large-v2 | 8 | 0 | 347.22 | 10.36 | 2.82 | 0.29 | 22c96b4 |
|
| 238 |
+
| RTX 2060 | AVX2 CUDA | large-v2-q5_0 | 8 | 0 | 357.06 | 8.81 | 2.58 | 0.34 | 22c96b4 |
|
| 239 |
+
| RTX 2060 | AVX2 CUDA | large-v2-q5_1 | 8 | 0 | 356.97 | 8.62 | 2.49 | 0.33 | 22c96b4 |
|
| 240 |
+
| RTX 2060 | AVX2 CUDA | large-v2-dis | 8 | 0 | 318.05 | 1.03 | 0.34 | 0.04 | 22c96b4 |
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1
|
| 244 |
+
|
| 245 |
+
| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 246 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 247 |
+
| RTX 2060 | AVX2 CUDA | tiny | 8 | 1 | 7.21 | 0.76 | 0.29 | 0.02 | 22c96b4 |
|
| 248 |
+
| RTX 2060 | AVX2 CUDA | tiny-q5_0 | 8 | 1 | 7.42 | 0.82 | 0.18 | 0.02 | 22c96b4 |
|
| 249 |
+
| RTX 2060 | AVX2 CUDA | tiny-q5_1 | 8 | 1 | 7.38 | 0.82 | 0.18 | 0.02 | 22c96b4 |
|
| 250 |
+
| RTX 2060 | AVX2 CUDA | base | 8 | 1 | 13.49 | 1.04 | 0.36 | 0.02 | 22c96b4 |
|
| 251 |
+
| RTX 2060 | AVX2 CUDA | base-q5_0 | 8 | 1 | 13.94 | 1.13 | 0.26 | 0.03 | 22c96b4 |
|
| 252 |
+
| RTX 2060 | AVX2 CUDA | base-q5_1 | 8 | 1 | 13.94 | 1.14 | 0.26 | 0.03 | 22c96b4 |
|
| 253 |
+
| RTX 2060 | AVX2 CUDA | small | 8 | 1 | 42.81 | 2.33 | 0.69 | 0.05 | 22c96b4 |
|
| 254 |
+
| RTX 2060 | AVX2 CUDA | small-q5_0 | 8 | 1 | 44.43 | 2.25 | 0.59 | 0.06 | 22c96b4 |
|
| 255 |
+
| RTX 2060 | AVX2 CUDA | small-q5_1 | 8 | 1 | 44.11 | 2.24 | 0.58 | 0.06 | 22c96b4 |
|
| 256 |
+
| RTX 2060 | AVX2 CUDA | medium | 8 | 1 | 115.47 | 5.17 | 1.45 | 0.11 | 22c96b4 |
|
| 257 |
+
| RTX 2060 | AVX2 CUDA | medium-q5_0 | 8 | 1 | 120.37 | 4.63 | 1.25 | 0.13 | 22c96b4 |
|
| 258 |
+
| RTX 2060 | AVX2 CUDA | medium-q5_1 | 8 | 1 | 120.28 | 4.55 | 1.21 | 0.13 | 22c96b4 |
|
| 259 |
+
| RTX 2060 | AVX2 CUDA | medium-dis | 8 | 1 | 101.69 | 0.75 | 0.20 | 0.02 | 22c96b4 |
|
| 260 |
+
| RTX 2060 | AVX2 CUDA | large-v2 | 8 | 1 | 205.67 | 8.49 | 2.19 | 0.18 | 22c96b4 |
|
| 261 |
+
| RTX 2060 | AVX2 CUDA | large-v2-q5_0 | 8 | 1 | 214.07 | 6.88 | 1.94 | 0.22 | 22c96b4 |
|
| 262 |
+
| RTX 2060 | AVX2 CUDA | large-v2-q5_1 | 8 | 1 | 213.98 | 6.70 | 1.86 | 0.22 | 22c96b4 |
|
| 263 |
+
| RTX 2060 | AVX2 CUDA | large-v2-dis | 8 | 1 | 176.71 | 0.91 | 0.31 | 0.03 | 22c96b4 |
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# V100
|
| 269 |
+
|
| 270 |
+
WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0
|
| 271 |
+
|
| 272 |
+
| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 273 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 274 |
+
| V100 | AVX2 CUDA | tiny | 1 | 0 | 6.21 | 1.11 | 0.30 | 0.02 | 22c96b4 |
|
| 275 |
+
| V100 | AVX2 CUDA | tiny-q5_1 | 1 | 0 | 5.97 | 1.10 | 0.26 | 0.02 | 22c96b4 |
|
| 276 |
+
| V100 | AVX2 CUDA | base | 1 | 0 | 10.95 | 1.47 | 0.42 | 0.03 | 22c96b4 |
|
| 277 |
+
| V100 | AVX2 CUDA | base-q5_1 | 1 | 0 | 11.13 | 1.53 | 0.36 | 0.03 | 22c96b4 |
|
| 278 |
+
| V100 | AVX2 CUDA | small | 1 | 0 | 31.57 | 2.96 | 0.84 | 0.05 | 22c96b4 |
|
| 279 |
+
| V100 | AVX2 CUDA | small-q5_1 | 1 | 0 | 32.19 | 3.14 | 0.75 | 0.05 | 22c96b4 |
|
| 280 |
+
| V100 | AVX2 CUDA | medium | 1 | 0 | 85.88 | 6.49 | 1.80 | 0.10 | 22c96b4 |
|
| 281 |
+
| V100 | AVX2 CUDA | medium-q5_0 | 1 | 0 | 87.53 | 5.82 | 1.37 | 0.10 | 22c96b4 |
|
| 282 |
+
| V100 | AVX2 CUDA | large-v2 | 1 | 0 | 142.23 | 8.92 | 2.62 | 0.15 | 22c96b4 |
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1
|
| 286 |
+
|
| 287 |
+
| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
|
| 288 |
+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
|
| 289 |
+
| V100 | AVX2 CUDA | tiny | 1 | 1 | 3.96 | 0.82 | 0.24 | 0.02 | 22c96b4 |
|
| 290 |
+
| V100 | AVX2 CUDA | tiny-q5_1 | 1 | 1 | 4.05 | 0.85 | 0.18 | 0.02 | 22c96b4 |
|
| 291 |
+
| V100 | AVX2 CUDA | base | 1 | 1 | 7.21 | 1.16 | 0.36 | 0.02 | 22c96b4 |
|
| 292 |
+
| V100 | AVX2 CUDA | base-q5_1 | 1 | 1 | 7.39 | 1.21 | 0.26 | 0.02 | 22c96b4 |
|
| 293 |
+
| V100 | AVX2 CUDA | small | 1 | 1 | 19.81 | 2.41 | 0.71 | 0.04 | 22c96b4 |
|
| 294 |
+
| V100 | AVX2 CUDA | small-q5_1 | 1 | 1 | 20.50 | 2.31 | 0.51 | 0.04 | 22c96b4 |
|
| 295 |
+
| V100 | AVX2 CUDA | medium | 1 | 1 | 56.02 | 4.89 | 1.44 | 0.07 | 22c96b4 |
|
| 296 |
+
| V100 | AVX2 CUDA | medium-q5_0 | 1 | 1 | 57.85 | 4.73 | 1.09 | 0.08 | 22c96b4 |
|
| 297 |
+
| V100 | AVX2 CUDA | large-v2 | 1 | 1 | 92.73 | 7.18 | 2.14 | 0.10 | 22c96b4 |
|
| 298 |
+
|
scripts/bench-all.sh
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
# Helper script to run the bench tool on all models and print the results in share-able format
|
| 4 |
|
| 5 |
-
printf "Usage: ./bench.sh [n_threads] [encoder-only]\n"
|
| 6 |
|
| 7 |
if [ -z "$1" ]; then
|
| 8 |
n_threads=4
|
|
@@ -11,12 +11,19 @@ else
|
|
| 11 |
fi
|
| 12 |
|
| 13 |
encoder_only=0
|
| 14 |
-
if [ -z "$2" ]; then
|
| 15 |
encoder_only=0
|
| 16 |
else
|
| 17 |
encoder_only=$2
|
| 18 |
fi
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
models=( \
|
| 21 |
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
| 22 |
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
|
|
@@ -44,13 +51,19 @@ if [ "$encoder_only" -eq 0 ]; then
|
|
| 44 |
printf "\n"
|
| 45 |
fi
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
for model in "${models[@]}"; do
|
| 51 |
# actual run
|
| 52 |
# store stderr output in a variable in order to parse it later
|
| 53 |
-
output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
|
| 54 |
ret=$?
|
| 55 |
|
| 56 |
# parse the output:
|
|
@@ -95,6 +108,6 @@ for model in "${models[@]}"; do
|
|
| 95 |
commit=$(git rev-parse --short HEAD)
|
| 96 |
|
| 97 |
if [ $ret -eq 0 ]; then
|
| 98 |
-
printf "| <todo> | <todo> | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
|
| 99 |
fi
|
| 100 |
done
|
|
|
|
| 2 |
|
| 3 |
# Helper script to run the bench tool on all models and print the results in share-able format
|
| 4 |
|
| 5 |
+
printf "Usage: ./bench.sh [n_threads] [encoder-only] [flash-attn]\n"
|
| 6 |
|
| 7 |
if [ -z "$1" ]; then
|
| 8 |
n_threads=4
|
|
|
|
| 11 |
fi
|
| 12 |
|
| 13 |
encoder_only=0
|
| 14 |
+
if [ -z "$2" ] || [ "$2" -eq 0 ]; then
|
| 15 |
encoder_only=0
|
| 16 |
else
|
| 17 |
encoder_only=$2
|
| 18 |
fi
|
| 19 |
|
| 20 |
+
fattn=""
|
| 21 |
+
if [ -z "$3" ] || [ "$3" -eq 0 ]; then
|
| 22 |
+
fattn=""
|
| 23 |
+
else
|
| 24 |
+
fattn="-fa"
|
| 25 |
+
fi
|
| 26 |
+
|
| 27 |
models=( \
|
| 28 |
"tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
|
| 29 |
"base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
|
|
|
|
| 51 |
printf "\n"
|
| 52 |
fi
|
| 53 |
|
| 54 |
+
if [ "$fattn" == "-fa" ]; then
|
| 55 |
+
fattn_i=1
|
| 56 |
+
else
|
| 57 |
+
fattn_i=0
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "FA" "Enc." "Dec." "Bch5" "PP" "Commit"
|
| 61 |
+
printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
|
| 62 |
|
| 63 |
for model in "${models[@]}"; do
|
| 64 |
# actual run
|
| 65 |
# store stderr output in a variable in order to parse it later
|
| 66 |
+
output=$(./bench -m ./models/ggml-$model.bin -t $n_threads $fattn 2>&1)
|
| 67 |
ret=$?
|
| 68 |
|
| 69 |
# parse the output:
|
|
|
|
| 108 |
commit=$(git rev-parse --short HEAD)
|
| 109 |
|
| 110 |
if [ $ret -eq 0 ]; then
|
| 111 |
+
printf "| <todo> | <todo> | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$fattn_i" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
|
| 112 |
fi
|
| 113 |
done
|
whisper.cpp
CHANGED
|
@@ -809,14 +809,15 @@ struct whisper_state {
|
|
| 809 |
// shared between all decoders
|
| 810 |
whisper_kv_cache kv_cross;
|
| 811 |
|
|
|
|
|
|
|
|
|
|
| 812 |
whisper_mel mel;
|
| 813 |
|
| 814 |
whisper_batch batch;
|
| 815 |
|
| 816 |
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
| 817 |
|
| 818 |
-
ggml_backend_t backend = nullptr;
|
| 819 |
-
|
| 820 |
// ggml-alloc:
|
| 821 |
// - stores meta info about the intermediate tensors into the `meta` buffers
|
| 822 |
// - stores the actual tensor data into the `data` buffers
|
|
@@ -902,14 +903,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|
| 902 |
}
|
| 903 |
|
| 904 |
static bool kv_cache_init(
|
| 905 |
-
const struct whisper_hparams & hparams,
|
| 906 |
struct whisper_kv_cache & cache,
|
| 907 |
ggml_backend_t backend,
|
| 908 |
ggml_type wtype,
|
|
|
|
|
|
|
| 909 |
int n_ctx) {
|
| 910 |
-
const int64_t n_text_state = hparams.n_text_state;
|
| 911 |
-
const int64_t n_text_layer = hparams.n_text_layer;
|
| 912 |
-
|
| 913 |
const int64_t n_mem = n_text_layer*n_ctx;
|
| 914 |
const int64_t n_elements = n_text_state*n_mem;
|
| 915 |
|
|
@@ -941,6 +940,8 @@ static bool kv_cache_init(
|
|
| 941 |
return false;
|
| 942 |
}
|
| 943 |
|
|
|
|
|
|
|
| 944 |
return true;
|
| 945 |
}
|
| 946 |
|
|
@@ -1068,6 +1069,26 @@ static void whisper_kv_cache_seq_cp(
|
|
| 1068 |
}
|
| 1069 |
}
|
| 1070 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1071 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 1072 |
static bool aheads_masks_init(
|
| 1073 |
const whisper_context_params & cparams,
|
|
@@ -1872,6 +1893,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 1872 |
const int n_head = hparams.n_audio_head;
|
| 1873 |
const int n_layer = hparams.n_audio_layer;
|
| 1874 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1875 |
struct ggml_init_params params = {
|
| 1876 |
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
|
| 1877 |
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
|
|
@@ -1884,7 +1913,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 1884 |
|
| 1885 |
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
| 1886 |
|
| 1887 |
-
const float KQscale = 1.0f/sqrtf(float(
|
| 1888 |
|
| 1889 |
// ===================================================================
|
| 1890 |
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
@@ -1934,14 +1963,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 1934 |
|
| 1935 |
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
| 1936 |
|
| 1937 |
-
//Qcur = ggml_scale(ctx0, Qcur, pow(float(
|
| 1938 |
|
| 1939 |
// note: no bias for Key
|
| 1940 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
| 1941 |
layer.attn_k_w,
|
| 1942 |
cur);
|
| 1943 |
|
| 1944 |
-
//Kcur = ggml_scale(ctx0, Kcur, pow(float(
|
| 1945 |
|
| 1946 |
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
| 1947 |
layer.attn_v_w,
|
|
@@ -1955,38 +1984,61 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|
| 1955 |
ggml_permute(ctx0,
|
| 1956 |
ggml_cpy(ctx0,
|
| 1957 |
Qcur,
|
| 1958 |
-
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32,
|
| 1959 |
0, 2, 1, 3);
|
| 1960 |
|
| 1961 |
-
|
| 1962 |
-
|
| 1963 |
-
|
| 1964 |
-
Kcur,
|
| 1965 |
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
| 1966 |
-
0, 2, 1, 3);
|
| 1967 |
-
|
| 1968 |
-
// K * Q
|
| 1969 |
-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 1970 |
|
| 1971 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1972 |
|
| 1973 |
-
|
| 1974 |
-
|
| 1975 |
-
|
| 1976 |
-
|
| 1977 |
-
|
| 1978 |
-
|
| 1979 |
-
1, 2, 0, 3),
|
| 1980 |
-
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
| 1981 |
-
);
|
| 1982 |
|
| 1983 |
-
|
| 1984 |
|
| 1985 |
-
|
| 1986 |
-
|
| 1987 |
-
|
| 1988 |
-
|
| 1989 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1990 |
}
|
| 1991 |
|
| 1992 |
// projection
|
|
@@ -2085,6 +2137,10 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
| 2085 |
const int n_state = hparams.n_audio_state;
|
| 2086 |
const int n_head = hparams.n_audio_head;
|
| 2087 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2088 |
struct ggml_init_params params = {
|
| 2089 |
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
|
| 2090 |
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
|
|
@@ -2097,18 +2153,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
| 2097 |
|
| 2098 |
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
| 2099 |
|
| 2100 |
-
const float Kscale = pow(float(
|
| 2101 |
|
| 2102 |
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
| 2103 |
auto & layer = model.layers_decoder[il];
|
| 2104 |
|
| 2105 |
-
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
| 2106 |
layer.cross_attn_k_w,
|
| 2107 |
cur);
|
| 2108 |
|
| 2109 |
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
| 2110 |
|
| 2111 |
-
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
| 2112 |
layer.cross_attn_v_w,
|
| 2113 |
cur);
|
| 2114 |
|
|
@@ -2116,15 +2172,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|
| 2116 |
Vcross,
|
| 2117 |
layer.cross_attn_v_b);
|
| 2118 |
|
| 2119 |
-
|
|
|
|
| 2120 |
|
| 2121 |
-
|
| 2122 |
-
|
| 2123 |
-
|
| 2124 |
|
| 2125 |
-
|
| 2126 |
-
|
| 2127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2128 |
|
| 2129 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
|
| 2130 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
|
|
@@ -2195,7 +2261,7 @@ static bool whisper_encode_internal(
|
|
| 2195 |
}
|
| 2196 |
|
| 2197 |
if (!whisper_encode_external(wstate)) {
|
| 2198 |
-
if (!ggml_graph_compute_helper(
|
| 2199 |
return false;
|
| 2200 |
}
|
| 2201 |
} else {
|
|
@@ -2218,7 +2284,7 @@ static bool whisper_encode_internal(
|
|
| 2218 |
return false;
|
| 2219 |
}
|
| 2220 |
|
| 2221 |
-
if (!ggml_graph_compute_helper(
|
| 2222 |
return false;
|
| 2223 |
}
|
| 2224 |
}
|
|
@@ -2234,7 +2300,7 @@ static bool whisper_encode_internal(
|
|
| 2234 |
return false;
|
| 2235 |
}
|
| 2236 |
|
| 2237 |
-
if (!ggml_graph_compute_helper(
|
| 2238 |
return false;
|
| 2239 |
}
|
| 2240 |
}
|
|
@@ -2263,11 +2329,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2263 |
const int n_head = hparams.n_text_head;
|
| 2264 |
const int n_layer = hparams.n_text_layer;
|
| 2265 |
|
|
|
|
|
|
|
| 2266 |
const int n_tokens = batch.n_tokens;
|
| 2267 |
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 2268 |
|
| 2269 |
-
const
|
| 2270 |
-
|
|
|
|
|
|
|
| 2271 |
|
| 2272 |
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
| 2273 |
|
|
@@ -2289,12 +2359,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2289 |
ggml_set_name(position, "position");
|
| 2290 |
ggml_set_input(position);
|
| 2291 |
|
| 2292 |
-
const float KQscale = pow(float(
|
| 2293 |
|
| 2294 |
-
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
| 2295 |
ggml_set_name(KQ_mask, "KQ_mask");
|
| 2296 |
ggml_set_input(KQ_mask);
|
| 2297 |
|
|
|
|
|
|
|
| 2298 |
// token encoding + position encoding
|
| 2299 |
struct ggml_tensor * cur =
|
| 2300 |
ggml_add(ctx0,
|
|
@@ -2350,12 +2422,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2350 |
Vcur,
|
| 2351 |
layer.attn_v_b);
|
| 2352 |
|
| 2353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2354 |
|
| 2355 |
-
|
| 2356 |
-
|
| 2357 |
-
|
| 2358 |
-
|
|
|
|
|
|
|
|
|
|
| 2359 |
|
| 2360 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
| 2361 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
|
@@ -2365,35 +2450,48 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2365 |
|
| 2366 |
struct ggml_tensor * Q =
|
| 2367 |
ggml_permute(ctx0,
|
| 2368 |
-
ggml_reshape_3d(ctx0, Qcur,
|
| 2369 |
0, 2, 1, 3);
|
| 2370 |
|
| 2371 |
struct ggml_tensor * K =
|
| 2372 |
ggml_view_3d(ctx0, kv_self.k,
|
| 2373 |
-
|
| 2374 |
ggml_element_size(kv_self.k)*n_state,
|
| 2375 |
-
ggml_element_size(kv_self.k)*
|
| 2376 |
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
| 2377 |
|
| 2378 |
-
|
| 2379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2380 |
|
| 2381 |
-
|
| 2382 |
|
| 2383 |
-
|
| 2384 |
-
|
| 2385 |
-
|
| 2386 |
-
|
| 2387 |
-
|
| 2388 |
-
|
| 2389 |
|
| 2390 |
-
|
| 2391 |
|
| 2392 |
-
|
| 2393 |
|
| 2394 |
-
|
| 2395 |
-
|
| 2396 |
-
|
|
|
|
| 2397 |
}
|
| 2398 |
|
| 2399 |
// projection
|
|
@@ -2432,80 +2530,77 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|
| 2432 |
Qcur,
|
| 2433 |
layer.cross_attn_q_b);
|
| 2434 |
|
| 2435 |
-
Qcur = ggml_scale(ctx0, Qcur, KQscale);
|
| 2436 |
-
|
| 2437 |
-
// Kcross is already scaled
|
| 2438 |
-
struct ggml_tensor * Kcross =
|
| 2439 |
-
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
| 2440 |
-
n_state/n_head, n_audio_ctx, n_head,
|
| 2441 |
-
ggml_element_size(wstate.kv_cross.k)*n_state,
|
| 2442 |
-
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
| 2443 |
-
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
| 2444 |
-
|
| 2445 |
-
//struct ggml_tensor * Vcross =
|
| 2446 |
-
// ggml_reshape_3d(ctx0,
|
| 2447 |
-
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
|
| 2448 |
-
// n_state/n_head, n_head, n_audio_ctx);
|
| 2449 |
-
|
| 2450 |
-
//struct ggml_tensor * V_trans =
|
| 2451 |
-
// ggml_cpy(ctx0,
|
| 2452 |
-
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
| 2453 |
-
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
| 2454 |
-
|
| 2455 |
-
struct ggml_tensor * V =
|
| 2456 |
-
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
| 2457 |
-
n_audio_ctx, n_state/n_head, n_head,
|
| 2458 |
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
| 2459 |
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
| 2460 |
-
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
| 2461 |
-
|
| 2462 |
-
// ------
|
| 2463 |
-
|
| 2464 |
struct ggml_tensor * Q =
|
| 2465 |
ggml_permute(ctx0,
|
| 2466 |
-
ggml_reshape_3d(ctx0, Qcur,
|
| 2467 |
0, 2, 1, 3);
|
| 2468 |
|
| 2469 |
-
|
| 2470 |
-
|
| 2471 |
-
|
| 2472 |
-
|
| 2473 |
-
|
| 2474 |
-
|
| 2475 |
-
|
| 2476 |
-
// );
|
| 2477 |
|
| 2478 |
-
|
| 2479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2480 |
|
| 2481 |
-
|
| 2482 |
|
| 2483 |
-
|
| 2484 |
-
|
| 2485 |
-
|
| 2486 |
-
|
| 2487 |
-
|
| 2488 |
-
|
| 2489 |
-
|
| 2490 |
-
|
| 2491 |
-
|
| 2492 |
-
|
| 2493 |
-
|
| 2494 |
-
|
| 2495 |
-
|
| 2496 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2497 |
}
|
| 2498 |
}
|
| 2499 |
-
}
|
| 2500 |
|
| 2501 |
-
|
| 2502 |
|
| 2503 |
-
|
| 2504 |
|
| 2505 |
-
|
| 2506 |
-
|
| 2507 |
-
|
| 2508 |
-
|
| 2509 |
}
|
| 2510 |
|
| 2511 |
// projection
|
|
@@ -2638,7 +2733,9 @@ static bool whisper_decode_internal(
|
|
| 2638 |
return false;
|
| 2639 |
}
|
| 2640 |
|
| 2641 |
-
|
|
|
|
|
|
|
| 2642 |
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
| 2643 |
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
| 2644 |
}
|
|
@@ -2672,9 +2769,10 @@ static bool whisper_decode_internal(
|
|
| 2672 |
struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
|
| 2673 |
|
| 2674 |
auto & kv_self = wstate.kv_self;
|
| 2675 |
-
const int32_t n_kv = kv_self.n;
|
| 2676 |
|
| 2677 |
-
|
|
|
|
|
|
|
| 2678 |
|
| 2679 |
float * data = wstate.inp_mask.data();
|
| 2680 |
memset(data, 0, ggml_nbytes(KQ_mask));
|
|
@@ -2690,6 +2788,12 @@ static bool whisper_decode_internal(
|
|
| 2690 |
}
|
| 2691 |
}
|
| 2692 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2693 |
}
|
| 2694 |
|
| 2695 |
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
|
@@ -2697,7 +2801,7 @@ static bool whisper_decode_internal(
|
|
| 2697 |
|
| 2698 |
logits = gf->nodes[gf->n_nodes - 1];
|
| 2699 |
|
| 2700 |
-
if (!ggml_graph_compute_helper(
|
| 2701 |
return false;
|
| 2702 |
}
|
| 2703 |
}
|
|
@@ -3144,18 +3248,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3144 |
|
| 3145 |
whisper_state * state = new whisper_state;
|
| 3146 |
|
| 3147 |
-
state->backend = whisper_backend_init(ctx->params);
|
| 3148 |
-
if (!state->backend) {
|
| 3149 |
-
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
|
| 3150 |
-
whisper_free_state(state);
|
| 3151 |
-
return nullptr;
|
| 3152 |
-
}
|
| 3153 |
-
|
| 3154 |
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
| 3155 |
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
| 3156 |
const int factor = 3;
|
| 3157 |
|
| 3158 |
-
if (!kv_cache_init(
|
|
|
|
|
|
|
|
|
|
| 3159 |
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 3160 |
whisper_free_state(state);
|
| 3161 |
return nullptr;
|
|
@@ -3166,7 +3266,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3166 |
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3167 |
}
|
| 3168 |
|
| 3169 |
-
if (!kv_cache_init(
|
|
|
|
|
|
|
|
|
|
| 3170 |
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
| 3171 |
whisper_free_state(state);
|
| 3172 |
return nullptr;
|
|
@@ -3177,6 +3280,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 3177 |
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3178 |
}
|
| 3179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3180 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 3181 |
if (ctx->params.dtw_token_timestamps) {
|
| 3182 |
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
|
|
@@ -3347,6 +3464,7 @@ int whisper_ctx_init_openvino_encoder(
|
|
| 3347 |
struct whisper_context_params whisper_context_default_params() {
|
| 3348 |
struct whisper_context_params result = {
|
| 3349 |
/*.use_gpu =*/ true,
|
|
|
|
| 3350 |
/*.gpu_device =*/ 0,
|
| 3351 |
|
| 3352 |
/*.dtw_token_timestamps =*/ false,
|
|
@@ -3445,6 +3563,16 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
|
|
| 3445 |
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
| 3446 |
ggml_time_init();
|
| 3447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3448 |
whisper_context * ctx = new whisper_context;
|
| 3449 |
ctx->params = params;
|
| 3450 |
|
|
@@ -3533,6 +3661,7 @@ void whisper_free_state(struct whisper_state * state) {
|
|
| 3533 |
if (state) {
|
| 3534 |
kv_cache_free(state->kv_self);
|
| 3535 |
kv_cache_free(state->kv_cross);
|
|
|
|
| 3536 |
|
| 3537 |
#ifdef WHISPER_USE_COREML
|
| 3538 |
if (state->ctx_coreml != nullptr) {
|
|
@@ -3555,8 +3684,6 @@ void whisper_free_state(struct whisper_state * state) {
|
|
| 3555 |
ggml_gallocr_free(state->alloc_cross.alloc);
|
| 3556 |
ggml_gallocr_free(state->alloc_decode.alloc);
|
| 3557 |
|
| 3558 |
-
ggml_backend_free(state->backend);
|
| 3559 |
-
|
| 3560 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 3561 |
aheads_masks_free(state->aheads_masks);
|
| 3562 |
|
|
|
|
| 809 |
// shared between all decoders
|
| 810 |
whisper_kv_cache kv_cross;
|
| 811 |
|
| 812 |
+
// padded buffer for flash-attention
|
| 813 |
+
whisper_kv_cache kv_pad;
|
| 814 |
+
|
| 815 |
whisper_mel mel;
|
| 816 |
|
| 817 |
whisper_batch batch;
|
| 818 |
|
| 819 |
whisper_decoder decoders[WHISPER_MAX_DECODERS];
|
| 820 |
|
|
|
|
|
|
|
| 821 |
// ggml-alloc:
|
| 822 |
// - stores meta info about the intermediate tensors into the `meta` buffers
|
| 823 |
// - stores the actual tensor data into the `data` buffers
|
|
|
|
| 903 |
}
|
| 904 |
|
| 905 |
static bool kv_cache_init(
|
|
|
|
| 906 |
struct whisper_kv_cache & cache,
|
| 907 |
ggml_backend_t backend,
|
| 908 |
ggml_type wtype,
|
| 909 |
+
int64_t n_text_state,
|
| 910 |
+
int64_t n_text_layer,
|
| 911 |
int n_ctx) {
|
|
|
|
|
|
|
|
|
|
| 912 |
const int64_t n_mem = n_text_layer*n_ctx;
|
| 913 |
const int64_t n_elements = n_text_state*n_mem;
|
| 914 |
|
|
|
|
| 940 |
return false;
|
| 941 |
}
|
| 942 |
|
| 943 |
+
ggml_backend_buffer_clear(cache.buffer, 0);
|
| 944 |
+
|
| 945 |
return true;
|
| 946 |
}
|
| 947 |
|
|
|
|
| 1069 |
}
|
| 1070 |
}
|
| 1071 |
|
| 1072 |
+
static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
|
| 1073 |
+
if (!wctx.params.flash_attn) {
|
| 1074 |
+
return 1u;
|
| 1075 |
+
}
|
| 1076 |
+
|
| 1077 |
+
#ifdef GGML_USE_METAL
|
| 1078 |
+
if (ggml_backend_is_metal(wctx.backend)) {
|
| 1079 |
+
return 32u;
|
| 1080 |
+
}
|
| 1081 |
+
#endif
|
| 1082 |
+
|
| 1083 |
+
#ifdef GGML_USE_CUDA
|
| 1084 |
+
if (ggml_backend_is_cuda(wctx.backend)) {
|
| 1085 |
+
return 256u;
|
| 1086 |
+
}
|
| 1087 |
+
#endif
|
| 1088 |
+
|
| 1089 |
+
return 1u;
|
| 1090 |
+
}
|
| 1091 |
+
|
| 1092 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 1093 |
static bool aheads_masks_init(
|
| 1094 |
const whisper_context_params & cparams,
|
|
|
|
| 1893 |
const int n_head = hparams.n_audio_head;
|
| 1894 |
const int n_layer = hparams.n_audio_layer;
|
| 1895 |
|
| 1896 |
+
const int n_state_head = n_state/n_head;
|
| 1897 |
+
|
| 1898 |
+
auto & kv_pad = wstate.kv_pad;
|
| 1899 |
+
|
| 1900 |
+
WHISPER_ASSERT(!!kv_pad.ctx);
|
| 1901 |
+
|
| 1902 |
+
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
| 1903 |
+
|
| 1904 |
struct ggml_init_params params = {
|
| 1905 |
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
|
| 1906 |
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
|
|
|
|
| 1913 |
|
| 1914 |
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
| 1915 |
|
| 1916 |
+
const float KQscale = 1.0f/sqrtf(float(n_state_head));
|
| 1917 |
|
| 1918 |
// ===================================================================
|
| 1919 |
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
|
|
|
| 1963 |
|
| 1964 |
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
| 1965 |
|
| 1966 |
+
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
|
| 1967 |
|
| 1968 |
// note: no bias for Key
|
| 1969 |
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
| 1970 |
layer.attn_k_w,
|
| 1971 |
cur);
|
| 1972 |
|
| 1973 |
+
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
|
| 1974 |
|
| 1975 |
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
| 1976 |
layer.attn_v_w,
|
|
|
|
| 1984 |
ggml_permute(ctx0,
|
| 1985 |
ggml_cpy(ctx0,
|
| 1986 |
Qcur,
|
| 1987 |
+
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
|
| 1988 |
0, 2, 1, 3);
|
| 1989 |
|
| 1990 |
+
if (wctx.params.flash_attn) {
|
| 1991 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
|
| 1992 |
+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1993 |
|
| 1994 |
+
struct ggml_tensor * K =
|
| 1995 |
+
ggml_view_3d(ctx0, kv_pad.k,
|
| 1996 |
+
n_state_head, n_ctx_pad, n_head,
|
| 1997 |
+
ggml_element_size(kv_pad.k)*n_state,
|
| 1998 |
+
ggml_element_size(kv_pad.k)*n_state_head,
|
| 1999 |
+
0);
|
| 2000 |
|
| 2001 |
+
struct ggml_tensor * V =
|
| 2002 |
+
ggml_view_3d(ctx0, kv_pad.v,
|
| 2003 |
+
n_state_head, n_ctx_pad, n_head,
|
| 2004 |
+
ggml_element_size(kv_pad.v)*n_state,
|
| 2005 |
+
ggml_element_size(kv_pad.v)*n_state_head,
|
| 2006 |
+
0);
|
|
|
|
|
|
|
|
|
|
| 2007 |
|
| 2008 |
+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f);
|
| 2009 |
|
| 2010 |
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
|
| 2011 |
+
} else {
|
| 2012 |
+
struct ggml_tensor * K =
|
| 2013 |
+
ggml_permute(ctx0,
|
| 2014 |
+
ggml_cpy(ctx0,
|
| 2015 |
+
Kcur,
|
| 2016 |
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
|
| 2017 |
+
0, 2, 1, 3);
|
| 2018 |
+
|
| 2019 |
+
// K * Q
|
| 2020 |
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 2021 |
+
|
| 2022 |
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
| 2023 |
+
|
| 2024 |
+
struct ggml_tensor * V =
|
| 2025 |
+
ggml_cpy(ctx0,
|
| 2026 |
+
ggml_permute(ctx0,
|
| 2027 |
+
ggml_reshape_3d(ctx0,
|
| 2028 |
+
Vcur,
|
| 2029 |
+
n_state_head, n_head, n_ctx),
|
| 2030 |
+
1, 2, 0, 3),
|
| 2031 |
+
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
|
| 2032 |
+
);
|
| 2033 |
+
|
| 2034 |
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2035 |
+
|
| 2036 |
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2037 |
+
|
| 2038 |
+
cur = ggml_cpy(ctx0,
|
| 2039 |
+
KQV_merged,
|
| 2040 |
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
| 2041 |
+
}
|
| 2042 |
}
|
| 2043 |
|
| 2044 |
// projection
|
|
|
|
| 2137 |
const int n_state = hparams.n_audio_state;
|
| 2138 |
const int n_head = hparams.n_audio_head;
|
| 2139 |
|
| 2140 |
+
const int n_state_head = n_state/n_head;
|
| 2141 |
+
|
| 2142 |
+
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
| 2143 |
+
|
| 2144 |
struct ggml_init_params params = {
|
| 2145 |
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
|
| 2146 |
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
|
|
|
|
| 2153 |
|
| 2154 |
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
| 2155 |
|
| 2156 |
+
const float Kscale = pow(float(n_state_head), -0.25);
|
| 2157 |
|
| 2158 |
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
| 2159 |
auto & layer = model.layers_decoder[il];
|
| 2160 |
|
| 2161 |
+
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
|
| 2162 |
layer.cross_attn_k_w,
|
| 2163 |
cur);
|
| 2164 |
|
| 2165 |
Kcross = ggml_scale(ctx0, Kcross, Kscale);
|
| 2166 |
|
| 2167 |
+
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
| 2168 |
layer.cross_attn_v_w,
|
| 2169 |
cur);
|
| 2170 |
|
|
|
|
| 2172 |
Vcross,
|
| 2173 |
layer.cross_attn_v_b);
|
| 2174 |
|
| 2175 |
+
struct ggml_tensor * k;
|
| 2176 |
+
struct ggml_tensor * v;
|
| 2177 |
|
| 2178 |
+
if (wctx.params.flash_attn) {
|
| 2179 |
+
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
| 2180 |
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
|
| 2181 |
|
| 2182 |
+
v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
|
| 2183 |
+
(ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
|
| 2184 |
+
} else {
|
| 2185 |
+
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
| 2186 |
+
|
| 2187 |
+
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
|
| 2188 |
+
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
| 2189 |
+
|
| 2190 |
+
v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
| 2191 |
+
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
| 2192 |
+
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
| 2193 |
+
}
|
| 2194 |
|
| 2195 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
|
| 2196 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
|
|
|
|
| 2261 |
}
|
| 2262 |
|
| 2263 |
if (!whisper_encode_external(wstate)) {
|
| 2264 |
+
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
|
| 2265 |
return false;
|
| 2266 |
}
|
| 2267 |
} else {
|
|
|
|
| 2284 |
return false;
|
| 2285 |
}
|
| 2286 |
|
| 2287 |
+
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
|
| 2288 |
return false;
|
| 2289 |
}
|
| 2290 |
}
|
|
|
|
| 2300 |
return false;
|
| 2301 |
}
|
| 2302 |
|
| 2303 |
+
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
|
| 2304 |
return false;
|
| 2305 |
}
|
| 2306 |
}
|
|
|
|
| 2329 |
const int n_head = hparams.n_text_head;
|
| 2330 |
const int n_layer = hparams.n_text_layer;
|
| 2331 |
|
| 2332 |
+
const int n_state_head = n_state/n_head;
|
| 2333 |
+
|
| 2334 |
const int n_tokens = batch.n_tokens;
|
| 2335 |
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
| 2336 |
|
| 2337 |
+
const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
|
| 2338 |
+
|
| 2339 |
+
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
| 2340 |
+
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
|
| 2341 |
|
| 2342 |
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
| 2343 |
|
|
|
|
| 2359 |
ggml_set_name(position, "position");
|
| 2360 |
ggml_set_input(position);
|
| 2361 |
|
| 2362 |
+
const float KQscale = pow(float(n_state_head), -0.25);
|
| 2363 |
|
| 2364 |
+
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
|
| 2365 |
ggml_set_name(KQ_mask, "KQ_mask");
|
| 2366 |
ggml_set_input(KQ_mask);
|
| 2367 |
|
| 2368 |
+
struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
|
| 2369 |
+
|
| 2370 |
// token encoding + position encoding
|
| 2371 |
struct ggml_tensor * cur =
|
| 2372 |
ggml_add(ctx0,
|
|
|
|
| 2422 |
Vcur,
|
| 2423 |
layer.attn_v_b);
|
| 2424 |
|
| 2425 |
+
struct ggml_tensor * k;
|
| 2426 |
+
struct ggml_tensor * v;
|
| 2427 |
+
|
| 2428 |
+
if (wctx.params.flash_attn) {
|
| 2429 |
+
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
| 2430 |
+
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
| 2431 |
+
|
| 2432 |
+
v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
|
| 2433 |
+
(ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
|
| 2434 |
+
} else {
|
| 2435 |
+
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
|
| 2436 |
|
| 2437 |
+
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
|
| 2438 |
+
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
|
| 2439 |
+
|
| 2440 |
+
v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
|
| 2441 |
+
( n_ctx)*ggml_element_size(kv_self.v),
|
| 2442 |
+
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
|
| 2443 |
+
}
|
| 2444 |
|
| 2445 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
| 2446 |
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
|
|
|
| 2450 |
|
| 2451 |
struct ggml_tensor * Q =
|
| 2452 |
ggml_permute(ctx0,
|
| 2453 |
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
| 2454 |
0, 2, 1, 3);
|
| 2455 |
|
| 2456 |
struct ggml_tensor * K =
|
| 2457 |
ggml_view_3d(ctx0, kv_self.k,
|
| 2458 |
+
n_state_head, n_kv, n_head,
|
| 2459 |
ggml_element_size(kv_self.k)*n_state,
|
| 2460 |
+
ggml_element_size(kv_self.k)*n_state_head,
|
| 2461 |
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
| 2462 |
|
| 2463 |
+
if (wctx.params.flash_attn) {
|
| 2464 |
+
struct ggml_tensor * V =
|
| 2465 |
+
ggml_view_3d(ctx0, kv_self.v,
|
| 2466 |
+
n_state_head, n_kv, n_head,
|
| 2467 |
+
ggml_element_size(kv_self.v)*n_state,
|
| 2468 |
+
ggml_element_size(kv_self.v)*n_state_head,
|
| 2469 |
+
ggml_element_size(kv_self.v)*n_state*n_ctx*il);
|
| 2470 |
+
|
| 2471 |
+
cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f);
|
| 2472 |
+
|
| 2473 |
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
| 2474 |
+
} else {
|
| 2475 |
+
// K * Q
|
| 2476 |
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 2477 |
|
| 2478 |
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
|
| 2479 |
|
| 2480 |
+
struct ggml_tensor * V =
|
| 2481 |
+
ggml_view_3d(ctx0, kv_self.v,
|
| 2482 |
+
n_kv, n_state_head, n_head,
|
| 2483 |
+
n_ctx*ggml_element_size(kv_self.v),
|
| 2484 |
+
n_ctx*ggml_element_size(kv_self.v)*n_state_head,
|
| 2485 |
+
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
|
| 2486 |
|
| 2487 |
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
| 2488 |
|
| 2489 |
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2490 |
|
| 2491 |
+
cur = ggml_cpy(ctx0,
|
| 2492 |
+
KQV_merged,
|
| 2493 |
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
| 2494 |
+
}
|
| 2495 |
}
|
| 2496 |
|
| 2497 |
// projection
|
|
|
|
| 2530 |
Qcur,
|
| 2531 |
layer.cross_attn_q_b);
|
| 2532 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2533 |
struct ggml_tensor * Q =
|
| 2534 |
ggml_permute(ctx0,
|
| 2535 |
+
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
| 2536 |
0, 2, 1, 3);
|
| 2537 |
|
| 2538 |
+
if (wctx.params.flash_attn) {
|
| 2539 |
+
struct ggml_tensor * Kcross =
|
| 2540 |
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
| 2541 |
+
n_state_head, n_audio_ctx_pad, n_head,
|
| 2542 |
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
| 2543 |
+
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
| 2544 |
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
|
|
|
|
| 2545 |
|
| 2546 |
+
struct ggml_tensor * Vcross =
|
| 2547 |
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
| 2548 |
+
n_state_head, n_audio_ctx_pad, n_head,
|
| 2549 |
+
ggml_element_size(wstate.kv_cross.v)*n_state,
|
| 2550 |
+
ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
| 2551 |
+
ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
|
| 2552 |
|
| 2553 |
+
cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f);
|
| 2554 |
|
| 2555 |
+
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
|
| 2556 |
+
} else {
|
| 2557 |
+
struct ggml_tensor * Kcross =
|
| 2558 |
+
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
| 2559 |
+
n_state_head, n_audio_ctx, n_head,
|
| 2560 |
+
ggml_element_size(wstate.kv_cross.k)*n_state,
|
| 2561 |
+
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
| 2562 |
+
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
| 2563 |
+
|
| 2564 |
+
struct ggml_tensor * Vcross =
|
| 2565 |
+
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
| 2566 |
+
n_audio_ctx, n_state_head, n_head,
|
| 2567 |
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
| 2568 |
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
| 2569 |
+
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
| 2570 |
+
|
| 2571 |
+
// ------
|
| 2572 |
+
|
| 2573 |
+
// K * Q
|
| 2574 |
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
|
| 2575 |
+
|
| 2576 |
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
| 2577 |
+
|
| 2578 |
+
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 2579 |
+
if (wctx.params.dtw_token_timestamps) {
|
| 2580 |
+
if (wstate.aheads_masks.m[il] != nullptr) {
|
| 2581 |
+
struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
|
| 2582 |
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
| 2583 |
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
| 2584 |
+
aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
|
| 2585 |
+
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
|
| 2586 |
+
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
|
| 2587 |
+
aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
|
| 2588 |
+
if (aheads_cross_QKs == NULL) {
|
| 2589 |
+
aheads_cross_QKs = aheads_KQs;
|
| 2590 |
+
} else {
|
| 2591 |
+
aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
|
| 2592 |
+
}
|
| 2593 |
}
|
| 2594 |
}
|
|
|
|
| 2595 |
|
| 2596 |
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
|
| 2597 |
|
| 2598 |
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2599 |
|
| 2600 |
+
cur = ggml_cpy(ctx0,
|
| 2601 |
+
KQV_merged,
|
| 2602 |
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
| 2603 |
+
}
|
| 2604 |
}
|
| 2605 |
|
| 2606 |
// projection
|
|
|
|
| 2733 |
return false;
|
| 2734 |
}
|
| 2735 |
|
| 2736 |
+
const uint32_t pad = whisper_kv_cache_get_padding(wctx);
|
| 2737 |
+
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
|
| 2738 |
+
|
| 2739 |
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
|
| 2740 |
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
|
| 2741 |
}
|
|
|
|
| 2769 |
struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
|
| 2770 |
|
| 2771 |
auto & kv_self = wstate.kv_self;
|
|
|
|
| 2772 |
|
| 2773 |
+
const int32_t n_kv = kv_self.n;
|
| 2774 |
+
|
| 2775 |
+
wstate.inp_mask.resize(ggml_nelements(KQ_mask));
|
| 2776 |
|
| 2777 |
float * data = wstate.inp_mask.data();
|
| 2778 |
memset(data, 0, ggml_nbytes(KQ_mask));
|
|
|
|
| 2788 |
}
|
| 2789 |
}
|
| 2790 |
}
|
| 2791 |
+
|
| 2792 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 2793 |
+
for (int j = 0; j < n_kv; ++j) {
|
| 2794 |
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 2795 |
+
}
|
| 2796 |
+
}
|
| 2797 |
}
|
| 2798 |
|
| 2799 |
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
|
|
|
|
| 2801 |
|
| 2802 |
logits = gf->nodes[gf->n_nodes - 1];
|
| 2803 |
|
| 2804 |
+
if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
|
| 2805 |
return false;
|
| 2806 |
}
|
| 2807 |
}
|
|
|
|
| 3248 |
|
| 3249 |
whisper_state * state = new whisper_state;
|
| 3250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3251 |
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
| 3252 |
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
| 3253 |
const int factor = 3;
|
| 3254 |
|
| 3255 |
+
if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype,
|
| 3256 |
+
ctx->model.hparams.n_text_state,
|
| 3257 |
+
ctx->model.hparams.n_text_layer,
|
| 3258 |
+
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
| 3259 |
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 3260 |
whisper_free_state(state);
|
| 3261 |
return nullptr;
|
|
|
|
| 3266 |
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3267 |
}
|
| 3268 |
|
| 3269 |
+
if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
|
| 3270 |
+
ctx->model.hparams.n_text_state,
|
| 3271 |
+
ctx->model.hparams.n_text_layer,
|
| 3272 |
+
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
| 3273 |
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
| 3274 |
whisper_free_state(state);
|
| 3275 |
return nullptr;
|
|
|
|
| 3280 |
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3281 |
}
|
| 3282 |
|
| 3283 |
+
if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype,
|
| 3284 |
+
ctx->model.hparams.n_audio_state,
|
| 3285 |
+
1,
|
| 3286 |
+
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
| 3287 |
+
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 3288 |
+
whisper_free_state(state);
|
| 3289 |
+
return nullptr;
|
| 3290 |
+
}
|
| 3291 |
+
|
| 3292 |
+
{
|
| 3293 |
+
const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
|
| 3294 |
+
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
|
| 3295 |
+
}
|
| 3296 |
+
|
| 3297 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 3298 |
if (ctx->params.dtw_token_timestamps) {
|
| 3299 |
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
|
|
|
|
| 3464 |
struct whisper_context_params whisper_context_default_params() {
|
| 3465 |
struct whisper_context_params result = {
|
| 3466 |
/*.use_gpu =*/ true,
|
| 3467 |
+
/*.flash_attn =*/ false,
|
| 3468 |
/*.gpu_device =*/ 0,
|
| 3469 |
|
| 3470 |
/*.dtw_token_timestamps =*/ false,
|
|
|
|
| 3563 |
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
|
| 3564 |
ggml_time_init();
|
| 3565 |
|
| 3566 |
+
if (params.flash_attn && params.dtw_token_timestamps) {
|
| 3567 |
+
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
|
| 3568 |
+
params.dtw_token_timestamps = false;
|
| 3569 |
+
}
|
| 3570 |
+
|
| 3571 |
+
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
|
| 3572 |
+
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
|
| 3573 |
+
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
|
| 3574 |
+
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
|
| 3575 |
+
|
| 3576 |
whisper_context * ctx = new whisper_context;
|
| 3577 |
ctx->params = params;
|
| 3578 |
|
|
|
|
| 3661 |
if (state) {
|
| 3662 |
kv_cache_free(state->kv_self);
|
| 3663 |
kv_cache_free(state->kv_cross);
|
| 3664 |
+
kv_cache_free(state->kv_pad);
|
| 3665 |
|
| 3666 |
#ifdef WHISPER_USE_COREML
|
| 3667 |
if (state->ctx_coreml != nullptr) {
|
|
|
|
| 3684 |
ggml_gallocr_free(state->alloc_cross.alloc);
|
| 3685 |
ggml_gallocr_free(state->alloc_decode.alloc);
|
| 3686 |
|
|
|
|
|
|
|
| 3687 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
| 3688 |
aheads_masks_free(state->aheads_masks);
|
| 3689 |
|
whisper.h
CHANGED
|
@@ -113,6 +113,7 @@ extern "C" {
|
|
| 113 |
|
| 114 |
struct whisper_context_params {
|
| 115 |
bool use_gpu;
|
|
|
|
| 116 |
int gpu_device; // CUDA device
|
| 117 |
|
| 118 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|
|
|
|
| 113 |
|
| 114 |
struct whisper_context_params {
|
| 115 |
bool use_gpu;
|
| 116 |
+
bool flash_attn;
|
| 117 |
int gpu_device; // CUDA device
|
| 118 |
|
| 119 |
// [EXPERIMENTAL] Token-level timestamps with DTW
|