whisper : fixes

This commit is contained in:
Georgi Gerganov 2023-11-11 17:39:30 +02:00
parent b618229340
commit fc8565d0e2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -651,12 +651,11 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backe
auto & meta = allocr.meta;
auto & buffer = allocr.buffer;
const int tensor_alignment = ggml_backend_get_alignment(backend);
alloc = ggml_allocr_new_measure(tensor_alignment);
alloc = ggml_allocr_new_measure_from_backend(backend);
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment;
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph());
ggml_allocr_free(alloc);
@ -1299,7 +1298,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// initialize the backends
#ifdef GGML_USE_CUBLAS
if (wctx.params.use_gpu > 0) {
if (wctx.params.use_gpu) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
backend_gpu = ggml_backend_cuda_init();
if (!backend_gpu) {