From ea9f206f18d86c4eb357db9fdc52e4d9dc24435e Mon Sep 17 00:00:00 2001 From: matteng1 <31434228+matteng1@users.noreply.github.com> Date: Mon, 26 May 2025 07:57:39 +0200 Subject: [PATCH] talk-llama : fix for swedish umlauts + expose model inference settings in talk-llama.cpp (#3187) Quick fix for not removing swedish umlauts. * Update talk-llama.cpp Expose model inference settings to user instead of hard coding them. Same defaults as previous defaults. * Update examples/talk-llama/talk-llama.cpp Co-authored-by: Georgi Gerganov --- examples/talk-llama/talk-llama.cpp | 39 ++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 9097c491..17ae1c95 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -60,7 +60,13 @@ struct whisper_params { int32_t max_tokens = 32; int32_t audio_ctx = 0; int32_t n_gpu_layers = 999; - + int32_t seed = 0; + int32_t top_k = 5; + int32_t min_keep = 1; + float top_p = 0.80f; + float min_p = 0.01f; + float temp = 0.30f; + float vad_thold = 0.6f; float freq_thold = 100.0f; @@ -102,6 +108,12 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ngl" || arg == "--n-gpu-layers") { params.n_gpu_layers = std::stoi(argv[++i]); } + else if (arg == "--seed") { params.seed = std::stoi(argv[++i]); } + else if (arg == "--top-k") { params.top_k = std::stoi(argv[++i]); } + else if (arg == "--min-keep") { params.min_keep = std::stoul(argv[++i]);} + else if (arg == "--top-p") { params.top_p = std::stof(argv[++i]); } + else if (arg == "--min-p") { params.min_p = std::stof(argv[++i]); } + else if (arg == "--temp") { params.temp = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } @@ -150,6 +162,12 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ngl N, --n-gpu-layers N [%-7d] number of layers to store in VRAM\n", params.n_gpu_layers); + fprintf(stderr, " --seed N [%-7d] seed sampling\n", params.seed); + fprintf(stderr, " --top-k N [%-7d] top-k sampling (0 = disabled)\n", params.top_k); + fprintf(stderr, " --min-keep N [%-7d] minimum number of tokens to keep\n", params.min_keep); + fprintf(stderr, " --top-p N [%-7.2f] top-p sampling\n", params.top_p); + fprintf(stderr, " --min-p N [%-7.2f] min-p sampling\n", params.min_p); + fprintf(stderr, " --temp N [%-7.2f] temperature\n", params.temp); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); @@ -409,21 +427,16 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(llama_n_ctx(ctx_llama), 0, 1); // init sampler - const float top_k = 5; - const float top_p = 0.80f; - const float temp = 0.30f; - - const int seed = 0; - auto sparams = llama_sampler_chain_default_params(); llama_sampler * smpl = llama_sampler_chain_init(sparams); - if (temp > 0.0f) { - llama_sampler_chain_add(smpl, llama_sampler_init_top_k(top_k)); - llama_sampler_chain_add(smpl, llama_sampler_init_top_p(top_p, 1)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp (temp)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed)); + if (params.temp > 0.0f) { + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.top_p, params.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.seed)); + llama_sampler_chain_add(smpl, llama_sampler_init_min_p (params.min_p, params.min_keep)); } else { llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); } @@ -615,7 +628,7 @@ int main(int argc, char ** argv) { } // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' ' - text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), ""); + text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9åäöÅÄÖ\\.,\\?!\\s\\:\\'\\-]"), ""); // take first line text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));