diff --git a/whisper.cpp b/whisper.cpp index d8e7f370..68626c97 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2886,7 +2886,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { #define WHISPER_METAL_CHECK_BUF(result) \ if (!(result)) { \ - log("%s: failed to add buffer\n", __func__); \ + log("%s: failed to add metal buffer\n", __func__); \ delete state; \ return nullptr; \ } @@ -4425,6 +4425,21 @@ int whisper_full_with_state( decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); decoder.logprobs.resize(ctx->vocab.n_vocab); + + // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0 +#ifdef GGML_USE_METAL +#define WHISPER_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + log("%s: failed to add metal buffer\n", __func__); \ + return 0; \ + } + + const std::string kv_name = "kv_self_" + std::to_string(j); + auto & kv_self = decoder.kv_self; + + WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); +#undef WHISPER_METAL_CHECK_BUF +#endif } }