From eb68324c86306e4c7a332e125adedd6551776109 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 13 Jan 2025 13:11:37 +0200 Subject: [PATCH] whisper : fix gpu device selection (#2728) --- src/whisper.cpp | 48 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 13 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index f90d3c1a..11077d5b 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) { static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) { ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); + ggml_backend_dev_t dev = nullptr; + + int cnt = 0; if (params.use_gpu) { for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { - WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); - ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); - if (!result) { - WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) { + if (cnt == 0 || cnt == params.gpu_device) { + dev = dev_cur; + } + + if (++cnt > params.gpu_device) { + break; } - return result; } } } - return nullptr; + if (dev == nullptr) { + WHISPER_LOG_INFO("%s: no GPU found\n", __func__); + return nullptr; + } + + WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); + if (!result) { + WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + } + + return result; } static std::vector whisper_backend_init(const whisper_context_params & params) { @@ -1283,20 +1298,27 @@ static std::vector whisper_backend_init(const whisper_context_pa } static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) { + ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type(); + if (!params.use_gpu) { - return ggml_backend_cpu_buffer_type(); + return result; } - // if we have a GPU device - use it + int cnt = 0; for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i); if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { - WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev)); - return ggml_backend_dev_buffer_type(dev); + if (cnt == 0 || cnt == params.gpu_device) { + result = ggml_backend_dev_buffer_type(dev); + } + + if (++cnt > params.gpu_device) { + break; + } } } - return ggml_backend_cpu_buffer_type(); + return result; } // load the model from a ggml file