mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-19 11:42:02 +02:00
whisper : fine-tuning grammar functionality
This commit is contained in:
@@ -31,8 +31,9 @@ struct whisper_params {
|
|||||||
int32_t max_tokens = 32;
|
int32_t max_tokens = 32;
|
||||||
int32_t audio_ctx = 0;
|
int32_t audio_ctx = 0;
|
||||||
|
|
||||||
float vad_thold = 0.6f;
|
float vad_thold = 0.6f;
|
||||||
float freq_thold = 100.0f;
|
float freq_thold = 100.0f;
|
||||||
|
|
||||||
float grammar_penalty = 100.0f;
|
float grammar_penalty = 100.0f;
|
||||||
|
|
||||||
bool speed_up = false;
|
bool speed_up = false;
|
||||||
@@ -138,6 +139,9 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
|
|||||||
wparams.language = params.language.c_str();
|
wparams.language = params.language.c_str();
|
||||||
wparams.n_threads = params.n_threads;
|
wparams.n_threads = params.n_threads;
|
||||||
|
|
||||||
|
// disable fallback - seems not useful for command recognition
|
||||||
|
wparams.temperature_inc = 0.0f;
|
||||||
|
|
||||||
wparams.audio_ctx = params.audio_ctx;
|
wparams.audio_ctx = params.audio_ctx;
|
||||||
wparams.speed_up = params.speed_up;
|
wparams.speed_up = params.speed_up;
|
||||||
|
|
||||||
@@ -508,7 +512,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
|
|||||||
|
|
||||||
// general-purpose mode
|
// general-purpose mode
|
||||||
// freely transcribe the voice into text
|
// freely transcribe the voice into text
|
||||||
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
|
||||||
bool is_running = true;
|
bool is_running = true;
|
||||||
bool have_prompt = false;
|
bool have_prompt = false;
|
||||||
bool ask_prompt = true;
|
bool ask_prompt = true;
|
||||||
@@ -519,7 +523,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
|||||||
std::vector<float> pcmf32_cur;
|
std::vector<float> pcmf32_cur;
|
||||||
std::vector<float> pcmf32_prompt;
|
std::vector<float> pcmf32_prompt;
|
||||||
|
|
||||||
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
//const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
||||||
|
//const std::string k_prompt = "Начало.";
|
||||||
|
const std::string k_prompt = "Добре Уиспър, започни да слушаш за команди.";
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
||||||
@@ -578,6 +584,9 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
|||||||
// prepend the prompt audio
|
// prepend the prompt audio
|
||||||
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||||
|
|
||||||
|
// append 1 second of silence
|
||||||
|
pcmf32_cur.insert(pcmf32_cur.end(), 1000*WHISPER_SAMPLE_RATE/1000, 0.0f);
|
||||||
|
|
||||||
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
||||||
|
|
||||||
prob = 100.0f*(prob - prob0);
|
prob = 100.0f*(prob - prob0);
|
||||||
@@ -604,6 +613,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fprintf(stdout, "%s: DEBUG: txt = '%s'\n", __func__, txt.c_str());
|
||||||
if (best_len == 0) {
|
if (best_len == 0) {
|
||||||
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
|
||||||
} else {
|
} else {
|
||||||
|
114
whisper.cpp
114
whisper.cpp
@@ -3865,7 +3865,7 @@ static struct whisper_grammar whisper_grammar_init(
|
|||||||
static void whisper_suppress_invalid_grammar(
|
static void whisper_suppress_invalid_grammar(
|
||||||
whisper_context & ctx,
|
whisper_context & ctx,
|
||||||
const whisper_full_params & params,
|
const whisper_full_params & params,
|
||||||
std::vector<float> & logits,
|
std::vector<float> & logprobs,
|
||||||
const whisper_grammar & grammar) {
|
const whisper_grammar & grammar) {
|
||||||
|
|
||||||
if (grammar.rules.empty() || grammar.stacks.empty()) {
|
if (grammar.rules.empty() || grammar.stacks.empty()) {
|
||||||
@@ -3883,8 +3883,8 @@ static void whisper_suppress_invalid_grammar(
|
|||||||
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
|
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
|
||||||
std::vector<whisper_grammar_candidate> candidates_grammar;
|
std::vector<whisper_grammar_candidate> candidates_grammar;
|
||||||
|
|
||||||
size_t size = logits.size();
|
size_t size = logprobs.size();
|
||||||
for (whisper_token id = 0; id < size; ++id) {
|
for (whisper_token id = 0; id < (int) size; ++id) {
|
||||||
const std::string & text = ctx.vocab.id_to_token[id];
|
const std::string & text = ctx.vocab.id_to_token[id];
|
||||||
if (!text.empty() && text.rfind("[_", 0) != 0) {
|
if (!text.empty() && text.rfind("[_", 0) != 0) {
|
||||||
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
|
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
|
||||||
@@ -3893,14 +3893,18 @@ static void whisper_suppress_invalid_grammar(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
|
||||||
|
|
||||||
for (const auto & reject : rejects) {
|
for (const auto & reject : rejects) {
|
||||||
if (logits[reject.id] > 0) {
|
logprobs[reject.id] -= params.grammar_penalty;
|
||||||
logits[reject.id] /= params.grammar_penalty;
|
|
||||||
} else {
|
|
||||||
logits[reject.id] *= params.grammar_penalty;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
|
||||||
|
// when the grammar does not allow any continuation, we don't want to penalize the EOT token
|
||||||
|
// TODO: is there are better way to do this?
|
||||||
|
printf("rejects.size() = %zu, whisper_token_eot(&ctx) - 2 = %d\n", rejects.size(), whisper_token_eot(&ctx) - 2);
|
||||||
|
if ((int) rejects.size() < whisper_token_eot(&ctx) - 2) {
|
||||||
|
logprobs[whisper_token_eot(&ctx)] -= params.grammar_penalty;
|
||||||
|
}
|
||||||
|
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
|
static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
|
||||||
@@ -3908,10 +3912,10 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str());
|
fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
|
||||||
|
|
||||||
const std::string & text = ctx.vocab.id_to_token[token];
|
const std::string & text = ctx.vocab.id_to_token[token];
|
||||||
|
|
||||||
if (text.rfind("[_", 0) == 0) {
|
if (text.rfind("[_", 0) == 0) {
|
||||||
// fprintf(stderr, " (skipped)\n");
|
// fprintf(stderr, " (skipped)\n");
|
||||||
return;
|
return;
|
||||||
@@ -4015,7 +4019,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.grammar_rules =*/ nullptr,
|
/*.grammar_rules =*/ nullptr,
|
||||||
/*.n_grammar_rules =*/ 0,
|
/*.n_grammar_rules =*/ 0,
|
||||||
/*.i_start_rule =*/ 0,
|
/*.i_start_rule =*/ 0,
|
||||||
/*.grammar_penalty =*/ 1000.0f,
|
/*.grammar_penalty =*/ 100.0f,
|
||||||
};
|
};
|
||||||
|
|
||||||
switch (strategy) {
|
switch (strategy) {
|
||||||
@@ -4181,12 +4185,18 @@ static void whisper_process_logits(
|
|||||||
logits[vocab.token_translate] = -INFINITY;
|
logits[vocab.token_translate] = -INFINITY;
|
||||||
logits[vocab.token_transcribe] = -INFINITY;
|
logits[vocab.token_transcribe] = -INFINITY;
|
||||||
|
|
||||||
|
// suppress lang tokens
|
||||||
|
for (size_t i = 0; i < g_lang.size(); ++i) {
|
||||||
|
logits[whisper_token_lang(&ctx, i)] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
// suppress prev token
|
||||||
|
logits[vocab.token_prev] = -INFINITY;
|
||||||
|
|
||||||
if (params.logits_filter_callback) {
|
if (params.logits_filter_callback) {
|
||||||
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
|
||||||
}
|
}
|
||||||
|
|
||||||
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
|
||||||
|
|
||||||
// suppress non-speech tokens
|
// suppress non-speech tokens
|
||||||
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
|
||||||
if (params.suppress_non_speech_tokens) {
|
if (params.suppress_non_speech_tokens) {
|
||||||
@@ -4293,10 +4303,19 @@ static void whisper_process_logits(
|
|||||||
//log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
//log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
||||||
|
|
||||||
if (timestamp_logprob > max_text_token_logprob) {
|
if (timestamp_logprob > max_text_token_logprob) {
|
||||||
|
//printf("sampling timestamp\n");
|
||||||
for (int i = 0; i < vocab.token_beg; ++i) {
|
for (int i = 0; i < vocab.token_beg; ++i) {
|
||||||
logits[i] = -INFINITY;
|
logits[i] = -INFINITY;
|
||||||
logprobs[i] = -INFINITY;
|
logprobs[i] = -INFINITY;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
//printf("sampling text\n");
|
||||||
|
for (int i = vocab.token_beg; i < n_logits; ++i) {
|
||||||
|
logits[i] = -INFINITY;
|
||||||
|
logprobs[i] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
whisper_suppress_invalid_grammar(ctx, params, logprobs, decoder.grammar);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4312,34 +4331,57 @@ static void whisper_process_logits(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
#if 1
|
||||||
// print first 100 logits - token string : logit
|
// print first 100 logits - token string : logit
|
||||||
for (int i = 0; i < 100; i++) {
|
//for (int i = 0; i < 10; i++) {
|
||||||
const auto token = vocab.id_to_token.at(i);
|
// const auto token = vocab.id_to_token.at(i);
|
||||||
const auto prob = probs[i];
|
// const auto prob = probs[i];
|
||||||
const auto logit = logits[i];
|
// const auto logit = logits[i];
|
||||||
const auto logprob = logprobs[i];
|
// const auto logprob = logprobs[i];
|
||||||
printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
|
// printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
|
||||||
|
//}
|
||||||
|
|
||||||
|
// print sorted
|
||||||
|
{
|
||||||
|
std::vector<std::pair<float, int>> pairs;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
|
pairs.push_back(std::make_pair(probs[i], i));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
|
||||||
|
return a.first > b.first;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
const auto token = vocab.id_to_token.at(pairs[i].second);
|
||||||
|
const auto prob = pairs[i].first;
|
||||||
|
const auto logit = logits[pairs[i].second];
|
||||||
|
const auto logprob = logprobs[pairs[i].second];
|
||||||
|
printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("----------------\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
// "And", "and", " And", " and"
|
// "And", "and", " And", " and"
|
||||||
printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
//printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
|
||||||
printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
//printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
|
||||||
printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
//printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
|
||||||
printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
//printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
|
||||||
printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
//printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
|
||||||
|
|
||||||
printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
//printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
|
||||||
printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
//printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
|
||||||
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
//printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
|
||||||
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
//printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
|
||||||
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
//printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
|
||||||
|
|
||||||
printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
//printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
|
||||||
printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
//printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
|
||||||
printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
//printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
|
||||||
printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
//printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
|
||||||
printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
//printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user