From 0463028bc2a5774fe7361c8ac37bef440725bcd7 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Mon, 6 Nov 2023 17:04:24 +0800 Subject: [PATCH] whisper : add context param to disable gpu (#1293) * whisper : check state->ctx_metal not null * whisper : add whisper_context_params { use_gpu } * whisper : new API with params & deprecate old API * examples : use no-gpu param && whisper_init_from_file_with_params * whisper.objc : enable metal & disable on simulator * whisper.swiftui, metal : enable metal & support load default.metallib * whisper.android : use new API * bindings : use new API * addon.node : fix build & test * bindings : updata java binding * bindings : add missing whisper_context_default_params_by_ref WHISPER_API for java * metal : use SWIFTPM_MODULE_BUNDLE for GGML_SWIFT and reuse library load * metal : move bundle var into block * metal : use SWIFT_PACKAGE instead of GGML_SWIFT * style : minor updates --------- Co-authored-by: Georgi Gerganov --- bindings/go/whisper.go | 2 +- .../ggerganov/whispercpp/WhisperContext.java | 4 +- .../ggerganov/whispercpp/WhisperCpp.java | 61 ++++-- .../whispercpp/WhisperCppJnaLibrary.java | 22 +- .../params/WhisperContextParams.java | 31 +++ bindings/javascript/emscripten.cpp | 2 +- bindings/ruby/ext/ruby_whisper.cpp | 2 +- examples/addon.node/__test__/whisper.spec.js | 1 + examples/addon.node/addon.cpp | 7 +- examples/addon.node/index.js | 1 + examples/bench.wasm/emscripten.cpp | 2 +- examples/bench/bench.cpp | 15 +- examples/command.wasm/emscripten.cpp | 2 +- examples/command/command.cpp | 8 +- examples/lsp/lsp.cpp | 7 +- examples/main/main.cpp | 8 +- examples/stream.wasm/emscripten.cpp | 2 +- examples/stream/stream.cpp | 52 +++-- examples/talk-llama/talk-llama.cpp | 92 ++++---- examples/talk.wasm/emscripten.cpp | 2 +- examples/talk/talk.cpp | 7 +- .../app/src/main/jni/whisper/jni.c | 4 +- .../whisper.objc.xcodeproj/project.pbxproj | 4 +- .../whisper.objc/ViewController.m | 8 +- .../whisper.cpp.swift/LibWhisper.swift | 7 +- .../whisper.swiftui.xcodeproj/project.pbxproj | 14 +- examples/whisper.wasm/emscripten.cpp | 2 +- whisper.cpp | 204 +++++++++++------- whisper.h | 54 ++++- 29 files changed, 439 insertions(+), 188 deletions(-) create mode 100644 bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index e605d8e0..b77e103c 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -103,7 +103,7 @@ var ( func Whisper_init(path string) *Context { cPath := C.CString(path) defer C.free(unsafe.Pointer(cPath)) - if ctx := C.whisper_init_from_file(cPath); ctx != nil { + if ctx := C.whisper_init_from_file_with_params(cPath, C.whisper_context_default_params()); ctx != nil { return (*Context)(ctx) } else { return nil diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java index 22d4ce87..0498eb4d 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java @@ -4,6 +4,7 @@ import com.sun.jna.Structure; import com.sun.jna.ptr.PointerByReference; import io.github.ggerganov.whispercpp.ggml.GgmlType; import io.github.ggerganov.whispercpp.WhisperModel; +import io.github.ggerganov.whispercpp.params.WhisperContextParams; import java.util.List; @@ -23,8 +24,9 @@ public class WhisperContext extends Structure { public PointerByReference vocab; public PointerByReference state; - /** populated by whisper_init_from_file() */ + /** populated by whisper_init_from_file_with_params() */ String path_model; + WhisperContextParams params; // public static class ByReference extends WhisperContext implements Structure.ByReference { // } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java index 9bc1a860..4a250403 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java @@ -2,6 +2,7 @@ package io.github.ggerganov.whispercpp; import com.sun.jna.Native; import com.sun.jna.Pointer; +import io.github.ggerganov.whispercpp.params.WhisperContextParams; import io.github.ggerganov.whispercpp.params.WhisperFullParams; import io.github.ggerganov.whispercpp.params.WhisperSamplingStrategy; @@ -15,8 +16,9 @@ import java.io.IOException; public class WhisperCpp implements AutoCloseable { private WhisperCppJnaLibrary lib = WhisperCppJnaLibrary.instance; private Pointer ctx = null; - private Pointer greedyPointer = null; - private Pointer beamPointer = null; + private Pointer paramsPointer = null; + private Pointer greedyParamsPointer = null; + private Pointer beamParamsPointer = null; public File modelDir() { String modelDirPath = System.getenv("XDG_CACHE_HOME"); @@ -31,6 +33,18 @@ public class WhisperCpp implements AutoCloseable { * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en") */ public void initContext(String modelPath) throws FileNotFoundException { + initContextImpl(modelPath, getContextDefaultParams()); + } + + /** + * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en") + * @param params - params to use when initialising the context + */ + public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException { + initContextImpl(modelPath, params); + } + + private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException { if (ctx != null) { lib.whisper_free(ctx); } @@ -43,13 +57,26 @@ public class WhisperCpp implements AutoCloseable { modelPath = new File(modelDir(), modelPath).getAbsolutePath(); } - ctx = lib.whisper_init_from_file(modelPath); + ctx = lib.whisper_init_from_file_with_params(modelPath, params); if (ctx == null) { throw new FileNotFoundException(modelPath); } } + /** + * Provides default params which can be used with `whisper_init_from_file_with_params()` etc. + * Because this function allocates memory for the params, the caller must call either: + * - call `whisper_free_context_params()` + * - `Native.free(Pointer.nativeValue(pointer));` + */ + public WhisperContextParams getContextDefaultParams() { + paramsPointer = lib.whisper_context_default_params_by_ref(); + WhisperContextParams params = new WhisperContextParams(paramsPointer); + params.read(); + return params; + } + /** * Provides default params which can be used with `whisper_full()` etc. * Because this function allocates memory for the params, the caller must call either: @@ -63,15 +90,15 @@ public class WhisperCpp implements AutoCloseable { // whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy. if (strategy == WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY) { - if (greedyPointer == null) { - greedyPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal()); + if (greedyParamsPointer == null) { + greedyParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal()); } - pointer = greedyPointer; + pointer = greedyParamsPointer; } else { - if (beamPointer == null) { - beamPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal()); + if (beamParamsPointer == null) { + beamParamsPointer = lib.whisper_full_default_params_by_ref(strategy.ordinal()); } - pointer = beamPointer; + pointer = beamParamsPointer; } WhisperFullParams params = new WhisperFullParams(pointer); @@ -93,13 +120,17 @@ public class WhisperCpp implements AutoCloseable { } private void freeParams() { - if (greedyPointer != null) { - Native.free(Pointer.nativeValue(greedyPointer)); - greedyPointer = null; + if (paramsPointer != null) { + Native.free(Pointer.nativeValue(paramsPointer)); + paramsPointer = null; } - if (beamPointer != null) { - Native.free(Pointer.nativeValue(beamPointer)); - beamPointer = null; + if (greedyParamsPointer != null) { + Native.free(Pointer.nativeValue(greedyParamsPointer)); + greedyParamsPointer = null; + } + if (beamParamsPointer != null) { + Native.free(Pointer.nativeValue(beamParamsPointer)); + beamParamsPointer = null; } } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java index ad9faa0b..56a37380 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -5,6 +5,7 @@ import com.sun.jna.Native; import com.sun.jna.Pointer; import io.github.ggerganov.whispercpp.model.WhisperModelLoader; import io.github.ggerganov.whispercpp.model.WhisperTokenData; +import io.github.ggerganov.whispercpp.params.WhisperContextParams; import io.github.ggerganov.whispercpp.params.WhisperFullParams; public interface WhisperCppJnaLibrary extends Library { @@ -13,12 +14,31 @@ public interface WhisperCppJnaLibrary extends Library { String whisper_print_system_info(); /** - * Allocate (almost) all memory needed for the model by loading from a file. + * DEPRECATED. Allocate (almost) all memory needed for the model by loading from a file. * * @param path_model Path to the model file * @return Whisper context on success, null on failure */ Pointer whisper_init_from_file(String path_model); + + /** + * Provides default params which can be used with `whisper_init_from_file_with_params()` etc. + * Because this function allocates memory for the params, the caller must call either: + * - call `whisper_free_context_params()` + * - `Native.free(Pointer.nativeValue(pointer));` + */ + Pointer whisper_context_default_params_by_ref(); + + void whisper_free_context_params(Pointer params); + + /** + * Allocate (almost) all memory needed for the model by loading from a file. + * + * @param path_model Path to the model file + * @param params Pointer to whisper_context_params + * @return Whisper context on success, null on failure + */ + Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params); /** * Allocate (almost) all memory needed for the model by loading from a buffer. diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java new file mode 100644 index 00000000..cf98d2c3 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java @@ -0,0 +1,31 @@ +package io.github.ggerganov.whispercpp.params; + +import com.sun.jna.*; + +import java.util.Arrays; +import java.util.List; + +/** + * Parameters for the whisper_init_from_file_with_params() function. + * If you change the order or add new parameters, make sure to update the default values in whisper.cpp: + * whisper_context_default_params() + */ +public class WhisperContextParams extends Structure { + + public WhisperContextParams(Pointer p) { + super(p); + } + + /** Use GPU for inference Number (default = true) */ + public CBool use_gpu; + + /** Use GPU for inference Number (default = true) */ + public void useGpu(boolean enable) { + use_gpu = enable ? CBool.TRUE : CBool.FALSE; + } + + @Override + protected List getFieldOrder() { + return Arrays.asList("use_gpu"); + } +} diff --git a/bindings/javascript/emscripten.cpp b/bindings/javascript/emscripten.cpp index 789ad8b5..b442c1fc 100644 --- a/bindings/javascript/emscripten.cpp +++ b/bindings/javascript/emscripten.cpp @@ -20,7 +20,7 @@ struct whisper_context * g_context; EMSCRIPTEN_BINDINGS(whisper) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { if (g_context == nullptr) { - g_context = whisper_init_from_file(path_model.c_str()); + g_context = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); if (g_context != nullptr) { return true; } else { diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 82027d42..86af9391 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -87,7 +87,7 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } - rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path)); + rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); if (rw->context == nullptr) { rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); } diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index 845af2f0..d102fe76 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -11,6 +11,7 @@ const whisperParamsMock = { language: "en", model: path.join(__dirname, "../../../models/ggml-base.en.bin"), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), + use_gpu: true, }; describe("Run whisper.node", () => { diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 52e80ad8..30acbc6a 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -36,6 +36,7 @@ struct whisper_params { bool print_colors = false; bool print_progress = false; bool no_timestamps = false; + bool use_gpu = true; std::string language = "en"; std::string prompt; @@ -153,7 +154,9 @@ int run(whisper_params ¶ms, std::vector> &result) { // whisper init - struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (ctx == nullptr) { fprintf(stderr, "error: failed to initialize whisper context\n"); @@ -315,10 +318,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { std::string language = whisper_params.Get("language").As(); std::string model = whisper_params.Get("model").As(); std::string input = whisper_params.Get("fname_inp").As(); + bool use_gpu = whisper_params.Get("use_gpu").As(); params.language = language; params.model = model; params.fname_inp.emplace_back(input); + params.use_gpu = use_gpu; Napi::Function callback = info[1].As(); Worker* worker = new Worker(callback, params); diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index d511cdc2..3c642937 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -11,6 +11,7 @@ const whisperParams = { language: "en", model: path.join(__dirname, "../../models/ggml-base.en.bin"), fname_inp: "../../samples/jfk.wav", + use_gpu: true, }; const arguments = process.argv.slice(2); diff --git a/examples/bench.wasm/emscripten.cpp b/examples/bench.wasm/emscripten.cpp index 09e9d55d..3624bbc4 100644 --- a/examples/bench.wasm/emscripten.cpp +++ b/examples/bench.wasm/emscripten.cpp @@ -57,7 +57,7 @@ EMSCRIPTEN_BINDINGS(bench) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { for (size_t i = 0; i < g_contexts.size(); ++i) { if (g_contexts[i] == nullptr) { - g_contexts[i] = whisper_init_from_file(path_model.c_str()); + g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); if (g_contexts[i] != nullptr) { if (g_worker.joinable()) { g_worker.join(); diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index dfb1d11b..9f50b3b6 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -11,6 +11,8 @@ struct whisper_params { int32_t what = 0; // what to benchmark: 0 - whisper ecoder, 1 - memcpy, 2 - ggml_mul_mat std::string model = "models/ggml-base.en.bin"; + + bool use_gpu = true; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -23,9 +25,10 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-w" || arg == "--what") { params.what = atoi(argv[++i]); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -45,6 +48,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -w N, --what N [%-7d] what to benchmark:\n", params.what); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " %-7s 0 - whisper\n", ""); fprintf(stderr, " %-7s 1 - memcpy\n", ""); fprintf(stderr, " %-7s 2 - ggml_mul_mat\n", ""); @@ -54,7 +58,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para int whisper_bench_full(const whisper_params & params) { // whisper init - struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; + + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); { fprintf(stderr, "\n"); diff --git a/examples/command.wasm/emscripten.cpp b/examples/command.wasm/emscripten.cpp index e739656d..528ff6ab 100644 --- a/examples/command.wasm/emscripten.cpp +++ b/examples/command.wasm/emscripten.cpp @@ -243,7 +243,7 @@ EMSCRIPTEN_BINDINGS(command) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { for (size_t i = 0; i < g_contexts.size(); ++i) { if (g_contexts[i] == nullptr) { - g_contexts[i] = whisper_init_from_file(path_model.c_str()); + g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); if (g_contexts[i] != nullptr) { g_running = true; if (g_worker.joinable()) { diff --git a/examples/command/command.cpp b/examples/command/command.cpp index d39af730..7045f5ff 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -38,6 +38,7 @@ struct whisper_params { bool print_special = false; bool print_energy = false; bool no_timestamps = true; + bool use_gpu = true; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -68,6 +69,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } @@ -101,6 +103,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); @@ -610,7 +613,10 @@ int main(int argc, char ** argv) { // whisper init - struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; + + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); // print some info about the processing { diff --git a/examples/lsp/lsp.cpp b/examples/lsp/lsp.cpp index b8001b95..8d8b6ffa 100644 --- a/examples/lsp/lsp.cpp +++ b/examples/lsp/lsp.cpp @@ -30,6 +30,7 @@ struct whisper_params { bool translate = false; bool print_special = false; bool print_energy = false; + bool use_gpu = true; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; @@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } else { @@ -102,6 +104,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, "\n"); @@ -432,7 +435,9 @@ int main(int argc, char ** argv) { } // whisper init - struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); // init audio audio_async audio(30*1000); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bed0789f..e43dfe3f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -90,6 +90,7 @@ struct whisper_params { bool print_progress = false; bool no_timestamps = false; bool log_score = false; + bool use_gpu = true; std::string language = "en"; std::string prompt; @@ -165,6 +166,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -221,6 +223,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, "\n"); } @@ -877,7 +880,10 @@ int main(int argc, char ** argv) { // whisper init - struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; + + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (ctx == nullptr) { fprintf(stderr, "error: failed to initialize whisper context\n"); diff --git a/examples/stream.wasm/emscripten.cpp b/examples/stream.wasm/emscripten.cpp index 144a14d2..71acffba 100644 --- a/examples/stream.wasm/emscripten.cpp +++ b/examples/stream.wasm/emscripten.cpp @@ -132,7 +132,7 @@ EMSCRIPTEN_BINDINGS(stream) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { for (size_t i = 0; i < g_contexts.size(); ++i) { if (g_contexts[i] == nullptr) { - g_contexts[i] = whisper_init_from_file(path_model.c_str()); + g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); if (g_contexts[i] != nullptr) { g_running = true; if (g_worker.joinable()) { diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index c8a452d1..47f1780b 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -48,11 +48,12 @@ struct whisper_params { bool no_context = true; bool no_timestamps = false; bool tinydiarize = false; + bool save_audio = false; // save audio to wav file + bool use_gpu = true; std::string language = "en"; std::string model = "models/ggml-base.en.bin"; std::string fname_out; - bool save_audio = false; // save audio to wav file }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); @@ -65,25 +66,26 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); } - else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); } - else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); } - else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } - else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } - else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } - else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } - else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if ( arg == "--step") { params.step_ms = std::stoi(argv[++i]); } + else if ( arg == "--length") { params.length_ms = std::stoi(argv[++i]); } + else if ( arg == "--keep") { params.keep_ms = std::stoi(argv[++i]); } + else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } + else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } + else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); @@ -118,8 +120,9 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); - fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true"); fprintf(stderr, "\n"); } @@ -163,7 +166,10 @@ int main(int argc, char ** argv) { exit(0); } - struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; + + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); std::vector pcmf32 (n_samples_30s, 0.0f); std::vector pcmf32_old; @@ -424,4 +430,4 @@ int main(int argc, char ** argv) { whisper_free(ctx); return 0; -} \ No newline at end of file +} diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index e497690e..6cc30c16 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -63,6 +63,7 @@ struct whisper_params { bool print_energy = false; bool no_timestamps = true; bool verbose_prompt = false; + bool use_gpu = true; std::string person = "Georgi"; std::string language = "en"; @@ -84,25 +85,26 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { whisper_print_usage(argc, argv, params); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); } - else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } - else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } - else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } - else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } - else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } - else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } - else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } - else if (arg == "--session") { params.path_session = argv[++i];} - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } - else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } - else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; } - else if (arg == "--prompt-file") { + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); } + else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } + else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } + else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } + else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } + else if (arg == "--session") { params.path_session = argv[++i];} + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } + else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } + else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; } + else if (arg == "--prompt-file") { std::ifstream file(argv[++i]); std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); if (params.prompt.back() == '\n') { @@ -110,6 +112,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { } } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); @@ -125,27 +128,28 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms); - fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); - fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); - fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); - fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); - fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); - fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); - fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); - fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str()); - fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str()); - fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", ""); - fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n"); - fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false"); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms); + fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); + fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); + fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); + fprintf(stderr, " -vp, --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); + fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str()); + fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str()); + fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", ""); + fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n"); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, "\n"); } @@ -252,7 +256,10 @@ int main(int argc, char ** argv) { // whisper init - struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str()); + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; + + struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams); // llama init @@ -269,6 +276,9 @@ int main(int argc, char ** argv) { lcparams.seed = 1; lcparams.f16_kv = true; lcparams.n_threads = params.n_threads; + if (!params.use_gpu) { + lcparams.n_gpu_layers = 0; + } struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams); diff --git a/examples/talk.wasm/emscripten.cpp b/examples/talk.wasm/emscripten.cpp index 1ea97029..6d30b295 100644 --- a/examples/talk.wasm/emscripten.cpp +++ b/examples/talk.wasm/emscripten.cpp @@ -271,7 +271,7 @@ EMSCRIPTEN_BINDINGS(talk) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { for (size_t i = 0; i < g_contexts.size(); ++i) { if (g_contexts[i] == nullptr) { - g_contexts[i] = whisper_init_from_file(path_model.c_str()); + g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); if (g_contexts[i] != nullptr) { g_running = true; if (g_worker.joinable()) { diff --git a/examples/talk/talk.cpp b/examples/talk/talk.cpp index 346d9d48..cdb1a230 100644 --- a/examples/talk/talk.cpp +++ b/examples/talk/talk.cpp @@ -31,6 +31,7 @@ struct whisper_params { bool print_special = false; bool print_energy = false; bool no_timestamps = true; + bool use_gpu = true; std::string person = "Santa"; std::string language = "en"; @@ -61,6 +62,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } @@ -94,6 +96,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); @@ -181,8 +184,10 @@ int main(int argc, char ** argv) { } // whisper init + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; - struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str()); + struct whisper_context * ctx_wsp = whisper_init_from_file_with_params(params.model_wsp.c_str(), cparams); // gpt init diff --git a/examples/whisper.android/app/src/main/jni/whisper/jni.c b/examples/whisper.android/app/src/main/jni/whisper/jni.c index c437d099..a8b3ded4 100644 --- a/examples/whisper.android/app/src/main/jni/whisper/jni.c +++ b/examples/whisper.android/app/src/main/jni/whisper/jni.c @@ -127,7 +127,7 @@ static struct whisper_context *whisper_init_from_asset( .close = &asset_close }; - return whisper_init(&loader); + return whisper_init_with_params(&loader, whisper_context_default_params()); } JNIEXPORT jlong JNICALL @@ -147,7 +147,7 @@ Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContext( UNUSED(thiz); struct whisper_context *context = NULL; const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL); - context = whisper_init_from_file(model_path_chars); + context = whisper_init_from_file_with_params(model_path_chars, whisper_context_default_params()); (*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars); return (jlong) context; } diff --git a/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj b/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj index 06af23e6..fd884cf3 100644 --- a/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj +++ b/examples/whisper.objc/whisper.objc.xcodeproj/project.pbxproj @@ -17,8 +17,8 @@ 18627C8629052BE000BD2A04 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8529052BE000BD2A04 /* Assets.xcassets */; }; 18627C8929052BE000BD2A04 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 18627C8729052BE000BD2A04 /* LaunchScreen.storyboard */; }; 18627C8C29052BE000BD2A04 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = 18627C8B29052BE000BD2A04 /* main.m */; }; - 18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML"; }; }; - 18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE"; }; }; + 18627C9429052C4900BD2A04 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9329052C4900BD2A04 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK -DGGML_USE_METAL"; }; }; + 18627C9629052C5800BD2A04 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 18627C9529052C5800BD2A04 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -DGGML_USE_METAL"; }; }; 18627C9B29052CFF00BD2A04 /* ggml-base.en.bin in Resources */ = {isa = PBXBuildFile; fileRef = 18627C9A29052CFF00BD2A04 /* ggml-base.en.bin */; }; 18ABE15A2AF556340044A204 /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1572AF556340044A204 /* ggml-backend.c */; }; 18ABE15B2AF556340044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1592AF556340044A204 /* ggml-quants.c */; }; diff --git a/examples/whisper.objc/whisper.objc/ViewController.m b/examples/whisper.objc/whisper.objc/ViewController.m index 8a1e876c..151b05d9 100644 --- a/examples/whisper.objc/whisper.objc/ViewController.m +++ b/examples/whisper.objc/whisper.objc/ViewController.m @@ -61,7 +61,13 @@ void AudioInputCallback(void * inUserData, NSLog(@"Loading model from %@", modelPath); // create ggml context - stateInp.ctx = whisper_init_from_file([modelPath UTF8String]); + + struct whisper_context_params cparams = whisper_context_default_params(); +#if TARGET_OS_SIMULATOR + cparams.use_gpu = false; + NSLog(@"Running on simulator, using CPU"); +#endif + stateInp.ctx = whisper_init_from_file_with_params([modelPath UTF8String], cparams); // check if the model was loaded successfully if (stateInp.ctx == NULL) { diff --git a/examples/whisper.swiftui/whisper.cpp.swift/LibWhisper.swift b/examples/whisper.swiftui/whisper.cpp.swift/LibWhisper.swift index e9645b34..95e1aeef 100644 --- a/examples/whisper.swiftui/whisper.cpp.swift/LibWhisper.swift +++ b/examples/whisper.swiftui/whisper.cpp.swift/LibWhisper.swift @@ -55,7 +55,12 @@ actor WhisperContext { } static func createContext(path: String) throws -> WhisperContext { - let context = whisper_init_from_file(path) + var params = whisper_context_default_params() +#if targetEnvironment(simulator) + params.use_gpu = false + print("Running on the simulator, using CPU") +#endif + let context = whisper_init_from_file_with_params(path, params) if let context { return WhisperContext(context: context) } else { diff --git a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj index 832a2a1b..605240da 100644 --- a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj +++ b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj @@ -16,13 +16,15 @@ 0AAC5D9D29539CCF003032C3 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9C29539CCF003032C3 /* ContentView.swift */; }; 0AAC5D9F29539CD0003032C3 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0AAC5D9E29539CD0003032C3 /* Assets.xcassets */; }; 0AAC5DA329539CD0003032C3 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */; }; - 0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC729539EB0003032C3 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-Wno-shorten-64-to-32"; }; }; - 0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -Wno-shorten-64-to-32"; }; }; + 0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC729539EB0003032C3 /* whisper.cpp */; settings = {COMPILER_FLAGS = "-DGGML_USE_METAL -Wno-shorten-64-to-32"; }; }; + 0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DC929539EB0003032C3 /* ggml.c */; settings = {COMPILER_FLAGS = "-DGGML_USE_ACCELERATE -DGGML_USE_METAL -Wno-shorten-64-to-32"; }; }; 0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */; }; 0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */ = {isa = PBXBuildFile; fileRef = 0AAC5DD02953A394003032C3 /* LibWhisper.swift */; }; 18ABE1522AF555FA0044A204 /* ggml-backend.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE14C2AF555FA0044A204 /* ggml-backend.c */; }; 18ABE1532AF555FA0044A204 /* ggml-quants.c in Sources */ = {isa = PBXBuildFile; fileRef = 18ABE1512AF555FA0044A204 /* ggml-quants.c */; }; 18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */ = {isa = PBXBuildFile; fileRef = 18AED47F2AB21F2B009D854F /* ggml-alloc.c */; }; + 7FCB08262ACFA3A400AF3530 /* ggml-metal.m in Sources */ = {isa = PBXBuildFile; fileRef = 7FCB08252ACFA3A400AF3530 /* ggml-metal.m */; settings = {COMPILER_FLAGS = "-framework Foundation -framework Metal -framework MetalKit -fno-objc-arc"; }; }; + 7FCB08282ACFA48500AF3530 /* ggml-metal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 7FCB08272ACFA48500AF3530 /* ggml-metal.metal */; }; /* End PBXBuildFile section */ /* Begin PBXFileReference section */ @@ -52,6 +54,9 @@ 18ABE1512AF555FA0044A204 /* ggml-quants.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-quants.c"; sourceTree = ""; }; 18AED47F2AB21F2B009D854F /* ggml-alloc.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = "ggml-alloc.c"; sourceTree = ""; }; 18AED4802AB21F2B009D854F /* ggml-alloc.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-alloc.h"; sourceTree = ""; }; + 7FCB081E2ACFA04400AF3530 /* ggml-metal.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "ggml-metal.h"; sourceTree = ""; }; + 7FCB08252ACFA3A400AF3530 /* ggml-metal.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = "ggml-metal.m"; sourceTree = ""; }; + 7FCB08272ACFA48500AF3530 /* ggml-metal.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = "ggml-metal.metal"; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -135,6 +140,9 @@ 0AAC5DC529539E89003032C3 /* whisper.cpp */ = { isa = PBXGroup; children = ( + 7FCB08272ACFA48500AF3530 /* ggml-metal.metal */, + 7FCB081E2ACFA04400AF3530 /* ggml-metal.h */, + 7FCB08252ACFA3A400AF3530 /* ggml-metal.m */, 18ABE14E2AF555FA0044A204 /* ggml-backend-impl.h */, 18ABE14C2AF555FA0044A204 /* ggml-backend.c */, 18ABE14D2AF555FA0044A204 /* ggml-backend.h */, @@ -258,10 +266,12 @@ 0AAC5DCC29539EB1003032C3 /* ggml.c in Sources */, 18ABE1532AF555FA0044A204 /* ggml-quants.c in Sources */, 0AAC5DCE2953A05C003032C3 /* WhisperState.swift in Sources */, + 7FCB08282ACFA48500AF3530 /* ggml-metal.metal in Sources */, 0AAC5DD12953A394003032C3 /* LibWhisper.swift in Sources */, 0AA7514C2953B569001EE061 /* RiffWaveUtils.swift in Sources */, 0AAC5DCB29539EB1003032C3 /* whisper.cpp in Sources */, 0AA7514E2953D958001EE061 /* Recorder.swift in Sources */, + 7FCB08262ACFA3A400AF3530 /* ggml-metal.m in Sources */, 18AED4812AB21F2B009D854F /* ggml-alloc.c in Sources */, 18ABE1522AF555FA0044A204 /* ggml-backend.c in Sources */, ); diff --git a/examples/whisper.wasm/emscripten.cpp b/examples/whisper.wasm/emscripten.cpp index db1ff789..b84893de 100644 --- a/examples/whisper.wasm/emscripten.cpp +++ b/examples/whisper.wasm/emscripten.cpp @@ -24,7 +24,7 @@ EMSCRIPTEN_BINDINGS(whisper) { for (size_t i = 0; i < g_contexts.size(); ++i) { if (g_contexts[i] == nullptr) { - g_contexts[i] = whisper_init_from_file(path_model.c_str()); + g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); if (g_contexts[i] != nullptr) { return i + 1; } else { diff --git a/whisper.cpp b/whisper.cpp index 3e36d362..70446049 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -736,7 +736,7 @@ struct whisper_state { int lang_id = 0; // english by default - std::string path_model; // populated by whisper_init_from_file() + std::string path_model; // populated by whisper_init_from_file_with_params() #ifdef WHISPER_USE_COREML whisper_coreml_context * ctx_coreml = nullptr; #endif @@ -770,7 +770,8 @@ struct whisper_context { whisper_vocab vocab; whisper_state * state = nullptr; - std::string path_model; // populated by whisper_init_from_file() + std::string path_model; // populated by whisper_init_from_file_with_params() + whisper_context_params params; }; static void whisper_default_log(const char * text) { @@ -2930,59 +2931,64 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } #ifdef GGML_USE_METAL - state->ctx_metal = ggml_metal_init(1); - if (!state->ctx_metal) { - log("%s: ggml_metal_init() failed\n", __func__); - delete state; - return nullptr; + if (ctx->params.use_gpu) { + state->ctx_metal = ggml_metal_init(1); + if (!state->ctx_metal) { + log("%s: ggml_metal_init() failed\n", __func__); + delete state; + return nullptr; + } } - log("%s: Metal context initialized\n", __func__); + if (state->ctx_metal) { + log("%s: Metal context initialized\n", __func__); - // this allocates all Metal resources and memory buffers + // this allocates all Metal resources and memory buffers - void * data_ptr = NULL; - size_t data_size = 0; + void * data_ptr = NULL; + size_t data_size = 0; - // TODO: add mmap support - //if (params.use_mmap) { - // data_ptr = ctx->model.mapping->addr; - // data_size = ctx->model.mapping->size; - //} else { - // data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - // data_size = ggml_get_mem_size (ctx->model.ctx); - //} + // TODO: add mmap support + //if (params.use_mmap) { + // data_ptr = ctx->model.mapping->addr; + // data_size = ctx->model.mapping->size; + //} else { + // data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + // data_size = ggml_get_mem_size (ctx->model.ctx); + //} - data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - data_size = ggml_get_mem_size (ctx->model.ctx); + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size = ggml_get_mem_size (ctx->model.ctx); - const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); + const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); - log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); + log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); #define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - delete state; \ - return nullptr; \ - } + if (!(result)) { \ + log("%s: failed to add metal buffer\n", __func__); \ + delete state; \ + return nullptr; \ + } - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); #undef WHISPER_METAL_CHECK_BUF + + } #endif state->rng = std::mt19937(0); @@ -3039,7 +3045,14 @@ int whisper_ctx_init_openvino_encoder( #endif } -struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { +struct whisper_context_params whisper_context_default_params() { + struct whisper_context_params result = { + /*.use_gpu =*/ true, + }; + return result; +} + +struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) { log("%s: loading model from '%s'\n", __func__, path_model); auto fin = std::ifstream(path_model, std::ios::binary); @@ -3068,7 +3081,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model fin->close(); }; - auto ctx = whisper_init_no_state(&loader); + auto ctx = whisper_init_with_params_no_state(&loader, params); if (ctx) { ctx->path_model = path_model; @@ -3077,7 +3090,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model return ctx; } -struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { +struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) { struct buf_context { uint8_t* buffer; size_t size; @@ -3111,13 +3124,14 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t loader.close = [](void * /*ctx*/) { }; - return whisper_init_no_state(&loader); + return whisper_init_with_params_no_state(&loader, params); } -struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { +struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) { ggml_time_init(); whisper_context * ctx = new whisper_context; + ctx->params = params; if (!whisper_model_load(loader, *ctx)) { loader->close(loader->context); @@ -3131,8 +3145,8 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa return ctx; } -struct whisper_context * whisper_init_from_file(const char * path_model) { - whisper_context * ctx = whisper_init_from_file_no_state(path_model); +struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params); if (!ctx) { return nullptr; } @@ -3146,34 +3160,58 @@ struct whisper_context * whisper_init_from_file(const char * path_model) { return ctx; } +struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_with_params_no_state(loader, params); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init_from_file(const char * path_model) { + return whisper_init_from_file_with_params(path_model, whisper_context_default_params()); +} + struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { - whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size); - if (!ctx) { - return nullptr; - } - - ctx->state = whisper_init_state(ctx); - if (!ctx->state) { - whisper_free(ctx); - return nullptr; - } - - return ctx; + return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params()); } struct whisper_context * whisper_init(struct whisper_model_loader * loader) { - whisper_context * ctx = whisper_init_no_state(loader); - if (!ctx) { - return nullptr; - } + return whisper_init_with_params(loader, whisper_context_default_params()); +} - ctx->state = whisper_init_state(ctx); - if (!ctx->state) { - whisper_free(ctx); - return nullptr; - } +struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { + return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params()); +} - return ctx; +struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { + return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { + return whisper_init_with_params_no_state(loader, whisper_context_default_params()); } void whisper_free_state(struct whisper_state * state) @@ -3230,6 +3268,12 @@ void whisper_free(struct whisper_context * ctx) { } } +void whisper_free_context_params(struct whisper_context_params * params) { + if (params) { + delete params; + } +} + void whisper_free_params(struct whisper_full_params * params) { if (params) { delete params; @@ -3698,6 +3742,14 @@ const char * whisper_print_system_info(void) { //////////////////////////////////////////////////////////////////////////// +struct whisper_context_params * whisper_context_default_params_by_ref() { + struct whisper_context_params params = whisper_context_default_params(); + + struct whisper_context_params* result = new whisper_context_params(); + *result = params; + return result; +} + struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) { struct whisper_full_params params = whisper_full_default_params(strategy); @@ -4507,17 +4559,19 @@ int whisper_full_with_state( // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0 #ifdef GGML_USE_METAL + if (state->ctx_metal) { #define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - return 0; \ - } + if (!(result)) { \ + log("%s: failed to add metal buffer\n", __func__); \ + return 0; \ + } - const std::string kv_name = "kv_self_" + std::to_string(j); - auto & kv_self = decoder.kv_self; + const std::string kv_name = "kv_self_" + std::to_string(j); + auto & kv_self = decoder.kv_self; - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); #undef WHISPER_METAL_CHECK_BUF + } #endif } } diff --git a/whisper.h b/whisper.h index c3118c9c..300fc4ba 100644 --- a/whisper.h +++ b/whisper.h @@ -5,6 +5,14 @@ #include #include +#ifdef __GNUC__ +# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define WHISPER_DEPRECATED(func, hint) func +#endif + #ifdef WHISPER_SHARED # ifdef _WIN32 # ifdef WHISPER_BUILD @@ -71,6 +79,10 @@ extern "C" { typedef int whisper_token; + struct whisper_context_params { + bool use_gpu; + }; + typedef struct whisper_token_data { whisper_token id; // token id whisper_token tid; // forced timestamp token id @@ -99,15 +111,40 @@ extern "C" { // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure - WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); - WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); - WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params); // These are the same as the above, but the internal state of the context is not allocated automatically // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) - WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model); - WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size); - WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params); + + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model), + "use whisper_init_from_file_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader), + "use whisper_init_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model), + "use whisper_init_from_file_with_params_no_state instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params_no_state instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader), + "use whisper_init_with_params_no_state instead" + ); WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); @@ -132,6 +169,7 @@ extern "C" { WHISPER_API void whisper_free (struct whisper_context * ctx); WHISPER_API void whisper_free_state(struct whisper_state * state); WHISPER_API void whisper_free_params(struct whisper_full_params * params); + WHISPER_API void whisper_free_context_params(struct whisper_context_params * params); // Convert RAW PCM audio to log mel spectrogram. // The resulting spectrogram is stored inside the default state of the provided whisper context. @@ -442,7 +480,9 @@ extern "C" { void * logits_filter_callback_user_data; }; - // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params() + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() + WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(); + WHISPER_API struct whisper_context_params whisper_context_default_params(void); WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);