diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index 142c4e31..ce063211 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -17,6 +17,7 @@ const whisperParamsMock = { comma_in_time: false, translate: true, no_timestamps: false, + detect_language: false, audio_ctx: 0, max_len: 0, prompt: "", @@ -30,8 +31,9 @@ const whisperParamsMock = { describe("Run whisper.node", () => { test("it should receive a non-empty value", async () => { let result = await whisperAsync(whisperParamsMock); + console.log(result); - expect(result.length).toBeGreaterThan(0); + expect(result['transcription'].length).toBeGreaterThan(0); }, 10000); }); diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 8181ca24..67b1ec92 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -38,6 +38,7 @@ struct whisper_params { bool print_progress = false; bool no_timestamps = false; bool no_prints = false; + bool detect_language= false; bool use_gpu = true; bool flash_attn = false; bool comma_in_time = true; @@ -130,6 +131,11 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper void cb_log_disable(enum ggml_log_level, const char *, void *) {} +struct whisper_result { + std::vector> segments; + std::string language; +}; + class ProgressWorker : public Napi::AsyncWorker { public: ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env) @@ -160,15 +166,27 @@ class ProgressWorker : public Napi::AsyncWorker { void OnOK() override { Napi::HandleScope scope(Env()); - Napi::Object res = Napi::Array::New(Env(), result.size()); - for (uint64_t i = 0; i < result.size(); ++i) { + + if (params.detect_language) { + Napi::Object resultObj = Napi::Object::New(Env()); + resultObj.Set("language", Napi::String::New(Env(), result.language)); + Callback().Call({Env().Null(), resultObj}); + } + + Napi::Object returnObj = Napi::Object::New(Env()); + if (!result.language.empty()) { + returnObj.Set("language", Napi::String::New(Env(), result.language)); + } + Napi::Array transcriptionArray = Napi::Array::New(Env(), result.segments.size()); + for (uint64_t i = 0; i < result.segments.size(); ++i) { Napi::Object tmp = Napi::Array::New(Env(), 3); for (uint64_t j = 0; j < 3; ++j) { - tmp[j] = Napi::String::New(Env(), result[i][j]); + tmp[j] = Napi::String::New(Env(), result.segments[i][j]); } - res[i] = tmp; - } - Callback().Call({Env().Null(), res}); + transcriptionArray[i] = tmp; + } + returnObj.Set("transcription", transcriptionArray); + Callback().Call({Env().Null(), returnObj}); } // Progress callback function - using thread-safe function @@ -185,12 +203,12 @@ class ProgressWorker : public Napi::AsyncWorker { private: whisper_params params; - std::vector> result; + whisper_result result; Napi::Env env; Napi::ThreadSafeFunction tsfn; // Custom run function with progress callback support - int run_with_progress(whisper_params ¶ms, std::vector> &result) { + int run_with_progress(whisper_params ¶ms, whisper_result & result) { if (params.no_prints) { whisper_log_set(cb_log_disable, NULL); } @@ -279,7 +297,8 @@ class ProgressWorker : public Napi::AsyncWorker { wparams.print_timestamps = !params.no_timestamps; wparams.print_special = params.print_special; wparams.translate = params.translate; - wparams.language = params.language.c_str(); + wparams.language = params.detect_language ? "auto" : params.language.c_str(); + wparams.detect_language = params.detect_language; wparams.n_threads = params.n_threads; wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.offset_ms = params.offset_t_ms; @@ -330,18 +349,22 @@ class ProgressWorker : public Napi::AsyncWorker { return 10; } } - } + } + if (params.detect_language || params.language == "auto") { + result.language = whisper_lang_str(whisper_full_lang_id(ctx)); + } const int n_segments = whisper_full_n_segments(ctx); - result.resize(n_segments); + result.segments.resize(n_segments); + for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - result[i].emplace_back(to_timestamp(t0, params.comma_in_time)); - result[i].emplace_back(to_timestamp(t1, params.comma_in_time)); - result[i].emplace_back(text); + result.segments[i].emplace_back(to_timestamp(t0, params.comma_in_time)); + result.segments[i].emplace_back(to_timestamp(t1, params.comma_in_time)); + result.segments[i].emplace_back(text); } whisper_print_timings(ctx); @@ -366,6 +389,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { bool flash_attn = whisper_params.Get("flash_attn").As(); bool no_prints = whisper_params.Get("no_prints").As(); bool no_timestamps = whisper_params.Get("no_timestamps").As(); + bool detect_language = whisper_params.Get("detect_language").As(); int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); bool comma_in_time = whisper_params.Get("comma_in_time").As(); int32_t max_len = whisper_params.Get("max_len").As(); @@ -418,6 +442,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { params.max_context = max_context; params.print_progress = print_progress; params.prompt = prompt; + params.detect_language = detect_language; Napi::Function callback = info[1].As(); // Create a new Worker class with progress callback support diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 408d6d33..9324d6fa 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -17,6 +17,7 @@ const whisperParams = { comma_in_time: false, translate: true, no_timestamps: false, + detect_language: false, audio_ctx: 0, max_len: 0, progress_callback: (progress) => { @@ -31,6 +32,8 @@ const params = Object.fromEntries( const [key, value] = item.slice(2).split("="); if (key === "audio_ctx") { whisperParams[key] = parseInt(value); + } else if (key === "detect_language") { + whisperParams[key] = value === "true"; } else { whisperParams[key] = value; }