whisper : by default disable non-speech tokens suppression (#473)

This seems to be causing hallucinations in the end of the audio, e.g.:

"Thank you for listening"
"Amen"
..
This commit is contained in:
Georgi Gerganov 2023-02-15 21:48:49 +02:00
parent 2407ae8ef0
commit a94897bcde
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2936,7 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.language =*/ "en", /*.language =*/ "en",
/*.suppress_blank =*/ true, /*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/true, /*.suppress_non_speech_tokens =*/ false,
/*.temperature =*/ 0.0f, /*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f, /*.max_initial_ts =*/ 1.0f,
@ -3078,8 +3078,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
return res; return res;
} }
static const std::vector<std::string> non_speech_tokens static const std::vector<std::string> non_speech_tokens = {
{
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--", "_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
@ -3149,26 +3148,21 @@ static void whisper_process_logits(
// suppress non-speech tokens // suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) if (params.suppress_non_speech_tokens) {
{ for (const std::string & token : non_speech_tokens) {
for (const std::string &token : non_speech_tokens) const std::string suppress_tokens[] = {token, " " + token};
{ for (const std::string & suppress_token : suppress_tokens) {
std::string suppress_tokens[] = {token, " " + token}; if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
for (const std::string &suppress_token : suppress_tokens)
{
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
{
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
} }
} }
} }
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
{
logits[vocab.token_to_id.at(" -")] = -INFINITY; logits[vocab.token_to_id.at(" -")] = -INFINITY;
} }
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
{
logits[vocab.token_to_id.at(" '")] = -INFINITY; logits[vocab.token_to_id.at(" '")] = -INFINITY;
} }
} }