mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-19 16:06:25 +02:00
whisper : factor out graph builds
This commit is contained in:
5
Makefile
5
Makefile
@@ -295,6 +295,11 @@ $(info )
|
||||
ggml.o: ggml.c ggml.h ggml-cuda.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
|
||||
WHISPER_OBJ += ggml-alloc.o
|
||||
|
||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
||||
|
546
whisper.cpp
546
whisper.cpp
@@ -12,6 +12,7 @@
|
||||
#endif
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-alloc.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -119,9 +120,6 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
||||
//#define WHISPER_USE_FLASH_FF
|
||||
#define WHISPER_MAX_DECODERS 16
|
||||
|
||||
#define WHISPER_USE_SCRATCH
|
||||
#define WHISPER_MAX_SCRATCH_BUFFERS 16
|
||||
|
||||
// available whisper models
|
||||
enum e_model {
|
||||
MODEL_UNKNOWN,
|
||||
@@ -236,38 +234,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
||||
|
||||
static const size_t MB = 1ull*1024*1024;
|
||||
|
||||
static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
|
||||
{ MODEL_TINY, 62ull*MB },
|
||||
{ MODEL_BASE, 80ull*MB },
|
||||
{ MODEL_SMALL, 120ull*MB },
|
||||
{ MODEL_MEDIUM, 158ull*MB },
|
||||
{ MODEL_LARGE, 198ull*MB },
|
||||
};
|
||||
|
||||
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
|
||||
{ MODEL_TINY, 18ull*MB },
|
||||
{ MODEL_BASE, 24ull*MB },
|
||||
{ MODEL_SMALL, 36ull*MB },
|
||||
{ MODEL_MEDIUM, 48ull*MB },
|
||||
{ MODEL_LARGE, 60ull*MB },
|
||||
};
|
||||
|
||||
static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
|
||||
{ MODEL_TINY, 4ull*MB },
|
||||
{ MODEL_BASE, 4ull*MB },
|
||||
{ MODEL_SMALL, 6ull*MB },
|
||||
{ MODEL_MEDIUM, 7ull*MB },
|
||||
{ MODEL_LARGE, 9ull*MB },
|
||||
};
|
||||
|
||||
static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
|
||||
{ MODEL_TINY, 4ull*MB },
|
||||
{ MODEL_BASE, 4ull*MB },
|
||||
{ MODEL_SMALL, 6ull*MB },
|
||||
{ MODEL_MEDIUM, 7ull*MB },
|
||||
{ MODEL_LARGE, 9ull*MB },
|
||||
};
|
||||
|
||||
// TODO: avoid using GGUF
|
||||
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
||||
{ GGML_TYPE_F32,
|
||||
{
|
||||
@@ -334,38 +301,6 @@ static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
||||
},
|
||||
};
|
||||
|
||||
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
||||
{ MODEL_TINY, 3ull*MB },
|
||||
{ MODEL_BASE, 6ull*MB },
|
||||
{ MODEL_SMALL, 16ull*MB },
|
||||
{ MODEL_MEDIUM, 43ull*MB },
|
||||
{ MODEL_LARGE, 71ull*MB },
|
||||
};
|
||||
|
||||
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
|
||||
{ MODEL_TINY, 9ull*MB },
|
||||
{ MODEL_BASE, 18ull*MB },
|
||||
{ MODEL_SMALL, 53ull*MB },
|
||||
{ MODEL_MEDIUM, 141ull*MB },
|
||||
{ MODEL_LARGE, 235ull*MB },
|
||||
};
|
||||
|
||||
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
||||
{ MODEL_TINY, 30ull*MB },
|
||||
{ MODEL_BASE, 38ull*MB },
|
||||
{ MODEL_SMALL, 56ull*MB },
|
||||
{ MODEL_MEDIUM, 74ull*MB },
|
||||
{ MODEL_LARGE, 94ull*MB },
|
||||
};
|
||||
|
||||
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
||||
{ MODEL_TINY, 3ull*MB },
|
||||
{ MODEL_BASE, 5ull*MB },
|
||||
{ MODEL_SMALL, 10ull*MB },
|
||||
{ MODEL_MEDIUM, 18ull*MB },
|
||||
{ MODEL_LARGE, 27ull*MB },
|
||||
};
|
||||
|
||||
struct whisper_mel {
|
||||
int n_len;
|
||||
int n_len_org;
|
||||
@@ -670,10 +605,12 @@ struct whisper_state {
|
||||
|
||||
// memory buffers used by encode / decode contexts
|
||||
std::vector<uint8_t> buf_compute;
|
||||
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
||||
|
||||
int buf_last = 0;
|
||||
size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
|
||||
std::vector<uint8_t> buf_encode;
|
||||
std::vector<uint8_t> buf_decode;
|
||||
|
||||
ggml_allocr * alloc_encode = NULL;
|
||||
ggml_allocr * alloc_decode = NULL;
|
||||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
std::vector<float> logits;
|
||||
@@ -709,37 +646,6 @@ struct whisper_state {
|
||||
|
||||
// [EXPERIMENTAL] speed-up techniques
|
||||
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
||||
|
||||
void use_buf(struct ggml_context * ctx, int i) {
|
||||
#if defined(WHISPER_USE_SCRATCH)
|
||||
size_t last_size = 0;
|
||||
|
||||
if (i == -1) {
|
||||
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
||||
} else {
|
||||
auto & buf = buf_scratch[i];
|
||||
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
|
||||
}
|
||||
|
||||
if (buf_last >= 0) {
|
||||
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
||||
}
|
||||
|
||||
buf_last = i;
|
||||
#else
|
||||
(void) i;
|
||||
(void) ctx;
|
||||
#endif
|
||||
}
|
||||
|
||||
size_t get_buf_max_mem(int i) const {
|
||||
#if defined(WHISPER_USE_SCRATCH)
|
||||
return buf_max_size[i];
|
||||
#else
|
||||
(void) i;
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
struct whisper_context {
|
||||
@@ -786,10 +692,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
||||
|
||||
static bool kv_cache_init(
|
||||
const struct whisper_hparams & hparams,
|
||||
const size_t mem_bytes,
|
||||
struct whisper_kv_cache & cache,
|
||||
ggml_type wtype,
|
||||
int n_ctx) {
|
||||
const int64_t n_text_state = hparams.n_text_state;
|
||||
const int64_t n_text_layer = hparams.n_text_layer;
|
||||
|
||||
const int64_t n_mem = n_text_layer*n_ctx;
|
||||
const int64_t n_elements = n_text_state*n_mem;
|
||||
|
||||
const size_t mem_bytes = ggml_type_size(wtype)*n_elements;
|
||||
|
||||
cache.buf.resize(mem_bytes);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
@@ -805,12 +718,6 @@ static bool kv_cache_init(
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_text_state = hparams.n_text_state;
|
||||
const int n_text_layer = hparams.n_text_layer;
|
||||
|
||||
const int n_mem = n_text_layer*n_ctx;
|
||||
const int n_elements = n_text_state*n_mem;
|
||||
|
||||
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||
|
||||
@@ -953,22 +860,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
||||
|
||||
// print memory requirements
|
||||
{
|
||||
// this is the total memory required to run the inference
|
||||
const size_t mem_required =
|
||||
MEM_REQ_SCRATCH0.at(model.type) +
|
||||
MEM_REQ_SCRATCH1.at(model.type) +
|
||||
MEM_REQ_SCRATCH2.at(model.type) +
|
||||
MEM_REQ_SCRATCH3.at(model.type) +
|
||||
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
|
||||
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
||||
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
||||
|
||||
// this is the memory required by one decoder
|
||||
const size_t mem_required_decoder =
|
||||
scale*MEM_REQ_KV_SELF.at(model.type);
|
||||
|
||||
log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
||||
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
||||
// TODO
|
||||
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
||||
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// initialize all memory buffers
|
||||
@@ -1477,24 +1371,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
||||
return true;
|
||||
}
|
||||
|
||||
// evaluate the encoder with the given state
|
||||
//
|
||||
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
||||
// part of the transformer model and returns the encoded features
|
||||
//
|
||||
// - wctx: the model
|
||||
// - wstate: the state of the encoder
|
||||
// - n_threads: number of threads to use
|
||||
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
||||
//
|
||||
static bool whisper_encode_internal(
|
||||
static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
const int mel_offset,
|
||||
const int n_threads){
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
const int mel_offset) {
|
||||
const auto & model = wctx.model;
|
||||
const auto & mel_inp = wstate.mel;
|
||||
const auto & hparams = model.hparams;
|
||||
@@ -1510,12 +1390,12 @@ static bool whisper_encode_internal(
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||
/*.no_alloc =*/ false,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
||||
assert(mel->type == GGML_TYPE_F32);
|
||||
@@ -1550,8 +1430,6 @@ static bool whisper_encode_internal(
|
||||
if (!use_coreml && !use_openvino) {
|
||||
// convolution + gelu
|
||||
{
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
@@ -1561,8 +1439,6 @@ static bool whisper_encode_internal(
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
@@ -1573,8 +1449,6 @@ static bool whisper_encode_internal(
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 3);
|
||||
|
||||
// ===================================================================
|
||||
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
||||
//static int iter = -1;
|
||||
@@ -1608,8 +1482,6 @@ static bool whisper_encode_internal(
|
||||
|
||||
// norm
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
||||
|
||||
// cur = ln_0_w*cur + ln_0_b
|
||||
@@ -1622,8 +1494,6 @@ static bool whisper_encode_internal(
|
||||
|
||||
// self-attention
|
||||
{
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
||||
layer.attn_q_w,
|
||||
cur);
|
||||
@@ -1655,8 +1525,6 @@ static bool whisper_encode_internal(
|
||||
|
||||
// ------
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
#ifdef WHISPER_USE_FLASH_ATTN
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
@@ -1722,8 +1590,6 @@ static bool whisper_encode_internal(
|
||||
#endif
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
||||
@@ -1731,21 +1597,15 @@ static bool whisper_encode_internal(
|
||||
|
||||
// projection
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.attn_ln_1_w,
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 2);
|
||||
|
||||
// add the input
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
@@ -1755,12 +1615,8 @@ static bool whisper_encode_internal(
|
||||
{
|
||||
// norm
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
// cur = mlp_ln_w*cur + mlp_ln_b
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
@@ -1770,47 +1626,33 @@ static bool whisper_encode_internal(
|
||||
}
|
||||
|
||||
#ifdef WHISPER_USE_FLASH_FF
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_flash_ff(ctx0,
|
||||
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
||||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
||||
#else
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
// fully connected
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.mlp_0_w,
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
// GELU activation
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
// projection
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.mlp_1_w,
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
||||
cur);
|
||||
#endif
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 3);
|
||||
|
||||
inpL = ggml_add(ctx0, cur, inpFF);
|
||||
}
|
||||
|
||||
@@ -1818,12 +1660,8 @@ static bool whisper_encode_internal(
|
||||
|
||||
// norm
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_norm(ctx0, cur, hparams.eps);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
// cur = ln_f_g*cur + ln_f_b
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
@@ -1832,22 +1670,10 @@ static bool whisper_encode_internal(
|
||||
ggml_repeat(ctx0, model.e_ln_b, cur));
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, -1);
|
||||
|
||||
// run the computation
|
||||
{
|
||||
struct ggml_cgraph gf = {};
|
||||
|
||||
ggml_build_forward_expand (&gf, cur);
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
//ggml_graph_print(&gf);
|
||||
}
|
||||
ggml_build_forward_expand (gf, cur);
|
||||
}
|
||||
#ifdef WHISPER_USE_COREML
|
||||
else if (use_coreml) {
|
||||
wstate.use_buf(ctx0, -1);
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
||||
|
||||
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
||||
@@ -1855,8 +1681,6 @@ static bool whisper_encode_internal(
|
||||
#endif
|
||||
#ifdef WHISPER_USE_OPENVINO
|
||||
else if (use_openvino) {
|
||||
wstate.use_buf(ctx0, -1);
|
||||
|
||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
||||
|
||||
if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
|
||||
@@ -1865,69 +1689,6 @@ static bool whisper_encode_internal(
|
||||
}
|
||||
#endif
|
||||
|
||||
// cur
|
||||
//{
|
||||
// printf("ne0 = %d\n", cur->ne[0]);
|
||||
// printf("ne1 = %d\n", cur->ne[1]);
|
||||
// for (int i = 0; i < 10; ++i) {
|
||||
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
||||
// }
|
||||
// printf("... ");
|
||||
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
|
||||
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
||||
// }
|
||||
// printf("\n");
|
||||
//}
|
||||
|
||||
// pre-compute cross-attention memory
|
||||
{
|
||||
struct ggml_cgraph gf = {};
|
||||
|
||||
// TODO: hack to disconnect the encoded features from the previous graph
|
||||
cur->op = GGML_OP_NONE;
|
||||
cur->src[0] = nullptr;
|
||||
cur->src[1] = nullptr;
|
||||
|
||||
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
||||
auto& layer = model.layers_decoder[il];
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_k_w,
|
||||
cur);
|
||||
|
||||
Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_v_w,
|
||||
cur);
|
||||
|
||||
Vcross = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
layer.cross_attn_v_b,
|
||||
Vcross),
|
||||
Vcross);
|
||||
|
||||
wstate.use_buf(ctx0, -1);
|
||||
|
||||
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
||||
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
||||
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
||||
}
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
//ggml_graph_print(&gf);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||
@@ -1939,32 +1700,105 @@ static bool whisper_encode_internal(
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
// pre-compute cross-attention memory
|
||||
static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
struct ggml_tensor * embd_enc) {
|
||||
const auto & model = wctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
||||
const int n_state = hparams.n_audio_state;
|
||||
const int n_head = hparams.n_audio_head;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * cur = embd_enc;
|
||||
|
||||
// TODO: hack to disconnect the encoded features from the previous graph
|
||||
cur->op = GGML_OP_NONE;
|
||||
cur->src[0] = nullptr;
|
||||
cur->src[1] = nullptr;
|
||||
|
||||
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
||||
auto & layer = model.layers_decoder[il];
|
||||
|
||||
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_k_w,
|
||||
cur);
|
||||
|
||||
Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
|
||||
|
||||
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_v_w,
|
||||
cur);
|
||||
|
||||
Vcross = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
layer.cross_attn_v_b,
|
||||
Vcross),
|
||||
Vcross);
|
||||
|
||||
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
||||
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
||||
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
|
||||
}
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
// evaluate the encoder with the given state
|
||||
//
|
||||
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
||||
// part of the transformer model and returns the encoded features
|
||||
//
|
||||
// - wctx: the model
|
||||
// - wstate: the state of the encoder
|
||||
// - n_threads: number of threads to use
|
||||
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
||||
//
|
||||
static bool whisper_encode_internal(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
const int mel_offset,
|
||||
const int n_threads) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
||||
wstate.n_encode++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// evaluate the decoder
|
||||
//
|
||||
// given text prompt + audio features -> computes the logits for the next token
|
||||
//
|
||||
// - model: the model
|
||||
// - n_threads: number of threads to use
|
||||
// - tokens: text prompt
|
||||
// - n_tokens: number of tokens in the prompt
|
||||
// - n_past: number of past tokens to prefix the prompt with
|
||||
//
|
||||
static bool whisper_decode_internal(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
whisper_decoder & decoder,
|
||||
const whisper_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads) {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
whisper_decoder & decoder,
|
||||
const whisper_token * tokens,
|
||||
int n_tokens,
|
||||
int n_past) {
|
||||
const auto & model = wctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
@@ -1972,10 +1806,6 @@ static bool whisper_decode_internal(
|
||||
|
||||
WHISPER_ASSERT(!!kv_self.ctx);
|
||||
|
||||
auto & logits_out = wstate.logits;
|
||||
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
|
||||
const int n_ctx = hparams.n_text_ctx;
|
||||
const int n_state = hparams.n_text_state;
|
||||
const int n_head = hparams.n_text_head;
|
||||
@@ -1989,12 +1819,12 @@ static bool whisper_decode_internal(
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||
/*.no_alloc =*/ false,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_cgraph gf = {};
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||
@@ -2004,8 +1834,6 @@ static bool whisper_decode_internal(
|
||||
((int32_t *) position->data)[i] = n_past + i;
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 3);
|
||||
|
||||
// token encoding + position encoding
|
||||
struct ggml_tensor * cur =
|
||||
ggml_add(ctx0,
|
||||
@@ -2019,8 +1847,6 @@ static bool whisper_decode_internal(
|
||||
|
||||
// norm
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
||||
|
||||
// cur = ln_0_w*cur + ln_0_b
|
||||
@@ -2071,14 +1897,12 @@ static bool whisper_decode_internal(
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
// ------
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
@@ -2093,8 +1917,6 @@ static bool whisper_decode_internal(
|
||||
n_state/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
@@ -2126,28 +1948,20 @@ static bool whisper_decode_internal(
|
||||
|
||||
// projection
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.attn_ln_1_w,
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 2);
|
||||
|
||||
// add the input
|
||||
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
||||
|
||||
// norm
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
|
||||
|
||||
// cur = ln_0_w*cur + ln_0_b
|
||||
@@ -2232,21 +2046,15 @@ static bool whisper_decode_internal(
|
||||
|
||||
// projection
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_ln_1_w,
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 2);
|
||||
|
||||
// add the input
|
||||
cur = ggml_add(ctx0, cur, inpCA);
|
||||
|
||||
@@ -2256,12 +2064,8 @@ static bool whisper_decode_internal(
|
||||
{
|
||||
// norm
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
// cur = mlp_ln_w*cur + mlp_ln_b
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
@@ -2270,40 +2074,28 @@ static bool whisper_decode_internal(
|
||||
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
// fully connected
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.mlp_0_w,
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
// GELU activation
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
// projection
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
layer.mlp_1_w,
|
||||
cur);
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
||||
cur);
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 3);
|
||||
|
||||
inpL = ggml_add(ctx0, cur, inpFF);
|
||||
}
|
||||
|
||||
@@ -2311,12 +2103,8 @@ static bool whisper_decode_internal(
|
||||
|
||||
// norm
|
||||
{
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
cur = ggml_norm(ctx0, cur, hparams.eps);
|
||||
|
||||
wstate.use_buf(ctx0, 1);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.d_ln_w, cur),
|
||||
@@ -2324,8 +2112,6 @@ static bool whisper_decode_internal(
|
||||
ggml_repeat(ctx0, model.d_ln_b, cur));
|
||||
}
|
||||
|
||||
wstate.use_buf(ctx0, 0);
|
||||
|
||||
// compute logits only for the last token
|
||||
// comment this line to compute logits for all N tokens
|
||||
// might be useful in the future
|
||||
@@ -2333,39 +2119,63 @@ static bool whisper_decode_internal(
|
||||
|
||||
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
||||
|
||||
wstate.use_buf(ctx0, -1);
|
||||
|
||||
// run the computation
|
||||
{
|
||||
ggml_build_forward_expand (&gf, logits);
|
||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
}
|
||||
|
||||
// extract logits for all N tokens
|
||||
//logits_out.resize(N*n_vocab);
|
||||
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
|
||||
|
||||
// extract logits only for the last token
|
||||
logits_out.resize(n_vocab);
|
||||
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
||||
|
||||
if (N > 1) {
|
||||
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
||||
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
||||
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
||||
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
||||
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
||||
}
|
||||
ggml_build_forward_expand(gf, logits);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
||||
wstate.n_decode++;
|
||||
return gf;
|
||||
}
|
||||
|
||||
// evaluate the decoder
|
||||
//
|
||||
// given text prompt + audio features -> computes the logits for the next token
|
||||
//
|
||||
// - model: the model
|
||||
// - n_threads: number of threads to use
|
||||
// - tokens: text prompt
|
||||
// - n_tokens: number of tokens in the prompt
|
||||
// - n_past: number of past tokens to prefix the prompt with
|
||||
//
|
||||
static bool whisper_decode_internal(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
whisper_decoder & decoder,
|
||||
const whisper_token * tokens,
|
||||
const int n_tokens,
|
||||
const int n_past,
|
||||
const int n_threads) {
|
||||
//const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
//auto & logits_out = wstate.logits;
|
||||
|
||||
//const int n_vocab = hparams.n_vocab;
|
||||
|
||||
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||
|
||||
//// extract logits for all N tokens
|
||||
////logits_out.resize(N*n_vocab);
|
||||
////memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
|
||||
|
||||
//// extract logits only for the last token
|
||||
//logits_out.resize(n_vocab);
|
||||
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
||||
|
||||
//if (N > 1) {
|
||||
// //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||
// // ggml_used_mem(ctx0)/1024.0/1024.0,
|
||||
// // wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
||||
// // wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
||||
// // wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
||||
// // wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
||||
//}
|
||||
|
||||
//wstate.t_decode_us += ggml_time_us() - t_start_us;
|
||||
//wstate.n_decode++;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
static std::string to_timestamp(int64_t t, bool comma = false) {
|
||||
@@ -2774,9 +2584,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
fill_sin_cos_table();
|
||||
whisper_state * state = new whisper_state;
|
||||
|
||||
const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
|
||||
|
||||
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
||||
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
||||
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
delete state;
|
||||
return nullptr;
|
||||
@@ -2787,7 +2595,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
||||
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
||||
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
||||
delete state;
|
||||
return nullptr;
|
||||
@@ -2825,12 +2633,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
||||
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
||||
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
||||
state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
|
||||
|
||||
state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
|
||||
state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
|
||||
state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
|
||||
state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
|
||||
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
static const size_t tensor_alignment = 32;
|
||||
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
|
||||
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
|
||||
|
||||
state->rng = std::mt19937(0);
|
||||
|
||||
|
Reference in New Issue
Block a user