Compare commits

..

12 Commits

Author SHA1 Message Date
5801b8ac64 cuda : fix HIPBLAS build 2024-06-11 19:13:43 +03:00
99804b0f3e cuda : fix bounds check for src0 rows in MMVQ kernel (#2231)
* cuda : fix bounds check for src0 rows in MMVQ kernel

* Update ggml-cuda/mmvq.cu

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2024-06-11 17:39:01 +03:00
c55964c956 ci : fix CUDA builds (#2232) 2024-06-11 17:21:30 +03:00
20c542c713 whisper : auto-grow working areas for mel_calc_cuda (#2227)
* whisper : auto-grow working areas for mel_calc_cuda, fixes #2226

* whisper : only calculate mel spectrogram on GPU if audio is <= 5 min
2024-06-10 21:51:32 +03:00
c2bdb960cd whisper : free whisper_mel instances (#2220) 2024-06-10 11:00:15 +03:00
87acd6d629 whisper : whisper_state/backend fixes (#2217)
* whisper : fixes

* ci : WHISPER_CUBLAS -> WHISPER_CUDA
2024-06-06 18:51:36 +03:00
f842d31171 whisper : calculate mel spectrogram directly into a ggml_tensor (#2208)
* whisper : calculate mel spectrogram directly into a ggml_tensor

* whisper : remove unused temp buffer from state

* whisper : fix not initializing wstate.embd_enc
2024-06-06 16:20:46 +03:00
ffef323c4c whisper : add CUDA-specific computation mel spectrograms (#2206)
* whisper : use polymorphic class to calculate mel spectrogram

* whisper : add cuda-specific mel spectrogram calculation

* whisper : conditionally compile cufftGetErrorString to avoid warnings

* build : add new files to makefile

* ruby : add new files to conf script

* build : fix typo in makefile

* whisper : suppress cub warning for deprecated C++ std in whisper-mel-cuda
2024-06-04 09:32:23 +03:00
af5833e298 whisper : remove speed_up and phase_vocoder* functions (#2198)
* whisper : fix cast warning

* whisper : remove phase_vocoder functions, ref #2195

* whisper : remove speed_up from whisper_full_params, closes #2195
2024-05-31 11:37:29 +03:00
b87494bb8f readme : add conan badge (#2196)
* Add conan badge

* Fix markdown formating
2024-05-30 15:43:28 +03:00
ad130431aa readme : add install instructions for Conan (#2189) 2024-05-30 15:06:15 +03:00
e130b66642 whisper: use global cache for sin/cos vals and Hann window (#2194)
- also rename Hanning to Hann as it's named after Julius von Hann
 as per Wikipedia
2024-05-29 19:09:21 +03:00
29 changed files with 710 additions and 349 deletions

View File

@ -459,7 +459,7 @@ jobs:
path: build/bin/${{ matrix.build }} path: build/bin/${{ matrix.build }}
windows-cublas: windows-cublas:
runs-on: windows-latest runs-on: windows-2019
strategy: strategy:
matrix: matrix:
@ -498,7 +498,7 @@ jobs:
run: > run: >
cmake -S . -B ./build -A ${{ matrix.arch }} cmake -S . -B ./build -A ${{ matrix.arch }}
-DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_BUILD_TYPE=${{ matrix.build }}
-DWHISPER_CUBLAS=${{ matrix.cublas }} -DWHISPER_CUDA=${{ matrix.cublas }}
-DWHISPER_SDL2=${{ matrix.sdl2 }} -DWHISPER_SDL2=${{ matrix.sdl2 }}
- name: Build ${{ matrix.cuda-toolkit }} - name: Build ${{ matrix.cuda-toolkit }}

View File

@ -364,12 +364,12 @@ if (WHISPER_CUDA)
if (WHISPER_STATIC) if (WHISPER_STATIC)
if (WIN32) if (WIN32)
# As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library # As of 12.3.1 CUDA Tookit for Windows does not offer a static cublas library
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt CUDA::cufft)
else () else ()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static CUDA::cufft_static)
endif() endif()
else() else()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cufft)
endif() endif()
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver) set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} CUDA::cuda_driver)
@ -679,6 +679,10 @@ add_library(${TARGET}
whisper.cpp whisper.cpp
) )
if (WHISPER_CUDA)
target_sources(${TARGET} PRIVATE whisper-mel-cuda.cu)
endif()
include_directories ( include_directories (
. .
) )

View File

@ -286,8 +286,8 @@ ifdef WHISPER_CUDA
CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lcufft -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
WHISPER_OBJ += ggml-cuda.o WHISPER_OBJ += ggml-cuda.o whisper-mel-cuda.o
WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu)) WHISPER_OBJ += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
NVCC = nvcc NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG) NVCCFLAGS = --forward-unknown-to-host-compiler -arch=$(CUDA_ARCH_FLAG)
@ -297,6 +297,9 @@ ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/com
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh) ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
whisper-mel-cuda.o: whisper-mel-cuda.cu whisper.h ggml.h ggml-backend.h whisper-mel.hpp whisper-mel-cuda.hpp
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif endif
ifdef WHISPER_HIPBLAS ifdef WHISPER_HIPBLAS
@ -404,7 +407,7 @@ ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o WHISPER_OBJ += ggml.o ggml-alloc.o ggml-backend.o ggml-quants.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h whisper.o: whisper.cpp whisper.h whisper-mel.hpp ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
ifndef WHISPER_COREML ifndef WHISPER_COREML

View File

@ -4,6 +4,7 @@
[![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions) [![Actions Status](https://github.com/ggerganov/whisper.cpp/workflows/CI/badge.svg)](https://github.com/ggerganov/whisper.cpp/actions)
[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT)
[![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp)
[![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/)
Stable: [v1.6.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126) Stable: [v1.6.2](https://github.com/ggerganov/whisper.cpp/releases/tag/v1.6.0) / [Roadmap | F.A.Q.](https://github.com/ggerganov/whisper.cpp/discussions/126)
@ -502,6 +503,16 @@ docker run -it --rm \
whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav" whisper.cpp:main "./main -m /models/ggml-base.bin -f ./samples/jfk.wav"
``` ```
## Installing with Conan
You can install pre-built binaries for whisper.cpp or build it from source using [Conan](https://conan.io/). Use the following command:
```
conan install --requires="whisper-cpp/[*]" --build=missing
```
For detailed instructions on how to use Conan, please refer to the [Conan documentation](https://docs.conan.io/2/).
## Limitations ## Limitations
- Inference only - Inference only
@ -710,7 +721,7 @@ The [main](examples/main) example provides support for output of karaoke-style m
currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script. currently pronounced word is highlighted. Use the `-wts` argument and run the generated bash script.
This requires to have `ffmpeg` installed. This requires to have `ffmpeg` installed.
Here are a few *"typical"* examples: Here are a few _"typical"_ examples:
```bash ```bash
./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts ./main -m ./models/ggml-base.en.bin -f ./samples/jfk.wav -owts

View File

@ -68,10 +68,6 @@ func (flags *Flags) GetOut() string {
return strings.ToLower(flags.Lookup("out").Value.String()) return strings.ToLower(flags.Lookup("out").Value.String())
} }
func (flags *Flags) IsSpeedup() bool {
return flags.Lookup("speedup").Value.String() == "true"
}
func (flags *Flags) IsTokens() bool { func (flags *Flags) IsTokens() bool {
return flags.Lookup("tokens").Value.String() == "true" return flags.Lookup("tokens").Value.String() == "true"
} }
@ -111,10 +107,6 @@ func (flags *Flags) SetParams(context whisper.Context) error {
fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration) fmt.Fprintf(flags.Output(), "Setting duration to %v\n", duration)
context.SetDuration(duration) context.SetDuration(duration)
} }
if flags.IsSpeedup() {
fmt.Fprintf(flags.Output(), "Setting speedup to true\n")
context.SetSpeedup(true)
}
if threads := flags.GetThreads(); threads != 0 { if threads := flags.GetThreads(); threads != 0 {
fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads) fmt.Fprintf(flags.Output(), "Setting threads to %d\n", threads)
context.SetThreads(threads) context.SetThreads(threads)
@ -146,7 +138,6 @@ func registerFlags(flag *Flags) {
flag.Duration("offset", 0, "Time offset") flag.Duration("offset", 0, "Time offset")
flag.Duration("duration", 0, "Duration of audio to process") flag.Duration("duration", 0, "Duration of audio to process")
flag.Uint("threads", 0, "Number of threads to use") flag.Uint("threads", 0, "Number of threads to use")
flag.Bool("speedup", false, "Enable speedup")
flag.Uint("max-len", 0, "Maximum segment length in characters") flag.Uint("max-len", 0, "Maximum segment length in characters")
flag.Uint("max-tokens", 0, "Maximum tokens per segment") flag.Uint("max-tokens", 0, "Maximum tokens per segment")
flag.Float64("word-thold", 0, "Maximum segment score") flag.Float64("word-thold", 0, "Maximum segment score")

View File

@ -47,10 +47,6 @@ func (p *Params) SetPrintTimestamps(v bool) {
p.print_timestamps = toBool(v) p.print_timestamps = toBool(v)
} }
func (p *Params) SetSpeedup(v bool) {
p.speed_up = toBool(v)
}
// Set language id // Set language id
func (p *Params) SetLanguage(lang int) error { func (p *Params) SetLanguage(lang int) error {
if lang == -1 { if lang == -1 {
@ -177,9 +173,6 @@ func (p *Params) String() string {
if p.token_timestamps { if p.token_timestamps {
str += " token_timestamps" str += " token_timestamps"
} }
if p.speed_up {
str += " speed_up"
}
return str + ">" return str + ">"
} }

View File

@ -76,11 +76,6 @@ func (context *context) SetTranslate(v bool) {
context.params.SetTranslate(v) context.params.SetTranslate(v)
} }
// Set speedup flag
func (context *context) SetSpeedup(v bool) {
context.params.SetSpeedup(v)
}
func (context *context) SetSplitOnWord(v bool) { func (context *context) SetSplitOnWord(v bool) {
context.params.SetSplitOnWord(v) context.params.SetSplitOnWord(v)
} }

View File

@ -41,7 +41,6 @@ type Context interface {
SetOffset(time.Duration) // Set offset SetOffset(time.Duration) // Set offset
SetDuration(time.Duration) // Set duration SetDuration(time.Duration) // Set duration
SetThreads(uint) // Set number of threads to use SetThreads(uint) // Set number of threads to use
SetSpeedup(bool) // Set speedup flag
SetSplitOnWord(bool) // Set split on word flag SetSplitOnWord(bool) // Set split on word flag
SetTokenThreshold(float32) // Set timestamp token probability threshold SetTokenThreshold(float32) // Set timestamp token probability threshold
SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold

View File

@ -20,7 +20,7 @@ public interface WhisperCppJnaLibrary extends Library {
* @return Whisper context on success, null on failure * @return Whisper context on success, null on failure
*/ */
Pointer whisper_init_from_file(String path_model); Pointer whisper_init_from_file(String path_model);
/** /**
* Provides default params which can be used with `whisper_init_from_file_with_params()` etc. * Provides default params which can be used with `whisper_init_from_file_with_params()` etc.
* Because this function allocates memory for the params, the caller must call either: * Because this function allocates memory for the params, the caller must call either:
@ -304,14 +304,6 @@ public interface WhisperCppJnaLibrary extends Library {
/** Language id associated with the provided state */ /** Language id associated with the provided state */
int whisper_full_lang_id_from_state(Pointer state); int whisper_full_lang_id_from_state(Pointer state);
/**
* Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
* The resulting spectrogram is stored inside the default state of the provided whisper context.
* @return 0 on success
*/
int whisper_pcm_to_mel_phase_vocoder(Pointer ctx, final float[] samples, int n_samples, int n_threads);
int whisper_pcm_to_mel_phase_vocoder_with_state(Pointer ctx, Pointer state, final float[] samples, int n_samples, int n_threads);
/** Get the start time of the specified segment. */ /** Get the start time of the specified segment. */
long whisper_full_get_segment_t0(Pointer ctx, int i_segment); long whisper_full_get_segment_t0(Pointer ctx, int i_segment);

View File

@ -129,14 +129,6 @@ public class WhisperFullParams extends Structure {
/** Maximum tokens per segment (0, default = no limit) */ /** Maximum tokens per segment (0, default = no limit) */
public int max_tokens; public int max_tokens;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public CBool speed_up;
/** Flag to speed up the audio by 2x using Phase Vocoder. (default = false) */
public void speedUp(boolean enable) {
speed_up = enable ? CBool.TRUE : CBool.FALSE;
}
/** Overwrite the audio context size (0 = use default). */ /** Overwrite the audio context size (0 = use default). */
public int audio_ctx; public int audio_ctx;
@ -321,7 +313,7 @@ public class WhisperFullParams extends Structure {
return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate", return Arrays.asList("strategy", "n_threads", "n_max_text_ctx", "offset_ms", "duration_ms", "translate",
"no_context", "single_segment", "no_timestamps", "no_context", "single_segment", "no_timestamps",
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",

View File

@ -1,6 +1,7 @@
require 'mkmf' require 'mkmf'
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper-mel.hpp')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .")
system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .") system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml-impl.h')} .")

View File

@ -311,12 +311,6 @@ static VALUE ruby_whisper_params_get_split_on_word(VALUE self) {
static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, split_on_word, value) BOOL_PARAMS_SETTER(self, split_on_word, value)
} }
static VALUE ruby_whisper_params_get_speed_up(VALUE self) {
BOOL_PARAMS_GETTER(self, speed_up)
}
static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, speed_up, value)
}
static VALUE ruby_whisper_params_get_diarize(VALUE self) { static VALUE ruby_whisper_params_get_diarize(VALUE self) {
ruby_whisper_params *rwp; ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp); Data_Get_Struct(self, ruby_whisper_params, rwp);
@ -408,8 +402,6 @@ void Init_whisper() {
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);
rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1);
rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0);
rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1);
rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0);
rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1);

View File

@ -117,13 +117,6 @@ class TestWhisper < Test::Unit::TestCase
assert !@params.split_on_word assert !@params.split_on_word
end end
def test_speed_up
@params.speed_up = true
assert @params.speed_up
@params.speed_up = false
assert !@params.speed_up
end
def test_whisper def test_whisper
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new params = Whisper::Params.new

View File

@ -25,7 +25,6 @@ struct whisper_params {
float entropy_thold = 2.4f; float entropy_thold = 2.4f;
float logprob_thold = -1.0f; float logprob_thold = -1.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool diarize = false; bool diarize = false;
bool output_txt = false; bool output_txt = false;
@ -232,8 +231,6 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.greedy.best_of = params.best_of; wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size; wparams.beam_search.beam_size = params.beam_size;

View File

@ -38,7 +38,6 @@ struct whisper_params {
grammar_parser::parse_state grammar_parsed; grammar_parser::parse_state grammar_parsed;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -76,7 +75,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
@ -115,7 +113,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -165,7 +162,6 @@ std::string transcribe(
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.temperature = 0.4f; wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f; wparams.temperature_inc = 1.0f;
@ -371,7 +367,6 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.prompt_tokens = k_tokens.data(); wparams.prompt_tokens = k_tokens.data();
wparams.prompt_n_tokens = k_tokens.size(); wparams.prompt_n_tokens = k_tokens.size();

View File

@ -185,7 +185,7 @@ private:
// It is assumed that PCM data is normalized to a range from -1 to 1 // It is assumed that PCM data is normalized to a range from -1 to 1
bool write_audio(const float * data, size_t length) { bool write_audio(const float * data, size_t length) {
for (size_t i = 0; i < length; ++i) { for (size_t i = 0; i < length; ++i) {
const int16_t intSample = data[i] * 32767; const int16_t intSample = int16_t(data[i] * 32767);
file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t)); file.write(reinterpret_cast<const char *>(&intSample), sizeof(int16_t));
dataSize += sizeof(int16_t); dataSize += sizeof(int16_t);
} }

View File

@ -26,7 +26,6 @@ struct whisper_params {
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -70,7 +69,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
@ -102,7 +100,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -184,7 +181,6 @@ json unguided_transcription(struct whisper_context * ctx, audio_async &audio, js
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.suppress_non_speech_tokens = true; wparams.suppress_non_speech_tokens = true;
// run the transformer and a single decoding pass // run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
@ -223,7 +219,6 @@ json guided_transcription(struct whisper_context * ctx, audio_async &audio, cons
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
// TODO: Do some time testing. Does an overly long prompt slow down processing? // TODO: Do some time testing. Does an overly long prompt slow down processing?
// Set up command sets/precompute prompts // Set up command sets/precompute prompts

View File

@ -47,7 +47,6 @@ struct whisper_params {
float temperature = 0.0f; float temperature = 0.0f;
float temperature_inc = 0.2f; float temperature_inc = 0.2f;
bool speed_up = false;
bool debug_mode = false; bool debug_mode = false;
bool translate = false; bool translate = false;
bool detect_language = false; bool detect_language = false;
@ -138,7 +137,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); } else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); }
else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); } else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -206,7 +204,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature);
fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
@ -1106,7 +1103,6 @@ int main(int argc, char ** argv) {
wparams.split_on_word = params.split_on_word; wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode; wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

View File

@ -61,7 +61,6 @@ struct whisper_params {
float temperature = 0.00f; float temperature = 0.00f;
float temperature_inc = 0.20f; float temperature_inc = 0.20f;
bool speed_up = false;
bool debug_mode = false; bool debug_mode = false;
bool translate = false; bool translate = false;
bool detect_language = false; bool detect_language = false;
@ -112,7 +111,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
@ -159,7 +157,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
// else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
@ -768,7 +765,6 @@ int main(int argc, char ** argv) {
wparams.split_on_word = params.split_on_word; wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.debug_mode = params.debug_mode; wparams.debug_mode = params.debug_mode;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

View File

@ -27,7 +27,6 @@ struct whisper_params {
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool no_fallback = false; bool no_fallback = false;
bool print_special = false; bool print_special = false;
@ -62,7 +61,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
@ -100,7 +98,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
@ -314,7 +311,6 @@ int main(int argc, char ** argv) {
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.tdrz_enable = params.tinydiarize; // [TDRZ]

View File

@ -59,7 +59,6 @@ struct whisper_params {
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -100,7 +99,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); } else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
@ -149,7 +147,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers); fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -205,7 +202,6 @@ std::string transcribe(
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return ""; return "";

View File

@ -26,7 +26,6 @@ struct whisper_params {
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -60,7 +59,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
@ -96,7 +94,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -132,7 +129,6 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return ""; return "";

View File

@ -26,7 +26,6 @@ struct whisper_params {
float grammar_penalty = 100.0f; float grammar_penalty = 100.0f;
bool speed_up = false;
bool translate = false; bool translate = false;
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
@ -57,7 +56,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
@ -89,7 +87,6 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }

View File

@ -75,7 +75,7 @@ static __global__ void mul_mat_vec_q(
tmp[j][i] = warp_reduce_sum(tmp[j][i]); tmp[j][i] = warp_reduce_sum(tmp[j][i]);
} }
if (threadIdx.x < rows_per_cuda_block) { if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
} }
} }

363
whisper-mel-cuda.cu Normal file
View File

@ -0,0 +1,363 @@
#define CUB_IGNORE_DEPRECATED_CPP_DIALECT
#include "whisper-mel-cuda.hpp"
#include "whisper.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cufft.h>
#include <cublas_v2.h>
#include <cuComplex.h>
#include <cub/device/device_reduce.cuh>
#include <device_launch_parameters.h>
#include <algorithm>
#if defined(_MSC_VER)
#pragma warning(disable: 4324) // added padding
#endif
#ifndef NDEBUG
# define DO_CHECKS 1
#else
# define DO_CHECKS 0
#endif
namespace {
#if DO_CHECKS
const char* cufftGetErrorString(cufftResult_t res) {
switch (res) {
case CUFFT_SUCCESS: return "The cuFFT operation was successful";
case CUFFT_INVALID_PLAN: return "cuFFT was passed an invalid plan handle";
case CUFFT_ALLOC_FAILED: return "cuFFT failed to allocate GPU or CPU memory";
case CUFFT_INVALID_TYPE: return "No longer used";
case CUFFT_INVALID_VALUE: return "User specified an invalid pointer or parameter";
case CUFFT_INTERNAL_ERROR: return "Driver or internal cuFFT library error";
case CUFFT_EXEC_FAILED: return "Failed to execute an FFT on the GPU";
case CUFFT_SETUP_FAILED: return "The cuFFT library failed to initialize";
case CUFFT_INVALID_SIZE: return "User specified an invalid transform size";
case CUFFT_UNALIGNED_DATA: return "No longer used";
case CUFFT_INCOMPLETE_PARAMETER_LIST: return "Missing parameters in call";
case CUFFT_INVALID_DEVICE: return "Execution of a plan was on different GPU than plan creation";
case CUFFT_PARSE_ERROR: return "Internal plan database error";
case CUFFT_NO_WORKSPACE: return "No workspace has been provided prior to plan execution";
case CUFFT_NOT_IMPLEMENTED: return "Function does not implement functionality for parameters given.";
case CUFFT_LICENSE_ERROR: return "Used in previous versions.";
case CUFFT_NOT_SUPPORTED: return "Operation is not supported for parameters given.";
default: return "Unknown error";
}
}
# define CUDA_CHECK_GEN(err, success, error_fn) \
do { \
auto err_ = (err); \
if (err_ != (success)) { \
fprintf(stderr, "%s %s:%d - %s\n", #err, __FILE__, __LINE__, error_fn(err_)); \
} \
} while (0)
#else
# define CUDA_CHECK_GEN(err, success, error_fn) err
#endif
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublasGetStatusString)
#define CUFFT_CHECK(err) CUDA_CHECK_GEN(err, CUFFT_SUCCESS, cufftGetErrorString)
__global__ void k_fill_stft_input(
const float * padded_samples,
const int n_frames,
const float * hann_window,
float * stft_in
) {
auto y = blockIdx.y * blockDim.y + threadIdx.y;
// if (y >= n_frames) return;
auto x = blockIdx.x * blockDim.x + threadIdx.x;
// if (x >= WHISPER_N_FFT) return;
auto line = padded_samples + y * WHISPER_HOP_LENGTH;
auto outLine = stft_in + y * WHISPER_N_FFT;
outLine[x] = line[x] * hann_window[x];
}
__global__ void k_calc_magnitudes(
const cuComplex* stft_out,
const int n_frames,
float * magnitudes
) {
auto y = blockIdx.y * blockDim.y + threadIdx.y;
// if (y >= n_frames) return;
auto x = blockIdx.x * blockDim.x + threadIdx.x;
// if (x >= WHISPER_N_FFT_HALF) return;
auto idx = y * WHISPER_N_FFT_HALF + x;
auto r = stft_out[idx].x;
auto i = stft_out[idx].y;
magnitudes[idx] = r * r + i * i;
}
__global__ void k_calc_log_mel(
const float * mel_data,
const int n_mel,
const float * max_val,
float * log_mel
) {
auto x = blockIdx.x * blockDim.x + threadIdx.x;
if (x >= n_mel) return;
float val = mel_data[x];
constexpr float e = 1e-10f;
if (val < e) val = e;
val = log10(val);
const float max = log10(*max_val) - 8.f;
if (val < max) val = max;
log_mel[x] = (val + 4) / 4;
}
void fill_stft_input(
const float * padded_samples,
int n_frames,
const float * hann_window,
float * stft_in,
cudaStream_t stream
) {
dim3 block(WHISPER_N_FFT, 1);
dim3 grid(1, n_frames);
k_fill_stft_input<<<grid, block, 0, stream>>>(padded_samples, n_frames, hann_window, stft_in);
}
void calc_magnitudes(
const cuComplex* stft_out,
int n_frames,
float * magnitudes,
cudaStream_t stream
) {
dim3 block(WHISPER_N_FFT_HALF, 1);
dim3 grid(1, n_frames);
k_calc_magnitudes<<<grid, block, 0, stream>>>(stft_out, n_frames, magnitudes);
}
constexpr auto LOG_MEL_PREFIX_SIZE = 256;
void calc_log_mel(
const float * mel_data,
int n_mel,
void * tempStorage,
int tempStorageSize,
float * log_mel,
cudaStream_t stream
) {
float * max_val = reinterpret_cast<float *>(tempStorage);
void * maxTemp = reinterpret_cast<char*>(tempStorage) + LOG_MEL_PREFIX_SIZE;
size_t nbytes = size_t(tempStorageSize - LOG_MEL_PREFIX_SIZE);
cub::DeviceReduce::Max(maxTemp, nbytes, mel_data, max_val, n_mel, stream);
int block = 256;
int grid = (n_mel + block - 1) / block;
k_calc_log_mel<<<grid, block, 0, stream>>>(mel_data, n_mel, max_val, log_mel);
}
class mel_calc_cuda : public whisper_mel_calc {
const int m_n_mel;
ggml_backend_t m_backend = nullptr;
cudaStream_t m_stream = nullptr;
cublasHandle_t m_cublas_handle = nullptr;
float * m_hann_window = nullptr;
float * m_filters = nullptr;
// max samples for which we have allocated memory for the temp working areas below (cufft, log_mel)
int m_n_max_samples = 0;
size_t m_cufft_workspace_size = 0;
void * m_cufft_workspace = nullptr;
size_t m_log_mel_temp_storage_size = 0;
void * m_log_mel_temp_storage = nullptr;
public:
mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters)
: m_n_mel(filters.n_mel)
, m_backend(backend)
{
if (filters.n_fft != WHISPER_N_FFT_HALF) {
throw std::invalid_argument("MelFilters n_frames must be WHISPER_N_FFT_HALF");
}
assert(filters.data.size() == filters.n_mel * WHISPER_N_FFT_HALF);
CUDA_CHECK(cudaStreamCreate(&m_stream));
CUBLAS_CHECK(cublasCreate(&m_cublas_handle));
CUBLAS_CHECK(cublasSetMathMode(m_cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
CUBLAS_CHECK(cublasSetStream(m_cublas_handle, m_stream));
// create Hann window
{
auto hw = whisper_mel_calc::hann_window();
CUDA_CHECK(cudaMallocAsync(&m_hann_window, hw.len * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(m_hann_window, hw.data, hw.len * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
// fill filters
{
auto& f = filters.data;
CUDA_CHECK(cudaMallocAsync(&m_filters, f.size() * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(m_filters, f.data(), f.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
}
// preallocate working areas enough for the most common cases (<= 30s)
ensure_working_areas(WHISPER_N_SAMPLES);
}
~mel_calc_cuda() {
CUDA_CHECK(cudaStreamSynchronize(m_stream));
CUDA_CHECK(cudaStreamDestroy(m_stream));
CUDA_CHECK(cudaFree(m_hann_window));
CUDA_CHECK(cudaFree(m_cufft_workspace));
CUDA_CHECK(cudaFree(m_filters));
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
}
void ensure_working_areas(int n_samples) {
if (n_samples <= m_n_max_samples) {
return;
}
const auto max_padded_samples = n_samples + WHISPER_N_SAMPLES + WHISPER_N_FFT;
const auto max_frames = 1 + (max_padded_samples - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
// cufft workspace
{
if (m_cufft_workspace) {
CUDA_CHECK(cudaFree(m_cufft_workspace));
m_cufft_workspace_size = 0;
m_cufft_workspace = nullptr;
}
CUFFT_CHECK(cufftEstimate1d(WHISPER_N_FFT, CUFFT_R2C, max_frames, &m_cufft_workspace_size));
CUDA_CHECK(cudaMallocAsync(&m_cufft_workspace, m_cufft_workspace_size, m_stream));
}
// device reduce working area
{
if (m_log_mel_temp_storage) {
CUDA_CHECK(cudaFree(m_log_mel_temp_storage));
m_log_mel_temp_storage_size = 0;
m_log_mel_temp_storage = nullptr;
}
const auto max_mels = 160;
size_t nbytes = 0;
float* temp = nullptr;
cub::DeviceReduce::Max(nullptr, nbytes, temp, temp, max_frames * max_mels);
m_log_mel_temp_storage_size = nbytes + LOG_MEL_PREFIX_SIZE;
CUDA_CHECK(cudaMallocAsync(&m_log_mel_temp_storage, m_log_mel_temp_storage_size, m_stream));
}
m_n_max_samples = n_samples;
}
virtual whisper_mel calculate(whisper_span<const float> samples, int /*n_threads*/) override {
ensure_working_areas(samples.len);
const size_t mirror_pad = WHISPER_N_FFT / 2;
const size_t padded_size = samples.len + WHISPER_N_SAMPLES + WHISPER_N_FFT;
// pad
std::vector<float> padded_samples(padded_size);
std::reverse_copy(samples.data + 1, samples.data + 1 + mirror_pad, padded_samples.begin()); // reflect
std::copy(samples.data, samples.data + samples.len, padded_samples.begin() + mirror_pad); // copy
// fill the rest of the data
// it should canonically be mirrored at the end as well,
// but we just assume the last MEL_FRAME_SIZE/2 samples are zeros
std::fill(padded_samples.begin() + mirror_pad + samples.len, padded_samples.end(), 0.f);
const auto n_frames = 1 + (padded_samples.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
float * cu_padded_samples = nullptr;
CUDA_CHECK(cudaMallocAsync(&cu_padded_samples, padded_samples.size() * sizeof(float), m_stream));
CUDA_CHECK(cudaMemcpyAsync(cu_padded_samples, padded_samples.data(), padded_samples.size() * sizeof(float), cudaMemcpyHostToDevice, m_stream));
float * stft_in = nullptr; // contiguous buffer for stft input
CUDA_CHECK(cudaMallocAsync(&stft_in, n_frames * WHISPER_N_FFT * sizeof(float), m_stream));
fill_stft_input(cu_padded_samples, int(n_frames), m_hann_window, stft_in, m_stream);
cufftComplex* stft_out;
CUDA_CHECK(cudaMallocAsync(&stft_out, n_frames * WHISPER_N_FFT_HALF * sizeof(cufftComplex), m_stream));
cufftHandle plan;
CUFFT_CHECK(cufftCreate(&plan));
CUFFT_CHECK(cufftSetAutoAllocation(plan, 0));
{
size_t waSize;
CUFFT_CHECK(cufftMakePlan1d(plan, WHISPER_N_FFT, CUFFT_R2C, int(n_frames), &waSize));
assert(waSize <= m_cufft_workspace_size);
CUFFT_CHECK(cufftSetWorkArea(plan, m_cufft_workspace));
CUFFT_CHECK(cufftSetStream(plan, m_stream));
}
CUFFT_CHECK(cufftExecR2C(plan, stft_in, stft_out));
const auto n_mag_frames = n_frames - 1; // drop last frame
float * magnitudes;
CUDA_CHECK(cudaMallocAsync(&magnitudes, n_mag_frames * WHISPER_N_FFT_HALF * sizeof(float), m_stream));
calc_magnitudes(stft_out, int(n_mag_frames), magnitudes, m_stream);
float * mel_data = nullptr;
CUDA_CHECK(cudaMallocAsync(&mel_data, m_n_mel * n_mag_frames * sizeof(float), m_stream));
const float fone = 1.0f, fzero = 0.0f;
CUBLAS_CHECK(cublasSgemm(m_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
int(n_mag_frames), m_n_mel, WHISPER_N_FFT_HALF,
&fone,
magnitudes, WHISPER_N_FFT_HALF,
m_filters, WHISPER_N_FFT_HALF,
&fzero,
mel_data, int(n_mag_frames)));
whisper_mel ret;
// Calculate semi-padded sample length to ensure compatibility
int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel);
assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float));
float* log_mels = reinterpret_cast<float*>(ret.tensor->data);
calc_log_mel(
mel_data, int(m_n_mel * n_mag_frames),
m_log_mel_temp_storage , int(m_log_mel_temp_storage_size),
log_mels, m_stream);
CUDA_CHECK(cudaStreamSynchronize(m_stream));
// cleanup
CUFFT_CHECK(cufftDestroy(plan));
CUDA_CHECK(cudaFreeAsync(mel_data, m_stream));
CUDA_CHECK(cudaFreeAsync(magnitudes, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_out, m_stream));
CUDA_CHECK(cudaFreeAsync(stft_in, m_stream));
CUDA_CHECK(cudaFreeAsync(cu_padded_samples, m_stream));
return ret;
}
};
}
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters) {
if (filters.n_fft != WHISPER_N_FFT_HALF) {
return nullptr;
}
return new mel_calc_cuda(backend, filters);
}

3
whisper-mel-cuda.hpp Normal file
View File

@ -0,0 +1,3 @@
#include "whisper-mel.hpp"
whisper_mel_calc * whisper_mel_calc_create_cuda(ggml_backend_t backend, const whisper_filters & filters);

34
whisper-mel.hpp Normal file
View File

@ -0,0 +1,34 @@
#pragma once
#include "ggml-backend.h"
#include <vector>
struct whisper_mel {
int n_len_org = 0;
ggml_context * ctx = nullptr;
ggml_tensor * tensor = nullptr;
ggml_backend_buffer_t buffer = nullptr;
};
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel);
void whisper_mel_free(whisper_mel & mel);
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
template <typename T>
struct whisper_span {
T * data;
int len;
};
struct whisper_mel_calc {
virtual ~whisper_mel_calc();
virtual whisper_mel calculate(whisper_span<const float> samples, int n_threads) = 0;
static whisper_span<const float> hann_window();
};

View File

@ -10,6 +10,7 @@
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
#include "ggml-cuda.h" #include "ggml-cuda.h"
#include "whisper-mel-cuda.hpp"
#endif #endif
#ifdef GGML_USE_SYCL #ifdef GGML_USE_SYCL
@ -24,6 +25,8 @@
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "whisper-mel.hpp"
#include <atomic> #include <atomic>
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
@ -380,21 +383,6 @@ static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head); static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
struct whisper_vocab { struct whisper_vocab {
using id = int32_t; using id = int32_t;
using token = std::string; using token = std::string;
@ -813,6 +801,8 @@ struct whisper_state {
whisper_kv_cache kv_pad; whisper_kv_cache kv_pad;
whisper_mel mel; whisper_mel mel;
whisper_mel_calc * mel_calc = nullptr;
whisper_mel_calc * mel_calc_fallback = nullptr;
whisper_batch batch; whisper_batch batch;
@ -833,7 +823,6 @@ struct whisper_state {
struct ggml_tensor * embd_enc = nullptr; struct ggml_tensor * embd_enc = nullptr;
// helpers for GPU offloading // helpers for GPU offloading
std::vector<float> inp_mel;
std::vector<float> inp_mask; std::vector<float> inp_mask;
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
@ -904,7 +893,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
BYTESWAP_VALUE(dest); BYTESWAP_VALUE(dest);
} }
static bool kv_cache_init( static bool whisper_kv_cache_init(
struct whisper_kv_cache & cache, struct whisper_kv_cache & cache,
ggml_backend_t backend, ggml_backend_t backend,
ggml_type wtype, ggml_type wtype,
@ -947,7 +936,7 @@ static bool kv_cache_init(
return true; return true;
} }
static void kv_cache_free(struct whisper_kv_cache & cache) { static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
ggml_free(cache.ctx); ggml_free(cache.ctx);
ggml_backend_buffer_free(cache.buffer); ggml_backend_buffer_free(cache.buffer);
cache.ctx = nullptr; cache.ctx = nullptr;
@ -1261,9 +1250,12 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params
} }
#endif #endif
GGML_UNUSED(params);
if (backend_gpu) { if (backend_gpu) {
return backend_gpu; return backend_gpu;
} }
return ggml_backend_cpu_init(); return ggml_backend_cpu_init();
} }
@ -1825,7 +1817,8 @@ static bool whisper_encode_external(const whisper_state & wstate) {
static struct ggml_cgraph * whisper_build_graph_conv( static struct ggml_cgraph * whisper_build_graph_conv(
whisper_context & wctx, whisper_context & wctx,
whisper_state & wstate) { whisper_state & wstate,
const int mel_offset) {
const auto & model = wctx.model; const auto & model = wctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@ -1844,9 +1837,32 @@ static struct ggml_cgraph * whisper_build_graph_conv(
ggml_cgraph * gf = ggml_new_graph(ctx0); ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); ggml_tensor * mel_inp = wstate.mel.tensor;
ggml_set_name(mel, "mel"); ggml_tensor * mel;
ggml_set_input(mel); if (mel_inp) {
const int n_len = int(mel_inp->ne[0]);
const int out_s = 2 * n_ctx;
const int i0 = std::min(mel_offset, n_len);
const int i1 = std::min(mel_offset + out_s, n_len);
const int mel_s = i1 - i0;
assert(mel_inp->type == GGML_TYPE_F32);
assert(mel_inp->ne[1] == n_mels);
ggml_tensor * cur = ggml_view_2d(ctx0, mel_inp, out_s, n_mels, mel_inp->nb[1], ggml_row_size(mel_inp->type, i0));
if (mel_s < out_s) {
mel = ggml_pad(ctx0, cur, out_s - mel_s, 0, 0, 0);
}
else {
mel = ggml_cont(ctx0, cur);
}
}
else {
// just create some tensor so that the graph/buffer size estimation is correct
mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels);
}
ggml_set_name(mel, "mel"); // used with external encoding
struct ggml_tensor * cur = nullptr; struct ggml_tensor * cur = nullptr;
@ -2228,45 +2244,21 @@ static bool whisper_encode_internal(
{ {
auto & alloc = wstate.alloc_conv.alloc; auto & alloc = wstate.alloc_conv.alloc;
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate); ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate, mel_offset);
if (!ggml_gallocr_alloc_graph(alloc, gf)) { if (!ggml_gallocr_alloc_graph(alloc, gf)) {
// should never happen as we pre-allocate the memory // should never happen as we pre-allocate the memory
return false; return false;
} }
struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
return false;
// set the input
{
const auto & mel_inp = wstate.mel;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
assert(mel->type == GGML_TYPE_F32);
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
wstate.inp_mel.resize(ggml_nelements(mel));
float * dst = wstate.inp_mel.data();
memset(dst, 0, ggml_nbytes(mel));
const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
for (int j = 0; j < mel_inp.n_mel; ++j) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
}
}
ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
} }
if (!whisper_encode_external(wstate)) { if (whisper_encode_external(wstate)) {
if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
return false; assert(mel->ne[1] == wctx.model.hparams.n_mels);
} GGML_UNUSED(mel);
} else {
#if defined(WHISPER_USE_COREML) #if defined(WHISPER_USE_COREML)
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data); whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
#elif defined(WHISPER_USE_OPENVINO) #elif defined(WHISPER_USE_OPENVINO)
@ -2857,20 +2849,70 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
} }
#define SIN_COS_N_COUNT WHISPER_N_FFT #define SIN_COS_N_COUNT WHISPER_N_FFT
static float sin_vals[SIN_COS_N_COUNT]; namespace {
static float cos_vals[SIN_COS_N_COUNT]; struct whisper_global_cache {
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
float sin_vals[SIN_COS_N_COUNT];
float cos_vals[SIN_COS_N_COUNT];
// In FFT, we frequently use sine and cosine operations with the same values. // Hann window (Use cosf to eliminate difference)
// We can use precalculated values to speed up the process. // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
static void fill_sin_cos_table() { // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
static bool is_filled = false; float hann_window[WHISPER_N_FFT];
if (is_filled) return;
for (int i = 0; i < SIN_COS_N_COUNT; i++) { whisper_global_cache() {
double theta = (2*M_PI*i)/SIN_COS_N_COUNT; fill_sin_cos_table();
sin_vals[i] = sinf(theta); fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
cos_vals[i] = cosf(theta);
} }
is_filled = true;
void fill_sin_cos_table() {
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}
void fill_hann_window(int length, bool periodic, float * output) {
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
} global_cache;
}
// Mel spectrogram
void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel) {
WHISPER_LOG_INFO("%s: n_len = %d, n_len_org = %d, n_mel = %d\n", __func__, n_len, n_len_org, n_mel);
mel.n_len_org = n_len_org;
assert(!mel.ctx);
mel.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true});
mel.tensor = ggml_new_tensor_2d(mel.ctx, GGML_TYPE_F32, n_len, n_mel);
mel.buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(mel.tensor) + ggml_backend_get_alignment(backend));
auto alloc = ggml_tallocr_new(mel.buffer);
ggml_tallocr_alloc(&alloc, mel.tensor);
}
void whisper_mel_free(whisper_mel & mel) {
ggml_free(mel.ctx);
ggml_backend_buffer_free(mel.buffer);
mel.n_len_org = 0;
mel.ctx = nullptr;
mel.tensor = nullptr;
mel.buffer = nullptr;
}
whisper_mel_calc::~whisper_mel_calc() = default; // export vtable
whisper_span<const float> whisper_mel_calc::hann_window() {
return {global_cache.hann_window, WHISPER_N_FFT};
} }
// naive Discrete Fourier Transform // naive Discrete Fourier Transform
@ -2888,8 +2930,8 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
for (int n = 0; n < N; n++) { for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*cos_vals[idx]; // cos(t) re += in[n]*global_cache.cos_vals[idx]; // cos(t)
im -= in[n]*sin_vals[idx]; // sin(t) im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
} }
out[k*2 + 0] = re; out[k*2 + 0] = re;
@ -2940,8 +2982,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
const int sin_cos_step = SIN_COS_N_COUNT / N; const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N/2; k++) { for (int k = 0; k < N/2; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = cos_vals[idx]; // cos(t) float re = global_cache.cos_vals[idx]; // cos(t)
float im = -sin_vals[idx]; // sin(t) float im = -global_cache.sin_vals[idx]; // sin(t)
float re_odd = odd_fft[2*k + 0]; float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1]; float im_odd = odd_fft[2*k + 1];
@ -2954,24 +2996,20 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
} }
} }
static bool hann_window(int length, bool periodic, std::vector<float> & output) { namespace {
if (output.size() < static_cast<size_t>(length)) {
output.resize(length);
}
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset)));
}
return true; struct whisper_mel_data {
} int n_len;
int n_len_org;
int n_mel;
float * data;
};
static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> & hann, const std::vector<float> & samples, void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads, int n_samples, int n_threads,
const whisper_filters & filters, whisper_mel & mel) { const whisper_filters & filters, whisper_mel_data & mel) {
const auto frame_size = WHISPER_N_FFT;
const auto frame_step = WHISPER_HOP_LENGTH;
std::vector<float> fft_in(frame_size, 0.0); std::vector<float> fft_in(frame_size, 0.0);
std::vector<float> fft_out(2 * frame_size); std::vector<float> fft_out(2 * frame_size);
int n_fft = filters.n_fft; int n_fft = filters.n_fft;
@ -2984,7 +3022,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step; const int offset = i * frame_step;
// apply Hanning window (~10% faster) // apply Hann window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j]; fft_in[j] = hann[j] * samples[offset + j];
} }
@ -3036,101 +3074,109 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
} }
} }
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 struct mel_calc_cpu : public whisper_mel_calc {
static bool log_mel_spectrogram( ggml_backend_t m_backend;
whisper_state & wstate, const whisper_filters & m_filters;
const float * samples, mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {}
const int n_samples,
const int /*sample_rate*/,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();
// Hanning window (Use cosf to eliminate difference) // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html whisper_mel calculate(whisper_span<const float> ssamples, int n_threads) override {
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 // Hann window
std::vector<float> hann; const float * hann = global_cache.hann_window;
hann_window(frame_size, true, hann);
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = WHISPER_N_FFT / 2;
// Calculate the length of padding const int n_samples = int(ssamples.len);
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; const float * samples = ssamples.data;
int64_t stage_2_pad = frame_size / 2;
// Initialize a vector and copy data from C array to it. // Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded; std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
// reflective pad 200 samples at the beginning of audio // reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
mel.n_mel = n_mel; whisper_mel_data mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 mel.n_mel = m_filters.n_mel;
// Calculate number of frames + remove the last frame // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
mel.n_len = (samples_padded.size() - frame_size) / frame_step; // Calculate number of frames + remove the last frame
// Calculate semi-padded sample length to ensure compatibility mel.n_len = (samples_padded.size() - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; // Calculate semi-padded sample length to ensure compatibility
mel.data.resize(mel.n_mel * mel.n_len); mel.n_len_org = 1 + (n_samples + stage_2_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH;
std::vector<float> host_mel_data;
{ whisper_mel ret;
std::vector<std::thread> workers(n_threads - 1); whisper_mel_init(ret, m_backend, mel.n_len, mel.n_len_org, mel.n_mel);
for (int iw = 0; iw < n_threads - 1; ++iw) { if (ggml_backend_buffer_is_host(ret.buffer)) {
workers[iw] = std::thread( mel.data = reinterpret_cast<float*>(ret.tensor->data);
log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded, } else {
n_samples + stage_2_pad, frame_size, frame_step, n_threads, host_mel_data.resize(mel.n_len * mel.n_mel);
std::cref(filters), std::ref(mel)); mel.data = host_mel_data.data();
} }
// main thread {
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, hann, samples_padded,
n_samples + stage_2_pad, n_threads,
std::cref(m_filters), std::ref(mel));
}
for (int iw = 0; iw < n_threads - 1; ++iw) { // main thread
workers[iw].join(); log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, n_threads, m_filters, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
} }
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
if (!host_mel_data.empty()) {
// the ret buffer is not host-accessible so we used this temporary buffer and now we need to upload it
ggml_backend_tensor_set(ret.tensor, host_mel_data.data(), 0, ggml_nbytes(ret.tensor));
}
return ret;
} }
};
}
// clamping and normalization whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters) {
double mmax = -1e20; #if defined(GGML_USE_CUDA) && !defined(GGML_USE_HIPBLAS)
for (int i = 0; i < mel.n_mel*mel.n_len; i++) { if (ggml_backend_is_cuda(backend)) {
if (mel.data[i] > mmax) { auto ret = whisper_mel_calc_create_cuda(backend, filters);
mmax = mel.data[i]; // run a warmup to avoid the first kernel launch overhead (thus we get the best perf even on the first run)
} const float warmup[256] = {0};
} ret->calculate({warmup, 256}, 1);
return ret;
mmax -= 8.0; } else
#endif
for (int i = 0; i < mel.n_mel*mel.n_len; i++) { return new mel_calc_cpu(backend, filters);
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
wstate.t_mel_us += ggml_time_us() - t_start_us;
// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}
return true;
} }
// split text into tokens // split text into tokens
@ -3246,8 +3292,6 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) {
#endif #endif
struct whisper_state * whisper_init_state(whisper_context * ctx) { struct whisper_state * whisper_init_state(whisper_context * ctx) {
fill_sin_cos_table();
whisper_state * state = new whisper_state; whisper_state * state = new whisper_state;
state->backend = whisper_backend_init(ctx->params); state->backend = whisper_backend_init(ctx->params);
@ -3257,15 +3301,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
return nullptr; return nullptr;
} }
state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters);
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
// in theory, there can be a case where this is not enough, but in practice it should always be enough // in theory, there can be a case where this is not enough, but in practice it should always be enough
const int factor = 3; const int factor = 3;
if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype,
ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer, ctx->model.hparams.n_text_layer,
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state); whisper_free_state(state);
return nullptr; return nullptr;
} }
@ -3275,11 +3321,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
} }
if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype,
ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer, ctx->model.hparams.n_text_layer,
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
whisper_free_state(state); whisper_free_state(state);
return nullptr; return nullptr;
} }
@ -3289,11 +3335,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
} }
if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype,
ctx->model.hparams.n_audio_state, ctx->model.hparams.n_audio_state,
1, 1,
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state); whisper_free_state(state);
return nullptr; return nullptr;
} }
@ -3305,7 +3351,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// [EXPERIMENTAL] Token-level timestamps with DTW // [EXPERIMENTAL] Token-level timestamps with DTW
if (ctx->params.dtw_token_timestamps) { if (ctx->params.dtw_token_timestamps) {
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) {
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__); WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
whisper_free_state(state); whisper_free_state(state);
return nullptr; return nullptr;
@ -3348,9 +3394,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// conv allocator // conv allocator
{ {
bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend, bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend,
[&]() { [&]() {
return whisper_build_graph_conv(*ctx, *state); return whisper_build_graph_conv(*ctx, *state, 0);
}); });
if (!ok) { if (!ok) {
@ -3364,7 +3410,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// encoder allocator // encoder allocator
if (!whisper_encode_external(*state)) { if (!whisper_encode_external(*state)) {
bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend, bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend,
[&]() { [&]() {
return whisper_build_graph_encoder(*ctx, *state); return whisper_build_graph_encoder(*ctx, *state);
}); });
@ -3380,7 +3426,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// cross allocator // cross allocator
{ {
bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend, bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend,
[&]() { [&]() {
return whisper_build_graph_cross(*ctx, *state); return whisper_build_graph_cross(*ctx, *state);
}); });
@ -3396,7 +3442,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
// decoder allocator // decoder allocator
{ {
bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend, bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend,
[&]() { [&]() {
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
@ -3668,9 +3714,16 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa
void whisper_free_state(struct whisper_state * state) { void whisper_free_state(struct whisper_state * state) {
if (state) { if (state) {
kv_cache_free(state->kv_self); whisper_kv_cache_free(state->kv_self);
kv_cache_free(state->kv_cross); whisper_kv_cache_free(state->kv_cross);
kv_cache_free(state->kv_pad); whisper_kv_cache_free(state->kv_pad);
whisper_mel_free(state->mel);
delete state->mel_calc;
state->mel_calc = nullptr;
delete state->mel_calc_fallback;
state->mel_calc_fallback = nullptr;
#ifdef WHISPER_USE_COREML #ifdef WHISPER_USE_COREML
if (state->ctx_coreml != nullptr) { if (state->ctx_coreml != nullptr) {
@ -3729,11 +3782,37 @@ void whisper_free_params(struct whisper_full_params * params) {
} }
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { const int64_t t_start_us = ggml_time_us();
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1; whisper_mel_free(state->mel);
if (n_samples <= 5 * 60 * WHISPER_SAMPLE_RATE) {
// calculate mel spectrogram for lengths up to 5 minutes on the most optimal mel calculator
state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads);
} else {
// calcuate mel spectrogram for longer audios on the CPU
// 1. gpu calculations may use hundreds of megabytes of memory for longer audios so we're being conservative
// with our gpu demands
// 2. the time to transcribe audios this long will be dominated by the decoding time, so the mel calculation
// taking longer is not a major concern
if (!state->mel_calc_fallback) {
state->mel_calc_fallback = new mel_calc_cpu(state->backend, ctx->model.filters);
}
state->mel = state->mel_calc_fallback->calculate({samples, n_samples}, n_threads);
} }
state->t_mel_us += ggml_time_us() - t_start_us;
// Dump log_mel_spectrogram
//{
// auto& mel = state->mel;
// std::ofstream outFile("log_mel_spectrogram.json");
// outFile << "[";
// for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
// outFile << mel.data[i] << ", ";
// }
// outFile << mel.data[mel.data.size() - 1] << "]";
// outFile.close();
//}
return 0; return 0;
} }
@ -3741,30 +3820,6 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
} }
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
return 0;
}
// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads);
}
// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
// TODO
// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
// TODO
// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
// TODO
int whisper_set_mel_with_state( int whisper_set_mel_with_state(
struct whisper_context * ctx, struct whisper_context * ctx,
struct whisper_state * state, struct whisper_state * state,
@ -3776,12 +3831,10 @@ int whisper_set_mel_with_state(
return -1; return -1;
} }
state->mel.n_len = n_len; whisper_mel_free(state->mel);
state->mel.n_len_org = n_len; whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel);
state->mel.n_mel = n_mel;
state->mel.data.resize(n_len*n_mel); ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor));
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
return 0; return 0;
} }
@ -4665,7 +4718,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.split_on_word =*/ false, /*.split_on_word =*/ false,
/*.max_tokens =*/ 0, /*.max_tokens =*/ 0,
/*.speed_up =*/ false,
/*.debug_mode =*/ false, /*.debug_mode =*/ false,
/*.audio_ctx =*/ 0, /*.audio_ctx =*/ 0,
@ -5339,15 +5391,9 @@ int whisper_full_with_state(
if (n_samples > 0) { if (n_samples > 0) {
// compute log mel spectrogram // compute log mel spectrogram
if (params.speed_up) { if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
// TODO: Replace PV with more advanced algorithm
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
return -1; return -2;
} else {
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
return -2;
}
} }
} }
@ -5384,7 +5430,7 @@ int whisper_full_with_state(
// if length of spectrogram is less than 1.0s (100 frames), then return // if length of spectrogram is less than 1.0s (100 frames), then return
// basically don't process anything that is less than 1.0s // basically don't process anything that is less than 1.0s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { if (seek_end < seek_start + 100) {
WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10); WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
return 0; return 0;
} }
@ -6096,8 +6142,8 @@ int whisper_full_with_state(
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) { if (!text.empty()) {
const auto tt0 = params.speed_up ? 2*t0 : t0; const auto tt0 = t0;
const auto tt1 = params.speed_up ? 2*t1 : t1; const auto tt1 = t1;
if (params.print_realtime) { if (params.print_realtime) {
if (params.print_timestamps) { if (params.print_timestamps) {
@ -6143,8 +6189,8 @@ int whisper_full_with_state(
if (!text.empty()) { if (!text.empty()) {
const auto t1 = seek + seek_delta; const auto t1 = seek + seek_delta;
const auto tt0 = params.speed_up ? 2*t0 : t0; const auto tt0 = t0;
const auto tt1 = params.speed_up ? 2*t1 : t1; const auto tt1 = t1;
if (params.print_realtime) { if (params.print_realtime) {
if (params.print_timestamps) { if (params.print_timestamps) {
@ -7235,7 +7281,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
// operation (after median filter) // operation (after median filter)
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
w = ggml_norm(gctx, w, 1e-9); w = ggml_norm(gctx, w, 1e-9f);
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3); w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
// Pass median filter - this is done over AUDIO_TOKENS dimension. // Pass median filter - this is done over AUDIO_TOKENS dimension.

View File

@ -31,8 +31,10 @@
#define WHISPER_SAMPLE_RATE 16000 #define WHISPER_SAMPLE_RATE 16000
#define WHISPER_N_FFT 400 #define WHISPER_N_FFT 400
#define WHISPER_N_FFT_HALF (WHISPER_N_FFT / 2 + 1)
#define WHISPER_HOP_LENGTH 160 #define WHISPER_HOP_LENGTH 160
#define WHISPER_CHUNK_SIZE 30 #define WHISPER_CHUNK_SIZE 30
#define WHISPER_N_SAMPLES (WHISPER_SAMPLE_RATE * WHISPER_CHUNK_SIZE)
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
@ -266,22 +268,6 @@ extern "C" {
int n_samples, int n_samples,
int n_threads); int n_threads);
// Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2.
// The resulting spectrogram is stored inside the default state of the provided whisper context.
// Returns 0 on success
WHISPER_API int whisper_pcm_to_mel_phase_vocoder(
struct whisper_context * ctx,
const float * samples,
int n_samples,
int n_threads);
WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * samples,
int n_samples,
int n_threads);
// This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80 // n_mel must be 80
@ -499,7 +485,6 @@ extern "C" {
// [EXPERIMENTAL] speed-up techniques // [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output // note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
int audio_ctx; // overwrite the audio context size (0 = use default) int audio_ctx; // overwrite the audio context size (0 = use default)