diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e9bf9c28..2095e70d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -498,7 +498,7 @@ jobs: run: > cmake -S . -B ./build -A ${{ matrix.arch }} -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DWHISPER_CUBLAS=${{ matrix.cublas }} + -DWHISPER_CUDA=${{ matrix.cublas }} -DWHISPER_SDL2=${{ matrix.sdl2 }} - name: Build ${{ matrix.cuda-toolkit }} diff --git a/whisper-mel-cuda.cu b/whisper-mel-cuda.cu index 3f3e3158..9a6f1093 100644 --- a/whisper-mel-cuda.cu +++ b/whisper-mel-cuda.cu @@ -194,7 +194,7 @@ class mel_calc_cuda : public whisper_mel_calc { size_t m_log_mel_temp_storage_size = 0; void * m_log_mel_temp_storage = nullptr; public: - mel_calc_cuda(ggml_backend_t backend, const whisper_filters& filters) + mel_calc_cuda(ggml_backend_t backend, const whisper_filters & filters) : m_n_mel(filters.n_mel) , m_backend(backend) { @@ -305,7 +305,7 @@ public: whisper_mel ret; // Calculate semi-padded sample length to ensure compatibility int n_len_org = 1 + int(samples.len + mirror_pad - WHISPER_N_FFT) / WHISPER_HOP_LENGTH; - ret.init(m_backend, int(n_mag_frames), n_len_org, m_n_mel); + whisper_mel_init(ret, m_backend, int(n_mag_frames), n_len_org, m_n_mel); assert(ggml_nbytes(ret.tensor) == m_n_mel * n_mag_frames * sizeof(float)); float* log_mels = reinterpret_cast(ret.tensor->data); diff --git a/whisper-mel.hpp b/whisper-mel.hpp index e52b804d..1a54a23c 100644 --- a/whisper-mel.hpp +++ b/whisper-mel.hpp @@ -5,23 +5,15 @@ struct whisper_mel { int n_len_org = 0; - ggml_tensor * tensor = nullptr; ggml_context * ctx = nullptr; + ggml_tensor * tensor = nullptr; ggml_backend_buffer_t buffer = nullptr; - - whisper_mel() = default; - ~whisper_mel(); - - whisper_mel(const whisper_mel &) = delete; - whisper_mel & operator=(const whisper_mel &) = delete; - whisper_mel(whisper_mel &&) noexcept; - whisper_mel & operator=(whisper_mel &&) noexcept; - - void init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel); - void reset(); - void take(whisper_mel & other) noexcept; }; +void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel); + +void whisper_mel_free(whisper_mel & mel); + struct whisper_filters { int32_t n_mel; int32_t n_fft; @@ -40,6 +32,3 @@ struct whisper_mel_calc { virtual whisper_mel calculate(whisper_span samples, int n_threads) const = 0; static whisper_span hann_window(); }; - -// returns a new pointer which needs to be freed with delete -whisper_mel_calc * whisper_mel_calc_create(ggml_backend_t backend, const whisper_filters & filters); diff --git a/whisper.cpp b/whisper.cpp index dfbcc9d3..e8a13208 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -801,6 +801,7 @@ struct whisper_state { whisper_kv_cache kv_pad; whisper_mel mel; + whisper_mel_calc * mel_calc = nullptr; whisper_batch batch; @@ -870,8 +871,6 @@ struct whisper_context { whisper_model model; whisper_vocab vocab; - whisper_mel_calc * mel_calc = nullptr; - whisper_state * state = nullptr; ggml_backend_t backend = nullptr; @@ -893,7 +892,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) { BYTESWAP_VALUE(dest); } -static bool kv_cache_init( +static bool whisper_kv_cache_init( struct whisper_kv_cache & cache, ggml_backend_t backend, ggml_type wtype, @@ -936,7 +935,7 @@ static bool kv_cache_init( return true; } -static void kv_cache_free(struct whisper_kv_cache & cache) { +static void whisper_kv_cache_free(struct whisper_kv_cache & cache) { ggml_free(cache.ctx); ggml_backend_buffer_free(cache.buffer); cache.ctx = nullptr; @@ -1250,9 +1249,12 @@ static ggml_backend_t whisper_backend_init(const whisper_context_params & params } #endif + GGML_UNUSED(params); + if (backend_gpu) { return backend_gpu; } + return ggml_backend_cpu_init(); } @@ -2885,52 +2887,25 @@ struct whisper_global_cache { // Mel spectrogram -whisper_mel::~whisper_mel() { - reset(); +void whisper_mel_init(whisper_mel & mel, ggml_backend_t backend, int n_len, int n_len_org, int n_mel) { + WHISPER_LOG_INFO("%s: n_len = %d, n_len_org = %d, n_mel = %d\n", __func__, n_len, n_len_org, n_mel); + mel.n_len_org = n_len_org; + assert(!mel.ctx); + mel.ctx = ggml_init({ggml_tensor_overhead(), nullptr, true}); + mel.tensor = ggml_new_tensor_2d(mel.ctx, GGML_TYPE_F32, n_len, n_mel); + mel.buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(mel.tensor) + ggml_backend_get_alignment(backend)); + auto alloc = ggml_tallocr_new(mel.buffer); + ggml_tallocr_alloc(&alloc, mel.tensor); } -whisper_mel::whisper_mel(whisper_mel && other) noexcept { - take(other); -} +void whisper_mel_free(whisper_mel & mel) { + ggml_free(mel.ctx); + ggml_backend_buffer_free(mel.buffer); -whisper_mel & whisper_mel::operator=(whisper_mel && other) noexcept { - if (this != &other) { - reset(); - take(other); - } - return *this; -} - -void whisper_mel::init(ggml_backend_t backend, int n_len, int n_len_org, int n_mel) { - this->n_len_org = n_len_org; - assert(!ctx); - ctx = ggml_init({ggml_tensor_overhead(), nullptr, true}); - tensor = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_len, n_mel); - buffer = ggml_backend_alloc_buffer(backend, ggml_nbytes(tensor) + ggml_backend_get_alignment(backend)); - auto alloc = ggml_tallocr_new(buffer); - ggml_tallocr_alloc(&alloc, tensor); -} - -void whisper_mel::reset() { - ggml_free(ctx); - ggml_backend_buffer_free(buffer); - - n_len_org = 0; - tensor = nullptr; - ctx = nullptr; - buffer = nullptr; -} - -void whisper_mel::take(whisper_mel & other) noexcept { - n_len_org = other.n_len_org; - tensor = other.tensor; - ctx = other.ctx; - buffer = other.buffer; - - other.n_len_org = 0; - other.tensor = nullptr; - other.ctx = nullptr; - other.buffer = nullptr; + mel.n_len_org = 0; + mel.ctx = nullptr; + mel.tensor = nullptr; + mel.buffer = nullptr; } whisper_mel_calc::~whisper_mel_calc() = default; // export vtable @@ -3026,7 +3001,7 @@ struct whisper_mel_data { int n_len; int n_len_org; int n_mel; - float* data; + float * data; }; void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, @@ -3100,7 +3075,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v struct mel_calc_cpu : public whisper_mel_calc { ggml_backend_t m_backend; - const whisper_filters& m_filters; + const whisper_filters & m_filters; mel_calc_cpu(ggml_backend_t backend, const whisper_filters & filters) : m_backend(backend), m_filters(filters) {} // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 @@ -3137,7 +3112,7 @@ struct mel_calc_cpu : public whisper_mel_calc { std::vector host_mel_data; whisper_mel ret; - ret.init(m_backend, mel.n_len, mel.n_len_org, mel.n_mel); + whisper_mel_init(ret, m_backend, mel.n_len, mel.n_len_org, mel.n_mel); if (ggml_backend_buffer_is_host(ret.buffer)) { mel.data = reinterpret_cast(ret.tensor->data); } else { @@ -3325,15 +3300,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return nullptr; } + state->mel_calc = whisper_mel_calc_create(state->backend, ctx->model.filters); + // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx // in theory, there can be a case where this is not enough, but in practice it should always be enough const int factor = 3; - if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, + if (!whisper_kv_cache_init(state->kv_self, state->backend, ctx->itype, ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_layer, GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { - WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; } @@ -3343,11 +3320,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); } - if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, + if (!whisper_kv_cache_init(state->kv_cross, state->backend, ctx->itype, ctx->model.hparams.n_text_state, ctx->model.hparams.n_text_layer, GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { - WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__); whisper_free_state(state); return nullptr; } @@ -3357,11 +3334,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); } - if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, + if (!whisper_kv_cache_init(state->kv_pad, state->backend, ctx->itype, ctx->model.hparams.n_audio_state, 1, GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { - WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; } @@ -3373,7 +3350,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // [EXPERIMENTAL] Token-level timestamps with DTW if (ctx->params.dtw_token_timestamps) { - if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { + if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backend)) { WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__); whisper_free_state(state); return nullptr; @@ -3416,7 +3393,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // conv allocator { - bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend, + bool ok = whisper_allocr_graph_init(state->alloc_conv, state->backend, [&]() { return whisper_build_graph_conv(*ctx, *state, 0); }); @@ -3432,7 +3409,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // encoder allocator if (!whisper_encode_external(*state)) { - bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend, + bool ok = whisper_allocr_graph_init(state->alloc_encode, state->backend, [&]() { return whisper_build_graph_encoder(*ctx, *state); }); @@ -3448,7 +3425,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // cross allocator { - bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend, + bool ok = whisper_allocr_graph_init(state->alloc_cross, state->backend, [&]() { return whisper_build_graph_cross(*ctx, *state); }); @@ -3464,7 +3441,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // decoder allocator { - bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend, + bool ok = whisper_allocr_graph_init(state->alloc_decode, state->backend, [&]() { const auto & hparams = ctx->model.hparams; @@ -3660,8 +3637,6 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ return nullptr; } - ctx->mel_calc = whisper_mel_calc_create(ctx->backend, ctx->model.filters); - loader->close(loader->context); return ctx; @@ -3738,9 +3713,14 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa void whisper_free_state(struct whisper_state * state) { if (state) { - kv_cache_free(state->kv_self); - kv_cache_free(state->kv_cross); - kv_cache_free(state->kv_pad); + whisper_kv_cache_free(state->kv_self); + whisper_kv_cache_free(state->kv_cross); + whisper_kv_cache_free(state->kv_pad); + + whisper_mel_free(state->mel); + + delete state->mel_calc; + state->mel_calc = nullptr; #ifdef WHISPER_USE_COREML if (state->ctx_coreml != nullptr) { @@ -3782,8 +3762,6 @@ void whisper_free(struct whisper_context * ctx) { ggml_backend_free(ctx->backend); - delete ctx->mel_calc; - ctx->mel_calc = nullptr; delete ctx; } } @@ -3800,9 +3778,11 @@ void whisper_free_params(struct whisper_full_params * params) { } } -int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { +int whisper_pcm_to_mel_with_state(struct whisper_context * /*ctx*/, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { const int64_t t_start_us = ggml_time_us(); - state->mel = ctx->mel_calc->calculate({samples, n_samples}, n_threads); + + state->mel = state->mel_calc->calculate({samples, n_samples}, n_threads); + state->t_mel_us += ggml_time_us() - t_start_us; // Dump log_mel_spectrogram @@ -3834,8 +3814,9 @@ int whisper_set_mel_with_state( return -1; } - state->mel.reset(); - state->mel.init(ctx->backend, n_len, n_len, n_mel); + whisper_mel_free(state->mel); + whisper_mel_init(state->mel, ctx->backend, n_len, n_len, n_mel); + ggml_backend_tensor_set(state->mel.tensor, data, 0, ggml_nbytes(state->mel.tensor)); return 0;