mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-19 09:02:14 +02:00
whisper : fine-tuning grammar functionality
This commit is contained in:
@@ -31,8 +31,9 @@ struct whisper_params {
|
||||
int32_t max_tokens = 32;
|
||||
int32_t audio_ctx = 0;
|
||||
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
float vad_thold = 0.6f;
|
||||
float freq_thold = 100.0f;
|
||||
|
||||
float grammar_penalty = 100.0f;
|
||||
|
||||
bool speed_up = false;
|
||||
@@ -138,6 +139,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.n_threads = params.n_threads;
|
||||
|
||||
// disable fallback - seems not useful for command recognition
|
||||
wparams.temperature_inc = 0.0f;
|
||||
|
||||
wparams.audio_ctx = params.audio_ctx;
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
@@ -508,7 +512,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
||||
|
||||
// general-purpose mode
|
||||
// freely transcribe the voice into text
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
||||
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||
bool is_running = true;
|
||||
bool have_prompt = false;
|
||||
bool ask_prompt = true;
|
||||
@@ -519,7 +523,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
//const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||
//const std::string k_prompt = "Начало.";
|
||||
const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди.";
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||
@@ -578,6 +584,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
// prepend the prompt audio
|
||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
// append 1 second of silence
|
||||
pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f);
|
||||
|
||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||
|
||||
prob = 100.0f*(prob - prob0);
|
||||
@@ -604,6 +613,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stdout, "%s: DEBUG: txt = '%s'\n", __func__, txt.c_str());
|
||||
if (best_len == 0) {
|
||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
||||
} else {
|
||||
|
114
whisper.cpp
114
whisper.cpp
@@ -3865,7 +3865,7 @@ static struct whisper_grammar whisper_grammar_init(
|
||||
static void whisper_suppress_invalid_grammar(
|
||||
whisper_context & ctx,
|
||||
const whisper_full_params & params,
|
||||
std::vector<float> & logits,
|
||||
std::vector<float> & logprobs,
|
||||
const whisper_grammar & grammar) {
|
||||
|
||||
if (grammar.rules.empty() || grammar.stacks.empty()) {
|
||||
@@ -3883,8 +3883,8 @@ static void whisper_suppress_invalid_grammar(
|
||||
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
|
||||
std::vector<whisper_grammar_candidate> candidates_grammar;
|
||||
|
||||
size_t size = logits.size();
|
||||
for (whisper_token id = 0; id < size; ++id) {
|
||||
size_t size = logprobs.size();
|
||||
for (whisper_token id = 0; id < (int) size; ++id) {
|
||||
const std::string & text = ctx.vocab.id_to_token[id];
|
||||
if (!text.empty() && text.rfind("[_", 0) != 0) {
|
||||
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
|
||||
@@ -3893,14 +3893,18 @@ static void whisper_suppress_invalid_grammar(
|
||||
}
|
||||
|
||||
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
||||
|
||||
for (const auto & reject : rejects) {
|
||||
if (logits[reject.id] > 0) {
|
||||
logits[reject.id] /= params.grammar_penalty;
|
||||
} else {
|
||||
logits[reject.id] *= params.grammar_penalty;
|
||||
}
|
||||
logprobs[reject.id] -= params.grammar_penalty;
|
||||
}
|
||||
// fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
||||
|
||||
// when the grammar does not allow any continuation, we don't want to penalize the EOT token
|
||||
// TODO: is there are better way to do this?
|
||||
printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
|
||||
if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
|
||||
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
|
||||
}
|
||||
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
||||
}
|
||||
|
||||
static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
|
||||
@@ -3908,10 +3912,10 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
|
||||
return;
|
||||
}
|
||||
|
||||
// fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str());
|
||||
fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
|
||||
|
||||
const std::string & text = ctx.vocab.id_to_token[token];
|
||||
|
||||
|
||||
if (text.rfind("[_", 0) == 0) {
|
||||
// fprintf(stderr, " (skipped)\n");
|
||||
return;
|
||||
@@ -4015,7 +4019,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
/*.grammar_rules =*/ nullptr,
|
||||
/*.n_grammar_rules =*/ 0,
|
||||
/*.i_start_rule =*/ 0,
|
||||
/*.grammar_penalty =*/ 1000.0f,
|
||||
/*.grammar_penalty =*/ 100.0f,
|
||||
};
|
||||
|
||||
switch (strategy) {
|
||||
@@ -4181,12 +4185,18 @@ static void whisper_process_logits(
|
||||
logits[vocab.token_translate] = -INFINITY;
|
||||
logits[vocab.token_transcribe] = -INFINITY;
|
||||
|
||||
// suppress lang tokens
|
||||
for (size_t i = 0; i < g_lang.size(); ++i) {
|
||||
logits[whisper_token_lang(&ctx, i)] = -INFINITY;
|
||||
}
|
||||
|
||||
// suppress prev token
|
||||
logits[vocab.token_prev] = -INFINITY;
|
||||
|
||||
if (params.logits_filter_callback) {
|
||||
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
||||
}
|
||||
|
||||
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
||||
|
||||
// suppress non-speech tokens
|
||||
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||
if (params.suppress_non_speech_tokens) {
|
||||
@@ -4293,10 +4303,19 @@ static void whisper_process_logits(
|
||||
//log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
||||
|
||||
if (timestamp_logprob > max_text_token_logprob) {
|
||||
//printf("sampling timestamp\n");
|
||||
for (int i = 0; i < vocab.token_beg; ++i) {
|
||||
logits[i] = -INFINITY;
|
||||
logprobs[i] = -INFINITY;
|
||||
}
|
||||
} else {
|
||||
//printf("sampling text\n");
|
||||
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
||||
logits[i] = -INFINITY;
|
||||
logprobs[i] = -INFINITY;
|
||||
}
|
||||
|
||||
whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4312,34 +4331,57 @@ static void whisper_process_logits(
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// print first 100 logits - token string : logit
|
||||
for (int i = 0; i < 100; i++) {
|
||||
const auto token = vocab.id_to_token.at(i);
|
||||
const auto prob = probs[i];
|
||||
const auto logit = logits[i];
|
||||
const auto logprob = logprobs[i];
|
||||
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
|
||||
//for (int i = 0; i < 10; i++) {
|
||||
// const auto token = vocab.id_to_token.at(i);
|
||||
// const auto prob = probs[i];
|
||||
// const auto logit = logits[i];
|
||||
// const auto logprob = logprobs[i];
|
||||
// printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
|
||||
//}
|
||||
|
||||
// print sorted
|
||||
{
|
||||
std::vector<std::pair<float, int>> pairs;
|
||||
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
pairs.push_back(std::make_pair(probs[i], i));
|
||||
}
|
||||
|
||||
std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
const auto token = vocab.id_to_token.at(pairs[i].second);
|
||||
const auto prob = pairs[i].first;
|
||||
const auto logit = logits[pairs[i].second];
|
||||
const auto logprob = logprobs[pairs[i].second];
|
||||
printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
|
||||
}
|
||||
|
||||
printf("----------------\n");
|
||||
}
|
||||
|
||||
// "And", "and", " And", " and"
|
||||
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
||||
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
||||
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
||||
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
||||
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
||||
//printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
||||
//printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
||||
//printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
||||
//printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
||||
//printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
||||
|
||||
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
||||
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
||||
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
||||
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
||||
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
||||
//printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
||||
//printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
||||
//printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
||||
//printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
||||
//printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
||||
|
||||
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
||||
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
||||
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
||||
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
||||
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
||||
//printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
||||
//printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
||||
//printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
||||
//printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
||||
//printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user