Spaces:
Sleeping
Sleeping
shibukazu
commited on
whisper : suppress non-speech-related token outputs (#473)
Browse files* add non-speech-token suppression
* add suppress non-speech_tokens param
- whisper.cpp +36 -0
- whisper.h +1 -0
whisper.cpp
CHANGED
|
@@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2936 |
/*.language =*/ "en",
|
| 2937 |
|
| 2938 |
/*.suppress_blank =*/ true,
|
|
|
|
| 2939 |
|
| 2940 |
/*.temperature =*/ 0.0f,
|
| 2941 |
/*.max_initial_ts =*/ 1.0f,
|
|
@@ -3077,6 +3078,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
|
|
| 3077 |
return res;
|
| 3078 |
}
|
| 3079 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3080 |
// process the logits for the selected decoder
|
| 3081 |
// - applies logit filters
|
| 3082 |
// - computes logprobs and probs
|
|
@@ -3137,6 +3146,33 @@ static void whisper_process_logits(
|
|
| 3137 |
logits[vocab.token_translate] = -INFINITY;
|
| 3138 |
logits[vocab.token_transcribe] = -INFINITY;
|
| 3139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3140 |
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
| 3141 |
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
|
| 3142 |
{
|
|
|
|
| 2936 |
/*.language =*/ "en",
|
| 2937 |
|
| 2938 |
/*.suppress_blank =*/ true,
|
| 2939 |
+
/*.suppress_non_speech_tokens =*/true,
|
| 2940 |
|
| 2941 |
/*.temperature =*/ 0.0f,
|
| 2942 |
/*.max_initial_ts =*/ 1.0f,
|
|
|
|
| 3078 |
return res;
|
| 3079 |
}
|
| 3080 |
|
| 3081 |
+
static const std::vector<std::string> non_speech_tokens
|
| 3082 |
+
{
|
| 3083 |
+
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
|
| 3084 |
+
"_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--",
|
| 3085 |
+
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
|
| 3086 |
+
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
| 3087 |
+
};
|
| 3088 |
+
|
| 3089 |
// process the logits for the selected decoder
|
| 3090 |
// - applies logit filters
|
| 3091 |
// - computes logprobs and probs
|
|
|
|
| 3146 |
logits[vocab.token_translate] = -INFINITY;
|
| 3147 |
logits[vocab.token_transcribe] = -INFINITY;
|
| 3148 |
|
| 3149 |
+
|
| 3150 |
+
// suppress non-speech tokens
|
| 3151 |
+
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
| 3152 |
+
if (params.suppress_non_speech_tokens)
|
| 3153 |
+
{
|
| 3154 |
+
for (const std::string &token : non_speech_tokens)
|
| 3155 |
+
{
|
| 3156 |
+
std::string suppress_tokens[] = {token, " " + token};
|
| 3157 |
+
for (const std::string &suppress_token : suppress_tokens)
|
| 3158 |
+
{
|
| 3159 |
+
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
|
| 3160 |
+
{
|
| 3161 |
+
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
|
| 3162 |
+
}
|
| 3163 |
+
}
|
| 3164 |
+
}
|
| 3165 |
+
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
| 3166 |
+
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
|
| 3167 |
+
{
|
| 3168 |
+
logits[vocab.token_to_id.at(" -")] = -INFINITY;
|
| 3169 |
+
}
|
| 3170 |
+
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
|
| 3171 |
+
{
|
| 3172 |
+
logits[vocab.token_to_id.at(" '")] = -INFINITY;
|
| 3173 |
+
}
|
| 3174 |
+
}
|
| 3175 |
+
|
| 3176 |
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
| 3177 |
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
|
| 3178 |
{
|
whisper.h
CHANGED
|
@@ -285,6 +285,7 @@ extern "C" {
|
|
| 285 |
|
| 286 |
// common decoding parameters:
|
| 287 |
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
|
|
|
| 288 |
|
| 289 |
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
|
| 290 |
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
|
|
|
| 285 |
|
| 286 |
// common decoding parameters:
|
| 287 |
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
| 288 |
+
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
| 289 |
|
| 290 |
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
|
| 291 |
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|