metal : add multi-decoder support

This commit is contained in:
Georgi Gerganov
2023-09-12 19:33:29 +03:00
parent fbc9ddc582
commit 79a88057bd

View File

@ -2886,7 +2886,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
#define WHISPER_METAL_CHECK_BUF(result) \
if (!(result)) { \
log("%s: failed to add buffer\n", __func__); \
log("%s: failed to add metal buffer\n", __func__); \
delete state; \
return nullptr; \
}
@ -4425,6 +4425,21 @@ int whisper_full_with_state(
decoder.probs.resize (ctx->vocab.n_vocab);
decoder.logits.resize (ctx->vocab.n_vocab);
decoder.logprobs.resize(ctx->vocab.n_vocab);
// TODO: not very clean - look for a better way and potentially merging with the init of decoder 0
#ifdef GGML_USE_METAL
#define WHISPER_METAL_CHECK_BUF(result) \
if (!(result)) { \
log("%s: failed to add metal buffer\n", __func__); \
return 0; \
}
const std::string kv_name = "kv_self_" + std::to_string(j);
auto & kv_self = decoder.kv_self;
WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0));
#undef WHISPER_METAL_CHECK_BUF
#endif
}
}