diff --git a/coreml/whisper-encoder-impl.h b/coreml/whisper-encoder-impl.h index ecb61555..7b83cd90 100644 --- a/coreml/whisper-encoder-impl.h +++ b/coreml/whisper-encoder-impl.h @@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v /** Make a prediction using the convenience interface - @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats: + @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats: @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL. @return the prediction as whisper_encoder_implOutput */ diff --git a/coreml/whisper-encoder.h b/coreml/whisper-encoder.h index 84bbe416..508df7c1 100644 --- a/coreml/whisper-encoder.h +++ b/coreml/whisper-encoder.h @@ -3,6 +3,8 @@ // Code is derived from the work of Github user @wangchou // ref: https://github.com/wangchou/callCoreMLFromCpp +#include + #if __cplusplus extern "C" { #endif @@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx); void whisper_coreml_encode( const whisper_coreml_context * ctx, + int64_t n_ctx, + int64_t n_mel, float * mel, float * out); diff --git a/coreml/whisper-encoder.mm b/coreml/whisper-encoder.mm index 499edaed..8e93f180 100644 --- a/coreml/whisper-encoder.mm +++ b/coreml/whisper-encoder.mm @@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) { void whisper_coreml_encode( const whisper_coreml_context * ctx, + int64_t n_ctx, + int64_t n_mel, float * mel, float * out) { MLMultiArray * inMultiArray = [ [MLMultiArray alloc] initWithDataPointer: mel - shape: @[@1, @80, @3000] + shape: @[@1, @(n_mel), @(n_ctx)] dataType: MLMultiArrayDataTypeFloat32 - strides: @[@(240000), @(3000), @1] + strides: @[@(n_ctx*n_mel), @(n_ctx), @1] deallocator: nil error: nil ]; diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 8c41ef54..af971cab 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -248,7 +248,7 @@ int main(int argc, char ** argv) { return 1; } - if (whisper_lang_id(params.language.c_str()) == -1) { + if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); whisper_print_usage(argc, argv, params); exit(0); diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index adbbd109..7e09f5ba 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -252,7 +252,7 @@ class WhisperANE(Whisper): def convert_encoder(hparams, model, quantize=False): model.eval() - input_shape = (1, 80, 3000) + input_shape = (1, hparams.n_mels, 3000) input_data = torch.randn(input_shape) traced_model = torch.jit.trace(model, input_data) @@ -302,7 +302,7 @@ if __name__ == "__main__": parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) args = parser.parse_args() - if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]: + if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]: raise ValueError("Invalid model name") whisper = load_model(args.model).cpu() diff --git a/models/convert-whisper-to-openvino.py b/models/convert-whisper-to-openvino.py index 6b3d3966..88e03ff7 100644 --- a/models/convert-whisper-to-openvino.py +++ b/models/convert-whisper-to-openvino.py @@ -9,7 +9,7 @@ import shutil def convert_encoder(hparams, encoder, mname): encoder.eval() - mel = torch.zeros((1, 80, 3000)) + mel = torch.zeros((1, hparams.n_mels, 3000)) onnx_folder=os.path.join(os.path.dirname(__file__),"onnx_encoder") diff --git a/whisper.cpp b/whisper.cpp index e2f7eb2a..12ba855a 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1639,7 +1639,7 @@ static struct ggml_cgraph * whisper_build_graph_conv( ggml_allocr_alloc(alloc, cur); if (!ggml_allocr_is_measure(alloc)) { - whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data); + whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data); } #endif #ifdef WHISPER_USE_OPENVINO @@ -3708,6 +3708,7 @@ void whisper_print_timings(struct whisper_context * ctx) { void whisper_reset_timings(struct whisper_context * ctx) { ctx->t_start_us = ggml_time_us(); if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; ctx->state->t_sample_us = 0; ctx->state->t_encode_us = 0; ctx->state->t_decode_us = 0;