mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-26 16:48:50 +01:00
whisper : add new-segment callback
Can be used to process new segments as they are being generated. Sample usage in main, for printing the resulting segments during the inference.
This commit is contained in:
parent
8f95c25aed
commit
7affd309d3
95
main.cpp
95
main.cpp
@ -141,6 +141,55 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
|
||||||
|
const whisper_params & params = *(whisper_params *) user_data;
|
||||||
|
|
||||||
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
|
|
||||||
|
// print the last segment
|
||||||
|
const int i = n_segments - 1;
|
||||||
|
if (i == 0) {
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.no_timestamps) {
|
||||||
|
if (params.print_colors) {
|
||||||
|
// TODO
|
||||||
|
} else {
|
||||||
|
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||||
|
printf("%s", text);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||||
|
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||||
|
|
||||||
|
if (params.print_colors) {
|
||||||
|
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||||
|
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||||
|
if (params.print_special_tokens == false) {
|
||||||
|
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
||||||
|
if (id >= whisper_token_eot(ctx)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||||
|
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||||
|
|
||||||
|
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||||
|
|
||||||
|
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
} else {
|
||||||
|
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||||
|
|
||||||
|
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
||||||
std::ofstream fout(fname);
|
std::ofstream fout(fname);
|
||||||
if (!fout.is_open()) {
|
if (!fout.is_open()) {
|
||||||
@ -294,7 +343,7 @@ int main(int argc, char ** argv) {
|
|||||||
{
|
{
|
||||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
|
|
||||||
wparams.print_realtime = !params.print_colors;
|
wparams.print_realtime = false;
|
||||||
wparams.print_progress = false;
|
wparams.print_progress = false;
|
||||||
wparams.print_timestamps = !params.no_timestamps;
|
wparams.print_timestamps = !params.no_timestamps;
|
||||||
wparams.print_special_tokens = params.print_special_tokens;
|
wparams.print_special_tokens = params.print_special_tokens;
|
||||||
@ -303,49 +352,17 @@ int main(int argc, char ** argv) {
|
|||||||
wparams.n_threads = params.n_threads;
|
wparams.n_threads = params.n_threads;
|
||||||
wparams.offset_ms = params.offset_t_ms;
|
wparams.offset_ms = params.offset_t_ms;
|
||||||
|
|
||||||
|
// this callback is called on each new segment
|
||||||
|
if (!wparams.print_realtime) {
|
||||||
|
wparams.new_segment_callback = whisper_print_segment_callback;
|
||||||
|
wparams.new_segment_callback_user_data = ¶ms;
|
||||||
|
}
|
||||||
|
|
||||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||||
return 7;
|
return 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
// print result
|
|
||||||
if (!wparams.print_realtime) {
|
|
||||||
printf("\n");
|
|
||||||
|
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
|
||||||
if (params.no_timestamps) {
|
|
||||||
if (params.print_colors) {
|
|
||||||
// TODO
|
|
||||||
} else {
|
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
|
||||||
printf("%s", text);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
|
||||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
|
||||||
|
|
||||||
if (params.print_colors) {
|
|
||||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
|
||||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
|
||||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
|
||||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
|
||||||
|
|
||||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
|
||||||
|
|
||||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
} else {
|
|
||||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
|
||||||
|
|
||||||
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
// output to text file
|
// output to text file
|
||||||
|
16
whisper.cpp
16
whisper.cpp
@ -2320,6 +2320,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.beam_width =*/ -1,
|
/*.beam_width =*/ -1,
|
||||||
/*.n_best =*/ -1,
|
/*.n_best =*/ -1,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/*.new_segment_callback =*/ nullptr,
|
||||||
|
/*.new_segment_callback_user_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
case WHISPER_SAMPLING_BEAM_SEARCH:
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
||||||
@ -2348,6 +2351,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.beam_width =*/ 10,
|
/*.beam_width =*/ 10,
|
||||||
/*.n_best =*/ 5,
|
/*.n_best =*/ 5,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/*.new_segment_callback =*/ nullptr,
|
||||||
|
/*.new_segment_callback_user_data =*/ nullptr,
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
}
|
}
|
||||||
@ -2549,6 +2555,9 @@ int whisper_full(
|
|||||||
for (int j = i0; j <= i; j++) {
|
for (int j = i0; j <= i; j++) {
|
||||||
result_all.back().tokens.push_back(tokens_cur[j]);
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
||||||
}
|
}
|
||||||
|
if (params.new_segment_callback) {
|
||||||
|
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
text = "";
|
text = "";
|
||||||
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
|
||||||
@ -2576,6 +2585,9 @@ int whisper_full(
|
|||||||
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
||||||
result_all.back().tokens.push_back(tokens_cur[j]);
|
result_all.back().tokens.push_back(tokens_cur[j]);
|
||||||
}
|
}
|
||||||
|
if (params.new_segment_callback) {
|
||||||
|
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2609,6 +2621,10 @@ const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_seg
|
|||||||
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
|
||||||
|
return ctx->result_all[i_segment].tokens[i_token].id;
|
||||||
|
}
|
||||||
|
|
||||||
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
|
||||||
return ctx->result_all[i_segment].tokens[i_token].p;
|
return ctx->result_all[i_segment].tokens[i_token].p;
|
||||||
}
|
}
|
||||||
|
@ -160,6 +160,11 @@ extern "C" {
|
|||||||
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
|
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Text segment callback
|
||||||
|
// Called on every newly generated text segment
|
||||||
|
// Use the whisper_full_...() functions to obtain the text segments
|
||||||
|
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
|
||||||
|
|
||||||
struct whisper_full_params {
|
struct whisper_full_params {
|
||||||
enum whisper_sampling_strategy strategy;
|
enum whisper_sampling_strategy strategy;
|
||||||
|
|
||||||
@ -184,6 +189,9 @@ extern "C" {
|
|||||||
int beam_width;
|
int beam_width;
|
||||||
int n_best;
|
int n_best;
|
||||||
} beam_search;
|
} beam_search;
|
||||||
|
|
||||||
|
whisper_new_segment_callback new_segment_callback;
|
||||||
|
void * new_segment_callback_user_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
|
||||||
@ -212,6 +220,7 @@ extern "C" {
|
|||||||
|
|
||||||
// Get the token text of the specified token in the specified segment.
|
// Get the token text of the specified token in the specified segment.
|
||||||
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
|
WHISPER_API const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token);
|
||||||
|
WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token);
|
||||||
|
|
||||||
// Get the probability of the specified token in the specified segment.
|
// Get the probability of the specified token in the specified segment.
|
||||||
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
||||||
|
Loading…
Reference in New Issue
Block a user