whisper : add new-segment callback

Can be used to process new segments as they are being generated.
Sample usage in main, for printing the resulting segments during the
inference.
This commit is contained in:
Georgi Gerganov 2022-10-22 21:06:50 +03:00
parent 8f95c25aed
commit 7affd309d3
3 changed files with 81 additions and 39 deletions

View File

@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
const whisper_params & params = *(whisper_params *) user_data;
const int n_segments = whisper_full_n_segments(ctx);
// print the last segment
const int i = n_segments - 1;
if (i == 0) {
printf("\n");
}
if (params.no_timestamps) {
if (params.print_colors) {
// TODO
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
fflush(stdout);
}
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special_tokens == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
}
}
bool output_txt(struct whisper_context * ctx, const char * fname) { bool output_txt(struct whisper_context * ctx, const char * fname) {
std::ofstream fout(fname); std::ofstream fout(fname);
if (!fout.is_open()) { if (!fout.is_open()) {
@ -294,7 +343,7 @@ int main(int argc, char ** argv) {
{ {
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_realtime = !params.print_colors; wparams.print_realtime = false;
wparams.print_progress = false; wparams.print_progress = false;
wparams.print_timestamps = !params.no_timestamps; wparams.print_timestamps = !params.no_timestamps;
wparams.print_special_tokens = params.print_special_tokens; wparams.print_special_tokens = params.print_special_tokens;
@ -303,49 +352,17 @@ int main(int argc, char ** argv) {
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.offset_ms = params.offset_t_ms; wparams.offset_ms = params.offset_t_ms;
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &params;
}
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]); fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 7; return 7;
} }
// print result
if (!wparams.print_realtime) {
printf("\n");
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
if (params.no_timestamps) {
if (params.print_colors) {
// TODO
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("%s", text);
fflush(stdout);
}
} else {
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
if (params.print_colors) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
const char * text = whisper_full_get_token_text(ctx, i, j);
const float p = whisper_full_get_token_p (ctx, i, j);
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
}
printf("\n");
} else {
const char * text = whisper_full_get_segment_text(ctx, i);
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
}
}
}
}
printf("\n"); printf("\n");
// output to text file // output to text file

View File

@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.beam_width =*/ -1, /*.beam_width =*/ -1,
/*.n_best =*/ -1, /*.n_best =*/ -1,
}, },
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
}; };
} break; } break;
case WHISPER_SAMPLING_BEAM_SEARCH: case WHISPER_SAMPLING_BEAM_SEARCH:
@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.beam_width =*/ 10, /*.beam_width =*/ 10,
/*.n_best =*/ 5, /*.n_best =*/ 5,
}, },
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
}; };
} break; } break;
} }
@ -2549,6 +2555,9 @@ int whisper_full(
for (int j = i0; j <= i; j++) { for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
} }
text = ""; text = "";
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
@ -2576,6 +2585,9 @@ int whisper_full(
for (int j = i0; j < (int) tokens_cur.size(); j++) { for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
}
} }
} }
@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str(); return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
} }
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].id;
}
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].p; return ctx->result_all[i_segment].tokens[i_token].p;
} }

View File

@ -160,6 +160,11 @@ extern "C" {
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
}; };
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
struct whisper_full_params { struct whisper_full_params {
enum whisper_sampling_strategy strategy; enum whisper_sampling_strategy strategy;
@ -184,6 +189,9 @@ extern "C" {
int beam_width; int beam_width;
int n_best; int n_best;
} beam_search; } beam_search;
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
}; };
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@ -212,6 +220,7 @@ extern "C" {
// Get the token text of the specified token in the specified segment. // Get the token text of the specified token in the specified segment.
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
// Get the probability of the specified token in the specified segment. // Get the probability of the specified token in the specified segment.
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token); WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);