From b5055396706a0a597b6de6d8d9e15be6da5f75a6 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 2 Jun 2025 14:58:05 +0200 Subject: [PATCH] node : add language detection support (#3190) This commit add support for language detection in the Whisper Node.js addon example. It also updates the node addon to return an object instead of an array as the results. The motivation for this change is to enable the inclusion of the detected language in the result, in addition to the transcription segments. For example, when using the `detect_language` option, the result will now be: ```console { language: 'en' } ``` And if the `language` option is set to "auto", it will also return: ```console { language: 'en', transcription: [ [ '00:00:00.000', '00:00:07.600', ' And so my fellow Americans, ask not what your country can do for you,' ], [ '00:00:07.600', '00:00:10.600', ' ask what you can do for your country.' ] ] } ``` --- examples/addon.node/__test__/whisper.spec.js | 4 +- examples/addon.node/addon.cpp | 53 ++++++++++++++------ examples/addon.node/index.js | 3 ++ 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index 142c4e31..ce063211 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -17,6 +17,7 @@ const whisperParamsMock = { comma_in_time: false, translate: true, no_timestamps: false, + detect_language: false, audio_ctx: 0, max_len: 0, prompt: "", @@ -30,8 +31,9 @@ const whisperParamsMock = { describe("Run whisper.node", () => { test("it should receive a non-empty value", async () => { let result = await whisperAsync(whisperParamsMock); + console.log(result); - expect(result.length).toBeGreaterThan(0); + expect(result['transcription'].length).toBeGreaterThan(0); }, 10000); }); diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 8181ca24..67b1ec92 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -38,6 +38,7 @@ struct whisper_params { bool print_progress = false; bool no_timestamps = false; bool no_prints = false; + bool detect_language= false; bool use_gpu = true; bool flash_attn = false; bool comma_in_time = true; @@ -130,6 +131,11 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper void cb_log_disable(enum ggml_log_level, const char *, void *) {} +struct whisper_result { + std::vector> segments; + std::string language; +}; + class ProgressWorker : public Napi::AsyncWorker { public: ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env) @@ -160,15 +166,27 @@ class ProgressWorker : public Napi::AsyncWorker { void OnOK() override { Napi::HandleScope scope(Env()); - Napi::Object res = Napi::Array::New(Env(), result.size()); - for (uint64_t i = 0; i < result.size(); ++i) { + + if (params.detect_language) { + Napi::Object resultObj = Napi::Object::New(Env()); + resultObj.Set("language", Napi::String::New(Env(), result.language)); + Callback().Call({Env().Null(), resultObj}); + } + + Napi::Object returnObj = Napi::Object::New(Env()); + if (!result.language.empty()) { + returnObj.Set("language", Napi::String::New(Env(), result.language)); + } + Napi::Array transcriptionArray = Napi::Array::New(Env(), result.segments.size()); + for (uint64_t i = 0; i < result.segments.size(); ++i) { Napi::Object tmp = Napi::Array::New(Env(), 3); for (uint64_t j = 0; j < 3; ++j) { - tmp[j] = Napi::String::New(Env(), result[i][j]); + tmp[j] = Napi::String::New(Env(), result.segments[i][j]); } - res[i] = tmp; - } - Callback().Call({Env().Null(), res}); + transcriptionArray[i] = tmp; + } + returnObj.Set("transcription", transcriptionArray); + Callback().Call({Env().Null(), returnObj}); } // Progress callback function - using thread-safe function @@ -185,12 +203,12 @@ class ProgressWorker : public Napi::AsyncWorker { private: whisper_params params; - std::vector> result; + whisper_result result; Napi::Env env; Napi::ThreadSafeFunction tsfn; // Custom run function with progress callback support - int run_with_progress(whisper_params ¶ms, std::vector> &result) { + int run_with_progress(whisper_params ¶ms, whisper_result & result) { if (params.no_prints) { whisper_log_set(cb_log_disable, NULL); } @@ -279,7 +297,8 @@ class ProgressWorker : public Napi::AsyncWorker { wparams.print_timestamps = !params.no_timestamps; wparams.print_special = params.print_special; wparams.translate = params.translate; - wparams.language = params.language.c_str(); + wparams.language = params.detect_language ? "auto" : params.language.c_str(); + wparams.detect_language = params.detect_language; wparams.n_threads = params.n_threads; wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.offset_ms = params.offset_t_ms; @@ -330,18 +349,22 @@ class ProgressWorker : public Napi::AsyncWorker { return 10; } } - } + } + if (params.detect_language || params.language == "auto") { + result.language = whisper_lang_str(whisper_full_lang_id(ctx)); + } const int n_segments = whisper_full_n_segments(ctx); - result.resize(n_segments); + result.segments.resize(n_segments); + for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - result[i].emplace_back(to_timestamp(t0, params.comma_in_time)); - result[i].emplace_back(to_timestamp(t1, params.comma_in_time)); - result[i].emplace_back(text); + result.segments[i].emplace_back(to_timestamp(t0, params.comma_in_time)); + result.segments[i].emplace_back(to_timestamp(t1, params.comma_in_time)); + result.segments[i].emplace_back(text); } whisper_print_timings(ctx); @@ -366,6 +389,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { bool flash_attn = whisper_params.Get("flash_attn").As(); bool no_prints = whisper_params.Get("no_prints").As(); bool no_timestamps = whisper_params.Get("no_timestamps").As(); + bool detect_language = whisper_params.Get("detect_language").As(); int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); bool comma_in_time = whisper_params.Get("comma_in_time").As(); int32_t max_len = whisper_params.Get("max_len").As(); @@ -418,6 +442,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { params.max_context = max_context; params.print_progress = print_progress; params.prompt = prompt; + params.detect_language = detect_language; Napi::Function callback = info[1].As(); // Create a new Worker class with progress callback support diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 408d6d33..9324d6fa 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -17,6 +17,7 @@ const whisperParams = { comma_in_time: false, translate: true, no_timestamps: false, + detect_language: false, audio_ctx: 0, max_len: 0, progress_callback: (progress) => { @@ -31,6 +32,8 @@ const params = Object.fromEntries( const [key, value] = item.slice(2).split("="); if (key === "audio_ctx") { whisperParams[key] = parseInt(value); + } else if (key === "detect_language") { + whisperParams[key] = value === "true"; } else { whisperParams[key] = value; }