From c058aaf22ed7bd1a21097a0f121aa826e988f1dc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 11 Nov 2022 22:33:10 +0200 Subject: [PATCH] stream : partial encoder experiments --- examples/stream/stream.cpp | 6 ++-- whisper.cpp | 67 ++++++++++++++++++++++++++------------ whisper.h | 3 ++ 3 files changed, 54 insertions(+), 22 deletions(-) diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 718c8151..3c2f8612 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -221,6 +221,7 @@ int main(int argc, char ** argv) { const int n_samples = (params.step_ms/1000.0)*WHISPER_SAMPLE_RATE; const int n_samples_len = (params.length_ms/1000.0)*WHISPER_SAMPLE_RATE; const int n_samples_30s = 30*WHISPER_SAMPLE_RATE; + const int n_samples_keep = 0.2*WHISPER_SAMPLE_RATE; std::vector pcmf32(n_samples_30s, 0.0f); std::vector pcmf32_old; @@ -303,7 +304,7 @@ int main(int argc, char ** argv) { //const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_30s/30 - n_samples_new)); // take up to params.length_ms audio from previous iteration - const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_len - n_samples_new)); + const int n_samples_take = std::min((int) pcmf32_old.size(), std::max(0, n_samples_keep + n_samples_len - n_samples_new)); //printf("processing: take = %d, new = %d, old = %d\n", n_samples_take, n_samples_new, (int) pcmf32_old.size()); @@ -379,7 +380,8 @@ int main(int argc, char ** argv) { if ((n_iter % n_new_line) == 0) { printf("\n"); - pcmf32_old.clear(); + // keep part of the audio for next iteration to try to mitigate word boundary issues + pcmf32_old = std::vector(pcmf32.end() - n_samples_keep, pcmf32.end()); } } } diff --git a/whisper.cpp b/whisper.cpp index a8b9e714..7c4a1d4c 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -613,7 +613,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx const int n_audio_state = hparams.n_audio_state; const int n_audio_layer = hparams.n_audio_layer; - const int n_text_ctx = hparams.n_text_ctx; + const int n_text_ctx = hparams.n_text_ctx; const int n_text_state = hparams.n_text_state; const int n_text_layer = hparams.n_text_layer; @@ -748,7 +748,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx const int n_audio_state = hparams.n_audio_state; const int n_audio_layer = hparams.n_audio_layer; - const int n_text_ctx = hparams.n_text_ctx; + const int n_text_ctx = hparams.n_text_ctx; const int n_text_state = hparams.n_text_state; const int n_text_layer = hparams.n_text_layer; @@ -967,13 +967,16 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx // key/value memory for the cross-attention layer { - const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_ctx = hparams.n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx; const int n_elements = n_text_state*n_mem; model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements); + + //memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + //memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); } const size_t memory_size = @@ -1076,13 +1079,11 @@ static bool whisper_encode( const auto & mel_inp = wctx.mel; const auto & hparams = model.hparams; - const int n_ctx = hparams.n_audio_ctx; + const int n_ctx = WHISPER_EXPERIMENT_AUDIO_CTX; const int n_state = hparams.n_audio_state; const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; - const int N = n_ctx; - const int n_mels = hparams.n_mels; assert(mel_inp.n_mel == n_mels); @@ -1132,7 +1133,24 @@ static bool whisper_encode( cur = ggml_gelu(ctx0, cur); } - cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + //static int iter = -1; + //const int n_iter = 1500/n_ctx; + + //iter = (iter + 1) % n_iter; + + //if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + //} + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + + struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + + cur = ggml_add(ctx0, e_pe, ggml_transpose(ctx0, cur)); struct ggml_tensor * inpL = cur; @@ -1198,14 +1216,14 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_cpy(ctxL, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = ggml_permute(ctxL, ggml_cpy(ctxL, Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * V = @@ -1213,9 +1231,9 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_reshape_3d(ctxL, Vcur, - n_state/n_head, n_head, N), + n_state/n_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, N, n_state/n_head, n_head) + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head) ); struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); @@ -1224,14 +1242,14 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_cpy(ctxL, Qcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F32, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); struct ggml_tensor * K = ggml_permute(ctxL, ggml_cpy(ctxL, Kcur, - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), 0, 2, 1, 3); // K * Q @@ -1249,7 +1267,7 @@ static bool whisper_encode( // ggml_permute(ctxL, // ggml_cpy(ctxL, // Vcur, - // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)), // 1, 2, 0, 3); //struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); @@ -1259,9 +1277,9 @@ static bool whisper_encode( ggml_permute(ctxL, ggml_reshape_3d(ctxL, Vcur, - n_state/n_head, n_head, N), + n_state/n_head, n_head, n_ctx), 0, 2, 1, 3), - ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, N, n_head) + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head) ); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max); @@ -1271,7 +1289,7 @@ static bool whisper_encode( cur = ggml_cpy(ctxL, KQV_merged, - ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, N)); + ggml_new_tensor_2d(ctxL, GGML_TYPE_F32, n_state, n_ctx)); } // projection @@ -1425,6 +1443,8 @@ static bool whisper_encode( Vcross), Vcross); + //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx)); struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx)); @@ -1474,7 +1494,8 @@ static bool whisper_decode( const int n_layer = hparams.n_text_layer; const int N = n_tokens; - const int M = hparams.n_audio_ctx; + //const int M = hparams.n_audio_ctx; + const int M = WHISPER_EXPERIMENT_AUDIO_CTX; struct ggml_init_params params = { .mem_size = wctx.buf_compute.size(), @@ -2662,7 +2683,7 @@ int whisper_full( //} // end of text token - if (token.id == whisper_token_eot(ctx)) { + if (token.id == whisper_token_eot(ctx) || (i > WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT)) { if (result_len == 0) { if (seek + seek_delta + 100 >= seek_end) { result_len = i + 1; @@ -2671,6 +2692,12 @@ int whisper_full( fprintf(stderr, "\n%s: failed to generate timestamp token - this should not happen\n\n", __func__); } } + + // TODO: TMP TO MAKE STREAM WORK ON RPI4 === + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + // ========================================= + break; } @@ -2850,7 +2877,7 @@ int whisper_full_parallel( // key/value memory for the cross-attention layer { - const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_ctx = hparams.n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx; const int n_elements = n_text_state*n_mem; diff --git a/whisper.h b/whisper.h index ea677eaf..769ae643 100644 --- a/whisper.h +++ b/whisper.h @@ -24,6 +24,9 @@ #define WHISPER_HOP_LENGTH 160 #define WHISPER_CHUNK_SIZE 30 +#define WHISPER_EXPERIMENT_AUDIO_CTX 512 +#define WHISPER_EXPERIMENT_MAX_TOKENS_PER_SEGMENT 32 + #ifdef __cplusplus extern "C" { #endif