whisper : support speaker segmentation (local diarization) of mono audio via tinydiarize (#1058)

* add HuggingFace mirror to download  ggml model

* support tdrz via simple hack overriding solm tokens

* fix incorrect translate/transcribe token_ids that are not static const

* add apollo 13 sample for tdrz demo

* render [SPEAKER TURN] consistently in all terminal output using vocab.id_to_token

* extend whisper_segment with speaker_turn_next field and save in json output

* fix failing go build

* slipped in some python syntax whoops

* whisper : finalize tinydiarize support (add flag + fixes)

* whisper : tdrz support for word-level timestamps (respect max_len)

* java : try to fix tests after adding tdrz_enable flag

* main : remove TODO leftover

* java : fix params order list after adding "tdrz_enable"

* whisper : fix solm and add nosp token

* main : print tinydiarize help

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Akash Mahajan 2023-07-03 23:45:00 -07:00 committed by GitHub
parent fdf58a6668
commit c8d0f5fe98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 215 additions and 130 deletions

View File

@ -308,12 +308,16 @@ samples:
@wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg @wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
@wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg @wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg
@wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav @wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav
@wget --quiet --show-progress -O samples/a13.mp3 https://upload.wikimedia.org/wikipedia/commons/transcoded/6/6f/Apollo13-wehaveaproblem.ogg/Apollo13-wehaveaproblem.ogg.mp3
@echo "Converting to 16-bit WAV ..." @echo "Converting to 16-bit WAV ..."
@ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav @ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav
@ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav @ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav
@ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav @ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav
@rm samples/*.ogg
@ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav @ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav
@rm samples/mm1.wav @rm samples/mm1.wav
@ffmpeg -loglevel -0 -y -i samples/a13.mp3 -ar 16000 -ac 1 -c:a pcm_s16le -ss 00:00:00 -to 00:00:30 samples/a13.wav
@rm samples/a13.mp3
# #
# Models # Models

View File

@ -270,13 +270,13 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token {
} }
// Task tokens // Task tokens
func Whisper_token_translate() Token { func (ctx *Context) Whisper_token_translate() Token {
return Token(C.whisper_token_translate()) return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx)))
} }
// Task tokens // Task tokens
func Whisper_token_transcribe() Token { func (ctx *Context) Whisper_token_transcribe() Token {
return Token(C.whisper_token_transcribe()) return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx)))
} }
// Performance information // Performance information

View File

@ -224,8 +224,8 @@ public interface WhisperCppJnaLibrary extends Library {
int whisper_token_lang(Pointer ctx, int lang_id); int whisper_token_lang(Pointer ctx, int lang_id);
// Task tokens // Task tokens
int whisper_token_translate(); int whisper_token_translate (Pointer ctx);
int whisper_token_transcribe(); int whisper_token_transcribe(Pointer ctx);
// Performance information from the default state. // Performance information from the default state.
void whisper_print_timings(Pointer ctx); void whisper_print_timings(Pointer ctx);

View File

@ -137,6 +137,14 @@ public class WhisperFullParams extends Structure {
/** Overwrite the audio context size (0 = use default). */ /** Overwrite the audio context size (0 = use default). */
public int audio_ctx; public int audio_ctx;
/** Enable tinydiarize (default = false) */
public CBool tdrz_enable;
/** Enable tinydiarize (default = false) */
public void tdrzEnable(boolean enable) {
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
}
/** Tokens to provide to the whisper decoder as an initial prompt. /** Tokens to provide to the whisper decoder as an initial prompt.
* These are prepended to any existing text context from a previous call. */ * These are prepended to any existing text context from a previous call. */
public String initial_prompt; public String initial_prompt;
@ -302,7 +310,7 @@ public class WhisperFullParams extends Structure {
"no_context", "single_segment", "no_context", "single_segment",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
"initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data", "new_segment_callback", "new_segment_callback_user_data",

View File

@ -68,28 +68,32 @@ struct whisper_params {
float entropy_thold = 2.40f; float entropy_thold = 2.40f;
float logprob_thold = -1.00f; float logprob_thold = -1.00f;
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
bool detect_language= false; bool detect_language = false;
bool diarize = false; bool diarize = false;
bool split_on_word = false; bool tinydiarize = false;
bool no_fallback = false; bool split_on_word = false;
bool output_txt = false; bool no_fallback = false;
bool output_vtt = false; bool output_txt = false;
bool output_srt = false; bool output_vtt = false;
bool output_wts = false; bool output_srt = false;
bool output_csv = false; bool output_wts = false;
bool output_jsn = false; bool output_csv = false;
bool output_lrc = false; bool output_jsn = false;
bool print_special = false; bool output_lrc = false;
bool print_colors = false; bool print_special = false;
bool print_progress = false; bool print_colors = false;
bool no_timestamps = false; bool print_progress = false;
bool no_timestamps = false;
std::string language = "en"; std::string language = "en";
std::string prompt; std::string prompt;
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
std::string model = "models/ggml-base.en.bin"; std::string model = "models/ggml-base.en.bin";
// [TDRZ] speaker turn string
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
std::vector<std::string> fname_inp = {}; std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {}; std::vector<std::string> fname_out = {};
@ -115,41 +119,42 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
exit(0); exit(0);
} }
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; } else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else { else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params); whisper_print_usage(argc, argv, params);
@ -182,6 +187,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
@ -297,6 +303,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
printf("%s%s", speaker.c_str(), text); printf("%s%s", speaker.c_str(), text);
} }
if (params.tinydiarize) {
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
printf("%s", params.tdrz_speaker_turn.c_str());
}
}
// with timestamps or speakers: each segment on new line // with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) { if (!params.no_timestamps || params.diarize) {
printf("\n"); printf("\n");
@ -564,6 +576,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
const int n_segments = whisper_full_n_segments(ctx); const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) { for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i); const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
@ -576,11 +589,15 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
value_i("from", t0 * 10, false); value_i("from", t0 * 10, false);
value_i("to", t1 * 10, true); value_i("to", t1 * 10, true);
end_obj(false); end_obj(false);
value_s("text", text, !params.diarize); value_s("text", text, !params.diarize && !params.tinydiarize);
if (params.diarize && pcmf32s.size() == 2) { if (params.diarize && pcmf32s.size() == 2) {
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
} }
if (params.tinydiarize) {
value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
}
end_obj(i == (n_segments - 1)); end_obj(i == (n_segments - 1));
} }
@ -777,6 +794,12 @@ int main(int argc, char ** argv) {
exit(0); exit(0);
} }
if (params.diarize && params.tinydiarize) {
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init // whisper init
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
@ -818,11 +841,12 @@ int main(int argc, char ** argv) {
if (params.detect_language) { if (params.detect_language) {
params.language = "auto"; params.language = "auto";
} }
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors, params.n_threads, params.n_processors,
params.language.c_str(), params.language.c_str(),
params.translate ? "translate" : "transcribe", params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",
params.no_timestamps ? 0 : 1); params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -853,6 +877,8 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
wparams.initial_prompt = params.prompt.c_str(); wparams.initial_prompt = params.prompt.c_str();
wparams.greedy.best_of = params.best_of; wparams.greedy.best_of = params.best_of;

View File

@ -22,7 +22,7 @@ function get_script_path() {
models_path="$(get_script_path)" models_path="$(get_script_path)"
# Whisper models # Whisper models
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" ) models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small.en-tdrz" "small" "medium.en" "medium" "large-v1" "large" )
# list available models # list available models
function list_models { function list_models {
@ -50,6 +50,12 @@ if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
exit 1 exit 1
fi fi
# check if model contains `tdrz` and update the src and pfx accordingly
if [[ $model == *"tdrz"* ]]; then
src="https://huggingface.co/akashmjn/tinydiarize-whisper.cpp"
pfx="resolve/main/ggml"
fi
# download ggml model # download ggml model
printf "Downloading ggml model $model from '$src' ...\n" printf "Downloading ggml model $model from '$src' ...\n"

View File

@ -380,16 +380,18 @@ struct whisper_vocab {
std::map<token, id> token_to_id; std::map<token, id> token_to_id;
std::map<id, token> id_to_token; std::map<id, token> id_to_token;
id token_eot = 50256; // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
id token_sot = 50257; id token_eot = 50256;
id token_prev = 50360; id token_sot = 50257;
id token_solm = 50361; // ?? // task tokens (used only for multilingual models)
id token_not = 50362; // no timestamps id token_translate = 50357;
id token_beg = 50363; id token_transcribe = 50358;
// other special tokens
// available tasks id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn
static const id token_translate = 50358; id token_prev = 50360;
static const id token_transcribe = 50359; id token_nosp = 50361;
id token_not = 50362; // no timestamps
id token_beg = 50363; // begin timestamps
bool is_multilingual() const { bool is_multilingual() const {
return n_vocab == 51865; return n_vocab == 51865;
@ -403,6 +405,8 @@ struct whisper_segment {
std::string text; std::string text;
std::vector<whisper_token_data> tokens; std::vector<whisper_token_data> tokens;
bool speaker_turn_next;
}; };
// medium // medium
@ -966,8 +970,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
if (vocab.is_multilingual()) { if (vocab.is_multilingual()) {
vocab.token_eot++; vocab.token_eot++;
vocab.token_sot++; vocab.token_sot++;
vocab.token_prev++; vocab.token_translate++;
vocab.token_transcribe++;
vocab.token_solm++; vocab.token_solm++;
vocab.token_prev++;
vocab.token_nosp++;
vocab.token_not++; vocab.token_not++;
vocab.token_beg++; vocab.token_beg++;
} }
@ -981,8 +988,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
word = "[_EOT_]"; word = "[_EOT_]";
} else if (i == vocab.token_sot) { } else if (i == vocab.token_sot) {
word = "[_SOT_]"; word = "[_SOT_]";
} else if (i == vocab.token_solm) {
word = "[_SOLM_]";
} else if (i == vocab.token_prev) { } else if (i == vocab.token_prev) {
word = "[_PREV_]"; word = "[_PREV_]";
} else if (i == vocab.token_nosp) {
word = "[_NOSP_]";
} else if (i == vocab.token_not) { } else if (i == vocab.token_not) {
word = "[_NOT_]"; word = "[_NOT_]";
} else if (i == vocab.token_beg) { } else if (i == vocab.token_beg) {
@ -3208,12 +3219,16 @@ whisper_token whisper_token_sot(struct whisper_context * ctx) {
return ctx->vocab.token_sot; return ctx->vocab.token_sot;
} }
whisper_token whisper_token_solm(struct whisper_context * ctx) {
return ctx->vocab.token_solm;
}
whisper_token whisper_token_prev(struct whisper_context * ctx) { whisper_token whisper_token_prev(struct whisper_context * ctx) {
return ctx->vocab.token_prev; return ctx->vocab.token_prev;
} }
whisper_token whisper_token_solm(struct whisper_context * ctx) { whisper_token whisper_token_nosp(struct whisper_context * ctx) {
return ctx->vocab.token_solm; return ctx->vocab.token_nosp;
} }
whisper_token whisper_token_not(struct whisper_context * ctx) { whisper_token whisper_token_not(struct whisper_context * ctx) {
@ -3228,12 +3243,12 @@ whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
return whisper_token_sot(ctx) + 1 + lang_id; return whisper_token_sot(ctx) + 1 + lang_id;
} }
whisper_token whisper_token_translate(void) { whisper_token whisper_token_translate(struct whisper_context * ctx) {
return whisper_vocab::token_translate; return ctx->vocab.token_translate;
} }
whisper_token whisper_token_transcribe(void) { whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
return whisper_vocab::token_transcribe; return ctx->vocab.token_transcribe;
} }
void whisper_print_timings(struct whisper_context * ctx) { void whisper_print_timings(struct whisper_context * ctx) {
@ -3305,51 +3320,53 @@ struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sam
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = { struct whisper_full_params result = {
/*.strategy =*/ strategy, /*.strategy =*/ strategy,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384, /*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0, /*.offset_ms =*/ 0,
/*.duration_ms =*/ 0, /*.duration_ms =*/ 0,
/*.translate =*/ false, /*.translate =*/ false,
/*.no_context =*/ true, /*.no_context =*/ true,
/*.single_segment =*/ false, /*.single_segment =*/ false,
/*.print_special =*/ false, /*.print_special =*/ false,
/*.print_progress =*/ true, /*.print_progress =*/ true,
/*.print_realtime =*/ false, /*.print_realtime =*/ false,
/*.print_timestamps =*/ true, /*.print_timestamps =*/ true,
/*.token_timestamps =*/ false, /*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f, /*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f, /*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0, /*.max_len =*/ 0,
/*.split_on_word =*/ false, /*.split_on_word =*/ false,
/*.max_tokens =*/ 0, /*.max_tokens =*/ 0,
/*.speed_up =*/ false, /*.speed_up =*/ false,
/*.audio_ctx =*/ 0, /*.audio_ctx =*/ 0,
/*.initial_prompt =*/ nullptr, /*.tdrz_enable =*/ false,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en", /*.initial_prompt =*/ nullptr,
/*.detect_language =*/ false, /*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.suppress_blank =*/ true, /*.language =*/ "en",
/*.detect_language =*/ false,
/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false, /*.suppress_non_speech_tokens =*/ false,
/*.temperature =*/ 0.0f, /*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f, /*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f, /*.length_penalty =*/ -1.0f,
/*.temperature_inc =*/ 0.4f, /*.temperature_inc =*/ 0.4f,
/*.entropy_thold =*/ 2.4f, /*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f, /*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f, /*.no_speech_thold =*/ 0.6f,
/*.greedy =*/ { /*.greedy =*/ {
/*.best_of =*/ -1, /*.best_of =*/ -1,
}, },
@ -3430,6 +3447,7 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
state.result_all.back().text = std::move(text); state.result_all.back().text = std::move(text);
state.result_all.back().t1 = token.t0; state.result_all.back().t1 = token.t0;
state.result_all.back().tokens.resize(i); state.result_all.back().tokens.resize(i);
state.result_all.back().speaker_turn_next = false;
state.result_all.push_back({}); state.result_all.push_back({});
state.result_all.back().t0 = token.t0; state.result_all.back().t0 = token.t0;
@ -3441,6 +3459,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta
segment.tokens.begin() + i, segment.tokens.begin() + i,
segment.tokens.end()); segment.tokens.end());
state.result_all.back().speaker_turn_next = segment.speaker_turn_next;
acc = 0; acc = 0;
text = ""; text = "";
@ -3519,9 +3539,14 @@ static void whisper_process_logits(
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY; logits[vocab.token_not] = -INFINITY;
// suppress sot and solm tokens // suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY; logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_solm] = -INFINITY; logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
// [TDRZ] when tinydiarize is disabled, suppress solm token
if (params.tdrz_enable == false) {
logits[vocab.token_solm] = -INFINITY;
}
// suppress task tokens // suppress task tokens
logits[vocab.token_translate] = -INFINITY; logits[vocab.token_translate] = -INFINITY;
@ -4018,9 +4043,9 @@ int whisper_full_with_state(
state->lang_id = lang_id; state->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id)); prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) { if (params.translate) {
prompt_init.push_back(whisper_token_translate()); prompt_init.push_back(whisper_token_translate(ctx));
} else { } else {
prompt_init.push_back(whisper_token_transcribe()); prompt_init.push_back(whisper_token_transcribe(ctx));
} }
} }
@ -4500,23 +4525,27 @@ int whisper_full_with_state(
prompt_past.push_back(tokens_cur[i].id); prompt_past.push_back(tokens_cur[i].id);
} }
// store the text from this iteration
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
int i0 = 0; int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
std::string text; std::string text;
bool speaker_turn_next = false;
for (int i = 0; i < (int) tokens_cur.size(); i++) { for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__, //printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) {
} else {
text += whisper_token_to_str(ctx, tokens_cur[i].id); text += whisper_token_to_str(ctx, tokens_cur[i].id);
} }
// [TDRZ] record if speaker turn was predicted after current segment
if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) {
speaker_turn_next = true;
}
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
@ -4535,7 +4564,7 @@ int whisper_full_with_state(
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
result_all.push_back({ tt0, tt1, text, {} }); result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next });
for (int j = i0; j <= i; j++) { for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
@ -4561,6 +4590,7 @@ int whisper_full_with_state(
i--; i--;
t0 = t1; t0 = t1;
i0 = i + 1; i0 = i + 1;
speaker_turn_next = false;
} }
} }
@ -4579,7 +4609,7 @@ int whisper_full_with_state(
} }
} }
result_all.push_back({ tt0, tt1, text, {} }); result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next });
for (int j = i0; j < (int) tokens_cur.size(); j++) { for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]); result_all.back().tokens.push_back(tokens_cur[j]);
} }
@ -4759,6 +4789,10 @@ int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)
return ctx->state->result_all[i_segment].t1; return ctx->state->result_all[i_segment].t1;
} }
bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].speaker_turn_next;
}
const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].text.c_str(); return state->result_all[i_segment].text.c_str();
} }

View File

@ -277,15 +277,16 @@ extern "C" {
// Special tokens // Special tokens
WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id);
// Task tokens // Task tokens
WHISPER_API whisper_token whisper_token_translate (void); WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx);
WHISPER_API whisper_token whisper_token_transcribe(void); WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx);
// Performance information from the default state. // Performance information from the default state.
WHISPER_API void whisper_print_timings(struct whisper_context * ctx); WHISPER_API void whisper_print_timings(struct whisper_context * ctx);
@ -358,6 +359,9 @@ extern "C" {
bool speed_up; // speed-up the audio by 2x using Phase Vocoder bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default) int audio_ctx; // overwrite the audio context size (0 = use default)
// [EXPERIMENTAL] [TDRZ] tinydiarize
bool tdrz_enable; // enable tinydiarize speaker turn detection
// tokens to provide to the whisper decoder as initial prompt // tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call // these are prepended to any existing text context from a previous call
const char * initial_prompt; const char * initial_prompt;
@ -460,6 +464,9 @@ extern "C" {
WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment);
// Get whether the next segment is predicted as a speaker turn
WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment);
// Get the text of the specified segment // Get the text of the specified segment
WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment);
WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment);
@ -488,9 +495,9 @@ extern "C" {
// Temporary helpers needed for exposing ggml interface // Temporary helpers needed for exposing ggml interface
WHISPER_API int whisper_bench_memcpy(int n_threads); WHISPER_API int whisper_bench_memcpy (int n_threads);
WHISPER_API const char * whisper_bench_memcpy_str(int n_threads); WHISPER_API const char * whisper_bench_memcpy_str (int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads);
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads);
#ifdef __cplusplus #ifdef __cplusplus