mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-23 06:11:15 +02:00
whisper : factor out graph builds
This commit is contained in:
5
Makefile
5
Makefile
@@ -295,6 +295,11 @@ $(info )
|
|||||||
ggml.o: ggml.c ggml.h ggml-cuda.h
|
ggml.o: ggml.c ggml.h ggml-cuda.h
|
||||||
$(CC) $(CFLAGS) -c $< -o $@
|
$(CC) $(CFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
|
||||||
|
$(CC) $(CFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
WHISPER_OBJ += ggml-alloc.o
|
||||||
|
|
||||||
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
whisper.o: whisper.cpp whisper.h ggml.h ggml-cuda.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
538
whisper.cpp
538
whisper.cpp
@@ -12,6 +12,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#include "ggml-alloc.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@@ -119,9 +120,6 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|||||||
//#define WHISPER_USE_FLASH_FF
|
//#define WHISPER_USE_FLASH_FF
|
||||||
#define WHISPER_MAX_DECODERS 16
|
#define WHISPER_MAX_DECODERS 16
|
||||||
|
|
||||||
#define WHISPER_USE_SCRATCH
|
|
||||||
#define WHISPER_MAX_SCRATCH_BUFFERS 16
|
|
||||||
|
|
||||||
// available whisper models
|
// available whisper models
|
||||||
enum e_model {
|
enum e_model {
|
||||||
MODEL_UNKNOWN,
|
MODEL_UNKNOWN,
|
||||||
@@ -236,38 +234,7 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
|
|||||||
|
|
||||||
static const size_t MB = 1ull*1024*1024;
|
static const size_t MB = 1ull*1024*1024;
|
||||||
|
|
||||||
static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
|
// TODO: avoid using GGUF
|
||||||
{ MODEL_TINY, 62ull*MB },
|
|
||||||
{ MODEL_BASE, 80ull*MB },
|
|
||||||
{ MODEL_SMALL, 120ull*MB },
|
|
||||||
{ MODEL_MEDIUM, 158ull*MB },
|
|
||||||
{ MODEL_LARGE, 198ull*MB },
|
|
||||||
};
|
|
||||||
|
|
||||||
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
|
|
||||||
{ MODEL_TINY, 18ull*MB },
|
|
||||||
{ MODEL_BASE, 24ull*MB },
|
|
||||||
{ MODEL_SMALL, 36ull*MB },
|
|
||||||
{ MODEL_MEDIUM, 48ull*MB },
|
|
||||||
{ MODEL_LARGE, 60ull*MB },
|
|
||||||
};
|
|
||||||
|
|
||||||
static const std::map<e_model, size_t> MEM_REQ_SCRATCH2 = {
|
|
||||||
{ MODEL_TINY, 4ull*MB },
|
|
||||||
{ MODEL_BASE, 4ull*MB },
|
|
||||||
{ MODEL_SMALL, 6ull*MB },
|
|
||||||
{ MODEL_MEDIUM, 7ull*MB },
|
|
||||||
{ MODEL_LARGE, 9ull*MB },
|
|
||||||
};
|
|
||||||
|
|
||||||
static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
|
|
||||||
{ MODEL_TINY, 4ull*MB },
|
|
||||||
{ MODEL_BASE, 4ull*MB },
|
|
||||||
{ MODEL_SMALL, 6ull*MB },
|
|
||||||
{ MODEL_MEDIUM, 7ull*MB },
|
|
||||||
{ MODEL_LARGE, 9ull*MB },
|
|
||||||
};
|
|
||||||
|
|
||||||
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
||||||
{ GGML_TYPE_F32,
|
{ GGML_TYPE_F32,
|
||||||
{
|
{
|
||||||
@@ -334,38 +301,6 @@ static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
|
||||||
{ MODEL_TINY, 3ull*MB },
|
|
||||||
{ MODEL_BASE, 6ull*MB },
|
|
||||||
{ MODEL_SMALL, 16ull*MB },
|
|
||||||
{ MODEL_MEDIUM, 43ull*MB },
|
|
||||||
{ MODEL_LARGE, 71ull*MB },
|
|
||||||
};
|
|
||||||
|
|
||||||
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
|
|
||||||
{ MODEL_TINY, 9ull*MB },
|
|
||||||
{ MODEL_BASE, 18ull*MB },
|
|
||||||
{ MODEL_SMALL, 53ull*MB },
|
|
||||||
{ MODEL_MEDIUM, 141ull*MB },
|
|
||||||
{ MODEL_LARGE, 235ull*MB },
|
|
||||||
};
|
|
||||||
|
|
||||||
static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
|
|
||||||
{ MODEL_TINY, 30ull*MB },
|
|
||||||
{ MODEL_BASE, 38ull*MB },
|
|
||||||
{ MODEL_SMALL, 56ull*MB },
|
|
||||||
{ MODEL_MEDIUM, 74ull*MB },
|
|
||||||
{ MODEL_LARGE, 94ull*MB },
|
|
||||||
};
|
|
||||||
|
|
||||||
static const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
|
||||||
{ MODEL_TINY, 3ull*MB },
|
|
||||||
{ MODEL_BASE, 5ull*MB },
|
|
||||||
{ MODEL_SMALL, 10ull*MB },
|
|
||||||
{ MODEL_MEDIUM, 18ull*MB },
|
|
||||||
{ MODEL_LARGE, 27ull*MB },
|
|
||||||
};
|
|
||||||
|
|
||||||
struct whisper_mel {
|
struct whisper_mel {
|
||||||
int n_len;
|
int n_len;
|
||||||
int n_len_org;
|
int n_len_org;
|
||||||
@@ -670,10 +605,12 @@ struct whisper_state {
|
|||||||
|
|
||||||
// memory buffers used by encode / decode contexts
|
// memory buffers used by encode / decode contexts
|
||||||
std::vector<uint8_t> buf_compute;
|
std::vector<uint8_t> buf_compute;
|
||||||
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
|
||||||
|
|
||||||
int buf_last = 0;
|
std::vector<uint8_t> buf_encode;
|
||||||
size_t buf_max_size[WHISPER_MAX_SCRATCH_BUFFERS] = { 0 };
|
std::vector<uint8_t> buf_decode;
|
||||||
|
|
||||||
|
ggml_allocr * alloc_encode = NULL;
|
||||||
|
ggml_allocr * alloc_decode = NULL;
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
@@ -709,37 +646,6 @@ struct whisper_state {
|
|||||||
|
|
||||||
// [EXPERIMENTAL] speed-up techniques
|
// [EXPERIMENTAL] speed-up techniques
|
||||||
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
int32_t exp_n_audio_ctx = 0; // 0 - use default
|
||||||
|
|
||||||
void use_buf(struct ggml_context * ctx, int i) {
|
|
||||||
#if defined(WHISPER_USE_SCRATCH)
|
|
||||||
size_t last_size = 0;
|
|
||||||
|
|
||||||
if (i == -1) {
|
|
||||||
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
|
||||||
} else {
|
|
||||||
auto & buf = buf_scratch[i];
|
|
||||||
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (buf_last >= 0) {
|
|
||||||
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
buf_last = i;
|
|
||||||
#else
|
|
||||||
(void) i;
|
|
||||||
(void) ctx;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t get_buf_max_mem(int i) const {
|
|
||||||
#if defined(WHISPER_USE_SCRATCH)
|
|
||||||
return buf_max_size[i];
|
|
||||||
#else
|
|
||||||
(void) i;
|
|
||||||
return 0;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct whisper_context {
|
struct whisper_context {
|
||||||
@@ -786,10 +692,17 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|||||||
|
|
||||||
static bool kv_cache_init(
|
static bool kv_cache_init(
|
||||||
const struct whisper_hparams & hparams,
|
const struct whisper_hparams & hparams,
|
||||||
const size_t mem_bytes,
|
|
||||||
struct whisper_kv_cache & cache,
|
struct whisper_kv_cache & cache,
|
||||||
ggml_type wtype,
|
ggml_type wtype,
|
||||||
int n_ctx) {
|
int n_ctx) {
|
||||||
|
const int64_t n_text_state = hparams.n_text_state;
|
||||||
|
const int64_t n_text_layer = hparams.n_text_layer;
|
||||||
|
|
||||||
|
const int64_t n_mem = n_text_layer*n_ctx;
|
||||||
|
const int64_t n_elements = n_text_state*n_mem;
|
||||||
|
|
||||||
|
const size_t mem_bytes = ggml_type_size(wtype)*n_elements;
|
||||||
|
|
||||||
cache.buf.resize(mem_bytes);
|
cache.buf.resize(mem_bytes);
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
@@ -805,12 +718,6 @@ static bool kv_cache_init(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_text_state = hparams.n_text_state;
|
|
||||||
const int n_text_layer = hparams.n_text_layer;
|
|
||||||
|
|
||||||
const int n_mem = n_text_layer*n_ctx;
|
|
||||||
const int n_elements = n_text_state*n_mem;
|
|
||||||
|
|
||||||
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||||
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||||
|
|
||||||
@@ -953,22 +860,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
|
|
||||||
// print memory requirements
|
// print memory requirements
|
||||||
{
|
{
|
||||||
// this is the total memory required to run the inference
|
// TODO
|
||||||
const size_t mem_required =
|
//log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
||||||
MEM_REQ_SCRATCH0.at(model.type) +
|
// mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
||||||
MEM_REQ_SCRATCH1.at(model.type) +
|
|
||||||
MEM_REQ_SCRATCH2.at(model.type) +
|
|
||||||
MEM_REQ_SCRATCH3.at(model.type) +
|
|
||||||
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
|
|
||||||
scale*MEM_REQ_KV_CROSS.at(model.type) +
|
|
||||||
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
|
|
||||||
|
|
||||||
// this is the memory required by one decoder
|
|
||||||
const size_t mem_required_decoder =
|
|
||||||
scale*MEM_REQ_KV_SELF.at(model.type);
|
|
||||||
|
|
||||||
log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__,
|
|
||||||
mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize all memory buffers
|
// initialize all memory buffers
|
||||||
@@ -1477,24 +1371,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the encoder with the given state
|
static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||||
//
|
|
||||||
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
|
||||||
// part of the transformer model and returns the encoded features
|
|
||||||
//
|
|
||||||
// - wctx: the model
|
|
||||||
// - wstate: the state of the encoder
|
|
||||||
// - n_threads: number of threads to use
|
|
||||||
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
|
||||||
//
|
|
||||||
static bool whisper_encode_internal(
|
|
||||||
whisper_context & wctx,
|
whisper_context & wctx,
|
||||||
whisper_state & wstate,
|
whisper_state & wstate,
|
||||||
const int mel_offset,
|
const int mel_offset) {
|
||||||
const int n_threads){
|
|
||||||
|
|
||||||
const int64_t t_start_us = ggml_time_us();
|
|
||||||
|
|
||||||
const auto & model = wctx.model;
|
const auto & model = wctx.model;
|
||||||
const auto & mel_inp = wstate.mel;
|
const auto & mel_inp = wstate.mel;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
@@ -1510,12 +1390,12 @@ static bool whisper_encode_internal(
|
|||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ wstate.buf_compute.size(),
|
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||||
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||||
/*.no_alloc =*/ false,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
||||||
assert(mel->type == GGML_TYPE_F32);
|
assert(mel->type == GGML_TYPE_F32);
|
||||||
@@ -1550,8 +1430,6 @@ static bool whisper_encode_internal(
|
|||||||
if (!use_coreml && !use_openvino) {
|
if (!use_coreml && !use_openvino) {
|
||||||
// convolution + gelu
|
// convolution + gelu
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0,
|
ggml_repeat(ctx0,
|
||||||
@@ -1561,8 +1439,6 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0,
|
ggml_repeat(ctx0,
|
||||||
@@ -1573,8 +1449,6 @@ static bool whisper_encode_internal(
|
|||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 3);
|
|
||||||
|
|
||||||
// ===================================================================
|
// ===================================================================
|
||||||
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
||||||
//static int iter = -1;
|
//static int iter = -1;
|
||||||
@@ -1608,8 +1482,6 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
||||||
|
|
||||||
// cur = ln_0_w*cur + ln_0_b
|
// cur = ln_0_w*cur + ln_0_b
|
||||||
@@ -1622,8 +1494,6 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
// self-attention
|
// self-attention
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_q_w,
|
layer.attn_q_w,
|
||||||
cur);
|
cur);
|
||||||
@@ -1655,8 +1525,6 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
// ------
|
// ------
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
#ifdef WHISPER_USE_FLASH_ATTN
|
#ifdef WHISPER_USE_FLASH_ATTN
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
@@ -1722,8 +1590,6 @@ static bool whisper_encode_internal(
|
|||||||
#endif
|
#endif
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
cur = ggml_cpy(ctx0,
|
cur = ggml_cpy(ctx0,
|
||||||
KQV_merged,
|
KQV_merged,
|
||||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
||||||
@@ -1731,21 +1597,15 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
// projection
|
// projection
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_ln_1_w,
|
layer.attn_ln_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
||||||
cur);
|
cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 2);
|
|
||||||
|
|
||||||
// add the input
|
// add the input
|
||||||
cur = ggml_add(ctx0, cur, inpL);
|
cur = ggml_add(ctx0, cur, inpL);
|
||||||
|
|
||||||
@@ -1755,12 +1615,8 @@ static bool whisper_encode_internal(
|
|||||||
{
|
{
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
// cur = mlp_ln_w*cur + mlp_ln_b
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_mul(ctx0,
|
ggml_mul(ctx0,
|
||||||
@@ -1770,47 +1626,33 @@ static bool whisper_encode_internal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef WHISPER_USE_FLASH_FF
|
#ifdef WHISPER_USE_FLASH_FF
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_flash_ff(ctx0,
|
cur = ggml_flash_ff(ctx0,
|
||||||
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
||||||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
||||||
#else
|
#else
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// fully connected
|
// fully connected
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_0_w,
|
layer.mlp_0_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// GELU activation
|
// GELU activation
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
// projection
|
// projection
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_1_w,
|
layer.mlp_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
||||||
cur);
|
cur);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 3);
|
|
||||||
|
|
||||||
inpL = ggml_add(ctx0, cur, inpFF);
|
inpL = ggml_add(ctx0, cur, inpFF);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1818,12 +1660,8 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_norm(ctx0, cur, hparams.eps);
|
cur = ggml_norm(ctx0, cur, hparams.eps);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
// cur = ln_f_g*cur + ln_f_b
|
// cur = ln_f_g*cur + ln_f_b
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_mul(ctx0,
|
ggml_mul(ctx0,
|
||||||
@@ -1832,22 +1670,10 @@ static bool whisper_encode_internal(
|
|||||||
ggml_repeat(ctx0, model.e_ln_b, cur));
|
ggml_repeat(ctx0, model.e_ln_b, cur));
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, -1);
|
ggml_build_forward_expand (gf, cur);
|
||||||
|
|
||||||
// run the computation
|
|
||||||
{
|
|
||||||
struct ggml_cgraph gf = {};
|
|
||||||
|
|
||||||
ggml_build_forward_expand (&gf, cur);
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
||||||
|
|
||||||
//ggml_graph_print(&gf);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#ifdef WHISPER_USE_COREML
|
#ifdef WHISPER_USE_COREML
|
||||||
else if (use_coreml) {
|
else if (use_coreml) {
|
||||||
wstate.use_buf(ctx0, -1);
|
|
||||||
|
|
||||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
||||||
|
|
||||||
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
||||||
@@ -1855,8 +1681,6 @@ static bool whisper_encode_internal(
|
|||||||
#endif
|
#endif
|
||||||
#ifdef WHISPER_USE_OPENVINO
|
#ifdef WHISPER_USE_OPENVINO
|
||||||
else if (use_openvino) {
|
else if (use_openvino) {
|
||||||
wstate.use_buf(ctx0, -1);
|
|
||||||
|
|
||||||
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
|
||||||
|
|
||||||
if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
|
if (!whisper_openvino_encode(wstate.ctx_openvino, mel, cur)) {
|
||||||
@@ -1865,69 +1689,6 @@ static bool whisper_encode_internal(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// cur
|
|
||||||
//{
|
|
||||||
// printf("ne0 = %d\n", cur->ne[0]);
|
|
||||||
// printf("ne1 = %d\n", cur->ne[1]);
|
|
||||||
// for (int i = 0; i < 10; ++i) {
|
|
||||||
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
||||||
// }
|
|
||||||
// printf("... ");
|
|
||||||
// for (int i = cur->ne[0] - 10; i < cur->ne[0]; ++i) {
|
|
||||||
// printf("%8.4f ", ((float *)(cur->data))[i]);
|
|
||||||
// }
|
|
||||||
// printf("\n");
|
|
||||||
//}
|
|
||||||
|
|
||||||
// pre-compute cross-attention memory
|
|
||||||
{
|
|
||||||
struct ggml_cgraph gf = {};
|
|
||||||
|
|
||||||
// TODO: hack to disconnect the encoded features from the previous graph
|
|
||||||
cur->op = GGML_OP_NONE;
|
|
||||||
cur->src[0] = nullptr;
|
|
||||||
cur->src[1] = nullptr;
|
|
||||||
|
|
||||||
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
|
||||||
auto& layer = model.layers_decoder[il];
|
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
|
||||||
layer.cross_attn_k_w,
|
|
||||||
cur);
|
|
||||||
|
|
||||||
Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
|
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
|
||||||
layer.cross_attn_v_w,
|
|
||||||
cur);
|
|
||||||
|
|
||||||
Vcross = ggml_add(ctx0,
|
|
||||||
ggml_repeat(ctx0,
|
|
||||||
layer.cross_attn_v_b,
|
|
||||||
Vcross),
|
|
||||||
Vcross);
|
|
||||||
|
|
||||||
wstate.use_buf(ctx0, -1);
|
|
||||||
|
|
||||||
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
|
||||||
|
|
||||||
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
|
||||||
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
|
||||||
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
|
||||||
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
|
||||||
|
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k));
|
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
||||||
//ggml_graph_print(&gf);
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||||
@@ -1939,32 +1700,105 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
// pre-compute cross-attention memory
|
||||||
|
static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
||||||
|
whisper_context & wctx,
|
||||||
|
whisper_state & wstate,
|
||||||
|
struct ggml_tensor * embd_enc) {
|
||||||
|
const auto & model = wctx.model;
|
||||||
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
||||||
|
const int n_state = hparams.n_audio_state;
|
||||||
|
const int n_head = hparams.n_audio_head;
|
||||||
|
|
||||||
|
struct ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||||
|
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||||
|
/*.no_alloc =*/ true,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
|
|
||||||
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur = embd_enc;
|
||||||
|
|
||||||
|
// TODO: hack to disconnect the encoded features from the previous graph
|
||||||
|
cur->op = GGML_OP_NONE;
|
||||||
|
cur->src[0] = nullptr;
|
||||||
|
cur->src[1] = nullptr;
|
||||||
|
|
||||||
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
||||||
|
auto & layer = model.layers_decoder[il];
|
||||||
|
|
||||||
|
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
|
||||||
|
layer.cross_attn_k_w,
|
||||||
|
cur);
|
||||||
|
|
||||||
|
Kcross = ggml_scale_inplace(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state) / n_head, -0.25)));
|
||||||
|
|
||||||
|
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
|
||||||
|
layer.cross_attn_v_w,
|
||||||
|
cur);
|
||||||
|
|
||||||
|
Vcross = ggml_add(ctx0,
|
||||||
|
ggml_repeat(ctx0,
|
||||||
|
layer.cross_attn_v_b,
|
||||||
|
Vcross),
|
||||||
|
Vcross);
|
||||||
|
|
||||||
|
Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx));
|
||||||
|
|
||||||
|
struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx));
|
||||||
|
struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state,
|
||||||
|
( n_ctx)*ggml_element_size(wstate.kv_cross.v),
|
||||||
|
(il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k));
|
||||||
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_free(ctx0);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
|
// evaluate the encoder with the given state
|
||||||
|
//
|
||||||
|
// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder
|
||||||
|
// part of the transformer model and returns the encoded features
|
||||||
|
//
|
||||||
|
// - wctx: the model
|
||||||
|
// - wstate: the state of the encoder
|
||||||
|
// - n_threads: number of threads to use
|
||||||
|
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
||||||
|
//
|
||||||
|
static bool whisper_encode_internal(
|
||||||
|
whisper_context & wctx,
|
||||||
|
whisper_state & wstate,
|
||||||
|
const int mel_offset,
|
||||||
|
const int n_threads) {
|
||||||
|
const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
|
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
||||||
wstate.n_encode++;
|
wstate.n_encode++;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the decoder
|
static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||||
//
|
|
||||||
// given text prompt + audio features -> computes the logits for the next token
|
|
||||||
//
|
|
||||||
// - model: the model
|
|
||||||
// - n_threads: number of threads to use
|
|
||||||
// - tokens: text prompt
|
|
||||||
// - n_tokens: number of tokens in the prompt
|
|
||||||
// - n_past: number of past tokens to prefix the prompt with
|
|
||||||
//
|
|
||||||
static bool whisper_decode_internal(
|
|
||||||
whisper_context & wctx,
|
whisper_context & wctx,
|
||||||
whisper_state & wstate,
|
whisper_state & wstate,
|
||||||
whisper_decoder & decoder,
|
whisper_decoder & decoder,
|
||||||
const whisper_token * tokens,
|
const whisper_token * tokens,
|
||||||
const int n_tokens,
|
int n_tokens,
|
||||||
const int n_past,
|
int n_past) {
|
||||||
const int n_threads) {
|
|
||||||
const int64_t t_start_us = ggml_time_us();
|
|
||||||
|
|
||||||
const auto & model = wctx.model;
|
const auto & model = wctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
@@ -1972,10 +1806,6 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
WHISPER_ASSERT(!!kv_self.ctx);
|
WHISPER_ASSERT(!!kv_self.ctx);
|
||||||
|
|
||||||
auto & logits_out = wstate.logits;
|
|
||||||
|
|
||||||
const int n_vocab = hparams.n_vocab;
|
|
||||||
|
|
||||||
const int n_ctx = hparams.n_text_ctx;
|
const int n_ctx = hparams.n_text_ctx;
|
||||||
const int n_state = hparams.n_text_state;
|
const int n_state = hparams.n_text_state;
|
||||||
const int n_head = hparams.n_text_head;
|
const int n_head = hparams.n_text_head;
|
||||||
@@ -1989,12 +1819,12 @@ static bool whisper_decode_internal(
|
|||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ wstate.buf_compute.size(),
|
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||||
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||||
/*.no_alloc =*/ false,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
|
|
||||||
struct ggml_cgraph gf = {};
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||||
@@ -2004,8 +1834,6 @@ static bool whisper_decode_internal(
|
|||||||
((int32_t *) position->data)[i] = n_past + i;
|
((int32_t *) position->data)[i] = n_past + i;
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 3);
|
|
||||||
|
|
||||||
// token encoding + position encoding
|
// token encoding + position encoding
|
||||||
struct ggml_tensor * cur =
|
struct ggml_tensor * cur =
|
||||||
ggml_add(ctx0,
|
ggml_add(ctx0,
|
||||||
@@ -2019,8 +1847,6 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
cur = ggml_norm(ctx0, inpL, hparams.eps);
|
||||||
|
|
||||||
// cur = ln_0_w*cur + ln_0_b
|
// cur = ln_0_w*cur + ln_0_b
|
||||||
@@ -2071,14 +1897,12 @@ static bool whisper_decode_internal(
|
|||||||
( n_ctx)*ggml_element_size(kv_self.v),
|
( n_ctx)*ggml_element_size(kv_self.v),
|
||||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
|
(il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v));
|
||||||
|
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------
|
// ------
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
@@ -2093,8 +1917,6 @@ static bool whisper_decode_internal(
|
|||||||
n_state/n_head, n_head, n_past + N),
|
n_state/n_head, n_head, n_past + N),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
|
|
||||||
@@ -2126,28 +1948,20 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
// projection
|
// projection
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_ln_1_w,
|
layer.attn_ln_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
ggml_repeat(ctx0, layer.attn_ln_1_b, cur),
|
||||||
cur);
|
cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 2);
|
|
||||||
|
|
||||||
// add the input
|
// add the input
|
||||||
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL);
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
|
cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
|
||||||
|
|
||||||
// cur = ln_0_w*cur + ln_0_b
|
// cur = ln_0_w*cur + ln_0_b
|
||||||
@@ -2232,21 +2046,15 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
// projection
|
// projection
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.cross_attn_ln_1_w,
|
layer.cross_attn_ln_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
ggml_repeat(ctx0, layer.cross_attn_ln_1_b, cur),
|
||||||
cur);
|
cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 2);
|
|
||||||
|
|
||||||
// add the input
|
// add the input
|
||||||
cur = ggml_add(ctx0, cur, inpCA);
|
cur = ggml_add(ctx0, cur, inpCA);
|
||||||
|
|
||||||
@@ -2256,12 +2064,8 @@ static bool whisper_decode_internal(
|
|||||||
{
|
{
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
cur = ggml_norm(ctx0, inpFF, hparams.eps);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
// cur = mlp_ln_w*cur + mlp_ln_b
|
// cur = mlp_ln_w*cur + mlp_ln_b
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_mul(ctx0,
|
ggml_mul(ctx0,
|
||||||
@@ -2270,40 +2074,28 @@ static bool whisper_decode_internal(
|
|||||||
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// fully connected
|
// fully connected
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_0_w,
|
layer.mlp_0_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
ggml_repeat(ctx0, layer.mlp_0_b, cur),
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// GELU activation
|
// GELU activation
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
// projection
|
// projection
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_1_w,
|
layer.mlp_1_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
||||||
cur);
|
cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 3);
|
|
||||||
|
|
||||||
inpL = ggml_add(ctx0, cur, inpFF);
|
inpL = ggml_add(ctx0, cur, inpFF);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2311,12 +2103,8 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
cur = ggml_norm(ctx0, cur, hparams.eps);
|
cur = ggml_norm(ctx0, cur, hparams.eps);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 1);
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0,
|
cur = ggml_add(ctx0,
|
||||||
ggml_mul(ctx0,
|
ggml_mul(ctx0,
|
||||||
ggml_repeat(ctx0, model.d_ln_w, cur),
|
ggml_repeat(ctx0, model.d_ln_w, cur),
|
||||||
@@ -2324,8 +2112,6 @@ static bool whisper_decode_internal(
|
|||||||
ggml_repeat(ctx0, model.d_ln_b, cur));
|
ggml_repeat(ctx0, model.d_ln_b, cur));
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.use_buf(ctx0, 0);
|
|
||||||
|
|
||||||
// compute logits only for the last token
|
// compute logits only for the last token
|
||||||
// comment this line to compute logits for all N tokens
|
// comment this line to compute logits for all N tokens
|
||||||
// might be useful in the future
|
// might be useful in the future
|
||||||
@@ -2333,39 +2119,63 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
|
||||||
|
|
||||||
wstate.use_buf(ctx0, -1);
|
ggml_build_forward_expand(gf, logits);
|
||||||
|
|
||||||
// run the computation
|
|
||||||
{
|
|
||||||
ggml_build_forward_expand (&gf, logits);
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
|
||||||
}
|
|
||||||
|
|
||||||
// extract logits for all N tokens
|
|
||||||
//logits_out.resize(N*n_vocab);
|
|
||||||
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
|
|
||||||
|
|
||||||
// extract logits only for the last token
|
|
||||||
logits_out.resize(n_vocab);
|
|
||||||
memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
|
||||||
|
|
||||||
if (N > 1) {
|
|
||||||
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
|
||||||
// ggml_used_mem(ctx0)/1024.0/1024.0,
|
|
||||||
// wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
|
||||||
// wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
|
||||||
// wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
|
||||||
// wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
|
||||||
wstate.t_decode_us += ggml_time_us() - t_start_us;
|
return gf;
|
||||||
wstate.n_decode++;
|
}
|
||||||
|
|
||||||
|
// evaluate the decoder
|
||||||
|
//
|
||||||
|
// given text prompt + audio features -> computes the logits for the next token
|
||||||
|
//
|
||||||
|
// - model: the model
|
||||||
|
// - n_threads: number of threads to use
|
||||||
|
// - tokens: text prompt
|
||||||
|
// - n_tokens: number of tokens in the prompt
|
||||||
|
// - n_past: number of past tokens to prefix the prompt with
|
||||||
|
//
|
||||||
|
static bool whisper_decode_internal(
|
||||||
|
whisper_context & wctx,
|
||||||
|
whisper_state & wstate,
|
||||||
|
whisper_decoder & decoder,
|
||||||
|
const whisper_token * tokens,
|
||||||
|
const int n_tokens,
|
||||||
|
const int n_past,
|
||||||
|
const int n_threads) {
|
||||||
|
//const int64_t t_start_us = ggml_time_us();
|
||||||
|
|
||||||
|
//auto & logits_out = wstate.logits;
|
||||||
|
|
||||||
|
//const int n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
|
// ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
|
//// extract logits for all N tokens
|
||||||
|
////logits_out.resize(N*n_vocab);
|
||||||
|
////memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
|
||||||
|
|
||||||
|
//// extract logits only for the last token
|
||||||
|
//logits_out.resize(n_vocab);
|
||||||
|
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
|
||||||
|
|
||||||
|
//if (N > 1) {
|
||||||
|
// //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||||
|
// // ggml_used_mem(ctx0)/1024.0/1024.0,
|
||||||
|
// // wstate.get_buf_max_mem(0)/1024.0/1024.0,
|
||||||
|
// // wstate.get_buf_max_mem(1)/1024.0/1024.0,
|
||||||
|
// // wstate.get_buf_max_mem(2)/1024.0/1024.0,
|
||||||
|
// // wstate.get_buf_max_mem(3)/1024.0/1024.0);
|
||||||
|
//}
|
||||||
|
|
||||||
|
//wstate.t_decode_us += ggml_time_us() - t_start_us;
|
||||||
|
//wstate.n_decode++;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// 500 -> 00:05.000
|
// 500 -> 00:05.000
|
||||||
// 6000 -> 01:00.000
|
// 6000 -> 01:00.000
|
||||||
static std::string to_timestamp(int64_t t, bool comma = false) {
|
static std::string to_timestamp(int64_t t, bool comma = false) {
|
||||||
@@ -2774,9 +2584,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
fill_sin_cos_table();
|
fill_sin_cos_table();
|
||||||
whisper_state * state = new whisper_state;
|
whisper_state * state = new whisper_state;
|
||||||
|
|
||||||
const size_t scale = ctx->model.hparams.ftype ? 1 : 2;
|
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
||||||
|
|
||||||
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
|
||||||
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
log("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
delete state;
|
delete state;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@@ -2787,7 +2595,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_CROSS.at(ctx->model.type), state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
||||||
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
||||||
delete state;
|
delete state;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@@ -2825,12 +2633,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
||||||
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
||||||
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
||||||
state->buf_compute.resize(scale * std::max(MEM_REQ_ENCODE.at(ctx->model.type), MEM_REQ_DECODE.at(ctx->model.type)));
|
|
||||||
|
|
||||||
state->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
|
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||||
state->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
|
|
||||||
state->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type));
|
static const size_t tensor_alignment = 32;
|
||||||
state->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type));
|
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
|
||||||
|
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
|
||||||
|
|
||||||
state->rng = std::mt19937(0);
|
state->rng = std::mt19937(0);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user