whisper.cpp/src/whisper.cpp
Daniel Bevenius e41bc5c61a
vad : add initial Voice Activity Detection (VAD) support (#3065)
* vad : add initial Voice Activity Detection (VAD) support

This commit add support for Voice Activity Detection (VAD). When enabled
this feature will process the audio input and detect speech segments.
This information is then used to reduce the number of samples that need
to be processed by whisper_full.

Resolves: https://github.com/ggml-org/whisper.cpp/issues/3003

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-05-12 16:10:11 +02:00

8917 lines
318 KiB
C++

#include "whisper.h"
#include "whisper-arch.h"
#include "ggml.h"
#include "ggml-cpp.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"
#ifdef WHISPER_USE_COREML
#include "coreml/whisper-encoder.h"
#endif
#ifdef WHISPER_USE_OPENVINO
#include "openvino/whisper-openvino-encoder.h"
#endif
#include <atomic>
#include <algorithm>
#include <cassert>
#include <cfloat>
#define _USE_MATH_DEFINES
#include <cmath>
#include <climits>
#include <codecvt>
#include <cstdarg>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <functional>
#include <map>
#include <mutex>
#include <random>
#include <regex>
#include <set>
#include <string>
#include <thread>
#include <vector>
#if defined(WHISPER_BIG_ENDIAN)
template<typename T>
static T byteswap(T value) {
T value_swapped;
char * source = reinterpret_cast<char *>(&value);
char * target = reinterpret_cast<char *>(&value_swapped);
int size = sizeof(T);
for (int i = 0; i < size; i++) {
target[size - 1 - i] = source[i];
}
return value_swapped;
}
template<typename T>
static void byteswap_tensor_data(ggml_tensor * tensor) {
T * datum = reinterpret_cast<T *>(tensor->data);
for (int i = 0; i < ggml_nelements(tensor); i++) {
datum[i] = byteswap(datum[i]);
}
}
static void byteswap_tensor(ggml_tensor * tensor) {
switch (tensor->type) {
case GGML_TYPE_I16: {
byteswap_tensor_data<int16_t>(tensor);
break;
}
case GGML_TYPE_F16: {
byteswap_tensor_data<ggml_fp16_t>(tensor);
break;
}
case GGML_TYPE_I32: {
byteswap_tensor_data<int32_t>(tensor);
break;
}
case GGML_TYPE_F32: {
byteswap_tensor_data<float>(tensor);
break;
}
default: { // GML_TYPE_I8
break;
}
}
}
#define BYTESWAP_VALUE(d) d = byteswap(d)
#define BYTESWAP_FILTERS(f) \
do { \
for (auto & datum : f.data) { \
datum = byteswap(datum); \
} \
} while (0)
#define BYTESWAP_TENSOR(t) \
do { \
byteswap_tensor(t); \
} while (0)
#else
#define BYTESWAP_VALUE(d) do {} while (0)
#define BYTESWAP_FILTERS(f) do {} while (0)
#define BYTESWAP_TENSOR(t) do {} while (0)
#endif
#ifdef __GNUC__
#ifdef __MINGW32__
#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else
#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#else
#define WHISPER_ATTRIBUTE_FORMAT(...)
#endif
//
// logging
//
WHISPER_ATTRIBUTE_FORMAT(2, 3)
static void whisper_log_internal (ggml_log_level level, const char * format, ...);
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data);
#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
// define this to enable verbose trace logging - useful for debugging purposes
//#define WHISPER_DEBUG
#if defined(WHISPER_DEBUG)
#define WHISPER_LOG_DEBUG(...) whisper_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
#else
#define WHISPER_LOG_DEBUG(...)
#endif
#define WHISPER_ASSERT(x) \
do { \
if (!(x)) { \
WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#define WHISPER_MAX_DECODERS 8
#define WHISPER_MAX_NODES 4096
static std::string format(const char * fmt, ...) {
va_list ap;
va_list ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
GGML_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
}
//
// ggml helpers
//
static bool ggml_graph_compute_helper(
struct ggml_cgraph * graph,
int n_threads,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get()));
auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback");
if (set_abort_callback_fn) {
set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data);
}
auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (ggml_backend_set_n_threads_fn) {
ggml_backend_set_n_threads_fn(backend.get(), n_threads);
}
return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS;
}
static bool ggml_graph_compute_helper(
ggml_backend_sched_t sched,
struct ggml_cgraph * graph,
int n_threads,
bool sched_reset = true) {
for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (fn_set_n_threads) {
fn_set_n_threads(backend, n_threads);
}
}
const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS);
if (!t || sched_reset) {
ggml_backend_sched_reset(sched);
}
return t;
}
static void whisper_load_backends() {
#ifdef GGML_BACKEND_DL
static std::once_flag flag;
std::call_once(flag, []() {
ggml_backend_load_all();
});
#endif
}
// TODO: move these functions to ggml-base with support for ggml-backend?
static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) {
GGML_ASSERT(t->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(t));
size_t nels = ggml_nelements(t);
for (size_t i = 0; i < nels; ++i) {
((float *) t->data)[i] = v;
}
return t;
}
static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) {
GGML_ASSERT(t->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(t));
size_t nels = ggml_nelements(t);
for (size_t i = 0; i < nels; ++i) {
((int32_t *) t->data)[i] = v;
}
return t;
}
static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
GGML_ASSERT(t->type == GGML_TYPE_F32);
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
return *(float *) data;
}
static void whisper_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) {
GGML_ASSERT(t->type == GGML_TYPE_F32);
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
*(float *) data = v;
}
static int32_t whisper_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
GGML_ASSERT(t->type == GGML_TYPE_I32);
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
return *(int32_t *) data;
}
static void whisper_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) {
GGML_ASSERT(t->type == GGML_TYPE_I32);
void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3];
*(int32_t *) data = v;
}
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
// the idea is to represent the original matrix multiplication:
//
// Z = X @ Y
//
// with the sum of two matrix multiplications:
//
// Z = (X_0 @ Y_0) + (X_1 @ Y_1)
//
// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad"
// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more
// general-purpose kernels
//
static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
// use padding only if dimension 0 is at least 8 times larger than the padding
// else we won't get much benefit from the optimization
const int n_pad_req = 8;
if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) {
return ggml_mul_mat(ctx, x, y);
}
struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
return ggml_add(ctx,
ggml_mul_mat(ctx, x_0, y_0),
ggml_mul_mat(ctx, x_1, y_1));
}
// TODO: check if other platforms can benefit from this optimization
// TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly
#if defined(GGML_USE_METAL)
#define ggml_mul_mat ggml_mul_mat_pad
#endif
// available whisper models
enum e_model {
MODEL_UNKNOWN,
MODEL_TINY,
MODEL_BASE,
MODEL_SMALL,
MODEL_MEDIUM,
MODEL_LARGE,
};
static const std::map<e_model, std::string> g_model_name = {
{ MODEL_UNKNOWN, "unknown" },
{ MODEL_TINY, "tiny" },
{ MODEL_BASE, "base" },
{ MODEL_SMALL, "small" },
{ MODEL_MEDIUM, "medium" },
{ MODEL_LARGE, "large" },
};
static const std::map<std::string, std::pair<int, std::string>> g_lang = {
{ "en", { 0, "english", } },
{ "zh", { 1, "chinese", } },
{ "de", { 2, "german", } },
{ "es", { 3, "spanish", } },
{ "ru", { 4, "russian", } },
{ "ko", { 5, "korean", } },
{ "fr", { 6, "french", } },
{ "ja", { 7, "japanese", } },
{ "pt", { 8, "portuguese", } },
{ "tr", { 9, "turkish", } },
{ "pl", { 10, "polish", } },
{ "ca", { 11, "catalan", } },
{ "nl", { 12, "dutch", } },
{ "ar", { 13, "arabic", } },
{ "sv", { 14, "swedish", } },
{ "it", { 15, "italian", } },
{ "id", { 16, "indonesian", } },
{ "hi", { 17, "hindi", } },
{ "fi", { 18, "finnish", } },
{ "vi", { 19, "vietnamese", } },
{ "he", { 20, "hebrew", } },
{ "uk", { 21, "ukrainian", } },
{ "el", { 22, "greek", } },
{ "ms", { 23, "malay", } },
{ "cs", { 24, "czech", } },
{ "ro", { 25, "romanian", } },
{ "da", { 26, "danish", } },
{ "hu", { 27, "hungarian", } },
{ "ta", { 28, "tamil", } },
{ "no", { 29, "norwegian", } },
{ "th", { 30, "thai", } },
{ "ur", { 31, "urdu", } },
{ "hr", { 32, "croatian", } },
{ "bg", { 33, "bulgarian", } },
{ "lt", { 34, "lithuanian", } },
{ "la", { 35, "latin", } },
{ "mi", { 36, "maori", } },
{ "ml", { 37, "malayalam", } },
{ "cy", { 38, "welsh", } },
{ "sk", { 39, "slovak", } },
{ "te", { 40, "telugu", } },
{ "fa", { 41, "persian", } },
{ "lv", { 42, "latvian", } },
{ "bn", { 43, "bengali", } },
{ "sr", { 44, "serbian", } },
{ "az", { 45, "azerbaijani", } },
{ "sl", { 46, "slovenian", } },
{ "kn", { 47, "kannada", } },
{ "et", { 48, "estonian", } },
{ "mk", { 49, "macedonian", } },
{ "br", { 50, "breton", } },
{ "eu", { 51, "basque", } },
{ "is", { 52, "icelandic", } },
{ "hy", { 53, "armenian", } },
{ "ne", { 54, "nepali", } },
{ "mn", { 55, "mongolian", } },
{ "bs", { 56, "bosnian", } },
{ "kk", { 57, "kazakh", } },
{ "sq", { 58, "albanian", } },
{ "sw", { 59, "swahili", } },
{ "gl", { 60, "galician", } },
{ "mr", { 61, "marathi", } },
{ "pa", { 62, "punjabi", } },
{ "si", { 63, "sinhala", } },
{ "km", { 64, "khmer", } },
{ "sn", { 65, "shona", } },
{ "yo", { 66, "yoruba", } },
{ "so", { 67, "somali", } },
{ "af", { 68, "afrikaans", } },
{ "oc", { 69, "occitan", } },
{ "ka", { 70, "georgian", } },
{ "be", { 71, "belarusian", } },
{ "tg", { 72, "tajik", } },
{ "sd", { 73, "sindhi", } },
{ "gu", { 74, "gujarati", } },
{ "am", { 75, "amharic", } },
{ "yi", { 76, "yiddish", } },
{ "lo", { 77, "lao", } },
{ "uz", { 78, "uzbek", } },
{ "fo", { 79, "faroese", } },
{ "ht", { 80, "haitian creole", } },
{ "ps", { 81, "pashto", } },
{ "tk", { 82, "turkmen", } },
{ "nn", { 83, "nynorsk", } },
{ "mt", { 84, "maltese", } },
{ "sa", { 85, "sanskrit", } },
{ "lb", { 86, "luxembourgish", } },
{ "my", { 87, "myanmar", } },
{ "bo", { 88, "tibetan", } },
{ "tl", { 89, "tagalog", } },
{ "mg", { 90, "malagasy", } },
{ "as", { 91, "assamese", } },
{ "tt", { 92, "tatar", } },
{ "haw", { 93, "hawaiian", } },
{ "ln", { 94, "lingala", } },
{ "ha", { 95, "hausa", } },
{ "ba", { 96, "bashkir", } },
{ "jw", { 97, "javanese", } },
{ "su", { 98, "sundanese", } },
{ "yue", { 99, "cantonese", } },
};
// [EXPERIMENTAL] Token-level timestamps with DTW
static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} };
static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} };
static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} };
static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} };
static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} };
static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} };
static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} };
static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} };
static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} };
static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} };
static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} };
static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} };
static const std::map<whisper_alignment_heads_preset, whisper_aheads> g_aheads {
{ WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } },
{ WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } },
{ WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } },
{ WHISPER_AHEADS_BASE, { 8, g_aheads_base } },
{ WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } },
{ WHISPER_AHEADS_SMALL, { 10, g_aheads_small } },
{ WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } },
{ WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } },
{ WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } },
{ WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } },
{ WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } },
{ WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } },
};
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head);
struct whisper_mel {
int n_len;
int n_len_org;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
struct whisper_vocab {
using id = int32_t;
using token = std::string;
int n_vocab = 51864;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
// reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349
id token_eot = 50256;
id token_sot = 50257;
// task tokens (used only for multilingual models)
id token_translate = 50357;
id token_transcribe = 50358;
// other special tokens
id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn
id token_prev = 50360;
id token_nosp = 50361;
id token_not = 50362; // no timestamps
id token_beg = 50363; // begin timestamps
bool is_multilingual() const {
return n_vocab >= 51865;
}
int num_languages() const {
return n_vocab - 51765 - (is_multilingual() ? 1 : 0);
}
};
struct whisper_segment {
int64_t t0;
int64_t t1;
std::string text;
float no_speech_prob;
std::vector<whisper_token_data> tokens;
bool speaker_turn_next;
};
struct whisper_batch {
int32_t n_tokens;
whisper_token * token;
whisper_pos * pos;
int32_t * n_seq_id; // always 1, here for consistency with llama.cpp
whisper_seq_id ** seq_id; // null terminated
int8_t * logits;
};
static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) {
whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, };
batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens));
batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens));
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens));
batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1));
for (int i = 0; i < n_tokens; ++i) {
batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max);
}
batch.seq_id[n_tokens] = nullptr;
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return batch;
}
static void whisper_batch_free(struct whisper_batch batch) {
if (batch.token) free(batch.token);
if (batch.pos) free(batch.pos);
if (batch.n_seq_id) free(batch.n_seq_id);
if (batch.seq_id) {
for (int i = 0; batch.seq_id[i]; ++i) {
free(batch.seq_id[i]);
}
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
}
static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) {
batch.n_tokens = n_tokens;
for (int i = 0; i < n_tokens; ++i) {
if (tokens) {
batch.token[i] = tokens[i];
}
batch.pos [i] = n_past + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i][0] = seq_id;
batch.logits [i] = 0;
}
batch.logits[n_tokens - 1] = 1;
}
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
template<typename A, typename B>
struct whisper_pair {
A first;
B second;
// Define a constructor that takes two arguments.
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
// Define a constructor that takes no argument.
whisper_pair() : first(A()), second(B()) {}
};
// ggml_backend_sched wrapper for whisper usage
struct whisper_sched {
ggml_backend_sched_t sched = nullptr;
std::vector<uint8_t> meta;
};
static size_t whisper_sched_size(struct whisper_sched & allocr) {
size_t size = allocr.meta.size();
for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i);
size += ggml_backend_sched_get_buffer_size(allocr.sched, backend);
}
return size;
}
// measure the memory usage of a graph and prepare the allocr's internal data buffer
static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) {
auto & sched = allocr.sched;
auto & meta = allocr.meta;
sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false);
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
// since there are dependencies between the different graphs,
// we need to allocate them instead of only reserving to get the correct compute buffer size
if (!ggml_backend_sched_alloc_graph(sched, get_graph())) {
// failed to allocate the compute buffer
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
return false;
}
ggml_backend_sched_reset(sched);
return true;
}
// medium
// hparams: {
// 'n_mels': 80,
// 'n_vocab': 51864,
// 'n_audio_ctx': 1500,
// 'n_audio_state': 1024,
// 'n_audio_head': 16,
// 'n_audio_layer': 24,
// 'n_text_ctx': 448,
// 'n_text_state': 1024,
// 'n_text_head': 16,
// 'n_text_layer': 24
// }
//
// default hparams (Whisper tiny)
struct whisper_hparams {
int32_t n_vocab = 51864;
int32_t n_audio_ctx = 1500;
int32_t n_audio_state = 384;
int32_t n_audio_head = 6;
int32_t n_audio_layer = 4;
int32_t n_text_ctx = 448;
int32_t n_text_state = 384;
int32_t n_text_head = 6;
int32_t n_text_layer = 4;
int32_t n_mels = 80;
int32_t ftype = 1;
float eps = 1e-5f;
};
// audio encoding layer
struct whisper_layer_encoder {
// encoder.blocks.*.attn_ln
struct ggml_tensor * attn_ln_0_w;
struct ggml_tensor * attn_ln_0_b;
// encoder.blocks.*.attn.out
struct ggml_tensor * attn_ln_1_w;
struct ggml_tensor * attn_ln_1_b;
// encoder.blocks.*.attn.query
struct ggml_tensor * attn_q_w;
struct ggml_tensor * attn_q_b;
// encoder.blocks.*.attn.key
struct ggml_tensor * attn_k_w;
// encoder.blocks.*.attn.value
struct ggml_tensor * attn_v_w;
struct ggml_tensor * attn_v_b;
// encoder.blocks.*.mlp_ln
struct ggml_tensor * mlp_ln_w;
struct ggml_tensor * mlp_ln_b;
// encoder.blocks.*.mlp.0
struct ggml_tensor * mlp_0_w;
struct ggml_tensor * mlp_0_b;
// encoder.blocks.*.mlp.2
struct ggml_tensor * mlp_1_w;
struct ggml_tensor * mlp_1_b;
};
// token decoding layer
struct whisper_layer_decoder {
// decoder.blocks.*.attn_ln
struct ggml_tensor * attn_ln_0_w;
struct ggml_tensor * attn_ln_0_b;
// decoder.blocks.*.attn.out
struct ggml_tensor * attn_ln_1_w;
struct ggml_tensor * attn_ln_1_b;
// decoder.blocks.*.attn.query
struct ggml_tensor * attn_q_w;
struct ggml_tensor * attn_q_b;
// decoder.blocks.*.attn.key
struct ggml_tensor * attn_k_w;
// decoder.blocks.*.attn.value
struct ggml_tensor * attn_v_w;
struct ggml_tensor * attn_v_b;
// decoder.blocks.*.cross_attn_ln
struct ggml_tensor * cross_attn_ln_0_w;
struct ggml_tensor * cross_attn_ln_0_b;
// decoder.blocks.*.cross_attn.out
struct ggml_tensor * cross_attn_ln_1_w;
struct ggml_tensor * cross_attn_ln_1_b;
// decoder.blocks.*.cross_attn.query
struct ggml_tensor * cross_attn_q_w;
struct ggml_tensor * cross_attn_q_b;
// decoder.blocks.*.cross_attn.key
struct ggml_tensor * cross_attn_k_w;
// decoder.blocks.*.cross_attn.value
struct ggml_tensor * cross_attn_v_w;
struct ggml_tensor * cross_attn_v_b;
// decoder.blocks.*.mlp_ln
struct ggml_tensor * mlp_ln_w;
struct ggml_tensor * mlp_ln_b;
// decoder.blocks.*.mlp.0
struct ggml_tensor * mlp_0_w;
struct ggml_tensor * mlp_0_b;
// decoder.blocks.*.mlp.2
struct ggml_tensor * mlp_1_w;
struct ggml_tensor * mlp_1_b;
};
struct whisper_kv_cell {
whisper_pos pos = -1;
std::set<whisper_seq_id> seq_id;
bool has_seq_id(const whisper_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
};
struct whisper_kv_cache {
uint32_t head = 0;
uint32_t size = 0;
// computed before each graph build
uint32_t n = 0;
std::vector<whisper_kv_cell> cells;
struct ggml_tensor * k;
struct ggml_tensor * v;
ggml_backend_buffer_t buffer = nullptr;
std::vector<uint8_t> ctx_buf;
};
struct whisper_model {
e_model type = MODEL_UNKNOWN;
whisper_hparams hparams;
whisper_filters filters;
// encoder.positional_embedding
struct ggml_tensor * e_pe;
// encoder.conv1
struct ggml_tensor * e_conv_1_w;
struct ggml_tensor * e_conv_1_b;
// encoder.conv2
struct ggml_tensor * e_conv_2_w;
struct ggml_tensor * e_conv_2_b;
// encoder.ln_post
struct ggml_tensor * e_ln_w;
struct ggml_tensor * e_ln_b;
// decoder.positional_embedding
struct ggml_tensor * d_pe;
// decoder.token_embedding
struct ggml_tensor * d_te;
// decoder.ln
struct ggml_tensor * d_ln_w;
struct ggml_tensor * d_ln_b;
std::vector<whisper_layer_encoder> layers_encoder;
std::vector<whisper_layer_decoder> layers_decoder;
// ggml context that contains all the meta information about the model tensors
std::vector<ggml_context *> ctxs;
// the model backend data is read-only and can be shared between processors
std::vector<ggml_backend_buffer_t> buffers;
// tensors
int n_loaded;
std::map<std::string, struct ggml_tensor *> tensors;
};
struct whisper_partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};
struct whisper_grammar {
/*const*/ std::vector<std::vector<whisper_grammar_element>> rules;
std::vector<std::vector<const whisper_grammar_element *>> stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
whisper_partial_utf8 partial_utf8;
};
struct whisper_grammar_candidate {
whisper_token id;
const uint32_t * code_points;
whisper_partial_utf8 partial_utf8;
};
struct whisper_sequence {
std::vector<whisper_token_data> tokens;
// the accumulated transcription in the current iteration (used to truncate the tokens array)
int result_len;
double sum_logprobs_all; // the sum of the log probabilities of the tokens
double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens)
double avg_logprobs; // the average log probability of the tokens
double entropy; // the entropy of the tokens
double score; // likelihood rank score
};
// TAGS: WHISPER_DECODER_INIT
struct whisper_decoder {
// the currently generated sequence of tokens
whisper_sequence sequence;
// grammar parse state of generated sequence of tokens
whisper_grammar grammar;
int i_batch; // the index of the token in the current batch
int seek_delta; // the window shift found so far based on the decoded timestamp tokens
bool failed; // has the current segment failed to decode?
bool completed; // has the decoder completed the current segment?
bool has_ts; // have we already sampled a non-beg timestamp token for the current segment?
// new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab])
std::vector<float> probs;
std::vector<float> logits;
std::vector<float> logprobs;
// work container used to avoid memory allocations
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
mutable std::mt19937 rng; // used for sampling at t > 0.0
};
// [EXPERIMENTAL] Token-level timestamps with DTW
struct whisper_aheads_masks {
std::vector<struct ggml_tensor *> m; // One mask per text layer.
struct ggml_context * ctx = nullptr;
ggml_backend_buffer_t buffer = nullptr;
};
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
int64_t t_decode_us = 0;
int64_t t_batchd_us = 0;
int64_t t_prompt_us = 0;
int64_t t_mel_us = 0;
int32_t n_sample = 0; // number of tokens sampled
int32_t n_encode = 0; // number of encoder calls
int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation)
int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding)
int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding)
int32_t n_fail_p = 0; // number of logprob threshold failures
int32_t n_fail_h = 0; // number of entropy threshold failures
// number of decoders for which we have constructed the KV cache
int32_t kv_self_n_dec = 0;
// unified self-attention KV cache for all decoders
whisper_kv_cache kv_self;
// cross-attention KV cache for the decoders
// shared between all decoders
whisper_kv_cache kv_cross;
// padded buffer for flash-attention
whisper_kv_cache kv_pad;
whisper_mel mel;
whisper_batch batch;
whisper_decoder decoders[WHISPER_MAX_DECODERS];
std::vector<ggml_backend_t> backends;
// - stores meta info about the intermediate tensors into the `meta` buffers
whisper_sched sched_conv;
whisper_sched sched_encode;
whisper_sched sched_cross;
whisper_sched sched_decode;
// result of the encoder
struct ggml_tensor * embd_conv = nullptr;
struct ggml_tensor * embd_enc = nullptr;
// helpers for GPU offloading
std::vector<float> inp_mel;
std::vector<float> inp_mask;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
int lang_id = 0; // english by default
std::string path_model; // populated by whisper_init_from_file_with_params()
#ifdef WHISPER_USE_COREML
whisper_coreml_context * ctx_coreml = nullptr;
#endif
#ifdef WHISPER_USE_OPENVINO
whisper_openvino_context * ctx_openvino = nullptr;
#endif
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
float no_speech_prob = 0.0f;
// [EXPERIMENTAL] Token-level timestamps with DTW
whisper_aheads_masks aheads_masks;
ggml_tensor * aheads_cross_QKs = nullptr;
std::vector<float> aheads_cross_QKs_data;
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default
struct vad_segment_info {
float orig_start;
float orig_end;
float vad_start;
float vad_end;
};
std::vector<vad_segment_info> vad_segments;
bool has_vad_segments = false;
};
struct whisper_context {
int64_t t_load_us = 0;
int64_t t_start_us = 0;
ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX)
ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16)
whisper_context_params params;
whisper_model model;
whisper_vocab vocab;
whisper_state * state = nullptr;
std::string path_model; // populated by whisper_init_from_file_with_params()
};
struct whisper_global {
// We save the log callback globally
ggml_log_callback log_callback = whisper_log_callback_default;
void * log_callback_user_data = nullptr;
};
static whisper_global g_state;
template<typename T>
static void read_safe(whisper_model_loader * loader, T & dest) {
loader->read(loader->context, &dest, sizeof(T));
BYTESWAP_VALUE(dest);
}
static bool whisper_kv_cache_init(
struct whisper_kv_cache & cache,
ggml_backend_t backend,
ggml_type wtype,
int64_t n_text_state,
int64_t n_text_layer,
int n_ctx) {
const int64_t n_mem = n_text_layer*n_ctx;
const int64_t n_elements = n_text_state*n_mem;
cache.ctx_buf.resize(2*ggml_tensor_overhead());
struct ggml_init_params params = {
/*.mem_size =*/ cache.ctx_buf.size(),
/*.mem_buffer =*/ cache.ctx_buf.data(),
/*.no_alloc =*/ true,
};
cache.head = 0;
cache.size = n_ctx;
cache.cells.clear();
cache.cells.resize(n_ctx);
struct ggml_context * ctx = ggml_init(params);
if (!ctx) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__);
return false;
}
cache.k = ggml_new_tensor_1d(ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements);
cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
if (!cache.buffer) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__);
return false;
}
ggml_backend_buffer_clear(cache.buffer, 0);
ggml_free(ctx);
return true;
}
static void whisper_kv_cache_free(struct whisper_kv_cache & cache) {
ggml_backend_buffer_free(cache.buffer);
}
static bool whisper_kv_cache_find_slot(
struct whisper_kv_cache & cache,
const struct whisper_batch & batch) {
const uint32_t n_ctx = cache.size;
const uint32_t n_tokens = batch.n_tokens;
if (n_tokens > n_ctx) {
WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx);
return false;
}
uint32_t n_tested = 0;
while (true) {
if (cache.head + n_tokens > n_ctx) {
n_tested += n_ctx - cache.head;
cache.head = 0;
continue;
}
bool found = true;
for (uint32_t i = 0; i < n_tokens; i++) {
if (cache.cells[cache.head + i].pos >= 0) {
found = false;
cache.head += i + 1;
n_tested += i + 1;
break;
}
}
if (found) {
break;
}
if (n_tested >= n_ctx) {
//WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return false;
}
}
for (uint32_t i = 0; i < n_tokens; i++) {
cache.cells[cache.head + i].pos = batch.pos[i];
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]);
}
}
return true;
}
// find how many cells are currently in use
static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) {
for (uint32_t i = cache.size - 1; i > 0; --i) {
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
return i + 1;
}
}
return 1;
}
static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) {
for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}
cache.head = 0;
ggml_backend_buffer_clear(cache.buffer, 0);
}
static void whisper_kv_cache_seq_rm(
struct whisper_kv_cache & cache,
whisper_seq_id seq_id,
whisper_pos p0,
whisper_pos p1) {
uint32_t new_head = cache.size;
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
if (seq_id < 0) {
cache.cells[i].seq_id.clear();
} else if (cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].seq_id.erase(seq_id);
} else {
continue;
}
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
}
// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}
static void whisper_kv_cache_seq_cp(
struct whisper_kv_cache & cache,
whisper_seq_id seq_id_src,
whisper_seq_id seq_id_dst,
whisper_pos p0,
whisper_pos p1) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<whisper_pos>::max();
cache.head = 0;
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.insert(seq_id_dst);
}
}
}
static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) {
if (!wctx.params.flash_attn || !wctx.params.use_gpu) {
return 1u;
}
#ifdef GGML_USE_METAL
if (wctx.params.use_gpu) {
return 32u;
}
#endif
#ifdef GGML_USE_CUDA
if (wctx.params.use_gpu) {
return 256u;
}
#endif
return 1u;
}
// [EXPERIMENTAL] Token-level timestamps with DTW
static bool aheads_masks_init(
const whisper_context_params & cparams,
const whisper_hparams & hparams,
struct whisper_aheads_masks & aheads_masks,
ggml_backend_t backend) {
const int32_t n_text_layer = hparams.n_text_layer;
const int32_t n_head = hparams.n_text_head;
// Sanity checks
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
return false;
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
return false;
}
} else {
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
if (aheads.n_heads == 0) {
WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
return false;
}
if (aheads.heads == NULL) {
WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
return false;
}
}
for (size_t i = 0; i < aheads.n_heads; ++i) {
if (aheads.heads[i].n_text_layer >= n_text_layer) {
WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer);
return false;
}
if (aheads.heads[i].n_text_layer < 0) {
WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__);
return false;
}
if (aheads.heads[i].n_head >= n_head) {
WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head);
return false;
}
if (aheads.heads[i].n_head < 0) {
WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__);
return false;
}
}
}
struct ggml_init_params params = {
/*.mem_size =*/ (size_t) static_cast<size_t>(n_text_layer)*ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
aheads_masks.ctx = ggml_init(params);
if (!aheads_masks.ctx) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__);
return false;
}
for (int64_t il = 0; il < n_text_layer; ++il) {
auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
if (!aheads.empty()) {
aheads_masks.m.push_back(ggml_new_tensor_2d(aheads_masks.ctx, GGML_TYPE_F32, n_head, aheads.size()));
} else {
aheads_masks.m.push_back(nullptr);
}
}
aheads_masks.buffer = ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend);
if (!aheads_masks.buffer) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__);
return false;
}
// Set data on mask tensors
// Since this must be backend agnostic, we write our desired values on mask_data,
// and send it to backend with ggml_backend_tensor_set.
// Each mask in N_HEADS*N_ALIGNMENT_HEADS, one per text layer containing alignment
// heads. Each row of the mask "marks" one alignment head. E.g. if some text layer
// has a total of 10 heads and of those, heads 0,5,6 are alignment heads, the mask
// should read:
// 1 0 0 0 0 0 0 0 0 0
// 0 0 0 0 0 1 0 0 0 0
// 0 0 0 0 0 0 1 0 0 0
std::vector<float> mask_data;
for (int64_t il = 0; il < n_text_layer; ++il) {
if (aheads_masks.m[il] != nullptr) {
auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head);
size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1];
size_t data_size_bytes = data_size * sizeof(float);
mask_data.resize(data_size);
std::fill(mask_data.begin(), mask_data.end(), 0);
for (size_t ih = 0; ih < aheads.size(); ++ih) {
size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0]));
mask_data[pos] = 1.0f;
}
ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size_bytes);
}
}
if (aheads_masks.m.empty()) {
WHISPER_LOG_ERROR("%s: \n", __func__);
return false;
}
return true;
}
static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) {
ggml_free(aheads_masks.ctx);
ggml_backend_buffer_free(aheads_masks.buffer);
aheads_masks.ctx = nullptr;
}
static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
size_t size = 0;
for (size_t i = 0; i < aheads_masks.m.size(); ++i) {
if (aheads_masks.m[i] != nullptr)
size += ggml_nbytes(aheads_masks.m[i]);
}
return size;
}
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
whisper_load_backends();
ggml_backend_dev_t dev = nullptr;
int cnt = 0;
if (params.use_gpu) {
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
if (cnt == 0 || cnt == params.gpu_device) {
dev = dev_cur;
}
if (++cnt > params.gpu_device) {
break;
}
}
}
}
if (dev == nullptr) {
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
return nullptr;
}
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
if (!result) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
}
return result;
}
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
std::vector<ggml_backend_t> result;
ggml_backend_t backend_gpu = whisper_backend_init_gpu(params);
if (backend_gpu) {
result.push_back(backend_gpu);
}
// ACCEL backends
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
continue;
}
result.push_back(backend);
}
}
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (backend_cpu == nullptr) {
throw std::runtime_error("failed to initialize CPU backend");
}
result.push_back(backend_cpu);
return result;
}
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
static buft_list_t make_buft_list(whisper_context_params & params) {
// Prio order: GPU -> CPU Extra -> CPU
buft_list_t buft_list;
// GPU
if (params.use_gpu) {
int cnt = 0;
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
if (cnt == 0 || cnt == params.gpu_device) {
auto * buft = ggml_backend_dev_buffer_type(dev);
if (buft) {
buft_list.emplace_back(dev, buft);
}
}
if (++cnt > params.gpu_device) {
break;
}
}
}
}
// CPU Extra
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
if (get_extra_bufts_fn) {
ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev);
while (extra_bufts && *extra_bufts) {
buft_list.emplace_back(cpu_dev, *extra_bufts);
++extra_bufts;
}
}
// CPU
buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type());
return buft_list;
}
static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
bool op_supported = true;
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
// GPU and default CPU backend support all operators
op_supported = true;
} else {
switch (op) {
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
case GGML_OP_MUL_MAT: {
ggml_init_params params = {
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error("failed to create ggml context");
}
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * op_tensor = nullptr;
int64_t n_ctx = hparams.n_audio_ctx;
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
op_tensor = ggml_mul_mat(ctx, w, b);
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
GGML_ASSERT(w->buffer == nullptr);
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
ggml_backend_buffer_free(w->buffer);
w->buffer = nullptr;
break;
}
default: {
op_supported = false;
break;
}
};
}
return op_supported;
}
static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
GGML_ASSERT(!buft_list.empty());
for (const auto & p : buft_list) {
ggml_backend_dev_t dev = p.first;
ggml_backend_buffer_type_t buft = p.second;
if (weight_buft_supported(hparams, w, op, buft, dev)) {
return buft;
}
}
return nullptr;
}
// load the model from a ggml file
//
// file format:
//
// - hparams
// - pre-computed mel filters
// - vocab
// - weights
//
// see the convert-pt-to-ggml.py script for details
//
static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) {
WHISPER_LOG_INFO("%s: loading model\n", __func__);
const int64_t t_start_us = ggml_time_us();
wctx.t_start_us = t_start_us;
auto & model = wctx.model;
auto & vocab = wctx.vocab;
// verify magic
{
uint32_t magic;
read_safe(loader, magic);
if (magic != GGML_FILE_MAGIC) {
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
return false;
}
}
//load hparams
{
auto & hparams = model.hparams;
read_safe(loader, hparams.n_vocab);
read_safe(loader, hparams.n_audio_ctx);
read_safe(loader, hparams.n_audio_state);
read_safe(loader, hparams.n_audio_head);
read_safe(loader, hparams.n_audio_layer);
read_safe(loader, hparams.n_text_ctx);
read_safe(loader, hparams.n_text_state);
read_safe(loader, hparams.n_text_head);
read_safe(loader, hparams.n_text_layer);
read_safe(loader, hparams.n_mels);
read_safe(loader, hparams.ftype);
assert(hparams.n_text_state == hparams.n_audio_state);
std::string mver = "";
if (hparams.n_audio_layer == 4) {
model.type = e_model::MODEL_TINY;
}
if (hparams.n_audio_layer == 6) {
model.type = e_model::MODEL_BASE;
}
if (hparams.n_audio_layer == 12) {
model.type = e_model::MODEL_SMALL;
}
if (hparams.n_audio_layer == 24) {
model.type = e_model::MODEL_MEDIUM;
}
if (hparams.n_audio_layer == 32) {
model.type = e_model::MODEL_LARGE;
if (hparams.n_vocab == 51866) {
mver = " v3";
}
}
const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR;
hparams.ftype %= GGML_QNT_VERSION_FACTOR;
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype));
if (wctx.wtype == GGML_TYPE_COUNT) {
WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype);
return false;
}
WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state);
WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head);
WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels);
WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype);
WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr);
WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str());
}
// load mel filters
{
auto & filters = wctx.model.filters;
read_safe(loader, filters.n_mel);
read_safe(loader, filters.n_fft);
filters.data.resize(filters.n_mel * filters.n_fft);
loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float));
BYTESWAP_FILTERS(filters);
}
// load vocab
{
int32_t n_vocab = 0;
read_safe(loader, n_vocab);
//if (n_vocab != model.hparams.n_vocab) {
// WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n",
// __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
// return false;
//}
std::string word;
std::vector<char> tmp;
tmp.reserve(128);
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
read_safe(loader, len);
if (len > 0) {
tmp.resize(len);
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
word.assign(&tmp[0], tmp.size());
} else {
// seems like we have an empty-string token in multi-language models (i = 50256)
//WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i);
word = "";
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
//printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
}
vocab.n_vocab = model.hparams.n_vocab;
if (vocab.is_multilingual()) {
vocab.token_eot++;
vocab.token_sot++;
// account for variable number of language tokens
const int dt = vocab.num_languages() - 98;
vocab.token_translate += dt;
vocab.token_transcribe += dt;
vocab.token_solm += dt;
vocab.token_prev += dt;
vocab.token_nosp += dt;
vocab.token_not += dt;
vocab.token_beg += dt;
}
if (n_vocab < model.hparams.n_vocab) {
WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
if (i > vocab.token_beg) {
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
} else if (i == vocab.token_eot) {
word = "[_EOT_]";
} else if (i == vocab.token_sot) {
word = "[_SOT_]";
} else if (i == vocab.token_translate) {
word = "[_TRANSLATE_]";
} else if (i == vocab.token_transcribe) {
word = "[_TRANSCRIBE_]";
} else if (i == vocab.token_solm) {
word = "[_SOLM_]";
} else if (i == vocab.token_prev) {
word = "[_PREV_]";
} else if (i == vocab.token_nosp) {
word = "[_NOSP_]";
} else if (i == vocab.token_not) {
word = "[_NOT_]";
} else if (i == vocab.token_beg) {
word = "[_BEG_]";
} else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
} else {
word = "[_extra_token_" + std::to_string(i) + "]";
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages());
}
const ggml_type wtype = wctx.wtype;
const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type
const auto & hparams = model.hparams;
const int n_audio_layer = hparams.n_audio_layer;
const int n_text_layer = hparams.n_text_layer;
const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer;
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
if (!ctx) {
throw std::runtime_error("failed to create ggml context");
}
ctx_map[buft] = ctx;
model.ctxs.emplace_back(ctx);
return ctx;
}
return it->second;
};
// Create a list of available bufts, in priority order
buft_list_t buft_list = make_buft_list(wctx.params);
auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * {
ggml_op op = ASR_TENSOR_INFO.at(type);
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
if (!buft) {
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type)));
}
ggml_context * ctx = get_ctx(buft);
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor;
return tensor;
};
// prepare tensors for the weights
{
ggml_init_params params = {
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
const auto & hparams = model.hparams;
const int n_vocab = hparams.n_vocab;
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_audio_state = hparams.n_audio_state;
const int n_audio_layer = hparams.n_audio_layer;
const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_mels = hparams.n_mels;
model.layers_encoder.resize(n_audio_layer);
model.layers_decoder.resize(n_text_layer);
// encoder
model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));
model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
model.e_conv_2_b = create_tensor(ASR_TENSOR_CONV2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));
model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state));
for (int i = 0; i < n_audio_layer; ++i) {
auto & layer = model.layers_encoder[i];
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i);
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i);
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i);
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i);
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i);
}
// decoder
model.d_pe = create_tensor(ASR_TENSOR_DEC_POS_EMBD, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx));
model.d_te = create_tensor(ASR_TENSOR_DEC_TOKEN_EMBD_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab));
model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state));
for (int i = 0; i < n_text_layer; ++i) {
auto & layer = model.layers_decoder[i];
layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i);
layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i);
layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i);
layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
layer.attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
layer.attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
layer.attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
layer.attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.cross_attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.cross_attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.cross_attn_q_w = create_tensor(ASR_TENSOR_ATTN_QUERY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
layer.cross_attn_q_b = create_tensor(ASR_TENSOR_ATTN_QUERY_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.cross_attn_k_w = create_tensor(ASR_TENSOR_ATTN_KEY_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
layer.cross_attn_v_w = create_tensor(ASR_TENSOR_ATTN_VALUE_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
layer.cross_attn_v_b = create_tensor(ASR_TENSOR_ATTN_VALUE_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
layer.cross_attn_ln_1_w = create_tensor(ASR_TENSOR_ATTN_OUT_WEIGHT, ASR_SYSTEM_CROSS, ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state), i);
layer.cross_attn_ln_1_b = create_tensor(ASR_TENSOR_ATTN_OUT_BIAS, ASR_SYSTEM_CROSS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i);
}
ggml_free(ctx);
}
// allocate tensors in the backend buffers
for (auto & p : ctx_map) {
ggml_backend_buffer_type_t buft = p.first;
ggml_context * ctx = p.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (buf) {
model.buffers.emplace_back(buf);
size_t size_main = ggml_backend_buffer_get_size(buf);
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
}
}
// load weights
{
size_t total_size = 0;
model.n_loaded = 0;
std::vector<char> read_buf;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
read_safe(loader, n_dims);
read_safe(loader, length);
read_safe(loader, ttype);
if (loader->eof(loader->context)) {
break;
}
int32_t nelements = 1;
int32_t ne[4] = { 1, 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
read_safe(loader, ne[i]);
nelements *= ne[i];
}
std::string name;
std::vector<char> tmp(length); // create a buffer
loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name) == model.tensors.end()) {
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
if (ggml_backend_buffer_is_host(tensor->buffer)) {
// for the CPU and Metal backend, we can read directly into the tensor
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
BYTESWAP_TENSOR(tensor);
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(ggml_nbytes(tensor));
loader->read(loader->context, read_buf.data(), read_buf.size());
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
}
total_size += ggml_nbytes(tensor);
model.n_loaded++;
}
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
if (model.n_loaded == 0) {
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
} else if (model.n_loaded != (int) model.tensors.size()) {
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
return false;
}
}
for (auto & buf : model.buffers) {
ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
}
wctx.t_load_us = ggml_time_us() - t_start_us;
return true;
}
static bool whisper_encode_external(const whisper_state & wstate) {
GGML_UNUSED(wstate);
#ifndef WHISPER_USE_COREML
const bool use_coreml = false;
#else
const bool use_coreml = wstate.ctx_coreml != nullptr;
#endif
#ifndef WHISPER_USE_OPENVINO
const bool use_openvino = false;
#else
const bool use_openvino = wstate.ctx_openvino != nullptr;
#endif
return use_coreml || use_openvino;
}
static struct ggml_cgraph * whisper_build_graph_conv(
whisper_context & wctx,
whisper_state & wstate) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state);
const int n_mels = hparams.n_mels;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.sched_conv.meta.size(),
/*.mem_buffer =*/ wstate.sched_conv.meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
ggml_set_name(mel, "mel");
ggml_set_input(mel);
struct ggml_tensor * cur = nullptr;
if (!whisper_encode_external(wstate)) {
// convolution + gelu
{
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
cur = ggml_gelu(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
cur = ggml_gelu(ctx0, cur);
}
ggml_set_name(cur, "embd_conv");
wstate.embd_conv = cur;
} else {
ggml_build_forward_expand(gf, mel);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
ggml_set_input(cur); // the external encoder will write into this tensor
ggml_set_name(cur, "embd_enc");
wstate.embd_enc = cur;
}
ggml_set_output(cur);
ggml_build_forward_expand(gf, cur);
ggml_free(ctx0);
return gf;
}
static struct ggml_cgraph * whisper_build_graph_encoder(
whisper_context & wctx,
whisper_state & wstate) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head;
const int n_layer = hparams.n_audio_layer;
const int n_state_head = n_state/n_head;
auto & kv_pad = wstate.kv_pad;
WHISPER_ASSERT(!!kv_pad.buffer);
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
struct ggml_init_params params = {
/*.mem_size =*/ wstate.sched_encode.meta.size(),
/*.mem_buffer =*/ wstate.sched_encode.meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
const float KQscale = 1.0f/sqrtf(float(n_state_head));
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1;
//const int n_iter = 1500/n_ctx;
//iter = (iter + 1) % n_iter;
//if (iter == 0) {
// memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k));
// memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v));
//}
static int iter = 0;
const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe);
const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter;
struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset);
cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur)));
// ===================================================================
// original:
//cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
struct ggml_tensor * inpL = cur;
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers_encoder[il];
// norm
{
cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, layer.attn_ln_0_w),
layer.attn_ln_0_b);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.attn_q_w,
cur);
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
layer.attn_k_w,
cur);
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
layer.attn_v_w,
cur);
Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b);
// ------
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
0, 2, 1, 3);
if (wctx.params.flash_attn) {
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0)));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0)));
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_pad.k,
n_state_head, n_ctx_pad, n_head,
ggml_element_size(kv_pad.k)*n_state,
ggml_element_size(kv_pad.k)*n_state_head,
0);
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_pad.v,
n_state_head, n_ctx_pad, n_head,
ggml_element_size(kv_pad.v)*n_state,
ggml_element_size(kv_pad.v)*n_state_head,
0);
cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
} else {
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_cast(ctx0,
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
wctx.itype),
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
struct ggml_tensor * V =
ggml_cast(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
Vcur,
n_state_head, n_head, n_ctx),
1, 2, 0, 3),
wctx.itype);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
}
}
// projection
{
cur = ggml_mul_mat(ctx0,
layer.attn_ln_1_w,
cur);
cur = ggml_add(ctx0, cur, layer.attn_ln_1_b);
}
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF, hparams.eps);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, layer.mlp_ln_w),
layer.mlp_ln_b);
}
// fully connected
cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
cur = ggml_add(ctx0, cur, layer.mlp_0_b);
// GELU activation
cur = ggml_gelu(ctx0, cur);
// projection
cur = ggml_mul_mat(ctx0,
layer.mlp_1_w,
cur);
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
}
inpL = ggml_add(ctx0, cur, inpFF);
}
cur = inpL;
// norm
{
cur = ggml_norm(ctx0, cur, hparams.eps);
// cur = ln_f_g*cur + ln_f_b
cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, model.e_ln_w),
model.e_ln_b);
}
ggml_build_forward_expand(gf, cur);
wstate.embd_enc = cur;
//ggml_graph_print(gf);
////////////////////////////////////////////////////////////////////////////
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// ggml_used_mem(ctx0)/1e6,
// wstate.get_buf_max_mem(0)/1e6,
// wstate.get_buf_max_mem(1)/1e6,
// wstate.get_buf_max_mem(2)/1e6,
// wstate.get_buf_max_mem(3)/1e6);
ggml_free(ctx0);
return gf;
}
// pre-compute cross-attention memory
static struct ggml_cgraph * whisper_build_graph_cross(
whisper_context & wctx,
whisper_state & wstate) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head;
const int n_state_head = n_state/n_head;
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
struct ggml_init_params params = {
/*.mem_size =*/ wstate.sched_cross.meta.size(),
/*.mem_buffer =*/ wstate.sched_cross.meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
const float Kscale = pow(float(n_state_head), -0.25);
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
layer.cross_attn_k_w,
cur);
Kcross = ggml_scale(ctx0, Kcross, Kscale);
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
cur);
Vcross = ggml_add(ctx0,
Vcross,
layer.cross_attn_v_b);
struct ggml_tensor * k;
struct ggml_tensor * v;
if (wctx.params.flash_attn) {
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad));
v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx,
(ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad));
} else {
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx,
(ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
}
//ggml_graph_print(gf);
ggml_free(ctx0);
return gf;
}
// evaluate the encoder with the given state
//
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
// part of the transformer model and returns the encoded features
//
// - wctx: the model
// - wstate: the state of the encoder
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
static bool whisper_encode_internal(
whisper_context & wctx,
whisper_state & wstate,
const int mel_offset,
const int n_threads,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();
// conv
{
auto & sched = wstate.sched_conv.sched;
ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate);
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
// should never happen as we pre-allocate the memory
return false;
}
struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel");
// set the input
{
const auto & mel_inp = wstate.mel;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
assert(mel->type == GGML_TYPE_F32);
assert(mel_inp.n_mel == wctx.model.hparams.n_mels);
wstate.inp_mel.resize(ggml_nelements(mel));
float * dst = wstate.inp_mel.data();
memset(dst, 0, ggml_nbytes(mel));
const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
for (int j = 0; j < mel_inp.n_mel; ++j) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
}
}
ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float));
}
if (!whisper_encode_external(wstate)) {
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
return false;
}
} else {
#if defined(WHISPER_USE_COREML)
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data);
#elif defined(WHISPER_USE_OPENVINO)
whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc);
#endif
}
}
// encoder
if (!whisper_encode_external(wstate)) {
auto & sched = wstate.sched_encode.sched;
ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate);
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
// should never happen as we pre-allocate the memory
return false;
}
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
return false;
}
}
// cross
{
auto & sched = wstate.sched_cross.sched;
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
// should never happen as we pre-allocate the memory
return false;
}
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
return false;
}
}
wstate.t_encode_us += ggml_time_us() - t_start_us;
wstate.n_encode++;
return !(abort_callback && abort_callback(abort_callback_data));
}
static struct ggml_cgraph * whisper_build_graph_decoder(
whisper_context & wctx,
whisper_state & wstate,
const whisper_batch & batch,
bool save_alignment_heads_QKs,
bool worst_case) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
auto & kv_self = wstate.kv_self;
WHISPER_ASSERT(!!kv_self.buffer);
const int n_ctx = kv_self.size;
const int n_state = hparams.n_text_state;
const int n_head = hparams.n_text_head;
const int n_layer = hparams.n_text_layer;
const int n_state_head = n_state/n_head;
const int n_tokens = batch.n_tokens;
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256);
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
struct ggml_init_params params = {
/*.mem_size =*/ wstate.sched_decode.meta.size(),
/*.mem_buffer =*/ wstate.sched_decode.meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_name(embd, "embd");
ggml_set_input(embd);
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_name(position, "position");
ggml_set_input(position);
const float KQscale = pow(float(n_state_head), -0.25);
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_set_input(KQ_mask);
struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16);
// token encoding + position encoding
struct ggml_tensor * cur =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.d_te, embd),
ggml_get_rows(ctx0, model.d_pe, position));
struct ggml_tensor * inpL = cur;
// [EXPERIMENTAL] Token-level timestamps with DTW
struct ggml_tensor * aheads_cross_QKs = nullptr;
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers_decoder[il];
// norm
{
cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
cur,
layer.attn_ln_0_w),
layer.attn_ln_0_b);
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.attn_q_w,
cur);
Qcur = ggml_add(ctx0,
Qcur,
layer.attn_q_b);
Qcur = ggml_scale(ctx0, Qcur, KQscale);
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
layer.attn_k_w,
cur);
Kcur = ggml_scale(ctx0, Kcur, KQscale);
// store key and value to memory
{
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
layer.attn_v_w,
cur);
Vcur = ggml_add(ctx0,
Vcur,
layer.attn_v_b);
struct ggml_tensor * k;
struct ggml_tensor * v;
if (wctx.params.flash_attn) {
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state,
(ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head));
} else {
Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens));
k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state,
(ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head));
v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v));
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}
// ------
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
n_state_head, n_kv, n_head,
ggml_element_size(kv_self.k)*n_state,
ggml_element_size(kv_self.k)*n_state_head,
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
if (wctx.params.flash_attn) {
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_state_head, n_kv, n_head,
ggml_element_size(kv_self.v)*n_state,
ggml_element_size(kv_self.v)*n_state_head,
ggml_element_size(kv_self.v)*n_state*n_ctx*il);
cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f);
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
} else {
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f);
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_kv, n_state_head, n_head,
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_state_head,
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
}
}
// projection
{
cur = ggml_mul_mat(ctx0,
layer.attn_ln_1_w,
cur);
cur = ggml_add(ctx0,
cur,
layer.attn_ln_1_b);
}
// add the input
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
// norm
{
cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
cur,
layer.cross_attn_ln_0_w),
layer.cross_attn_ln_0_b);
}
// cross-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.cross_attn_q_w,
cur);
Qcur = ggml_add(ctx0,
Qcur,
layer.cross_attn_q_b);
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
0, 2, 1, 3);
if (wctx.params.flash_attn) {
struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state_head, n_audio_ctx_pad, n_head,
ggml_element_size(wstate.kv_cross.k)*n_state,
ggml_element_size(wstate.kv_cross.k)*n_state_head,
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il);
struct ggml_tensor * Vcross =
ggml_view_3d(ctx0, wstate.kv_cross.v,
n_state_head, n_audio_ctx_pad, n_head,
ggml_element_size(wstate.kv_cross.v)*n_state,
ggml_element_size(wstate.kv_cross.v)*n_state_head,
ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il);
cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f);
cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens);
} else {
struct ggml_tensor * Kcross =
ggml_view_3d(ctx0, wstate.kv_cross.k,
n_state_head, n_audio_ctx, n_head,
ggml_element_size(wstate.kv_cross.k)*n_state,
ggml_element_size(wstate.kv_cross.k)*n_state_head,
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
struct ggml_tensor * Vcross =
ggml_view_3d(ctx0, wstate.kv_cross.v,
n_audio_ctx, n_state_head, n_head,
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
// ------
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
// [EXPERIMENTAL] Token-level timestamps with DTW
if (wctx.params.dtw_token_timestamps) {
if (wstate.aheads_masks.m[il] != nullptr) {
struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]);
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs);
aheads_KQs = ggml_transpose(ctx0, aheads_KQs);
aheads_KQs = ggml_cont(ctx0, aheads_KQs);
aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]);
if (aheads_cross_QKs == NULL) {
aheads_cross_QKs = aheads_KQs;
} else {
aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2);
}
}
}
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
}
}
// projection
{
cur = ggml_mul_mat(ctx0,
layer.cross_attn_ln_1_w,
cur);
cur = ggml_add(ctx0,
cur,
layer.cross_attn_ln_1_b);
}
// add the input
cur = ggml_add(ctx0, cur, inpCA);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctx0, inpFF, hparams.eps);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
cur,
layer.mlp_ln_w),
layer.mlp_ln_b);
}
// fully connected
cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
cur = ggml_add(ctx0,
cur,
layer.mlp_0_b);
// GELU activation
cur = ggml_gelu(ctx0, cur);
// projection
cur = ggml_mul_mat(ctx0,
layer.mlp_1_w,
cur);
cur = ggml_add(ctx0,
cur,
layer.mlp_1_b);
}
inpL = ggml_add(ctx0, cur, inpFF);
}
cur = inpL;
// norm
{
cur = ggml_norm(ctx0, cur, hparams.eps);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
cur,
model.d_ln_w),
model.d_ln_b);
}
// compute logits only for the last token
// comment this line to compute logits for all n_tokens
// might be useful in the future
//cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
// [EXPERIMENTAL] Token-level timestamps with DTW
if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) {
aheads_cross_QKs = ggml_transpose(ctx0, aheads_cross_QKs);
aheads_cross_QKs = ggml_cont(ctx0, aheads_cross_QKs);
if (save_alignment_heads_QKs) {
ggml_build_forward_expand(gf, aheads_cross_QKs);
wstate.aheads_cross_QKs = aheads_cross_QKs;
}
}
ggml_build_forward_expand(gf, logits);
ggml_free(ctx0);
return gf;
}
// evaluate the decoder
//
// given text prompt + audio features -> computes the logits for the next token
//
// - model: the model
// - n_threads: number of threads to use
// - tokens: text prompt
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
static bool whisper_decode_internal(
whisper_context & wctx,
whisper_state & wstate,
const whisper_batch & batch,
const int n_threads,
bool save_alignment_heads_QKs,
ggml_abort_callback abort_callback,
void * abort_callback_data) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model;
const auto & hparams = model.hparams;
const int n_vocab = hparams.n_vocab;
const int n_tokens = batch.n_tokens;
auto & logits_out = wstate.logits;
struct ggml_tensor * logits;
// find KV slot for the batch
{
auto & kv_self = wstate.kv_self;
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
return false;
}
const uint32_t pad = whisper_kv_cache_get_padding(wctx);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad)));
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
}
// decoder
{
auto & sched = wstate.sched_decode.sched;
ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false);
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
// should never happen as we pre-allocate the memory
return false;
}
// set the inputs
{
struct ggml_tensor * embd = ggml_graph_get_tensor(gf, "embd");
ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd));
}
{
struct ggml_tensor * position = ggml_graph_get_tensor(gf, "position");
for (int i = 0; i < n_tokens; ++i) {
const int32_t val = batch.pos[i];
ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t));
}
}
{
struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask");
auto & kv_self = wstate.kv_self;
const int32_t n_kv = kv_self.n;
wstate.inp_mask.resize(ggml_nelements(KQ_mask));
float * data = wstate.inp_mask.data();
memset(data, 0, ggml_nbytes(KQ_mask));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const whisper_pos pos = batch.pos[j];
const whisper_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
}
}
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
}
ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float));
}
logits = ggml_graph_node(gf, -1);
if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
return false;
}
}
logits_out.resize(n_tokens*n_vocab);
for (int i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}
ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
}
if (batch.n_tokens > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// ggml_used_mem(ctx0)/1e6,
// wstate.get_buf_max_mem(0)/1e6,
// wstate.get_buf_max_mem(1)/1e6,
// wstate.get_buf_max_mem(2)/1e6,
// wstate.get_buf_max_mem(3)/1e6);
}
if (batch.n_tokens == 1) {
wstate.t_decode_us += ggml_time_us() - t_start_us;
wstate.n_decode++;
} else if (batch.n_tokens < 16) {
wstate.t_batchd_us += ggml_time_us() - t_start_us;
wstate.n_batchd += n_tokens;
} else {
wstate.t_prompt_us += ggml_time_us() - t_start_us;
wstate.n_prompt += n_tokens;
}
return !(abort_callback && abort_callback(abort_callback_data));
}
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
int64_t msec = t * 10;
int64_t hr = msec / (1000 * 60 * 60);
msec = msec - hr * (1000 * 60 * 60);
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
return std::string(buf);
}
#define SIN_COS_N_COUNT WHISPER_N_FFT
namespace {
struct whisper_global_cache {
// In FFT, we frequently use sine and cosine operations with the same values.
// We can use precalculated values to speed up the process.
float sin_vals[SIN_COS_N_COUNT];
float cos_vals[SIN_COS_N_COUNT];
// Hann window (Use cosf to eliminate difference)
// ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
float hann_window[WHISPER_N_FFT];
whisper_global_cache() {
fill_sin_cos_table();
fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
}
void fill_sin_cos_table() {
for (int i = 0; i < SIN_COS_N_COUNT; i++) {
double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
sin_vals[i] = sinf(theta);
cos_vals[i] = cosf(theta);
}
}
void fill_hann_window(int length, bool periodic, float * output) {
int offset = -1;
if (periodic) {
offset = 0;
}
for (int i = 0; i < length; i++) {
output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
}
}
} global_cache;
}
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const float* in, int N, float* out) {
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
re += in[n]*global_cache.cos_vals[idx]; // cos(t)
im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}
// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(float* in, int N, float* out) {
if (N == 1) {
out[0] = in[0];
out[1] = 0;
return;
}
const int half_N = N / 2;
if (N - half_N*2 == 1) {
dft(in, N, out);
return;
}
float* even = in + N;
for (int i = 0; i < half_N; ++i) {
even[i]= in[2*i];
}
float* even_fft = out + 2 * N;
fft(even, half_N, even_fft);
float* odd = even;
for (int i = 0; i < half_N; ++i) {
odd[i] = in[2*i + 1];
}
float* odd_fft = even_fft + N;
fft(odd, half_N, odd_fft);
const int sin_cos_step = SIN_COS_N_COUNT / N;
for (int k = 0; k < half_N; k++) {
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
float re = global_cache.cos_vals[idx]; // cos(t)
float im = -global_cache.sin_vals[idx]; // sin(t)
float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
}
}
static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector<float> & samples,
int n_samples, int frame_size, int frame_step, int n_threads,
const whisper_filters & filters, whisper_mel & mel) {
std::vector<float> fft_in(frame_size * 2, 0.0);
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
int n_fft = filters.n_fft;
int i = ith;
// make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
assert(n_fft == 1 + (frame_size / 2));
// calculate FFT only when fft_in are not all zero
for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
const int offset = i * frame_step;
// apply Hann window (~10% faster)
for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
fft_in[j] = hann[j] * samples[offset + j];
}
// fill the rest with zeros
if (n_samples - offset < frame_size) {
std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
}
// FFT
fft(fft_in.data(), frame_size, fft_out.data());
// Calculate modulus^2 of complex numbers
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
for (int j = 0; j < n_fft; j++) {
fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
// unroll loop (suggested by GH user @lunixbochs)
int k = 0;
for (k = 0; k < n_fft - 3; k += 4) {
sum +=
fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
fft_out[k + 3] * filters.data[j * n_fft + k + 3];
}
// handle n_fft remainder
for (; k < n_fft; k++) {
sum += fft_out[k] * filters.data[j * n_fft + k];
}
sum = log10(std::max(sum, 1e-10));
mel.data[j * mel.n_len + i] = sum;
}
}
// Otherwise fft_out are all zero
double sum = log10(1e-10);
for (; i < mel.n_len; i += n_threads) {
for (int j = 0; j < mel.n_mel; j++) {
mel.data[j * mel.n_len + i] = sum;
}
}
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
static bool log_mel_spectrogram(
whisper_state & wstate,
const float * samples,
const int n_samples,
const int /*sample_rate*/,
const int frame_size,
const int frame_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();
// Hann window
WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
const float * hann = global_cache.hann_window;
// Calculate the length of padding
int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
int64_t stage_2_pad = frame_size / 2;
// Initialize a vector and copy data from C array to it.
std::vector<float> samples_padded;
samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
// pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
// reflective pad 200 samples at the beginning of audio
std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
mel.n_mel = n_mel;
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
// Calculate number of frames + remove the last frame
mel.n_len = (samples_padded.size() - frame_size) / frame_step;
// Calculate semi-padded sample length to ensure compatibility
mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
mel.data.resize(mel.n_mel * mel.n_len);
{
std::vector<std::thread> workers(n_threads - 1);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw] = std::thread(
log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
n_samples + stage_2_pad, frame_size, frame_step, n_threads,
std::cref(filters), std::ref(mel));
}
// main thread
log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
for (int iw = 0; iw < n_threads - 1; ++iw) {
workers[iw].join();
}
}
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
wstate.t_mel_us += ggml_time_us() - t_start_us;
// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}
return true;
}
// split text into tokens
//
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
//
// Regex (Python):
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
//
// Regex (C++):
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
//
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<whisper_vocab::id> tokens;
for (const auto & word : words) {
if (word.empty()) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
bool found = false;
while (j > i) {
auto sub = word.substr(i, j-i);
auto it = vocab.token_to_id.find(sub);
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
found = true;
break;
}
--j;
}
if (!found) {
WHISPER_LOG_ERROR("unknown token\n");
++i;
}
}
}
return tokens;
}
//
// interface implementation
//
#ifdef WHISPER_USE_COREML
// replace .bin with -encoder.mlmodelc
static std::string whisper_get_coreml_path_encoder(std::string path_bin) {
auto pos = path_bin.rfind('.');
if (pos != std::string::npos) {
path_bin = path_bin.substr(0, pos);
}
// match "-qx_x"
pos = path_bin.rfind('-');
if (pos != std::string::npos) {
auto sub = path_bin.substr(pos);
if (sub.size() == 5 && sub[1] == 'q' && sub[3] == '_') {
path_bin = path_bin.substr(0, pos);
}
}
path_bin += "-encoder.mlmodelc";
return path_bin;
}
#endif
#ifdef WHISPER_USE_OPENVINO
// replace .bin with-encoder-openvino.xml
static std::string whisper_openvino_get_path_encoder(std::string path_bin) {
auto pos = path_bin.rfind('.');
if (pos != std::string::npos) {
path_bin = path_bin.substr(0, pos);
}
path_bin += "-encoder-openvino.xml";
return path_bin;
}
static std::string whisper_openvino_get_path_cache(std::string path_bin) {
auto pos = path_bin.rfind('.');
if (pos != std::string::npos) {
path_bin = path_bin.substr(0, pos);
}
path_bin += "-encoder-openvino-cache";
return path_bin;
}
#endif
struct whisper_state * whisper_init_state(whisper_context * ctx) {
whisper_state * state = new whisper_state;
state->backends = whisper_backend_init(ctx->params);
if (state->backends.empty()) {
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
whisper_free_state(state);
return nullptr;
}
// at this point, we don't know yet how many decoders will be used
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
state->kv_self_n_dec = 1;
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer,
GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state);
return nullptr;
}
{
const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v);
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
}
if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype,
ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer,
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__);
whisper_free_state(state);
return nullptr;
}
{
const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v);
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
}
if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype,
ctx->model.hparams.n_audio_state,
1,
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state);
return nullptr;
}
{
const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
}
// [EXPERIMENTAL] Token-level timestamps with DTW
if (ctx->params.dtw_token_timestamps) {
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) {
WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__);
whisper_free_state(state);
return nullptr;
}
const size_t memory_size = aheads_masks_nbytes(state->aheads_masks);
WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size);
}
#ifdef WHISPER_USE_COREML
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str());
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
state->ctx_coreml = whisper_coreml_init(path_coreml.c_str());
if (!state->ctx_coreml) {
WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str());
#ifndef WHISPER_COREML_ALLOW_FALLBACK
whisper_free_state(state);
return nullptr;
#endif
} else {
WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__);
}
#endif
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
// TAGS: WHISPER_DECODER_INIT
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
state->decoders[0].rng = std::mt19937(0);
// conv allocator
{
bool ok = whisper_sched_graph_init(state->sched_conv, state->backends,
[&]() {
return whisper_build_graph_conv(*ctx, *state);
});
if (!ok) {
WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__);
whisper_free_state(state);
return nullptr;
}
WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_conv) / 1e6);
}
// encoder allocator
if (!whisper_encode_external(*state)) {
bool ok = whisper_sched_graph_init(state->sched_encode, state->backends,
[&]() {
return whisper_build_graph_encoder(*ctx, *state);
});
if (!ok) {
WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__);
whisper_free_state(state);
return nullptr;
}
WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_encode) / 1e6);
}
// cross allocator
{
bool ok = whisper_sched_graph_init(state->sched_cross, state->backends,
[&]() {
return whisper_build_graph_cross(*ctx, *state);
});
if (!ok) {
WHISPER_LOG_ERROR("%s: failed to init cross allocator\n", __func__);
whisper_free_state(state);
return nullptr;
}
WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_cross) / 1e6);
}
// decoder allocator
{
bool ok = whisper_sched_graph_init(state->sched_decode, state->backends,
[&]() {
const auto & hparams = ctx->model.hparams;
// TODO: make sure this is the worst-case scenario
const int n_tokens = hparams.n_text_ctx;
const int n_past = 0;
whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0);
return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true);
});
if (!ok) {
WHISPER_LOG_ERROR("%s: failed to init decoder allocator\n", __func__);
whisper_free_state(state);
return nullptr;
}
WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_sched_size(state->sched_decode) / 1e6);
}
return state;
}
int whisper_ctx_init_openvino_encoder_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const char * model_path,
const char * device,
const char * cache_dir) {
#ifndef WHISPER_USE_OPENVINO
(void)(ctx);
(void)(state);
(void)(model_path);
(void)(device);
(void)(cache_dir);
return 1;
#else
if (!model_path && ctx->path_model.empty()) {
WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__);
return 1;
}
std::string path_encoder;
if (!model_path) {
//if model_path is not set, attempt to find it in the same directory as ggml-<model>.bin model
path_encoder = whisper_openvino_get_path_encoder(ctx->path_model);
} else {
path_encoder = model_path;
}
std::string path_cache;
if (!cache_dir) {
//if cache_dir is not set, set it as a dir residing next to ggml-<model>.bin
path_cache = whisper_openvino_get_path_cache(ctx->path_model);
} else {
path_cache = cache_dir;
}
WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str());
WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__);
state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str());
if (!state->ctx_openvino) {
WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str());
return 1;
} else {
WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__);
}
return 0;
#endif
}
int whisper_ctx_init_openvino_encoder(
struct whisper_context * ctx,
const char * model_path,
const char * device,
const char * cache_dir) {
return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir);
}
struct whisper_context_params whisper_context_default_params() {
struct whisper_context_params result = {
/*.use_gpu =*/ true,
/*.flash_attn =*/ false,
/*.gpu_device =*/ 0,
/*.dtw_token_timestamps =*/ false,
/*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
/*.dtw_n_top =*/ -1,
/*.dtw_aheads =*/ {
/*.n_heads =*/ 0,
/*.heads =*/ NULL,
},
/*.dtw_mem_size =*/ 1024*1024*128,
};
return result;
}
struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) {
WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model);
#ifdef _MSC_VER
// Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues.
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring path_model_wide = converter.from_bytes(path_model);
auto fin = std::ifstream(path_model_wide, std::ios::binary);
#else
auto fin = std::ifstream(path_model, std::ios::binary);
#endif
if (!fin) {
WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model);
return nullptr;
}
whisper_model_loader loader = {};
loader.context = &fin;
loader.read = [](void * ctx, void * output, size_t read_size) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->read((char *)output, read_size);
return read_size;
};
loader.eof = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
return fin->eof();
};
loader.close = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->close();
};
auto ctx = whisper_init_with_params_no_state(&loader, params);
if (ctx) {
ctx->path_model = path_model;
}
return ctx;
}
struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) {
struct buf_context {
uint8_t* buffer;
size_t size;
size_t current_offset;
};
buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 };
WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__);
whisper_model_loader loader = {};
loader.context = &ctx;
loader.read = [](void * ctx, void * output, size_t read_size) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset;
memcpy(output, buf->buffer + buf->current_offset, size_to_copy);
buf->current_offset += size_to_copy;
return size_to_copy;
};
loader.eof = [](void * ctx) {
buf_context * buf = reinterpret_cast<buf_context *>(ctx);
return buf->current_offset >= buf->size;
};
loader.close = [](void * /*ctx*/) { };
return whisper_init_with_params_no_state(&loader, params);
}
struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) {
ggml_time_init();
if (params.flash_attn && params.dtw_token_timestamps) {
WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__);
params.dtw_token_timestamps = false;
}
WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu);
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);
WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());
whisper_context * ctx = new whisper_context;
ctx->params = params;
if (!whisper_model_load(loader, *ctx)) {
loader->close(loader->context);
WHISPER_LOG_ERROR("%s: failed to load model\n", __func__);
delete ctx;
return nullptr;
}
loader->close(loader->context);
return ctx;
}
struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) {
whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params);
if (!ctx) {
return nullptr;
}
ctx->state = whisper_init_state(ctx);
if (!ctx->state) {
whisper_free(ctx);
return nullptr;
}
return ctx;
}
struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) {
whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params);
if (!ctx) {
return nullptr;
}
ctx->state = whisper_init_state(ctx);
if (!ctx->state) {
whisper_free(ctx);
return nullptr;
}
return ctx;
}
struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) {
whisper_context * ctx = whisper_init_with_params_no_state(loader, params);
if (!ctx) {
return nullptr;
}
ctx->state = whisper_init_state(ctx);
if (!ctx->state) {
whisper_free(ctx);
return nullptr;
}
return ctx;
}
struct whisper_context * whisper_init_from_file(const char * path_model) {
return whisper_init_from_file_with_params(path_model, whisper_context_default_params());
}
struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) {
return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params());
}
struct whisper_context * whisper_init(struct whisper_model_loader * loader) {
return whisper_init_with_params(loader, whisper_context_default_params());
}
struct whisper_context * whisper_init_from_file_no_state(const char * path_model) {
return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params());
}
struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params());
}
struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) {
return whisper_init_with_params_no_state(loader, whisper_context_default_params());
}
void whisper_free_state(struct whisper_state * state) {
if (state) {
whisper_kv_cache_free(state->kv_self);
whisper_kv_cache_free(state->kv_cross);
whisper_kv_cache_free(state->kv_pad);
#ifdef WHISPER_USE_COREML
if (state->ctx_coreml != nullptr) {
whisper_coreml_free(state->ctx_coreml);
state->ctx_coreml = nullptr;
}
#endif
#ifdef WHISPER_USE_OPENVINO
if (state->ctx_openvino != nullptr) {
whisper_openvino_free(state->ctx_openvino);
state->ctx_openvino = nullptr;
}
#endif
whisper_batch_free(state->batch);
ggml_backend_sched_free(state->sched_conv.sched);
ggml_backend_sched_free(state->sched_encode.sched);
ggml_backend_sched_free(state->sched_cross.sched);
ggml_backend_sched_free(state->sched_decode.sched);
for (auto & backend : state->backends) {
ggml_backend_free(backend);
}
// [EXPERIMENTAL] Token-level timestamps with DTW
aheads_masks_free(state->aheads_masks);
delete state;
}
}
void whisper_free(struct whisper_context * ctx) {
if (ctx) {
for (ggml_context * context : ctx->model.ctxs) {
ggml_free(context);
}
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
ggml_backend_buffer_free(buf);
}
whisper_free_state(ctx->state);
delete ctx;
}
}
void whisper_free_context_params(struct whisper_context_params * params) {
if (params) {
delete params;
}
}
void whisper_free_params(struct whisper_full_params * params) {
if (params) {
delete params;
}
}
int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
return 0;
}
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads);
}
int whisper_set_mel_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
const float * data,
int n_len,
int n_mel) {
if (n_mel != ctx->model.filters.n_mel) {
WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel);
return -1;
}
state->mel.n_len = n_len;
state->mel.n_len_org = n_len;
state->mel.n_mel = n_mel;
state->mel.data.resize(n_len*n_mel);
memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float));
return 0;
}
int whisper_set_mel(
struct whisper_context * ctx,
const float * data,
int n_len,
int n_mel) {
return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel);
}
int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) {
if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return -1;
}
return 0;
}
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return -1;
}
return 0;
}
int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
return 0;
}
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
if (ctx->state == nullptr) {
WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__);
return -1;
}
return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads);
}
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
const auto res = tokenize(ctx->vocab, text);
if (n_max_tokens < (int) res.size()) {
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
return -(int) res.size();
}
for (int i = 0; i < (int) res.size(); i++) {
tokens[i] = res[i];
}
return res.size();
}
int whisper_token_count(struct whisper_context * ctx, const char * text) {
return -whisper_tokenize(ctx, text, NULL, 0);
}
int whisper_lang_max_id(void) {
auto max_id = 0;
for (const auto & kv : g_lang) {
max_id = std::max(max_id, kv.second.first);
}
return max_id;
}
int whisper_lang_id(const char * lang) {
if (!g_lang.count(lang)) {
for (const auto & kv : g_lang) {
if (kv.second.second == lang) {
return kv.second.first;
}
}
WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang);
return -1;
}
return g_lang.at(lang).first;
}
const char * whisper_lang_str(int id) {
for (const auto & kv : g_lang) {
if (kv.second.first == id) {
return kv.first.c_str();
}
}
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
return nullptr;
}
const char * whisper_lang_str_full(int id) {
for (const auto & kv : g_lang) {
if (kv.second.first == id) {
return kv.second.second.c_str();
}
}
WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id);
return nullptr;
}
int whisper_lang_auto_detect_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
int offset_ms,
int n_threads,
float * lang_probs) {
const int seek = offset_ms/10;
if (seek < 0) {
WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms);
return -1;
}
if (seek >= state->mel.n_len_org) {
WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10);
return -2;
}
// run the encoder
if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) {
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
return -6;
}
const std::vector<whisper_token> prompt = { whisper_token_sot(ctx) };
if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -7;
}
auto & logits_id = state->decoders[0].logits_id;
logits_id.clear();
for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
logits_id.emplace_back(state->logits[token_lang], kv.second.first);
}
// sort descending
{
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
// softmax
{
const auto max = logits_id[0].first;
double sum = 0.0f;
for (auto & kv : logits_id) {
kv.first = exp(kv.first - max);
sum += kv.first;
}
for (auto & kv : logits_id) {
kv.first /= sum;
}
}
{
for (const auto & prob : logits_id) {
if (lang_probs) {
lang_probs[prob.second] = prob.first;
}
//printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
}
}
return logits_id[0].second;
}
int whisper_lang_auto_detect(
struct whisper_context * ctx,
int offset_ms,
int n_threads,
float * lang_probs) {
return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs);
}
int whisper_model_n_vocab(struct whisper_context * ctx) {
return ctx->model.hparams.n_vocab;
}
int whisper_model_n_audio_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_ctx;
}
int whisper_model_n_audio_state(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_state;
}
int whisper_model_n_audio_head(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_head;
}
int whisper_model_n_audio_layer(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_layer;
}
int whisper_model_n_text_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_ctx;
}
int whisper_model_n_text_state(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_state;
}
int whisper_model_n_text_head(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_head;
}
int whisper_model_n_text_layer(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_layer;
}
int whisper_model_n_mels(struct whisper_context * ctx) {
return ctx->model.hparams.n_mels;
}
int whisper_model_ftype(struct whisper_context * ctx) {
return ctx->model.hparams.ftype;
}
int whisper_model_type(struct whisper_context * ctx) {
return ctx->model.type;
}
const char *whisper_model_type_readable(struct whisper_context * ctx) {
switch (ctx->model.type) {
case e_model::MODEL_TINY:
return "tiny";
case e_model::MODEL_BASE:
return "base";
case e_model::MODEL_SMALL:
return "small";
case e_model::MODEL_MEDIUM:
return "medium";
case e_model::MODEL_LARGE:
return "large";
default:
return "unknown";
}
}
int whisper_n_len_from_state(struct whisper_state * state) {
return state->mel.n_len_org;
}
int whisper_n_len(struct whisper_context * ctx) {
return ctx->state->mel.n_len_org;
}
int whisper_n_vocab(struct whisper_context * ctx) {
return ctx->vocab.n_vocab;
}
int whisper_n_text_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_ctx;
}
int whisper_n_audio_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_audio_ctx;
}
int whisper_is_multilingual(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0;
}
float * whisper_get_logits(struct whisper_context * ctx) {
return ctx->state->logits.data();
}
float * whisper_get_logits_from_state(struct whisper_state * state) {
return state->logits.data();
}
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
return ctx->vocab.id_to_token.at(token).c_str();
}
whisper_token whisper_token_eot(struct whisper_context * ctx) {
return ctx->vocab.token_eot;
}
whisper_token whisper_token_sot(struct whisper_context * ctx) {
return ctx->vocab.token_sot;
}
whisper_token whisper_token_solm(struct whisper_context * ctx) {
return ctx->vocab.token_solm;
}
whisper_token whisper_token_prev(struct whisper_context * ctx) {
return ctx->vocab.token_prev;
}
whisper_token whisper_token_nosp(struct whisper_context * ctx) {
return ctx->vocab.token_nosp;
}
whisper_token whisper_token_not(struct whisper_context * ctx) {
return ctx->vocab.token_not;
}
whisper_token whisper_token_beg(struct whisper_context * ctx) {
return ctx->vocab.token_beg;
}
whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) {
return whisper_token_sot(ctx) + 1 + lang_id;
}
whisper_token whisper_token_translate(struct whisper_context * ctx) {
return ctx->vocab.token_translate;
}
whisper_token whisper_token_transcribe(struct whisper_context * ctx) {
return ctx->vocab.token_transcribe;
}
struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) {
if (ctx->state == nullptr) {
return nullptr;
}
whisper_timings * timings = new whisper_timings;
timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample);
timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode);
timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode);
timings->batchd_ms = 1e-3f * ctx->state->t_batchd_us / std::max(1, ctx->state->n_batchd);
timings->prompt_ms = 1e-3f * ctx->state->t_prompt_us / std::max(1, ctx->state->n_prompt);
return timings;
}
void whisper_print_timings(struct whisper_context * ctx) {
const int64_t t_end_us = ggml_time_us();
WHISPER_LOG_INFO("\n");
WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
if (ctx->state != nullptr) {
const int32_t n_sample = std::max(1, ctx->state->n_sample);
const int32_t n_encode = std::max(1, ctx->state->n_encode);
const int32_t n_decode = std::max(1, ctx->state->n_decode);
const int32_t n_batchd = std::max(1, ctx->state->n_batchd);
const int32_t n_prompt = std::max(1, ctx->state->n_prompt);
WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h);
WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f);
WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample);
WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode);
WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode);
WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd);
WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt);
}
WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
}
void whisper_reset_timings(struct whisper_context * ctx) {
ctx->t_start_us = ggml_time_us();
if (ctx->state != nullptr) {
ctx->state->t_mel_us = 0;
ctx->state->t_sample_us = 0;
ctx->state->t_encode_us = 0;
ctx->state->t_decode_us = 0;
ctx->state->t_batchd_us = 0;
ctx->state->t_prompt_us = 0;
ctx->state->n_sample = 0;
ctx->state->n_encode = 0;
ctx->state->n_decode = 0;
ctx->state->n_batchd = 0;
ctx->state->n_prompt = 0;
}
}
static int whisper_has_coreml(void) {
#ifdef WHISPER_USE_COREML
return 1;
#else
return 0;
#endif
}
static int whisper_has_openvino(void) {
#ifdef WHISPER_USE_OPENVINO
return 1;
#else
return 0;
#endif
}
const char * whisper_print_system_info(void) {
static std::string s;
whisper_load_backends();
s = "";
s += "WHISPER : ";
s += "COREML = " + std::to_string(whisper_has_coreml()) + " | ";
s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | ";
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto * reg = ggml_backend_reg_get(i);
auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
if (get_features_fn) {
ggml_backend_feature * features = get_features_fn(reg);
s += ggml_backend_reg_name(reg);
s += " : ";
for (; features->name; features++) {
s += features->name;
s += " = ";
s += features->value;
s += " | ";
}
}
}
return s.c_str();
}
//////////////////////////////////
// Voice Activity Detection (VAD)
//////////////////////////////////
struct whisper_vad_hparams {
int32_t n_encoder_layers;
int32_t * encoder_in_channels;
int32_t * encoder_out_channels;
int32_t * kernel_sizes;
int32_t lstm_input_size;
int32_t lstm_hidden_size;
int32_t final_conv_in;
int32_t final_conv_out;
};
struct whisper_vad_model {
std::string type;
std::string version;
whisper_vad_hparams hparams;
struct ggml_tensor * stft_forward_basis; // [256, 1, 258]
// Encoder tensors - 4 convolutional layers
struct ggml_tensor * encoder_0_weight; // [3, 129, 128]
struct ggml_tensor * encoder_0_bias; // [128]
// Second encoder layer
struct ggml_tensor * encoder_1_weight; // [3, 128, 64]
struct ggml_tensor * encoder_1_bias; // [64]
// Third encoder layer
struct ggml_tensor * encoder_2_weight; // [3, 64, 64]
struct ggml_tensor * encoder_2_bias; // [64]
// Fourth encoder layer
struct ggml_tensor * encoder_3_weight; // [3, 64, 128]
struct ggml_tensor * encoder_3_bias; // [128]
// LSTM decoder tensors
struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden
struct ggml_tensor * lstm_ih_bias; // [512]
struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden
struct ggml_tensor * lstm_hh_bias; // [512]
// Final conv layer
struct ggml_tensor * final_conv_weight; // [128]
struct ggml_tensor * final_conv_bias; // [1]
// ggml contexts
std::vector<ggml_context *> ctxs;
// buffer for the model tensors
std::vector<ggml_backend_buffer_t> buffers;
// tensors
int n_loaded;
std::map<std::string, struct ggml_tensor *> tensors;
};
struct whisper_vad_segment {
float start; // Start time in seconds
float end; // End time in seconds
};
struct whisper_vad_segments {
std::vector<whisper_vad_segment> data;
};
struct whisper_vad_context {
int64_t t_vad_us = 0;
int n_window;
int n_context;
int n_threads;
std::vector<ggml_backend_t> backends;
ggml_backend_buffer_t buffer = nullptr;
whisper_context_params params;
std::vector<uint8_t> ctx_buf;
whisper_sched sched;
whisper_vad_model model;
std::string path_model;
struct ggml_tensor * h_state;
struct ggml_tensor * c_state;
std::vector<float> probs;
};
struct whisper_vad_context_params whisper_vad_default_context_params(void) {
whisper_vad_context_params result = {
/*.n_thread = */ 4,
/*.use_gpu = */ false,
/*.gpu_device = */ 0,
};
return result;
}
struct whisper_vad_params whisper_vad_default_params(void) {
whisper_vad_params result = {
/* threshold = */ 0.5f,
/* min_speech_duration_ms = */ 250,
/* min_silence_duration_ms = */ 100,
/* max_speech_duration_s = */ FLT_MAX,
/* speech_pad_ms = */ 30,
/* samples_overlap = */ 0.1,
};
return result;
}
static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
bool op_supported = true;
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU ||
(ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) {
// GPU and default CPU backend support all operators
op_supported = true;
} else {
switch (op) {
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
case GGML_OP_MUL_MAT: {
ggml_init_params params = {
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error("failed to create ggml context");
}
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * op_tensor = nullptr;
int64_t n_ctx = hparams.lstm_hidden_size;
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
op_tensor = ggml_mul_mat(ctx, w, b);
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
GGML_ASSERT(w->buffer == nullptr);
w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
ggml_backend_buffer_free(w->buffer);
w->buffer = nullptr;
break;
}
default: {
op_supported = false;
break;
}
};
}
return op_supported;
}
static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) {
GGML_ASSERT(!buft_list.empty());
for (const auto & p : buft_list) {
ggml_backend_dev_t dev = p.first;
ggml_backend_buffer_type_t buft = p.second;
if (weight_buft_supported(hparams, w, op, buft, dev)) {
return buft;
}
}
return nullptr;
}
static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0,
const whisper_vad_model & model, ggml_tensor * cur) {
// Apply reflective padding to the input tensor
ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64);
struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1);
// Calculate cutoff for real/imaginary parts
int cutoff = model.stft_forward_basis->ne[2] / 2;
// Extract real part (first half of the STFT output).
struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0);
// Extract imaginary part (second half of the STFT output).
struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]);
// Calculate magnitude: sqrt(real^2 + imag^2)
struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part);
struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part);
struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared);
struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares);
return magnitude;
}
static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0,
const whisper_vad_model & model, ggml_tensor * cur) {
// First Conv1D: expands to 128 channels.
cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1));
cur = ggml_relu(ctx0, cur);
// Second Conv1D: reduces to 64 channels.
cur = ggml_conv_1d(ctx0, model.encoder_1_weight, cur, 2, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_1_bias, 1, 64, 1));
cur = ggml_relu(ctx0, cur);
// Third Conv1D: maintains 64 channels
cur = ggml_conv_1d(ctx0, model.encoder_2_weight, cur, 2, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_2_bias, 1, 64, 1));
cur = ggml_relu(ctx0, cur);
// Fourth Conv1D: expands to 128 channels
cur = ggml_conv_1d(ctx0, model.encoder_3_weight, cur, 1, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_3_bias, 1, 128, 1));
cur = ggml_relu(ctx0, cur);
return cur;
}
static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0,
const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) {
const whisper_vad_model & model = vctx.model;
const int hdim = model.hparams.lstm_hidden_size;
struct ggml_tensor * x_t = ggml_transpose(ctx0, cur);
// Create operations using the input-to-hidden weights.
struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t);
inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias);
// Create operations using the hidden-to-hidden weights.
struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state);
hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias);
// Create add operation to get preactivations for all gates.
struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate);
const size_t hdim_size = ggml_row_size(out_gate->type, hdim);
// Create sigmoid for input gate (using the first 128 bytes from the preactivations).
struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size));
// Create sigmoid for the forget gate (using the second 128 bytes from the preactivations).
struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size));
// Create sigmoid for the cell gate (using the third 128 bytes from the preactivations).
struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size));
// Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations).
struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size));
// Update cell state
struct ggml_tensor * c_out = ggml_add(ctx0,
ggml_mul(ctx0, f_t, vctx.c_state),
ggml_mul(ctx0, i_t, g_t));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state));
// Update hidden state
struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state));
return out;
}
static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) {
const auto & model = vctx.model;
struct ggml_init_params params = {
/*.mem_size =*/ vctx.sched.meta.size(),
/*.mem_buffer =*/ vctx.sched.meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1);
ggml_set_name(frame, "frame");
ggml_set_input(frame);
struct ggml_tensor * cur = nullptr;
{
cur = whisper_vad_build_stft_layer(ctx0, model, frame);
cur = whisper_vad_build_encoder_layer(ctx0, model, cur);
// Extract the first element of the first dimension
// (equivalent to pytorch's [:, :, 0])
cur = ggml_view_2d(ctx0, cur, 1, 128, cur->nb[1], 0);
cur = whisper_vad_build_lstm_layer(ctx0, vctx, cur, gf);
cur = ggml_relu(ctx0, cur);
cur = ggml_conv_1d(ctx0, model.final_conv_weight, cur, 1, 0, 1);
cur = ggml_add(ctx0, cur, model.final_conv_bias);
cur = ggml_sigmoid(ctx0, cur);
ggml_set_name(cur, "prob");
ggml_set_output(cur);
}
ggml_build_forward_expand(gf, cur);
ggml_free(ctx0);
return gf;
}
static bool whisper_vad_init_context(whisper_vad_context * vctx) {
auto whisper_context_params = whisper_context_default_params();
// TODO: GPU VAD is forced disabled until the performance is improved
//whisper_context_params.use_gpu = vctx->params.use_gpu;
whisper_context_params.use_gpu = false;
whisper_context_params.gpu_device = vctx->params.gpu_device;
vctx->backends = whisper_backend_init(whisper_context_params);
if (vctx->backends.empty()) {
WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__);
return false;
}
const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size;
vctx->ctx_buf.resize(2u*ggml_tensor_overhead());
struct ggml_init_params params = {
/*.mem_size =*/ vctx->ctx_buf.size(),
/*.mem_buffer =*/ vctx->ctx_buf.data(),
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
if (!ctx) {
WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__);
return false;
}
// LSTM Hidden state
vctx->h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
ggml_set_name(vctx->h_state, "h_state");
// LSTM Cell state
vctx->c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, lstm_hidden_size);
ggml_set_name(vctx->c_state, "c_state");
vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]);
if (!vctx->buffer) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__);
return false;
}
{
bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends,
[&]() {
return whisper_vad_build_graph(*vctx);
});
if (!ok) {
WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__);
return false;
}
WHISPER_LOG_INFO("%s: compute buffer (VAD) = %7.2f MB\n", __func__, whisper_sched_size(vctx->sched) / 1e6);
}
return true;
}
struct whisper_vad_context * whisper_vad_init_from_file_with_params(
const char * path_model,
struct whisper_vad_context_params params) {
WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model);
#ifdef _MSC_VER
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring path_model_wide = converter.from_bytes(path_model);
auto fin = std::ifstream(path_model_wide, std::ios::binary);
#else
auto fin = std::ifstream(path_model, std::ios::binary);
#endif
if (!fin) {
WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model);
return nullptr;
}
whisper_model_loader loader = {};
loader.context = &fin;
loader.read = [](void * ctx, void * output, size_t read_size) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->read((char *)output, read_size);
return read_size;
};
loader.eof = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
return fin->eof();
};
loader.close = [](void * ctx) {
std::ifstream * fin = (std::ifstream*)ctx;
fin->close();
};
auto ctx = whisper_vad_init_with_params(&loader, params);
if (!ctx) {
whisper_vad_free(ctx);
return nullptr;
}
ctx->path_model = path_model;
return ctx;
}
struct whisper_vad_context * whisper_vad_init_with_params(
struct whisper_model_loader * loader,
struct whisper_vad_context_params params) {
// Read the VAD model
{
uint32_t magic;
read_safe(loader, magic);
if (magic != GGML_FILE_MAGIC) {
WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__);
return nullptr;
}
}
whisper_vad_context * vctx = new whisper_vad_context;
vctx->n_threads = params.n_threads;
vctx->params.use_gpu = params.use_gpu;
vctx->params.gpu_device = params.gpu_device;
auto & model = vctx->model;
auto & hparams = model.hparams;
// load model context params.
{
int32_t str_len;
read_safe(loader, str_len);
std::vector<char> buffer(str_len + 1, 0);
loader->read(loader->context, buffer.data(), str_len);
std::string model_type(buffer.data(), str_len);
model.type = model_type;
WHISPER_LOG_INFO("%s: model type: %s\n", __func__, model.type.c_str());
int32_t major, minor, patch;
read_safe(loader, major);
read_safe(loader, minor);
read_safe(loader, patch);
std::string version_str = std::to_string(major) + "." +
std::to_string(minor) + "." +
std::to_string(patch);
model.version = version_str;
WHISPER_LOG_INFO("%s: model version: %s\n", __func__, model.version.c_str());
read_safe(loader, vctx->n_window);
read_safe(loader, vctx->n_context);
}
// load model hyper params (hparams).
{
read_safe(loader, hparams.n_encoder_layers);
hparams.encoder_in_channels = new int32_t[hparams.n_encoder_layers];
hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers];
hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers];
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
read_safe(loader, hparams.encoder_in_channels[i]);
read_safe(loader, hparams.encoder_out_channels[i]);
read_safe(loader, hparams.kernel_sizes[i]);
}
read_safe(loader, hparams.lstm_input_size);
read_safe(loader, hparams.lstm_hidden_size);
read_safe(loader, hparams.final_conv_in);
read_safe(loader, hparams.final_conv_out);
WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers);
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]);
}
for (int32_t i = 0; i < hparams.n_encoder_layers; i++) {
WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]);
}
WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size);
WHISPER_LOG_INFO("%s: lstm_hidden_size = %d\n", __func__, hparams.lstm_hidden_size);
WHISPER_LOG_INFO("%s: final_conv_in = %d\n", __func__, hparams.final_conv_in);
WHISPER_LOG_INFO("%s: final_conv_out = %d\n", __func__, hparams.final_conv_out);
}
// 1 STFT tensor, 4*2 encoder tensors, 4 LSTM tensors, 2 final output tensors
const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1;
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
if (!ctx) {
throw std::runtime_error("failed to create ggml context");
}
ctx_map[buft] = ctx;
model.ctxs.emplace_back(ctx);
return ctx;
}
return it->second;
};
whisper_context_params wparams = whisper_context_default_params();
wparams.use_gpu = params.use_gpu;
wparams.gpu_device = params.gpu_device;
buft_list_t buft_list = make_buft_list(wparams);
auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * {
ggml_op op = VAD_TENSOR_OPS.at(type);
ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list);
if (!buft) {
throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type)));
}
ggml_context * ctx = get_ctx(buft);
ggml_tensor * tensor = ggml_dup_tensor(ctx, meta);
model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor;
return tensor;
};
// create tensors
{
ggml_init_params params = {
/*.mem_size =*/ n_tensors * ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
ggml_context * ctx = ggml_init(params);
const auto & hparams = model.hparams;
// SFTF precomputed basis matrix
model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS,
ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258));
model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT,
ggml_new_tensor_3d(
ctx,
GGML_TYPE_F16,
hparams.kernel_sizes[0],
hparams.encoder_in_channels[0],
hparams.encoder_out_channels[0]
));
model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS,
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0]));
model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT,
ggml_new_tensor_3d(
ctx,
GGML_TYPE_F16,
hparams.kernel_sizes[1],
hparams.encoder_in_channels[1],
hparams.encoder_out_channels[1]
));
model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS,
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1]));
model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT,
ggml_new_tensor_3d(
ctx,
GGML_TYPE_F16,
hparams.kernel_sizes[2],
hparams.encoder_in_channels[2],
hparams.encoder_out_channels[2]
));
model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS,
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2]));
model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT,
ggml_new_tensor_3d(
ctx,
GGML_TYPE_F16,
hparams.kernel_sizes[3],
hparams.encoder_in_channels[3],
hparams.encoder_out_channels[3]
));
model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS,
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3]));
// Hidden State dimension (input gate, forget gate, cell gate, output gate)
const int hstate_dim = hparams.lstm_hidden_size * 4;
// LSTM weights - input to hidden
model.lstm_ih_weight = create_tensor(
VAD_TENSOR_LSTM_WEIGHT_IH,
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
);
model.lstm_ih_bias = create_tensor(
VAD_TENSOR_LSTM_BIAS_IH,
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
);
// LSTM weights - hidden to hidden
model.lstm_hh_weight = create_tensor(
VAD_TENSOR_LSTM_WEIGHT_HH,
ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)
);
model.lstm_hh_bias = create_tensor(
VAD_TENSOR_LSTM_BIAS_HH,
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)
);
// Final conv layer weight
model.final_conv_weight = create_tensor(
VAD_TENSOR_FINAL_CONV_WEIGHT,
ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)
);
model.final_conv_bias = create_tensor(
VAD_TENSOR_FINAL_CONV_BIAS,
ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)
);
ggml_free(ctx);
}
// allocate tensors in the backend buffers
for (auto & p : ctx_map) {
ggml_backend_buffer_type_t buft = p.first;
ggml_context * ctx = p.second;
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (buf) {
model.buffers.emplace_back(buf);
size_t size_main = ggml_backend_buffer_get_size(buf);
WHISPER_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6);
}
}
// load weights
{
size_t total_size = 0;
model.n_loaded = 0;
std::vector<char> read_buf;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ttype;
read_safe(loader, n_dims);
read_safe(loader, length);
read_safe(loader, ttype);
if (loader->eof(loader->context)) {
break;
}
int32_t nelements = 1;
int32_t ne[4] = { 1, 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
read_safe(loader, ne[i]);
nelements *= ne[i];
}
std::string name;
std::vector<char> tmp(length);
loader->read(loader->context, &tmp[0], tmp.size());
name.assign(&tmp[0], tmp.size());
if (model.tensors.find(name) == model.tensors.end()) {
WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data());
return nullptr;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n",
__func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]);
return nullptr;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]);
return nullptr;
}
const size_t bpe = ggml_type_size(ggml_type(ttype));
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return nullptr;
}
if (ggml_backend_buffer_is_host(tensor->buffer)) {
// for the CPU and Metal backend, we can read directly into the tensor
loader->read(loader->context, tensor->data, ggml_nbytes(tensor));
BYTESWAP_TENSOR(tensor);
} else {
// read into a temporary buffer first, then copy to device memory
read_buf.resize(ggml_nbytes(tensor));
loader->read(loader->context, read_buf.data(), read_buf.size());
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));
}
total_size += ggml_nbytes(tensor);
model.n_loaded++;
}
WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6);
if (model.n_loaded == 0) {
WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
} else if (model.n_loaded != (int) model.tensors.size()) {
WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
return nullptr;
}
}
if (!whisper_vad_init_context(vctx)) {
whisper_vad_free(vctx);
return nullptr;
}
return vctx;
}
bool whisper_vad_detect_speech(
struct whisper_vad_context * vctx,
const float * samples,
int n_samples) {
int n_chunks = n_samples / vctx->n_window;
if (n_samples % vctx->n_window != 0) {
n_chunks += 1; // Add one more chunk for remaining samples.
}
WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples);
WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks);
// Reset LSTM hidden/cell states
ggml_backend_buffer_clear(vctx->buffer, 0);
vctx->probs.resize(n_chunks);
WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks);
std::vector<float> window(vctx->n_window, 0.0f);
auto & sched = vctx->sched.sched;
ggml_cgraph * gf = whisper_vad_build_graph(*vctx);
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__);
return false;
}
struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame");
struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob");
// we are going to reuse the graph multiple times for each chunk
const int64_t t_start_vad_us = ggml_time_us();
for (int i = 0; i < n_chunks; i++) {
const int idx_start = i * vctx->n_window;
const int idx_end = std::min(idx_start + vctx->n_window, n_samples);
const int chunk_len = idx_end - idx_start;
if (chunk_len < vctx->n_window) {
WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window);
std::vector<float> partial_chunk(vctx->n_window, 0.0f);
std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin());
// Copy the zero-padded chunk to the window.
const int samples_to_copy_max = vctx->n_window;
const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size());
std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin());
if (samples_to_copy_cur < samples_to_copy_max) {
std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f);
}
} else {
// Copy current frame samples to the window.
const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window);
std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin());
}
// Set the frame tensor data with the samples.
ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float));
// do not reset the scheduler - we will reuse the graph in the next chunk
if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) {
WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__);
break;
}
// Get the probability for this chunk.
ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float));
//WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]);
}
vctx->t_vad_us += ggml_time_us() - t_start_vad_us;
WHISPER_LOG_INFO("%s: vad time = %.2f ms processing %d samples\n", __func__, 1e-3f * vctx->t_vad_us, n_samples);
ggml_backend_sched_reset(sched);
return true;
}
int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) {
return segments->data.size();
}
float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) {
return segments->data[i_segment].start;
}
float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) {
return segments->data[i_segment].end;
}
int whisper_vad_n_probs(struct whisper_vad_context * vctx) {
return vctx->probs.size();
}
float * whisper_vad_probs(struct whisper_vad_context * vctx) {
return vctx->probs.data();
}
struct whisper_vad_segments * whisper_vad_segments_from_probs(
struct whisper_vad_context * vctx,
whisper_vad_params params) {
WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx));
int n_probs = whisper_vad_n_probs(vctx);
float * probs = whisper_vad_probs(vctx);
float threshold = params.threshold;
int min_speech_duration_ms = params.min_speech_duration_ms;
int min_silence_duration_ms = params.min_silence_duration_ms;
float max_speech_duration_s = params.max_speech_duration_s;
int speech_pad_ms = params.speech_pad_ms;
int n_window = vctx->n_window;
int sample_rate = WHISPER_SAMPLE_RATE;
int min_silence_samples = sample_rate * min_silence_duration_ms / 1000;
int audio_length_samples = n_probs * n_window;
// Min number of samples to be considered valid speech.
int min_speech_samples = sample_rate * min_speech_duration_ms / 1000;
int speech_pad_samples = sample_rate * speech_pad_ms / 1000;
// Max number of samples that a speech segment can contain before it is
// split into multiple segments.
int max_speech_samples;
if (max_speech_duration_s > 100000.0f) {
max_speech_samples = INT_MAX / 2;
} else {
int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples;
max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp;
if (max_speech_samples < 0) {
max_speech_samples = INT_MAX / 2;
}
}
// Detect silence period that exceeds this value, then that location (sample)
// is marked as a potential place where the segment could be split if
// max_speech_samples is reached. The value 98 was taken from the original
// silaro-vad python implementation:
//https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291
int min_silence_samples_at_max_speech = sample_rate * 98 / 1000;
// Calculate lower threshold for detecting end of speech segments.
float neg_threshold = threshold - 0.15f;
if (neg_threshold < 0.01f) {
neg_threshold = 0.01f;
}
struct speech_segment_t {
int start;
int end;
};
std::vector<speech_segment_t> speeches;
speeches.reserve(256);
bool is_speech_segment = false;
int temp_end = 0;
int prev_end = 0;
int next_start = 0;
int curr_speech_start = 0;
bool has_curr_speech = false;
for (int i = 0; i < n_probs; i++) {
float curr_prob = probs[i];
int curr_sample = n_window * i;
// Reset temp_end when we get back to speech
if ((curr_prob >= threshold) && temp_end) {
temp_end = 0;
if (next_start < prev_end) {
next_start = curr_sample;
}
}
// Start a new speech segment when probability exceeds threshold and not already in speech
if ((curr_prob >= threshold) && !is_speech_segment) {
is_speech_segment = true;
curr_speech_start = curr_sample;
has_curr_speech = true;
continue;
}
// Handle maximum speech duration
if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) {
if (prev_end) {
speeches.push_back({ curr_speech_start, prev_end });
has_curr_speech = true;
if (next_start < prev_end) { // Previously reached silence and is still not speech
is_speech_segment = false;
has_curr_speech = false;
} else {
curr_speech_start = next_start;
}
prev_end = next_start = temp_end = 0;
} else {
speeches.push_back({ curr_speech_start, curr_sample });
prev_end = next_start = temp_end = 0;
is_speech_segment = false;
has_curr_speech = false;
continue;
}
}
// Handle silence after speech
if ((curr_prob < neg_threshold) && is_speech_segment) {
if (!temp_end) {
temp_end = curr_sample;
}
// Track potential segment ends for max_speech handling
if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) {
prev_end = temp_end;
}
// Check if silence is long enough to end the segment
if ((curr_sample - temp_end) < min_silence_samples) {
continue;
} else {
// End the segment if it's long enough
if ((temp_end - curr_speech_start) > min_speech_samples) {
speeches.push_back({ curr_speech_start, temp_end });
}
prev_end = next_start = temp_end = 0;
is_speech_segment = false;
has_curr_speech = false;
continue;
}
}
}
// Handle the case if we're still in a speech segment at the end
if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) {
speeches.push_back({ curr_speech_start, audio_length_samples });
}
// Merge adjacent segments with small gaps in between (post-processing)
if (speeches.size() > 1) {
int merged_count = 0;
for (int i = 0; i < (int) speeches.size() - 1; i++) {
// Define maximum gap allowed for merging (e.g., 200ms converted to samples)
int max_merge_gap_samples = sample_rate * 200 / 1000;
// If the gap between this segment and the next is small enough
if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) {
// Merge by extending current segment to the end of next segment
speeches[i].end = speeches[i+1].end;
speeches.erase(speeches.begin() + i + 1);
i--;
merged_count++;
}
}
WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n",
__func__, merged_count, (int) speeches.size());
}
// Double-check for minimum speech duration
for (int i = 0; i < (int) speeches.size(); i++) {
if (speeches[i].end - speeches[i].start < min_speech_samples) {
WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n",
__func__, i, speeches[i].end - speeches[i].start);
speeches.erase(speeches.begin() + i);
i--;
}
}
WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size());
// Allocate final segments
std::vector<whisper_vad_segment> segments;
if (speeches.size() > 0) {
try {
segments.resize(speeches.size());
} catch (const std::bad_alloc &) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__);
return nullptr;
}
}
// Apply padding to segments and copy to final segments
for (int i = 0; i < (int) speeches.size(); i++) {
// Apply padding to the start of the first segment
if (i == 0) {
speeches[i].start =
(speeches[i].start > speech_pad_samples) ?
(speeches[i].start - speech_pad_samples) : 0;
}
// Handle spacing between segments
if (i < (int) speeches.size() - 1) {
int silence_duration = speeches[i+1].start - speeches[i].end;
if (silence_duration < 2 * speech_pad_samples) {
// If segments are close, split the difference
speeches[i].end += silence_duration / 2;
speeches[i+1].start =
(speeches[i+1].start > silence_duration / 2) ?
(speeches[i+1].start - silence_duration / 2) : 0;
} else {
// Otherwise, apply full padding to both
speeches[i].end =
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
(speeches[i].end + speech_pad_samples) : audio_length_samples;
speeches[i+1].start =
(speeches[i+1].start > speech_pad_samples) ?
(speeches[i+1].start - speech_pad_samples) : 0;
}
} else {
// Apply padding to the end of the last segment
speeches[i].end =
(speeches[i].end + speech_pad_samples < audio_length_samples) ?
(speeches[i].end + speech_pad_samples) : audio_length_samples;
}
// Convert from samples to seconds and copy to final segments
segments[i].start = (float)speeches[i].start / sample_rate;
segments[i].end = (float)speeches[i].end / sample_rate;
WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n",
__func__, i, segments[i].start, segments[i].end, segments[i].end - segments[i].start);
}
whisper_vad_segments * vad_segments = new whisper_vad_segments;
if (vad_segments == NULL) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__);
return nullptr;
}
vad_segments->data = std::move(segments);
return vad_segments;
}
struct whisper_vad_segments * whisper_vad_segments_from_samples(
whisper_vad_context * vctx,
whisper_vad_params params,
const float * samples,
int n_samples) {
WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples);
if (!whisper_vad_detect_speech(vctx, samples, n_samples)) {
WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__);
return nullptr;
}
return whisper_vad_segments_from_probs(vctx, params);
}
void whisper_vad_free(whisper_vad_context * ctx) {
if (ctx) {
for (ggml_context * context : ctx->model.ctxs) {
ggml_free(context);
}
for (ggml_backend_buffer_t buf : ctx->model.buffers) {
ggml_backend_buffer_free(buf);
}
ggml_backend_sched_free(ctx->sched.sched);
for (auto & backend : ctx->backends) {
ggml_backend_free(backend);
}
delete ctx;
}
}
void whisper_vad_free_segments(whisper_vad_segments * segments) {
if (segments) {
delete segments;
}
}
//////////////////////////////////
// Grammar - ported from llama.cpp
//////////////////////////////////
// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`.
static std::pair<std::vector<uint32_t>, whisper_partial_utf8> decode_utf8(
const char * src,
whisper_partial_utf8 partial_start) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
const char * pos = src;
std::vector<uint32_t> code_points;
uint32_t value = partial_start.value;
int n_remain = partial_start.n_remain;
// continue previous decode, if applicable
while (*pos != 0 && n_remain > 0) {
uint8_t next_byte = static_cast<uint8_t>(*pos);
if ((next_byte >> 6) != 2) {
// invalid sequence, abort
code_points.push_back(0);
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 });
}
value = (value << 6) + (next_byte & 0x3F);
++pos;
--n_remain;
}
if (partial_start.n_remain > 0 && n_remain == 0) {
code_points.push_back(value);
}
// decode any subsequent utf-8 sequences, which may end in an incomplete one
while (*pos != 0) {
uint8_t first_byte = static_cast<uint8_t>(*pos);
uint8_t highbits = first_byte >> 4;
n_remain = lookup[highbits] - 1;
if (n_remain < 0) {
// invalid sequence, abort
code_points.clear();
code_points.push_back(0);
return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain });
}
uint8_t mask = (1 << (7 - n_remain)) - 1;
value = first_byte & mask;
++pos;
while (*pos != 0 && n_remain > 0) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
++pos;
--n_remain;
}
if (n_remain == 0) {
code_points.push_back(value);
}
}
code_points.push_back(0);
return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain });
}
// returns true iff pos points to the end of one of the definitions of a rule
static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) {
switch (pos->type) {
case WHISPER_GRETYPE_END: return true; // NOLINT
case WHISPER_GRETYPE_ALT: return true; // NOLINT
default: return false;
}
}
// returns true iff chr satisfies the char range at pos (regular or inverse range)
// asserts that pos is pointing to a char range element
static std::pair<bool, const whisper_grammar_element *> whisper_grammar_match_char(
const whisper_grammar_element * pos,
const uint32_t chr) {
bool found = false;
bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT
do {
if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
found = found || (pos->value <= chr && chr <= pos[1].value);
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
found = found || pos->value == chr;
pos += 1;
}
} while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
return std::make_pair(found == is_positive_char, pos);
}
// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
// range at pos (regular or inverse range)
// asserts that pos is pointing to a char range element
static bool whisper_grammar_match_partial_char(
const whisper_grammar_element * pos,
const whisper_partial_utf8 partial_utf8) {
bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR;
WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT);
uint32_t partial_value = partial_utf8.value;
int n_remain = partial_utf8.n_remain;
// invalid sequence or 7-bit char split across 2 bytes (overlong)
if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
return false;
}
// range of possible code points this partial UTF-8 sequence could complete to
uint32_t low = partial_value << (n_remain * 6);
uint32_t high = low | ((1 << (n_remain * 6)) - 1);
if (low == 0) {
if (n_remain == 2) {
low = 1 << 11;
} else if (n_remain == 3) {
low = 1 << 16;
}
}
do {
if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
if (pos->value <= high && low <= pos[1].value) {
return is_positive_char;
}
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
if (low <= pos->value && pos->value <= high) {
return is_positive_char;
}
pos += 1;
}
} while (pos->type == WHISPER_GRETYPE_CHAR_ALT);
return !is_positive_char;
}
// transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element)
static void whisper_grammar_advance_stack(
const std::vector<std::vector<whisper_grammar_element>> & rules,
const std::vector<const whisper_grammar_element *> & stack,
std::vector<std::vector<const whisper_grammar_element *>> & new_stacks) {
if (stack.empty()) {
new_stacks.emplace_back();
return;
}
const whisper_grammar_element * pos = stack.back();
switch (pos->type) {
case WHISPER_GRETYPE_RULE_REF: {
const size_t rule_id = static_cast<size_t>(pos->value);
const whisper_grammar_element * subpos = rules[rule_id].data();
do {
// init new stack without the top (pos)
std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
if (!whisper_grammar_is_end_of_sequence(pos + 1)) {
// if this rule ref is followed by another element, add that to stack
new_stack.push_back(pos + 1);
}
if (!whisper_grammar_is_end_of_sequence(subpos)) {
// if alternate is nonempty, add to stack
new_stack.push_back(subpos);
}
whisper_grammar_advance_stack(rules, new_stack, new_stacks);
while (!whisper_grammar_is_end_of_sequence(subpos)) {
// scan to end of alternate def
subpos++;
}
if (subpos->type == WHISPER_GRETYPE_ALT) {
// there's another alternate def of this rule to process
subpos++;
} else {
break;
}
} while (true);
break;
}
case WHISPER_GRETYPE_CHAR:
case WHISPER_GRETYPE_CHAR_NOT:
new_stacks.push_back(stack);
break;
default:
// end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range
// (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
// those
WHISPER_ASSERT(false);
}
}
// takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `whisper_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
static std::vector<std::vector<const whisper_grammar_element *>> whisper_grammar_accept(
const std::vector<std::vector<whisper_grammar_element>> & rules,
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
const uint32_t chr) {
std::vector<std::vector<const whisper_grammar_element *>> new_stacks;
for (const auto & stack : stacks) {
if (stack.empty()) {
continue;
}
auto match = whisper_grammar_match_char(stack.back(), chr);
if (match.first) {
const whisper_grammar_element * pos = match.second;
// update top of stack to next element, if any
std::vector<const whisper_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
if (!whisper_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
whisper_grammar_advance_stack(rules, new_stack, new_stacks);
}
}
return new_stacks;
}
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
const std::vector<std::vector<whisper_grammar_element>> & rules,
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
const std::vector<whisper_grammar_candidate> & candidates);
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates_for_stack(
const std::vector<std::vector<whisper_grammar_element>> & rules,
const std::vector<const whisper_grammar_element *> & stack,
const std::vector<whisper_grammar_candidate> & candidates) {
std::vector<whisper_grammar_candidate> rejects;
if (stack.empty()) {
for (auto tok : candidates) {
if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
rejects.push_back(tok);
}
}
return rejects;
}
const whisper_grammar_element * stack_pos = stack.back();
std::vector<whisper_grammar_candidate> next_candidates;
for (auto tok : candidates) {
if (*tok.code_points == 0) {
// reached end of full codepoints in token, reject iff it ended in a partial sequence
// that cannot satisfy this position in grammar
if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
rejects.push_back(tok);
}
} else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) {
next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 });
} else {
rejects.push_back(tok);
}
}
const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second;
// update top of stack to next element, if any
std::vector<const whisper_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) {
stack_after.push_back(stack_pos_after);
}
std::vector<std::vector<const whisper_grammar_element *>> next_stacks;
whisper_grammar_advance_stack(rules, stack_after, next_stacks);
auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates);
for (auto tok : next_rejects) {
rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 });
}
return rejects;
}
static std::vector<whisper_grammar_candidate> whisper_grammar_reject_candidates(
const std::vector<std::vector<whisper_grammar_element>> & rules,
const std::vector<std::vector<const whisper_grammar_element *>> & stacks,
const std::vector<whisper_grammar_candidate> & candidates) {
if (candidates.empty() || stacks.empty()) {
return std::vector<whisper_grammar_candidate>();
}
auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
}
return rejects;
}
static struct whisper_grammar whisper_grammar_init(
const whisper_grammar_element ** rules,
size_t n_rules,
size_t i_start_rule) {
const whisper_grammar_element * pos;
// copy rule definitions into vectors
std::vector<std::vector<whisper_grammar_element>> vec_rules(n_rules);
for (size_t i = 0; i < n_rules; i++) {
for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) {
vec_rules[i].push_back(*pos);
}
vec_rules[i].push_back({WHISPER_GRETYPE_END, 0});
}
// loop over alternates of start rule to build initial stacks
std::vector<std::vector<const whisper_grammar_element *>> stacks;
pos = rules[i_start_rule];
do {
std::vector<const whisper_grammar_element *> stack;
if (!whisper_grammar_is_end_of_sequence(pos)) {
// if alternate is nonempty, add to stack
stack.push_back(pos);
}
whisper_grammar_advance_stack(vec_rules, stack, stacks);
while (!whisper_grammar_is_end_of_sequence(pos)) {
// scan to end of alternate def
pos++;
}
if (pos->type == WHISPER_GRETYPE_ALT) {
// there's another alternate def of this rule to process
pos++;
} else {
break;
}
} while (true);
return { std::move(vec_rules), std::move(stacks), {} };
}
static void whisper_suppress_invalid_grammar(
whisper_context & ctx,
const whisper_full_params & params,
std::vector<float> & logits,
const whisper_grammar & grammar) {
if (grammar.rules.empty() || grammar.stacks.empty()) {
return;
}
//bool allow_eot = false;
//for (const auto & stack : grammar.stacks) {
// if (stack.empty()) {
// allow_eot = true;
// break;
// }
//}
const whisper_token eot = whisper_token_eot(&ctx);
std::vector<std::pair<std::vector<uint32_t>, whisper_partial_utf8>> candidates_decoded;
std::vector<whisper_grammar_candidate> candidates_grammar;
for (whisper_token id = 0; id < eot; ++id) {
const std::string & text = ctx.vocab.id_to_token[id];
if (!text.empty()) {
candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8));
candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second });
}
}
const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
for (const auto & reject : rejects) {
logits[reject.id] -= params.grammar_penalty;
}
// when the grammar allows a continuation, we penalize the end-of-text token
//if (!allow_eot) {
// logits[eot] -= params.grammar_penalty;
//}
//fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size());
}
static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) {
if (grammar.rules.empty() || grammar.stacks.empty()) {
return;
}
//fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str());
const std::string & text = ctx.vocab.id_to_token[token];
if (text.rfind("[_", 0) == 0) {
// fprintf(stderr, " (skipped)\n");
return;
}
// fprintf(stderr, "\n");
// Note terminating 0 in decoded string
const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8);
const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it);
}
grammar.partial_utf8 = decoded.second;
}
//////////////
// END grammar
//////////////
////////////////////////////////////////////////////////////////////////////
struct whisper_context_params * whisper_context_default_params_by_ref(void) {
struct whisper_context_params params = whisper_context_default_params();
struct whisper_context_params* result = new whisper_context_params();
*result = params;
return result;
}
struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) {
struct whisper_full_params params = whisper_full_default_params(strategy);
struct whisper_full_params* result = new whisper_full_params();
*result = params;
return result;
}
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result = {
/*.strategy =*/ strategy,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.duration_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ true,
/*.no_timestamps =*/ false,
/*.single_segment =*/ false,
/*.print_special =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.split_on_word =*/ false,
/*.max_tokens =*/ 0,
/*.debug_mode =*/ false,
/*.audio_ctx =*/ 0,
/*.tdrz_enable =*/ false,
/* suppress_regex =*/ nullptr,
/*.initial_prompt =*/ nullptr,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
/*.detect_language =*/ false,
/*.suppress_blank =*/ true,
/*.suppress_nst =*/ false,
/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_inc =*/ 0.2f,
/*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f,
/*.greedy =*/ {
/*.best_of =*/ -1,
},
/*.beam_search =*/ {
/*.beam_size =*/ -1,
/*.patience =*/ -1.0f,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
/*.abort_callback =*/ nullptr,
/*.abort_callback_user_data =*/ nullptr,
/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
/*.grammar_rules =*/ nullptr,
/*.n_grammar_rules =*/ 0,
/*.i_start_rule =*/ 0,
/*.grammar_penalty =*/ 100.0f,
/*.vad =*/ false,
/*.vad_model_path =*/ nullptr,
/* vad_params =*/ whisper_vad_default_params(),
};
switch (strategy) {
case WHISPER_SAMPLING_GREEDY:
{
result.greedy = {
/*.best_of =*/ 5,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
result.beam_search = {
/*.beam_size =*/ 5,
/*.patience =*/ -1.0f,
};
} break;
}
return result;
}
// forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context & ctx,
struct whisper_state & state,
int i_segment,
float thold_pt,
float thold_ptsum);
static inline bool should_split_on_word(const char * txt, bool split_on_word) {
if (!split_on_word) return true;
return txt[0] == ' ';
}
static void whisper_exp_compute_token_level_timestamps_dtw(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
int i_segment,
size_t n_segments,
int seek,
int n_frames,
int medfilt_width,
int n_threads);
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
auto segment = state.result_all.back();
int res = 1;
int acc = 0;
std::string text;
for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i];
if (token.id >= whisper_token_eot(&ctx)) {
continue;
}
const auto txt = whisper_token_to_str(&ctx, token.id);
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
state.result_all.back().text = std::move(text);
state.result_all.back().t1 = token.t0;
state.result_all.back().tokens.resize(i);
state.result_all.back().speaker_turn_next = false;
state.result_all.push_back({});
state.result_all.back().t0 = token.t0;
state.result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment
state.result_all.back().tokens.insert(
state.result_all.back().tokens.end(),
segment.tokens.begin() + i,
segment.tokens.end());
state.result_all.back().speaker_turn_next = segment.speaker_turn_next;
acc = 0;
text = "";
segment = state.result_all.back();
i = -1;
res++;
} else {
acc += cur;
text += txt;
}
}
state.result_all.back().text = std::move(text);
return res;
}
static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", ""
};
static void whisper_compute_logprobs(
const std::vector<float> & logits,
const int n_logits,
std::vector<float> & logprobs) {
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logsumexp += expf(logits[i] - logit_max);
}
}
logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logprobs[i] = logits[i] - logsumexp;
} else {
logprobs[i] = -INFINITY;
}
}
}
static void whisper_compute_probs(
const std::vector<float> & logits,
const int n_logits,
const std::vector<float> & logprobs,
std::vector<float> & probs) {
for (int i = 0; i < n_logits; ++i) {
if (logits[i] == -INFINITY) {
probs[i] = 0.0f;
} else {
probs[i] = expf(logprobs[i]);
}
}
}
// process the logits for the selected decoder
// - applies logit filters
// - computes logprobs and probs
// TODO: optimize
static void whisper_process_logits(
struct whisper_context & ctx,
struct whisper_state & state,
struct whisper_decoder & decoder,
const struct whisper_full_params params,
float temperature) {
const auto & vocab = ctx.vocab;
const auto & tokens_cur = decoder.sequence.tokens;
const bool is_initial = tokens_cur.size() == 0;
const int n_logits = vocab.id_to_token.size();
WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab);
// extract the logits for the last token
// we will be mutating, and therefore we don't want to use the ctx.logits buffer directly
auto & probs = decoder.probs;
auto & logits = decoder.logits;
auto & logprobs = decoder.logprobs;
{
logits.resize(n_logits);
memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));
if (temperature > 0.0f) {
for (int i = 0; i < n_logits; i++) {
logits[i] /= temperature;
}
}
// will be populated a bit later
probs.resize(n_logits);
logprobs.resize(n_logits);
}
// apply logit filters here
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493
{
// suppress blank
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390
if (params.suppress_blank) {
if (is_initial) {
logits[vocab.token_eot] = -INFINITY;
logits[vocab.token_to_id.at(" ")] = -INFINITY;
}
}
// suppress <|notimestamps|> token
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412
logits[vocab.token_not] = -INFINITY;
if (params.no_timestamps) {
for (int i = vocab.token_beg; i < n_logits; ++i) {
logits[i] = -INFINITY;
}
}
// suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_nosp] = -INFINITY;
// [TDRZ] when tinydiarize is disabled, suppress solm token
if (params.tdrz_enable == false) {
logits[vocab.token_solm] = -INFINITY;
}
// suppress task tokens
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
logits[vocab.token_prev] = -INFINITY;
// suppress lang tokens
for (size_t i = 0; i < g_lang.size(); ++i) {
logits[whisper_token_lang(&ctx, i)] = -INFINITY;
}
// suppress prev token
logits[vocab.token_prev] = -INFINITY;
if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}
// suppress any tokens matching a regular expression
// ref: https://github.com/openai/whisper/discussions/1041
if (params.suppress_regex != nullptr) {
std::regex re(params.suppress_regex);
for (std::pair<whisper_vocab::token, whisper_vocab::id> token_id : vocab.token_to_id) {
if (std::regex_match(token_id.first, re)) {
logits[token_id.second] = -INFINITY;
}
}
}
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_nst) {
for (const std::string & token : non_speech_tokens) {
const std::string suppress_tokens[] = {token, " " + token};
for (const std::string & suppress_token : suppress_tokens) {
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
}
}
}
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" -")] = -INFINITY;
}
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) {
logits[vocab.token_to_id.at(" '")] = -INFINITY;
}
}
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
{
const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg;
const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg;
//WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp);
if (last_was_timestamp) {
if (penultimate_was_timestamp) {
for (int i = vocab.token_beg; i < n_logits; ++i) {
logits[i] = -INFINITY;
}
} else {
for (int i = 0; i < vocab.token_eot; ++i) {
logits[i] = -INFINITY;
}
}
}
}
// the initial timestamp cannot be larger than max_initial_ts
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429
if (is_initial && params.max_initial_ts > 0.0f) {
const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx;
const int tid0 = std::round(params.max_initial_ts/precision);
for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) {
logits[i] = -INFINITY;
}
}
// condition timestamp tokens to be increasing
// ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556
if (decoder.has_ts) {
const int tid0 = decoder.seek_delta/2;
for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) {
logits[i] = -INFINITY;
}
}
// populate the logprobs array (log_softmax)
whisper_compute_logprobs(logits, n_logits, logprobs);
// if sum of probability over timestamps is above any other token, sample timestamp
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
{
// logsumexp over timestamps
float timestamp_logprob = -INFINITY;
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
for (int i = vocab.token_beg; i < n_logits; ++i) {
if (logprobs[i] > -INFINITY) {
logsumexp += expf(logprobs[i] - logprob_max);
}
}
if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
}
}
const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg);
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
if (timestamp_logprob > max_text_token_logprob) {
for (int i = 0; i < vocab.token_beg; ++i) {
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}
} else {
if (params.n_grammar_rules > 0) {
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
// populate the logprobs array (log_softmax)
{
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logsumexp += expf(logits[i] - logit_max);
}
}
logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logprobs[i] = logits[i] - logsumexp;
} else {
logprobs[i] = -INFINITY;
}
}
}
}
}
}
}
// compute probs
whisper_compute_probs(logits, n_logits, logprobs, probs);
#if 0
// print first 100 logits - token string : logit
//for (int i = 0; i < 10; i++) {
// const auto token = vocab.id_to_token.at(i);
// const auto prob = probs[i];
// const auto logit = logits[i];
// const auto logprob = logprobs[i];
// printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob);
//}
// print sorted
{
std::vector<std::pair<float, int>> pairs;
for (int i = 0; i < n_logits; ++i) {
pairs.push_back(std::make_pair(probs[i], i));
}
std::sort(pairs.begin(), pairs.end(), [](const std::pair<float, int>& a, const std::pair<float, int>& b) {
return a.first > b.first;
});
for (int i = 0; i < 10; i++) {
const auto token = vocab.id_to_token.at(pairs[i].second);
const auto prob = pairs[i].first;
const auto logit = logits[pairs[i].second];
const auto logprob = logprobs[pairs[i].second];
printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str());
}
printf("----------------\n");
}
// "And", "and", " And", " and"
//printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]);
//printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]);
//printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]);
//printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]);
//printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]);
//printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]);
//printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]);
//printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
//printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
//printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
//printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]);
//printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]);
//printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]);
//printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]);
//printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]);
#endif
}
static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) {
if (a.tokens.size() != b.tokens.size()) {
return false;
}
// sequences are more likely to diverge at the end
for (int i = a.tokens.size() - 1; i >= 0; i--) {
if (a.tokens[i].id != b.tokens[i].id) {
return false;
}
}
return true;
}
static whisper_token_data whisper_sample_token(
whisper_context & ctx,
const whisper_decoder & decoder,
bool best) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f,
};
const auto & vocab = ctx.vocab;
const auto & probs = decoder.probs;
const auto & logprobs = decoder.logprobs;
const int n_logits = vocab.n_vocab;
{
double sum_ts = 0.0;
double max_ts = 0.0;
for (int i = vocab.token_beg; i < n_logits; i++) {
if (probs[i] == -INFINITY) {
continue;
}
sum_ts += probs[i];
if (max_ts < probs[i]) {
max_ts = probs[i];
result.tid = i;
}
}
result.pt = max_ts/(sum_ts + 1e-10);
result.ptsum = sum_ts;
}
if (best) {
for (int i = 0; i < n_logits; ++i) {
if (result.p < probs[i]) {
result.id = i;
result.p = probs[i];
result.plog = logprobs[i];
}
}
} else {
std::discrete_distribution<> dist(probs.begin(), probs.end());
result.id = dist(decoder.rng);
result.p = probs[result.id];
result.plog = logprobs[result.id];
}
if (result.id >= vocab.token_beg) {
result.tid = result.id;
result.pt = result.p;
}
return result;
}
static std::vector<whisper_token_data> whisper_sample_token_topk(
whisper_context & ctx,
whisper_decoder & decoder,
int k) {
const auto & vocab = ctx.vocab;
const auto & probs = decoder.probs;
const auto & logits = decoder.logits;
const auto & logprobs = decoder.logprobs;
const int n_logits = vocab.n_vocab;
auto & logits_id = decoder.logits_id;
logits_id.resize(n_logits);
for (int i = 0; i < n_logits; ++i) {
logits_id[i].first = logits[i];
logits_id[i].second = i;
}
{
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
std::partial_sort(
logits_id.begin(),
logits_id.begin() + k, logits_id.end(),
[](const pair_type & a, const pair_type & b) {
return a.first > b.first;
});
}
std::vector<whisper_token_data> result;
result.reserve(k);
whisper_token tid = vocab.token_beg;
float pt = 0.0;
float ptsum = 0.0;
{
double sum_ts = 0.0;
double max_ts = 0.0;
for (int i = vocab.token_beg; i < n_logits; i++) {
if (probs[i] == -INFINITY) {
continue;
}
sum_ts += probs[i];
if (max_ts < probs[i]) {
max_ts = probs[i];
tid = i;
}
}
pt = max_ts/(sum_ts + 1e-10);
ptsum = sum_ts;
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
for (int i = 0; i < k; ++i) {
const auto id = dist(decoder.rng);
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, });
if (result[i].id >= vocab.token_beg) {
result[i].tid = result[i].id;
result[i].pt = result[i].p;
}
}
return result;
}
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192
static void whisper_sequence_score(
const struct whisper_full_params & params,
whisper_sequence & sequence) {
if (sequence.result_len == 0) {
return;
}
double result = 0.0f;
for (int i = 0; i < sequence.result_len; ++i) {
result += sequence.tokens[i].plog;
}
sequence.sum_logprobs = result;
sequence.avg_logprobs = result/sequence.result_len;
double penalty = sequence.result_len;
if (params.length_penalty > 0.0f) {
penalty = pow((5.0 + penalty)/6.0, params.length_penalty);
}
sequence.score = result/penalty;
// compute the entropy of the sequence of the last 32 tokens
{
const int n = 32;
int cnt = 0;
double entropy = 0.0f;
std::map<whisper_token, int> token_counts;
for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) {
token_counts[sequence.tokens[i].id]++;
cnt++;
}
for (const auto & kv : token_counts) {
const auto p = kv.second/(double)cnt;
entropy -= p*log(p);
//WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second);
}
sequence.entropy = entropy;
}
}
static bool whisper_vad(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples,
std::vector<float> & filtered_samples,
int & filtered_n_samples) {
WHISPER_LOG_INFO("%s: VAD is enabled, processing speach segments only\n", __func__);
filtered_n_samples = 0;
struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params();
struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params);
if (vctx == nullptr) {
WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__);
return false;
}
const whisper_vad_params & vad_params = params.vad_params;
whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples);
if (vad_segments->data.size() > 0) {
state->has_vad_segments = true;
ctx->state->vad_segments.clear();
ctx->state->vad_segments.reserve(vad_segments->data.size());
WHISPER_LOG_INFO("%s: detected %d speech segments\n", __func__, (int)vad_segments->data.size());
float overlap_seconds = vad_params.samples_overlap;
int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE;
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
if (i < (int)vad_segments->data.size() - 1) {
segment_end_samples += overlap_samples;
}
segment_end_samples = std::min(segment_end_samples, n_samples - 1);
filtered_n_samples += (segment_end_samples - segment_start_samples);
WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
__func__, i, vad_segments->data[i].start,
vad_segments->data[i].end + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0),
(vad_segments->data[i].end - vad_segments->data[i].start) +
(i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0));
}
int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
int total_samples_needed = filtered_n_samples + total_silence_samples;
WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
__func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
try {
filtered_samples.resize(total_samples_needed);
} catch (const std::bad_alloc & /* e */) {
WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
whisper_vad_free_segments(vad_segments);
whisper_vad_free(vctx);
return false;
}
int offset = 0;
for (int i = 0; i < (int)vad_segments->data.size(); i++) {
int segment_start_samples = vad_segments->data[i].start * WHISPER_SAMPLE_RATE;
int segment_end_samples = vad_segments->data[i].end * WHISPER_SAMPLE_RATE;
if (i < (int)vad_segments->data.size() - 1) {
segment_end_samples += overlap_samples;
}
segment_start_samples = std::min(segment_start_samples, n_samples - 1);
segment_end_samples = std::min(segment_end_samples, n_samples);
int segment_length = segment_end_samples - segment_start_samples;
if (segment_length > 0) {
whisper_state::vad_segment_info segment;
segment.orig_start = vad_segments->data[i].start;
segment.orig_end = vad_segments->data[i].end;
segment.vad_start = offset / (float)WHISPER_SAMPLE_RATE;
segment.vad_end = (offset + segment_length) / (float)WHISPER_SAMPLE_RATE;
WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n",
__func__, segment.orig_start, segment.orig_end, segment.vad_start, segment.vad_end);
ctx->state->vad_segments.push_back(segment);
// Copy this speech segment
memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
offset += segment_length;
// Add silence after this segment (except after the last segment)
if (i < (int)vad_segments->data.size() - 1) {
// Fill with zeros (silence)
memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
offset += silence_samples;
}
}
}
filtered_n_samples = offset;
WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
__func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
}
return true;
}
int whisper_full_with_state(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
const float * samples,
int n_samples) {
// clear old results
auto & result_all = state->result_all;
result_all.clear();
const float * process_samples = samples;
int n_process_samples = n_samples;
std::vector<float> vad_samples;
if (params.vad) {
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
int vad_n_samples;
if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
return -1;
}
process_samples = vad_samples.data();
n_process_samples = vad_n_samples;
}
if (n_process_samples > 0) {
// compute log mel spectrogram
if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
return -2;
}
}
// auto-detect language if not specified
if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) {
std::vector<float> probs(whisper_lang_max_id() + 1, 0.0f);
const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data());
if (lang_id < 0) {
WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__);
return -3;
}
state->lang_id = lang_id;
params.language = whisper_lang_str(lang_id);
WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]);
if (params.detect_language) {
return 0;
}
}
if (params.token_timestamps) {
state->t_beg = 0;
state->t_last = 0;
state->tid_last = 0;
if (n_samples > 0) {
state->energy = get_signal_energy(samples, n_samples, 32);
}
}
const int seek_start = params.offset_ms/10;
const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10;
// if length of spectrogram is less than 100ms (10 frames), then return
// basically don't process anything that is less than 100ms
// ref: https://github.com/ggml-org/whisper.cpp/issues/2065
const int delta_min = 10;
if (seek_end < seek_start + delta_min) {
WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10);
return 0;
}
// a set of temperatures to use
// [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ]
std::vector<float> temperatures;
if (params.temperature_inc > 0.0f) {
for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) {
temperatures.push_back(t);
}
} else {
temperatures.push_back(params.temperature);
}
// initialize the decoders
int n_decoders = 1;
switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY:
{
n_decoders = params.greedy.best_of;
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size);
} break;
};
n_decoders = std::max(1, n_decoders);
if (n_decoders > WHISPER_MAX_DECODERS) {
WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS);
return -4;
}
// TAGS: WHISPER_DECODER_INIT
for (int j = 1; j < n_decoders; j++) {
auto & decoder = state->decoders[j];
decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity());
decoder.probs.resize (ctx->vocab.n_vocab);
decoder.logits.resize (ctx->vocab.n_vocab);
decoder.logprobs.resize(ctx->vocab.n_vocab);
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
decoder.rng = std::mt19937(j);
}
// the accumulated text context so far
auto & prompt_past = state->prompt_past;
if (params.no_context) {
prompt_past.clear();
}
// prepare prompt
{
std::vector<whisper_token> prompt_tokens;
// initial prompt
if (!params.prompt_tokens && params.initial_prompt) {
prompt_tokens.resize(1024);
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
if (n_needed < 0) {
prompt_tokens.resize(-n_needed);
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
}
prompt_tokens.resize(n_needed);
params.prompt_tokens = prompt_tokens.data();
params.prompt_n_tokens = prompt_tokens.size();
}
// prepend the prompt tokens to the prompt_past
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
// parse tokens from the pointer
for (int i = 0; i < params.prompt_n_tokens; i++) {
prompt_past.push_back(params.prompt_tokens[i]);
}
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
}
// overwrite audio_ctx, max allowed is hparams.n_audio_ctx
if (params.audio_ctx > whisper_n_audio_ctx(ctx)) {
WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx));
return -5;
}
state->exp_n_audio_ctx = params.audio_ctx;
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx), };
if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language);
state->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
if (params.translate) {
prompt_init.push_back(whisper_token_translate(ctx));
} else {
prompt_init.push_back(whisper_token_transcribe(ctx));
}
}
// first release distilled models require the "no_timestamps" token
{
const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866;
if (is_distil && !params.no_timestamps) {
WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__);
params.no_timestamps = true;
}
}
if (params.no_timestamps) {
prompt_init.push_back(whisper_token_not(ctx));
}
int seek = seek_start;
std::vector<whisper_token> prompt;
prompt.reserve(whisper_n_text_ctx(ctx));
struct beam_candidate {
int decoder_idx;
int seek_delta;
bool has_ts;
whisper_sequence sequence;
whisper_grammar grammar;
};
std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
std::vector<beam_candidate> beam_candidates;
// main loop
while (true) {
if (params.progress_callback) {
const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start);
params.progress_callback(
ctx, state, progress_cur, params.progress_callback_user_data);
}
// if only 100ms left, then stop
if (seek + delta_min >= seek_end) {
break;
}
if (params.encoder_begin_callback) {
if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) {
WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__);
break;
}
}
// encode audio features starting at offset seek
if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to encode\n", __func__);
return -6;
}
// if there is a very short audio segment left to process, we remove any past prompt since it tends
// to confuse the decoder and often make it repeat or hallucinate stuff
if (seek > seek_start && seek + 500 >= seek_end) {
prompt_past.clear();
}
int best_decoder_id = 0;
for (int it = 0; it < (int) temperatures.size(); ++it) {
const float t_cur = temperatures[it];
int n_decoders_cur = 1;
switch (params.strategy) {
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
if (t_cur > 0.0f) {
n_decoders_cur = params.greedy.best_of;
}
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
if (t_cur > 0.0f) {
n_decoders_cur = params.greedy.best_of;
} else {
n_decoders_cur = params.beam_search.beam_size;
}
} break;
};
n_decoders_cur = std::max(1, n_decoders_cur);
WHISPER_LOG_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur);
// TAGS: WHISPER_DECODER_INIT
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
decoder.sequence.tokens.clear();
decoder.sequence.result_len = 0;
decoder.sequence.sum_logprobs_all = 0.0;
decoder.sequence.sum_logprobs = -INFINITY;
decoder.sequence.avg_logprobs = -INFINITY;
decoder.sequence.entropy = 0.0;
decoder.sequence.score = -INFINITY;
decoder.seek_delta = 100*WHISPER_CHUNK_SIZE;
decoder.failed = false;
decoder.completed = false;
decoder.has_ts = false;
if (params.grammar_rules != nullptr) {
decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule);
} else {
decoder.grammar = {};
}
}
// init prompt and kv cache for the current iteration
// TODO: do not recompute the prompt if it is the same as previous time
{
prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation
if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) };
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
}
// init new transcription with sot, language (opt) and task tokens
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
// print the prompt
WHISPER_LOG_DEBUG("\n\n");
for (int i = 0; i < (int) prompt.size(); i++) {
WHISPER_LOG_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str());
}
WHISPER_LOG_DEBUG("\n\n");
// recreate the KV cache if the number of decoders has changed
if (state->kv_self_n_dec < n_decoders_cur) {
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
whisper_kv_cache_free(state->kv_self);
// overallocate to workaround KV cache fragmentation issues
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
ctx->model.hparams.n_text_state,
ctx->model.hparams.n_text_layer,
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
whisper_free_state(state);
return -7;
}
state->kv_self_n_dec = n_decoders_cur;
}
whisper_kv_cache_clear(state->kv_self);
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -8;
}
// Calculate no_speech probability after first decode.
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
{
const int n_logits = ctx->vocab.id_to_token.size();
std::vector<float> logprobs(n_logits);
std::vector<float> probs(n_logits);
whisper_compute_logprobs(state->logits, n_logits, logprobs);
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
}
{
const int64_t t_start_sample_us = ggml_time_us();
state->decoders[0].i_batch = prompt.size() - 1;
whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1);
memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0]));
memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0]));
memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0]));
}
state->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
const int64_t t_start_sample_us = ggml_time_us();
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
for (auto & bc : bc_per_dec) {
bc.clear();
}
}
// sampling
// TODO: avoid memory allocations, optimize, avoid threads?
{
std::atomic<int> j_cur(0);
auto process = [&]() {
while (true) {
const int j = j_cur.fetch_add(1);
if (j >= n_decoders_cur) {
break;
}
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
switch (params.strategy) {
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
if (t_cur < 1e-6f) {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
} else {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
}
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
for (const auto & token : tokens_new) {
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
bc_per_dec[j].back().sequence.tokens.push_back(token);
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
}
} break;
};
}
};
const int n_threads = std::min(params.n_threads, n_decoders_cur);
if (n_threads == 1) {
process();
} else {
std::vector<std::thread> threads(n_threads - 1);
for (int t = 0; t < n_threads - 1; ++t) {
threads[t] = std::thread(process);
}
process();
for (int t = 0; t < n_threads - 1; ++t) {
threads[t].join();
}
}
}
beam_candidates.clear();
for (const auto & bc : bc_per_dec) {
beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
if (!bc.empty()) {
state->n_sample += 1;
}
}
// for beam-search, choose the top candidates and update the KV caches
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
std::sort(
beam_candidates.begin(),
beam_candidates.end(),
[](const beam_candidate & a, const beam_candidate & b) {
if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) {
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
}
return a.decoder_idx < b.decoder_idx;
});
uint32_t cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
if (cur_c >= beam_candidates.size()) {
cur_c = 0;
}
auto & cur = beam_candidates[cur_c++];
while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) {
++cur_c;
}
decoder.seek_delta = cur.seek_delta;
decoder.has_ts = cur.has_ts;
decoder.sequence = cur.sequence;
decoder.grammar = cur.grammar;
whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);
WHISPER_LOG_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
}
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1);
whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1);
whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1);
}
}
// update the decoder state
// - check if the sequence is completed
// - check if the sequence is failed
// - update sliding window based on timestamp tokens
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
auto & has_ts = decoder.has_ts;
auto & failed = decoder.failed;
auto & completed = decoder.completed;
auto & seek_delta = decoder.seek_delta;
auto & result_len = decoder.sequence.result_len;
{
const auto & token = decoder.sequence.tokens.back();
// timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) {
const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx));
// do not allow to go back in time
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
WHISPER_LOG_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
failed = true; // TODO: maybe this is not a failure ?
continue;
}
seek_delta = seek_delta_new;
result_len = i + 1;
has_ts = true;
}
whisper_grammar_accept_token(*ctx, decoder.grammar, token.id);
#ifdef WHISPER_DEBUG
{
const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]";
WHISPER_LOG_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n",
__func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str());
}
#endif
// end of segment
if (token.id == whisper_token_eot(ctx) || // end of text token
(params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached
(has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms)
) {
if (result_len == 0 && !params.no_timestamps) {
if (seek + seek_delta + delta_min >= seek_end) {
result_len = i + 1;
} else {
WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
failed = true;
continue;
}
}
if (params.single_segment || params.no_timestamps) {
result_len = i + 1;
seek_delta = 100*WHISPER_CHUNK_SIZE;
}
WHISPER_LOG_DEBUG("%s: decoder %d completed\n", __func__, j);
completed = true;
continue;
}
// TESTS: if no tensors are loaded, it means we are running tests
if (ctx->model.n_loaded == 0) {
seek_delta = 100*WHISPER_CHUNK_SIZE;
completed = true;
continue;
}
}
// sometimes, the decoding can get stuck in a repetition loop
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
WHISPER_LOG_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
failed = true;
continue;
}
}
// check if all decoders have finished (i.e. completed or failed)
{
bool completed_all = true;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
completed_all = false;
}
if (completed_all) {
break;
}
}
state->t_sample_us += ggml_time_us() - t_start_sample_us;
// obtain logits for the next token
{
auto & batch = state->batch;
batch.n_tokens = 0;
const int n_past = prompt.size() + i;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.failed || decoder.completed) {
continue;
}
//WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
decoder.i_batch = batch.n_tokens;
batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
batch.pos [batch.n_tokens] = n_past;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id [batch.n_tokens][0] = j;
batch.logits [batch.n_tokens] = 1;
batch.n_tokens++;
}
assert(batch.n_tokens > 0);
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -9;
}
const int64_t t_start_sample_us = ggml_time_us();
// TODO: avoid memory allocations, optimize, avoid threads?
{
std::atomic<int> j_cur(0);
auto process = [&]() {
while (true) {
const int j = j_cur.fetch_add(1);
if (j >= n_decoders_cur) {
break;
}
auto & decoder = state->decoders[j];
if (decoder.failed || decoder.completed) {
continue;
}
whisper_process_logits(*ctx, *state, decoder, params, t_cur);
}
};
const int n_threads = std::min(params.n_threads, n_decoders_cur);
if (n_threads == 1) {
process();
} else {
std::vector<std::thread> threads(n_threads - 1);
for (int t = 0; t < n_threads - 1; ++t) {
threads[t] = std::thread(process);
}
process();
for (int t = 0; t < n_threads - 1; ++t) {
threads[t].join();
}
}
}
state->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}
// rank the resulting sequences and select the best one
{
double best_score = -INFINITY;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.failed) {
continue;
}
decoder.sequence.tokens.resize(decoder.sequence.result_len);
whisper_sequence_score(params, decoder.sequence);
WHISPER_LOG_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n",
__func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy);
if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) {
WHISPER_LOG_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n",
__func__, j, decoder.sequence.entropy, params.entropy_thold);
decoder.failed = true;
state->n_fail_h++;
continue;
}
if (best_score < decoder.sequence.score) {
best_score = decoder.sequence.score;
best_decoder_id = j;
}
}
WHISPER_LOG_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
}
bool success = true;
// was the decoding successful for the current temperature?
// do fallback only if:
// - we are not at the last temperature
if (it != (int) temperatures.size() - 1) {
const auto & decoder = state->decoders[best_decoder_id];
if (decoder.failed ||
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
success = false;
state->n_fail_p++;
}
}
if (success) {
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}
break;
}
WHISPER_LOG_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
}
// output results through a user-provided callback
{
const auto & best_decoder = state->decoders[best_decoder_id];
auto seek_delta = best_decoder.seek_delta;
const auto result_len = best_decoder.sequence.result_len;
const auto & tokens_cur = best_decoder.sequence.tokens;
// [EXPERIMENTAL] Token-level timestamps with DTW
const auto n_segments_before = state->result_all.size();
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
best_decoder.sequence.avg_logprobs < params.logprob_thold);
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
// update prompt_past
prompt_past.clear();
if (prompt.front() == whisper_token_prev(ctx)) {
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
}
for (int i = 0; i < result_len && !is_no_speech; ++i) {
prompt_past.push_back(tokens_cur[i].id);
}
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
std::string text;
bool speaker_turn_next = false;
for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) {
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
// [TDRZ] record if speaker turn was predicted after current segment
if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) {
speaker_turn_next = true;
}
if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) {
const auto tt0 = t0;
const auto tt1 = t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
//printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid);
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
}
}
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
}
}
text = "";
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
i++;
}
i--;
t0 = t1;
i0 = i + 1;
speaker_turn_next = false;
}
}
if (!text.empty()) {
const auto t1 = seek + seek_delta;
const auto tt0 = t0;
const auto tt1 = t1;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
}
}
if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) {
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
}
}
}
// FIXME: will timestamp offsets be correct?
// [EXPERIMENTAL] Token-level timestamps with DTW
{
const int n_segments = state->result_all.size() - n_segments_before;
if (ctx->params.dtw_token_timestamps && n_segments) {
const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek);
whisper_exp_compute_token_level_timestamps_dtw(
ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads);
if (params.new_segment_callback) {
for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) {
params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data);
}
}
}
}
// ref: https://github.com/ggml-org/whisper.cpp/pull/2629
const bool single_timestamp_ending = tokens_cur.size() > 1 &&
tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) &&
tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx);
if (single_timestamp_ending) {
WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n");
seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100);
}
// update audio window
seek += seek_delta;
WHISPER_LOG_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta);
}
}
return 0;
}
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples) {
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
}
int whisper_full_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples,
int n_processors) {
if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples);
}
int ret = 0;
// prepare separate states for each thread
std::vector<whisper_state*> states;
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
// the calling thread will process the first chunk
// while the other threads will process the remaining chunks
std::vector<std::thread> workers(n_processors - 1);
for (int i = 0; i < n_processors - 1; ++i) {
// create a new state for each thread
states.push_back(whisper_init_state(ctx));
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
auto params_cur = params;
params_cur.offset_ms = 0;
params_cur.print_progress = false;
params_cur.print_realtime = false;
params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;
params_cur.progress_callback = nullptr;
params_cur.progress_callback_user_data = nullptr;
workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur);
}
{
auto params_cur = params;
// We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk.
params_cur.print_realtime = false;
// Run the first transformation using default state but only for the first chunk.
ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
}
for (int i = 0; i < n_processors - 1; ++i) {
workers[i].join();
}
const int64_t offset_t = (int64_t) params.offset_ms/10.0;
// combine results into result_state->result_all from all other states
for (int i = 0; i < n_processors - 1; ++i) {
auto& results_i = states[i]->result_all;
for (auto& result : results_i) {
// correct the segment timestamp taking into account the offset
result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t;
// make sure that segments are not overlapping
if (!ctx->state->result_all.empty()) {
result.t0 = std::max(result.t0, ctx->state->result_all.back().t1);
}
ctx->state->result_all.push_back(std::move(result));
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data);
}
}
ctx->state->t_mel_us += states[i]->t_mel_us;
ctx->state->t_sample_us += states[i]->t_sample_us;
ctx->state->t_encode_us += states[i]->t_encode_us;
ctx->state->t_decode_us += states[i]->t_decode_us;
ctx->state->t_batchd_us += states[i]->t_batchd_us;
ctx->state->t_prompt_us += states[i]->t_prompt_us;
ctx->state->n_sample += states[i]->n_sample;
ctx->state->n_encode += states[i]->n_encode;
ctx->state->n_decode += states[i]->n_decode;
ctx->state->n_batchd += states[i]->n_batchd;
ctx->state->n_prompt += states[i]->n_prompt;
whisper_free_state(states[i]);
}
// average the timings
ctx->state->t_mel_us /= n_processors;
ctx->state->t_sample_us /= n_processors;
ctx->state->t_encode_us /= n_processors;
ctx->state->t_decode_us /= n_processors;
// print information about the audio boundaries
WHISPER_LOG_WARN("\n");
WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
for (int i = 0; i < n_processors - 1; ++i) {
WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
}
WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
return ret;
}
int whisper_full_n_segments_from_state(struct whisper_state * state) {
return state->result_all.size();
}
int whisper_full_n_segments(struct whisper_context * ctx) {
return ctx->state->result_all.size();
}
int whisper_full_lang_id_from_state(struct whisper_state * state) {
return state->lang_id;
}
int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->state->lang_id;
}
int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
// If VAD wasn't used, return the original timestamp
if (!state->has_vad_segments || state->vad_segments.empty()) {
return state->result_all[i_segment].t0;
}
// Get the start timestamp produced by whisper_full. whisper_full processes
// only the speech segments in this case so we need to map these timestamps
// back to the original audio.
float t0 = state->result_all[i_segment].t0 / 100.0f;
// Find which VAD segment this timestamp belongs.
// TODO(danbev) This could be optimized by using a binary search if the number
// of segments exceed a certain limit. Also we might be able to assume that
// the access pattern is sequential and optimized for that too.
for (size_t i = 0; i < state->vad_segments.size(); i++) {
const auto & segment = state->vad_segments[i];
// Check if the timestamp falls within this segment.
if (t0 >= segment.vad_start && t0 <= segment.vad_end) {
float proportion = 0.0f;
if (segment.vad_end > segment.vad_start) {
proportion = (t0 - segment.vad_start) / (segment.vad_end - segment.vad_start);
}
float orig_t0 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
return (int64_t)(orig_t0 * 100);
}
}
// Check if the timestamp falls between two segments.
for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
const auto & curr = state->vad_segments[i];
const auto & next = state->vad_segments[i + 1];
if (t0 > curr.vad_end && t0 < next.vad_start) {
// Calculate how far we are through the gap as a proportion
float gap_proportion = 0.0f;
if (next.vad_start > curr.vad_end) {
gap_proportion = (t0 - curr.vad_end) / (next.vad_start - curr.vad_end);
}
float orig_t0 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
return (int64_t)(orig_t0 * 100);
}
}
// Handle the case where the timestamp is after the last segment.
if (t0 > state->vad_segments.back().vad_end) {
// For timestamps after the last segment, add the extra time to the end of the last segment
const auto& last = state->vad_segments.back();
// Calculate how far beyond the last segment
float extra_time = t0 - last.vad_end;
// Add this extra time to the original end time
float orig_t0 = last.orig_end + extra_time;
return (int64_t)(orig_t0 * 100);
}
WHISPER_LOG_WARN("%s: Could not map t0 = %f to a VAD segment\n", __func__, t0);
return t0;
}
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
return whisper_full_get_segment_t0_from_state(ctx->state, i_segment);
}
int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) {
// If VAD wasn't used, return the original timestamp
if (!state->has_vad_segments || state->vad_segments.empty()) {
return state->result_all[i_segment].t1;
}
// Get the end timestamp produced by whisper_full. whisper_full processes
// only the speech segments in this case so we need to map these timestamps
// back to the original audio.
float t1 = state->result_all[i_segment].t1 / 100.0f;
// Find which VAD segment this timestamp belongs.
// TODO(danbev) This could be optimized by using a binary search if the number
// of segments exceed a certain limit. Also we might be able to assume that
// the access pattern is sequential and optimized for that too.
for (size_t i = 0; i < state->vad_segments.size(); i++) {
const auto& segment = state->vad_segments[i];
// Check if the timestamp falls within this segment.
if (t1 >= segment.vad_start && t1 <= segment.vad_end) {
// Calculate the proportion through the filtered segment.
float proportion = 0.0f;
if (segment.vad_end > segment.vad_start) {
proportion = (t1 - segment.vad_start) / (segment.vad_end - segment.vad_start);
}
float orig_t1 = segment.orig_start + proportion * (segment.orig_end - segment.orig_start);
return (int64_t)(orig_t1 * 100);
}
}
// Check if the timestamp falls between two segments.
for (size_t i = 0; i < state->vad_segments.size() - 1; i++) {
const auto & curr = state->vad_segments[i];
const auto & next = state->vad_segments[i + 1];
if (t1 > curr.vad_end && t1 < next.vad_start) {
// Calculate how far we are through the gap as a proportion
float gap_proportion = 0.0f;
if (next.vad_start > curr.vad_end) {
gap_proportion = (t1 - curr.vad_end) / (next.vad_start - curr.vad_end);
}
// Map to the corresponding position in the original gap
float orig_t1 = curr.orig_end + gap_proportion * (next.orig_start - curr.orig_end);
return (int64_t)(orig_t1 * 100);
}
}
// Handle the case where the timestamp is after the last segment
if (t1 > state->vad_segments.back().vad_end) {
// For the last segment, use the end of the last VAD segment
const auto& last = state->vad_segments.back();
// Calculate how far beyond the last segment
float extra_time = t1 - last.vad_end;
// Add this extra time to the original end time
float orig_t1 = last.orig_end + extra_time;
return (int64_t)(orig_t1 * 100);
}
WHISPER_LOG_WARN("%s: Could not map t1 = %f to a VAD segment\n", __func__, t1);
return t1;
}
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
return whisper_full_get_segment_t1_from_state(ctx->state, i_segment);
}
bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].speaker_turn_next;
}
bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].speaker_turn_next;
}
const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].text.c_str();
}
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].text.c_str();
}
int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].tokens.size();
}
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].tokens.size();
}
const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) {
return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str();
}
const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str();
}
whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) {
return state->result_all[i_segment].tokens[i_token].id;
}
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->state->result_all[i_segment].tokens[i_token].id;
}
struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) {
return state->result_all[i_segment].tokens[i_token];
}
struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->state->result_all[i_segment].tokens[i_token];
}
float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) {
return state->result_all[i_segment].tokens[i_token].p;
}
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->state->result_all[i_segment].tokens[i_token].p;
}
float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].no_speech_prob;
}
float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].no_speech_prob;
}
// =================================================================================================
//
// Temporary interface needed for exposing ggml interface
// Will be removed in the future when ggml becomes a separate library
//
WHISPER_API int whisper_bench_memcpy(int n_threads) {
fputs(whisper_bench_memcpy_str(n_threads), stderr);
return 0;
}
WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
static std::string s;
s = "";
char strbuf[256];
ggml_time_init();
size_t n = 20;
size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations
// 1GB array
const size_t size = arr*1e6;
double sum = 0.0;
// heat-up
{
char * src = (char *) malloc(size);
char * dst = (char *) malloc(size);
for (size_t i = 0; i < size; i++) src[i] = i;
memcpy(dst, src, size); // heat-up
double tsum = 0.0;
for (size_t i = 0; i < n; i++) {
const int64_t t0 = ggml_time_us();
memcpy(dst, src, size);
const int64_t t1 = ggml_time_us();
tsum += (t1 - t0)*1e-6;
src[rand() % size] = rand() % 256;
}
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9));
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away
{
for (size_t i = 0; i < size; i++) sum += dst[i];
}
free(src);
free(dst);
}
// single-thread
{
char * src = (char *) malloc(size);
char * dst = (char *) malloc(size);
for (size_t i = 0; i < size; i++) src[i] = i;
memcpy(dst, src, size); // heat-up
double tsum = 0.0;
for (size_t i = 0; i < n; i++) {
const int64_t t0 = ggml_time_us();
memcpy(dst, src, size);
const int64_t t1 = ggml_time_us();
tsum += (t1 - t0)*1e-6;
src[rand() % size] = rand() % 256;
}
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9));
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away
{
for (size_t i = 0; i < size; i++) sum += dst[i];
}
free(src);
free(dst);
}
// multi-thread
for (int32_t k = 1; k <= n_threads; k++) {
char * src = (char *) malloc(size);
char * dst = (char *) malloc(size);
for (size_t i = 0; i < size; i++) src[i] = i;
memcpy(dst, src, size); // heat-up
double tsum = 0.0;
auto helper = [&](int th) {
const int64_t i0 = (th + 0)*size/k;
const int64_t i1 = (th + 1)*size/k;
for (size_t i = 0; i < n; i++) {
memcpy(dst + i0, src + i0, i1 - i0);
src[i0 + rand() % (i1 - i0)] = rand() % 256;
};
};
const int64_t t0 = ggml_time_us();
std::vector<std::thread> threads(k - 1);
for (int32_t th = 0; th < k - 1; ++th) {
threads[th] = std::thread(helper, th);
}
helper(k - 1);
for (int32_t th = 0; th < k - 1; ++th) {
threads[th].join();
}
const int64_t t1 = ggml_time_us();
tsum += (t1 - t0)*1e-6;
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
s += strbuf;
// needed to prevent the compiler from optimizing the memcpy away
{
for (size_t i = 0; i < size; i++) sum += dst[i];
}
free(src);
free(dst);
}
snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum);
s += strbuf;
return s.c_str();
}
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr);
return 0;
}
WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
whisper_load_backends();
static std::string s;
s = "";
char strbuf[256];
ggml_time_init();
const int n_max = 128;
const std::vector<size_t> sizes = {
64, 128, 256, 512, 1024, 2048, 4096,
};
const size_t N_max = sizes.back();
// a: N*N*sizeof(float)
// b: N*N*sizeof(float)
// c: N*N*sizeof(float)
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead());
// put a bunch of random data in the buffer
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
for (int j = 0; j < (int) sizes.size(); j++) {
int n_q4_0 = 0;
int n_q4_1 = 0;
int n_q5_0 = 0;
int n_q5_1 = 0;
int n_q8_0 = 0;
int n_fp16 = 0;
int n_fp32 = 0;
// GFLOPS/s
double s_q4_0 = 0.0;
double s_q4_1 = 0.0;
double s_q5_0 = 0.0;
double s_q5_1 = 0.0;
double s_q8_0 = 0.0;
double s_fp16 = 0.0;
double s_fp32 = 0.0;
const size_t N = sizes[j];
for (int k = 0; k < 7; ++k) {
const ggml_type wtype =
k == 0 ? GGML_TYPE_Q4_0 :
k == 1 ? GGML_TYPE_Q4_1 :
k == 2 ? GGML_TYPE_Q5_0 :
k == 3 ? GGML_TYPE_Q5_1 :
k == 4 ? GGML_TYPE_Q8_0 :
k == 5 ? GGML_TYPE_F16 : GGML_TYPE_F32;
double & s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 : k == 2 ? s_q5_0 : k == 3 ? s_q5_1 : k == 4 ? s_q8_0 : k == 5 ? s_fp16 : /*k == 6*/ s_fp32;
int & n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 : k == 2 ? n_q5_0 : k == 3 ? n_q5_1 : k == 4 ? n_q8_0 : k == 5 ? n_fp16 : /*k == 6*/ n_fp32;
struct ggml_init_params gparams = {
/*.mem_size =*/ buf.size(),
/*.mem_buffer =*/ buf.data(),
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(gparams);
struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N);
struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N);
struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b);
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_build_forward_expand(gf, c);
double tsum = 0.0;
// heat-up
ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
for (int i = 0; i < n_max; ++i) {
const int64_t t0 = ggml_time_us();
ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr);
const int64_t t1 = ggml_time_us();
tsum += (t1 - t0)*1e-6;
n++;
if (tsum > 1.0 && n >= 3) {
break;
}
}
ggml_free(ctx0);
s = ((2.0*N*N*N*n)/tsum)*1e-9;
}
// Q4_0 | Q4_1
snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) | Q4_1 %7.1f GFLOPS (%3d runs)\n",
N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1);
s += strbuf;
// Q5_0 | Q5_1 | Q8_0
snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q5_0 %7.1f GFLOPS (%3d runs) | Q5_1 %7.1f GFLOPS (%3d runs) | Q8_0 %7.1f GFLOPS (%3d runs)\n",
N, N, s_q5_0, n_q5_0, s_q5_1, n_q5_1, s_q8_0, n_q8_0);
s += strbuf;
// F16 | F32
snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: F16 %7.1f GFLOPS (%3d runs) | F32 %7.1f GFLOPS (%3d runs)\n",
N, N, s_fp16, n_fp16, s_fp32, n_fp32);
s += strbuf;
}
return s.c_str();
}
// =================================================================================================
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
static int64_t sample_to_timestamp(int i_sample) {
return (100ll*i_sample)/WHISPER_SAMPLE_RATE;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length(const std::string & text) {
float res = 0.0f;
for (char c : text) {
if (c == ' ') {
res += 0.01f;
} else if (c == ',') {
res += 2.00f;
} else if (c == '.') {
res += 3.00f;
} else if (c == '!') {
res += 3.00f;
} else if (c == '?') {
res += 3.00f;
} else if (c >= '0' && c <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// average the fabs of the signal
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
const int hw = n_samples_per_half_window;
std::vector<float> result(n_samples);
for (int i = 0; i < n_samples; i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < n_samples) {
sum += fabs(signal[i + j]);
}
}
result[i] = sum/(2*hw + 1);
}
return result;
}
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context & ctx,
struct whisper_state & state,
int i_segment,
float thold_pt,
float thold_ptsum) {
auto & segment = state.result_all[i_segment];
auto & tokens = segment.tokens;
const int n_samples = state.energy.size();
if (n_samples == 0) {
WHISPER_LOG_ERROR("%s: no signal data available\n", __func__);
return;
}
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
const int n = tokens.size();
if (n == 0) {
return;
}
if (n == 1) {
tokens[0].t0 = t0;
tokens[0].t1 = t1;
return;
}
auto & t_beg = state.t_beg;
auto & t_last = state.t_last;
auto & tid_last = state.tid_last;
for (int j = 0; j < n; ++j) {
auto & token = tokens[j];
if (j == 0) {
if (token.id == whisper_token_beg(&ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(&ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
//printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1);
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
// split the time proportionally to the voice length
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
// fix up (just in case)
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
if (j > 0) {
if (tokens[j - 1].t1 > tokens[j].t0) {
tokens[j].t0 = tokens[j - 1].t1;
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(&ctx)) {
continue;
}
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
const int ss0 = std::max(s0 - hw, 0);
const int ss1 = std::min(s1 + hw, n_samples);
const int ns = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += state.energy[k];
}
const float thold = 0.5*sum/ns;
{
int k = s0;
if (state.energy[k] > thold && j > 0) {
while (k > 0 && state.energy[k] > thold) {
k--;
}
tokens[j].t0 = sample_to_timestamp(k);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (state.energy[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = sample_to_timestamp(k);
}
}
{
int k = s1;
if (state.energy[k] > thold) {
while (k < n_samples - 1 && state.energy[k] > thold) {
k++;
}
tokens[j].t1 = sample_to_timestamp(k);
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (state.energy[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = sample_to_timestamp(k);
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(&ctx)) {
// continue;
// }
//}
}
//
// token level timestamps - dtw version
//
// n_text_layer -> total text layers on model
// n_head -> total heads per text layer on model
static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) {
std::vector<uint32_t> ret;
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
return ret;
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
if (il >= n_text_layer - cparams.dtw_n_top) {
for (int32_t i = 0; i < n_head; ++i) {
ret.push_back(i);
}
}
} else {
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
for (size_t i = 0; i < aheads.n_heads; ++i) {
if (aheads.heads[i].n_text_layer == il) {
ret.push_back(aheads.heads[i].n_head);
}
}
}
return ret;
}
// dtw + backtrace to return found path
// based on
// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
WHISPER_ASSERT(ggml_n_dims(x) == 2);
int64_t N = x->ne[0];
int64_t M = x->ne[1];
struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1);
struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1);
cost = whisper_set_f32(cost, INFINITY);
trace = whisper_set_i32(trace, -1);
whisper_set_f32_nd(cost, 0, 0, 0, 0, 0.0);
// dtw
// supposedly can be optmized by computing diagonals in parallel ?
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most.
for (int64_t j = 1; j < M + 1; ++j) {
for (int64_t i = 1; i < N + 1; ++i) {
float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0);
float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0);
float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0);
float c;
int32_t t;
if (c0 < c1 && c0 < c2) {
c = c0;
t = 0;
} else if (c1 < c0 && c1 < c2) {
c = c1;
t = 1;
} else {
c = c2;
t = 2;
}
c = whisper_get_f32_nd(x, i - 1, j - 1, 0, 0) + c;
whisper_set_f32_nd(cost, i, j, 0, 0, c);
whisper_set_i32_nd(trace, i, j, 0, 0, t);
}
}
// Backtrace
const int64_t BT_MAX_ROWS = N + M - 1;
struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2);
// trace[0, :] = 2;
for (int64_t i = 0; i < M + 1; ++i)
whisper_set_i32_nd(trace, 0, i, 0, 0, 2);
//trace[:, 0] = 1;
for (int64_t i = 0; i < N + 1; ++i)
whisper_set_i32_nd(trace, i, 0, 0, 0, 1);
int bt_row_idx = BT_MAX_ROWS - 1;
int64_t i = N;
int64_t j = M;
while (i > 0 || j > 0) {
whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1);
whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1);
--bt_row_idx;
int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0);
if (t == 0) {
--i;
--j;
} else if (t == 1) {
--i;
} else if (t == 2) {
--j;
} else {
WHISPER_ASSERT(0);
}
}
// FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs)
// Clip + transpose
// This might not be entirely necessary for our case, but leaving it for now so output matrix
// is identical to dtw on openAI timing.py
const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1;
ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols);
for (int64_t i = 0; i < 2; ++i) {
for (int64_t j = 0; j < result_n_cols; ++j) {
int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0);
whisper_set_i32_nd(r, i, j, 0, 0, v);
}
}
return r;
}
struct median_filter_user_data {
int filter_width;
};
static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int /*nth*/, void * userdata) {
if (ith != 0) {
return;
}
int filter_width = ((median_filter_user_data *) userdata)->filter_width;
WHISPER_ASSERT(filter_width < a->ne[2]);
WHISPER_ASSERT(filter_width % 2);
WHISPER_ASSERT(ggml_n_dims(a) == 3);
WHISPER_ASSERT(a->type == GGML_TYPE_F32);
std::vector<float> filter;
filter.reserve(filter_width);
for (int64_t i = 0; i < a->ne[0]; ++i) {
for (int64_t j = 0; j < a->ne[1]; ++j) {
for (int64_t k = 0; k < a->ne[2]; ++k) {
for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
// "reflect" padding
int64_t idx = k + off;
if (idx < 0) {
idx = -idx;
} else if (idx >= a->ne[2]) {
idx = 2*(a->ne[2] - 1) - idx;
}
filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0));
}
std::sort(filter.begin(), filter.end());
const float v = filter[filter.size()/2];
whisper_set_f32_nd(dst, i, j, k, 0, v);
filter.clear();
}
}
}
}
static void whisper_exp_compute_token_level_timestamps_dtw(
struct whisper_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
int i_segment,
size_t n_segments,
int seek,
int n_frames,
int medfilt_width,
int n_threads)
{
const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx;
WHISPER_ASSERT(medfilt_width % 2);
WHISPER_ASSERT(n_frames <= n_audio_ctx * 2);
WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE);
// FIXME: Allocating mem everytime we call this func
// Our ggml buffer should be pre-allocated somewhere during init and reused
// when we call this function
struct ggml_init_params gparams = {
/*.mem_size =*/ ctx->params.dtw_mem_size,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
struct ggml_context * gctx = ggml_init(gparams);
// Build token sequence that will be passed to decoder
// sot + [lang] + text result + eot
std::vector<whisper_token> tokens = { whisper_token_sot(ctx), };
if (whisper_is_multilingual(ctx)) {
const int lang_id = whisper_lang_id(params.language);
state->lang_id = lang_id;
tokens.push_back(whisper_token_lang(ctx, lang_id));
}
const size_t sot_sequence_length = tokens.size();
tokens.push_back(whisper_token_not(ctx));
for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
auto & segment = state->result_all[i];
for (auto &t: segment.tokens) {
// Only text tokens
if (t.id < whisper_token_eot(ctx)) {
tokens.push_back(t.id);
}
}
}
tokens.push_back(whisper_token_eot(ctx));
// Get result tokens, pass then along to decoder to get cross attention QKs
// used in timestamping
// Decoder already returns only alignment head QKs, already concatenated in
// one tensor.
whisper_kv_cache_clear(state->kv_self);
whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0);
whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1);
if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
WHISPER_LOG_INFO("DECODER FAILED\n");
WHISPER_ASSERT(0);
}
WHISPER_ASSERT(state->aheads_cross_QKs != nullptr);
const auto n_audio_tokens = n_frames/2;
WHISPER_ASSERT(state->aheads_cross_QKs != NULL);
WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]);
const auto n_tokens = state->aheads_cross_QKs->ne[0];
const auto n_heads = state->aheads_cross_QKs->ne[2];
// Copy data from decoder buffer to a local CPU tensor, discarding unused audio
// tokens (i.e. discarding rows at the end of tensor)
// IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
WHISPER_ASSERT(state->aheads_cross_QKs->type == GGML_TYPE_F32);
WHISPER_ASSERT(ggml_is_contiguous(state->aheads_cross_QKs));
ggml_tensor * w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
auto & data = state->aheads_cross_QKs_data;
data.resize(n_tokens * n_audio_ctx * n_heads);
ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);
for (int k = 0; k < n_heads; ++k) {
for (int j = 0; j < n_audio_tokens; ++j) {
memcpy(
(char *) w->data + j * w->nb[1] + k * w->nb[2],
data.data() + j * n_tokens + k * n_tokens * n_audio_ctx,
n_tokens * sizeof(float)
);
}
}
// Normalize - in original OpenAI code, this is done over dim=-2. In this case,
// we already permuted N_TOKENS dimension to columns on last loop, becase ggml_norm
// operates over columns. Afterwards, permute to a shape that facilitates mean
// operation (after median filter)
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
w = ggml_norm(gctx, w, 1e-9f);
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
// Pass median filter - this is done over AUDIO_TOKENS dimension.
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
// OUT: Same dims
median_filter_user_data mf_user_data = {medfilt_width};
w = ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data);
// Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
w = ggml_mean(gctx, w);
w = ggml_scale(gctx, w, -1.0);
w = ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
// Remove SOT sequence and EOT
// Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS
w = ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]);
// Compute
struct ggml_cgraph * gf = ggml_new_graph(gctx);
ggml_build_forward_expand(gf, w);
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
ggml_backend_graph_compute(backend.get(), gf);
ggml_tensor * alignment = dtw_and_backtrace(gctx, w);
// Place timestamps on segments
int32_t last_v = 0;
auto seg_i = state->result_all.begin() + i_segment;
auto tok_i = seg_i->tokens.begin();
for (int i = 0; i < alignment->ne[1]; ++i) {
int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0);
if (v != last_v) {
int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0);
int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio
last_v = v;
// Skip non-text tokens
while (!(tok_i->id < whisper_token_eot(ctx))) {
++tok_i;
if (tok_i == seg_i->tokens.end()) {
++seg_i;
tok_i = seg_i->tokens.begin();
}
}
tok_i->t_dtw = timestamp;
++tok_i;
if (tok_i == seg_i->tokens.end()) {
++seg_i;
tok_i = seg_i->tokens.begin();
}
}
}
// Print DTW timestamps
/*for (size_t i = i_segment; i < i_segment + n_segments; ++i) {
auto & segment = state->result_all[i];
for (auto &t: segment.tokens) {
const char * tok = whisper_token_to_str(ctx, t.id);
fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100);
}
fprintf(stderr, "\n");
}*/
ggml_free(gctx);
}
void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
g_state.log_callback_user_data = user_data;
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
}
GGML_ATTRIBUTE_FORMAT(2, 3)
static void whisper_log_internal(ggml_log_level level, const char * format, ...) {
va_list args;
va_start(args, format);
char buffer[1024];
int len = vsnprintf(buffer, 1024, format, args);
if (len < 1024) {
g_state.log_callback(level, buffer, g_state.log_callback_user_data);
} else {
char* buffer2 = new char[len+1];
vsnprintf(buffer2, len+1, format, args);
buffer2[len] = 0;
g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
delete[] buffer2;
}
va_end(args);
}
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
#ifndef WHISPER_DEBUG
if (level == GGML_LOG_LEVEL_DEBUG) {
return;
}
#endif
fputs(text, stderr);
fflush(stderr);
}