From d029784fb0914ef975dff47af9bca8bad4fa5408 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov <ggerganov@gmail.com>
Date: Sat, 11 Nov 2023 18:37:14 +0200
Subject: [PATCH] whisper : try to fix the parallel whisper_state functionality

---
 whisper.cpp | 45 ++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 40 insertions(+), 5 deletions(-)

diff --git a/whisper.cpp b/whisper.cpp
index 471d9a85..ccc7aaa8 100644
--- a/whisper.cpp
+++ b/whisper.cpp
@@ -702,6 +702,8 @@ struct whisper_state {
     // buffer for swapping KV caches between decoders during beam-search
     std::vector<kv_buf> kv_swap_bufs;
 
+    ggml_backend_t backend = nullptr;
+
     // ggml-alloc:
     // - stores meta info about the intermediate tensors into the `meta` buffers
     // - stores the actual tensor data into the `data` buffers
@@ -1299,7 +1301,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
         }
     }
 
-    // init backends
+    // init backend
     {
         ggml_backend_t backend_gpu = NULL;
 
@@ -1964,7 +1966,7 @@ static bool whisper_encode_internal(
         ggml_allocr_alloc_graph(alloc, gf);
 
         if (!whisper_encode_external(wstate)) {
-            ggml_graph_compute_helper(wctx.backend, gf, n_threads);
+            ggml_graph_compute_helper(wstate.backend, gf, n_threads);
         }
     }
 
@@ -1978,7 +1980,7 @@ static bool whisper_encode_internal(
 
         ggml_allocr_alloc_graph(alloc, gf);
 
-        ggml_graph_compute_helper(wctx.backend, gf, n_threads);
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
     // cross
@@ -1991,7 +1993,7 @@ static bool whisper_encode_internal(
 
         ggml_allocr_alloc_graph(alloc, gf);
 
-        ggml_graph_compute_helper(wctx.backend, gf, n_threads);
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
     wstate.t_encode_us += ggml_time_us() - t_start_us;
@@ -2382,7 +2384,7 @@ static bool whisper_decode_internal(
 
         logits = gf->nodes[gf->n_nodes - 1];
 
-        ggml_graph_compute_helper(wctx.backend, gf, n_threads);
+        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
     }
 
     // extract logits for all N tokens
@@ -2825,6 +2827,39 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
 
     whisper_state * state = new whisper_state;
 
+    // init backend
+    {
+        ggml_backend_t backend_gpu = NULL;
+
+        // initialize the backends
+#ifdef GGML_USE_CUBLAS
+        if (ctx->params.use_gpu) {
+            WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
+            backend_gpu = ggml_backend_cuda_init();
+            if (!backend_gpu) {
+                WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
+            }
+        }
+#endif
+
+#ifdef GGML_USE_METAL
+        if (ctx->params.use_gpu) {
+            WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
+            ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
+            backend_gpu = ggml_backend_metal_init();
+            if (!backend_gpu) {
+                WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
+            }
+        }
+#endif
+
+        if (backend_gpu) {
+            state->backend = backend_gpu;
+        } else {
+            state->backend = ggml_backend_cpu_init();
+        }
+    }
+
     if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
         WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
         delete state;