mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-08 18:14:38 +02:00
whisper : add kv_pad
This commit is contained in:
parent
7c94a11162
commit
2877b026cf
110
whisper.cpp
110
whisper.cpp
@ -148,7 +148,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
//#define WHISPER_USE_FLASH_ATTN
|
#define WHISPER_USE_FLASH_ATTN
|
||||||
//#define WHISPER_USE_FLASH_FF
|
//#define WHISPER_USE_FLASH_FF
|
||||||
#define WHISPER_MAX_DECODERS 8
|
#define WHISPER_MAX_DECODERS 8
|
||||||
#define WHISPER_MAX_NODES 4096
|
#define WHISPER_MAX_NODES 4096
|
||||||
@ -810,6 +810,9 @@ struct whisper_state {
|
|||||||
// shared between all decoders
|
// shared between all decoders
|
||||||
whisper_kv_cache kv_cross;
|
whisper_kv_cache kv_cross;
|
||||||
|
|
||||||
|
// padded buffer for flash-attention
|
||||||
|
whisper_kv_cache kv_pad;
|
||||||
|
|
||||||
whisper_mel mel;
|
whisper_mel mel;
|
||||||
|
|
||||||
whisper_batch batch;
|
whisper_batch batch;
|
||||||
@ -903,14 +906,12 @@ static void read_safe(whisper_model_loader * loader, T & dest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool kv_cache_init(
|
static bool kv_cache_init(
|
||||||
const struct whisper_hparams & hparams,
|
|
||||||
struct whisper_kv_cache & cache,
|
struct whisper_kv_cache & cache,
|
||||||
ggml_backend_t backend,
|
ggml_backend_t backend,
|
||||||
ggml_type wtype,
|
ggml_type wtype,
|
||||||
|
int64_t n_text_state,
|
||||||
|
int64_t n_text_layer,
|
||||||
int n_ctx) {
|
int n_ctx) {
|
||||||
const int64_t n_text_state = hparams.n_text_state;
|
|
||||||
const int64_t n_text_layer = hparams.n_text_layer;
|
|
||||||
|
|
||||||
const int64_t n_mem = n_text_layer*n_ctx;
|
const int64_t n_mem = n_text_layer*n_ctx;
|
||||||
const int64_t n_elements = n_text_state*n_mem;
|
const int64_t n_elements = n_text_state*n_mem;
|
||||||
|
|
||||||
@ -942,6 +943,8 @@ static bool kv_cache_init(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_backend_buffer_clear(cache.buffer, 0);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1873,6 +1876,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
const int n_head = hparams.n_audio_head;
|
const int n_head = hparams.n_audio_head;
|
||||||
const int n_layer = hparams.n_audio_layer;
|
const int n_layer = hparams.n_audio_layer;
|
||||||
|
|
||||||
|
const int n_state_head = n_state/n_head;
|
||||||
|
|
||||||
|
auto & kv_pad = wstate.kv_pad;
|
||||||
|
|
||||||
|
WHISPER_ASSERT(!!kv_pad.ctx);
|
||||||
|
|
||||||
|
const int n_ctx_pad = GGML_PAD(n_ctx, 256);
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
|
/*.mem_size =*/ wstate.alloc_encode.meta.size(),
|
||||||
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
|
/*.mem_buffer =*/ wstate.alloc_encode.meta.data(),
|
||||||
@ -1885,7 +1896,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
|
|
||||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
|
||||||
|
|
||||||
const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
|
const float KQscale = 1.0f/sqrtf(float(n_state_head));
|
||||||
|
|
||||||
// ===================================================================
|
// ===================================================================
|
||||||
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
// NOTE: experimenting with partial evaluation of the encoder (ignore)
|
||||||
@ -1935,14 +1946,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
|
|
||||||
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b);
|
||||||
|
|
||||||
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state)/n_head, -0.25));
|
//Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25));
|
||||||
|
|
||||||
// note: no bias for Key
|
// note: no bias for Key
|
||||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_k_w,
|
layer.attn_k_w,
|
||||||
cur);
|
cur);
|
||||||
|
|
||||||
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state)/n_head, -0.25));
|
//Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25));
|
||||||
|
|
||||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
|
||||||
layer.attn_v_w,
|
layer.attn_v_w,
|
||||||
@ -1956,31 +1967,31 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
Qcur,
|
Qcur,
|
||||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)),
|
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
#ifdef WHISPER_USE_FLASH_ATTN
|
#ifdef WHISPER_USE_FLASH_ATTN
|
||||||
struct ggml_tensor * Kpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
|
struct ggml_tensor * Kpad = ggml_reshape_3d(ctx0, kv_pad.k, n_state_head, n_ctx_pad, n_head);
|
||||||
|
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_reshape_3d(ctx0, Kcur, n_state/n_head, n_head, n_ctx),
|
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
|
||||||
0, 2, 1, 3),
|
0, 2, 1, 3),
|
||||||
ggml_view_3d(ctx0,
|
ggml_view_3d(ctx0,
|
||||||
Kpad,
|
Kpad,
|
||||||
n_state/n_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0));
|
n_state_head, n_ctx, n_head, Kpad->nb[1], Kpad->nb[2], 0));
|
||||||
|
|
||||||
struct ggml_tensor * Vpad = ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, GGML_PAD(n_ctx, 256), n_head);
|
struct ggml_tensor * Vpad = ggml_reshape_3d(ctx0, kv_pad.v, n_state_head, n_ctx_pad, n_head);
|
||||||
|
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_reshape_3d(ctx0, Vcur, n_state/n_head, n_head, n_ctx),
|
ggml_reshape_3d(ctx0, Vcur, n_state_head, n_head, n_ctx),
|
||||||
0, 2, 1, 3),
|
0, 2, 1, 3),
|
||||||
ggml_view_3d(ctx0,
|
ggml_view_3d(ctx0,
|
||||||
Vpad,
|
Vpad,
|
||||||
n_state/n_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0));
|
n_state_head, n_ctx, n_head, Vpad->nb[1], Vpad->nb[2], 0));
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, K);
|
ggml_build_forward_expand(gf, K);
|
||||||
ggml_build_forward_expand(gf, V);
|
ggml_build_forward_expand(gf, V);
|
||||||
@ -1993,7 +2004,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_cpy(ctx0,
|
ggml_cpy(ctx0,
|
||||||
Kcur,
|
Kcur,
|
||||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)),
|
ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
@ -2006,9 +2017,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_reshape_3d(ctx0,
|
ggml_reshape_3d(ctx0,
|
||||||
Vcur,
|
Vcur,
|
||||||
n_state/n_head, n_head, n_ctx),
|
n_state_head, n_head, n_ctx),
|
||||||
1, 2, 0, 3),
|
1, 2, 0, 3),
|
||||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
|
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
|
||||||
);
|
);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||||
@ -2117,6 +2128,8 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|||||||
const int n_state = hparams.n_audio_state;
|
const int n_state = hparams.n_audio_state;
|
||||||
const int n_head = hparams.n_audio_head;
|
const int n_head = hparams.n_audio_head;
|
||||||
|
|
||||||
|
const int n_state_head = n_state/n_head;
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
|
/*.mem_size =*/ wstate.alloc_cross.meta.size(),
|
||||||
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
|
/*.mem_buffer =*/ wstate.alloc_cross.meta.data(),
|
||||||
@ -2129,7 +2142,7 @@ static struct ggml_cgraph * whisper_build_graph_cross(
|
|||||||
|
|
||||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
||||||
|
|
||||||
const float Kscale = pow(float(n_state) / n_head, -0.25);
|
const float Kscale = pow(float(n_state_head), -0.25);
|
||||||
|
|
||||||
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
|
||||||
auto & layer = model.layers_decoder[il];
|
auto & layer = model.layers_decoder[il];
|
||||||
@ -2295,11 +2308,13 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
const int n_head = hparams.n_text_head;
|
const int n_head = hparams.n_text_head;
|
||||||
const int n_layer = hparams.n_text_layer;
|
const int n_layer = hparams.n_text_layer;
|
||||||
|
|
||||||
|
const int n_state_head = n_state/n_head;
|
||||||
|
|
||||||
const int n_tokens = batch.n_tokens;
|
const int n_tokens = batch.n_tokens;
|
||||||
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx;
|
||||||
|
|
||||||
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
||||||
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
|
const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head;
|
||||||
|
|
||||||
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
//WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx);
|
||||||
|
|
||||||
@ -2321,7 +2336,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
ggml_set_name(position, "position");
|
ggml_set_name(position, "position");
|
||||||
ggml_set_input(position);
|
ggml_set_input(position);
|
||||||
|
|
||||||
const float KQscale = pow(float(n_state)/n_head, -0.25);
|
const float KQscale = pow(float(n_state_head), -0.25);
|
||||||
|
|
||||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||||
ggml_set_name(KQ_mask, "KQ_mask");
|
ggml_set_name(KQ_mask, "KQ_mask");
|
||||||
@ -2397,14 +2412,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_view_3d(ctx0, kv_self.k,
|
ggml_view_3d(ctx0, kv_self.k,
|
||||||
n_state/n_head, n_kv, n_head,
|
n_state_head, n_kv, n_head,
|
||||||
ggml_element_size(kv_self.k)*n_state,
|
ggml_element_size(kv_self.k)*n_state,
|
||||||
ggml_element_size(kv_self.k)*n_state/n_head,
|
ggml_element_size(kv_self.k)*n_state_head,
|
||||||
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
@ -2414,9 +2429,9 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
|
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
ggml_view_3d(ctx0, kv_self.v,
|
ggml_view_3d(ctx0, kv_self.v,
|
||||||
n_kv, n_state/n_head, n_head,
|
n_kv, n_state_head, n_head,
|
||||||
n_ctx*ggml_element_size(kv_self.v),
|
n_ctx*ggml_element_size(kv_self.v),
|
||||||
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
|
n_ctx*ggml_element_size(kv_self.v)*n_state_head,
|
||||||
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
|
n_ctx*ggml_element_size(kv_self.v)*n_state*il);
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||||
@ -2469,33 +2484,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
// Kcross is already scaled
|
// Kcross is already scaled
|
||||||
struct ggml_tensor * Kcross =
|
struct ggml_tensor * Kcross =
|
||||||
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
||||||
n_state/n_head, n_audio_ctx, n_head,
|
n_state_head, n_audio_ctx, n_head,
|
||||||
ggml_element_size(wstate.kv_cross.k)*n_state,
|
ggml_element_size(wstate.kv_cross.k)*n_state,
|
||||||
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
ggml_element_size(wstate.kv_cross.k)*n_state_head,
|
||||||
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il);
|
||||||
|
|
||||||
//struct ggml_tensor * Vcross =
|
//struct ggml_tensor * Vcross =
|
||||||
// ggml_reshape_3d(ctx0,
|
// ggml_reshape_3d(ctx0,
|
||||||
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
|
// ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state),
|
||||||
// n_state/n_head, n_head, n_audio_ctx);
|
// n_state_head, n_head, n_audio_ctx);
|
||||||
|
|
||||||
//struct ggml_tensor * V_trans =
|
//struct ggml_tensor * V_trans =
|
||||||
// ggml_cpy(ctx0,
|
// ggml_cpy(ctx0,
|
||||||
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
// ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
||||||
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head));
|
// ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state_head, n_head));
|
||||||
|
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
ggml_view_3d(ctx0, wstate.kv_cross.v,
|
||||||
n_audio_ctx, n_state/n_head, n_head,
|
n_audio_ctx, n_state_head, n_head,
|
||||||
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
n_audio_ctx*ggml_element_size(wstate.kv_cross.v),
|
||||||
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head,
|
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head,
|
||||||
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il);
|
||||||
|
|
||||||
// ------
|
// ------
|
||||||
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens),
|
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
@ -2504,7 +2519,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
//struct ggml_tensor * KQ_scaled =
|
//struct ggml_tensor * KQ_scaled =
|
||||||
// ggml_scale(ctx0,
|
// ggml_scale(ctx0,
|
||||||
// KQ,
|
// KQ,
|
||||||
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
|
// ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state_head)))
|
||||||
// );
|
// );
|
||||||
|
|
||||||
// no masking for cross-attention
|
// no masking for cross-attention
|
||||||
@ -3187,7 +3202,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
||||||
const int factor = 3;
|
const int factor = 3;
|
||||||
|
|
||||||
if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
|
if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype,
|
||||||
|
ctx->model.hparams.n_text_state,
|
||||||
|
ctx->model.hparams.n_text_layer,
|
||||||
|
ctx->model.hparams.n_text_ctx*factor)) {
|
||||||
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
whisper_free_state(state);
|
whisper_free_state(state);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -3198,7 +3216,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype,
|
||||||
|
ctx->model.hparams.n_text_state,
|
||||||
|
ctx->model.hparams.n_text_layer,
|
||||||
|
ctx->model.hparams.n_audio_ctx)) {
|
||||||
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
||||||
whisper_free_state(state);
|
whisper_free_state(state);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -3209,6 +3230,20 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype,
|
||||||
|
ctx->model.hparams.n_audio_state,
|
||||||
|
1,
|
||||||
|
GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) {
|
||||||
|
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
|
whisper_free_state(state);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v);
|
||||||
|
WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6);
|
||||||
|
}
|
||||||
|
|
||||||
// [EXPERIMENTAL] Token-level timestamps with DTW
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
||||||
if (ctx->params.dtw_token_timestamps) {
|
if (ctx->params.dtw_token_timestamps) {
|
||||||
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
|
if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) {
|
||||||
@ -3565,6 +3600,7 @@ void whisper_free_state(struct whisper_state * state) {
|
|||||||
if (state) {
|
if (state) {
|
||||||
kv_cache_free(state->kv_self);
|
kv_cache_free(state->kv_self);
|
||||||
kv_cache_free(state->kv_cross);
|
kv_cache_free(state->kv_cross);
|
||||||
|
kv_cache_free(state->kv_pad);
|
||||||
|
|
||||||
#ifdef WHISPER_USE_COREML
|
#ifdef WHISPER_USE_COREML
|
||||||
if (state->ctx_coreml != nullptr) {
|
if (state->ctx_coreml != nullptr) {
|
||||||
|
Loading…
Reference in New Issue
Block a user