From 594a121f3ed1aad292b36e8f375c71553f6266ba Mon Sep 17 00:00:00 2001 From: Page-MS <120176945+Page-MS@users.noreply.github.com> Date: Wed, 26 Mar 2025 03:30:59 -0400 Subject: [PATCH 1/5] readme : add note about SDL2 (#2946) Precise the README section about real time audio processing, stating that sdl2 is needed. --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e6e96731..dfe7f143 100644 --- a/README.md +++ b/README.md @@ -427,7 +427,8 @@ For detailed instructions on how to use Conan, please refer to the [Conan docume This is a naive example of performing real-time inference on audio from your microphone. The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously. -More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10). +More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10). +You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installed for it to work properly. ```bash cmake -B build -DWHISPER_SDL2=ON From 2699e1485a9180cc666a937cf1b3c17da4168bf3 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 26 Mar 2025 14:49:12 +0100 Subject: [PATCH 2/5] bindings.javascript : update test instructions [no ci] (#2951) This commit updates the instructions for running the test in the JavaScript bindings README file. The motivation for this is for Node.js versions after v16.4.0 the `--experimental-wasm-threads` and `--experimental-wasm-simd` flags are no longer required and they generate the following errors: ```console $ node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js node: bad option: --experimental-wasm-threads node: bad option: --experimental-wasm-simd ``` --- bindings/javascript/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bindings/javascript/README.md b/bindings/javascript/README.md index 87f34805..5e726e13 100644 --- a/bindings/javascript/README.md +++ b/bindings/javascript/README.md @@ -33,6 +33,9 @@ mkdir build-em && cd build-em emcmake cmake .. && make -j # run test +node ../tests/test-whisper.js + +# For Node.js versions prior to v16.4.0, experimental features need to be enabled: node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js # publish npm package From 0b43a02be809a063ff5ac3eaab5a7fcd54be0eb3 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 26 Mar 2025 15:01:28 +0100 Subject: [PATCH 3/5] bindings.java : enable copyLibs task [no ci] (#2949) * bindings.java : enable copyLibs task [no ci] This commit adds a dependency on the copyLibs task to the sourcesJar and jar tasks. This ensures that the libwhisper.so file is copied to the correct location before the jar is built. It also sets the executable bit on the gradlew file. * bindings.java : add copyLibs dep for processResources [no ci] This will otherwise cause builds to fail after doing an initial build. * bindings.java : pass structs by value to native code This commit refactors the code to pass the structs by value to the native code. This is done by creating a ByValue class for each struct and using it in the Java code. The motivation for this change is that without this application crashes due to what I believe was memory mis-alignement. When the structs were passed to the native code they would be att different memory locations. Passing by value overcomes this issue and considering that the structs hold parementers (context and full params) it might be alright do to this. These changes allow all the tests to pass. * bindings.java : fix javadoc warnings [no ci] * bindings.java : fix libwhisper.dylib path in build.gradle [no ci] This commit fixes the copyLibwhisperDynlib task in the build.gradle file to copy the correct libwhisper.dylib file from build/src. --- bindings/java/build.gradle | 14 ++++- bindings/java/gradlew | 0 .../whispercpp/WhisperConstants.java | 24 +++++++ .../ggerganov/whispercpp/WhisperContext.java | 39 +++++------- .../ggerganov/whispercpp/WhisperCpp.java | 39 +++++++----- .../whispercpp/WhisperCppJnaLibrary.java | 17 ++--- .../callbacks/GgmlAbortCallback.java | 17 +++++ .../whispercpp/params/WhisperAhead.java | 30 +++++++++ .../whispercpp/params/WhisperAheads.java | 41 ++++++++++++ .../params/WhisperContextParams.java | 62 +++++++++++++++++-- .../whispercpp/params/WhisperFullParams.java | 56 +++++++++++++---- .../ggerganov/whispercpp/WhisperCppTest.java | 2 +- 12 files changed, 274 insertions(+), 67 deletions(-) mode change 100644 => 100755 bindings/java/gradlew create mode 100644 bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperConstants.java create mode 100644 bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/GgmlAbortCallback.java create mode 100644 bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperAhead.java create mode 100644 bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperAheads.java diff --git a/bindings/java/build.gradle b/bindings/java/build.gradle index 75f3a9cd..eb1a5c07 100644 --- a/bindings/java/build.gradle +++ b/bindings/java/build.gradle @@ -25,13 +25,13 @@ sourceSets { } tasks.register('copyLibwhisperDynlib', Copy) { - from '../../build' - include 'libwhisper.dynlib' + from '../../build/src' + include 'libwhisper.dylib' into 'build/generated/resources/main/darwin' } tasks.register('copyLibwhisperSo', Copy) { - from '../../build' + from '../../build/src' include 'libwhisper.so' into 'build/generated/resources/main/linux-x86-64' } @@ -55,7 +55,12 @@ java { withJavadocJar() } +sourcesJar() { + dependsOn copyLibs +} + jar { + dependsOn copyLibs exclude '**/whisper_java.exp', '**/whisper_java.lib' } @@ -67,6 +72,9 @@ tasks.withType(Test) { useJUnitPlatform() } +test.dependsOn copyLibs +processResources.dependsOn copyLibs + dependencies { implementation "net.java.dev.jna:jna:5.13.0" testImplementation "org.junit.jupiter:junit-jupiter:5.9.2" diff --git a/bindings/java/gradlew b/bindings/java/gradlew old mode 100644 new mode 100755 diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperConstants.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperConstants.java new file mode 100644 index 00000000..0c828f1d --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperConstants.java @@ -0,0 +1,24 @@ +package io.github.ggerganov.whispercpp; + +/** + * Presets for alignment heads in DTW token timestamps + */ +public class WhisperConstants { + // Alignment heads presets + public static final int WHISPER_AHEADS_NONE = 0; + public static final int WHISPER_AHEADS_TINY_EN = 1; + public static final int WHISPER_AHEADS_TINY = 2; + public static final int WHISPER_AHEADS_BASE_EN = 3; + public static final int WHISPER_AHEADS_BASE = 4; + public static final int WHISPER_AHEADS_SMALL_EN = 5; + public static final int WHISPER_AHEADS_SMALL = 6; + public static final int WHISPER_AHEADS_MEDIUM_EN = 7; + public static final int WHISPER_AHEADS_MEDIUM = 8; + public static final int WHISPER_AHEADS_LARGE_V1 = 9; + public static final int WHISPER_AHEADS_LARGE_V2 = 10; + public static final int WHISPER_AHEADS_LARGE_V3 = 11; + public static final int WHISPER_AHEADS_LARGE_V3_TURBO = 12; + public static final int WHISPER_AHEADS_CUSTOM = 13; + public static final int WHISPER_AHEADS_N_TOP_MOST = 14; + public static final int WHISPER_AHEADS_COUNT = 15; +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java index 0498eb4d..7ac124ed 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperContext.java @@ -1,7 +1,9 @@ package io.github.ggerganov.whispercpp; +import com.sun.jna.NativeLong; import com.sun.jna.Structure; import com.sun.jna.ptr.PointerByReference; +import com.sun.jna.Pointer; import io.github.ggerganov.whispercpp.ggml.GgmlType; import io.github.ggerganov.whispercpp.WhisperModel; import io.github.ggerganov.whispercpp.params.WhisperContextParams; @@ -9,33 +11,26 @@ import io.github.ggerganov.whispercpp.params.WhisperContextParams; import java.util.List; public class WhisperContext extends Structure { - int t_load_us = 0; - int t_start_us = 0; + public NativeLong t_load_us; + public NativeLong t_start_us; /** weight type (FP32 / FP16 / QX) */ - GgmlType wtype = GgmlType.GGML_TYPE_F16; + public GgmlType wtype = GgmlType.GGML_TYPE_F16; /** intermediate type (FP32 or FP16) */ - GgmlType itype = GgmlType.GGML_TYPE_F16; + public GgmlType itype = GgmlType.GGML_TYPE_F16; -// WhisperModel model; - public PointerByReference model; -// whisper_vocab vocab; -// whisper_state * state = nullptr; - public PointerByReference vocab; - public PointerByReference state; + public WhisperContextParams.ByValue params; + + public Pointer model; + public Pointer vocab; + public Pointer state; /** populated by whisper_init_from_file_with_params() */ - String path_model; - WhisperContextParams params; + public Pointer path_model; -// public static class ByReference extends WhisperContext implements Structure.ByReference { -// } -// -// public static class ByValue extends WhisperContext implements Structure.ByValue { -// } -// -// @Override -// protected List getFieldOrder() { -// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model"); -// } + @Override + protected List getFieldOrder() { + return List.of("t_load_us", "t_start_us", "wtype", "itype", + "params", "model", "vocab", "state", "path_model"); + } } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java index 4c1594d5..621d8c63 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCpp.java @@ -43,11 +43,11 @@ public class WhisperCpp implements AutoCloseable { * @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en") * @param params - params to use when initialising the context */ - public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException { + public void initContext(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException { initContextImpl(modelPath, params); } - private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException { + private void initContextImpl(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException { if (ctx != null) { lib.whisper_free(ctx); } @@ -69,15 +69,13 @@ public class WhisperCpp implements AutoCloseable { /** * Provides default params which can be used with `whisper_init_from_file_with_params()` etc. - * Because this function allocates memory for the params, the caller must call either: - * - call `whisper_free_context_params()` - * - `Native.free(Pointer.nativeValue(pointer));` + * Returns a ByValue instance to ensure proper parameter passing to native code. */ - public WhisperContextParams getContextDefaultParams() { - paramsPointer = lib.whisper_context_default_params_by_ref(); - WhisperContextParams params = new WhisperContextParams(paramsPointer); - params.read(); - return params; + public WhisperContextParams.ByValue getContextDefaultParams() { + WhisperContextParams.ByValue valueParams = new WhisperContextParams.ByValue( + lib.whisper_context_default_params_by_ref()); + valueParams.read(); + return valueParams; } /** @@ -88,7 +86,7 @@ public class WhisperCpp implements AutoCloseable { * * @param strategy - GREEDY */ - public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) { + public WhisperFullParams.ByValue getFullDefaultParams(WhisperSamplingStrategy strategy) { Pointer pointer; // whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy. @@ -104,7 +102,7 @@ public class WhisperCpp implements AutoCloseable { pointer = beamParamsPointer; } - WhisperFullParams params = new WhisperFullParams(pointer); + WhisperFullParams.ByValue params = new WhisperFullParams.ByValue(pointer); params.read(); return params; } @@ -138,15 +136,21 @@ public class WhisperCpp implements AutoCloseable { } /** - * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. * Not thread safe for same context * Uses the specified decoding strategy to obtain the text. */ - public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException { + public String fullTranscribe(WhisperFullParams.ByValue whisperParams, float[] audioData) throws IOException { if (ctx == null) { throw new IllegalStateException("Model not initialised"); } + /* + WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue( + lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal())); + valueParams.read(); + */ + if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) { throw new IOException("Failed to process audio"); } @@ -163,12 +167,17 @@ public class WhisperCpp implements AutoCloseable { return str.toString().trim(); } + public List fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException { if (ctx == null) { throw new IllegalStateException("Model not initialised"); } - if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) { + WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue( + lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal())); + valueParams.read(); + + if (lib.whisper_full(ctx, valueParams, audioData, audioData.length) != 0) { throw new IOException("Failed to process audio"); } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java index 1a73cee1..1cd2449f 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java @@ -38,7 +38,7 @@ public interface WhisperCppJnaLibrary extends Library { * @param params Pointer to whisper_context_params * @return Whisper context on success, null on failure */ - Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params); + Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams.ByValue params); /** * Allocate (almost) all memory needed for the model by loading from a buffer. @@ -180,12 +180,12 @@ public interface WhisperCppJnaLibrary extends Library { /** * @return the id of the specified language, returns -1 if not found. * Examples: - * "de" -> 2 - * "german" -> 2 + * "de" -> 2 + * "german" -> 2 */ int whisper_lang_id(String lang); - /** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */ + /** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */ String whisper_lang_str(int id); /** @@ -268,20 +268,21 @@ public interface WhisperCppJnaLibrary extends Library { void whisper_free_params(Pointer params); /** - * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text * Not thread safe for same context * Uses the specified decoding strategy to obtain the text. */ - int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples); + int whisper_full(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples); - int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples); + public int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams.ByValue params, float[] samples, int n_samples); + //int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples); // Split the input audio in chunks and process each chunk separately using whisper_full_with_state() // Result is stored in the default state of the context // Not thread safe if executed in parallel on the same context. // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk. - int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors); + int whisper_full_parallel(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples, int n_processors); /** * Number of generated text segments. diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/GgmlAbortCallback.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/GgmlAbortCallback.java new file mode 100644 index 00000000..244e4191 --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/callbacks/GgmlAbortCallback.java @@ -0,0 +1,17 @@ +package io.github.ggerganov.whispercpp.callbacks; + +import com.sun.jna.Callback; + +/** + * Callback for aborting GGML computation + * Maps to the C typedef: bool (*ggml_abort_callback)(void * data) + */ +public interface GgmlAbortCallback extends Callback { + /** + * Return true to abort the computation, false to continue + * + * @param data User data passed to the callback + * @return true to abort, false to continue + */ + boolean invoke(com.sun.jna.Pointer data); +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperAhead.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperAhead.java new file mode 100644 index 00000000..39691dcb --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperAhead.java @@ -0,0 +1,30 @@ +package io.github.ggerganov.whispercpp.params; +import com.sun.jna.*; +import java.util.Arrays; +import java.util.List; + +public class WhisperAhead extends Structure { + + public int n_text_layer; + + public int n_head; + + public WhisperAhead() { + super(); + } + + public WhisperAhead(int textLayer, int head) { + super(); + this.n_text_layer = textLayer; + this.n_head = head; + } + + @Override + protected List getFieldOrder() { + return Arrays.asList("n_text_layer", "n_head"); + } + + public static class ByReference extends WhisperAhead implements Structure.ByReference {} + + public static class ByValue extends WhisperAhead implements Structure.ByValue {} +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperAheads.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperAheads.java new file mode 100644 index 00000000..bca5eb0a --- /dev/null +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperAheads.java @@ -0,0 +1,41 @@ +package io.github.ggerganov.whispercpp.params; +import com.sun.jna.*; +import java.util.Arrays; +import java.util.List; + +public class WhisperAheads extends Structure { + public NativeLong n_heads; + + public Pointer heads; + + public WhisperAheads() { + super(); + } + + /** + * Create alignment heads from an array of WhisperAhead objects + */ + public void setHeads(WhisperAhead[] aheadsArray) { + this.n_heads = new NativeLong(aheadsArray.length); + + int structSize = aheadsArray[0].size(); + Memory mem = new Memory(structSize * aheadsArray.length); + + for (int i = 0; i < aheadsArray.length; i++) { + aheadsArray[i].write(); + byte[] buffer = aheadsArray[i].getPointer().getByteArray(0, structSize); + mem.write(i * structSize, buffer, 0, buffer.length); + } + + this.heads = mem; + } + + @Override + protected List getFieldOrder() { + return Arrays.asList("n_heads", "heads"); + } + + public static class ByReference extends WhisperAheads implements Structure.ByReference {} + + public static class ByValue extends WhisperAheads implements Structure.ByValue {} +} diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java index cf98d2c3..4bcdb6b0 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperContextParams.java @@ -1,7 +1,5 @@ package io.github.ggerganov.whispercpp.params; - import com.sun.jna.*; - import java.util.Arrays; import java.util.List; @@ -11,21 +9,73 @@ import java.util.List; * whisper_context_default_params() */ public class WhisperContextParams extends Structure { - public WhisperContextParams(Pointer p) { super(p); } - /** Use GPU for inference Number (default = true) */ + public WhisperContextParams() { + super(); + } + + /** Use GPU for inference (default = true) */ public CBool use_gpu; - /** Use GPU for inference Number (default = true) */ + /** Use flash attention (default = false) */ + public CBool flash_attn; + + /** CUDA device to use (default = 0) */ + public int gpu_device; + + /** [EXPERIMENTAL] Enable token-level timestamps with DTW (default = false) */ + public CBool dtw_token_timestamps; + + /** [EXPERIMENTAL] Alignment heads preset for DTW */ + public int dtw_aheads_preset; + + /** Number of top layers to use for DTW when using WHISPER_AHEADS_N_TOP_MOST preset */ + public int dtw_n_top; + + public WhisperAheads.ByValue dtw_aheads; + + /** DTW memory size (internal use) */ + public NativeLong dtw_mem_size; + + /** Use GPU for inference */ public void useGpu(boolean enable) { use_gpu = enable ? CBool.TRUE : CBool.FALSE; } + /** Use flash attention */ + public void useFlashAttn(boolean enable) { + flash_attn = enable ? CBool.TRUE : CBool.FALSE; + } + + /** Enable DTW token-level timestamps */ + public void enableDtwTokenTimestamps(boolean enable) { + dtw_token_timestamps = enable ? CBool.TRUE : CBool.FALSE; + } + + /** Set DTW alignment heads preset */ + public void setDtwAheadsPreset(int preset) { + dtw_aheads_preset = preset; + } + @Override protected List getFieldOrder() { - return Arrays.asList("use_gpu"); + return Arrays.asList( + "use_gpu", + "flash_attn", + "gpu_device", + "dtw_token_timestamps", + "dtw_aheads_preset", + "dtw_n_top", + "dtw_aheads", + "dtw_mem_size" + ); + } + + public static class ByValue extends WhisperContextParams implements Structure.ByValue { + public ByValue() { super(); } + public ByValue(Pointer p) { super(p); } } } diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 18c209fc..498ff126 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -5,6 +5,7 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback; import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback; +import io.github.ggerganov.whispercpp.callbacks.GgmlAbortCallback; import java.util.Arrays; import java.util.List; @@ -16,10 +17,12 @@ import java.util.List; */ public class WhisperFullParams extends Structure { + public WhisperFullParams() { + super(); + } + public WhisperFullParams(Pointer p) { super(p); -// super(p, ALIGN_MSVC); -// super(p, ALIGN_GNUC); } /** Sampling strategy for whisper_full() function. */ @@ -69,10 +72,10 @@ public class WhisperFullParams extends Structure { single_segment = single ? CBool.TRUE : CBool.FALSE; } - /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */ + /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */ public CBool print_special; - /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */ + /** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */ public void printSpecial(boolean enable) { print_special = enable ? CBool.TRUE : CBool.FALSE; } @@ -129,6 +132,14 @@ public class WhisperFullParams extends Structure { /** Maximum tokens per segment (0, default = no limit) */ public int max_tokens; + /** [EXPERIMENTAL] Enable debug mode for extra info */ + public CBool debug_mode; + + /** Enable debug mode */ + public void enableDebugMode(boolean enable) { + debug_mode = enable ? CBool.TRUE : CBool.FALSE; + } + /** Overwrite the audio context size (0 = use default). */ public int audio_ctx; @@ -274,6 +285,16 @@ public class WhisperFullParams extends Structure { */ public Pointer encoder_begin_callback_user_data; + /** Callback used to abort GGML computation */ + public Pointer abort_callback; + + /** User data for the abort_callback */ + public Pointer abort_callback_user_data; + + public void setAbortCallback(GgmlAbortCallback callback) { + abort_callback = CallbackReference.getFunctionPointer(callback); + } + /** * Callback by each decoder to filter obtained logits. * WhisperLogitsFilterCallback @@ -310,17 +331,28 @@ public class WhisperFullParams extends Structure { @Override protected List getFieldOrder() { - return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate", - "no_context", "single_segment", "no_timestamps", - "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", - "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx", - "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", - "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty", - "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", - "new_segment_callback", "new_segment_callback_user_data", + return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", + "offset_ms", "duration_ms", "translate", "no_context", + "no_timestamps", "single_segment", "print_special", + "print_progress", "print_realtime", "print_timestamps", + "token_timestamps", "thold_pt", "thold_ptsum", "max_len", + "split_on_word", "max_tokens", "debug_mode", "audio_ctx", + "tdrz_enable", "suppress_regex", "initial_prompt", + "prompt_tokens", "prompt_n_tokens", "language", "detect_language", + "suppress_blank", "suppress_nst", "temperature", + "max_initial_ts", "length_penalty", "temperature_inc", + "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", + "beam_search", "new_segment_callback", "new_segment_callback_user_data", "progress_callback", "progress_callback_user_data", "encoder_begin_callback", "encoder_begin_callback_user_data", + "abort_callback", "abort_callback_user_data", "logits_filter_callback", "logits_filter_callback_user_data", "grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty"); } + + public static class ByValue extends WhisperFullParams implements Structure.ByValue { + public ByValue() { super(); } + public ByValue(Pointer p) { super(p); } + } + } diff --git a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java index 034726ad..9d63fff3 100644 --- a/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java +++ b/bindings/java/src/test/java/io/github/ggerganov/whispercpp/WhisperCppTest.java @@ -76,7 +76,7 @@ class WhisperCppTest { float[] floats = new float[b.length / 2]; //WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY); - WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); + WhisperFullParams.ByValue params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH); params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress)); params.print_progress = CBool.FALSE; //params.initial_prompt = "and so my fellow Americans um, like"; From 21d890d53441a08e36992b0f68221f38b715789d Mon Sep 17 00:00:00 2001 From: Dan Johansson Date: Wed, 26 Mar 2025 15:54:02 +0100 Subject: [PATCH 4/5] whisper : add support for backends with multiple ggml_backend_buffer_type (#2863) * whisper : add support for ggml_backend_buffer_type Signed-off-by: Dan Johansson * fix compile error when building on Ubuntu Signed-off-by: Dan Johansson * remove copyright header from include file Signed-off-by: Dan Johansson --------- Signed-off-by: Dan Johansson --- src/CMakeLists.txt | 1 + src/whisper-arch.h | 141 +++++++++++++++ src/whisper.cpp | 438 +++++++++++++++++++++++++-------------------- 3 files changed, 387 insertions(+), 193 deletions(-) create mode 100644 src/whisper-arch.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 527d38b8..a091e66a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -102,6 +102,7 @@ endif() add_library(whisper ../include/whisper.h + whisper-arch.h whisper.cpp ) diff --git a/src/whisper-arch.h b/src/whisper-arch.h new file mode 100644 index 00000000..ea2cfd60 --- /dev/null +++ b/src/whisper-arch.h @@ -0,0 +1,141 @@ +#pragma once + +#include "ggml.h" + +#include + +enum asr_tensor { + ASR_TENSOR_ENC_POS_EMBD, + ASR_TENSOR_DEC_POS_EMBD, + ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, + ASR_TENSOR_LN_WEIGHT, + ASR_TENSOR_LN_BIAS, + ASR_TENSOR_CONV1_WEIGHT, + ASR_TENSOR_CONV1_BIAS, + ASR_TENSOR_CONV2_WEIGHT, + ASR_TENSOR_CONV2_BIAS, + ASR_TENSOR_LN_POST_WEIGHT, + ASR_TENSOR_LN_POST_BIAS, + ASR_TENSOR_MLP_LN_WEIGHT, + ASR_TENSOR_MLP_LN_BIAS, + ASR_TENSOR_MLP_0_WEIGHT, + ASR_TENSOR_MLP_0_BIAS, + ASR_TENSOR_MLP_2_WEIGHT, + ASR_TENSOR_MLP_2_BIAS, + ASR_TENSOR_ATTN_LN_WEIGHT, + ASR_TENSOR_ATTN_LN_BIAS, + ASR_TENSOR_ATTN_QUERY_WEIGHT, + ASR_TENSOR_ATTN_QUERY_BIAS, + ASR_TENSOR_ATTN_KEY_WEIGHT, + ASR_TENSOR_ATTN_VALUE_WEIGHT, + ASR_TENSOR_ATTN_VALUE_BIAS, + ASR_TENSOR_ATTN_OUT_WEIGHT, + ASR_TENSOR_ATTN_OUT_BIAS, +}; + +enum asr_system { + ASR_SYSTEM_ENCODER, + ASR_SYSTEM_DECODER, + ASR_SYSTEM_CROSS +}; + +static const std::map> ASR_TENSOR_NAMES = { + { + ASR_SYSTEM_ENCODER, + { + {ASR_TENSOR_ENC_POS_EMBD, "encoder.positional_embedding"}, + {ASR_TENSOR_CONV1_WEIGHT, "encoder.conv1.weight"}, + {ASR_TENSOR_CONV1_BIAS, "encoder.conv1.bias"}, + {ASR_TENSOR_CONV2_WEIGHT, "encoder.conv2.weight"}, + {ASR_TENSOR_CONV2_BIAS, "encoder.conv2.bias"}, + {ASR_TENSOR_LN_WEIGHT, "encoder.ln_post.weight"}, + {ASR_TENSOR_LN_POST_BIAS, "encoder.ln_post.bias"}, + {ASR_TENSOR_MLP_LN_WEIGHT, "encoder.blocks.%d.mlp_ln.weight"}, + {ASR_TENSOR_MLP_LN_BIAS, "encoder.blocks.%d.mlp_ln.bias"}, + {ASR_TENSOR_MLP_0_WEIGHT, "encoder.blocks.%d.mlp.0.weight"}, + {ASR_TENSOR_MLP_0_BIAS, "encoder.blocks.%d.mlp.0.bias"}, + {ASR_TENSOR_MLP_2_WEIGHT, "encoder.blocks.%d.mlp.2.weight"}, + {ASR_TENSOR_MLP_2_BIAS, "encoder.blocks.%d.mlp.2.bias"}, + {ASR_TENSOR_ATTN_LN_WEIGHT, "encoder.blocks.%d.attn_ln.weight"}, + {ASR_TENSOR_ATTN_LN_BIAS, "encoder.blocks.%d.attn_ln.bias"}, + {ASR_TENSOR_ATTN_QUERY_WEIGHT, "encoder.blocks.%d.attn.query.weight"}, + {ASR_TENSOR_ATTN_QUERY_BIAS, "encoder.blocks.%d.attn.query.bias"}, + {ASR_TENSOR_ATTN_KEY_WEIGHT, "encoder.blocks.%d.attn.key.weight"}, + {ASR_TENSOR_ATTN_VALUE_WEIGHT, "encoder.blocks.%d.attn.value.weight"}, + {ASR_TENSOR_ATTN_VALUE_BIAS, "encoder.blocks.%d.attn.value.bias"}, + {ASR_TENSOR_ATTN_OUT_WEIGHT, "encoder.blocks.%d.attn.out.weight"}, + {ASR_TENSOR_ATTN_OUT_BIAS, "encoder.blocks.%d.attn.out.bias"}, + }, + }, + { + ASR_SYSTEM_DECODER, + { + {ASR_TENSOR_DEC_POS_EMBD, "decoder.positional_embedding"}, + {ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, "decoder.token_embedding.weight"}, + {ASR_TENSOR_LN_WEIGHT, "decoder.ln.weight"}, + {ASR_TENSOR_LN_BIAS, "decoder.ln.bias"}, + + {ASR_TENSOR_MLP_LN_WEIGHT, "decoder.blocks.%d.mlp_ln.weight"}, + {ASR_TENSOR_MLP_LN_BIAS, "decoder.blocks.%d.mlp_ln.bias"}, + {ASR_TENSOR_MLP_0_WEIGHT, "decoder.blocks.%d.mlp.0.weight"}, + {ASR_TENSOR_MLP_0_BIAS, "decoder.blocks.%d.mlp.0.bias"}, + {ASR_TENSOR_MLP_2_WEIGHT, "decoder.blocks.%d.mlp.2.weight"}, + {ASR_TENSOR_MLP_2_BIAS, "decoder.blocks.%d.mlp.2.bias"}, + {ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.attn_ln.weight"}, + {ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.attn_ln.bias"}, + {ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.attn.query.weight"}, + {ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.attn.query.bias"}, + {ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.attn.key.weight"}, + {ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.attn.value.weight"}, + {ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.attn.value.bias"}, + {ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.attn.out.weight"}, + {ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.attn.out.bias"}, + }, + }, + { + ASR_SYSTEM_CROSS, + { + {ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.cross_attn_ln.weight"}, + {ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.cross_attn_ln.bias"}, + {ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.cross_attn.query.weight"}, + {ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.cross_attn.query.bias"}, + {ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.cross_attn.key.weight"}, + {ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.cross_attn.value.weight"}, + {ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.cross_attn.value.bias"}, + {ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.cross_attn.out.weight"}, + {ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.cross_attn.out.bias"}, + }, + }, +}; + +static const std::map ASR_TENSOR_INFO = { + {ASR_TENSOR_ENC_POS_EMBD, GGML_OP_ADD}, + {ASR_TENSOR_DEC_POS_EMBD, GGML_OP_GET_ROWS}, + // Note: ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT is also used by GGML_OP_MAT_MUL. Need to figure out a way how to handle + // weight tensors that are used by multiple different operators when extra_buffer_type implementations accelerate + // more than just GGML_OP_MUL_MAT. + {ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, GGML_OP_GET_ROWS}, + {ASR_TENSOR_LN_WEIGHT, GGML_OP_MUL}, + {ASR_TENSOR_LN_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_CONV1_WEIGHT, GGML_OP_IM2COL}, + {ASR_TENSOR_CONV1_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_CONV2_WEIGHT, GGML_OP_IM2COL}, + {ASR_TENSOR_CONV2_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_LN_POST_WEIGHT, GGML_OP_MUL}, + {ASR_TENSOR_LN_POST_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_MLP_LN_WEIGHT, GGML_OP_MUL}, + {ASR_TENSOR_MLP_LN_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_MLP_0_WEIGHT, GGML_OP_MUL_MAT}, + {ASR_TENSOR_MLP_0_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_MLP_2_WEIGHT, GGML_OP_MUL_MAT}, + {ASR_TENSOR_MLP_2_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_ATTN_LN_WEIGHT, GGML_OP_MUL}, + {ASR_TENSOR_ATTN_LN_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_ATTN_QUERY_WEIGHT, GGML_OP_MUL_MAT}, + {ASR_TENSOR_ATTN_QUERY_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_ATTN_KEY_WEIGHT, GGML_OP_MUL_MAT}, + {ASR_TENSOR_ATTN_VALUE_WEIGHT, GGML_OP_MUL_MAT}, + {ASR_TENSOR_ATTN_VALUE_BIAS, GGML_OP_ADD}, + {ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD}, +}; diff --git a/src/whisper.cpp b/src/whisper.cpp index 547dcf23..c633765e 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -1,4 +1,5 @@ #include "whisper.h" +#include "whisper-arch.h" #include "ggml.h" #include "ggml-cpp.h" @@ -18,6 +19,7 @@ #include #define _USE_MATH_DEFINES #include +#include #include #include #include @@ -143,6 +145,21 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text #define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + // // ggml helpers // @@ -778,10 +795,10 @@ struct whisper_model { std::vector layers_decoder; // ggml context that contains all the meta information about the model tensors - struct ggml_context * ctx = nullptr; + std::vector ctxs; // the model backend data is read-only and can be shared between processors - ggml_backend_buffer_t buffer = nullptr; + std::vector buffers; // tensors int n_loaded; @@ -1364,28 +1381,109 @@ static std::vector whisper_backend_init(const whisper_context_pa return result; } -static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) { - ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type(); +using buft_list_t = std::vector>; - if (!params.use_gpu) { - return result; - } +static buft_list_t make_buft_list(whisper_context_params & params) { + // Prio order: GPU -> CPU Extra -> CPU + buft_list_t buft_list; - int cnt = 0; - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { - if (cnt == 0 || cnt == params.gpu_device) { - result = ggml_backend_dev_buffer_type(dev); - } + // GPU + if (params.use_gpu) { + int cnt = 0; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { + if (cnt == 0 || cnt == params.gpu_device) { + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + } + } - if (++cnt > params.gpu_device) { - break; + if (++cnt > params.gpu_device) { + break; + } } } } - return result; + // CPU Extra + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // CPU + buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type()); + + return buft_list; +} + +static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + bool op_supported = true; + + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + // GPU and default CPU backend support all operators + op_supported = true; + } else { + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT + case GGML_OP_MUL_MAT: { + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + int64_t n_ctx = hparams.n_audio_ctx; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: { + op_supported = false; + break; + } + }; + } + + return op_supported; +} + +static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & p : buft_list) { + ggml_backend_dev_t dev = p.first; + ggml_backend_buffer_type_t buft = p.second; + if (weight_buft_supported(hparams, w, op, buft, dev)) { + return buft; + } + } + + return nullptr; } // load the model from a ggml file @@ -1594,31 +1692,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con const ggml_type wtype = wctx.wtype; const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type - // create the ggml context + const auto & hparams = model.hparams; + + const int n_audio_layer = hparams.n_audio_layer; + const int n_text_layer = hparams.n_text_layer; + + const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer; + + std::map ctx_map; + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ctx_map[buft] = ctx; + model.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // Create a list of available bufts, in priority order + buft_list_t buft_list = make_buft_list(wctx.params); + + auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * { + ggml_op op = ASR_TENSOR_INFO.at(type); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type))); + } + + ggml_context * ctx = get_ctx(buft); + ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + + model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor; + + return tensor; + }; + + + // prepare tensors for the weights { - const auto & hparams = model.hparams; - - const int n_audio_layer = hparams.n_audio_layer; - const int n_text_layer = hparams.n_text_layer; - - const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer; - - struct ggml_init_params params = { - /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - model.ctx = ggml_init(params); - if (!model.ctx) { - WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__); - return false; - } - } - - // prepare tensors for the weights - { - auto & ctx = model.ctx; + ggml_context * ctx = ggml_init(params); const auto & hparams = model.hparams; @@ -1638,189 +1770,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.layers_decoder.resize(n_text_layer); // encoder - { - model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); + model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx)); - model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); - model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state)); + model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state)); - model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); - model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state)); + model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state)); - model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); + model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); - // map by name - model.tensors["encoder.positional_embedding"] = model.e_pe; + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers_encoder[i]; - model.tensors["encoder.conv1.weight"] = model.e_conv_1_w; - model.tensors["encoder.conv1.bias"] = model.e_conv_1_b; + layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); - model.tensors["encoder.conv2.weight"] = model.e_conv_2_w; - model.tensors["encoder.conv2.bias"] = model.e_conv_2_b; + layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i); - model.tensors["encoder.ln_post.weight"] = model.e_ln_w; - model.tensors["encoder.ln_post.bias"] = model.e_ln_b; + layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); - for (int i = 0; i < n_audio_layer; ++i) { - auto & layer = model.layers_encoder[i]; + layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); - layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); - layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state); - layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state); + layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); - layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state); - layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); - layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - - layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); - layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); - - // map by name - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; - - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; - model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; - } + layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); } // decoder - { - model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx); + model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx)); - model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); + model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab)); - model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state)); + model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state)); - // map by name - model.tensors["decoder.positional_embedding"] = model.d_pe; + for (int i = 0; i < n_text_layer; ++i) { + auto & layer = model.layers_decoder[i]; - model.tensors["decoder.token_embedding.weight"] = model.d_te; + layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); + layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - model.tensors["decoder.ln.weight"] = model.d_ln_w; - model.tensors["decoder.ln.bias"] = model.d_ln_b; + layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i); + layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i); - for (int i = 0; i < n_text_layer; ++i) { - auto & layer = model.layers_decoder[i]; + layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i); + layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); + layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state); - layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state); + layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i); + layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state); - layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i); - layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i); + layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i); + layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); + layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i); + layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i); - layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i); + layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - - layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); - layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); - - // map by name - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; - - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; - model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; - } + layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i); + layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); } + + ggml_free(ctx); } // allocate tensors in the backend buffers - model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params)); - if (!model.buffer) { - WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__); - return false; - } + for (auto & p : ctx_map) { + ggml_backend_buffer_type_t buft = p.first; + ggml_context * ctx = p.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf) { + model.buffers.emplace_back(buf); - size_t size_main = ggml_backend_buffer_get_size(model.buffer); - WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6); + size_t size_main = ggml_backend_buffer_get_size(buf); + WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6); + } + } // load weights { @@ -1883,11 +1934,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con return false; } - //ggml_backend_t backend = wctx.backend; - - //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str()); - - if (ggml_backend_buffer_is_host(model.buffer)) { + if (ggml_backend_buffer_is_host(tensor->buffer)) { // for the CPU and Metal backend, we can read directly into the tensor loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); BYTESWAP_TENSOR(tensor); @@ -1900,7 +1947,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); } - //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6); total_size += ggml_nbytes(tensor); model.n_loaded++; } @@ -1915,7 +1961,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + for (auto & buf : model.buffers) { + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } wctx.t_load_us = ggml_time_us() - t_start_us; @@ -3806,9 +3854,13 @@ void whisper_free_state(struct whisper_state * state) { void whisper_free(struct whisper_context * ctx) { if (ctx) { - ggml_free(ctx->model.ctx); + for (ggml_context * context : ctx->model.ctxs) { + ggml_free(context); + } - ggml_backend_buffer_free(ctx->model.buffer); + for (ggml_backend_buffer_t buf : ctx->model.buffers) { + ggml_backend_buffer_free(buf); + } whisper_free_state(ctx->state); From 206459a80477f00799b383d4c7e81d207c95d00b Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 26 Mar 2025 16:21:07 +0100 Subject: [PATCH 5/5] bindings-go : update Makefile to use cmake (#2952) This commit updates the Makefile to use cmake instead of make to build whisper.cpp. The motivation for this change is that currently the make recipe test will fail with the following error: ```console $ make test Mkdir build Mkdir models Build whisper make[1]: Entering directory '/home/danbev/work/ai/whisper-work' make[1]: *** No rule to make target 'libwhisper.a'. Stop. make[1]: Leaving directory '/home/danbev/work/ai/whisper-work' make: *** [Makefile:33: whisper] Error 2 ``` --- bindings/go/Makefile | 10 ++++++---- bindings/go/whisper.go | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/bindings/go/Makefile b/bindings/go/Makefile index ca39de20..edcc0166 100644 --- a/bindings/go/Makefile +++ b/bindings/go/Makefile @@ -11,11 +11,11 @@ UNAME_M := $(shell uname -m) endif GGML_METAL_PATH_RESOURCES := $(abspath ../..) -BUILD_DIR := build +BUILD_DIR := build_go MODELS_DIR := models EXAMPLES_DIR := $(wildcard examples/*) INCLUDE_PATH := $(abspath ../../include):$(abspath ../../ggml/include) -LIBRARY_PATH := $(abspath ../..) +LIBRARY_PATH := $(abspath ../../${BUILD_DIR}/src:$(abspath ../../${BUILD_DIR}/ggml/src)) ifeq ($(GGML_CUDA),1) LIBRARY_PATH := $(LIBRARY_PATH):$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib/ @@ -29,8 +29,10 @@ endif all: clean whisper examples whisper: mkdir - @echo Build whisper - @${MAKE} -C ../.. libwhisper.a + cmake -S ../.. -B ../../${BUILD_DIR} \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=OFF + cmake --build ../../${BUILD_DIR} --target whisper test: model-small whisper modtidy ifeq ($(UNAME_S),Darwin) diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 39ec43b4..525b72d2 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -9,7 +9,7 @@ import ( // CGO /* -#cgo LDFLAGS: -lwhisper -lm -lstdc++ -fopenmp +#cgo LDFLAGS: -lwhisper -lggml -lggml-base -lggml-cpu -lm -lstdc++ -fopenmp #cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics #include #include