ggerganov commited on
Commit
27c0a97
·
unverified ·
1 Parent(s): b445508

whisper : use flash attention (#2152)

Browse files

* whisper : use flash attention in the encoder

* whisper : add kv_pad

* whisper : remove extra backend instance (huh?)

* whisper : use FA for cross-attention

* whisper : use FA for self-attention

* whisper : simplify encoder FA

* whisper : add flash_attn runtime parameter

* scripts : add bench log

* scripts : add M1 Pro bench log

examples/bench/bench.cpp CHANGED
@@ -12,7 +12,8 @@ struct whisper_params {
12
 
13
  std::string model = "models/ggml-base.en.bin";
14
 
15
- bool use_gpu = true;
 
16
  };
17
 
18
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
@@ -25,10 +26,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
25
  whisper_print_usage(argc, argv, params);
26
  exit(0);
27
  }
28
- else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
29
- else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
30
- else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
31
- else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
32
  else {
33
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
34
  whisper_print_usage(argc, argv, params);
@@ -49,6 +51,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
49
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
50
  fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
51
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
52
  fprintf(stderr, " %-7s 0 - whisper\n", "");
53
  fprintf(stderr, " %-7s 1 - memcpy\n", "");
54
  fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
@@ -59,7 +62,9 @@ int whisper_bench_full(const whisper_params & params) {
59
  // whisper init
60
 
61
  struct whisper_context_params cparams = whisper_context_default_params();
62
- cparams.use_gpu = params.use_gpu;
 
 
63
 
64
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
65
 
 
12
 
13
  std::string model = "models/ggml-base.en.bin";
14
 
15
+ bool use_gpu = true;
16
+ bool flash_attn = false;
17
  };
18
 
19
  void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
 
26
  whisper_print_usage(argc, argv, params);
27
  exit(0);
28
  }
29
+ else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
30
+ else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
31
+ else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); }
32
+ else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
33
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
34
  else {
35
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
36
  whisper_print_usage(argc, argv, params);
 
51
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
52
  fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what);
53
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
54
+ fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false");
55
  fprintf(stderr, " %-7s 0 - whisper\n", "");
56
  fprintf(stderr, " %-7s 1 - memcpy\n", "");
57
  fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", "");
 
62
  // whisper init
63
 
64
  struct whisper_context_params cparams = whisper_context_default_params();
65
+
66
+ cparams.use_gpu = params.use_gpu;
67
+ cparams.flash_attn = params.flash_attn;
68
 
69
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
70
 
examples/command/command.cpp CHANGED
@@ -44,6 +44,7 @@ struct whisper_params {
44
  bool print_energy = false;
45
  bool no_timestamps = true;
46
  bool use_gpu = true;
 
47
 
48
  std::string language = "en";
49
  std::string model = "models/ggml-base.en.bin";
@@ -80,6 +81,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
80
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
81
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
82
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
83
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
84
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
85
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
@@ -118,6 +120,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
118
  fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
119
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
120
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
121
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
122
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
123
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
@@ -696,7 +699,9 @@ int main(int argc, char ** argv) {
696
  // whisper init
697
 
698
  struct whisper_context_params cparams = whisper_context_default_params();
699
- cparams.use_gpu = params.use_gpu;
 
 
700
 
701
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
702
 
 
44
  bool print_energy = false;
45
  bool no_timestamps = true;
46
  bool use_gpu = true;
47
+ bool flash_attn = false;
48
 
49
  std::string language = "en";
50
  std::string model = "models/ggml-base.en.bin";
 
81
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
82
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
83
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
84
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
85
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
86
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
87
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
 
120
  fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
121
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
122
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
123
+ fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
124
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
125
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
126
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
 
699
  // whisper init
700
 
701
  struct whisper_context_params cparams = whisper_context_default_params();
702
+
703
+ cparams.use_gpu = params.use_gpu;
704
+ cparams.flash_attn = params.flash_attn;
705
 
706
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
707
 
examples/lsp/lsp.cpp CHANGED
@@ -31,6 +31,7 @@ struct whisper_params {
31
  bool print_special = false;
32
  bool print_energy = false;
33
  bool use_gpu = true;
 
34
 
35
  std::string language = "en";
36
  std::string model = "models/ggml-base.en.bin";
@@ -74,6 +75,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
74
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
75
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
76
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
77
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
78
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
79
  else {
@@ -105,6 +107,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
105
  fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
106
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
107
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
108
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
109
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
110
  fprintf(stderr, "\n");
@@ -436,7 +439,10 @@ int main(int argc, char ** argv) {
436
 
437
  // whisper init
438
  struct whisper_context_params cparams = whisper_context_default_params();
439
- cparams.use_gpu = params.use_gpu;
 
 
 
440
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
441
  // init audio
442
 
 
31
  bool print_special = false;
32
  bool print_energy = false;
33
  bool use_gpu = true;
34
+ bool flash_attn = false;
35
 
36
  std::string language = "en";
37
  std::string model = "models/ggml-base.en.bin";
 
75
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
76
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
77
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
78
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
79
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
80
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
81
  else {
 
107
  fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
108
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
109
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
110
+ fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
111
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
112
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
113
  fprintf(stderr, "\n");
 
439
 
440
  // whisper init
441
  struct whisper_context_params cparams = whisper_context_default_params();
442
+
443
+ cparams.use_gpu = params.use_gpu;
444
+ cparams.flash_attn = params.flash_attn;
445
+
446
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
447
  // init audio
448
 
examples/main/main.cpp CHANGED
@@ -70,6 +70,7 @@ struct whisper_params {
70
  bool no_timestamps = false;
71
  bool log_score = false;
72
  bool use_gpu = true;
 
73
 
74
  std::string language = "en";
75
  std::string prompt;
@@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
168
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
169
  else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
170
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
171
- else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
 
172
  else if ( arg == "--grammar") { params.grammar = argv[++i]; }
173
  else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
174
  else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
@@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
234
  fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
235
  fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
236
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
237
  fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
238
  fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
239
  fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
@@ -977,7 +980,9 @@ int main(int argc, char ** argv) {
977
  // whisper init
978
 
979
  struct whisper_context_params cparams = whisper_context_default_params();
980
- cparams.use_gpu = params.use_gpu;
 
 
981
 
982
  if (!params.dtw.empty()) {
983
  cparams.dtw_token_timestamps = true;
 
70
  bool no_timestamps = false;
71
  bool log_score = false;
72
  bool use_gpu = true;
73
+ bool flash_attn = false;
74
 
75
  std::string language = "en";
76
  std::string prompt;
 
169
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
170
  else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
171
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
172
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
173
+ else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
174
  else if ( arg == "--grammar") { params.grammar = argv[++i]; }
175
  else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
176
  else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
 
236
  fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
237
  fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
238
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
239
+ fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
240
  fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
241
  fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
242
  fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
 
980
  // whisper init
981
 
982
  struct whisper_context_params cparams = whisper_context_default_params();
983
+
984
+ cparams.use_gpu = params.use_gpu;
985
+ cparams.flash_attn = params.flash_attn;
986
 
987
  if (!params.dtw.empty()) {
988
  cparams.dtw_token_timestamps = true;
examples/server/server.cpp CHANGED
@@ -75,6 +75,7 @@ struct whisper_params {
75
  bool print_progress = false;
76
  bool no_timestamps = false;
77
  bool use_gpu = true;
 
78
 
79
  std::string language = "en";
80
  std::string prompt = "";
@@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
178
  else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
179
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
180
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
181
  // server params
182
  else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
183
  else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@@ -502,7 +504,10 @@ int main(int argc, char ** argv) {
502
  }
503
  // whisper init
504
  struct whisper_context_params cparams = whisper_context_default_params();
505
- cparams.use_gpu = params.use_gpu;
 
 
 
506
  if (!params.dtw.empty()) {
507
  cparams.dtw_token_timestamps = true;
508
  cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
 
75
  bool print_progress = false;
76
  bool no_timestamps = false;
77
  bool use_gpu = true;
78
+ bool flash_attn = false;
79
 
80
  std::string language = "en";
81
  std::string prompt = "";
 
179
  else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
180
  else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
181
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
182
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
183
  // server params
184
  else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
185
  else if ( arg == "--host") { sparams.hostname = argv[++i]; }
 
504
  }
505
  // whisper init
506
  struct whisper_context_params cparams = whisper_context_default_params();
507
+
508
+ cparams.use_gpu = params.use_gpu;
509
+ cparams.flash_attn = params.flash_attn;
510
+
511
  if (!params.dtw.empty()) {
512
  cparams.dtw_token_timestamps = true;
513
  cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
examples/stream/stream.cpp CHANGED
@@ -36,6 +36,7 @@ struct whisper_params {
36
  bool tinydiarize = false;
37
  bool save_audio = false; // save audio to wav file
38
  bool use_gpu = true;
 
39
 
40
  std::string language = "en";
41
  std::string model = "models/ggml-base.en.bin";
@@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
72
  else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
73
  else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
74
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
75
 
76
  else {
77
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
109
  fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
110
  fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
111
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
 
112
  fprintf(stderr, "\n");
113
  }
114
 
@@ -153,7 +156,9 @@ int main(int argc, char ** argv) {
153
  }
154
 
155
  struct whisper_context_params cparams = whisper_context_default_params();
156
- cparams.use_gpu = params.use_gpu;
 
 
157
 
158
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
159
 
 
36
  bool tinydiarize = false;
37
  bool save_audio = false; // save audio to wav file
38
  bool use_gpu = true;
39
+ bool flash_attn = false;
40
 
41
  std::string language = "en";
42
  std::string model = "models/ggml-base.en.bin";
 
73
  else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
74
  else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
75
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
76
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
77
 
78
  else {
79
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
 
111
  fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
112
  fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
113
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
114
+ fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
115
  fprintf(stderr, "\n");
116
  }
117
 
 
156
  }
157
 
158
  struct whisper_context_params cparams = whisper_context_default_params();
159
+
160
+ cparams.use_gpu = params.use_gpu;
161
+ cparams.flash_attn = params.flash_attn;
162
 
163
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
164
 
examples/talk-llama/talk-llama.cpp CHANGED
@@ -66,6 +66,7 @@ struct whisper_params {
66
  bool no_timestamps = true;
67
  bool verbose_prompt = false;
68
  bool use_gpu = true;
 
69
 
70
  std::string person = "Georgi";
71
  std::string bot_name = "LLaMA";
@@ -105,6 +106,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
105
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
106
  else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; }
107
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
108
  else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
109
  else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
110
  else if (arg == "--session") { params.path_session = argv[++i]; }
@@ -123,7 +125,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
123
  }
124
  }
125
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
126
- else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
127
  else {
128
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
129
  whisper_print_usage(argc, argv, params);
@@ -154,6 +155,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
154
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
155
  fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
156
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
157
  fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
158
  fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
159
  fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
@@ -285,7 +287,9 @@ int main(int argc, char ** argv) {
285
  // whisper init
286
 
287
  struct whisper_context_params cparams = whisper_context_default_params();
288
- cparams.use_gpu = params.use_gpu;
 
 
289
 
290
  struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
291
  if (!ctx_wsp) {
@@ -316,6 +320,7 @@ int main(int argc, char ** argv) {
316
  lcparams.n_ctx = 2048;
317
  lcparams.seed = 1;
318
  lcparams.n_threads = params.n_threads;
 
319
 
320
  struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
321
 
 
66
  bool no_timestamps = true;
67
  bool verbose_prompt = false;
68
  bool use_gpu = true;
69
+ bool flash_attn = false;
70
 
71
  std::string person = "Georgi";
72
  std::string bot_name = "LLaMA";
 
106
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
107
  else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; }
108
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
109
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
110
  else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
111
  else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; }
112
  else if (arg == "--session") { params.path_session = argv[++i]; }
 
125
  }
126
  }
127
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
 
128
  else {
129
  fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
130
  whisper_print_usage(argc, argv, params);
 
155
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
156
  fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
157
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
158
+ fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
159
  fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
160
  fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str());
161
  fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str());
 
287
  // whisper init
288
 
289
  struct whisper_context_params cparams = whisper_context_default_params();
290
+
291
+ cparams.use_gpu = params.use_gpu;
292
+ cparams.flash_attn = params.flash_attn;
293
 
294
  struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
295
  if (!ctx_wsp) {
 
320
  lcparams.n_ctx = 2048;
321
  lcparams.seed = 1;
322
  lcparams.n_threads = params.n_threads;
323
+ lcparams.flash_attn = params.flash_attn;
324
 
325
  struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
326
 
examples/talk/talk.cpp CHANGED
@@ -32,6 +32,7 @@ struct whisper_params {
32
  bool print_energy = false;
33
  bool no_timestamps = true;
34
  bool use_gpu = true;
 
35
 
36
  std::string person = "Santa";
37
  std::string language = "en";
@@ -64,6 +65,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
64
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
65
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
66
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
67
  else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
68
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
69
  else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
@@ -99,6 +101,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
99
  fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
100
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
101
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
102
  fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
103
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
104
  fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
@@ -188,7 +191,9 @@ int main(int argc, char ** argv) {
188
 
189
  // whisper init
190
  struct whisper_context_params cparams = whisper_context_default_params();
191
- cparams.use_gpu = params.use_gpu;
 
 
192
 
193
  struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
194
 
 
32
  bool print_energy = false;
33
  bool no_timestamps = true;
34
  bool use_gpu = true;
35
+ bool flash_attn = false;
36
 
37
  std::string person = "Santa";
38
  std::string language = "en";
 
65
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
66
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
67
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
68
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
69
  else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
70
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
71
  else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
 
101
  fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
102
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
103
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
104
+ fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
105
  fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
106
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
107
  fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
 
191
 
192
  // whisper init
193
  struct whisper_context_params cparams = whisper_context_default_params();
194
+
195
+ cparams.use_gpu = params.use_gpu;
196
+ cparams.flash_attn = params.flash_attn;
197
 
198
  struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams);
199
 
examples/wchess/wchess.cmd/wchess.cmd.cpp CHANGED
@@ -32,6 +32,7 @@ struct whisper_params {
32
  bool print_energy = false;
33
  bool no_timestamps = true;
34
  bool use_gpu = true;
 
35
 
36
  std::string language = "en";
37
  std::string model = "models/ggml-base.en.bin";
@@ -61,6 +62,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
61
  fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
62
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
63
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
 
64
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
65
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
66
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
@@ -92,6 +94,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
92
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
93
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
94
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
 
95
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
96
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
97
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
@@ -183,7 +186,9 @@ int main(int argc, char ** argv) {
183
  // whisper init
184
 
185
  struct whisper_context_params cparams = whisper_context_default_params();
186
- cparams.use_gpu = params.use_gpu;
 
 
187
 
188
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
189
  if (!ctx) {
 
32
  bool print_energy = false;
33
  bool no_timestamps = true;
34
  bool use_gpu = true;
35
+ bool flash_attn = false;
36
 
37
  std::string language = "en";
38
  std::string model = "models/ggml-base.en.bin";
 
62
  fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
63
  fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
64
  fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
65
+ fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during decoding\n", params.flash_attn ? "true" : "false");
66
  fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
67
  fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
68
  fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
 
94
  else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
95
  else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
96
  else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
97
+ else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
98
  else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
99
  else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
100
  else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
 
186
  // whisper init
187
 
188
  struct whisper_context_params cparams = whisper_context_default_params();
189
+
190
+ cparams.use_gpu = params.use_gpu;
191
+ cparams.flash_attn = params.flash_attn;
192
 
193
  struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
194
  if (!ctx) {
scripts/bench-all-gg.txt ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## M1 Pro
2
+
3
+ make -j && ./scripts/bench-all.sh 8
4
+
5
+ Running memcpy benchmark
6
+
7
+ memcpy: 39.10 GB/s (heat-up)
8
+ memcpy: 44.75 GB/s ( 1 thread)
9
+ memcpy: 44.78 GB/s ( 1 thread)
10
+ memcpy: 44.97 GB/s ( 2 thread)
11
+ memcpy: 48.04 GB/s ( 3 thread)
12
+ memcpy: 50.55 GB/s ( 4 thread)
13
+ memcpy: 55.20 GB/s ( 5 thread)
14
+ memcpy: 65.60 GB/s ( 6 thread)
15
+ memcpy: 70.64 GB/s ( 7 thread)
16
+ memcpy: 73.34 GB/s ( 8 thread)
17
+ sum: -5120002535.000000
18
+
19
+
20
+ make -j && ./scripts/bench-all.sh 1 0 0
21
+
22
+ Running ggml_mul_mat benchmark with 1 threads
23
+
24
+ 64 x 64: Q4_0 237.1 GFLOPS (128 runs) | Q4_1 168.6 GFLOPS (128 runs)
25
+ 64 x 64: Q5_0 136.4 GFLOPS (128 runs) | Q5_1 135.6 GFLOPS (128 runs) | Q8_0 243.1 GFLOPS (128 runs)
26
+ 64 x 64: F16 140.4 GFLOPS (128 runs) | F32 316.6 GFLOPS (128 runs)
27
+ 128 x 128: Q4_0 496.6 GFLOPS (128 runs) | Q4_1 348.6 GFLOPS (128 runs)
28
+ 128 x 128: Q5_0 273.2 GFLOPS (128 runs) | Q5_1 274.1 GFLOPS (128 runs) | Q8_0 505.1 GFLOPS (128 runs)
29
+ 128 x 128: F16 300.4 GFLOPS (128 runs) | F32 653.9 GFLOPS (128 runs)
30
+ 256 x 256: Q4_0 791.7 GFLOPS (128 runs) | Q4_1 615.3 GFLOPS (128 runs)
31
+ 256 x 256: Q5_0 651.0 GFLOPS (128 runs) | Q5_1 674.7 GFLOPS (128 runs) | Q8_0 803.1 GFLOPS (128 runs)
32
+ 256 x 256: F16 869.6 GFLOPS (128 runs) | F32 957.2 GFLOPS (128 runs)
33
+ 512 x 512: Q4_0 973.3 GFLOPS (128 runs) | Q4_1 897.9 GFLOPS (128 runs)
34
+ 512 x 512: Q5_0 1078.8 GFLOPS (128 runs) | Q5_1 998.4 GFLOPS (128 runs) | Q8_0 752.4 GFLOPS (128 runs)
35
+ 512 x 512: F16 892.5 GFLOPS (128 runs) | F32 1399.6 GFLOPS (128 runs)
36
+ 1024 x 1024: Q4_0 1402.7 GFLOPS (128 runs) | Q4_1 1218.5 GFLOPS (128 runs)
37
+ 1024 x 1024: Q5_0 1444.8 GFLOPS (128 runs) | Q5_1 1444.7 GFLOPS (128 runs) | Q8_0 1395.7 GFLOPS (128 runs)
38
+ 1024 x 1024: F16 1524.1 GFLOPS (128 runs) | F32 1726.6 GFLOPS (128 runs)
39
+ 2048 x 2048: Q4_0 1479.4 GFLOPS ( 87 runs) | Q4_1 1378.5 GFLOPS ( 81 runs)
40
+ 2048 x 2048: Q5_0 1454.6 GFLOPS ( 85 runs) | Q5_1 1462.9 GFLOPS ( 86 runs) | Q8_0 1483.2 GFLOPS ( 87 runs)
41
+ 2048 x 2048: F16 1488.0 GFLOPS ( 87 runs) | F32 1538.2 GFLOPS ( 90 runs)
42
+ 4096 x 4096: Q4_0 1509.7 GFLOPS ( 11 runs) | Q4_1 1433.0 GFLOPS ( 11 runs)
43
+ 4096 x 4096: Q5_0 1422.4 GFLOPS ( 11 runs) | Q5_1 1437.0 GFLOPS ( 11 runs) | Q8_0 1523.0 GFLOPS ( 12 runs)
44
+ 4096 x 4096: F16 1551.3 GFLOPS ( 12 runs) | F32 1451.0 GFLOPS ( 11 runs)
45
+
46
+ | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
47
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
48
+ | M1 Pro | METAL | tiny | 1 | 0 | 39.21 | 1.74 | 0.61 | 0.04 | 22c96b4 |
49
+ | M1 Pro | METAL | base | 1 | 0 | 70.76 | 2.60 | 0.93 | 0.06 | 22c96b4 |
50
+ | M1 Pro | METAL | small | 1 | 0 | 217.28 | 6.42 | 2.14 | 0.17 | 22c96b4 |
51
+ | M1 Pro | METAL | medium | 1 | 0 | 596.74 | 14.43 | 4.75 | 0.45 | 22c96b4 |
52
+
53
+
54
+ make -j && ./scripts/bench-all.sh 1 1 1
55
+
56
+ | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
57
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
58
+ | M1 Pro | METAL | tiny | 1 | 1 | 30.77 | 1.59 | 0.54 | 0.03 | 22c96b4 |
59
+ | M1 Pro | METAL | base | 1 | 1 | 60.42 | 2.29 | 0.81 | 0.05 | 22c96b4 |
60
+ | M1 Pro | METAL | small | 1 | 1 | 183.82 | 5.12 | 1.81 | 0.14 | 22c96b4 |
61
+ | M1 Pro | METAL | medium | 1 | 1 | 517.92 | 11.60 | 4.01 | 0.38 | 22c96b4 |
62
+
63
+
64
+ ## M2 Ultra
65
+
66
+ make -j && ./scripts/bench-all.sh 8
67
+
68
+ Running memcpy benchmark
69
+
70
+ memcpy: 46.58 GB/s (heat-up)
71
+ memcpy: 54.16 GB/s ( 1 thread)
72
+ memcpy: 54.23 GB/s ( 1 thread)
73
+ memcpy: 99.63 GB/s ( 2 thread)
74
+ memcpy: 140.59 GB/s ( 3 thread)
75
+ memcpy: 176.52 GB/s ( 4 thread)
76
+ memcpy: 158.90 GB/s ( 5 thread)
77
+ memcpy: 163.00 GB/s ( 6 thread)
78
+ memcpy: 189.69 GB/s ( 7 thread)
79
+ memcpy: 197.15 GB/s ( 8 thread)
80
+ sum: -5120002007.000000
81
+
82
+
83
+ make -j && ./scripts/bench-all.sh 1
84
+
85
+ Running ggml_mul_mat benchmark with 1 threads
86
+
87
+ 64 x 64: Q4_0 245.8 GFLOPS (128 runs) | Q4_1 168.6 GFLOPS (128 runs)
88
+ 64 x 64: Q5_0 115.7 GFLOPS (128 runs) | Q5_1 125.9 GFLOPS (128 runs) | Q8_0 215.8 GFLOPS (128 runs)
89
+ 64 x 64: F16 139.5 GFLOPS (128 runs) | F32 337.2 GFLOPS (128 runs)
90
+ 128 x 128: Q4_0 494.8 GFLOPS (128 runs) | Q4_1 350.4 GFLOPS (128 runs)
91
+ 128 x 128: Q5_0 257.1 GFLOPS (128 runs) | Q5_1 261.4 GFLOPS (128 runs) | Q8_0 509.4 GFLOPS (128 runs)
92
+ 128 x 128: F16 302.3 GFLOPS (128 runs) | F32 672.8 GFLOPS (128 runs)
93
+ 256 x 256: Q4_0 795.7 GFLOPS (128 runs) | Q4_1 663.7 GFLOPS (128 runs)
94
+ 256 x 256: Q5_0 737.8 GFLOPS (128 runs) | Q5_1 757.6 GFLOPS (128 runs) | Q8_0 827.7 GFLOPS (128 runs)
95
+ 256 x 256: F16 872.6 GFLOPS (128 runs) | F32 956.3 GFLOPS (128 runs)
96
+ 512 x 512: Q4_0 1188.0 GFLOPS (128 runs) | Q4_1 1085.0 GFLOPS (128 runs)
97
+ 512 x 512: Q5_0 1421.1 GFLOPS (128 runs) | Q5_1 1454.9 GFLOPS (128 runs) | Q8_0 1191.4 GFLOPS (128 runs)
98
+ 512 x 512: F16 1577.4 GFLOPS (128 runs) | F32 1982.0 GFLOPS (128 runs)
99
+ 1024 x 1024: Q4_0 2342.6 GFLOPS (128 runs) | Q4_1 1955.8 GFLOPS (128 runs)
100
+ 1024 x 1024: Q5_0 2306.7 GFLOPS (128 runs) | Q5_1 2217.0 GFLOPS (128 runs) | Q8_0 2230.7 GFLOPS (128 runs)
101
+ 1024 x 1024: F16 2593.8 GFLOPS (128 runs) | F32 3269.0 GFLOPS (128 runs)
102
+ 2048 x 2048: Q4_0 3735.7 GFLOPS (128 runs) | Q4_1 3205.3 GFLOPS (128 runs)
103
+ 2048 x 2048: Q5_0 3584.5 GFLOPS (128 runs) | Q5_1 3621.7 GFLOPS (128 runs) | Q8_0 3622.3 GFLOPS (128 runs)
104
+ 2048 x 2048: F16 3763.6 GFLOPS (128 runs) | F32 4153.3 GFLOPS (128 runs)
105
+ 4096 x 4096: Q4_0 3891.1 GFLOPS ( 29 runs) | Q4_1 3554.0 GFLOPS ( 26 runs)
106
+ 4096 x 4096: Q5_0 3753.1 GFLOPS ( 28 runs) | Q5_1 3750.1 GFLOPS ( 28 runs) | Q8_0 3768.5 GFLOPS ( 28 runs)
107
+ 4096 x 4096: F16 3864.2 GFLOPS ( 29 runs) | F32 3970.5 GFLOPS ( 29 runs)
108
+
109
+
110
+ make -j && ./scripts/bench-all.sh 1 1 0
111
+
112
+ | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
113
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
114
+ | M2 ULTRA | METAL | tiny | 1 | 0 | 12.32 | 1.35 | 0.49 | 0.01 | 22c96b4 |
115
+ | M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 11.65 | 1.30 | 0.51 | 0.01 | 22c96b4 |
116
+ | M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 12.08 | 1.30 | 0.51 | 0.01 | 22c96b4 |
117
+ | M2 ULTRA | METAL | base | 1 | 0 | 17.58 | 1.90 | 0.76 | 0.02 | 22c96b4 |
118
+ | M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 18.89 | 1.86 | 0.79 | 0.02 | 22c96b4 |
119
+ | M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 20.69 | 1.88 | 0.79 | 0.02 | 22c96b4 |
120
+ | M2 ULTRA | METAL | small | 1 | 0 | 49.32 | 3.85 | 1.71 | 0.05 | 22c96b4 |
121
+ | M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 54.91 | 3.81 | 1.82 | 0.06 | 22c96b4 |
122
+ | M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 54.92 | 3.81 | 1.79 | 0.06 | 22c96b4 |
123
+ | M2 ULTRA | METAL | medium | 1 | 0 | 134.34 | 8.04 | 3.82 | 0.13 | 22c96b4 |
124
+ | M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 151.68 | 7.59 | 4.07 | 0.14 | 22c96b4 |
125
+ | M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 151.58 | 7.67 | 4.07 | 0.14 | 22c96b4 |
126
+ | M2 ULTRA | METAL | medium-dis | 1 | 0 | 120.82 | 1.07 | 0.41 | 0.02 | 22c96b4 |
127
+ | M2 ULTRA | METAL | large-v2 | 1 | 0 | 235.63 | 12.27 | 5.85 | 0.22 | 22c96b4 |
128
+ | M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 273.38 | 11.17 | 6.40 | 0.26 | 22c96b4 |
129
+ | M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 272.44 | 11.32 | 6.29 | 0.26 | 22c96b4 |
130
+ | M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 212.51 | 1.20 | 0.47 | 0.02 | 22c96b4 |
131
+
132
+
133
+ make -j && ./scripts/bench-all.sh 1 1 1
134
+
135
+ | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
136
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
137
+ | M2 ULTRA | METAL | tiny | 1 | 1 | 9.07 | 1.33 | 0.45 | 0.01 | 22c96b4 |
138
+ | M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 9.74 | 1.33 | 0.47 | 0.01 | 22c96b4 |
139
+ | M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 8.93 | 1.31 | 0.46 | 0.01 | 22c96b4 |
140
+ | M2 ULTRA | METAL | base | 1 | 1 | 15.75 | 1.87 | 0.71 | 0.02 | 22c96b4 |
141
+ | M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 17.04 | 1.83 | 0.74 | 0.02 | 22c96b4 |
142
+ | M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 17.17 | 1.83 | 0.74 | 0.02 | 22c96b4 |
143
+ | M2 ULTRA | METAL | small | 1 | 1 | 42.33 | 3.64 | 1.60 | 0.05 | 22c96b4 |
144
+ | M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 47.61 | 3.63 | 1.70 | 0.05 | 22c96b4 |
145
+ | M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 47.70 | 3.66 | 1.68 | 0.05 | 22c96b4 |
146
+ | M2 ULTRA | METAL | medium | 1 | 1 | 114.42 | 7.53 | 3.55 | 0.11 | 22c96b4 |
147
+ | M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 132.63 | 7.02 | 3.77 | 0.13 | 22c96b4 |
148
+ | M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 132.28 | 7.10 | 3.76 | 0.13 | 22c96b4 |
149
+ | M2 ULTRA | METAL | medium-dis | 1 | 1 | 102.34 | 1.01 | 0.42 | 0.01 | 22c96b4 |
150
+ | M2 ULTRA | METAL | large-v2 | 1 | 1 | 203.01 | 11.03 | 5.45 | 0.20 | 22c96b4 |
151
+ | M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 240.05 | 10.18 | 5.98 | 0.23 | 22c96b4 |
152
+ | M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 239.22 | 10.23 | 5.87 | 0.23 | 22c96b4 |
153
+ | M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 181.14 | 1.14 | 0.48 | 0.02 | 22c96b4 |
154
+
155
+
156
+
157
+ ## Ryzen 9 5950X + RTX 2060
158
+
159
+ make -j && ./scripts/bench-all.sh 8 0 0
160
+
161
+ Running memcpy benchmark
162
+
163
+ memcpy: 12.36 GB/s (heat-up)
164
+ memcpy: 12.33 GB/s ( 1 thread)
165
+ memcpy: 12.38 GB/s ( 1 thread)
166
+ memcpy: 14.48 GB/s ( 2 thread)
167
+ memcpy: 15.00 GB/s ( 3 thread)
168
+ memcpy: 14.77 GB/s ( 4 thread)
169
+ memcpy: 14.60 GB/s ( 5 thread)
170
+ memcpy: 14.57 GB/s ( 6 thread)
171
+ memcpy: 14.34 GB/s ( 7 thread)
172
+ memcpy: 14.40 GB/s ( 8 thread)
173
+ sum: -5119998076.000000
174
+
175
+ Running ggml_mul_mat benchmark with 8 threads
176
+
177
+ 64 x 64: Q4_0 3.1 GFLOPS (128 runs) | Q4_1 3.1 GFLOPS (128 runs)
178
+ 64 x 64: Q5_0 3.0 GFLOPS (128 runs) | Q5_1 2.9 GFLOPS (128 runs) | Q8_0 3.1 GFLOPS (128 runs)
179
+ 64 x 64: F16 3.0 GFLOPS (128 runs) | F32 3.0 GFLOPS (128 runs)
180
+ 128 x 128: Q4_0 21.1 GFLOPS (128 runs) | Q4_1 20.3 GFLOPS (128 runs)
181
+ 128 x 128: Q5_0 20.6 GFLOPS (128 runs) | Q5_1 20.4 GFLOPS (128 runs) | Q8_0 22.1 GFLOPS (128 runs)
182
+ 128 x 128: F16 21.7 GFLOPS (128 runs) | F32 21.7 GFLOPS (128 runs)
183
+ 256 x 256: Q4_0 105.7 GFLOPS (128 runs) | Q4_1 94.4 GFLOPS (128 runs)
184
+ 256 x 256: Q5_0 94.8 GFLOPS (128 runs) | Q5_1 87.5 GFLOPS (128 runs) | Q8_0 107.2 GFLOPS (128 runs)
185
+ 256 x 256: F16 95.1 GFLOPS (128 runs) | F32 94.3 GFLOPS (128 runs)
186
+ 512 x 512: Q4_0 214.7 GFLOPS (128 runs) | Q4_1 189.8 GFLOPS (128 runs)
187
+ 512 x 512: Q5_0 187.7 GFLOPS (128 runs) | Q5_1 176.2 GFLOPS (128 runs) | Q8_0 252.2 GFLOPS (128 runs)
188
+ 512 x 512: F16 220.8 GFLOPS (128 runs) | F32 218.3 GFLOPS (128 runs)
189
+ 1024 x 1024: Q4_0 333.7 GFLOPS (128 runs) | Q4_1 305.8 GFLOPS (128 runs)
190
+ 1024 x 1024: Q5_0 283.2 GFLOPS (128 runs) | Q5_1 268.2 GFLOPS (125 runs) | Q8_0 394.1 GFLOPS (128 runs)
191
+ 1024 x 1024: F16 355.0 GFLOPS (128 runs) | F32 313.0 GFLOPS (128 runs)
192
+ 2048 x 2048: Q4_0 395.0 GFLOPS ( 23 runs) | Q4_1 380.6 GFLOPS ( 23 runs)
193
+ 2048 x 2048: Q5_0 336.6 GFLOPS ( 20 runs) | Q5_1 318.4 GFLOPS ( 19 runs) | Q8_0 482.6 GFLOPS ( 29 runs)
194
+ 2048 x 2048: F16 424.5 GFLOPS ( 25 runs) | F32 337.7 GFLOPS ( 20 runs)
195
+ 4096 x 4096: Q4_0 412.8 GFLOPS ( 4 runs) | Q4_1 405.1 GFLOPS ( 3 runs)
196
+ 4096 x 4096: Q5_0 346.0 GFLOPS ( 3 runs) | Q5_1 334.6 GFLOPS ( 3 runs) | Q8_0 502.6 GFLOPS ( 4 runs)
197
+ 4096 x 4096: F16 412.5 GFLOPS ( 4 runs) | F32 274.0 GFLOPS ( 3 runs)
198
+
199
+ | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
200
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
201
+ | Ryzen 9 5950X | AVX2 | tiny | 8 | 0 | 195.29 | 1.57 | 0.51 | 0.26 | 22c96b4 |
202
+ | Ryzen 9 5950X | AVX2 | tiny-q5_0 | 8 | 0 | 213.33 | 1.10 | 0.50 | 0.30 | 22c96b4 |
203
+ | Ryzen 9 5950X | AVX2 | tiny-q5_1 | 8 | 0 | 219.38 | 1.18 | 0.53 | 0.32 | 22c96b4 |
204
+ | Ryzen 9 5950X | AVX2 | base | 8 | 0 | 424.85 | 3.71 | 1.03 | 0.46 | 22c96b4 |
205
+ | Ryzen 9 5950X | AVX2 | base-q5_0 | 8 | 0 | 473.61 | 1.81 | 0.82 | 0.52 | 22c96b4 |
206
+ | Ryzen 9 5950X | AVX2 | base-q5_1 | 8 | 0 | 484.14 | 1.92 | 0.85 | 0.56 | 22c96b4 |
207
+ | Ryzen 9 5950X | AVX2 | small | 8 | 0 | 1458.32 | 12.66 | 3.09 | 1.26 | 22c96b4 |
208
+ | Ryzen 9 5950X | AVX2 | small-q5_0 | 8 | 0 | 1673.22 | 6.42 | 2.18 | 1.45 | 22c96b4 |
209
+ | Ryzen 9 5950X | AVX2 | small-q5_1 | 8 | 0 | 1724.78 | 6.72 | 2.32 | 1.52 | 22c96b4 |
210
+ | Ryzen 9 5950X | AVX2 | medium | 8 | 0 | 4333.87 | 36.80 | 8.56 | 3.37 | 22c96b4 |
211
+ | Ryzen 9 5950X | AVX2 | medium-q5_0 | 8 | 0 | 5194.09 | 19.21 | 5.71 | 3.97 | 22c96b4 |
212
+ | Ryzen 9 5950X | AVX2 | medium-q5_1 | 8 | 0 | 5450.39 | 20.01 | 5.99 | 4.17 | 22c96b4 |
213
+ | Ryzen 9 5950X | AVX2 | medium-dis | 8 | 0 | 3995.19 | 5.08 | 1.21 | 0.55 | 22c96b4 |
214
+ | Ryzen 9 5950X | AVX2 | large-v2 | 8 | 0 | 8056.16 | 69.74 | 16.11 | 6.13 | 22c96b4 |
215
+ | Ryzen 9 5950X | AVX2 | large-v2-q5_0 | 8 | 0 | 9799.58 | 35.16 | 10.49 | 7.28 | 22c96b4 |
216
+ | Ryzen 9 5950X | AVX2 | large-v2-q5_1 | 8 | 0 | ms | 36.74 | 11.02 | 7.65 | 22c96b4 |
217
+ | Ryzen 9 5950X | AVX2 | large-v2-dis | 8 | 0 | 7490.03 | 7.40 | 1.70 | 0.72 | 22c96b4 |
218
+
219
+
220
+ WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0
221
+
222
+ | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
223
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
224
+ | RTX 2060 | AVX2 CUDA | tiny | 8 | 0 | 12.54 | 0.93 | 0.29 | 0.02 | 22c96b4 |
225
+ | RTX 2060 | AVX2 CUDA | tiny-q5_0 | 8 | 0 | 12.73 | 0.98 | 0.24 | 0.02 | 22c96b4 |
226
+ | RTX 2060 | AVX2 CUDA | tiny-q5_1 | 8 | 0 | 12.72 | 0.99 | 0.24 | 0.02 | 22c96b4 |
227
+ | RTX 2060 | AVX2 CUDA | base | 8 | 0 | 24.14 | 1.28 | 0.41 | 0.03 | 22c96b4 |
228
+ | RTX 2060 | AVX2 CUDA | base-q5_0 | 8 | 0 | 24.58 | 1.38 | 0.35 | 0.03 | 22c96b4 |
229
+ | RTX 2060 | AVX2 CUDA | base-q5_1 | 8 | 0 | 24.58 | 1.37 | 0.35 | 0.03 | 22c96b4 |
230
+ | RTX 2060 | AVX2 CUDA | small | 8 | 0 | 74.70 | 2.91 | 0.84 | 0.07 | 22c96b4 |
231
+ | RTX 2060 | AVX2 CUDA | small-q5_0 | 8 | 0 | 76.12 | 2.84 | 0.77 | 0.08 | 22c96b4 |
232
+ | RTX 2060 | AVX2 CUDA | small-q5_1 | 8 | 0 | 76.14 | 2.84 | 0.76 | 0.08 | 22c96b4 |
233
+ | RTX 2060 | AVX2 CUDA | medium | 8 | 0 | 200.69 | 6.46 | 1.83 | 0.17 | 22c96b4 |
234
+ | RTX 2060 | AVX2 CUDA | medium-q5_0 | 8 | 0 | 204.80 | 5.90 | 1.65 | 0.19 | 22c96b4 |
235
+ | RTX 2060 | AVX2 CUDA | medium-q5_1 | 8 | 0 | 205.61 | 5.85 | 1.61 | 0.19 | 22c96b4 |
236
+ | RTX 2060 | AVX2 CUDA | medium-dis | 8 | 0 | 186.17 | 0.86 | 0.24 | 0.02 | 22c96b4 |
237
+ | RTX 2060 | AVX2 CUDA | large-v2 | 8 | 0 | 347.22 | 10.36 | 2.82 | 0.29 | 22c96b4 |
238
+ | RTX 2060 | AVX2 CUDA | large-v2-q5_0 | 8 | 0 | 357.06 | 8.81 | 2.58 | 0.34 | 22c96b4 |
239
+ | RTX 2060 | AVX2 CUDA | large-v2-q5_1 | 8 | 0 | 356.97 | 8.62 | 2.49 | 0.33 | 22c96b4 |
240
+ | RTX 2060 | AVX2 CUDA | large-v2-dis | 8 | 0 | 318.05 | 1.03 | 0.34 | 0.04 | 22c96b4 |
241
+
242
+
243
+ WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1
244
+
245
+ | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
246
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
247
+ | RTX 2060 | AVX2 CUDA | tiny | 8 | 1 | 7.21 | 0.76 | 0.29 | 0.02 | 22c96b4 |
248
+ | RTX 2060 | AVX2 CUDA | tiny-q5_0 | 8 | 1 | 7.42 | 0.82 | 0.18 | 0.02 | 22c96b4 |
249
+ | RTX 2060 | AVX2 CUDA | tiny-q5_1 | 8 | 1 | 7.38 | 0.82 | 0.18 | 0.02 | 22c96b4 |
250
+ | RTX 2060 | AVX2 CUDA | base | 8 | 1 | 13.49 | 1.04 | 0.36 | 0.02 | 22c96b4 |
251
+ | RTX 2060 | AVX2 CUDA | base-q5_0 | 8 | 1 | 13.94 | 1.13 | 0.26 | 0.03 | 22c96b4 |
252
+ | RTX 2060 | AVX2 CUDA | base-q5_1 | 8 | 1 | 13.94 | 1.14 | 0.26 | 0.03 | 22c96b4 |
253
+ | RTX 2060 | AVX2 CUDA | small | 8 | 1 | 42.81 | 2.33 | 0.69 | 0.05 | 22c96b4 |
254
+ | RTX 2060 | AVX2 CUDA | small-q5_0 | 8 | 1 | 44.43 | 2.25 | 0.59 | 0.06 | 22c96b4 |
255
+ | RTX 2060 | AVX2 CUDA | small-q5_1 | 8 | 1 | 44.11 | 2.24 | 0.58 | 0.06 | 22c96b4 |
256
+ | RTX 2060 | AVX2 CUDA | medium | 8 | 1 | 115.47 | 5.17 | 1.45 | 0.11 | 22c96b4 |
257
+ | RTX 2060 | AVX2 CUDA | medium-q5_0 | 8 | 1 | 120.37 | 4.63 | 1.25 | 0.13 | 22c96b4 |
258
+ | RTX 2060 | AVX2 CUDA | medium-q5_1 | 8 | 1 | 120.28 | 4.55 | 1.21 | 0.13 | 22c96b4 |
259
+ | RTX 2060 | AVX2 CUDA | medium-dis | 8 | 1 | 101.69 | 0.75 | 0.20 | 0.02 | 22c96b4 |
260
+ | RTX 2060 | AVX2 CUDA | large-v2 | 8 | 1 | 205.67 | 8.49 | 2.19 | 0.18 | 22c96b4 |
261
+ | RTX 2060 | AVX2 CUDA | large-v2-q5_0 | 8 | 1 | 214.07 | 6.88 | 1.94 | 0.22 | 22c96b4 |
262
+ | RTX 2060 | AVX2 CUDA | large-v2-q5_1 | 8 | 1 | 213.98 | 6.70 | 1.86 | 0.22 | 22c96b4 |
263
+ | RTX 2060 | AVX2 CUDA | large-v2-dis | 8 | 1 | 176.71 | 0.91 | 0.31 | 0.03 | 22c96b4 |
264
+
265
+
266
+
267
+
268
+ # V100
269
+
270
+ WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 0
271
+
272
+ | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
273
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
274
+ | V100 | AVX2 CUDA | tiny | 1 | 0 | 6.21 | 1.11 | 0.30 | 0.02 | 22c96b4 |
275
+ | V100 | AVX2 CUDA | tiny-q5_1 | 1 | 0 | 5.97 | 1.10 | 0.26 | 0.02 | 22c96b4 |
276
+ | V100 | AVX2 CUDA | base | 1 | 0 | 10.95 | 1.47 | 0.42 | 0.03 | 22c96b4 |
277
+ | V100 | AVX2 CUDA | base-q5_1 | 1 | 0 | 11.13 | 1.53 | 0.36 | 0.03 | 22c96b4 |
278
+ | V100 | AVX2 CUDA | small | 1 | 0 | 31.57 | 2.96 | 0.84 | 0.05 | 22c96b4 |
279
+ | V100 | AVX2 CUDA | small-q5_1 | 1 | 0 | 32.19 | 3.14 | 0.75 | 0.05 | 22c96b4 |
280
+ | V100 | AVX2 CUDA | medium | 1 | 0 | 85.88 | 6.49 | 1.80 | 0.10 | 22c96b4 |
281
+ | V100 | AVX2 CUDA | medium-q5_0 | 1 | 0 | 87.53 | 5.82 | 1.37 | 0.10 | 22c96b4 |
282
+ | V100 | AVX2 CUDA | large-v2 | 1 | 0 | 142.23 | 8.92 | 2.62 | 0.15 | 22c96b4 |
283
+
284
+
285
+ WHISPER_CUDA=1 make -j && ./scripts/bench-all.sh 8 1 1
286
+
287
+ | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit |
288
+ | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
289
+ | V100 | AVX2 CUDA | tiny | 1 | 1 | 3.96 | 0.82 | 0.24 | 0.02 | 22c96b4 |
290
+ | V100 | AVX2 CUDA | tiny-q5_1 | 1 | 1 | 4.05 | 0.85 | 0.18 | 0.02 | 22c96b4 |
291
+ | V100 | AVX2 CUDA | base | 1 | 1 | 7.21 | 1.16 | 0.36 | 0.02 | 22c96b4 |
292
+ | V100 | AVX2 CUDA | base-q5_1 | 1 | 1 | 7.39 | 1.21 | 0.26 | 0.02 | 22c96b4 |
293
+ | V100 | AVX2 CUDA | small | 1 | 1 | 19.81 | 2.41 | 0.71 | 0.04 | 22c96b4 |
294
+ | V100 | AVX2 CUDA | small-q5_1 | 1 | 1 | 20.50 | 2.31 | 0.51 | 0.04 | 22c96b4 |
295
+ | V100 | AVX2 CUDA | medium | 1 | 1 | 56.02 | 4.89 | 1.44 | 0.07 | 22c96b4 |
296
+ | V100 | AVX2 CUDA | medium-q5_0 | 1 | 1 | 57.85 | 4.73 | 1.09 | 0.08 | 22c96b4 |
297
+ | V100 | AVX2 CUDA | large-v2 | 1 | 1 | 92.73 | 7.18 | 2.14 | 0.10 | 22c96b4 |
298
+
scripts/bench-all.sh CHANGED
@@ -2,7 +2,7 @@
2
 
3
  # Helper script to run the bench tool on all models and print the results in share-able format
4
 
5
- printf "Usage: ./bench.sh [n_threads] [encoder-only]\n"
6
 
7
  if [ -z "$1" ]; then
8
  n_threads=4
@@ -11,12 +11,19 @@ else
11
  fi
12
 
13
  encoder_only=0
14
- if [ -z "$2" ]; then
15
  encoder_only=0
16
  else
17
  encoder_only=$2
18
  fi
19
 
 
 
 
 
 
 
 
20
  models=( \
21
  "tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
22
  "base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
@@ -44,13 +51,19 @@ if [ "$encoder_only" -eq 0 ]; then
44
  printf "\n"
45
  fi
46
 
47
- printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "Bch5" "PP" "Commit"
48
- printf "| %6s | %6s | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
 
 
 
 
 
 
49
 
50
  for model in "${models[@]}"; do
51
  # actual run
52
  # store stderr output in a variable in order to parse it later
53
- output=$(./bench -m ./models/ggml-$model.bin -t $n_threads 2>&1)
54
  ret=$?
55
 
56
  # parse the output:
@@ -95,6 +108,6 @@ for model in "${models[@]}"; do
95
  commit=$(git rev-parse --short HEAD)
96
 
97
  if [ $ret -eq 0 ]; then
98
- printf "| <todo> | <todo> | %16s | %13s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
99
  fi
100
  done
 
2
 
3
  # Helper script to run the bench tool on all models and print the results in share-able format
4
 
5
+ printf "Usage: ./bench.sh [n_threads] [encoder-only] [flash-attn]\n"
6
 
7
  if [ -z "$1" ]; then
8
  n_threads=4
 
11
  fi
12
 
13
  encoder_only=0
14
+ if [ -z "$2" ] || [ "$2" -eq 0 ]; then
15
  encoder_only=0
16
  else
17
  encoder_only=$2
18
  fi
19
 
20
+ fattn=""
21
+ if [ -z "$3" ] || [ "$3" -eq 0 ]; then
22
+ fattn=""
23
+ else
24
+ fattn="-fa"
25
+ fi
26
+
27
  models=( \
28
  "tiny" "tiny-q4_0" "tiny-q4_1" "tiny-q5_0" "tiny-q5_1" "tiny-q8_0" \
29
  "base" "base-q4_0" "base-q4_1" "base-q5_0" "base-q5_1" "base-q8_0" \
 
51
  printf "\n"
52
  fi
53
 
54
+ if [ "$fattn" == "-fa" ]; then
55
+ fattn_i=1
56
+ else
57
+ fattn_i=0
58
+ fi
59
+
60
+ printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "FA" "Enc." "Dec." "Bch5" "PP" "Commit"
61
+ printf "| %6s | %6s | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---" "---"
62
 
63
  for model in "${models[@]}"; do
64
  # actual run
65
  # store stderr output in a variable in order to parse it later
66
+ output=$(./bench -m ./models/ggml-$model.bin -t $n_threads $fattn 2>&1)
67
  ret=$?
68
 
69
  # parse the output:
 
108
  commit=$(git rev-parse --short HEAD)
109
 
110
  if [ $ret -eq 0 ]; then
111
+ printf "| <todo> | <todo> | %16s | %13s | %3s | %3s | %7s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$fattn_i" "$encode_time" "$decode_time" "$batchd_time" "$prompt_time" "$commit"
112
  fi
113
  done
whisper.cpp CHANGED
@@ -809,14 +809,15 @@ struct whisper_state {
809
  // shared between all decoders
810
  whisper_kv_cache kv_cross;
811
 
 
 
 
812
  whisper_mel mel;
813
 
814
  whisper_batch batch;
815
 
816
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
817
 
818
- ggml_backend_t backend = nullptr;
819
-
820
  // ggml-alloc:
821
  // - stores meta info about the intermediate tensors into the `meta` buffers
822
  // - stores the actual tensor data into the `data` buffers
@@ -902,14 +903,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
902
  }
903
 
904
  static bool kv_cache_init(
905
- const struct whisper_hparams & hparams,
906
  struct whisper_kv_cache & cache,
907
  ggml_backend_t backend,
908
  ggml_type wtype,
 
 
909
  int n_ctx) {
910
- const int64_t n_text_state = hparams.n_text_state;
911
- const int64_t n_text_layer = hparams.n_text_layer;
912
-
913
  const int64_t n_mem = n_text_layer*n_ctx;
914
  const int64_t n_elements = n_text_state*n_mem;
915
 
@@ -941,6 +940,8 @@ static bool kv_cache_init(
941
  return false;
942
  }
943
 
 
 
944
  return true;
945
  }
946
 
@@ -1068,6 +1069,26 @@ static void whisper_kv_cache_seq_cp(
1068
  }
1069
  }
1070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
  // [EXPERIMENTAL] Token-level timestamps with DTW
1072
  static bool aheads_masks_init(
1073
  const whisper_context_params & cparams,
@@ -1872,6 +1893,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1872
  const int n_head = hparams.n_audio_head;
1873
  const int n_layer = hparams.n_audio_layer;
1874
 
 
 
 
 
 
 
 
 
1875
  struct ggml_init_params params = {
1876
  /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1877
  /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
@@ -1884,7 +1913,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1884
 
1885
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1886
 
1887
- const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
1888
 
1889
  // ===================================================================
1890
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
@@ -1934,14 +1963,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1934
 
1935
  Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
1936
 
1937
- //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
1938
 
1939
  // note: no bias for Key
1940
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1941
  layer.attn_k_w,
1942
  cur);
1943
 
1944
- //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
1945
 
1946
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1947
  layer.attn_v_w,
@@ -1955,38 +1984,61 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
1955
  ggml_permute(ctx0,
1956
  ggml_cpy(ctx0,
1957
  Qcur,
1958
- ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
1959
  0, 2, 1, 3);
1960
 
1961
- struct ggml_tensor * K =
1962
- ggml_permute(ctx0,
1963
- ggml_cpy(ctx0,
1964
- Kcur,
1965
- ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
1966
- 0, 2, 1, 3);
1967
-
1968
- // K * Q
1969
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
1970
 
1971
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
 
 
 
 
 
1972
 
1973
- struct ggml_tensor * V =
1974
- ggml_cpy(ctx0,
1975
- ggml_permute(ctx0,
1976
- ggml_reshape_3d(ctx0,
1977
- Vcur,
1978
- n_state/n_head, n_head, n_ctx),
1979
- 1, 2, 0, 3),
1980
- ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
1981
- );
1982
 
1983
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
1984
 
1985
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
1986
-
1987
- cur = ggml_cpy(ctx0,
1988
- KQV_merged,
1989
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1990
  }
1991
 
1992
  // projection
@@ -2085,6 +2137,10 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2085
  const int n_state = hparams.n_audio_state;
2086
  const int n_head = hparams.n_audio_head;
2087
 
 
 
 
 
2088
  struct ggml_init_params params = {
2089
  /*.mem_size =*/ wstate.alloc_cross.meta.size(),
2090
  /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
@@ -2097,18 +2153,18 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2097
 
2098
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
2099
 
2100
- const float Kscale = pow(float(n_state) / n_head, -0.25);
2101
 
2102
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
2103
  auto & layer = model.layers_decoder[il];
2104
 
2105
- struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
2106
  layer.cross_attn_k_w,
2107
  cur);
2108
 
2109
  Kcross = ggml_scale(ctx0, Kcross, Kscale);
2110
 
2111
- struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
2112
  layer.cross_attn_v_w,
2113
  cur);
2114
 
@@ -2116,15 +2172,25 @@ static struct ggml_cgraph * whisper_build_graph_cross(
2116
  Vcross,
2117
  layer.cross_attn_v_b);
2118
 
2119
- Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
 
2120
 
2121
- struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k,
2122
- n_state*n_ctx,
2123
- (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2124
 
2125
- struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2126
- ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
2127
- (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
 
 
 
 
 
 
 
 
 
2128
 
2129
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
2130
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
@@ -2195,7 +2261,7 @@ static bool whisper_encode_internal(
2195
  }
2196
 
2197
  if (!whisper_encode_external(wstate)) {
2198
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2199
  return false;
2200
  }
2201
  } else {
@@ -2218,7 +2284,7 @@ static bool whisper_encode_internal(
2218
  return false;
2219
  }
2220
 
2221
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2222
  return false;
2223
  }
2224
  }
@@ -2234,7 +2300,7 @@ static bool whisper_encode_internal(
2234
  return false;
2235
  }
2236
 
2237
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2238
  return false;
2239
  }
2240
  }
@@ -2263,11 +2329,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2263
  const int n_head = hparams.n_text_head;
2264
  const int n_layer = hparams.n_text_layer;
2265
 
 
 
2266
  const int n_tokens = batch.n_tokens;
2267
  const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2268
 
2269
- const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2270
- const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
 
 
2271
 
2272
  //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2273
 
@@ -2289,12 +2359,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2289
  ggml_set_name(position, "position");
2290
  ggml_set_input(position);
2291
 
2292
- const float KQscale = pow(float(n_state)/n_head, -0.25);
2293
 
2294
- struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
2295
  ggml_set_name(KQ_mask, "KQ_mask");
2296
  ggml_set_input(KQ_mask);
2297
 
 
 
2298
  // token encoding + position encoding
2299
  struct ggml_tensor * cur =
2300
  ggml_add(ctx0,
@@ -2350,12 +2422,25 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2350
  Vcur,
2351
  layer.attn_v_b);
2352
 
2353
- Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
 
 
 
 
 
 
 
 
 
 
2354
 
2355
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2356
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2357
- ( n_ctx)*ggml_element_size(kv_self.v),
2358
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
 
 
 
2359
 
2360
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
2361
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
@@ -2365,35 +2450,48 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2365
 
2366
  struct ggml_tensor * Q =
2367
  ggml_permute(ctx0,
2368
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2369
  0, 2, 1, 3);
2370
 
2371
  struct ggml_tensor * K =
2372
  ggml_view_3d(ctx0, kv_self.k,
2373
- n_state/n_head, n_kv, n_head,
2374
  ggml_element_size(kv_self.k)*n_state,
2375
- ggml_element_size(kv_self.k)*n_state/n_head,
2376
  ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2377
 
2378
- // K * Q
2379
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 
 
 
 
 
 
 
 
 
 
 
 
2380
 
2381
- struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
2382
 
2383
- struct ggml_tensor * V =
2384
- ggml_view_3d(ctx0, kv_self.v,
2385
- n_kv, n_state/n_head, n_head,
2386
- n_ctx*ggml_element_size(kv_self.v),
2387
- n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
2388
- n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2389
 
2390
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2391
 
2392
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2393
 
2394
- cur = ggml_cpy(ctx0,
2395
- KQV_merged,
2396
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
 
2397
  }
2398
 
2399
  // projection
@@ -2432,80 +2530,77 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
2432
  Qcur,
2433
  layer.cross_attn_q_b);
2434
 
2435
- Qcur = ggml_scale(ctx0, Qcur, KQscale);
2436
-
2437
- // Kcross is already scaled
2438
- struct ggml_tensor * Kcross =
2439
- ggml_view_3d(ctx0, wstate.kv_cross.k,
2440
- n_state/n_head, n_audio_ctx, n_head,
2441
- ggml_element_size(wstate.kv_cross.k)*n_state,
2442
- ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
2443
- ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2444
-
2445
- //struct ggml_tensor * Vcross =
2446
- // ggml_reshape_3d(ctx0,
2447
- // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
2448
- // n_state/n_head, n_head, n_audio_ctx);
2449
-
2450
- //struct ggml_tensor * V_trans =
2451
- // ggml_cpy(ctx0,
2452
- // ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
2453
- // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
2454
-
2455
- struct ggml_tensor * V =
2456
- ggml_view_3d(ctx0, wstate.kv_cross.v,
2457
- n_audio_ctx, n_state/n_head, n_head,
2458
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2459
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
2460
- n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2461
-
2462
- // ------
2463
-
2464
  struct ggml_tensor * Q =
2465
  ggml_permute(ctx0,
2466
- ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
2467
  0, 2, 1, 3);
2468
 
2469
- // K * Q
2470
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2471
-
2472
- //struct ggml_tensor * KQ_scaled =
2473
- // ggml_scale(ctx0,
2474
- // KQ,
2475
- // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2476
- // );
2477
 
2478
- // no masking for cross-attention
2479
- //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
 
 
 
 
2480
 
2481
- struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
2482
 
2483
- // [EXPERIMENTAL] Token-level timestamps with DTW
2484
- if (wctx.params.dtw_token_timestamps) {
2485
- if (wstate.aheads_masks.m[il] != nullptr) {
2486
- struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2487
- aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2488
- aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2489
- aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2490
- aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2491
- aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2492
- aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2493
- if (aheads_cross_QKs == NULL) {
2494
- aheads_cross_QKs = aheads_KQs;
2495
- } else {
2496
- aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2497
  }
2498
  }
2499
- }
2500
 
2501
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2502
 
2503
- struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2504
 
2505
- // cur = KQV_merged.contiguous().view(n_state, n_tokens)
2506
- cur = ggml_cpy(ctx0,
2507
- KQV_merged,
2508
- ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2509
  }
2510
 
2511
  // projection
@@ -2638,7 +2733,9 @@ static bool whisper_decode_internal(
2638
  return false;
2639
  }
2640
 
2641
- kv_self.n = whisper_kv_cache_cell_max(kv_self);
 
 
2642
  //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2643
  //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2644
  }
@@ -2672,9 +2769,10 @@ static bool whisper_decode_internal(
2672
  struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
2673
 
2674
  auto & kv_self = wstate.kv_self;
2675
- const int32_t n_kv = kv_self.n;
2676
 
2677
- wstate.inp_mask.resize(n_kv*n_tokens);
 
 
2678
 
2679
  float * data = wstate.inp_mask.data();
2680
  memset(data, 0, ggml_nbytes(KQ_mask));
@@ -2690,6 +2788,12 @@ static bool whisper_decode_internal(
2690
  }
2691
  }
2692
  }
 
 
 
 
 
 
2693
  }
2694
 
2695
  ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
@@ -2697,7 +2801,7 @@ static bool whisper_decode_internal(
2697
 
2698
  logits = gf->nodes[gf->n_nodes - 1];
2699
 
2700
- if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
2701
  return false;
2702
  }
2703
  }
@@ -3144,18 +3248,14 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3144
 
3145
  whisper_state * state = new whisper_state;
3146
 
3147
- state->backend = whisper_backend_init(ctx->params);
3148
- if (!state->backend) {
3149
- WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
3150
- whisper_free_state(state);
3151
- return nullptr;
3152
- }
3153
-
3154
  // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3155
  // in theory, there can be a case where this is not enough, but in practice it should always be enough
3156
  const int factor = 3;
3157
 
3158
- if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
 
 
 
3159
  WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3160
  whisper_free_state(state);
3161
  return nullptr;
@@ -3166,7 +3266,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3166
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3167
  }
3168
 
3169
- if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
 
 
 
3170
  WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3171
  whisper_free_state(state);
3172
  return nullptr;
@@ -3177,6 +3280,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
3177
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3178
  }
3179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3180
  // [EXPERIMENTAL] Token-level timestamps with DTW
3181
  if (ctx->params.dtw_token_timestamps) {
3182
  if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
@@ -3347,6 +3464,7 @@ int whisper_ctx_init_openvino_encoder(
3347
  struct whisper_context_params whisper_context_default_params() {
3348
  struct whisper_context_params result = {
3349
  /*.use_gpu =*/ true,
 
3350
  /*.gpu_device =*/ 0,
3351
 
3352
  /*.dtw_token_timestamps =*/ false,
@@ -3445,6 +3563,16 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu
3445
  struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3446
  ggml_time_init();
3447
 
 
 
 
 
 
 
 
 
 
 
3448
  whisper_context * ctx = new whisper_context;
3449
  ctx->params = params;
3450
 
@@ -3533,6 +3661,7 @@ void whisper_free_state(struct whisper_state * state) {
3533
  if (state) {
3534
  kv_cache_free(state->kv_self);
3535
  kv_cache_free(state->kv_cross);
 
3536
 
3537
  #ifdef WHISPER_USE_COREML
3538
  if (state->ctx_coreml != nullptr) {
@@ -3555,8 +3684,6 @@ void whisper_free_state(struct whisper_state * state) {
3555
  ggml_gallocr_free(state->alloc_cross.alloc);
3556
  ggml_gallocr_free(state->alloc_decode.alloc);
3557
 
3558
- ggml_backend_free(state->backend);
3559
-
3560
  // [EXPERIMENTAL] Token-level timestamps with DTW
3561
  aheads_masks_free(state->aheads_masks);
3562
 
 
809
  // shared between all decoders
810
  whisper_kv_cache kv_cross;
811
 
812
+ // padded buffer for flash-attention
813
+ whisper_kv_cache kv_pad;
814
+
815
  whisper_mel mel;
816
 
817
  whisper_batch batch;
818
 
819
  whisper_decoder decoders[WHISPER_MAX_DECODERS];
820
 
 
 
821
  // ggml-alloc:
822
  // - stores meta info about the intermediate tensors into the `meta` buffers
823
  // - stores the actual tensor data into the `data` buffers
 
903
  }
904
 
905
  static bool kv_cache_init(
 
906
  struct whisper_kv_cache & cache,
907
  ggml_backend_t backend,
908
  ggml_type wtype,
909
+ int64_t n_text_state,
910
+ int64_t n_text_layer,
911
  int n_ctx) {
 
 
 
912
  const int64_t n_mem = n_text_layer*n_ctx;
913
  const int64_t n_elements = n_text_state*n_mem;
914
 
 
940
  return false;
941
  }
942
 
943
+ ggml_backend_buffer_clear(cache.buffer, 0);
944
+
945
  return true;
946
  }
947
 
 
1069
  }
1070
  }
1071
 
1072
+ static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
1073
+ if (!wctx.params.flash_attn) {
1074
+ return 1u;
1075
+ }
1076
+
1077
+ #ifdef GGML_USE_METAL
1078
+ if (ggml_backend_is_metal(wctx.backend)) {
1079
+ return 32u;
1080
+ }
1081
+ #endif
1082
+
1083
+ #ifdef GGML_USE_CUDA
1084
+ if (ggml_backend_is_cuda(wctx.backend)) {
1085
+ return 256u;
1086
+ }
1087
+ #endif
1088
+
1089
+ return 1u;
1090
+ }
1091
+
1092
  // [EXPERIMENTAL] Token-level timestamps with DTW
1093
  static bool aheads_masks_init(
1094
  const whisper_context_params & cparams,
 
1893
  const int n_head = hparams.n_audio_head;
1894
  const int n_layer = hparams.n_audio_layer;
1895
 
1896
+ const int n_state_head = n_state/n_head;
1897
+
1898
+ auto & kv_pad = wstate.kv_pad;
1899
+
1900
+ WHISPER_ASSERT(!!kv_pad.ctx);
1901
+
1902
+ const int n_ctx_pad = GGML_PAD(n_ctx, 256);
1903
+
1904
  struct ggml_init_params params = {
1905
  /*.mem_size =*/ wstate.alloc_encode.meta.size(),
1906
  /*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
 
1913
 
1914
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
1915
 
1916
+ const float KQscale = 1.0f/sqrtf(float(n_state_head));
1917
 
1918
  // ===================================================================
1919
  // NOTE: experimenting with partial evaluation of the encoder (ignore)
 
1963
 
1964
  Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
1965
 
1966
+ //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
1967
 
1968
  // note: no bias for Key
1969
  struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
1970
  layer.attn_k_w,
1971
  cur);
1972
 
1973
+ //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
1974
 
1975
  struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
1976
  layer.attn_v_w,
 
1984
  ggml_permute(ctx0,
1985
  ggml_cpy(ctx0,
1986
  Qcur,
1987
+ ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
1988
  0, 2, 1, 3);
1989
 
1990
+ if (wctx.params.flash_attn) {
1991
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
1992
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
 
 
 
 
 
 
1993
 
1994
+ struct ggml_tensor * K =
1995
+ ggml_view_3d(ctx0, kv_pad.k,
1996
+ n_state_head, n_ctx_pad, n_head,
1997
+ ggml_element_size(kv_pad.k)*n_state,
1998
+ ggml_element_size(kv_pad.k)*n_state_head,
1999
+ 0);
2000
 
2001
+ struct ggml_tensor * V =
2002
+ ggml_view_3d(ctx0, kv_pad.v,
2003
+ n_state_head, n_ctx_pad, n_head,
2004
+ ggml_element_size(kv_pad.v)*n_state,
2005
+ ggml_element_size(kv_pad.v)*n_state_head,
2006
+ 0);
 
 
 
2007
 
2008
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f);
2009
 
2010
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
2011
+ } else {
2012
+ struct ggml_tensor * K =
2013
+ ggml_permute(ctx0,
2014
+ ggml_cpy(ctx0,
2015
+ Kcur,
2016
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
2017
+ 0, 2, 1, 3);
2018
+
2019
+ // K * Q
2020
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2021
+
2022
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2023
+
2024
+ struct ggml_tensor * V =
2025
+ ggml_cpy(ctx0,
2026
+ ggml_permute(ctx0,
2027
+ ggml_reshape_3d(ctx0,
2028
+ Vcur,
2029
+ n_state_head, n_head, n_ctx),
2030
+ 1, 2, 0, 3),
2031
+ ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
2032
+ );
2033
+
2034
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2035
+
2036
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2037
+
2038
+ cur = ggml_cpy(ctx0,
2039
+ KQV_merged,
2040
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
2041
+ }
2042
  }
2043
 
2044
  // projection
 
2137
  const int n_state = hparams.n_audio_state;
2138
  const int n_head = hparams.n_audio_head;
2139
 
2140
+ const int n_state_head = n_state/n_head;
2141
+
2142
+ const int n_ctx_pad = GGML_PAD(n_ctx, 256);
2143
+
2144
  struct ggml_init_params params = {
2145
  /*.mem_size =*/ wstate.alloc_cross.meta.size(),
2146
  /*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
 
2153
 
2154
  struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
2155
 
2156
+ const float Kscale = pow(float(n_state_head), -0.25);
2157
 
2158
  for (int il = 0; il < model.hparams.n_text_layer; ++il) {
2159
  auto & layer = model.layers_decoder[il];
2160
 
2161
+ struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
2162
  layer.cross_attn_k_w,
2163
  cur);
2164
 
2165
  Kcross = ggml_scale(ctx0, Kcross, Kscale);
2166
 
2167
+ struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
2168
  layer.cross_attn_v_w,
2169
  cur);
2170
 
 
2172
  Vcross,
2173
  layer.cross_attn_v_b);
2174
 
2175
+ struct ggml_tensor * k;
2176
+ struct ggml_tensor * v;
2177
 
2178
+ if (wctx.params.flash_attn) {
2179
+ k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2180
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
2181
 
2182
+ v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
2183
+ (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
2184
+ } else {
2185
+ Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
2186
+
2187
+ k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
2188
+ (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
2189
+
2190
+ v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
2191
+ ( n_ctx)*ggml_element_size(wstate.kv_cross.v),
2192
+ (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
2193
+ }
2194
 
2195
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
2196
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
 
2261
  }
2262
 
2263
  if (!whisper_encode_external(wstate)) {
2264
+ if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
2265
  return false;
2266
  }
2267
  } else {
 
2284
  return false;
2285
  }
2286
 
2287
+ if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
2288
  return false;
2289
  }
2290
  }
 
2300
  return false;
2301
  }
2302
 
2303
+ if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
2304
  return false;
2305
  }
2306
  }
 
2329
  const int n_head = hparams.n_text_head;
2330
  const int n_layer = hparams.n_text_layer;
2331
 
2332
+ const int n_state_head = n_state/n_head;
2333
+
2334
  const int n_tokens = batch.n_tokens;
2335
  const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
2336
 
2337
+ const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
2338
+
2339
+ const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
2340
+ const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
2341
 
2342
  //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
2343
 
 
2359
  ggml_set_name(position, "position");
2360
  ggml_set_input(position);
2361
 
2362
+ const float KQscale = pow(float(n_state_head), -0.25);
2363
 
2364
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
2365
  ggml_set_name(KQ_mask, "KQ_mask");
2366
  ggml_set_input(KQ_mask);
2367
 
2368
+ struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
2369
+
2370
  // token encoding + position encoding
2371
  struct ggml_tensor * cur =
2372
  ggml_add(ctx0,
 
2422
  Vcur,
2423
  layer.attn_v_b);
2424
 
2425
+ struct ggml_tensor * k;
2426
+ struct ggml_tensor * v;
2427
+
2428
+ if (wctx.params.flash_attn) {
2429
+ k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2430
+ (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2431
+
2432
+ v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
2433
+ (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
2434
+ } else {
2435
+ Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
2436
 
2437
+ k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
2438
+ (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
2439
+
2440
+ v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
2441
+ ( n_ctx)*ggml_element_size(kv_self.v),
2442
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
2443
+ }
2444
 
2445
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
2446
  ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
 
2450
 
2451
  struct ggml_tensor * Q =
2452
  ggml_permute(ctx0,
2453
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2454
  0, 2, 1, 3);
2455
 
2456
  struct ggml_tensor * K =
2457
  ggml_view_3d(ctx0, kv_self.k,
2458
+ n_state_head, n_kv, n_head,
2459
  ggml_element_size(kv_self.k)*n_state,
2460
+ ggml_element_size(kv_self.k)*n_state_head,
2461
  ggml_element_size(kv_self.k)*n_state*n_ctx*il);
2462
 
2463
+ if (wctx.params.flash_attn) {
2464
+ struct ggml_tensor * V =
2465
+ ggml_view_3d(ctx0, kv_self.v,
2466
+ n_state_head, n_kv, n_head,
2467
+ ggml_element_size(kv_self.v)*n_state,
2468
+ ggml_element_size(kv_self.v)*n_state_head,
2469
+ ggml_element_size(kv_self.v)*n_state*n_ctx*il);
2470
+
2471
+ cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f);
2472
+
2473
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2474
+ } else {
2475
+ // K * Q
2476
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
2477
 
2478
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
2479
 
2480
+ struct ggml_tensor * V =
2481
+ ggml_view_3d(ctx0, kv_self.v,
2482
+ n_kv, n_state_head, n_head,
2483
+ n_ctx*ggml_element_size(kv_self.v),
2484
+ n_ctx*ggml_element_size(kv_self.v)*n_state_head,
2485
+ n_ctx*ggml_element_size(kv_self.v)*n_state*il);
2486
 
2487
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
2488
 
2489
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2490
 
2491
+ cur = ggml_cpy(ctx0,
2492
+ KQV_merged,
2493
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2494
+ }
2495
  }
2496
 
2497
  // projection
 
2530
  Qcur,
2531
  layer.cross_attn_q_b);
2532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2533
  struct ggml_tensor * Q =
2534
  ggml_permute(ctx0,
2535
+ ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
2536
  0, 2, 1, 3);
2537
 
2538
+ if (wctx.params.flash_attn) {
2539
+ struct ggml_tensor * Kcross =
2540
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2541
+ n_state_head, n_audio_ctx_pad, n_head,
2542
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2543
+ ggml_element_size(wstate.kv_cross.k)*n_state_head,
2544
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
 
2545
 
2546
+ struct ggml_tensor * Vcross =
2547
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2548
+ n_state_head, n_audio_ctx_pad, n_head,
2549
+ ggml_element_size(wstate.kv_cross.v)*n_state,
2550
+ ggml_element_size(wstate.kv_cross.v)*n_state_head,
2551
+ ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
2552
 
2553
+ cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f);
2554
 
2555
+ cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
2556
+ } else {
2557
+ struct ggml_tensor * Kcross =
2558
+ ggml_view_3d(ctx0, wstate.kv_cross.k,
2559
+ n_state_head, n_audio_ctx, n_head,
2560
+ ggml_element_size(wstate.kv_cross.k)*n_state,
2561
+ ggml_element_size(wstate.kv_cross.k)*n_state_head,
2562
+ ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
2563
+
2564
+ struct ggml_tensor * Vcross =
2565
+ ggml_view_3d(ctx0, wstate.kv_cross.v,
2566
+ n_audio_ctx, n_state_head, n_head,
2567
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
2568
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
2569
+ n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
2570
+
2571
+ // ------
2572
+
2573
+ // K * Q
2574
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
2575
+
2576
+ struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
2577
+
2578
+ // [EXPERIMENTAL] Token-level timestamps with DTW
2579
+ if (wctx.params.dtw_token_timestamps) {
2580
+ if (wstate.aheads_masks.m[il] != nullptr) {
2581
+ struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
2582
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2583
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2584
+ aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
2585
+ aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
2586
+ aheads_KQs = ggml_cont(ctx0, aheads_KQs);
2587
+ aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
2588
+ if (aheads_cross_QKs == NULL) {
2589
+ aheads_cross_QKs = aheads_KQs;
2590
+ } else {
2591
+ aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs);
2592
+ }
2593
  }
2594
  }
 
2595
 
2596
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
2597
 
2598
+ struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
2599
 
2600
+ cur = ggml_cpy(ctx0,
2601
+ KQV_merged,
2602
+ ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
2603
+ }
2604
  }
2605
 
2606
  // projection
 
2733
  return false;
2734
  }
2735
 
2736
+ const uint32_t pad = whisper_kv_cache_get_padding(wctx);
2737
+ kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
2738
+
2739
  //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
2740
  //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
2741
  }
 
2769
  struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
2770
 
2771
  auto & kv_self = wstate.kv_self;
 
2772
 
2773
+ const int32_t n_kv = kv_self.n;
2774
+
2775
+ wstate.inp_mask.resize(ggml_nelements(KQ_mask));
2776
 
2777
  float * data = wstate.inp_mask.data();
2778
  memset(data, 0, ggml_nbytes(KQ_mask));
 
2788
  }
2789
  }
2790
  }
2791
+
2792
+ for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
2793
+ for (int j = 0; j < n_kv; ++j) {
2794
+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
2795
+ }
2796
+ }
2797
  }
2798
 
2799
  ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
 
2801
 
2802
  logits = gf->nodes[gf->n_nodes - 1];
2803
 
2804
+ if (!ggml_graph_compute_helper(wctx.backend, gf, n_threads)) {
2805
  return false;
2806
  }
2807
  }
 
3248
 
3249
  whisper_state * state = new whisper_state;
3250
 
 
 
 
 
 
 
 
3251
  // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
3252
  // in theory, there can be a case where this is not enough, but in practice it should always be enough
3253
  const int factor = 3;
3254
 
3255
+ if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype,
3256
+ ctx->model.hparams.n_text_state,
3257
+ ctx->model.hparams.n_text_layer,
3258
+ GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
3259
  WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3260
  whisper_free_state(state);
3261
  return nullptr;
 
3266
  WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
3267
  }
3268
 
3269
+ if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
3270
+ ctx->model.hparams.n_text_state,
3271
+ ctx->model.hparams.n_text_layer,
3272
+ GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3273
  WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
3274
  whisper_free_state(state);
3275
  return nullptr;
 
3280
  WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
3281
  }
3282
 
3283
+ if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype,
3284
+ ctx->model.hparams.n_audio_state,
3285
+ 1,
3286
+ GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
3287
+ WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
3288
+ whisper_free_state(state);
3289
+ return nullptr;
3290
+ }
3291
+
3292
+ {
3293
+ const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
3294
+ WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
3295
+ }
3296
+
3297
  // [EXPERIMENTAL] Token-level timestamps with DTW
3298
  if (ctx->params.dtw_token_timestamps) {
3299
  if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
 
3464
  struct whisper_context_params whisper_context_default_params() {
3465
  struct whisper_context_params result = {
3466
  /*.use_gpu =*/ true,
3467
+ /*.flash_attn =*/ false,
3468
  /*.gpu_device =*/ 0,
3469
 
3470
  /*.dtw_token_timestamps =*/ false,
 
3563
  struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
3564
  ggml_time_init();
3565
 
3566
+ if (params.flash_attn && params.dtw_token_timestamps) {
3567
+ WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
3568
+ params.dtw_token_timestamps = false;
3569
+ }
3570
+
3571
+ WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
3572
+ WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
3573
+ WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
3574
+ WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
3575
+
3576
  whisper_context * ctx = new whisper_context;
3577
  ctx->params = params;
3578
 
 
3661
  if (state) {
3662
  kv_cache_free(state->kv_self);
3663
  kv_cache_free(state->kv_cross);
3664
+ kv_cache_free(state->kv_pad);
3665
 
3666
  #ifdef WHISPER_USE_COREML
3667
  if (state->ctx_coreml != nullptr) {
 
3684
  ggml_gallocr_free(state->alloc_cross.alloc);
3685
  ggml_gallocr_free(state->alloc_decode.alloc);
3686
 
 
 
3687
  // [EXPERIMENTAL] Token-level timestamps with DTW
3688
  aheads_masks_free(state->aheads_masks);
3689
 
whisper.h CHANGED
@@ -113,6 +113,7 @@ extern "C" {
113
 
114
  struct whisper_context_params {
115
  bool use_gpu;
 
116
  int gpu_device; // CUDA device
117
 
118
  // [EXPERIMENTAL] Token-level timestamps with DTW
 
113
 
114
  struct whisper_context_params {
115
  bool use_gpu;
116
+ bool flash_attn;
117
  int gpu_device; // CUDA device
118
 
119
  // [EXPERIMENTAL] Token-level timestamps with DTW