mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-03 00:15:40 +02:00
Merge branch 'ggerganov:master' into feat/specifyStrategy
This commit is contained in:
commit
524140c8fb
@ -427,7 +427,8 @@ For detailed instructions on how to use Conan, please refer to the [Conan docume
|
|||||||
|
|
||||||
This is a naive example of performing real-time inference on audio from your microphone.
|
This is a naive example of performing real-time inference on audio from your microphone.
|
||||||
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
|
The [stream](examples/stream) tool samples the audio every half a second and runs the transcription continuously.
|
||||||
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
More info is available in [issue #10](https://github.com/ggerganov/whisper.cpp/issues/10).
|
||||||
|
You will need to have [sdl2](https://wiki.libsdl.org/SDL2/Installation) installed for it to work properly.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cmake -B build -DWHISPER_SDL2=ON
|
cmake -B build -DWHISPER_SDL2=ON
|
||||||
|
@ -11,11 +11,11 @@ UNAME_M := $(shell uname -m)
|
|||||||
endif
|
endif
|
||||||
|
|
||||||
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
|
GGML_METAL_PATH_RESOURCES := $(abspath ../..)
|
||||||
BUILD_DIR := build
|
BUILD_DIR := build_go
|
||||||
MODELS_DIR := models
|
MODELS_DIR := models
|
||||||
EXAMPLES_DIR := $(wildcard examples/*)
|
EXAMPLES_DIR := $(wildcard examples/*)
|
||||||
INCLUDE_PATH := $(abspath ../../include):$(abspath ../../ggml/include)
|
INCLUDE_PATH := $(abspath ../../include):$(abspath ../../ggml/include)
|
||||||
LIBRARY_PATH := $(abspath ../..)
|
LIBRARY_PATH := $(abspath ../../${BUILD_DIR}/src:$(abspath ../../${BUILD_DIR}/ggml/src))
|
||||||
|
|
||||||
ifeq ($(GGML_CUDA),1)
|
ifeq ($(GGML_CUDA),1)
|
||||||
LIBRARY_PATH := $(LIBRARY_PATH):$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib/
|
LIBRARY_PATH := $(LIBRARY_PATH):$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib/
|
||||||
@ -29,8 +29,10 @@ endif
|
|||||||
all: clean whisper examples
|
all: clean whisper examples
|
||||||
|
|
||||||
whisper: mkdir
|
whisper: mkdir
|
||||||
@echo Build whisper
|
cmake -S ../.. -B ../../${BUILD_DIR} \
|
||||||
@${MAKE} -C ../.. libwhisper.a
|
-DCMAKE_BUILD_TYPE=Release \
|
||||||
|
-DBUILD_SHARED_LIBS=OFF
|
||||||
|
cmake --build ../../${BUILD_DIR} --target whisper
|
||||||
|
|
||||||
test: model-small whisper modtidy
|
test: model-small whisper modtidy
|
||||||
ifeq ($(UNAME_S),Darwin)
|
ifeq ($(UNAME_S),Darwin)
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
// CGO
|
// CGO
|
||||||
|
|
||||||
/*
|
/*
|
||||||
#cgo LDFLAGS: -lwhisper -lm -lstdc++ -fopenmp
|
#cgo LDFLAGS: -lwhisper -lggml -lggml-base -lggml-cpu -lm -lstdc++ -fopenmp
|
||||||
#cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics
|
#cgo darwin LDFLAGS: -framework Accelerate -framework Metal -framework Foundation -framework CoreGraphics
|
||||||
#include <whisper.h>
|
#include <whisper.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
@ -25,13 +25,13 @@ sourceSets {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tasks.register('copyLibwhisperDynlib', Copy) {
|
tasks.register('copyLibwhisperDynlib', Copy) {
|
||||||
from '../../build'
|
from '../../build/src'
|
||||||
include 'libwhisper.dynlib'
|
include 'libwhisper.dylib'
|
||||||
into 'build/generated/resources/main/darwin'
|
into 'build/generated/resources/main/darwin'
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.register('copyLibwhisperSo', Copy) {
|
tasks.register('copyLibwhisperSo', Copy) {
|
||||||
from '../../build'
|
from '../../build/src'
|
||||||
include 'libwhisper.so'
|
include 'libwhisper.so'
|
||||||
into 'build/generated/resources/main/linux-x86-64'
|
into 'build/generated/resources/main/linux-x86-64'
|
||||||
}
|
}
|
||||||
@ -55,7 +55,12 @@ java {
|
|||||||
withJavadocJar()
|
withJavadocJar()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sourcesJar() {
|
||||||
|
dependsOn copyLibs
|
||||||
|
}
|
||||||
|
|
||||||
jar {
|
jar {
|
||||||
|
dependsOn copyLibs
|
||||||
exclude '**/whisper_java.exp', '**/whisper_java.lib'
|
exclude '**/whisper_java.exp', '**/whisper_java.lib'
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,6 +72,9 @@ tasks.withType(Test) {
|
|||||||
useJUnitPlatform()
|
useJUnitPlatform()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test.dependsOn copyLibs
|
||||||
|
processResources.dependsOn copyLibs
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation "net.java.dev.jna:jna:5.13.0"
|
implementation "net.java.dev.jna:jna:5.13.0"
|
||||||
testImplementation "org.junit.jupiter:junit-jupiter:5.9.2"
|
testImplementation "org.junit.jupiter:junit-jupiter:5.9.2"
|
||||||
|
0
bindings/java/gradlew
vendored
Normal file → Executable file
0
bindings/java/gradlew
vendored
Normal file → Executable file
@ -0,0 +1,24 @@
|
|||||||
|
package io.github.ggerganov.whispercpp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Presets for alignment heads in DTW token timestamps
|
||||||
|
*/
|
||||||
|
public class WhisperConstants {
|
||||||
|
// Alignment heads presets
|
||||||
|
public static final int WHISPER_AHEADS_NONE = 0;
|
||||||
|
public static final int WHISPER_AHEADS_TINY_EN = 1;
|
||||||
|
public static final int WHISPER_AHEADS_TINY = 2;
|
||||||
|
public static final int WHISPER_AHEADS_BASE_EN = 3;
|
||||||
|
public static final int WHISPER_AHEADS_BASE = 4;
|
||||||
|
public static final int WHISPER_AHEADS_SMALL_EN = 5;
|
||||||
|
public static final int WHISPER_AHEADS_SMALL = 6;
|
||||||
|
public static final int WHISPER_AHEADS_MEDIUM_EN = 7;
|
||||||
|
public static final int WHISPER_AHEADS_MEDIUM = 8;
|
||||||
|
public static final int WHISPER_AHEADS_LARGE_V1 = 9;
|
||||||
|
public static final int WHISPER_AHEADS_LARGE_V2 = 10;
|
||||||
|
public static final int WHISPER_AHEADS_LARGE_V3 = 11;
|
||||||
|
public static final int WHISPER_AHEADS_LARGE_V3_TURBO = 12;
|
||||||
|
public static final int WHISPER_AHEADS_CUSTOM = 13;
|
||||||
|
public static final int WHISPER_AHEADS_N_TOP_MOST = 14;
|
||||||
|
public static final int WHISPER_AHEADS_COUNT = 15;
|
||||||
|
}
|
@ -1,7 +1,9 @@
|
|||||||
package io.github.ggerganov.whispercpp;
|
package io.github.ggerganov.whispercpp;
|
||||||
|
|
||||||
|
import com.sun.jna.NativeLong;
|
||||||
import com.sun.jna.Structure;
|
import com.sun.jna.Structure;
|
||||||
import com.sun.jna.ptr.PointerByReference;
|
import com.sun.jna.ptr.PointerByReference;
|
||||||
|
import com.sun.jna.Pointer;
|
||||||
import io.github.ggerganov.whispercpp.ggml.GgmlType;
|
import io.github.ggerganov.whispercpp.ggml.GgmlType;
|
||||||
import io.github.ggerganov.whispercpp.WhisperModel;
|
import io.github.ggerganov.whispercpp.WhisperModel;
|
||||||
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
|
import io.github.ggerganov.whispercpp.params.WhisperContextParams;
|
||||||
@ -9,33 +11,26 @@ import io.github.ggerganov.whispercpp.params.WhisperContextParams;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class WhisperContext extends Structure {
|
public class WhisperContext extends Structure {
|
||||||
int t_load_us = 0;
|
public NativeLong t_load_us;
|
||||||
int t_start_us = 0;
|
public NativeLong t_start_us;
|
||||||
|
|
||||||
/** weight type (FP32 / FP16 / QX) */
|
/** weight type (FP32 / FP16 / QX) */
|
||||||
GgmlType wtype = GgmlType.GGML_TYPE_F16;
|
public GgmlType wtype = GgmlType.GGML_TYPE_F16;
|
||||||
/** intermediate type (FP32 or FP16) */
|
/** intermediate type (FP32 or FP16) */
|
||||||
GgmlType itype = GgmlType.GGML_TYPE_F16;
|
public GgmlType itype = GgmlType.GGML_TYPE_F16;
|
||||||
|
|
||||||
// WhisperModel model;
|
public WhisperContextParams.ByValue params;
|
||||||
public PointerByReference model;
|
|
||||||
// whisper_vocab vocab;
|
public Pointer model;
|
||||||
// whisper_state * state = nullptr;
|
public Pointer vocab;
|
||||||
public PointerByReference vocab;
|
public Pointer state;
|
||||||
public PointerByReference state;
|
|
||||||
|
|
||||||
/** populated by whisper_init_from_file_with_params() */
|
/** populated by whisper_init_from_file_with_params() */
|
||||||
String path_model;
|
public Pointer path_model;
|
||||||
WhisperContextParams params;
|
|
||||||
|
|
||||||
// public static class ByReference extends WhisperContext implements Structure.ByReference {
|
@Override
|
||||||
// }
|
protected List<String> getFieldOrder() {
|
||||||
//
|
return List.of("t_load_us", "t_start_us", "wtype", "itype",
|
||||||
// public static class ByValue extends WhisperContext implements Structure.ByValue {
|
"params", "model", "vocab", "state", "path_model");
|
||||||
// }
|
}
|
||||||
//
|
|
||||||
// @Override
|
|
||||||
// protected List<String> getFieldOrder() {
|
|
||||||
// return List.of("t_load_us", "t_start_us", "wtype", "itype", "model", "vocab", "state", "path_model");
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
@ -43,11 +43,11 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
|
* @param modelPath - absolute path, or just the name (eg: "base", "base-en" or "base.en")
|
||||||
* @param params - params to use when initialising the context
|
* @param params - params to use when initialising the context
|
||||||
*/
|
*/
|
||||||
public void initContext(String modelPath, WhisperContextParams params) throws FileNotFoundException {
|
public void initContext(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException {
|
||||||
initContextImpl(modelPath, params);
|
initContextImpl(modelPath, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initContextImpl(String modelPath, WhisperContextParams params) throws FileNotFoundException {
|
private void initContextImpl(String modelPath, WhisperContextParams.ByValue params) throws FileNotFoundException {
|
||||||
if (ctx != null) {
|
if (ctx != null) {
|
||||||
lib.whisper_free(ctx);
|
lib.whisper_free(ctx);
|
||||||
}
|
}
|
||||||
@ -69,15 +69,13 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
|
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
|
||||||
* Because this function allocates memory for the params, the caller must call either:
|
* Returns a ByValue instance to ensure proper parameter passing to native code.
|
||||||
* - call `whisper_free_context_params()`
|
|
||||||
* - `Native.free(Pointer.nativeValue(pointer));`
|
|
||||||
*/
|
*/
|
||||||
public WhisperContextParams getContextDefaultParams() {
|
public WhisperContextParams.ByValue getContextDefaultParams() {
|
||||||
paramsPointer = lib.whisper_context_default_params_by_ref();
|
WhisperContextParams.ByValue valueParams = new WhisperContextParams.ByValue(
|
||||||
WhisperContextParams params = new WhisperContextParams(paramsPointer);
|
lib.whisper_context_default_params_by_ref());
|
||||||
params.read();
|
valueParams.read();
|
||||||
return params;
|
return valueParams;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -88,7 +86,7 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
*
|
*
|
||||||
* @param strategy - GREEDY
|
* @param strategy - GREEDY
|
||||||
*/
|
*/
|
||||||
public WhisperFullParams getFullDefaultParams(WhisperSamplingStrategy strategy) {
|
public WhisperFullParams.ByValue getFullDefaultParams(WhisperSamplingStrategy strategy) {
|
||||||
Pointer pointer;
|
Pointer pointer;
|
||||||
|
|
||||||
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
|
// whisper_full_default_params_by_ref allocates memory which we need to delete, so only create max 1 pointer for each strategy.
|
||||||
@ -104,7 +102,7 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
pointer = beamParamsPointer;
|
pointer = beamParamsPointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
WhisperFullParams params = new WhisperFullParams(pointer);
|
WhisperFullParams.ByValue params = new WhisperFullParams.ByValue(pointer);
|
||||||
params.read();
|
params.read();
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
@ -138,15 +136,21 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text.
|
||||||
* Not thread safe for same context
|
* Not thread safe for same context
|
||||||
* Uses the specified decoding strategy to obtain the text.
|
* Uses the specified decoding strategy to obtain the text.
|
||||||
*/
|
*/
|
||||||
public String fullTranscribe(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
public String fullTranscribe(WhisperFullParams.ByValue whisperParams, float[] audioData) throws IOException {
|
||||||
if (ctx == null) {
|
if (ctx == null) {
|
||||||
throw new IllegalStateException("Model not initialised");
|
throw new IllegalStateException("Model not initialised");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(
|
||||||
|
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
|
||||||
|
valueParams.read();
|
||||||
|
*/
|
||||||
|
|
||||||
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
||||||
throw new IOException("Failed to process audio");
|
throw new IOException("Failed to process audio");
|
||||||
}
|
}
|
||||||
@ -163,12 +167,17 @@ public class WhisperCpp implements AutoCloseable {
|
|||||||
|
|
||||||
return str.toString().trim();
|
return str.toString().trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
public List<WhisperSegment> fullTranscribeWithTime(WhisperFullParams whisperParams, float[] audioData) throws IOException {
|
||||||
if (ctx == null) {
|
if (ctx == null) {
|
||||||
throw new IllegalStateException("Model not initialised");
|
throw new IllegalStateException("Model not initialised");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lib.whisper_full(ctx, whisperParams, audioData, audioData.length) != 0) {
|
WhisperFullParams.ByValue valueParams = new WhisperFullParams.ByValue(
|
||||||
|
lib.whisper_full_default_params_by_ref(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH.ordinal()));
|
||||||
|
valueParams.read();
|
||||||
|
|
||||||
|
if (lib.whisper_full(ctx, valueParams, audioData, audioData.length) != 0) {
|
||||||
throw new IOException("Failed to process audio");
|
throw new IOException("Failed to process audio");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ public interface WhisperCppJnaLibrary extends Library {
|
|||||||
* @param params Pointer to whisper_context_params
|
* @param params Pointer to whisper_context_params
|
||||||
* @return Whisper context on success, null on failure
|
* @return Whisper context on success, null on failure
|
||||||
*/
|
*/
|
||||||
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams params);
|
Pointer whisper_init_from_file_with_params(String path_model, WhisperContextParams.ByValue params);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Allocate (almost) all memory needed for the model by loading from a buffer.
|
* Allocate (almost) all memory needed for the model by loading from a buffer.
|
||||||
@ -180,12 +180,12 @@ public interface WhisperCppJnaLibrary extends Library {
|
|||||||
/**
|
/**
|
||||||
* @return the id of the specified language, returns -1 if not found.
|
* @return the id of the specified language, returns -1 if not found.
|
||||||
* Examples:
|
* Examples:
|
||||||
* "de" -> 2
|
* "de" -> 2
|
||||||
* "german" -> 2
|
* "german" -> 2
|
||||||
*/
|
*/
|
||||||
int whisper_lang_id(String lang);
|
int whisper_lang_id(String lang);
|
||||||
|
|
||||||
/** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */
|
/** @return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found */
|
||||||
String whisper_lang_str(int id);
|
String whisper_lang_str(int id);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -268,20 +268,21 @@ public interface WhisperCppJnaLibrary extends Library {
|
|||||||
void whisper_free_params(Pointer params);
|
void whisper_free_params(Pointer params);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||||
* Not thread safe for same context
|
* Not thread safe for same context
|
||||||
* Uses the specified decoding strategy to obtain the text.
|
* Uses the specified decoding strategy to obtain the text.
|
||||||
*/
|
*/
|
||||||
int whisper_full(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples);
|
int whisper_full(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples);
|
||||||
|
|
||||||
int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
|
public int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams.ByValue params, float[] samples, int n_samples);
|
||||||
|
//int whisper_full_with_state(Pointer ctx, Pointer state, WhisperFullParams params, final float[] samples, int n_samples);
|
||||||
|
|
||||||
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
|
// Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
|
||||||
// Result is stored in the default state of the context
|
// Result is stored in the default state of the context
|
||||||
// Not thread safe if executed in parallel on the same context.
|
// Not thread safe if executed in parallel on the same context.
|
||||||
// It seems this approach can offer some speedup in some cases.
|
// It seems this approach can offer some speedup in some cases.
|
||||||
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||||
int whisper_full_parallel(Pointer ctx, WhisperFullParams params, final float[] samples, int n_samples, int n_processors);
|
int whisper_full_parallel(Pointer ctx, WhisperFullParams.ByValue params, final float[] samples, int n_samples, int n_processors);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of generated text segments.
|
* Number of generated text segments.
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
package io.github.ggerganov.whispercpp.callbacks;
|
||||||
|
|
||||||
|
import com.sun.jna.Callback;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Callback for aborting GGML computation
|
||||||
|
* Maps to the C typedef: bool (*ggml_abort_callback)(void * data)
|
||||||
|
*/
|
||||||
|
public interface GgmlAbortCallback extends Callback {
|
||||||
|
/**
|
||||||
|
* Return true to abort the computation, false to continue
|
||||||
|
*
|
||||||
|
* @param data User data passed to the callback
|
||||||
|
* @return true to abort, false to continue
|
||||||
|
*/
|
||||||
|
boolean invoke(com.sun.jna.Pointer data);
|
||||||
|
}
|
@ -0,0 +1,30 @@
|
|||||||
|
package io.github.ggerganov.whispercpp.params;
|
||||||
|
import com.sun.jna.*;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class WhisperAhead extends Structure {
|
||||||
|
|
||||||
|
public int n_text_layer;
|
||||||
|
|
||||||
|
public int n_head;
|
||||||
|
|
||||||
|
public WhisperAhead() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
public WhisperAhead(int textLayer, int head) {
|
||||||
|
super();
|
||||||
|
this.n_text_layer = textLayer;
|
||||||
|
this.n_head = head;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected List<String> getFieldOrder() {
|
||||||
|
return Arrays.asList("n_text_layer", "n_head");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class ByReference extends WhisperAhead implements Structure.ByReference {}
|
||||||
|
|
||||||
|
public static class ByValue extends WhisperAhead implements Structure.ByValue {}
|
||||||
|
}
|
@ -0,0 +1,41 @@
|
|||||||
|
package io.github.ggerganov.whispercpp.params;
|
||||||
|
import com.sun.jna.*;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class WhisperAheads extends Structure {
|
||||||
|
public NativeLong n_heads;
|
||||||
|
|
||||||
|
public Pointer heads;
|
||||||
|
|
||||||
|
public WhisperAheads() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create alignment heads from an array of WhisperAhead objects
|
||||||
|
*/
|
||||||
|
public void setHeads(WhisperAhead[] aheadsArray) {
|
||||||
|
this.n_heads = new NativeLong(aheadsArray.length);
|
||||||
|
|
||||||
|
int structSize = aheadsArray[0].size();
|
||||||
|
Memory mem = new Memory(structSize * aheadsArray.length);
|
||||||
|
|
||||||
|
for (int i = 0; i < aheadsArray.length; i++) {
|
||||||
|
aheadsArray[i].write();
|
||||||
|
byte[] buffer = aheadsArray[i].getPointer().getByteArray(0, structSize);
|
||||||
|
mem.write(i * structSize, buffer, 0, buffer.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.heads = mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected List<String> getFieldOrder() {
|
||||||
|
return Arrays.asList("n_heads", "heads");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class ByReference extends WhisperAheads implements Structure.ByReference {}
|
||||||
|
|
||||||
|
public static class ByValue extends WhisperAheads implements Structure.ByValue {}
|
||||||
|
}
|
@ -1,7 +1,5 @@
|
|||||||
package io.github.ggerganov.whispercpp.params;
|
package io.github.ggerganov.whispercpp.params;
|
||||||
|
|
||||||
import com.sun.jna.*;
|
import com.sun.jna.*;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@ -11,21 +9,73 @@ import java.util.List;
|
|||||||
* whisper_context_default_params()
|
* whisper_context_default_params()
|
||||||
*/
|
*/
|
||||||
public class WhisperContextParams extends Structure {
|
public class WhisperContextParams extends Structure {
|
||||||
|
|
||||||
public WhisperContextParams(Pointer p) {
|
public WhisperContextParams(Pointer p) {
|
||||||
super(p);
|
super(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Use GPU for inference Number (default = true) */
|
public WhisperContextParams() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Use GPU for inference (default = true) */
|
||||||
public CBool use_gpu;
|
public CBool use_gpu;
|
||||||
|
|
||||||
/** Use GPU for inference Number (default = true) */
|
/** Use flash attention (default = false) */
|
||||||
|
public CBool flash_attn;
|
||||||
|
|
||||||
|
/** CUDA device to use (default = 0) */
|
||||||
|
public int gpu_device;
|
||||||
|
|
||||||
|
/** [EXPERIMENTAL] Enable token-level timestamps with DTW (default = false) */
|
||||||
|
public CBool dtw_token_timestamps;
|
||||||
|
|
||||||
|
/** [EXPERIMENTAL] Alignment heads preset for DTW */
|
||||||
|
public int dtw_aheads_preset;
|
||||||
|
|
||||||
|
/** Number of top layers to use for DTW when using WHISPER_AHEADS_N_TOP_MOST preset */
|
||||||
|
public int dtw_n_top;
|
||||||
|
|
||||||
|
public WhisperAheads.ByValue dtw_aheads;
|
||||||
|
|
||||||
|
/** DTW memory size (internal use) */
|
||||||
|
public NativeLong dtw_mem_size;
|
||||||
|
|
||||||
|
/** Use GPU for inference */
|
||||||
public void useGpu(boolean enable) {
|
public void useGpu(boolean enable) {
|
||||||
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
|
use_gpu = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Use flash attention */
|
||||||
|
public void useFlashAttn(boolean enable) {
|
||||||
|
flash_attn = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Enable DTW token-level timestamps */
|
||||||
|
public void enableDtwTokenTimestamps(boolean enable) {
|
||||||
|
dtw_token_timestamps = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Set DTW alignment heads preset */
|
||||||
|
public void setDtwAheadsPreset(int preset) {
|
||||||
|
dtw_aheads_preset = preset;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected List<String> getFieldOrder() {
|
protected List<String> getFieldOrder() {
|
||||||
return Arrays.asList("use_gpu");
|
return Arrays.asList(
|
||||||
|
"use_gpu",
|
||||||
|
"flash_attn",
|
||||||
|
"gpu_device",
|
||||||
|
"dtw_token_timestamps",
|
||||||
|
"dtw_aheads_preset",
|
||||||
|
"dtw_n_top",
|
||||||
|
"dtw_aheads",
|
||||||
|
"dtw_mem_size"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class ByValue extends WhisperContextParams implements Structure.ByValue {
|
||||||
|
public ByValue() { super(); }
|
||||||
|
public ByValue(Pointer p) { super(p); }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import io.github.ggerganov.whispercpp.callbacks.WhisperEncoderBeginCallback;
|
|||||||
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
|
import io.github.ggerganov.whispercpp.callbacks.WhisperLogitsFilterCallback;
|
||||||
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
|
import io.github.ggerganov.whispercpp.callbacks.WhisperNewSegmentCallback;
|
||||||
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
import io.github.ggerganov.whispercpp.callbacks.WhisperProgressCallback;
|
||||||
|
import io.github.ggerganov.whispercpp.callbacks.GgmlAbortCallback;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@ -16,10 +17,12 @@ import java.util.List;
|
|||||||
*/
|
*/
|
||||||
public class WhisperFullParams extends Structure {
|
public class WhisperFullParams extends Structure {
|
||||||
|
|
||||||
|
public WhisperFullParams() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
public WhisperFullParams(Pointer p) {
|
public WhisperFullParams(Pointer p) {
|
||||||
super(p);
|
super(p);
|
||||||
// super(p, ALIGN_MSVC);
|
|
||||||
// super(p, ALIGN_GNUC);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Sampling strategy for whisper_full() function. */
|
/** Sampling strategy for whisper_full() function. */
|
||||||
@ -69,10 +72,10 @@ public class WhisperFullParams extends Structure {
|
|||||||
single_segment = single ? CBool.TRUE : CBool.FALSE;
|
single_segment = single ? CBool.TRUE : CBool.FALSE;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
||||||
public CBool print_special;
|
public CBool print_special;
|
||||||
|
|
||||||
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
/** Flag to print special tokens (e.g., <SOT>, <EOT>, <BEG>, etc.). (default = false) */
|
||||||
public void printSpecial(boolean enable) {
|
public void printSpecial(boolean enable) {
|
||||||
print_special = enable ? CBool.TRUE : CBool.FALSE;
|
print_special = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
}
|
}
|
||||||
@ -129,6 +132,14 @@ public class WhisperFullParams extends Structure {
|
|||||||
/** Maximum tokens per segment (0, default = no limit) */
|
/** Maximum tokens per segment (0, default = no limit) */
|
||||||
public int max_tokens;
|
public int max_tokens;
|
||||||
|
|
||||||
|
/** [EXPERIMENTAL] Enable debug mode for extra info */
|
||||||
|
public CBool debug_mode;
|
||||||
|
|
||||||
|
/** Enable debug mode */
|
||||||
|
public void enableDebugMode(boolean enable) {
|
||||||
|
debug_mode = enable ? CBool.TRUE : CBool.FALSE;
|
||||||
|
}
|
||||||
|
|
||||||
/** Overwrite the audio context size (0 = use default). */
|
/** Overwrite the audio context size (0 = use default). */
|
||||||
public int audio_ctx;
|
public int audio_ctx;
|
||||||
|
|
||||||
@ -274,6 +285,16 @@ public class WhisperFullParams extends Structure {
|
|||||||
*/
|
*/
|
||||||
public Pointer encoder_begin_callback_user_data;
|
public Pointer encoder_begin_callback_user_data;
|
||||||
|
|
||||||
|
/** Callback used to abort GGML computation */
|
||||||
|
public Pointer abort_callback;
|
||||||
|
|
||||||
|
/** User data for the abort_callback */
|
||||||
|
public Pointer abort_callback_user_data;
|
||||||
|
|
||||||
|
public void setAbortCallback(GgmlAbortCallback callback) {
|
||||||
|
abort_callback = CallbackReference.getFunctionPointer(callback);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Callback by each decoder to filter obtained logits.
|
* Callback by each decoder to filter obtained logits.
|
||||||
* WhisperLogitsFilterCallback
|
* WhisperLogitsFilterCallback
|
||||||
@ -310,17 +331,28 @@ public class WhisperFullParams extends Structure {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected List<String> getFieldOrder() {
|
protected List<String> getFieldOrder() {
|
||||||
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
|
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx",
|
||||||
"no_context", "single_segment", "no_timestamps",
|
"offset_ms", "duration_ms", "translate", "no_context",
|
||||||
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
|
"no_timestamps", "single_segment", "print_special",
|
||||||
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
|
"print_progress", "print_realtime", "print_timestamps",
|
||||||
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
"token_timestamps", "thold_pt", "thold_ptsum", "max_len",
|
||||||
"suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
|
"split_on_word", "max_tokens", "debug_mode", "audio_ctx",
|
||||||
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
|
"tdrz_enable", "suppress_regex", "initial_prompt",
|
||||||
"new_segment_callback", "new_segment_callback_user_data",
|
"prompt_tokens", "prompt_n_tokens", "language", "detect_language",
|
||||||
|
"suppress_blank", "suppress_nst", "temperature",
|
||||||
|
"max_initial_ts", "length_penalty", "temperature_inc",
|
||||||
|
"entropy_thold", "logprob_thold", "no_speech_thold", "greedy",
|
||||||
|
"beam_search", "new_segment_callback", "new_segment_callback_user_data",
|
||||||
"progress_callback", "progress_callback_user_data",
|
"progress_callback", "progress_callback_user_data",
|
||||||
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
"encoder_begin_callback", "encoder_begin_callback_user_data",
|
||||||
|
"abort_callback", "abort_callback_user_data",
|
||||||
"logits_filter_callback", "logits_filter_callback_user_data",
|
"logits_filter_callback", "logits_filter_callback_user_data",
|
||||||
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
|
"grammar_rules", "n_grammar_rules", "i_start_rule", "grammar_penalty");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static class ByValue extends WhisperFullParams implements Structure.ByValue {
|
||||||
|
public ByValue() { super(); }
|
||||||
|
public ByValue(Pointer p) { super(p); }
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,7 @@ class WhisperCppTest {
|
|||||||
float[] floats = new float[b.length / 2];
|
float[] floats = new float[b.length / 2];
|
||||||
|
|
||||||
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
//WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_GREEDY);
|
||||||
WhisperFullParams params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
WhisperFullParams.ByValue params = whisper.getFullDefaultParams(WhisperSamplingStrategy.WHISPER_SAMPLING_BEAM_SEARCH);
|
||||||
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
params.setProgressCallback((ctx, state, progress, user_data) -> System.out.println("progress: " + progress));
|
||||||
params.print_progress = CBool.FALSE;
|
params.print_progress = CBool.FALSE;
|
||||||
//params.initial_prompt = "and so my fellow Americans um, like";
|
//params.initial_prompt = "and so my fellow Americans um, like";
|
||||||
|
@ -33,6 +33,9 @@ mkdir build-em && cd build-em
|
|||||||
emcmake cmake .. && make -j
|
emcmake cmake .. && make -j
|
||||||
|
|
||||||
# run test
|
# run test
|
||||||
|
node ../tests/test-whisper.js
|
||||||
|
|
||||||
|
# For Node.js versions prior to v16.4.0, experimental features need to be enabled:
|
||||||
node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
|
node --experimental-wasm-threads --experimental-wasm-simd ../tests/test-whisper.js
|
||||||
|
|
||||||
# publish npm package
|
# publish npm package
|
||||||
|
@ -102,6 +102,7 @@ endif()
|
|||||||
|
|
||||||
add_library(whisper
|
add_library(whisper
|
||||||
../include/whisper.h
|
../include/whisper.h
|
||||||
|
whisper-arch.h
|
||||||
whisper.cpp
|
whisper.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
141
src/whisper-arch.h
Normal file
141
src/whisper-arch.h
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
enum asr_tensor {
|
||||||
|
ASR_TENSOR_ENC_POS_EMBD,
|
||||||
|
ASR_TENSOR_DEC_POS_EMBD,
|
||||||
|
ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT,
|
||||||
|
ASR_TENSOR_LN_WEIGHT,
|
||||||
|
ASR_TENSOR_LN_BIAS,
|
||||||
|
ASR_TENSOR_CONV1_WEIGHT,
|
||||||
|
ASR_TENSOR_CONV1_BIAS,
|
||||||
|
ASR_TENSOR_CONV2_WEIGHT,
|
||||||
|
ASR_TENSOR_CONV2_BIAS,
|
||||||
|
ASR_TENSOR_LN_POST_WEIGHT,
|
||||||
|
ASR_TENSOR_LN_POST_BIAS,
|
||||||
|
ASR_TENSOR_MLP_LN_WEIGHT,
|
||||||
|
ASR_TENSOR_MLP_LN_BIAS,
|
||||||
|
ASR_TENSOR_MLP_0_WEIGHT,
|
||||||
|
ASR_TENSOR_MLP_0_BIAS,
|
||||||
|
ASR_TENSOR_MLP_2_WEIGHT,
|
||||||
|
ASR_TENSOR_MLP_2_BIAS,
|
||||||
|
ASR_TENSOR_ATTN_LN_WEIGHT,
|
||||||
|
ASR_TENSOR_ATTN_LN_BIAS,
|
||||||
|
ASR_TENSOR_ATTN_QUERY_WEIGHT,
|
||||||
|
ASR_TENSOR_ATTN_QUERY_BIAS,
|
||||||
|
ASR_TENSOR_ATTN_KEY_WEIGHT,
|
||||||
|
ASR_TENSOR_ATTN_VALUE_WEIGHT,
|
||||||
|
ASR_TENSOR_ATTN_VALUE_BIAS,
|
||||||
|
ASR_TENSOR_ATTN_OUT_WEIGHT,
|
||||||
|
ASR_TENSOR_ATTN_OUT_BIAS,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum asr_system {
|
||||||
|
ASR_SYSTEM_ENCODER,
|
||||||
|
ASR_SYSTEM_DECODER,
|
||||||
|
ASR_SYSTEM_CROSS
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::map<asr_system, std::map<asr_tensor, const char *>> ASR_TENSOR_NAMES = {
|
||||||
|
{
|
||||||
|
ASR_SYSTEM_ENCODER,
|
||||||
|
{
|
||||||
|
{ASR_TENSOR_ENC_POS_EMBD, "encoder.positional_embedding"},
|
||||||
|
{ASR_TENSOR_CONV1_WEIGHT, "encoder.conv1.weight"},
|
||||||
|
{ASR_TENSOR_CONV1_BIAS, "encoder.conv1.bias"},
|
||||||
|
{ASR_TENSOR_CONV2_WEIGHT, "encoder.conv2.weight"},
|
||||||
|
{ASR_TENSOR_CONV2_BIAS, "encoder.conv2.bias"},
|
||||||
|
{ASR_TENSOR_LN_WEIGHT, "encoder.ln_post.weight"},
|
||||||
|
{ASR_TENSOR_LN_POST_BIAS, "encoder.ln_post.bias"},
|
||||||
|
{ASR_TENSOR_MLP_LN_WEIGHT, "encoder.blocks.%d.mlp_ln.weight"},
|
||||||
|
{ASR_TENSOR_MLP_LN_BIAS, "encoder.blocks.%d.mlp_ln.bias"},
|
||||||
|
{ASR_TENSOR_MLP_0_WEIGHT, "encoder.blocks.%d.mlp.0.weight"},
|
||||||
|
{ASR_TENSOR_MLP_0_BIAS, "encoder.blocks.%d.mlp.0.bias"},
|
||||||
|
{ASR_TENSOR_MLP_2_WEIGHT, "encoder.blocks.%d.mlp.2.weight"},
|
||||||
|
{ASR_TENSOR_MLP_2_BIAS, "encoder.blocks.%d.mlp.2.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_LN_WEIGHT, "encoder.blocks.%d.attn_ln.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_LN_BIAS, "encoder.blocks.%d.attn_ln.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "encoder.blocks.%d.attn.query.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_QUERY_BIAS, "encoder.blocks.%d.attn.query.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_KEY_WEIGHT, "encoder.blocks.%d.attn.key.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "encoder.blocks.%d.attn.value.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_VALUE_BIAS, "encoder.blocks.%d.attn.value.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_OUT_WEIGHT, "encoder.blocks.%d.attn.out.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_OUT_BIAS, "encoder.blocks.%d.attn.out.bias"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ASR_SYSTEM_DECODER,
|
||||||
|
{
|
||||||
|
{ASR_TENSOR_DEC_POS_EMBD, "decoder.positional_embedding"},
|
||||||
|
{ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, "decoder.token_embedding.weight"},
|
||||||
|
{ASR_TENSOR_LN_WEIGHT, "decoder.ln.weight"},
|
||||||
|
{ASR_TENSOR_LN_BIAS, "decoder.ln.bias"},
|
||||||
|
|
||||||
|
{ASR_TENSOR_MLP_LN_WEIGHT, "decoder.blocks.%d.mlp_ln.weight"},
|
||||||
|
{ASR_TENSOR_MLP_LN_BIAS, "decoder.blocks.%d.mlp_ln.bias"},
|
||||||
|
{ASR_TENSOR_MLP_0_WEIGHT, "decoder.blocks.%d.mlp.0.weight"},
|
||||||
|
{ASR_TENSOR_MLP_0_BIAS, "decoder.blocks.%d.mlp.0.bias"},
|
||||||
|
{ASR_TENSOR_MLP_2_WEIGHT, "decoder.blocks.%d.mlp.2.weight"},
|
||||||
|
{ASR_TENSOR_MLP_2_BIAS, "decoder.blocks.%d.mlp.2.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.attn_ln.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.attn_ln.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.attn.query.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.attn.query.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.attn.key.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.attn.value.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.attn.value.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.attn.out.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.attn.out.bias"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ASR_SYSTEM_CROSS,
|
||||||
|
{
|
||||||
|
{ASR_TENSOR_ATTN_LN_WEIGHT, "decoder.blocks.%d.cross_attn_ln.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_LN_BIAS, "decoder.blocks.%d.cross_attn_ln.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_QUERY_WEIGHT, "decoder.blocks.%d.cross_attn.query.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_QUERY_BIAS, "decoder.blocks.%d.cross_attn.query.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_KEY_WEIGHT, "decoder.blocks.%d.cross_attn.key.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_VALUE_WEIGHT, "decoder.blocks.%d.cross_attn.value.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_VALUE_BIAS, "decoder.blocks.%d.cross_attn.value.bias"},
|
||||||
|
{ASR_TENSOR_ATTN_OUT_WEIGHT, "decoder.blocks.%d.cross_attn.out.weight"},
|
||||||
|
{ASR_TENSOR_ATTN_OUT_BIAS, "decoder.blocks.%d.cross_attn.out.bias"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::map<asr_tensor, ggml_op> ASR_TENSOR_INFO = {
|
||||||
|
{ASR_TENSOR_ENC_POS_EMBD, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_DEC_POS_EMBD, GGML_OP_GET_ROWS},
|
||||||
|
// Note: ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT is also used by GGML_OP_MAT_MUL. Need to figure out a way how to handle
|
||||||
|
// weight tensors that are used by multiple different operators when extra_buffer_type implementations accelerate
|
||||||
|
// more than just GGML_OP_MUL_MAT.
|
||||||
|
{ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, GGML_OP_GET_ROWS},
|
||||||
|
{ASR_TENSOR_LN_WEIGHT, GGML_OP_MUL},
|
||||||
|
{ASR_TENSOR_LN_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_CONV1_WEIGHT, GGML_OP_IM2COL},
|
||||||
|
{ASR_TENSOR_CONV1_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_CONV2_WEIGHT, GGML_OP_IM2COL},
|
||||||
|
{ASR_TENSOR_CONV2_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_LN_POST_WEIGHT, GGML_OP_MUL},
|
||||||
|
{ASR_TENSOR_LN_POST_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_MLP_LN_WEIGHT, GGML_OP_MUL},
|
||||||
|
{ASR_TENSOR_MLP_LN_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_MLP_0_WEIGHT, GGML_OP_MUL_MAT},
|
||||||
|
{ASR_TENSOR_MLP_0_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_MLP_2_WEIGHT, GGML_OP_MUL_MAT},
|
||||||
|
{ASR_TENSOR_MLP_2_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_ATTN_LN_WEIGHT, GGML_OP_MUL},
|
||||||
|
{ASR_TENSOR_ATTN_LN_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_ATTN_QUERY_WEIGHT, GGML_OP_MUL_MAT},
|
||||||
|
{ASR_TENSOR_ATTN_QUERY_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_ATTN_KEY_WEIGHT, GGML_OP_MUL_MAT},
|
||||||
|
{ASR_TENSOR_ATTN_VALUE_WEIGHT, GGML_OP_MUL_MAT},
|
||||||
|
{ASR_TENSOR_ATTN_VALUE_BIAS, GGML_OP_ADD},
|
||||||
|
{ASR_TENSOR_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT},
|
||||||
|
{ASR_TENSOR_ATTN_OUT_BIAS, GGML_OP_ADD},
|
||||||
|
};
|
438
src/whisper.cpp
438
src/whisper.cpp
@ -1,4 +1,5 @@
|
|||||||
#include "whisper.h"
|
#include "whisper.h"
|
||||||
|
#include "whisper-arch.h"
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
@ -18,6 +19,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#define _USE_MATH_DEFINES
|
#define _USE_MATH_DEFINES
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <climits>
|
||||||
#include <codecvt>
|
#include <codecvt>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@ -143,6 +145,21 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||||||
#define WHISPER_MAX_DECODERS 8
|
#define WHISPER_MAX_DECODERS 8
|
||||||
#define WHISPER_MAX_NODES 4096
|
#define WHISPER_MAX_NODES 4096
|
||||||
|
|
||||||
|
static std::string format(const char * fmt, ...) {
|
||||||
|
va_list ap;
|
||||||
|
va_list ap2;
|
||||||
|
va_start(ap, fmt);
|
||||||
|
va_copy(ap2, ap);
|
||||||
|
int size = vsnprintf(NULL, 0, fmt, ap);
|
||||||
|
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
|
||||||
|
std::vector<char> buf(size + 1);
|
||||||
|
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
|
||||||
|
GGML_ASSERT(size2 == size);
|
||||||
|
va_end(ap2);
|
||||||
|
va_end(ap);
|
||||||
|
return std::string(buf.data(), size);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// ggml helpers
|
// ggml helpers
|
||||||
//
|
//
|
||||||
@ -778,10 +795,10 @@ struct whisper_model {
|
|||||||
std::vector<whisper_layer_decoder> layers_decoder;
|
std::vector<whisper_layer_decoder> layers_decoder;
|
||||||
|
|
||||||
// ggml context that contains all the meta information about the model tensors
|
// ggml context that contains all the meta information about the model tensors
|
||||||
struct ggml_context * ctx = nullptr;
|
std::vector<ggml_context *> ctxs;
|
||||||
|
|
||||||
// the model backend data is read-only and can be shared between processors
|
// the model backend data is read-only and can be shared between processors
|
||||||
ggml_backend_buffer_t buffer = nullptr;
|
std::vector<ggml_backend_buffer_t> buffers;
|
||||||
|
|
||||||
// tensors
|
// tensors
|
||||||
int n_loaded;
|
int n_loaded;
|
||||||
@ -1364,28 +1381,109 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
|
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
|
||||||
ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();
|
|
||||||
|
|
||||||
if (!params.use_gpu) {
|
static buft_list_t make_buft_list(whisper_context_params & params) {
|
||||||
return result;
|
// Prio order: GPU -> CPU Extra -> CPU
|
||||||
}
|
buft_list_t buft_list;
|
||||||
|
|
||||||
int cnt = 0;
|
// GPU
|
||||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
if (params.use_gpu) {
|
||||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
int cnt = 0;
|
||||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
if (cnt == 0 || cnt == params.gpu_device) {
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||||
result = ggml_backend_dev_buffer_type(dev);
|
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
||||||
}
|
if (cnt == 0 || cnt == params.gpu_device) {
|
||||||
|
auto * buft = ggml_backend_dev_buffer_type(dev);
|
||||||
|
if (buft) {
|
||||||
|
buft_list.emplace_back(dev, buft);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (++cnt > params.gpu_device) {
|
if (++cnt > params.gpu_device) {
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
// CPU Extra
|
||||||
|
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
|
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
|
||||||
|
auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||||
|
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
|
||||||
|
if (get_extra_bufts_fn) {
|
||||||
|
ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
|
||||||
|
while (extra_bufts && *extra_bufts) {
|
||||||
|
buft_list.emplace_back(cpu_dev, *extra_bufts);
|
||||||
|
++extra_bufts;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CPU
|
||||||
|
buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
|
||||||
|
|
||||||
|
return buft_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
|
||||||
|
bool op_supported = true;
|
||||||
|
|
||||||
|
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
|
||||||
|
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
|
||||||
|
// GPU and default CPU backend support all operators
|
||||||
|
op_supported = true;
|
||||||
|
} else {
|
||||||
|
switch (op) {
|
||||||
|
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
|
||||||
|
case GGML_OP_MUL_MAT: {
|
||||||
|
ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
||||||
|
/*.mem_buffer =*/ nullptr,
|
||||||
|
/*.no_alloc =*/ true,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||||
|
if (!ctx_ptr) {
|
||||||
|
throw std::runtime_error("failed to create ggml context");
|
||||||
|
}
|
||||||
|
ggml_context * ctx = ctx_ptr.get();
|
||||||
|
|
||||||
|
ggml_tensor * op_tensor = nullptr;
|
||||||
|
|
||||||
|
int64_t n_ctx = hparams.n_audio_ctx;
|
||||||
|
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
||||||
|
op_tensor = ggml_mul_mat(ctx, w, b);
|
||||||
|
|
||||||
|
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
||||||
|
GGML_ASSERT(w->buffer == nullptr);
|
||||||
|
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
|
||||||
|
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
|
||||||
|
ggml_backend_buffer_free(w->buffer);
|
||||||
|
w->buffer = nullptr;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
op_supported = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return op_supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
|
||||||
|
GGML_ASSERT(!buft_list.empty());
|
||||||
|
for (const auto & p : buft_list) {
|
||||||
|
ggml_backend_dev_t dev = p.first;
|
||||||
|
ggml_backend_buffer_type_t buft = p.second;
|
||||||
|
if (weight_buft_supported(hparams, w, op, buft, dev)) {
|
||||||
|
return buft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// load the model from a ggml file
|
// load the model from a ggml file
|
||||||
@ -1594,31 +1692,65 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
const ggml_type wtype = wctx.wtype;
|
const ggml_type wtype = wctx.wtype;
|
||||||
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
|
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
|
||||||
|
|
||||||
// create the ggml context
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
const int n_audio_layer = hparams.n_audio_layer;
|
||||||
|
const int n_text_layer = hparams.n_text_layer;
|
||||||
|
|
||||||
|
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
||||||
|
|
||||||
|
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||||
|
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
||||||
|
auto it = ctx_map.find(buft);
|
||||||
|
if (it == ctx_map.end()) {
|
||||||
|
ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
||||||
|
/*.mem_buffer =*/ nullptr,
|
||||||
|
/*.no_alloc =*/ true,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_context * ctx = ggml_init(params);
|
||||||
|
if (!ctx) {
|
||||||
|
throw std::runtime_error("failed to create ggml context");
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx_map[buft] = ctx;
|
||||||
|
model.ctxs.emplace_back(ctx);
|
||||||
|
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
return it->second;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a list of available bufts, in priority order
|
||||||
|
buft_list_t buft_list = make_buft_list(wctx.params);
|
||||||
|
|
||||||
|
auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
|
||||||
|
ggml_op op = ASR_TENSOR_INFO.at(type);
|
||||||
|
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
|
||||||
|
if (!buft) {
|
||||||
|
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_context * ctx = get_ctx(buft);
|
||||||
|
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
|
||||||
|
|
||||||
|
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
|
||||||
|
|
||||||
|
return tensor;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// prepare tensors for the weights
|
||||||
{
|
{
|
||||||
const auto & hparams = model.hparams;
|
ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
|
||||||
const int n_audio_layer = hparams.n_audio_layer;
|
|
||||||
const int n_text_layer = hparams.n_text_layer;
|
|
||||||
|
|
||||||
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
|
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
|
||||||
/*.mem_size =*/ n_tensors*ggml_tensor_overhead(),
|
|
||||||
/*.mem_buffer =*/ nullptr,
|
/*.mem_buffer =*/ nullptr,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
model.ctx = ggml_init(params);
|
ggml_context * ctx = ggml_init(params);
|
||||||
if (!model.ctx) {
|
|
||||||
WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepare tensors for the weights
|
|
||||||
{
|
|
||||||
auto & ctx = model.ctx;
|
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
@ -1638,189 +1770,108 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
model.layers_decoder.resize(n_text_layer);
|
model.layers_decoder.resize(n_text_layer);
|
||||||
|
|
||||||
// encoder
|
// encoder
|
||||||
{
|
model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
|
||||||
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
|
|
||||||
|
|
||||||
model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state);
|
model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
|
||||||
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
|
||||||
|
|
||||||
model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state);
|
model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
|
||||||
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
|
model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
|
||||||
|
|
||||||
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
||||||
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
|
||||||
|
|
||||||
// map by name
|
for (int i = 0; i < n_audio_layer; ++i) {
|
||||||
model.tensors["encoder.positional_embedding"] = model.e_pe;
|
auto & layer = model.layers_encoder[i];
|
||||||
|
|
||||||
model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
|
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
||||||
model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
|
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
||||||
|
|
||||||
model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
|
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
|
||||||
model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
|
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
|
||||||
|
|
||||||
model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
|
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
|
||||||
model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
|
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
||||||
|
|
||||||
for (int i = 0; i < n_audio_layer; ++i) {
|
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
||||||
auto & layer = model.layers_encoder[i];
|
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
||||||
|
|
||||||
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
||||||
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
||||||
|
|
||||||
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
|
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
||||||
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
|
|
||||||
|
|
||||||
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
|
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
||||||
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
||||||
|
|
||||||
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
|
||||||
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
|
||||||
|
|
||||||
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
||||||
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
||||||
|
|
||||||
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
||||||
|
|
||||||
layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
||||||
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
||||||
|
|
||||||
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
|
|
||||||
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
|
|
||||||
|
|
||||||
// map by name
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
|
||||||
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
|
||||||
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
|
||||||
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
|
||||||
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
|
||||||
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
|
||||||
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
|
||||||
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
|
||||||
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder
|
// decoder
|
||||||
{
|
model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
|
||||||
model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
|
|
||||||
|
|
||||||
model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
|
model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
|
||||||
|
|
||||||
model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
|
||||||
model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
|
||||||
|
|
||||||
// map by name
|
for (int i = 0; i < n_text_layer; ++i) {
|
||||||
model.tensors["decoder.positional_embedding"] = model.d_pe;
|
auto & layer = model.layers_decoder[i];
|
||||||
|
|
||||||
model.tensors["decoder.token_embedding.weight"] = model.d_te;
|
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
model.tensors["decoder.ln.weight"] = model.d_ln_w;
|
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
|
||||||
model.tensors["decoder.ln.bias"] = model.d_ln_b;
|
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
|
||||||
|
|
||||||
for (int i = 0; i < n_text_layer; ++i) {
|
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
|
||||||
auto & layer = model.layers_decoder[i];
|
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
|
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
||||||
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
|
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
|
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
||||||
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
||||||
|
|
||||||
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
||||||
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
||||||
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
||||||
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
||||||
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
||||||
|
|
||||||
layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
||||||
layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
|
||||||
layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
|
||||||
|
|
||||||
layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
||||||
|
|
||||||
layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
||||||
layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
||||||
|
|
||||||
layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
|
|
||||||
layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
|
|
||||||
|
|
||||||
// map by name
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
|
|
||||||
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
|
|
||||||
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_free(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate tensors in the backend buffers
|
// allocate tensors in the backend buffers
|
||||||
model.buffer = ggml_backend_alloc_ctx_tensors_from_buft(model.ctx, whisper_default_buffer_type(wctx.params));
|
for (auto & p : ctx_map) {
|
||||||
if (!model.buffer) {
|
ggml_backend_buffer_type_t buft = p.first;
|
||||||
WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__);
|
ggml_context * ctx = p.second;
|
||||||
return false;
|
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
||||||
}
|
if (buf) {
|
||||||
|
model.buffers.emplace_back(buf);
|
||||||
|
|
||||||
size_t size_main = ggml_backend_buffer_get_size(model.buffer);
|
size_t size_main = ggml_backend_buffer_get_size(buf);
|
||||||
WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(model.buffer), size_main / 1e6);
|
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// load weights
|
// load weights
|
||||||
{
|
{
|
||||||
@ -1883,11 +1934,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
//ggml_backend_t backend = wctx.backend;
|
if (ggml_backend_buffer_is_host(tensor->buffer)) {
|
||||||
|
|
||||||
//printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str());
|
|
||||||
|
|
||||||
if (ggml_backend_buffer_is_host(model.buffer)) {
|
|
||||||
// for the CPU and Metal backend, we can read directly into the tensor
|
// for the CPU and Metal backend, we can read directly into the tensor
|
||||||
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
|
||||||
BYTESWAP_TENSOR(tensor);
|
BYTESWAP_TENSOR(tensor);
|
||||||
@ -1900,7 +1947,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
//printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6);
|
|
||||||
total_size += ggml_nbytes(tensor);
|
total_size += ggml_nbytes(tensor);
|
||||||
model.n_loaded++;
|
model.n_loaded++;
|
||||||
}
|
}
|
||||||
@ -1915,7 +1961,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_set_usage(model.buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
for (auto & buf : model.buffers) {
|
||||||
|
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
|
||||||
|
}
|
||||||
|
|
||||||
wctx.t_load_us = ggml_time_us() - t_start_us;
|
wctx.t_load_us = ggml_time_us() - t_start_us;
|
||||||
|
|
||||||
@ -3806,9 +3854,13 @@ void whisper_free_state(struct whisper_state * state) {
|
|||||||
|
|
||||||
void whisper_free(struct whisper_context * ctx) {
|
void whisper_free(struct whisper_context * ctx) {
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
ggml_free(ctx->model.ctx);
|
for (ggml_context * context : ctx->model.ctxs) {
|
||||||
|
ggml_free(context);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_free(ctx->model.buffer);
|
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
|
||||||
|
ggml_backend_buffer_free(buf);
|
||||||
|
}
|
||||||
|
|
||||||
whisper_free_state(ctx->state);
|
whisper_free_state(ctx->state);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user