mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-08 10:47:16 +02:00
node : add language detection support (#3190)
This commit add support for language detection in the Whisper Node.js addon example. It also updates the node addon to return an object instead of an array as the results. The motivation for this change is to enable the inclusion of the detected language in the result, in addition to the transcription segments. For example, when using the `detect_language` option, the result will now be: ```console { language: 'en' } ``` And if the `language` option is set to "auto", it will also return: ```console { language: 'en', transcription: [ [ '00:00:00.000', '00:00:07.600', ' And so my fellow Americans, ask not what your country can do for you,' ], [ '00:00:07.600', '00:00:10.600', ' ask what you can do for your country.' ] ] } ```
This commit is contained in:
parent
7fd6fa8097
commit
b505539670
@ -17,6 +17,7 @@ const whisperParamsMock = {
|
||||
comma_in_time: false,
|
||||
translate: true,
|
||||
no_timestamps: false,
|
||||
detect_language: false,
|
||||
audio_ctx: 0,
|
||||
max_len: 0,
|
||||
prompt: "",
|
||||
@ -30,8 +31,9 @@ const whisperParamsMock = {
|
||||
describe("Run whisper.node", () => {
|
||||
test("it should receive a non-empty value", async () => {
|
||||
let result = await whisperAsync(whisperParamsMock);
|
||||
console.log(result);
|
||||
|
||||
expect(result.length).toBeGreaterThan(0);
|
||||
expect(result['transcription'].length).toBeGreaterThan(0);
|
||||
}, 10000);
|
||||
});
|
||||
|
||||
|
@ -38,6 +38,7 @@ struct whisper_params {
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool no_prints = false;
|
||||
bool detect_language= false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
bool comma_in_time = true;
|
||||
@ -130,6 +131,11 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
|
||||
|
||||
void cb_log_disable(enum ggml_log_level, const char *, void *) {}
|
||||
|
||||
struct whisper_result {
|
||||
std::vector<std::vector<std::string>> segments;
|
||||
std::string language;
|
||||
};
|
||||
|
||||
class ProgressWorker : public Napi::AsyncWorker {
|
||||
public:
|
||||
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
|
||||
@ -160,15 +166,27 @@ class ProgressWorker : public Napi::AsyncWorker {
|
||||
|
||||
void OnOK() override {
|
||||
Napi::HandleScope scope(Env());
|
||||
Napi::Object res = Napi::Array::New(Env(), result.size());
|
||||
for (uint64_t i = 0; i < result.size(); ++i) {
|
||||
|
||||
if (params.detect_language) {
|
||||
Napi::Object resultObj = Napi::Object::New(Env());
|
||||
resultObj.Set("language", Napi::String::New(Env(), result.language));
|
||||
Callback().Call({Env().Null(), resultObj});
|
||||
}
|
||||
|
||||
Napi::Object returnObj = Napi::Object::New(Env());
|
||||
if (!result.language.empty()) {
|
||||
returnObj.Set("language", Napi::String::New(Env(), result.language));
|
||||
}
|
||||
Napi::Array transcriptionArray = Napi::Array::New(Env(), result.segments.size());
|
||||
for (uint64_t i = 0; i < result.segments.size(); ++i) {
|
||||
Napi::Object tmp = Napi::Array::New(Env(), 3);
|
||||
for (uint64_t j = 0; j < 3; ++j) {
|
||||
tmp[j] = Napi::String::New(Env(), result[i][j]);
|
||||
tmp[j] = Napi::String::New(Env(), result.segments[i][j]);
|
||||
}
|
||||
res[i] = tmp;
|
||||
transcriptionArray[i] = tmp;
|
||||
}
|
||||
Callback().Call({Env().Null(), res});
|
||||
returnObj.Set("transcription", transcriptionArray);
|
||||
Callback().Call({Env().Null(), returnObj});
|
||||
}
|
||||
|
||||
// Progress callback function - using thread-safe function
|
||||
@ -185,12 +203,12 @@ class ProgressWorker : public Napi::AsyncWorker {
|
||||
|
||||
private:
|
||||
whisper_params params;
|
||||
std::vector<std::vector<std::string>> result;
|
||||
whisper_result result;
|
||||
Napi::Env env;
|
||||
Napi::ThreadSafeFunction tsfn;
|
||||
|
||||
// Custom run function with progress callback support
|
||||
int run_with_progress(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
||||
int run_with_progress(whisper_params ¶ms, whisper_result & result) {
|
||||
if (params.no_prints) {
|
||||
whisper_log_set(cb_log_disable, NULL);
|
||||
}
|
||||
@ -279,7 +297,8 @@ class ProgressWorker : public Napi::AsyncWorker {
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special = params.print_special;
|
||||
wparams.translate = params.translate;
|
||||
wparams.language = params.language.c_str();
|
||||
wparams.language = params.detect_language ? "auto" : params.language.c_str();
|
||||
wparams.detect_language = params.detect_language;
|
||||
wparams.n_threads = params.n_threads;
|
||||
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
||||
wparams.offset_ms = params.offset_t_ms;
|
||||
@ -332,16 +351,20 @@ class ProgressWorker : public Napi::AsyncWorker {
|
||||
}
|
||||
}
|
||||
|
||||
if (params.detect_language || params.language == "auto") {
|
||||
result.language = whisper_lang_str(whisper_full_lang_id(ctx));
|
||||
}
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
result.resize(n_segments);
|
||||
result.segments.resize(n_segments);
|
||||
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
|
||||
result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
|
||||
result[i].emplace_back(text);
|
||||
result.segments[i].emplace_back(to_timestamp(t0, params.comma_in_time));
|
||||
result.segments[i].emplace_back(to_timestamp(t1, params.comma_in_time));
|
||||
result.segments[i].emplace_back(text);
|
||||
}
|
||||
|
||||
whisper_print_timings(ctx);
|
||||
@ -366,6 +389,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
|
||||
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
|
||||
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
|
||||
bool detect_language = whisper_params.Get("detect_language").As<Napi::Boolean>();
|
||||
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
|
||||
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
|
||||
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
|
||||
@ -418,6 +442,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
|
||||
params.max_context = max_context;
|
||||
params.print_progress = print_progress;
|
||||
params.prompt = prompt;
|
||||
params.detect_language = detect_language;
|
||||
|
||||
Napi::Function callback = info[1].As<Napi::Function>();
|
||||
// Create a new Worker class with progress callback support
|
||||
|
@ -17,6 +17,7 @@ const whisperParams = {
|
||||
comma_in_time: false,
|
||||
translate: true,
|
||||
no_timestamps: false,
|
||||
detect_language: false,
|
||||
audio_ctx: 0,
|
||||
max_len: 0,
|
||||
progress_callback: (progress) => {
|
||||
@ -31,6 +32,8 @@ const params = Object.fromEntries(
|
||||
const [key, value] = item.slice(2).split("=");
|
||||
if (key === "audio_ctx") {
|
||||
whisperParams[key] = parseInt(value);
|
||||
} else if (key === "detect_language") {
|
||||
whisperParams[key] = value === "true";
|
||||
} else {
|
||||
whisperParams[key] = value;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user