diff --git a/whisper.cpp b/whisper.cpp index c0e91152..1054a28b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -406,6 +406,59 @@ struct whisper_segment { bool speaker_turn_next; }; +struct whisper_batch { + int32_t n_tokens; + + whisper_token * token; + whisper_pos * pos; + int32_t * n_seq_id; + whisper_seq_id ** seq_id; + int8_t * logits; +}; + +static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) { + whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (whisper_token *) malloc(sizeof(whisper_token) * n_tokens); + + batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * n_tokens); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * n_tokens); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max); + } + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void whisper_batch_free(struct whisper_batch batch) { + if (batch.token) free(batch.token); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; i < batch.n_tokens; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = 0; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + // medium // hparams: { // 'n_mels': 80, @@ -523,15 +576,31 @@ struct whisper_layer_decoder { struct ggml_tensor * mlp_1_b; }; +struct whisper_kv_cell { + whisper_pos pos = -1; + + std::set seq_id; + + bool has_seq_id(const whisper_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } +}; + struct whisper_kv_cache { + uint32_t head = 0; + uint32_t size = 0; + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + struct ggml_tensor * k; struct ggml_tensor * v; struct ggml_context * ctx; ggml_backend_buffer_t buffer; - - int n; // number of tokens currently in the cache }; struct whisper_model { @@ -723,6 +792,8 @@ struct whisper_state { whisper_kv_cache kv_cross; whisper_mel mel; + whisper_batch batch; + whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; // buffer for swapping KV caches between decoders during beam-search @@ -742,8 +813,9 @@ struct whisper_state { struct ggml_tensor * embd_conv = nullptr; struct ggml_tensor * embd_enc = nullptr; - // helper for GPU offloading + // helpers for GPU offloading std::vector inp_mel; + std::vector inp_mask; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -831,6 +903,12 @@ static bool kv_cache_init( /*.no_alloc =*/ true, }; + cache.head = 0; + cache.size = n_ctx; + + cache.cells.clear(); + cache.cells.resize(n_ctx); + cache.ctx = ggml_init(params); if (!cache.ctx) { @@ -858,6 +936,14 @@ static bool kv_cache_init( return true; } +static void kv_cache_free(struct whisper_kv_cache & cache) { + if (cache.ctx) { + ggml_free(cache.ctx); + ggml_backend_buffer_free(cache.buffer); + cache.ctx = nullptr; + } +} + // TODO: remove after batched decoding static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) { WHISPER_ASSERT(cache.ctx); @@ -901,11 +987,91 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t back return true; } -static void kv_cache_free(struct whisper_kv_cache & cache) { - if (cache.ctx) { - ggml_free(cache.ctx); - ggml_backend_buffer_free(cache.buffer); - cache.ctx = nullptr; +static bool whisper_kv_cache_find_slot( + struct whisper_kv_cache & cache, + const struct whisper_batch & batch) { + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens > n_ctx) { + WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > n_ctx) { + n_tested += n_ctx - cache.head; + cache.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= n_ctx) { + //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) { + cache.cells[cache.head + i].pos = batch.pos[i]; + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } + } + + return true; +} + +// find how many cells are currently in use +static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) { + for (uint32_t i = cache.size - 1; i > 0; --i) { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { + return i + 1; + } + } + + return 0; +} + +static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) { + for (int32_t i = 0; i < (int32_t) cache.size; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } + cache.head = 0; +} + +static void whisper_kv_cache_seq_cp( + struct whisper_kv_cache & cache, + whisper_seq_id seq_id_src, + whisper_seq_id seq_id_dst, + whisper_pos p0, + whisper_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + cache.head = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } } } @@ -2032,25 +2198,29 @@ static struct ggml_cgraph * whisper_build_graph_decoder( whisper_context & wctx, whisper_state & wstate, whisper_decoder & decoder, - const whisper_token * tokens, - int n_tokens, - int n_past) { + const whisper_batch & batch) { const auto & model = wctx.model; const auto & hparams = model.hparams; + // TODO: move to wstate auto & kv_self = decoder.kv_self; WHISPER_ASSERT(!!kv_self.ctx); + ggml_allocr * alloc = wstate.alloc_decode.alloc; + const int n_ctx = hparams.n_text_ctx; const int n_state = hparams.n_text_state; const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; - const int N = n_tokens; - const int M = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_tokens = batch.n_tokens; + const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + const int32_t n_kv = ggml_allocr_is_measure(alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(alloc) ? n_ctx - n_tokens : kv_self.head; + + //WHISPER_PRINT_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); struct ggml_init_params params = { /*.mem_size =*/ wstate.alloc_decode.meta.size(), @@ -2062,21 +2232,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); - ggml_allocr * alloc = wstate.alloc_decode.alloc; - - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(alloc, embd); if (!ggml_allocr_is_measure(alloc)) { - ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd)); + ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd)); } - struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_allocr_alloc(alloc, position); if (!ggml_allocr_is_measure(alloc)) { - for (int i = 0; i < N; ++i) { - const int32_t val = n_past + i; + for (int i = 0; i < n_tokens; ++i) { + const int32_t val = batch.pos[i]; ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t)); } } @@ -2089,6 +2257,31 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_allocr_alloc(alloc, KQ_mask); + + if (!ggml_allocr_is_measure(alloc)) { + wstate.inp_mask.resize(n_kv*n_tokens); + + float * data = wstate.inp_mask.data(); + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const whisper_pos pos = batch.pos[j]; + const whisper_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + + ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float)); + } + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2141,12 +2334,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Vcur, layer.attn_v_b); - Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N)); + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state, + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); @@ -2156,12 +2349,12 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), 0, 2, 1, 3); struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_state/n_head, n_past + N, n_head, + n_state/n_head, n_kv, n_head, ggml_element_size(kv_self.k)*n_state, ggml_element_size(kv_self.k)*n_state/n_head, ggml_element_size(kv_self.k)*n_state*n_ctx*il); @@ -2171,13 +2364,14 @@ static struct ggml_cgraph * whisper_build_graph_decoder( //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_state/n_head, n_head, + n_kv, n_state/n_head, n_head, n_ctx*ggml_element_size(kv_self.v), n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, il*n_ctx*ggml_element_size(kv_self.v)*n_state); @@ -2188,7 +2382,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); } // projection @@ -2232,33 +2426,33 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state/n_head, M, n_head, + n_state/n_head, n_audio_ctx, n_head, ggml_element_size(wstate.kv_cross.k)*n_state, ggml_element_size(wstate.kv_cross.k)*n_state/n_head, - ggml_element_size(wstate.kv_cross.k)*n_state*M*il); + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); //struct ggml_tensor * Vcross = // ggml_reshape_3d(ctx0, - // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), - // n_state/n_head, n_head, M); + // ggml_view_1d(ctx0, wstate.kv_cross.v, n_audio_ctx*n_state, il*n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state), + // n_state/n_head, n_head, n_audio_ctx); //struct ggml_tensor * V_trans = // ggml_cpy(ctx0, // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); + // ggml_new_tensor_3d(ctx0, Vcross->type, n_audio_ctx, n_state/n_head, n_head)); struct ggml_tensor * V = ggml_view_3d(ctx0, wstate.kv_cross.v, - M, n_state/n_head, n_head, - M*ggml_element_size(wstate.kv_cross.v), - M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, - il*M*ggml_element_size(wstate.kv_cross.v)*n_state); + n_audio_ctx, n_state/n_head, n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v), + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); // ------ struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, n_tokens), 0, 2, 1, 3); // K * Q @@ -2279,10 +2473,10 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // cur = KQV_merged.contiguous().view(n_state, N) + // cur = KQV_merged.contiguous().view(n_state, n_tokens) cur = ggml_cpy(ctx0, KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); } // projection @@ -2354,7 +2548,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( } // compute logits only for the last token - // comment this line to compute logits for all N tokens + // comment this line to compute logits for all n_tokens // might be useful in the future cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); @@ -2381,9 +2575,7 @@ static bool whisper_decode_internal( whisper_context & wctx, whisper_state & wstate, whisper_decoder & decoder, - const whisper_token * tokens, - const int n_tokens, - const int n_past, + const whisper_batch & batch, const int n_threads, whisper_abort_callback abort_callback, void * abort_callback_data) { @@ -2398,13 +2590,21 @@ static bool whisper_decode_internal( struct ggml_tensor * logits; + auto & kv_self = decoder.kv_self; + + if (!whisper_kv_cache_find_slot(kv_self, batch)) { + return 1; + } + + kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + // decoder { auto & alloc = wstate.alloc_decode.alloc; ggml_allocr_reset(alloc); - ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, tokens, n_tokens, n_past); + ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, decoder, batch); ggml_allocr_alloc_graph(alloc, gf); @@ -2423,7 +2623,7 @@ static bool whisper_decode_internal( //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab); - if (n_tokens > 1) { + if (batch.n_tokens > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, // ggml_used_mem(ctx0)/1024.0/1024.0, // wstate.get_buf_max_mem(0)/1024.0/1024.0, @@ -2432,7 +2632,7 @@ static bool whisper_decode_internal( // wstate.get_buf_max_mem(3)/1024.0/1024.0); } - if (n_tokens == 1) { + if (batch.n_tokens == 1) { wstate.t_decode_us += ggml_time_us() - t_start_us; wstate.n_decode++; } else { @@ -2443,7 +2643,6 @@ static bool whisper_decode_internal( return !(abort_callback && abort_callback(abort_callback_data)); } - // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2899,6 +3098,8 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->logits_id.reserve(ctx->model.hparams.n_vocab); + state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS); + // TAGS: WHISPER_DECODER_INIT state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); @@ -2946,7 +3147,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { const int n_tokens = hparams.n_text_ctx; const int n_past = 0; - return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); + whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past); + + return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], state->batch); }); WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); @@ -3203,6 +3406,8 @@ void whisper_free_state(struct whisper_state * state) } #endif + whisper_batch_free(state->batch); + whisper_allocr_free(state->alloc_conv); whisper_allocr_free(state->alloc_encode); whisper_allocr_free(state->alloc_cross); @@ -3331,7 +3536,9 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { const int selected_decoder_id = 0; - if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { + whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past); + + if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3348,7 +3555,9 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return false; } - if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { + whisper_batch_prep_legacy(ctx->state->batch, tokens, n_tokens, n_past); + + if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], ctx->state->batch, n_threads, nullptr, nullptr)) { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -5216,7 +5425,11 @@ int whisper_full_with_state( } WHISPER_PRINT_DEBUG("\n\n"); - if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + whisper_kv_cache_clear(state->decoders[0].kv_self); + + whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0); + + if (!whisper_decode_internal(*ctx, *state, state->decoders[0], state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -5449,7 +5662,9 @@ int whisper_full_with_state( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + whisper_batch_prep_legacy(state->batch, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n); + + if (!whisper_decode_internal(*ctx, *state, decoder, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } diff --git a/whisper.h b/whisper.h index 50f84a82..84540989 100644 --- a/whisper.h +++ b/whisper.h @@ -78,7 +78,9 @@ extern "C" { struct whisper_state; struct whisper_full_params; - typedef int whisper_token; + typedef int32_t whisper_pos; + typedef int32_t whisper_token; + typedef int32_t whisper_seq_id; struct whisper_context_params { bool use_gpu;