From a0d4f8e65ca03247ef385552a34be11ef6f1a871 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Jan 2023 21:45:05 +0200 Subject: [PATCH] main : make whisper_print_segment_callback() more readable (close #371) --- examples/main/main.cpp | 144 +++++++++++++++++++---------------------- 1 file changed, 68 insertions(+), 76 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 93108947..ef903aa7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -176,90 +176,82 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const int n_segments = whisper_full_n_segments(ctx); + std::string speaker = ""; + + int64_t t0; + int64_t t1; + // print the last n_new segments const int s0 = n_segments - n_new; + if (s0 == 0) { printf("\n"); } for (int i = s0; i < n_segments; i++) { - if (params.no_timestamps) { - if (params.print_colors) { - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; - } - } - - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); - } - } else { - const char * text = whisper_full_get_segment_text(ctx, i); - printf("%s", text); - } - fflush(stdout); - } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - - std::string speaker; - - if (params.diarize && pcmf32s.size() == 2) { - const int64_t n_samples = pcmf32s[0].size(); - - const int64_t is0 = timestamp_to_sample(t0, n_samples); - const int64_t is1 = timestamp_to_sample(t1, n_samples); - - double energy0 = 0.0f; - double energy1 = 0.0f; - - for (int64_t j = is0; j < is1; j++) { - energy0 += fabs(pcmf32s[0][j]); - energy1 += fabs(pcmf32s[1][j]); - } - - if (energy0 > 1.1*energy1) { - speaker = "(speaker 0)"; - } else if (energy1 > 1.1*energy0) { - speaker = "(speaker 1)"; - } else { - speaker = "(speaker ?)"; - } - - //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); - } - - if (params.print_colors) { - printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; - } - } - - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - - printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); - } - printf("\n"); - } else { - const char * text = whisper_full_get_segment_text(ctx, i); - - printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text); - } + if (!params.no_timestamps || params.diarize) { + t0 = whisper_full_get_segment_t0(ctx, i); + t1 = whisper_full_get_segment_t1(ctx, i); } + + if (!params.no_timestamps) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + } + + if (params.diarize && pcmf32s.size() == 2) { + + const int64_t n_samples = pcmf32s[0].size(); + + const int64_t is0 = timestamp_to_sample(t0, n_samples); + const int64_t is1 = timestamp_to_sample(t1, n_samples); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for (int64_t j = is0; j < is1; j++) { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } + + if (energy0 > 1.1*energy1) { + speaker = "(speaker 0)"; + } else if (energy1 > 1.1*energy0) { + speaker = "(speaker 1)"; + } else { + speaker = "(speaker ?)"; + } + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + } + + if (params.print_colors) { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } + } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); + } + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + + printf("%s%s", speaker.c_str(), text); + } + + // with timestamps or speakers: each segment on new line + if (!params.no_timestamps || params.diarize) { + printf("\n"); + } + + fflush(stdout); } }