whisper : factor out graph builds

This commit is contained in:
Georgi Gerganov
2023-09-10 19:23:06 +03:00
parent fbc3f8033e
commit 949ab6328d
2 changed files with 182 additions and 369 deletions

View File

@@ -295,6 +295,11 @@ $(info )
ggml.o: ggml.c ggml.h ggml-cuda.h
$(CC) $(CFLAGS) -c $< -o $@
ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
$(CC) $(CFLAGS) -c $< -o $@
WHISPER_OBJ += ggml-alloc.o
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
$(CXX) $(CXXFLAGS) -c $< -o $@

View File

@@ -12,6 +12,7 @@
#endif
#include "ggml.h"
#include "ggml-alloc.h"
#include <algorithm>
#include <cassert>
@@ -119,9 +120,6 @@ static void byteswap_tensor(ggml_tensor * tensor) {
//#define WHISPER_USE_FLASH_FF
#define WHISPER_MAX_DECODERS 16
#define WHISPER_USE_SCRATCH
#define WHISPER_MAX_SCRATCH_BUFFERS 16
// available whisper models
enum e_model {
MODEL_UNKNOWN,
@@ -236,38 +234,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
static const size_t MB = 1ull*1024*1024;
static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
{ MODEL_TINY, 62ull*MB },
{ MODEL_BASE, 80ull*MB },
{ MODEL_SMALL, 120ull*MB },
{ MODEL_MEDIUM, 158ull*MB },
{ MODEL_LARGE, 198ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
{ MODEL_TINY, 18ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 36ull*MB },
{ MODEL_MEDIUM, 48ull*MB },
{ MODEL_LARGE, 60ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
{ MODEL_TINY, 4ull*MB },
{ MODEL_BASE, 4ull*MB },
{ MODEL_SMALL, 6ull*MB },
{ MODEL_MEDIUM, 7ull*MB },
{ MODEL_LARGE, 9ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
{ MODEL_TINY, 4ull*MB },
{ MODEL_BASE, 4ull*MB },
{ MODEL_SMALL, 6ull*MB },
{ MODEL_MEDIUM, 7ull*MB },
{ MODEL_LARGE, 9ull*MB },
};
// TODO: avoid using GGUF
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
{ GGML_TYPE_F32,
{
@@ -334,38 +301,6 @@ static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
},
};
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
{ MODEL_TINY, 3ull*MB },
{ MODEL_BASE, 6ull*MB },
{ MODEL_SMALL, 16ull*MB },
{ MODEL_MEDIUM, 43ull*MB },
{ MODEL_LARGE, 71ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
{ MODEL_TINY, 9ull*MB },
{ MODEL_BASE, 18ull*MB },
{ MODEL_SMALL, 53ull*MB },
{ MODEL_MEDIUM, 141ull*MB },
{ MODEL_LARGE, 235ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
{ MODEL_TINY, 30ull*MB },
{ MODEL_BASE, 38ull*MB },
{ MODEL_SMALL, 56ull*MB },
{ MODEL_MEDIUM, 74ull*MB },
{ MODEL_LARGE, 94ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
{ MODEL_TINY, 3ull*MB },
{ MODEL_BASE, 5ull*MB },
{ MODEL_SMALL, 10ull*MB },
{ MODEL_MEDIUM, 18ull*MB },
{ MODEL_LARGE, 27ull*MB },
};
struct whisper_mel {
int n_len;
int n_len_org;
@@ -670,10 +605,12 @@ struct whisper_state {
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
int buf_last = 0;
size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
std::vector<uint8_t> buf_encode;
std::vector<uint8_t> buf_decode;
ggml_allocr * alloc_encode = NULL;
ggml_allocr * alloc_decode = NULL;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
@@ -709,37 +646,6 @@ struct whisper_state {
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx = 0; // 0 - use default
void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH)
size_t last_size = 0;
if (i == -1) {
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
} else {
auto & buf = buf_scratch[i];
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
}
if (buf_last >= 0) {
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
}
buf_last = i;
#else
(void) i;
(void) ctx;
#endif
}
size_t get_buf_max_mem(int i) const {
#if defined(WHISPER_USE_SCRATCH)
return buf_max_size[i];
#else
(void) i;
return 0;
#endif
}
};
struct whisper_context {
@@ -786,10 +692,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
static bool kv_cache_init(
const struct whisper_hparams & hparams,
const size_t mem_bytes,
struct whisper_kv_cache & cache,
ggml_type wtype,
int n_ctx) {
const int64_t n_text_state = hparams.n_text_state;
const int64_t n_text_layer = hparams.n_text_layer;
const int64_t n_mem = n_text_layer*n_ctx;
const int64_t n_elements = n_text_state*n_mem;
const size_t mem_bytes = ggml_type_size(wtype)*n_elements;
cache.buf.resize(mem_bytes);
struct ggml_init_params params = {
@@ -805,12 +718,6 @@ static bool kv_cache_init(
return false;
}
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_mem = n_text_layer*n_ctx;
const int n_elements = n_text_state*n_mem;
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
@@ -953,22 +860,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// print memory requirements
{
// this is the total memory required to run the inference
const size_t mem_required =
MEM_REQ_SCRATCH0.at(model.type) +
MEM_REQ_SCRATCH1.at(model.type) +
MEM_REQ_SCRATCH2.at(model.type) +
MEM_REQ_SCRATCH3.at(model.type) +
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
// this is the memory required by one decoder
const size_t mem_required_decoder =
scale*MEM_REQ_KV_SELF.at(model.type);
log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
// TODO
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
}
// initialize all memory buffers
@@ -1477,24 +1371,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
return true;
}
// evaluate the encoder with the given state
//
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
// part of the transformer model and returns the encoded features
//
// - wctx: the model
// - wstate: the state of the encoder
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
static bool whisper_encode_internal(
static struct ggml_cgraph * whisper_build_graph_encoder(
whisper_context & wctx,
whisper_state & wstate,
const int mel_offset,
const int n_threads){
const int64_t t_start_us = ggml_time_us();
const int mel_offset) {
const auto & model = wctx.model;
const auto & mel_inp = wstate.mel;
const auto & hparams = model.hparams;
@@ -1510,12 +1390,12 @@ static bool whisper_encode_internal(
struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(),
/*.no_alloc =*/ false,
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
wstate.use_buf(ctx0, 0);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
assert(mel->type == GGML_TYPE_F32);
@@ -1550,8 +1430,6 @@ static bool whisper_encode_internal(
if (!use_coreml && !use_openvino) {
// convolution + gelu
{
wstate.use_buf(ctx0, 1);
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
@@ -1561,8 +1439,6 @@ static bool whisper_encode_internal(
cur = ggml_gelu(ctx0, cur);
wstate.use_buf(ctx0, 0);
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
@@ -1573,8 +1449,6 @@ static bool whisper_encode_internal(
cur = ggml_gelu(ctx0, cur);
}
wstate.use_buf(ctx0, 3);
// ===================================================================
// NOTE: experimenting with partial evaluation of the encoder (ignore)
//static int iter = -1;
@@ -1608,8 +1482,6 @@ static bool whisper_encode_internal(
// norm
{
wstate.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_0_w*cur + ln_0_b
@@ -1622,8 +1494,6 @@ static bool whisper_encode_internal(
// self-attention
{
wstate.use_buf(ctx0, 1);
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.attn_q_w,
cur);
@@ -1655,8 +1525,6 @@ static bool whisper_encode_internal(
// ------
wstate.use_buf(ctx0, 0);
#ifdef WHISPER_USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctx0,
@@ -1722,8 +1590,6 @@ static bool whisper_encode_internal(
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
wstate.use_buf(ctx0, 1);
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
@@ -1731,21 +1597,15 @@ static bool whisper_encode_internal(
// projection
{
wstate.use_buf(ctx0, 0);
cur = ggml_mul_mat(ctx0,
layer.attn_ln_1_w,
cur);
wstate.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
}
wstate.use_buf(ctx0, 2);
// add the input
cur = ggml_add(ctx0, cur, inpL);
@@ -1755,12 +1615,8 @@ static bool whisper_encode_internal(
{
// norm
{
wstate.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpFF, hparams.eps);
wstate.use_buf(ctx0, 1);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
@@ -1770,47 +1626,33 @@ static bool whisper_encode_internal(
}
#ifdef WHISPER_USE_FLASH_FF
wstate.use_buf(ctx0, 0);
cur = ggml_flash_ff(ctx0,
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
wstate.use_buf(ctx0, 0);
// fully connected
cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
wstate.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
wstate.use_buf(ctx0, 0);
// GELU activation
cur = ggml_gelu(ctx0, cur);
wstate.use_buf(ctx0, 1);
// projection
cur = ggml_mul_mat(ctx0,
layer.mlp_1_w,
cur);
wstate.use_buf(ctx0, 0);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
#endif
}
wstate.use_buf(ctx0, 3);
inpL = ggml_add(ctx0, cur, inpFF);
}
@@ -1818,12 +1660,8 @@ static bool whisper_encode_internal(
// norm
{
wstate.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, cur, hparams.eps);
wstate.use_buf(ctx0, 1);
// cur = ln_f_g*cur + ln_f_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
@@ -1832,22 +1670,10 @@ static bool whisper_encode_internal(
ggml_repeat(ctx0, model.e_ln_b, cur));
}
wstate.use_buf(ctx0, -1);
// run the computation
{
struct ggml_cgraph gf = {};
ggml_build_forward_expand (&gf, cur);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//ggml_graph_print(&gf);
}
ggml_build_forward_expand (gf, cur);
}
#ifdef WHISPER_USE_COREML
else if (use_coreml) {
wstate.use_buf(ctx0, -1);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
@@ -1855,8 +1681,6 @@ static bool whisper_encode_internal(
#endif
#ifdef WHISPER_USE_OPENVINO
else if (use_openvino) {
wstate.use_buf(ctx0, -1);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
@@ -1865,69 +1689,6 @@ static bool whisper_encode_internal(
}
#endif
// cur
//{
// printf("ne0 = %d\n", cur->ne[0]);
// printf("ne1 = %d\n", cur->ne[1]);
// for (int i = 0; i < 10; ++i) {
// printf("%8.4f ", ((float *)(cur->data))[i]);
// }
// printf("... ");
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
// printf("%8.4f ", ((float *)(cur->data))[i]);
// }
// printf("\n");
//}
// pre-compute cross-attention memory
{
struct ggml_cgraph gf = {};
// TODO: hack to disconnect the encoded features from the previous graph
cur->op = GGML_OP_NONE;
cur->src[0] = nullptr;
cur->src[1] = nullptr;
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto& layer = model.layers_decoder[il];
wstate.use_buf(ctx0, 0);
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
layer.cross_attn_k_w,
cur);
Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
wstate.use_buf(ctx0, 1);
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
cur);
Vcross = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.cross_attn_v_b,
Vcross),
Vcross);
wstate.use_buf(ctx0, -1);
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
}
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//ggml_graph_print(&gf);
}
////////////////////////////////////////////////////////////////////////////
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
@@ -1939,32 +1700,105 @@ static bool whisper_encode_internal(
ggml_free(ctx0);
return gf;
}
// pre-compute cross-attention memory
static struct ggml_cgraph * whisper_build_graph_encoder_post(
whisper_context & wctx,
whisper_state & wstate,
struct ggml_tensor * embd_enc) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
const int n_state = hparams.n_audio_state;
const int n_head = hparams.n_audio_head;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur = embd_enc;
// TODO: hack to disconnect the encoded features from the previous graph
cur->op = GGML_OP_NONE;
cur->src[0] = nullptr;
cur->src[1] = nullptr;
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
layer.cross_attn_k_w,
cur);
Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
layer.cross_attn_v_w,
cur);
Vcross = ggml_add(ctx0,
ggml_repeat(ctx0,
layer.cross_attn_v_b,
Vcross),
Vcross);
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
}
ggml_free(ctx0);
return gf;
}
// evaluate the encoder with the given state
//
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
// part of the transformer model and returns the encoded features
//
// - wctx: the model
// - wstate: the state of the encoder
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
static bool whisper_encode_internal(
whisper_context & wctx,
whisper_state & wstate,
const int mel_offset,
const int n_threads) {
const int64_t t_start_us = ggml_time_us();
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
wstate.t_encode_us += ggml_time_us() - t_start_us;
wstate.n_encode++;
return true;
}
// evaluate the decoder
//
// given text prompt + audio features -> computes the logits for the next token
//
// - model: the model
// - n_threads: number of threads to use
// - tokens: text prompt
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
static bool whisper_decode_internal(
whisper_context & wctx,
whisper_state & wstate,
whisper_decoder & decoder,
const whisper_token * tokens,
const int n_tokens,
const int n_past,
const int n_threads) {
const int64_t t_start_us = ggml_time_us();
static struct ggml_cgraph * whisper_build_graph_decoder(
whisper_context & wctx,
whisper_state & wstate,
whisper_decoder & decoder,
const whisper_token * tokens,
int n_tokens,
int n_past) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
@@ -1972,10 +1806,6 @@ static bool whisper_decode_internal(
WHISPER_ASSERT(!!kv_self.ctx);
auto & logits_out = wstate.logits;
const int n_vocab = hparams.n_vocab;
const int n_ctx = hparams.n_text_ctx;
const int n_state = hparams.n_text_state;
const int n_head = hparams.n_text_head;
@@ -1989,12 +1819,12 @@ static bool whisper_decode_internal(
struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(),
/*.no_alloc =*/ false,
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, tokens, N*ggml_element_size(embd));
@@ -2004,8 +1834,6 @@ static bool whisper_decode_internal(
((int32_t *) position->data)[i] = n_past + i;
}
wstate.use_buf(ctx0, 3);
// token encoding + position encoding
struct ggml_tensor * cur =
ggml_add(ctx0,
@@ -2019,8 +1847,6 @@ static bool whisper_decode_internal(
// norm
{
wstate.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_0_w*cur + ln_0_b
@@ -2071,14 +1897,12 @@ static bool whisper_decode_internal(
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}
// ------
wstate.use_buf(ctx0, 0);
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
@@ -2093,8 +1917,6 @@ static bool whisper_decode_internal(
n_state/n_head, n_head, n_past + N),
0, 2, 1, 3);
wstate.use_buf(ctx0, 1);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
@@ -2126,28 +1948,20 @@ static bool whisper_decode_internal(
// projection
{
wstate.use_buf(ctx0, 0);
cur = ggml_mul_mat(ctx0,
layer.attn_ln_1_w,
cur);
wstate.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
cur);
}
wstate.use_buf(ctx0, 2);
// add the input
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
// norm
{
wstate.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
// cur = ln_0_w*cur + ln_0_b
@@ -2232,21 +2046,15 @@ static bool whisper_decode_internal(
// projection
{
wstate.use_buf(ctx0, 0);
cur = ggml_mul_mat(ctx0,
layer.cross_attn_ln_1_w,
cur);
wstate.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
cur);
}
wstate.use_buf(ctx0, 2);
// add the input
cur = ggml_add(ctx0, cur, inpCA);
@@ -2256,12 +2064,8 @@ static bool whisper_decode_internal(
{
// norm
{
wstate.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, inpFF, hparams.eps);
wstate.use_buf(ctx0, 1);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
@@ -2270,40 +2074,28 @@ static bool whisper_decode_internal(
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
}
wstate.use_buf(ctx0, 0);
// fully connected
cur = ggml_mul_mat(ctx0,
layer.mlp_0_w,
cur);
wstate.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_0_b, cur),
cur);
wstate.use_buf(ctx0, 0);
// GELU activation
cur = ggml_gelu(ctx0, cur);
wstate.use_buf(ctx0, 1);
// projection
cur = ggml_mul_mat(ctx0,
layer.mlp_1_w,
cur);
wstate.use_buf(ctx0, 0);
cur = ggml_add(ctx0,
ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
}
wstate.use_buf(ctx0, 3);
inpL = ggml_add(ctx0, cur, inpFF);
}
@@ -2311,12 +2103,8 @@ static bool whisper_decode_internal(
// norm
{
wstate.use_buf(ctx0, 0);
cur = ggml_norm(ctx0, cur, hparams.eps);
wstate.use_buf(ctx0, 1);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.d_ln_w, cur),
@@ -2324,8 +2112,6 @@ static bool whisper_decode_internal(
ggml_repeat(ctx0, model.d_ln_b, cur));
}
wstate.use_buf(ctx0, 0);
// compute logits only for the last token
// comment this line to compute logits for all N tokens
// might be useful in the future
@@ -2333,39 +2119,63 @@ static bool whisper_decode_internal(
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
wstate.use_buf(ctx0, -1);
// run the computation
{
ggml_build_forward_expand (&gf, logits);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
}
// extract logits for all N tokens
//logits_out.resize(N*n_vocab);
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
// extract logits only for the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
if (N > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// ggml_used_mem(ctx0)/1024.0/1024.0,
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
}
ggml_build_forward_expand(gf, logits);
ggml_free(ctx0);
wstate.t_decode_us += ggml_time_us() - t_start_us;
wstate.n_decode++;
return gf;
}
// evaluate the decoder
//
// given text prompt + audio features -> computes the logits for the next token
//
// - model: the model
// - n_threads: number of threads to use
// - tokens: text prompt
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
static bool whisper_decode_internal(
whisper_context & wctx,
whisper_state & wstate,
whisper_decoder & decoder,
const whisper_token * tokens,
const int n_tokens,
const int n_past,
const int n_threads) {
//const int64_t t_start_us = ggml_time_us();
//auto & logits_out = wstate.logits;
//const int n_vocab = hparams.n_vocab;
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//// extract logits for all N tokens
////logits_out.resize(N*n_vocab);
////memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
//// extract logits only for the last token
//logits_out.resize(n_vocab);
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
//if (N > 1) {
// //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
// // ggml_used_mem(ctx0)/1024.0/1024.0,
// // wstate.get_buf_max_mem(0)/1024.0/1024.0,
// // wstate.get_buf_max_mem(1)/1024.0/1024.0,
// // wstate.get_buf_max_mem(2)/1024.0/1024.0,
// // wstate.get_buf_max_mem(3)/1024.0/1024.0);
//}
//wstate.t_decode_us += ggml_time_us() - t_start_us;
//wstate.n_decode++;
return true;
}
// 500 -> 00:05.000
// 6000 -> 01:00.000
static std::string to_timestamp(int64_t t, bool comma = false) {
@@ -2774,9 +2584,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
fill_sin_cos_table();
whisper_state * state = new whisper_state;
const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state;
return nullptr;
@@ -2787,7 +2595,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
delete state;
return nullptr;
@@ -2825,12 +2633,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
static const size_t tensor_alignment = 32;
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
state->rng = std::mt19937(0);