diff --git a/whisper.cpp b/whisper.cpp index b6300d5f..a049ec61 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1,10 +1,15 @@ #include "whisper.h" + #ifdef WHISPER_USE_COREML #include "coreml/whisper-encoder.h" #endif #ifdef GGML_USE_METAL -# include "ggml-metal.h" +#include "ggml-metal.h" +#endif + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" #endif #ifdef WHISPER_USE_OPENVINO @@ -13,6 +18,7 @@ #include "ggml.h" #include "ggml-alloc.h" +#include "ggml-backend.h" #include #include @@ -97,10 +103,32 @@ static void byteswap_tensor(ggml_tensor * tensor) { #define BYTESWAP_TENSOR(t) do {} while (0) #endif +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +WHISPER_ATTRIBUTE_FORMAT(2, 3) +static void whisper_log_internal (ggml_log_level level, const char* format, ...); +static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) + #define WHISPER_ASSERT(x) \ do { \ if (!(x)) { \ - log("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ abort(); \ } \ } while (0) @@ -554,12 +582,16 @@ struct whisper_kv_cache { struct ggml_context * ctx; - // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init) - std::vector buf; + ggml_backend_buffer_t buffer; int n; // number of tokens currently in the cache }; +struct whisper_model_data { + ggml_backend_buffer_t buffer_conv; // TODO: tmp until GPU support for conv + ggml_backend_buffer_t buffer_main; +}; + struct whisper_model { e_model type = MODEL_UNKNOWN; @@ -597,8 +629,8 @@ struct whisper_model { // context struct ggml_context * ctx; - // the model memory buffer is read-only and can be shared between processors - std::vector * buf; + // the model backend data is read-only and can be shared between processors + struct whisper_model_data * data; // tensors int n_loaded; @@ -663,32 +695,31 @@ struct whisper_allocr { ggml_allocr * alloc = nullptr; std::vector meta; - std::vector data; + + ggml_backend_buffer_t buffer; }; static size_t whisper_allocr_size(struct whisper_allocr & allocr) { - return allocr.meta.size() + allocr.data.size(); + return allocr.meta.size() + ggml_backend_buffer_get_size(allocr.buffer); } // measure the memory usage of a graph and prepare the allocr's internal data buffer -static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::function && get_graph) { - const int tensor_alignment = 32; +static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; + auto & buffer = allocr.buffer; - auto & alloc = allocr.alloc; - auto & meta = allocr.meta; - auto & data = allocr.data; + const int tensor_alignment = ggml_backend_get_alignment(backend); + alloc = ggml_allocr_new_measure(tensor_alignment); meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); - alloc = ggml_allocr_new_measure(tensor_alignment); - const size_t alloc_size = ggml_allocr_alloc_graph(alloc, get_graph()) + tensor_alignment; ggml_allocr_free(alloc); - data.resize(alloc_size); - - alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); + buffer = ggml_backend_alloc_buffer(backend, alloc_size); + alloc = ggml_allocr_new_from_buffer(buffer); } static void whisper_allocr_free(struct whisper_allocr & allocr) { @@ -722,9 +753,6 @@ struct whisper_state { // buffer for swapping KV caches between decoders during beam-search std::vector kv_swap_bufs; - // reusable buffer for `struct ggml_graph_plan.work_data` - std::vector work_buffer; - // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers // - stores the actual tensor data into the `data` buffers @@ -737,6 +765,8 @@ struct whisper_state { struct ggml_tensor * embd_conv = nullptr; struct ggml_tensor * embd_enc = nullptr; + std::vector inp_mel; + // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -751,22 +781,21 @@ struct whisper_state { int lang_id = 0; // english by default std::string path_model; // populated by whisper_init_from_file_with_params() + #ifdef WHISPER_USE_COREML whisper_coreml_context * ctx_coreml = nullptr; #endif -#ifdef GGML_USE_METAL - ggml_metal_context * ctx_metal = nullptr; -#endif - #ifdef WHISPER_USE_OPENVINO whisper_openvino_context * ctx_openvino = nullptr; #endif // [EXPERIMENTAL] token-level timestamps data - int64_t t_beg = 0; + int64_t t_beg = 0; int64_t t_last = 0; + whisper_token tid_last; + std::vector energy; // PCM signal energy // [EXPERIMENTAL] speed-up techniques @@ -780,35 +809,39 @@ struct whisper_context { ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX) ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16) + whisper_context_params params; + whisper_model model; whisper_vocab vocab; + whisper_state * state = nullptr; + ggml_backend_t backend_cpu = nullptr; + ggml_backend_t backend_gpu = nullptr; + std::string path_model; // populated by whisper_init_from_file_with_params() - whisper_context_params params; + + ggml_backend_t backend_kv() const { + return backend_gpu ? backend_gpu : backend_cpu; + } + + // TODO: always on CPU until we have a GPU support for conv + ggml_backend_t backend_conv() const { + return backend_cpu; + } + + ggml_backend_t backend_main() const { + return backend_gpu ? backend_gpu : backend_cpu; + } }; -static void whisper_default_log(const char * text) { - fprintf(stderr, "%s", text); -} +struct whisper_global { + // We save the log callback globally + ggml_log_callback log_callback = whisper_log_callback_default; + void * log_callback_user_data = nullptr; +}; -static whisper_log_callback whisper_log = whisper_default_log; - -#ifdef __GNUC__ -#ifdef __MINGW32__ -__attribute__((gnu_format(printf, 1, 2))) -#else -__attribute__((format(printf, 1, 2))) -#endif -#endif -static void log(const char * fmt, ...) { - if (!whisper_log) return; - char buf[1024]; - va_list args; - va_start(args, fmt); - vsnprintf(buf, sizeof(buf), fmt, args); - whisper_log(buf); -} +static whisper_global g_state; template static void read_safe(whisper_model_loader * loader, T & dest) { @@ -819,6 +852,7 @@ static void read_safe(whisper_model_loader * loader, T & dest) { static bool kv_cache_init( const struct whisper_hparams & hparams, struct whisper_kv_cache & cache, + ggml_backend_t backend, ggml_type wtype, int n_ctx) { const int64_t n_text_state = hparams.n_text_state; @@ -827,30 +861,41 @@ static bool kv_cache_init( const int64_t n_mem = n_text_layer*n_ctx; const int64_t n_elements = n_text_state*n_mem; - const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead()); - - cache.buf.resize(mem_bytes); - struct ggml_init_params params = { - /*.mem_size =*/ cache.buf.size(), - /*.mem_buffer =*/ cache.buf.data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; cache.ctx = ggml_init(params); if (!cache.ctx) { - log("%s: failed to allocate memory for kv cache\n", __func__); + WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); return false; } cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v); + + cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes); + + // allocate the tensors into the backend buffer + { + ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer); + + ggml_allocr_alloc(alloc, cache.k); + ggml_allocr_alloc(alloc, cache.v); + + ggml_allocr_free(alloc); + } + return true; } -static bool kv_cache_reinit(struct whisper_kv_cache & cache) { +// TODO: remove after batched decoding +static bool kv_cache_reinit(struct whisper_kv_cache & cache, ggml_backend_t backend) { WHISPER_ASSERT(cache.ctx); const int n_elements = ggml_nelements(cache.k); @@ -859,24 +904,36 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) { const ggml_type wtype = cache.k->type; WHISPER_ASSERT(wtype == cache.v->type); - WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype)); - struct ggml_init_params params = { - /*.mem_size =*/ cache.buf.size(), - /*.mem_buffer =*/ cache.buf.data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; cache.ctx = ggml_init(params); if (!cache.ctx) { - log("%s: failed to allocate memory for kv cache\n", __func__); + WHISPER_LOG_ERROR("%s: failed to allocate memory for kv cache\n", __func__); return false; } cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + const size_t mem_bytes = ggml_nbytes(cache.k) + ggml_nbytes(cache.v); + + cache.buffer = ggml_backend_alloc_buffer(backend, mem_bytes); + + // allocate the tensors into the backend buffer + { + ggml_allocr * alloc = ggml_allocr_new_from_buffer(cache.buffer); + + ggml_allocr_alloc(alloc, cache.k); + ggml_allocr_alloc(alloc, cache.v); + + ggml_allocr_free(alloc); + } + return true; } @@ -899,7 +956,7 @@ static void kv_cache_free(struct whisper_kv_cache & cache) { // see the convert-pt-to-ggml.py script for details // static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { - log("%s: loading model\n", __func__); + WHISPER_LOG_INFO("%s: loading model\n", __func__); const int64_t t_start_us = ggml_time_us(); @@ -913,7 +970,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con uint32_t magic; read_safe(loader, magic); if (magic != GGML_FILE_MAGIC) { - log("%s: invalid model data (bad magic)\n", __func__); + WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); return false; } } @@ -970,41 +1027,23 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // in order to save memory and also to speed up the computation wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); if (wctx.wtype == GGML_TYPE_COUNT) { - log("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); + WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); return false; } - const size_t scale = model.hparams.ftype ? 1 : 2; - - log("%s: n_vocab = %d\n", __func__, hparams.n_vocab); - log("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); - log("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); - log("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); - log("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); - log("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); - log("%s: n_text_state = %d\n", __func__, hparams.n_text_state); - log("%s: n_text_head = %d\n", __func__, hparams.n_text_head); - log("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); - log("%s: n_mels = %d\n", __func__, hparams.n_mels); - log("%s: ftype = %d\n", __func__, model.hparams.ftype); - log("%s: qntvr = %d\n", __func__, qntvr); - log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); - - // print memory requirements - { - // TODO - //log("%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, - // mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); - } - - // initialize all memory buffers - // always have at least one decoder - - wctx.model.buf = new std::vector(); - wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type)); - - // we skip initialization of the state until it is needed - // because it might be that state will always be provided externally. + WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state); + WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head); + WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype); + WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); } // load mel filters @@ -1025,7 +1064,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con read_safe(loader, n_vocab); //if (n_vocab != model.hparams.n_vocab) { - // log("%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n", // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); // return false; //} @@ -1045,7 +1084,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word.assign(&tmp[0], tmp.size()); } else { // seems like we have an empty-string token in multi-language models (i = 50256) - //log("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); word = ""; } @@ -1073,7 +1112,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } if (n_vocab < model.hparams.n_vocab) { - log("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); for (int i = n_vocab; i < model.hparams.n_vocab; i++) { if (i > vocab.token_beg) { word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; @@ -1099,140 +1138,35 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - log("%s: n_langs = %d\n", __func__, vocab.num_languages()); + WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages()); } - size_t ctx_size = 0; - const ggml_type wtype = wctx.wtype; const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type + // create the ggml context { const auto & hparams = model.hparams; - const int n_vocab = hparams.n_vocab; - - const int n_audio_ctx = hparams.n_audio_ctx; - const int n_audio_state = hparams.n_audio_state; const int n_audio_layer = hparams.n_audio_layer; + const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; + const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer; - const int n_mels = hparams.n_mels; - - // encoder - { - ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe; - - ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(vtype); // e_conv_1_w - ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b - - ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(vtype); // e_conv_2_w - ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b - - ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w; - ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b; - } - - // decoder - { - ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe; - - ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te; - - ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w; - ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b; - } - - // encoder layers - { - ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b - - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w - ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b - - ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b - - ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b - - ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w - ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b - } - - // decoder layers - { - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b - - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w - ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b - - ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b - - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b - // - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w - ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b - - ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w - ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b - } - - ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*512; // object overhead - - log("%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); - } - - // create the ggml context - { struct ggml_init_params params = { - /*.mem_size =*/ wctx.model.buf->size(), - /*.mem_buffer =*/ wctx.model.buf->data(), - /*.no_alloc =*/ false, + /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, }; model.ctx = ggml_init(params); if (!model.ctx) { - log("%s: ggml_init() failed\n", __func__); + WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__); return false; } } - // prepare memory for the weights + // prepare tensors for the weights { auto & ctx = model.ctx; @@ -1428,12 +1362,74 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } + // init backends + { + model.data = new whisper_model_data; + + ggml_backend_t backend_gpu = NULL; + + // initialize the backends +#ifdef GGML_USE_CUBLAS + if (wctx.params.use_gpu > 0) { + WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); + backend_gpu = ggml_backend_cuda_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (wctx.params.use_gpu) { + WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); + ggml_metal_log_set_callback(whisper_log_callback_default, nullptr); + backend_gpu = ggml_backend_metal_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if (backend_gpu) { + wctx.backend_gpu = backend_gpu; + } else { + wctx.backend_gpu = nullptr; + } + + // always add the CPU backend as a fallback + wctx.backend_cpu = ggml_backend_cpu_init(); + } + + { + size_t size_conv = 0; + size_t size_main = 0; + + for (const auto & t : model.tensors) { + if (t.first.find("conv") != std::string::npos) { + size_conv += ggml_nbytes(t.second) + ggml_tensor_overhead(); + } else { + size_main += ggml_nbytes(t.second) + ggml_tensor_overhead(); + } + } + + model.data->buffer_conv = ggml_backend_alloc_buffer(wctx.backend_conv(), size_conv); + model.data->buffer_main = ggml_backend_alloc_buffer(wctx.backend_main(), size_main); + + WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_conv()), size_conv / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: %8s buffer size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend_main()), size_main / 1024.0 / 1024.0); + } + + ggml_allocr * alloc_conv = ggml_allocr_new_from_buffer(model.data->buffer_conv); + ggml_allocr * alloc_main = ggml_allocr_new_from_buffer(model.data->buffer_main); + // load weights { size_t total_size = 0; model.n_loaded = 0; + std::vector read_buf; + while (true) { int32_t n_dims; int32_t length; @@ -1460,20 +1456,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con name.assign(&tmp[0], tmp.size()); if (model.tensors.find(name) == model.tensors.end()) { - log("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); return false; } auto tensor = model.tensors[name.data()]; if (ggml_nelements(tensor) != nelements) { - log("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - log("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); return false; } if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { - log("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); return false; } @@ -1481,29 +1477,52 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con const size_t bpe = ggml_type_size(ggml_type(ttype)); if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - log("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); return false; } - loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); - BYTESWAP_TENSOR(tensor); + const bool is_conv = name.find("conv") != std::string::npos; + + ggml_allocr * alloc = is_conv ? alloc_conv : alloc_main; + ggml_backend * backend = is_conv ? wctx.backend_conv() : wctx.backend_main(); + + ggml_allocr_alloc(alloc, tensor); + //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str()); + + if (ggml_backend_is_cpu(backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(backend) +#endif + ) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + loader->read(loader->context, read_buf.data(), read_buf.size()); + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1024.0/1024.0); total_size += ggml_nbytes(tensor); model.n_loaded++; } - log("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1024.0/1024.0); if (model.n_loaded == 0) { - log("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); } else if (model.n_loaded != (int) model.tensors.size()) { - log("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); return false; } } + ggml_allocr_free(alloc_conv); + ggml_allocr_free(alloc_main); + wctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -1559,7 +1578,9 @@ static struct ggml_cgraph * whisper_build_graph_conv( if (!ggml_allocr_is_measure(alloc)) { assert(mel_inp.n_mel == n_mels); - float * dst = (float *) mel->data; + wstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = wstate.inp_mel.data(); memset(dst, 0, ggml_nbytes(mel)); const int i0 = std::min(mel_offset, mel_inp.n_len); @@ -1570,6 +1591,8 @@ static struct ggml_cgraph * whisper_build_graph_conv( dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; } } + + ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); } struct ggml_tensor * cur = nullptr; @@ -1596,6 +1619,7 @@ static struct ggml_cgraph * whisper_build_graph_conv( cur = ggml_gelu(ctx0, cur); } + ggml_set_name(cur, "embd_conv"); wstate.embd_conv = cur; } else { #ifdef WHISPER_USE_COREML @@ -1615,6 +1639,7 @@ static struct ggml_cgraph * whisper_build_graph_conv( } #endif + ggml_set_name(cur, "embd_enc"); wstate.embd_enc = cur; } @@ -1652,10 +1677,16 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_allocr_alloc(alloc, KQscale); if (!ggml_allocr_is_measure(alloc)) { - ggml_set_f32(KQscale, 1.0f/sqrt(float(n_state)/n_head)); + const float val = 1.0f/sqrtf(float(n_state)/n_head); + ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } - struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); + struct ggml_tensor * cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_state); + ggml_allocr_alloc(alloc, cur); + + if (!ggml_allocr_is_measure(alloc)) { + ggml_backend_tensor_copy(wstate.embd_conv, cur); + } // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) @@ -1675,7 +1706,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); - cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); // =================================================================== @@ -1903,7 +1933,8 @@ static struct ggml_cgraph * whisper_build_graph_cross( ggml_allocr_alloc(alloc, Kscale); if (!ggml_allocr_is_measure(alloc)) { - ggml_set_f32(Kscale, pow(float(n_state) / n_head, -0.25)); + const float val = pow(float(n_state) / n_head, -0.25); + ggml_backend_tensor_set(Kscale, &val, 0, sizeof(float)); } for (int il = 0; il < model.hparams.n_text_layer; ++il) { @@ -1974,7 +2005,15 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); + if (ggml_backend_is_cpu(wctx.backend_conv())) { + ggml_backend_cpu_set_n_threads(wctx.backend_conv(), n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(wctx.backend_conv())) { + ggml_backend_metal_set_n_cb(wctx.backend_conv(), n_threads); + } +#endif + ggml_backend_graph_compute(wctx.backend_conv(), gf); } } @@ -1988,16 +2027,15 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); -#ifdef GGML_USE_METAL - if (wstate.ctx_metal) { - ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); + if (ggml_backend_is_cpu(wctx.backend_main())) { + ggml_backend_cpu_set_n_threads(wctx.backend_main(), n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(wctx.backend_main())) { + ggml_backend_metal_set_n_cb(wctx.backend_main(), n_threads); } -#else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif + ggml_backend_graph_compute(wctx.backend_main(), gf); } // cross @@ -2010,20 +2048,17 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); -#ifdef GGML_USE_METAL - if (wstate.ctx_metal) { - ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); + if (ggml_backend_is_cpu(wctx.backend_main())) { + ggml_backend_cpu_set_n_threads(wctx.backend_main(), n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(wctx.backend_main())) { + ggml_backend_metal_set_n_cb(wctx.backend_main(), n_threads); } -#else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif + ggml_backend_graph_compute(wctx.backend_main(), gf); } - // ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); - wstate.t_encode_us += ggml_time_us() - t_start_us; wstate.n_encode++; @@ -2070,7 +2105,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_allocr_alloc(alloc, embd); if (!ggml_allocr_is_measure(alloc)) { - memcpy(embd->data, tokens, N*ggml_element_size(embd)); + ggml_backend_tensor_set(embd, tokens, 0, N*ggml_element_size(embd)); } struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); @@ -2078,7 +2113,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( if (!ggml_allocr_is_measure(alloc)) { for (int i = 0; i < N; ++i) { - ((int32_t *) position->data)[i] = n_past + i; + const int32_t val = n_past + i; + ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t)); } } @@ -2086,7 +2122,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_allocr_alloc(alloc, KQscale); if (!ggml_allocr_is_measure(alloc)) { - ggml_set_f32(KQscale, pow(float(n_state)/n_head, -0.25)); + const float val = pow(float(n_state)/n_head, -0.25); + ggml_backend_tensor_set(KQscale, &val, 0, sizeof(float)); } // token encoding + position encoding @@ -2410,16 +2447,15 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; -#ifdef GGML_USE_METAL - if (wstate.ctx_metal) { - ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); - ggml_metal_graph_compute(wstate.ctx_metal, gf); - } else { - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); + if (ggml_backend_is_cpu(wctx.backend_main())) { + ggml_backend_cpu_set_n_threads(wctx.backend_main(), n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(wctx.backend_main())) { + ggml_backend_metal_set_n_cb(wctx.backend_main(), n_threads); } -#else - ggml_graph_compute_helper(wstate.work_buffer, gf, n_threads, abort_callback, abort_callback_data); #endif + ggml_backend_graph_compute(wctx.backend_main(), gf); } // extract logits for all N tokens @@ -2428,7 +2464,8 @@ static bool whisper_decode_internal( // extract logits only for the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab); + ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab); if (n_tokens > 1) { //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, @@ -2794,7 +2831,7 @@ static std::vector tokenize(const whisper_vocab & vocab, cons --j; } if (!found) { - log("unknown token\n"); + WHISPER_LOG_ERROR("unknown token\n"); ++i; } } @@ -2857,45 +2894,46 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) { struct whisper_state * whisper_init_state(whisper_context * ctx) { fill_sin_cos_table(); + whisper_state * state = new whisper_state; - if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->itype, ctx->model.hparams.n_text_ctx)) { - log("%s: kv_cache_init() failed for self-attention cache\n", __func__); + if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend_kv(), ctx->itype, ctx->model.hparams.n_text_ctx)) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state; return nullptr; } { const size_t memory_size = ggml_nbytes(state->decoders[0].kv_self.k) + ggml_nbytes(state->decoders[0].kv_self.v); - log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { - log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); + if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->backend_kv(), ctx->itype, ctx->model.hparams.n_audio_ctx)) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); delete state; return nullptr; } { const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); - log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } #ifdef WHISPER_USE_COREML const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); - log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); - log("%s: first run on a device may take a while ...\n", __func__); + WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); if (!state->ctx_coreml) { - log("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); #ifndef WHISPER_COREML_ALLOW_FALLBACK delete state; return nullptr; #endif } else { - log("%s: Core ML model loaded\n", __func__); + WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); } #endif @@ -2912,37 +2950,37 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // conv allocator { - whisper_allocr_graph_init(state->alloc_conv, + whisper_allocr_graph_init(state->alloc_conv, ctx->backend_conv(), [&]() { return whisper_build_graph_conv(*ctx, *state, 0); }); - log("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1024.0 / 1024.0); } // encoder allocator if (!whisper_encode_external(*state)) { - whisper_allocr_graph_init(state->alloc_encode, + whisper_allocr_graph_init(state->alloc_encode, ctx->backend_main(), [&]() { return whisper_build_graph_encoder(*ctx, *state); }); - log("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1024.0 / 1024.0); } // cross allocator { - whisper_allocr_graph_init(state->alloc_cross, + whisper_allocr_graph_init(state->alloc_cross, ctx->backend_main(), [&]() { return whisper_build_graph_cross(*ctx, *state); }); - log("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1024.0 / 1024.0); } // decoder allocator { - whisper_allocr_graph_init(state->alloc_decode, + whisper_allocr_graph_init(state->alloc_decode, ctx->backend_main(), [&]() { const auto & hparams = ctx->model.hparams; @@ -2953,70 +2991,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { return whisper_build_graph_decoder(*ctx, *state, state->decoders[0], nullptr, n_tokens, n_past); }); - log("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); + WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1024.0 / 1024.0); } -#ifdef GGML_USE_METAL - if (ctx->params.use_gpu) { - state->ctx_metal = ggml_metal_init(1); - if (!state->ctx_metal) { - log("%s: ggml_metal_init() failed\n", __func__); - delete state; - return nullptr; - } - } - - if (state->ctx_metal) { - log("%s: Metal context initialized\n", __func__); - - // this allocates all Metal resources and memory buffers - - void * data_ptr = NULL; - size_t data_size = 0; - - // TODO: add mmap support - //if (params.use_mmap) { - // data_ptr = ctx->model.mapping->addr; - // data_size = ctx->model.mapping->size; - //} else { - // data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - // data_size = ggml_get_mem_size (ctx->model.ctx); - //} - - data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - data_size = ggml_get_mem_size (ctx->model.ctx); - - const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); - - log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); - -#define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - delete state; \ - return nullptr; \ - } - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); -#undef WHISPER_METAL_CHECK_BUF - - } -#endif - state->rng = std::mt19937(0); return state; @@ -3036,7 +3013,7 @@ int whisper_ctx_init_openvino_encoder( return 1; #else if (!model_path && ctx->path_model.empty()) { - log("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); + WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); return 1; } @@ -3056,15 +3033,15 @@ int whisper_ctx_init_openvino_encoder( path_cache = cache_dir; } - log("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); - log("%s: first run on a device may take a while ...\n", __func__); + WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str()); if (!ctx->state->ctx_openvino) { - log("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); + WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); return 1; } else { - log("%s: OpenVINO model loaded\n", __func__); + WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__); } return 0; @@ -3079,11 +3056,11 @@ struct whisper_context_params whisper_context_default_params() { } struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) { - log("%s: loading model from '%s'\n", __func__, path_model); + WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); auto fin = std::ifstream(path_model, std::ios::binary); if (!fin) { - log("%s: failed to open '%s'\n", __func__, path_model); + WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); return nullptr; } @@ -3125,7 +3102,7 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; - log("%s: loading model from buffer\n", __func__); + WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__); whisper_model_loader loader = {}; @@ -3161,7 +3138,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ if (!whisper_model_load(loader, *ctx)) { loader->close(loader->context); - log("%s: failed to load model\n", __func__); + WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); delete ctx; return nullptr; } @@ -3256,13 +3233,6 @@ void whisper_free_state(struct whisper_state * state) } #endif -#ifdef GGML_USE_METAL - if (state->ctx_metal) { - ggml_metal_free(state->ctx_metal); - state->ctx_metal = nullptr; - } -#endif - #ifdef WHISPER_USE_OPENVINO if (state->ctx_openvino != nullptr) { whisper_openvino_free(state->ctx_openvino); @@ -3271,9 +3241,9 @@ void whisper_free_state(struct whisper_state * state) #endif whisper_allocr_free(state->alloc_conv); - whisper_allocr_free(state->alloc_decode); - whisper_allocr_free(state->alloc_cross); whisper_allocr_free(state->alloc_encode); + whisper_allocr_free(state->alloc_cross); + whisper_allocr_free(state->alloc_decode); delete state; } @@ -3284,8 +3254,8 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.ctx) { ggml_free(ctx->model.ctx); } - if (ctx->model.buf) { - delete ctx->model.buf; + if (ctx->model.data) { + delete ctx->model.data; } whisper_free_state(ctx->state); @@ -3308,7 +3278,7 @@ void whisper_free_params(struct whisper_full_params * params) { int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { - log("%s: failed to compute mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3322,7 +3292,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { - log("%s: failed to compute mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3350,7 +3320,7 @@ int whisper_set_mel_with_state( int n_len, int n_mel) { if (n_mel != ctx->model.filters.n_mel) { - log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); return -1; } @@ -3374,7 +3344,7 @@ int whisper_set_mel( int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return -1; } @@ -3383,7 +3353,7 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return -1; } @@ -3394,7 +3364,7 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state const int selected_decoder_id = 0; if (!whisper_decode_internal(*ctx, *state, state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3406,12 +3376,12 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i const int selected_decoder_id = 0; if (ctx->state == nullptr) { - log("%s: ERROR state was not loaded.\n", __func__); + WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); return false; } if (!whisper_decode_internal(*ctx, *ctx->state, ctx->state->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads, nullptr, nullptr)) { - log("%s: failed to eval\n", __func__); + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3422,7 +3392,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to const auto res = tokenize(ctx->vocab, text); if (n_max_tokens < (int) res.size()) { - log("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); return -1; } @@ -3450,7 +3420,7 @@ int whisper_lang_id(const char * lang) { } } - log("%s: unknown language '%s'\n", __func__, lang); + WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang); return -1; } return g_lang.at(lang).first; @@ -3463,7 +3433,7 @@ const char * whisper_lang_str(int id) { } } - log("%s: unknown language id %d\n", __func__, id); + WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id); return nullptr; } @@ -3476,25 +3446,25 @@ int whisper_lang_auto_detect_with_state( const int seek = offset_ms/10; if (seek < 0) { - log("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); return -1; } if (seek >= state->mel.n_len_org) { - log("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); + WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); return -2; } // run the encoder if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) { - log("%s: failed to encode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); return -6; } const std::vector prompt = { whisper_token_sot(ctx) }; if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -3694,8 +3664,8 @@ whisper_token whisper_token_transcribe(struct whisper_context * ctx) { void whisper_print_timings(struct whisper_context * ctx) { const int64_t t_end_us = ggml_time_us(); - log("\n"); - log("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + WHISPER_LOG_INFO("\n"); + WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); if (ctx->state != nullptr) { const int32_t n_sample = std::max(1, ctx->state->n_sample); @@ -3703,14 +3673,14 @@ void whisper_print_timings(struct whisper_context * ctx) { const int32_t n_decode = std::max(1, ctx->state->n_decode); const int32_t n_prompt = std::max(1, ctx->state->n_prompt); - log("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); - log("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); - log("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); - log("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); - log("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); - log("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); + WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } - log("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); + WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); } void whisper_reset_timings(struct whisper_context * ctx) { @@ -4055,7 +4025,7 @@ static void whisper_process_logits( const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; - //log("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); if (last_was_timestamp) { if (penultimate_was_timestamp) { @@ -4131,7 +4101,7 @@ static void whisper_process_logits( const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); - //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { for (int i = 0; i < vocab.token_beg; ++i) { @@ -4494,11 +4464,11 @@ int whisper_full_with_state( // compute log mel spectrogram if (params.speed_up) { // TODO: Replace PV with more advanced algorithm - log("%s: failed to compute log mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -1; } else { if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { - log("%s: failed to compute log mel spectrogram\n", __func__); + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -2; } } @@ -4510,13 +4480,13 @@ int whisper_full_with_state( const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); if (lang_id < 0) { - log("%s: failed to auto-detect language\n", __func__); + WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); return -3; } state->lang_id = lang_id; params.language = whisper_lang_str(lang_id); - log("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); if (params.detect_language) { return 0; } @@ -4574,8 +4544,8 @@ int whisper_full_with_state( if (decoder.kv_self.ctx == nullptr) { decoder.kv_self = state->decoders[0].kv_self; - if (!kv_cache_reinit(decoder.kv_self)) { - log("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); + if (!kv_cache_reinit(decoder.kv_self, ctx->backend_kv())) { + WHISPER_LOG_ERROR("%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; } @@ -4586,23 +4556,6 @@ int whisper_full_with_state( decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); decoder.logprobs.resize(ctx->vocab.n_vocab); - - // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0 -#ifdef GGML_USE_METAL - if (state->ctx_metal) { -#define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - return 0; \ - } - - const std::string kv_name = "kv_self_" + std::to_string(j); - auto & kv_self = decoder.kv_self; - - WHISPER_METAL_CHECK_BUF(ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); -#undef WHISPER_METAL_CHECK_BUF - } -#endif } } @@ -4636,7 +4589,7 @@ int whisper_full_with_state( // overwrite audio_ctx, max allowed is hparams.n_audio_ctx if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { - log("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); return -5; } state->exp_n_audio_ctx = params.audio_ctx; @@ -4661,7 +4614,7 @@ int whisper_full_with_state( // distilled models require the "no_timestamps" token // TODO: add input parameter (#1229) if (is_distil) { - log("%s: using distilled model - forcing no_timestamps\n", __func__); + WHISPER_LOG_WARN("%s: using distilled model - forcing no_timestamps\n", __func__); prompt_init.push_back(whisper_token_not(ctx)); } } @@ -4698,14 +4651,14 @@ int whisper_full_with_state( if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { - log("%s: encoder_begin_callback returned false - aborting\n", __func__); + WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); break; } } // encode audio features starting at offset seek if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to encode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); return -6; } @@ -4788,7 +4741,7 @@ int whisper_full_with_state( WHISPER_PRINT_DEBUG("\n\n"); if (!whisper_decode_internal(*ctx, *state, state->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } @@ -5012,7 +4965,7 @@ int whisper_full_with_state( //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); if (!whisper_decode_internal(*ctx, *state, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { - log("%s: failed to decode\n", __func__); + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } @@ -5338,12 +5291,12 @@ int whisper_full_parallel( ctx->state->t_decode_us /= n_processors; // print information about the audio boundaries - log("\n"); - log("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + WHISPER_LOG_WARN("\n"); + WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); for (int i = 0; i < n_processors - 1; ++i) { - log("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); } - log("%s: the transcription quality may be degraded near these boundaries\n", __func__); + WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); return ret; } @@ -5708,7 +5661,7 @@ static void whisper_exp_compute_token_level_timestamps( const int n_samples = state.energy.size(); if (n_samples == 0) { - log("%s: no signal data available\n", __func__); + WHISPER_LOG_ERROR("%s: no signal data available\n", __func__); return; } @@ -5929,6 +5882,38 @@ static void whisper_exp_compute_token_level_timestamps( //} } -void whisper_set_log_callback(whisper_log_callback callback) { - whisper_log = callback; +void whisper_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default; + g_state.log_callback_user_data = user_data; +} + +static void whisper_log_internal_v(ggml_log_level level, const char * format, va_list args) { + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args_copy); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args_copy); +} + +static void whisper_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + whisper_log_internal_v(level, format, args); + va_end(args); +} + +static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); } diff --git a/whisper.h b/whisper.h index ed1612b4..0ea5237e 100644 --- a/whisper.h +++ b/whisper.h @@ -1,6 +1,8 @@ #ifndef WHISPER_H #define WHISPER_H +#include "ggml.h" + #include #include #include @@ -110,15 +112,15 @@ extern "C" { // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure - WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params); // These are the same as the above, but the internal state of the context is not allocated automatically // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) - WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params); WHISPER_DEPRECATED( WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model), @@ -570,8 +572,7 @@ extern "C" { // Control logging output; default behavior is to print to stderr - typedef void (*whisper_log_callback)(const char * line); - WHISPER_API void whisper_set_log_callback(whisper_log_callback callback); + WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data); #ifdef __cplusplus }