whisper : fine-tuning grammar functionality

This commit is contained in:
Georgi Gerganov
2023-09-06 17:05:05 +03:00
parent 97ebb48b99
commit b8f34d1ed7
2 changed files with 92 additions and 40 deletions

View File

@@ -31,8 +31,9 @@ struct whisper_params {
int32_t max_tokens = 32; int32_t max_tokens = 32;
int32_t audio_ctx = 0; int32_t audio_ctx = 0;
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
float grammar_penalty = 100.0f; float grammar_penalty = 100.0f;
bool speed_up = false; bool speed_up = false;
@@ -138,6 +139,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.language = params.language.c_str(); wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
// disable fallback - seems not useful for command recognition
wparams.temperature_inc = 0.0f;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
@@ -508,7 +512,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// general-purpose mode // general-purpose mode
// freely transcribe the voice into text // freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) { int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true; bool is_running = true;
bool have_prompt = false; bool have_prompt = false;
bool ask_prompt = true; bool ask_prompt = true;
@@ -519,7 +523,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
std::vector<float> pcmf32_cur; std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt; std::vector<float> pcmf32_prompt;
const std::string k_prompt = "Ok Whisper, start listening for commands."; //const std::string k_prompt = "Ok Whisper, start listening for commands.";
//const std::string k_prompt = "Начало.";
const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди.";
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__); fprintf(stderr, "%s: general-purpose mode\n", __func__);
@@ -578,6 +584,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// prepend the prompt audio // prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
// append 1 second of silence
pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f);
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
prob = 100.0f*(prob - prob0); prob = 100.0f*(prob - prob0);
@@ -604,6 +613,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
} }
} }
fprintf(stdout, "%s: DEBUG: txt = '%s'\n", __func__, txt.c_str());
if (best_len == 0) { if (best_len == 0) {
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__); fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
} else { } else {

View File

@@ -3865,7 +3865,7 @@ static struct whisper_grammar whisper_grammar_init(
static void whisper_suppress_invalid_grammar( static void whisper_suppress_invalid_grammar(
whisper_context & ctx, whisper_context & ctx,
const whisper_full_params & params, const whisper_full_params & params,
std::vector<float> & logits, std::vector<float> & logprobs,
const whisper_grammar & grammar) { const whisper_grammar & grammar) {
if (grammar.rules.empty() || grammar.stacks.empty()) { if (grammar.rules.empty() || grammar.stacks.empty()) {
@@ -3883,8 +3883,8 @@ static void whisper_suppress_invalid_grammar(
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded; std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
std::vector<whisper_grammar_candidate> candidates_grammar; std::vector<whisper_grammar_candidate> candidates_grammar;
size_t size = logits.size(); size_t size = logprobs.size();
for (whisper_token id = 0; id < size; ++id) { for (whisper_token id = 0; id < (int) size; ++id) {
const std::string & text = ctx.vocab.id_to_token[id]; const std::string & text = ctx.vocab.id_to_token[id];
if (!text.empty() && text.rfind("[_", 0) != 0) { if (!text.empty() && text.rfind("[_", 0) != 0) {
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
@@ -3893,14 +3893,18 @@ static void whisper_suppress_invalid_grammar(
} }
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
for (const auto & reject : rejects) { for (const auto & reject : rejects) {
if (logits[reject.id] > 0) { logprobs[reject.id] -= params.grammar_penalty;
logits[reject.id] /= params.grammar_penalty;
} else {
logits[reject.id] *= params.grammar_penalty;
}
} }
// fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
// when the grammar does not allow any continuation, we don't want to penalize the EOT token
// TODO: is there are better way to do this?
printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
}
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
} }
static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
@@ -3908,10 +3912,10 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
return; return;
} }
// fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str()); fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
const std::string & text = ctx.vocab.id_to_token[token]; const std::string & text = ctx.vocab.id_to_token[token];
if (text.rfind("[_", 0) == 0) { if (text.rfind("[_", 0) == 0) {
// fprintf(stderr, " (skipped)\n"); // fprintf(stderr, " (skipped)\n");
return; return;
@@ -4015,7 +4019,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.grammar_rules =*/ nullptr, /*.grammar_rules =*/ nullptr,
/*.n_grammar_rules =*/ 0, /*.n_grammar_rules =*/ 0,
/*.i_start_rule =*/ 0, /*.i_start_rule =*/ 0,
/*.grammar_penalty =*/ 1000.0f, /*.grammar_penalty =*/ 100.0f,
}; };
switch (strategy) { switch (strategy) {
@@ -4181,12 +4185,18 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY; logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY;
// suppress lang tokens
for (size_t i = 0; i < g_lang.size(); ++i) {
logits[whisper_token_lang(&ctx, i)] = -INFINITY;
}
// suppress prev token
logits[vocab.token_prev] = -INFINITY;
if (params.logits_filter_callback) { if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
} }
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
// suppress non-speech tokens // suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) { if (params.suppress_non_speech_tokens) {
@@ -4293,10 +4303,19 @@ static void whisper_process_logits(
//log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
if (timestamp_logprob > max_text_token_logprob) { if (timestamp_logprob > max_text_token_logprob) {
//printf("sampling timestamp\n");
for (int i = 0; i < vocab.token_beg; ++i) { for (int i = 0; i < vocab.token_beg; ++i) {
logits[i] = -INFINITY; logits[i] = -INFINITY;
logprobs[i] = -INFINITY; logprobs[i] = -INFINITY;
} }
} else {
//printf("sampling text\n");
for (int i = vocab.token_beg; i < n_logits; ++i) {
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}
whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar);
} }
} }
} }
@@ -4312,34 +4331,57 @@ static void whisper_process_logits(
} }
} }
#if 0 #if 1
// print first 100 logits - token string : logit // print first 100 logits - token string : logit
for (int i = 0; i < 100; i++) { //for (int i = 0; i < 10; i++) {
const auto token = vocab.id_to_token.at(i); // const auto token = vocab.id_to_token.at(i);
const auto prob = probs[i]; // const auto prob = probs[i];
const auto logit = logits[i]; // const auto logit = logits[i];
const auto logprob = logprobs[i]; // const auto logprob = logprobs[i];
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
//}
// print sorted
{
std::vector<std::pair<float, int>> pairs;
for (int i = 0; i < n_logits; ++i) {
pairs.push_back(std::make_pair(probs[i], i));
}
std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first > b.first;
});
for (int i = 0; i < 10; i++) {
const auto token = vocab.id_to_token.at(pairs[i].second);
const auto prob = pairs[i].first;
const auto logit = logits[pairs[i].second];
const auto logprob = logprobs[pairs[i].second];
printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
}
printf("----------------\n");
} }
// "And", "and", " And", " and" // "And", "and", " And", " and"
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
#endif #endif
} }