whisper.cpp/whisper.cpp
Georgi Gerganov d5afebd37c
whisper : token-level timestamp refactoring (#49, #120)
This turned out pretty good overall. The algorithm has been moved from
main.cpp to whisper.cpp and can be reused for all subtitles types. This
means that now you can specify the maximum length of the generated
lines. Simply provide the "-ml" argument specifying the max length in
number of characters
2022-11-02 21:45:54 +02:00

3250 lines
109 KiB
C++

#define WHISPER_BUILD
#include "whisper.h"
#include "ggml.h"
#include <algorithm>
#include <cassert>
#define _USE_MATH_DEFINES
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <thread>
#include <vector>
#define USE_FLASH_ATTN
//#define USE_FLASH_FF
// available whisper models
enum e_model {
MODEL_UNKNOWN,
MODEL_TINY,
MODEL_BASE,
MODEL_SMALL,
MODEL_MEDIUM,
MODEL_LARGE,
};
static const std::map<std::string, std::pair<int, std::string>> g_lang = {
{ "en", { 0, "english", } },
{ "zh", { 1, "chinese", } },
{ "de", { 2, "german", } },
{ "es", { 3, "spanish", } },
{ "ru", { 4, "russian", } },
{ "ko", { 5, "korean", } },
{ "fr", { 6, "french", } },
{ "ja", { 7, "japanese", } },
{ "pt", { 8, "portuguese", } },
{ "tr", { 9, "turkish", } },
{ "pl", { 10, "polish", } },
{ "ca", { 11, "catalan", } },
{ "nl", { 12, "dutch", } },
{ "ar", { 13, "arabic", } },
{ "sv", { 14, "swedish", } },
{ "it", { 15, "italian", } },
{ "id", { 16, "indonesian", } },
{ "hi", { 17, "hindi", } },
{ "fi", { 18, "finnish", } },
{ "vi", { 19, "vietnamese", } },
{ "iw", { 20, "hebrew", } },
{ "uk", { 21, "ukrainian", } },
{ "el", { 22, "greek", } },
{ "ms", { 23, "malay", } },
{ "cs", { 24, "czech", } },
{ "ro", { 25, "romanian", } },
{ "da", { 26, "danish", } },
{ "hu", { 27, "hungarian", } },
{ "ta", { 28, "tamil", } },
{ "no", { 29, "norwegian", } },
{ "th", { 30, "thai", } },
{ "ur", { 31, "urdu", } },
{ "hr", { 32, "croatian", } },
{ "bg", { 33, "bulgarian", } },
{ "lt", { 34, "lithuanian", } },
{ "la", { 35, "latin", } },
{ "mi", { 36, "maori", } },
{ "ml", { 37, "malayalam", } },
{ "cy", { 38, "welsh", } },
{ "sk", { 39, "slovak", } },
{ "te", { 40, "telugu", } },
{ "fa", { 41, "persian", } },
{ "lv", { 42, "latvian", } },
{ "bn", { 43, "bengali", } },
{ "sr", { 44, "serbian", } },
{ "az", { 45, "azerbaijani", } },
{ "sl", { 46, "slovenian", } },
{ "kn", { 47, "kannada", } },
{ "et", { 48, "estonian", } },
{ "mk", { 49, "macedonian", } },
{ "br", { 50, "breton", } },
{ "eu", { 51, "basque", } },
{ "is", { 52, "icelandic", } },
{ "hy", { 53, "armenian", } },
{ "ne", { 54, "nepali", } },
{ "mn", { 55, "mongolian", } },
{ "bs", { 56, "bosnian", } },
{ "kk", { 57, "kazakh", } },
{ "sq", { 58, "albanian", } },
{ "sw", { 59, "swahili", } },
{ "gl", { 60, "galician", } },
{ "mr", { 61, "marathi", } },
{ "pa", { 62, "punjabi", } },
{ "si", { 63, "sinhala", } },
{ "km", { 64, "khmer", } },
{ "sn", { 65, "shona", } },
{ "yo", { 66, "yoruba", } },
{ "so", { 67, "somali", } },
{ "af", { 68, "afrikaans", } },
{ "oc", { 69, "occitan", } },
{ "ka", { 70, "georgian", } },
{ "be", { 71, "belarusian", } },
{ "tg", { 72, "tajik", } },
{ "sd", { 73, "sindhi", } },
{ "gu", { 74, "gujarati", } },
{ "am", { 75, "amharic", } },
{ "yi", { 76, "yiddish", } },
{ "lo", { 77, "lao", } },
{ "uz", { 78, "uzbek", } },
{ "fo", { 79, "faroese", } },
{ "ht", { 80, "haitian creole", } },
{ "ps", { 81, "pashto", } },
{ "tk", { 82, "turkmen", } },
{ "nn", { 83, "nynorsk", } },
{ "mt", { 84, "maltese", } },
{ "sa", { 85, "sanskrit", } },
{ "lb", { 86, "luxembourgish", } },
{ "my", { 87, "myanmar", } },
{ "bo", { 88, "tibetan", } },
{ "tl", { 89, "tagalog", } },
{ "mg", { 90, "malagasy", } },
{ "as", { 91, "assamese", } },
{ "tt", { 92, "tatar", } },
{ "haw", { 93, "hawaiian", } },
{ "ln", { 94, "lingala", } },
{ "ha", { 95, "hausa", } },
{ "ba", { 96, "bashkir", } },
{ "jw", { 97, "javanese", } },
{ "su", { 98, "sundanese", } },
};
static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
{ MODEL_TINY, 80ull*MB },
{ MODEL_BASE, 128ull*MB },
{ MODEL_SMALL, 300ull*MB },
{ MODEL_MEDIUM, 680ull*MB },
{ MODEL_LARGE, 1100ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
{ MODEL_TINY, 104ull*MB },
{ MODEL_BASE, 138ull*MB },
{ MODEL_SMALL, 208ull*MB },
{ MODEL_MEDIUM, 280ull*MB },
{ MODEL_LARGE, 354ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
{ MODEL_TINY, 200ull*MB },
{ MODEL_BASE, 202ull*MB },
{ MODEL_SMALL, 204ull*MB },
{ MODEL_MEDIUM, 206ull*MB },
{ MODEL_LARGE, 208ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
{ MODEL_TINY, 32ull*MB },
{ MODEL_BASE, 44ull*MB },
{ MODEL_SMALL, 64ull*MB },
{ MODEL_MEDIUM, 84ull*MB },
{ MODEL_LARGE, 110ull*MB },
};
struct whisper_mel {
int n_len;
int n_mel;
std::vector<float> data;
};
struct whisper_filters {
int32_t n_mel;
int32_t n_fft;
std::vector<float> data;
};
struct whisper_vocab {
using id = int32_t;
using token = std::string;
int n_vocab = 51864;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
id token_eot = 50256;
id token_sot = 50257;
id token_prev = 50360;
id token_solm = 50361; // ??
id token_not = 50362; // no timestamps
id token_beg = 50363;
// available tasks
static const id token_translate = 50358;
static const id token_transcribe = 50359;
bool is_multilingual() const {
return n_vocab == 51865;
}
};
struct whisper_segment {
int64_t t0;
int64_t t1;
std::string text;
std::vector<whisper_token_data> tokens;
};
// medium
// hparams: {
// 'n_mels': 80,
// 'n_vocab': 51864,
// 'n_audio_ctx': 1500,
// 'n_audio_state': 1024,
// 'n_audio_head': 16,
// 'n_audio_layer': 24,
// 'n_text_ctx': 448,
// 'n_text_state': 1024,
// 'n_text_head': 16,
// 'n_text_layer': 24
// }
//
// default hparams (Whisper tiny)
struct whisper_hparams {
int32_t n_vocab = 51864;
int32_t n_audio_ctx = 1500;
int32_t n_audio_state = 384;
int32_t n_audio_head = 6;
int32_t n_audio_layer = 4;
int32_t n_text_ctx = 448;
int32_t n_text_state = 384;
int32_t n_text_head = 6;
int32_t n_text_layer = 4;
int32_t n_mels = 80;
int32_t f16 = 1;
};
// audio encoding layer
struct whisper_layer_encoder {
// encoder.blocks.*.attn_ln
struct ggml_tensor * attn_ln_0_w;
struct ggml_tensor * attn_ln_0_b;
// encoder.blocks.*.attn.out
struct ggml_tensor * attn_ln_1_w;
struct ggml_tensor * attn_ln_1_b;
// encoder.blocks.*.attn.query
struct ggml_tensor * attn_q_w;
struct ggml_tensor * attn_q_b;
// encoder.blocks.*.attn.key
struct ggml_tensor * attn_k_w;
// encoder.blocks.*.attn.value
struct ggml_tensor * attn_v_w;
struct ggml_tensor * attn_v_b;
// encoder.blocks.*.mlp_ln
struct ggml_tensor * mlp_ln_w;
struct ggml_tensor * mlp_ln_b;
// encoder.blocks.*.mlp.0
struct ggml_tensor * mlp_0_w;
struct ggml_tensor * mlp_0_b;
// encoder.blocks.*.mlp.2
struct ggml_tensor * mlp_1_w;
struct ggml_tensor * mlp_1_b;
};
// token decoding layer
struct whisper_layer_decoder {
// decoder.blocks.*.attn_ln
struct ggml_tensor * attn_ln_0_w;
struct ggml_tensor * attn_ln_0_b;
// decoder.blocks.*.attn.out
struct ggml_tensor * attn_ln_1_w;
struct ggml_tensor * attn_ln_1_b;
// decoder.blocks.*.attn.query
struct ggml_tensor * attn_q_w;
struct ggml_tensor * attn_q_b;
// decoder.blocks.*.attn.key
struct ggml_tensor * attn_k_w;
// decoder.blocks.*.attn.value
struct ggml_tensor * attn_v_w;
struct ggml_tensor * attn_v_b;
// decoder.blocks.*.cross_attn_ln
struct ggml_tensor * cross_attn_ln_0_w;
struct ggml_tensor * cross_attn_ln_0_b;
// decoder.blocks.*.cross_attn.out
struct ggml_tensor * cross_attn_ln_1_w;
struct ggml_tensor * cross_attn_ln_1_b;
// decoder.blocks.*.cross_attn.query
struct ggml_tensor * cross_attn_q_w;
struct ggml_tensor * cross_attn_q_b;
// decoder.blocks.*.cross_attn.key
struct ggml_tensor * cross_attn_k_w;
// decoder.blocks.*.cross_attn.value
struct ggml_tensor * cross_attn_v_w;
struct ggml_tensor * cross_attn_v_b;
// decoder.blocks.*.mlp_ln
struct ggml_tensor * mlp_ln_w;
struct ggml_tensor * mlp_ln_b;
// decoder.blocks.*.mlp.0
struct ggml_tensor * mlp_0_w;
struct ggml_tensor * mlp_0_b;
// decoder.blocks.*.mlp.2
struct ggml_tensor * mlp_1_w;
struct ggml_tensor * mlp_1_b;
};
struct whisper_model {
e_model type = MODEL_UNKNOWN;
whisper_hparams hparams;
whisper_filters filters;
// encoder.positional_embedding
struct ggml_tensor * e_pe;
// encoder.conv1
struct ggml_tensor * e_conv_1_w;
struct ggml_tensor * e_conv_1_b;
// encoder.conv2
struct ggml_tensor * e_conv_2_w;
struct ggml_tensor * e_conv_2_b;
// encoder.ln_post
struct ggml_tensor * e_ln_w;
struct ggml_tensor * e_ln_b;
// decoder.positional_embedding
struct ggml_tensor * d_pe; // DD
// decoder.token_embedding
struct ggml_tensor * d_te; // DD
// decoder.ln
struct ggml_tensor * d_ln_w; // DD
struct ggml_tensor * d_ln_b; // DD
std::vector<whisper_layer_encoder> layers_encoder;
std::vector<whisper_layer_decoder> layers_decoder;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
struct ggml_tensor * memory_cross_k;
struct ggml_tensor * memory_cross_v;
// context
struct ggml_context * ctx;
struct ggml_context * ctx_mem;
// tensors
int n_loaded;
std::map<std::string, struct ggml_tensor *> tensors;
};
struct whisper_context {
int64_t t_load_us = 0;
int64_t t_mel_us = 0;
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
int64_t t_decode_us = 0;
int64_t t_start_us = 0;
std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
std::vector<uint8_t> buf_memory;
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;
whisper_model model;
whisper_vocab vocab;
whisper_mel mel;
std::vector<float> probs;
std::vector<float> logits;
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
};
// load the model from a ggml file
//
// file format:
//
// - hparams
// - pre-computed mel filters
// - vocab
// - weights
//
// see the convert-pt-to-ggml.py script for details
//
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
auto & model = wctx.model;
auto & vocab = wctx.vocab;
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false;
}
// verify magic
{
uint32_t magic;
fin.read((char *) &magic, sizeof(magic));
if (magic != 0x67676d6c) {
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
}
//load hparams
{
auto & hparams = model.hparams;
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
fin.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx));
fin.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state));
fin.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head));
fin.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer));
fin.read((char *) &hparams.n_text_ctx, sizeof(hparams.n_text_ctx));
fin.read((char *) &hparams.n_text_state, sizeof(hparams.n_text_state));
fin.read((char *) &hparams.n_text_head, sizeof(hparams.n_text_head));
fin.read((char *) &hparams.n_text_layer, sizeof(hparams.n_text_layer));
fin.read((char *) &hparams.n_mels, sizeof(hparams.n_mels));
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
assert(hparams.n_text_state == hparams.n_audio_state);
if (hparams.n_audio_layer == 4) {
model.type = e_model::MODEL_TINY;
}
if (hparams.n_audio_layer == 6) {
model.type = e_model::MODEL_BASE;
}
if (hparams.n_audio_layer == 12) {
model.type = e_model::MODEL_SMALL;
}
if (hparams.n_audio_layer == 24) {
model.type = e_model::MODEL_MEDIUM;
}
if (hparams.n_audio_layer == 32) {
model.type = e_model::MODEL_LARGE;
}
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
fprintf(stderr, "%s: n_audio_head = %d\n", __func__, hparams.n_audio_head);
fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer);
fprintf(stderr, "%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx);
fprintf(stderr, "%s: n_text_state = %d\n", __func__, hparams.n_text_state);
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
// this is the total memory required to run the inference
const size_t mem_required =
wctx.buf_model->size() +
wctx.buf_memory.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();
fprintf(stderr, "%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
}
// load mel filters
{
auto & filters = wctx.model.filters;
fin.read((char *) &filters.n_mel, sizeof(filters.n_mel));
fin.read((char *) &filters.n_fft, sizeof(filters.n_fft));
filters.data.resize(filters.n_mel * filters.n_fft);
fin.read((char *) filters.data.data(), filters.data.size() * sizeof(float));
}
// load vocab
{
int32_t n_vocab = 0;
fin.read((char *) &n_vocab, sizeof(n_vocab));
//if (n_vocab != model.hparams.n_vocab) {
// fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
// __func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
// return false;
//}
std::string word;
for (int i = 0; i < n_vocab; i++) {
uint32_t len;
fin.read((char *) &len, sizeof(len));
word.resize(len);
fin.read((char *) word.data(), len);
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
//printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
}
vocab.n_vocab = model.hparams.n_vocab;
if (vocab.is_multilingual()) {
vocab.token_eot++;
vocab.token_sot++;
vocab.token_prev++;
vocab.token_solm++;
vocab.token_not++;
vocab.token_beg++;
}
if (n_vocab < model.hparams.n_vocab) {
fprintf(stderr, "%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);
for (int i = n_vocab; i < model.hparams.n_vocab; i++) {
if (i > vocab.token_beg) {
word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]";
} else if (i == vocab.token_eot) {
word = "[_EOT_]";
} else if (i == vocab.token_sot) {
word = "[_SOT_]";
} else if (i == vocab.token_prev) {
word = "[_PREV_]";
} else if (i == vocab.token_not) {
word = "[_NOT_]";
} else if (i == vocab.token_beg) {
word = "[_BEG_]";
} else {
word = "[_extra_token_" + std::to_string(i) + "]";
}
vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word;
}
}
}
// for the big tensors, we have the option to store the data in 16-bit floats
// in order to save memory and also to speed up the computation
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
size_t ctx_size = 0;
size_t ctx_mem_size = 0;
{
const auto & hparams = model.hparams;
const int n_vocab = hparams.n_vocab;
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_audio_state = hparams.n_audio_state;
const int n_audio_layer = hparams.n_audio_layer;
const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_mels = hparams.n_mels;
// encoder
{
// TODO: F16 .. maybe not?
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
}
// decoder
{
// TODO: F16 .. maybe not?
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
}
// encoder layers
{
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
}
// decoder layers
{
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
//
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
}
ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_k
ctx_mem_size += n_text_layer*n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_v
ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_k
ctx_mem_size += n_text_layer*n_audio_ctx*n_text_state*ggml_type_size(GGML_TYPE_F16); // memory_cross_v
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
}
// create the ggml context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_model->size(),
.mem_buffer = wctx.buf_model->data(),
};
model.ctx = ggml_init(params);
if (!model.ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// prepare memory for the weights
{
auto & ctx = model.ctx;
const auto & hparams = model.hparams;
const int n_vocab = hparams.n_vocab;
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_audio_state = hparams.n_audio_state;
const int n_audio_layer = hparams.n_audio_layer;
const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_mels = hparams.n_mels;
model.layers_encoder.resize(n_audio_layer);
model.layers_decoder.resize(n_text_layer);
// encoder
{
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
// map by name
model.tensors["encoder.positional_embedding"] = model.e_pe;
model.tensors["encoder.conv1.weight"] = model.e_conv_1_w;
model.tensors["encoder.conv1.bias"] = model.e_conv_1_b;
model.tensors["encoder.conv2.weight"] = model.e_conv_2_w;
model.tensors["encoder.conv2.bias"] = model.e_conv_2_b;
model.tensors["encoder.ln_post.weight"] = model.e_ln_w;
model.tensors["encoder.ln_post.bias"] = model.e_ln_b;
for (int i = 0; i < n_audio_layer; ++i) {
auto & layer = model.layers_encoder[i];
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state);
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state);
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state);
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state);
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
// map by name
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
}
}
// decoder
{
model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx);
model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab);
model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
// map by name
model.tensors["decoder.positional_embedding"] = model.d_pe;
model.tensors["decoder.token_embedding.weight"] = model.d_te;
model.tensors["decoder.ln.weight"] = model.d_ln_w;
model.tensors["decoder.ln.bias"] = model.d_ln_b;
for (int i = 0; i < n_text_layer; ++i) {
auto & layer = model.layers_decoder[i];
layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state);
layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state);
layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state);
layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state);
layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state);
// map by name
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w;
model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b;
}
}
}
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = wctx.buf_memory.size(),
.mem_buffer = wctx.buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// key + value memory
{
auto & ctx = model.ctx_mem;
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
}
// key/value memory for the cross-attention layer
{
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
}
const size_t memory_size =
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
fprintf(stderr, "%s: memory size = %8.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
// load weights
{
size_t total_size = 0;
model.n_loaded = 0;
while (true) {
int32_t n_dims;
int32_t length;
int32_t ftype;
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
if (fin.eof()) {
break;
}
int32_t nelements = 1;
int32_t ne[3] = { 1, 1, 1 };
for (int i = 0; i < n_dims; ++i) {
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
nelements *= ne[i];
}
std::string name(length, 0);
fin.read(&name[0], length);
if (model.tensors.find(name.data()) == model.tensors.end()) {
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
return false;
}
auto tensor = model.tensors[name.data()];
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
return false;
}
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) {
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n",
__func__, name.data(), tensor->ne[0], tensor->ne[1], tensor->ne[2], ne[0], ne[1], ne[2]);
return false;
}
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
if (nelements*bpe != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
}
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
total_size += ggml_nbytes(tensor);
model.n_loaded++;
}
fprintf(stderr, "%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
if (model.n_loaded == 0) {
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
} else if (model.n_loaded != (int) model.tensors.size()) {
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
return false;
}
}
fin.close();
return true;
}
// evaluate the encoder
//
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
// part of the transformer model and returns the encoded features
//
// - model: the model
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
static bool whisper_encode(
whisper_context & wctx,
const int n_threads,
const int mel_offset) {
const auto & model = wctx.model;
const auto & mel_inp = wctx.mel;
const auto & hparams = model.hparams;
const int n_ctx = hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head;
const int n_layer = hparams.n_audio_layer;
const int N = n_ctx;
const int n_mels = hparams.n_mels;
assert(mel_inp.n_mel == n_mels);
struct ggml_init_params params = {
.mem_size = wctx.buf_compute.size(),
.mem_buffer = wctx.buf_compute.data(),
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
assert(mel->type == GGML_TYPE_F32);
{
float * dst = (float *) mel->data;
memset(dst, 0, ggml_nbytes(mel));
const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
for (int j = 0; j < mel_inp.n_mel; ++j) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
}
}
}
struct ggml_tensor * cur;
// convolution + gelu
{
cur = ggml_conv_1d_1s(ctx0, model.e_conv_1_w, mel);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
model.e_conv_1_b,
cur),
cur);
cur = ggml_gelu(ctx0, cur);
cur = ggml_conv_1d_2s(ctx0, model.e_conv_2_w, cur);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
model.e_conv_2_b,
cur),
cur);
cur = ggml_gelu(ctx0, cur);
}
cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur));
struct ggml_tensor * inpL = cur;
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers_encoder[il];
// create separate context for each layer to reduce memory usage
struct ggml_init_params paramsL = {
.mem_size = wctx.buf_compute_layer.size(),
.mem_buffer = wctx.buf_compute_layer.data(),
};
struct ggml_context * ctxL = ggml_init(paramsL);
// norm
{
cur = ggml_norm(ctxL, inpL);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctxL,
ggml_mul(ctxL,
ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
cur),
ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
layer.attn_q_w,
cur);
Qcur = ggml_add(ctxL,
ggml_repeat(ctxL,
layer.attn_q_b,
Qcur),
Qcur);
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
layer.attn_k_w,
cur);
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
layer.attn_v_w,
cur);
Vcur = ggml_add(ctxL,
ggml_repeat(ctxL,
layer.attn_v_b,
Vcur),
Vcur);
// ------
#ifdef USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
Vcur,
n_state/n_head, n_head, N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
#else
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Kcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
struct ggml_tensor * KQ_scaled =
ggml_scale(ctxL,
KQ,
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_scaled);
//struct ggml_tensor * V_trans =
// ggml_permute(ctxL,
// ggml_cpy(ctxL,
// Vcur,
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
// 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
struct ggml_tensor * V =
ggml_cpy(ctxL,
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
Vcur,
n_state/n_head, n_head, N),
0, 2, 1, 3),
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
cur = ggml_cpy(ctxL,
KQV_merged,
ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
}
// projection
{
cur = ggml_mul_mat(ctxL,
layer.attn_ln_1_w,
cur);
cur = ggml_add(ctxL,
ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
cur);
}
// add the input
cur = ggml_add(ctxL, cur, inpL);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctxL, inpFF);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctxL,
ggml_mul(ctxL,
ggml_repeat(ctxL, layer.mlp_ln_w, cur),
cur),
ggml_repeat(ctxL, layer.mlp_ln_b, cur));
}
#ifdef USE_FLASH_FF
cur = ggml_flash_ff(ctxL,
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
// fully connected
cur = ggml_mul_mat(ctxL,
layer.mlp_0_w,
cur);
cur = ggml_add(ctxL,
ggml_repeat(ctxL, layer.mlp_0_b, cur),
cur);
// GELU activation
cur = ggml_gelu(ctxL, cur);
// projection
cur = ggml_mul_mat(ctxL,
layer.mlp_1_w,
cur);
cur = ggml_add(ctxL,
ggml_repeat(ctxL, layer.mlp_1_b, cur),
cur);
#endif
}
// output from this layer
struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
{
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, inpO);
ggml_graph_compute (ctxL, &gf);
//ggml_graph_print(&gf);
}
// TODO: this is a hack to have per-layer computation graphs - need to come up with something better
// input for next layer (inpO -> inpL)
memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
inpL->op = GGML_OP_NONE;
inpL->src0 = NULL;
inpL->src1 = NULL;
//printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
ggml_free(ctxL);
}
cur = inpL;
// norm
{
cur = ggml_norm(ctx0, cur);
// cur = ln_f_g*cur + ln_f_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.e_ln_w, cur),
cur),
ggml_repeat(ctx0, model.e_ln_b, cur));
}
// run the computation
{
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur);
ggml_graph_compute (ctx0, &gf);
//ggml_graph_print(&gf);
}
// cur
//{
// printf("ne0 = %d\n", cur->ne[0]);
// printf("ne1 = %d\n", cur->ne[1]);
// for (int i = 0; i < 10; ++i) {
// printf("%8.4f ", ((float *)(cur->data))[i]);
// }
// printf("... ");
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
// printf("%8.4f ", ((float *)(cur->data))[i]);
// }
// printf("\n");
//}
// pre-compute cross-attention memory
{
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
// TODO: hack to disconnect the encoded features from the previous graph
cur->op = GGML_OP_NONE;
cur->src0 = NULL;
cur->src1 = NULL;
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];
struct ggml_tensor * Kcross = ggml_mul_mat(ctx0,
layer.cross_attn_k_w,
cur);
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
cur);
Vcross = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.cross_attn_v_b,
Vcross),
Vcross);
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
}
ggml_graph_compute(ctx0, &gf);
}
////////////////////////////////////////////////////////////////////////////
//printf("%s: used_mem = %f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0);
ggml_free(ctx0);
return true;
}
// evaluate the decoder
//
// given text prompt + audio features -> predicts the probabilities for the next token
//
// - model: the model
// - n_threads: number of threads to use
// - tokens: text prompt
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
static bool whisper_decode(
whisper_context & wctx,
const int n_threads,
const whisper_token * tokens,
const int n_tokens,
const int n_past) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
auto & logits_out = wctx.logits;
auto & probs_out = wctx.probs;
const int n_vocab = hparams.n_vocab;
const int n_ctx = hparams.n_text_ctx;
const int n_state = hparams.n_text_state;
const int n_head = hparams.n_text_head;
const int n_layer = hparams.n_text_layer;
const int N = n_tokens;
const int M = hparams.n_audio_ctx;
struct ggml_init_params params = {
.mem_size = wctx.buf_compute.size(),
.mem_buffer = wctx.buf_compute.data(),
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, tokens, N*ggml_element_size(embd));
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
// token encoding + position encoding
struct ggml_tensor * cur =
ggml_add(ctx0,
ggml_get_rows(ctx0, model.d_te, embd),
ggml_get_rows(ctx0, model.d_pe, position));
struct ggml_tensor * inpL = cur;
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers_decoder[il];
struct ggml_init_params paramsL = {
.mem_size = wctx.buf_compute_layer.size(),
.mem_buffer = wctx.buf_compute_layer.data(),
};
struct ggml_context * ctxL = ggml_init(paramsL);
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
// norm
{
cur = ggml_norm(ctxL, inpL);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctxL,
ggml_mul(ctxL,
ggml_repeat(ctxL, layer.attn_ln_0_w, cur),
cur),
ggml_repeat(ctxL, layer.attn_ln_0_b, cur));
}
// self-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
layer.attn_q_w,
cur);
Qcur = ggml_add(ctxL,
ggml_repeat(ctxL,
layer.attn_q_b,
Qcur),
Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
layer.attn_k_w,
cur);
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
layer.attn_v_w,
cur);
Vcur = ggml_add(ctxL,
ggml_repeat(ctxL,
layer.attn_v_b,
Vcur),
Vcur);
// store key and value to memory
{
struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v));
}
// ------
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctxL,
// KQ,
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
// );
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
struct ggml_tensor * V_trans =
ggml_permute(ctxL,
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
cur = ggml_cpy(ctxL,
KQV_merged,
ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
}
{
cur = ggml_mul_mat(ctxL,
layer.attn_ln_1_w,
cur);
cur = ggml_add(ctxL,
ggml_repeat(ctxL, layer.attn_ln_1_b, cur),
cur);
}
// add the input
struct ggml_tensor * inpCA = ggml_add(ctxL, cur, inpL);
// norm
{
cur = ggml_norm(ctxL, inpCA); // note: we use inpCA here
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctxL,
ggml_mul(ctxL,
ggml_repeat(ctxL, layer.cross_attn_ln_0_w, cur),
cur),
ggml_repeat(ctxL, layer.cross_attn_ln_0_b, cur));
}
// cross-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctxL,
layer.cross_attn_q_w,
cur);
Qcur = ggml_add(ctxL,
ggml_repeat(ctxL,
layer.cross_attn_q_b,
Qcur),
Qcur);
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
// Kcross is already scaled
struct ggml_tensor * Kcross =
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state),
n_state/n_head, n_head, M);
struct ggml_tensor * Vcross =
ggml_reshape_3d(ctxL,
ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state),
n_state/n_head, n_head, M);
// ------
struct ggml_tensor * Q =
ggml_permute(ctxL,
ggml_cpy(ctxL,
Qcur,
ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)),
0, 2, 1, 3);
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctxL,
// KQ,
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
// );
// no masking for cross-attention
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ);
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_state, N)
cur = ggml_cpy(ctxL,
KQV_merged,
ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N));
}
// projection
{
cur = ggml_mul_mat(ctxL,
layer.cross_attn_ln_1_w,
cur);
cur = ggml_add(ctxL,
ggml_repeat(ctxL, layer.cross_attn_ln_1_b, cur),
cur);
}
// add the input
cur = ggml_add(ctxL, cur, inpCA);
struct ggml_tensor * inpFF = cur;
// feed-forward network
{
// norm
{
cur = ggml_norm(ctxL, inpFF);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctxL,
ggml_mul(ctxL,
ggml_repeat(ctxL, layer.mlp_ln_w, cur),
cur),
ggml_repeat(ctxL, layer.mlp_ln_b, cur));
}
// fully connected
cur = ggml_mul_mat(ctxL,
layer.mlp_0_w,
cur);
cur = ggml_add(ctxL,
ggml_repeat(ctxL, layer.mlp_0_b, cur),
cur);
// GELU activation
cur = ggml_gelu(ctxL, cur);
// projection
cur = ggml_mul_mat(ctxL,
layer.mlp_1_w,
cur);
cur = ggml_add(ctxL,
ggml_repeat(ctxL, layer.mlp_1_b, cur),
cur);
}
// output from this layer
struct ggml_tensor * inpO = ggml_add(ctxL, cur, inpFF);
{
ggml_build_forward_expand(&gf, inpO);
ggml_graph_compute (ctxL, &gf);
//ggml_graph_print(&gf);
}
// TODO: this is a hack to have per-layer computation graphs - need to come up with something better
// input for next layer (inpO -> inpL)
memcpy(inpL->data, inpO->data, ggml_nbytes(inpL));
inpL->op = GGML_OP_NONE;
inpL->src0 = NULL;
inpL->src1 = NULL;
if (N > 1) {
//printf("%s: - used_mem(%d) = %f MB\n", __func__, il, ggml_used_mem(ctxL)/1024.0/1024.0);
}
ggml_free(ctxL);
}
cur = inpL;
// norm
{
cur = ggml_norm(ctx0, cur);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.d_ln_w, cur),
cur),
ggml_repeat(ctx0, model.d_ln_b, cur));
}
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
// logits -> probs
cur = ggml_dup(ctx0, logits);
cur = ggml_soft_max(ctx0, cur); // in-place
// run the computation
{
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur);
ggml_graph_compute (ctx0, &gf);
}
logits_out.resize(N*n_vocab);
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
probs_out.resize(N*n_vocab);
memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab);
if (N > 1) {
//const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N;
//printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token);
//printf("%s: max mem = %f MB\n", __func__, mem_per_token*model.hparams.n_text_ctx);
}
ggml_free(ctx0);
return true;
}
// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
const float * probs) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.reserve(n_logits);
for (int i = 0; i < n_logits; i++) {
probs_id.push_back(std::make_pair(probs[i], i));
}
{
double sum_ts = 0.0;
double max_ts = -1.0;
double max_tx = -1.0;
for (int i = 0; i < vocab.token_beg; i++) {
max_tx = std::max(max_tx, probs_id[i].first);
}
for (int i = vocab.token_beg; i < n_logits; i++) {
sum_ts += probs_id[i].first;
if (probs_id[i].first > max_ts) {
max_ts = probs_id[i].first;
result.tid = probs_id[i].second;
}
}
// if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a
// timestamp token
if (sum_ts > max_tx) {
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for (int i = 0; i < vocab.token_beg; i++) {
probs_id[i].first = -INFINITY;
}
}
result.pt = max_ts/(sum_ts + 1e-10);
result.ptsum = sum_ts;
}
// find the top K tokens
const int top_k = 4;
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
[](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
return a.first > b.first;
});
probs_id.resize(top_k);
//printf("\n");
//for (int i = 0; i < (int) probs_id.size(); i++) {
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
int res = 0;
while ((probs_id[res].second == vocab.token_sot ||
probs_id[res].second == vocab.token_solm ||
probs_id[res].second == vocab.token_not) &&
res < (int) probs_id.size() - 1) {
res++;
}
result.id = probs_id[res].second;
result.p = probs_id[res].first;
return result;
}
// samples only from the timestamps tokens
static whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab,
const float * probs) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.reserve(n_logits);
for (int i = vocab.token_beg + 1; i < n_logits; i++) {
probs_id.push_back(std::make_pair(probs[i], i));
}
const int top_k = 10;
// find the top K tokens
std::partial_sort(
probs_id.begin(),
probs_id.begin() + top_k, probs_id.end(),
[](const std::pair<double, whisper_vocab::id> & a, const std::pair<double, whisper_vocab::id> & b) {
return a.first > b.first;
});
probs_id.resize(top_k);
//printf("\n");
//for (int i = 0; i < (int) probs_id.size(); i++) {
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
//}
return probs_id[0].second;
}
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
int64_t msec = t * 10;
int64_t hr = msec / (1000 * 60 * 60);
msec = msec - hr * (1000 * 60 * 60);
int64_t min = msec / (1000 * 60);
msec = msec - min * (1000 * 60);
int64_t sec = msec / 1000;
msec = msec - sec * 1000;
char buf[32];
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
return std::string(buf);
}
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
static void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;
for (int n = 0; n < N; n++) {
float angle = 2*M_PI*k*n/N;
re += in[n]*cos(angle);
im -= in[n]*sin(angle);
}
out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}
// Cooley-Tukey FFT
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
static void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2);
int N = in.size();
if (N == 1) {
out[0] = in[0];
out[1] = 0;
return;
}
if (N%2 == 1) {
dft(in, out);
return;
}
std::vector<float> even;
std::vector<float> odd;
for (int i = 0; i < N; i++) {
if (i % 2 == 0) {
even.push_back(in[i]);
} else {
odd.push_back(in[i]);
}
}
std::vector<float> even_fft;
std::vector<float> odd_fft;
fft(even, even_fft);
fft(odd, odd_fft);
for (int k = 0; k < N/2; k++) {
float theta = 2*M_PI*k/N;
float re = cos(theta);
float im = -sin(theta);
float re_odd = odd_fft[2*k + 0];
float im_odd = odd_fft[2*k + 1];
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
}
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int sample_rate,
const int fft_size,
const int fft_step,
const int n_mel,
const int n_threads,
const whisper_filters & filters,
whisper_mel & mel) {
// Hanning window
std::vector<float> hann;
hann.resize(fft_size);
for (int i = 0; i < fft_size; i++) {
hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size)));
}
mel.n_mel = n_mel;
mel.n_len = (n_samples)/fft_step;
mel.data.resize(mel.n_mel*mel.n_len);
const int n_fft = 1 + fft_size/2;
//printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
//printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
std::vector<std::thread> workers(n_threads);
for (int iw = 0; iw < n_threads; ++iw) {
workers[iw] = std::thread([&](int ith) {
std::vector<float> fft_in;
fft_in.resize(fft_size);
for (int i = 0; i < fft_size; i++) {
fft_in[i] = 0.0;
}
std::vector<float> fft_out;
fft_out.resize(2*fft_size);
for (int i = ith; i < mel.n_len; i += n_threads) {
const int offset = i*fft_step;
// apply Hanning window
for (int j = 0; j < fft_size; j++) {
if (offset + j < n_samples) {
fft_in[j] = hann[j]*samples[offset + j];
} else {
fft_in[j] = 0.0;
}
}
// FFT -> mag^2
fft(fft_in, fft_out);
for (int j = 0; j < fft_size; j++) {
fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
}
for (int j = 1; j < fft_size/2; j++) {
//if (i == 0) {
// printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
//}
fft_out[j] += fft_out[fft_size - j];
}
if (i == 0) {
//for (int j = 0; j < fft_size; j++) {
// printf("%d: %e\n", j, fft_out[j]);
//}
}
// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
double sum = 0.0;
for (int k = 0; k < n_fft; k++) {
sum += fft_out[k]*filters.data[j*n_fft + k];
}
if (sum < 1e-10) {
sum = 1e-10;
}
sum = log10(sum);
mel.data[j*mel.n_len + i] = sum;
}
}
}, iw);
}
for (int iw = 0; iw < n_threads; ++iw) {
workers[iw].join();
}
// clamping and normalization
double mmax = -1e20;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] > mmax) {
mmax = mel.data[i];
}
}
//printf("%s: max = %f\n", __func__, mmax);
mmax -= 8.0;
for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
if (mel.data[i] < mmax) {
mel.data[i] = mmax;
}
mel.data[i] = (mel.data[i] + 4.0)/4.0;
}
return true;
}
//
// interface implementation
//
struct whisper_context * whisper_init(const char * path_model) {
ggml_time_init();
whisper_context * ctx = new whisper_context;
const int64_t t_start_us = ggml_time_us();
ctx->t_start_us = t_start_us;
if (!whisper_model_load(path_model, *ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, path_model);
return NULL;
}
ctx->t_load_us = ggml_time_us() - t_start_us;
return ctx;
}
void whisper_free(struct whisper_context * ctx) {
if (ctx) {
if (ctx->buf_model) {
delete ctx->buf_model;
}
delete ctx;
}
}
int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, ctx->mel)) {
fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
ctx->t_mel_us = ggml_time_us() - t_start_us;
return 0;
}
int whisper_set_mel(
struct whisper_context * ctx,
const float * data,
int n_len,
int n_mel) {
if (n_mel != WHISPER_N_MEL) {
fprintf(stderr, "%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL);
return -1;
}
ctx->mel.n_len = n_len;
ctx->mel.n_mel = n_mel;
ctx->mel.data.resize(n_len*n_mel);
memcpy(ctx->mel.data.data(), data, n_len*n_mel*sizeof(float));
return 0;
}
int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!whisper_encode(*ctx, n_threads, offset)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return -1;
}
ctx->t_encode_us += ggml_time_us() - t_start_us;
return 0;
}
int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
const int64_t t_start_us = ggml_time_us();
if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return 1;
}
ctx->t_decode_us += ggml_time_us() - t_start_us;
return 0;
}
struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify
auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res;
}
whisper_token whisper_sample_timestamp(struct whisper_context * ctx) {
const int64_t t_start_sample_us = ggml_time_us();
// TODO: simplify
auto res = whisper_sample_timestamp(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab));
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
return res;
}
int whisper_lang_id(const char * lang) {
if (!g_lang.count(lang)) {
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
return -1;
}
return g_lang.at(lang).first;
}
int whisper_n_len(struct whisper_context * ctx) {
return ctx->mel.n_len;
}
int whisper_n_vocab(struct whisper_context * ctx) {
return ctx->vocab.n_vocab;
}
int whisper_n_text_ctx(struct whisper_context * ctx) {
return ctx->model.hparams.n_text_ctx;
}
int whisper_is_multilingual(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0;
}
float * whisper_get_probs(struct whisper_context * ctx) {
return ctx->probs.data();
}
const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) {
return ctx->vocab.id_to_token.at(token).c_str();
}
whisper_token whisper_token_eot(struct whisper_context * ctx) {
return ctx->vocab.token_eot;
}
whisper_token whisper_token_sot(struct whisper_context * ctx) {
return ctx->vocab.token_sot;
}
whisper_token whisper_token_prev(struct whisper_context * ctx) {
return ctx->vocab.token_prev;
}
whisper_token whisper_token_solm(struct whisper_context * ctx) {
return ctx->vocab.token_solm;
}
whisper_token whisper_token_not(struct whisper_context * ctx) {
return ctx->vocab.token_not;
}
whisper_token whisper_token_beg(struct whisper_context * ctx) {
return ctx->vocab.token_beg;
}
whisper_token whisper_token_translate() {
return whisper_vocab::token_translate;
}
whisper_token whisper_token_transcribe() {
return whisper_vocab::token_transcribe;
}
void whisper_print_timings(struct whisper_context * ctx) {
const int64_t t_end_us = ggml_time_us();
fprintf(stderr, "\n");
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us/1000.0f);
fprintf(stderr, "%s: mel time = %8.2f ms\n", __func__, ctx->t_mel_us/1000.0f);
fprintf(stderr, "%s: sample time = %8.2f ms\n", __func__, ctx->t_sample_us/1000.0f);
fprintf(stderr, "%s: encode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_encode_us/1000.0f, ctx->t_encode_us/1000.0f/ctx->model.hparams.n_audio_layer);
fprintf(stderr, "%s: decode time = %8.2f ms / %.2f ms per layer\n", __func__, ctx->t_decode_us/1000.0f, ctx->t_decode_us/1000.0f/ctx->model.hparams.n_text_layer);
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
}
////////////////////////////////////////////////////////////////////////////
struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) {
struct whisper_full_params result;
switch (strategy) {
case WHISPER_SAMPLING_GREEDY:
{
result = {
/*.strategy =*/ WHISPER_SAMPLING_GREEDY,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
/*.print_special_tokens =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
/*.n_past =*/ 0,
},
/*.beam_search =*/ {
/*.n_past =*/ -1,
/*.beam_width =*/ -1,
/*.n_best =*/ -1,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
result = {
/*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH,
/*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()),
/*.n_max_text_ctx =*/ 16384,
/*.offset_ms =*/ 0,
/*.translate =*/ false,
/*.no_context =*/ false,
/*.print_special_tokens =*/ false,
/*.print_progress =*/ true,
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
/*.n_past =*/ -1,
},
/*.beam_search =*/ {
/*.n_past =*/ 0,
/*.beam_width =*/ 10,
/*.n_best =*/ 5,
},
/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,
};
} break;
}
return result;
}
// forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum);
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
auto segment = ctx->result_all.back();
int res = 1;
int acc = 0;
std::string text;
for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i];
if (token.id >= whisper_token_eot(ctx)) {
continue;
}
const auto txt = whisper_token_to_str(ctx, token.id);
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) {
// split here
ctx->result_all.back().text = std::move(text);
ctx->result_all.back().t1 = token.t0;
ctx->result_all.back().tokens.resize(i);
ctx->result_all.push_back({});
ctx->result_all.back().t0 = token.t0;
ctx->result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment
ctx->result_all.back().tokens.insert(
ctx->result_all.back().tokens.end(),
segment.tokens.begin() + i,
segment.tokens.end());
acc = 0;
text = "";
segment = ctx->result_all.back();
i = -1;
res++;
} else {
acc += cur;
text += txt;
}
}
ctx->result_all.back().text = std::move(text);
return res;
}
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples) {
// clear old results
auto & result_all = ctx->result_all;
result_all.clear();
// compute log mel spectrogram
if (whisper_pcm_to_mel(ctx, samples, n_samples, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to compute log mel spectrogram\n", __func__);
return -1;
}
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
ctx->tid_last = 0;
ctx->energy = get_signal_energy(samples, n_samples, 32);
}
const int seek_start = params.offset_ms/10;
// if length of spectrogram is less than 1s (100 samples), then return
// basically don't process anything that is less than 1s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (whisper_n_len(ctx) < 100 + seek_start) {
return 0;
}
// the accumulated text context so far
auto & prompt_past = ctx->prompt_past;
if (params.no_context) {
prompt_past.clear();
}
// these tokens determine the task that will be performed
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
if (whisper_is_multilingual(ctx)) {
prompt_init.push_back(whisper_token_sot(ctx) + 1 + whisper_lang_id(params.language));
if (params.translate) {
prompt_init.push_back(whisper_token_translate());
} else {
prompt_init.push_back(whisper_token_transcribe());
}
}
int progress_prev = 0;
int progress_step = 5;
std::vector<whisper_token_data> tokens_cur;
tokens_cur.reserve(whisper_n_text_ctx(ctx));
std::vector<whisper_token> prompt;
prompt.reserve(whisper_n_text_ctx(ctx));
// main loop
int seek = seek_start;
while (true) {
int progress_cur = (100*seek)/whisper_n_len(ctx);
while (progress_cur >= progress_prev + progress_step) {
progress_prev += progress_step;
if (params.print_progress) {
fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev);
}
}
if (seek + 100 >= whisper_n_len(ctx)) {
break;
}
// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
return 7;
}
int n_past = 0;
prompt.clear();
// if we have already generated some text, use it as a prompt to condition the next generation
if (prompt_past.size() > 0) {
int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size()));
prompt = { whisper_token_prev(ctx) };
prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
prompt_past.clear();
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
}
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
bool done = false;
int seek_delta = 100*WHISPER_CHUNK_SIZE;
// print the prompt
//printf("\n\n");
//for (int i = 0; i < prompt.size(); i++) {
// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str());
//}
//printf("\n\n");
// the accumulated transcription in the current interation
int result_len = 0;
tokens_cur.clear();
for (int i = 0; i < whisper_n_text_ctx(ctx)/2 - 4; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
return 8;
}
n_past += prompt.size();
prompt.clear();
// very basic greedy sampling strategy:
//
// - always take the most probable token
//
// more sophisticated sampling strategies could be implemented here, but we keep it simple
// feel free to experiment!
//
{
auto token = whisper_sample_best(ctx);
if (i == 0) {
token.tid = whisper_token_beg(ctx);
}
// timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) {
seek_delta = 2*(token.id - whisper_token_beg(ctx));
result_len = i + 1;
}
// add it to the context
prompt.push_back(token.id);
tokens_cur.push_back(token);
//{
// const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]";
// printf("%s: %10s %6.3f '%s'\n", __func__, tt.c_str(), token.pt, ctx->vocab.id_to_token[token.id].c_str());
//}
// end of text token
if (token.id == whisper_token_eot(ctx)) {
if (result_len == 0) {
if (seek + seek_delta + 100 >= whisper_n_len(ctx)) {
result_len = i + 1;
} else {
// TODO: figure out how to resolve this
fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__);
}
}
break;
}
// TESTS: if no tensors are loaded, it means we are running tests
if (ctx->model.n_loaded == 0) {
seek_delta = 100*WHISPER_CHUNK_SIZE;
break;
}
}
if (done) {
break;
}
}
// shrink down to result_len
tokens_cur.resize(result_len);
for (const auto & r : tokens_cur) {
prompt_past.push_back(r.id);
}
// store the text from this iteration
if (tokens_cur.size() > 0) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
std::string text = "";
for (int i = 0; i < (int) tokens_cur.size(); i++) {
//printf("%s: %18s %6.3f %18s %6.3f\n", __func__,
// ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p,
// ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt);
if (params.print_special_tokens == false && tokens_cur[i].id >= whisper_token_eot(ctx)) {
} else {
text += whisper_token_to_str(ctx, tokens_cur[i].id);
}
if (tokens_cur[i].id > whisper_token_beg(ctx)) {
const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx));
if (!text.empty()) {
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
result_all.push_back({ t0, t1, text, {} });
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
text = "";
while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) {
i++;
}
i--;
t0 = t1;
i0 = i + 1;
}
}
if (!text.empty()) {
const auto t1 = seek + seek_delta;
if (params.print_realtime) {
if (params.print_timestamps) {
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text.c_str());
} else {
printf("%s", text.c_str());
fflush(stdout);
}
}
result_all.push_back({ t0, t1, text, {} });
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
}
seek += seek_delta;
}
return 0;
}
int whisper_full_parallel(
struct whisper_context * ctx,
struct whisper_full_params params,
const float * samples,
int n_samples,
const int n_processors) {
if (n_processors == 1) {
return whisper_full(ctx, params, samples, n_samples);
}
int ret = 0;
// prepare separate contexts for each thread
std::vector<struct whisper_context> ctxs(n_processors - 1);
for (int i = 0; i < n_processors - 1; ++i) {
ctxs[i] = *ctx;
auto & model = ctxs[i].model;
// create the ggml memory context
{
struct ggml_init_params params = {
.mem_size = ctxs[i].buf_memory.size(),
.mem_buffer = ctxs[i].buf_memory.data(),
};
model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}
// separate key + value memory for each processor
{
auto & ctx = model.ctx_mem;
const auto & hparams = model.hparams;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
}
// key/value memory for the cross-attention layer
{
const int n_audio_ctx = hparams.n_audio_ctx;
const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
}
const size_t memory_size =
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
}
}
const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000;
const int n_samples_per_processor = (n_samples - offset_samples)/n_processors;
// the calling thread will process the first chunk
// while the other threads will process the remaining chunks
std::vector<std::thread> workers(n_processors - 1);
for (int i = 0; i < n_processors - 1; ++i) {
const int start_samples = offset_samples + (i + 1)*n_samples_per_processor;
const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor;
auto params_cur = params;
params_cur.offset_ms = 0;
params_cur.print_progress = false;
params_cur.print_realtime = false;
params_cur.new_segment_callback = nullptr;
params_cur.new_segment_callback_user_data = nullptr;
workers[i] = std::thread(whisper_full, &ctxs[i], std::move(params_cur), samples + start_samples, n_samples_cur);
}
{
auto params_cur = params;
ret = whisper_full(ctx, std::move(params_cur), samples, offset_samples + n_samples_per_processor);
}
for (int i = 0; i < n_processors - 1; ++i) {
workers[i].join();
}
const int64_t offset_t = (int64_t) params.offset_ms/10.0;
// combine results into ctx->result_all
for (int i = 0; i < n_processors - 1; ++i) {
auto & results_i = ctxs[i].result_all;
for (int j = 0; j < (int) results_i.size(); ++j) {
// correct the segment timestamp taking into account the offset
results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
// make sure that segments are not overlapping
if (ctx->result_all.size() > 0) {
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
}
ctx->result_all.push_back(std::move(results_i[j]));
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
}
}
ctx->t_mel_us += ctxs[i].t_mel_us;
ctx->t_sample_us += ctxs[i].t_sample_us;
ctx->t_encode_us += ctxs[i].t_encode_us;
ctx->t_decode_us += ctxs[i].t_decode_us;
}
// average the timings
ctx->t_mel_us /= n_processors;
ctx->t_sample_us /= n_processors;
ctx->t_encode_us /= n_processors;
ctx->t_decode_us /= n_processors;
// print information about the audio boundaries
fprintf(stderr, "\n");
fprintf(stderr, "%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors);
for (int i = 0; i < n_processors - 1; ++i) {
fprintf(stderr, "%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str());
}
fprintf(stderr, "%s: the transcription quality may be degraded near these boundaries\n", __func__);
return ret;
}
int whisper_full_n_segments(struct whisper_context * ctx) {
return ctx->result_all.size();
}
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].t0;
}
int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].t1;
}
const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].text.c_str();
}
int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) {
return ctx->result_all[i_segment].tokens.size();
}
const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->vocab.id_to_token[ctx->result_all[i_segment].tokens[i_token].id].c_str();
}
whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].id;
}
struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token];
}
float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) {
return ctx->result_all[i_segment].tokens[i_token].p;
}
const char * whisper_print_system_info() {
static std::string s;
s = "";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
return s.c_str();
}
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
static int64_t sample_to_timestamp(int i_sample) {
return (100*i_sample)/WHISPER_SAMPLE_RATE;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
res += 2.00f;
} else if (text[i] == '.') {
res += 3.00f;
} else if (text[i] == '!') {
res += 3.00f;
} else if (text[i] == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// average the fabs of the signal
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
const int hw = n_samples_per_half_window;
std::vector<float> result(n_samples);
for (int i = 0; i < n_samples; i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < n_samples) {
sum += fabs(signal[i + j]);
}
}
result[i] = sum/(2*hw + 1);
}
return result;
}
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum) {
auto & segment = ctx->result_all[i_segment];
auto & tokens = segment.tokens;
const int n_samples = ctx->energy.size();
if (n_samples == 0) {
fprintf(stderr, "%s: no signal data available\n", __func__);
return;
}
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
const int s0 = timestamp_to_sample(t0, n_samples);
const int s1 = timestamp_to_sample(t1, n_samples);
const int n = tokens.size();
if (n == 0) {
return;
}
if (n == 1) {
tokens[0].t0 = t0;
tokens[0].t1 = t1;
return;
}
auto & t_beg = ctx->t_beg;
auto & t_last = ctx->t_last;
auto & tid_last = ctx->tid_last;
for (int j = 0; j < n; ++j) {
auto & token = tokens[j];
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
// split the time proportionally to the voice length
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
// fix up (just in case)
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
if (j > 0) {
if (tokens[j - 1].t1 > tokens[j].t0) {
tokens[j].t0 = tokens[j - 1].t1;
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
const int ss0 = std::max(s0 - hw, 0);
const int ss1 = std::min(s1 + hw, n_samples);
const int ns = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += ctx->energy[k];
}
const float thold = 0.5*sum/ns;
{
int k = s0;
if (ctx->energy[k] > thold && j > 0) {
while (k > 0 && ctx->energy[k] > thold) {
k--;
}
tokens[j].t0 = sample_to_timestamp(k);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (ctx->energy[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = sample_to_timestamp(k);
}
}
{
int k = s1;
if (ctx->energy[k] > thold) {
while (k < n_samples - 1 && ctx->energy[k] > thold) {
k++;
}
tokens[j].t1 = sample_to_timestamp(k);
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (ctx->energy[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = sample_to_timestamp(k);
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) {
// continue;
// }
//}
}