diff --git a/.gitignore b/.gitignore index d5c4b0ca..9ff35d00 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ .DS_Store build/ +build-coreml/ build-em/ build-debug/ build-release/ diff --git a/ggml-metal.h b/ggml-metal.h index 096b844e..be2731f8 100644 --- a/ggml-metal.h +++ b/ggml-metal.h @@ -26,7 +26,7 @@ #include // max memory buffers that can be mapped to the device -#define GGML_METAL_MAX_BUFFERS 16 +#define GGML_METAL_MAX_BUFFERS 64 #define GGML_METAL_MAX_COMMAND_BUFFERS 32 struct ggml_tensor; diff --git a/ggml-metal.m b/ggml-metal.m index 148c12b1..6293908c 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -479,6 +479,10 @@ static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru const int64_t tsize = ggml_nbytes(t); + if (t->buffer && t->buffer->backend && t->buffer->backend->context) { + ctx = t->buffer->backend->context; + } + // find the view that contains the tensor fully for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; diff --git a/whisper.cpp b/whisper.cpp index 471d9a85..244cfeb1 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -649,7 +649,6 @@ static size_t whisper_allocr_size(struct whisper_allocr & allocr) { static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { auto & alloc = allocr.alloc; auto & meta = allocr.meta; - auto & buffer = allocr.buffer; alloc = ggml_allocr_new_measure_from_backend(backend); @@ -659,6 +658,11 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe } static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { + if (allocr.alloc == nullptr) { + // this can be null if we use external encoder like CoreML or OpenVINO + return; + } + auto & alloc = allocr.alloc; auto & buffer = allocr.buffer; @@ -702,6 +706,8 @@ struct whisper_state { // buffer for swapping KV caches between decoders during beam-search std::vector kv_swap_bufs; + ggml_backend_t backend = nullptr; + // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers // - stores the actual tensor data into the `data` buffers @@ -881,6 +887,37 @@ static void kv_cache_free(struct whisper_kv_cache & cache) { } } +static ggml_backend_t whisper_backend_init(const whisper_context_params & params) { + ggml_backend_t backend_gpu = NULL; + + // initialize the backends +#ifdef GGML_USE_CUBLAS + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); + backend_gpu = ggml_backend_cuda_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); + ggml_metal_log_set_callback(whisper_log_callback_default, nullptr); + backend_gpu = ggml_backend_metal_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if (backend_gpu) { + return backend_gpu; + } + return ggml_backend_cpu_init(); +} + // load the model from a ggml file // // file format: @@ -1299,38 +1336,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - // init backends - { - ggml_backend_t backend_gpu = NULL; - - // initialize the backends -#ifdef GGML_USE_CUBLAS - if (wctx.params.use_gpu) { - WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); - backend_gpu = ggml_backend_cuda_init(); - if (!backend_gpu) { - WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); - } - } -#endif - -#ifdef GGML_USE_METAL - if (wctx.params.use_gpu) { - WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); - ggml_metal_log_set_callback(whisper_log_callback_default, nullptr); - backend_gpu = ggml_backend_metal_init(); - if (!backend_gpu) { - WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__); - } - } -#endif - - if (backend_gpu) { - wctx.backend = backend_gpu; - } else { - wctx.backend = ggml_backend_cpu_init(); - } - } + wctx.backend = whisper_backend_init(wctx.params); { size_t size_main = 0; @@ -1964,7 +1970,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - ggml_graph_compute_helper(wctx.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } } @@ -1978,7 +1984,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); - ggml_graph_compute_helper(wctx.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } // cross @@ -1991,7 +1997,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); - ggml_graph_compute_helper(wctx.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } wstate.t_encode_us += ggml_time_us() - t_start_us; @@ -2382,7 +2388,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - ggml_graph_compute_helper(wctx.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } // extract logits for all N tokens @@ -2825,6 +2831,8 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_state * state = new whisper_state; + state->backend = whisper_backend_init(ctx->params); + if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; @@ -2922,9 +2930,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); } - whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_conv, ctx->backend); whisper_allocr_graph_realloc(state->alloc_encode, ctx->backend); - whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend); + whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend); whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend); state->rng = std::mt19937(0); @@ -3178,6 +3186,8 @@ void whisper_free_state(struct whisper_state * state) whisper_allocr_free(state->alloc_cross); whisper_allocr_free(state->alloc_decode); + ggml_backend_free(state->backend); + delete state; } }