Compare commits

...

2 Commits

6 changed files with 112 additions and 14 deletions

View File

@ -72,7 +72,7 @@ int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
void whisper_print_segment(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -250,7 +250,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback = whisper_print_segment;
wparams.new_segment_callback_user_data = &user_data;
}

View File

@ -109,6 +109,73 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n");
}
struct whisper_logits_filter_user_data {
std::vector<std::string> * allowed_commands;
std::vector<std::vector<whisper_token>> * allowed_tokens;
};
void whisper_logits_filter(
struct whisper_context * ctx,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data){
const auto & allowed_tokens = *((whisper_logits_filter_user_data *) user_data)->allowed_tokens;
printf("n_tokens = %d\n", n_tokens);
for (int i = 0; i < n_tokens; i++) {
printf(" - '%s' (%.2f)\n", whisper_token_to_str(ctx, tokens[i].id), logits[i]);
}
if (n_tokens == 0) {
return;
}
std::vector<std::pair<whisper_token, float>> pool;
for (int i = 0; i < (int) allowed_tokens.size(); i++) {
const int n = (int) allowed_tokens[i].size();
if (n_tokens > n) {
continue;
}
const whisper_token id = allowed_tokens[i][n_tokens - 1];
pool.push_back({ id, logits[id] });
}
if (pool.empty()) {
return;
}
printf("applying logits filter, pool size = %d\n", (int) pool.size());
const int ibeg = whisper_token_beg(ctx);
double sum_all = 0.0;
for (int i = 0; i < ibeg; ++i) {
if (logits[i] == -INFINITY) {
continue;
}
sum_all += logits[i];
}
double sum_pool = 0.0;
for (int i = 0; i < (int) pool.size(); ++i) {
sum_pool += pool[i].second;
}
printf("sum_all = %.2f, sum_pool = %.2f\n", sum_all, sum_pool);
for (int i = 0; i < ibeg; ++i) {
logits[i] = -INFINITY;
}
for (int i = 0; i < (int) pool.size(); ++i) {
//logits[pool[i].first] = pool[i].second / sum_pool * sum_all;
logits[pool[i].first] = pool[i].second;
printf(" - '%s' (%.2f)\n", whisper_token_to_str(ctx, pool[i].first), logits[pool[i].first]);
}
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
@ -131,6 +198,8 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.temperature_inc = -1.0f;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
@ -334,22 +403,31 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = 1;
//wparams.max_tokens = 1;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.temperature_inc = -1.0f;
wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size();
whisper_logits_filter_user_data user_data = { &allowed_commands, &allowed_tokens };
wparams.logits_filter_callback = whisper_logits_filter;
wparams.logits_filter_callback_user_data = &user_data;
// run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
break;
}
fprintf(stdout, "%s: text - '%s'\n", __func__, whisper_full_get_segment_text(ctx, 0));
// estimate command probability
// NOTE: not optimal
{
@ -436,7 +514,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
// always-prompt mode
// transcribe the voice into text after valid prompt
int always_prompt_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
int process_always_prompt(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool ask_prompt = true;
@ -496,7 +574,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
const float sim = similarity(prompt, k_prompt);
//debug
//fprintf(stdout, "command size: %i\n", command_length);
//fprintf(stdout, "command size: %d, sim: %f\n", (int) command.size(), sim);
if ((sim > 0.7f) && (command.size() > 0)) {
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
@ -676,7 +754,7 @@ int main(int argc, char ** argv) {
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
ret_val = process_always_prompt(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
}

View File

@ -1,13 +1,13 @@
#pragma once
#include <SDL.h>
#include <SDL_audio.h>
#include <atomic>
#include <cstdint>
#include <vector>
#include <mutex>
#include <SDL.h>
#include <SDL_audio.h>
//
// SDL Audio capture
//

View File

@ -193,7 +193,7 @@ struct whisper_print_user_data {
const std::vector<std::vector<float>> * pcmf32s;
};
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
void whisper_print_segment(struct whisper_context * ctx, int n_new, void * user_data) {
const auto & params = *((whisper_print_user_data *) user_data)->params;
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
@ -597,7 +597,7 @@ int main(int argc, char ** argv) {
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback = whisper_print_segment;
wparams.new_segment_callback_user_data = &user_data;
}

View File

@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder
const size_t mem_required_decoder =
@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
};
switch (strategy) {
@ -3089,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens = {
// - applies logit filters
// - computes logprobs and probs
static void whisper_process_logits(
const struct whisper_context & ctx,
struct whisper_context & ctx,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
float temperature) {
@ -3145,6 +3148,9 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
@ -3848,7 +3854,7 @@ int whisper_full(
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});
unsigned int cur_c = 0;
uint32_t cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];

View File

@ -243,6 +243,16 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);
// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
@ -315,6 +325,10 @@ extern "C" {
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);