mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-30 14:14:41 +02:00
whisper : add detect-language mode (#853)
* add detectlanguage flag * renaming and help * no idea why that last one didn't commit * run language detection if dl is set * help message fix * various fixes * fix quitting * fix language being english on print
This commit is contained in:
parent
be5911a9f3
commit
b806420873
@ -66,6 +66,7 @@ struct whisper_params {
|
|||||||
|
|
||||||
bool speed_up = false;
|
bool speed_up = false;
|
||||||
bool translate = false;
|
bool translate = false;
|
||||||
|
bool detect_language= false;
|
||||||
bool diarize = false;
|
bool diarize = false;
|
||||||
bool split_on_word = false;
|
bool split_on_word = false;
|
||||||
bool no_fallback = false;
|
bool no_fallback = false;
|
||||||
@ -141,6 +142,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||||
|
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
|
||||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||||
@ -191,6 +193,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
|
||||||
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true");
|
||||||
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
|
||||||
|
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
|
||||||
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
|
||||||
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
|
||||||
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
|
||||||
@ -739,6 +742,9 @@ int main(int argc, char ** argv) {
|
|||||||
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (params.detect_language) {
|
||||||
|
params.language = "auto";
|
||||||
|
}
|
||||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
|
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
|
||||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||||
params.n_threads, params.n_processors,
|
params.n_threads, params.n_processors,
|
||||||
@ -761,6 +767,7 @@ int main(int argc, char ** argv) {
|
|||||||
wparams.print_special = params.print_special;
|
wparams.print_special = params.print_special;
|
||||||
wparams.translate = params.translate;
|
wparams.translate = params.translate;
|
||||||
wparams.language = params.language.c_str();
|
wparams.language = params.language.c_str();
|
||||||
|
wparams.detect_language = params.detect_language;
|
||||||
wparams.n_threads = params.n_threads;
|
wparams.n_threads = params.n_threads;
|
||||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||||
wparams.offset_ms = params.offset_t_ms;
|
wparams.offset_ms = params.offset_t_ms;
|
||||||
|
@ -3312,6 +3312,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.prompt_n_tokens =*/ 0,
|
/*.prompt_n_tokens =*/ 0,
|
||||||
|
|
||||||
/*.language =*/ "en",
|
/*.language =*/ "en",
|
||||||
|
/*.detect_language =*/ false,
|
||||||
|
|
||||||
/*.suppress_blank =*/ true,
|
/*.suppress_blank =*/ true,
|
||||||
/*.suppress_non_speech_tokens =*/ false,
|
/*.suppress_non_speech_tokens =*/ false,
|
||||||
@ -3898,7 +3899,7 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// auto-detect language if not specified
|
// auto-detect language if not specified
|
||||||
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0) {
|
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
|
||||||
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
|
||||||
|
|
||||||
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
|
||||||
@ -3910,6 +3911,9 @@ int whisper_full_with_state(
|
|||||||
params.language = whisper_lang_str(lang_id);
|
params.language = whisper_lang_str(lang_id);
|
||||||
|
|
||||||
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
fprintf(stderr, "%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
|
||||||
|
if (params.detect_language) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.token_timestamps) {
|
if (params.token_timestamps) {
|
||||||
|
@ -365,6 +365,7 @@ extern "C" {
|
|||||||
|
|
||||||
// for auto-detection, set to nullptr, "" or "auto"
|
// for auto-detection, set to nullptr, "" or "auto"
|
||||||
const char * language;
|
const char * language;
|
||||||
|
bool detect_language;
|
||||||
|
|
||||||
// common decoding parameters:
|
// common decoding parameters:
|
||||||
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
||||||
|
Loading…
Reference in New Issue
Block a user