From 66fc4010c2d9cb0628bd9939c1332515fe368169 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 26 Mar 2023 18:06:31 +0300 Subject: [PATCH] WIP --- examples/talk.llama/.gitignore | 1 + examples/talk.llama/CMakeLists.txt | 5 +- examples/talk.llama/speak.sh | 10 +- examples/talk.llama/talk-llama.cpp | 229 ++++++++++++++++++++++------- whisper.cpp | 30 ++-- 5 files changed, 198 insertions(+), 77 deletions(-) diff --git a/examples/talk.llama/.gitignore b/examples/talk.llama/.gitignore index 67403ae5..6b780a24 100644 --- a/examples/talk.llama/.gitignore +++ b/examples/talk.llama/.gitignore @@ -1 +1,2 @@ eleven-labs.py +audio.mp3 diff --git a/examples/talk.llama/CMakeLists.txt b/examples/talk.llama/CMakeLists.txt index 1ee9253b..c278deb8 100644 --- a/examples/talk.llama/CMakeLists.txt +++ b/examples/talk.llama/CMakeLists.txt @@ -4,10 +4,9 @@ if (WHISPER_SUPPORT_SDL2) # TODO: this is temporary # need to export ggml symbols for MSVC, but too lazy .. - add_executable(${TARGET} talk-llama.cpp llama.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp) + add_executable(${TARGET} talk-llama.cpp llama.cpp) include(DefaultTargetOptions) - target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS} ../../) - target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) endif () diff --git a/examples/talk.llama/speak.sh b/examples/talk.llama/speak.sh index 1d227097..86d48843 100755 --- a/examples/talk.llama/speak.sh +++ b/examples/talk.llama/speak.sh @@ -7,11 +7,11 @@ # Mac OS: brew install espeak # Linux: apt-get install espeak # -espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2" +#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2" # Eleven Labs # -#wd=$(dirname $0) -#script=$wd/eleven-labs.py -#python3 $script $1 "$2" -#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 +wd=$(dirname $0) +script=$wd/eleven-labs.py +python3 $script $1 "$2" +ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 diff --git a/examples/talk.llama/talk-llama.cpp b/examples/talk.llama/talk-llama.cpp index 260d8de6..2e4dc51f 100644 --- a/examples/talk.llama/talk-llama.cpp +++ b/examples/talk.llama/talk-llama.cpp @@ -168,13 +168,29 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con } // need to have leading ' ' -const std::string k_prompt = R"( Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. +//const std::string k_prompt = R"( Transcript of a dialog, where {1} interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer {1}'s requests immediately and with precision. +// +//{0}: Hello, Bob. +//{1}: Hello {0}. How may I help you today? +//{0}:)"; -User: Hello, Bob. -Bob: Hello {0}. How may I help you today? -User: Please tell me the largest city in Europe. -Bob: Sure. The largest city in Europe is Moscow, the capital of Russia. -User:)"; +const std::string k_prompt = R"( Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}. +{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision. +There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other. +The transcript only includes text, it does not include markup like HTML and Markdown. +{1} answers responds with short and concise answers. + +{0}{4} Hello, {1}! +{1}{4} Hello {0}! How may I help you today? +{0}{4} What time is it? +{1}{4} It is {2} o'clock. +{0}{4} What year is it? +{1}{4} We are in {3}. +{0}{4} What is a cat? +{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae. +{0}{4} Name a color. +{1}{4} Blue +{0}{4})"; int main(int argc, char ** argv) { whisper_params params; @@ -198,7 +214,7 @@ int main(int argc, char ** argv) { auto lparams = llama_context_default_params(); lparams.n_ctx = 512; - lparams.n_parts = 1; // TODO fix + lparams.n_parts = 2; // TODO fix lparams.seed = 1; // TODO fix lparams.f16_kv = true; @@ -242,24 +258,76 @@ int main(int argc, char ** argv) { float prob0 = 0.0f; + const std::string chat_symb = ":"; + const std::string bot_name = "LLAMA"; + std::vector pcmf32_cur; std::vector pcmf32_prompt; - std::string prompt_org = ::replace(k_prompt, "{0}", params.person); + std::string prompt_org = k_prompt; + prompt_org = ::replace(prompt_org, "{0}", params.person); + prompt_org = ::replace(prompt_org, "{1}", bot_name); + + { + // get time string + std::string time_str; + { + time_t t = time(0); + struct tm * now = localtime(&t); + char buf[128]; + strftime(buf, sizeof(buf), "%H:%M", now); + time_str = buf; + } + prompt_org = ::replace(prompt_org, "{2}", time_str); + } + + { + // get year string + std::string year_str; + { + time_t t = time(0); + struct tm * now = localtime(&t); + char buf[128]; + strftime(buf, sizeof(buf), "%Y", now); + year_str = buf; + } + prompt_org = ::replace(prompt_org, "{3}", year_str); + } + + prompt_org = ::replace(prompt_org, "{4}", chat_symb); + auto embd_inp = ::llama_tokenize(ctx_llama, prompt_org, true); const int n_ctx = llama_n_ctx(ctx_llama); - fprintf(stdout, "\n"); - fprintf(stdout, "%s", prompt_org.c_str()); - fflush(stdout); - //if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) { - // fprintf(stderr, "%s : failed to eval\n", __func__); - // return 1; - //} + printf("\n"); + printf("%s : initializing - please wait ...\n", __func__); - const int n_init = embd_inp.size(); - const int voice_id = rand()%6; + if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + + //fprintf(stdout, "\n"); + //fprintf(stdout, "%s", prompt_org.c_str()); + //fflush(stdout); + + printf("%s : done! start speaking in the microphone\n", __func__); + printf("\n"); + printf("%s%s", params.person.c_str(), chat_symb.c_str()); + fflush(stdout); + + const int n_keep = embd_inp.size(); + const int voice_id = 2; + + int n_past = n_keep; + int n_prev = 64; // TODO arg + + std::vector embd; + + std::vector antiprompts = { + params.person + chat_symb, + }; // main loop while (is_running) { @@ -279,7 +347,7 @@ int main(int argc, char ** argv) { audio.get(2000, pcmf32_cur); if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) { - fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + //fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); audio.get(params.voice_ms, pcmf32_cur); @@ -314,7 +382,7 @@ int main(int argc, char ** argv) { const std::vector tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false); if (text_heard.empty() || tokens.empty() || force_speak) { - fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__); + //fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__); audio.clear(); continue; @@ -322,53 +390,106 @@ int main(int argc, char ** argv) { force_speak = false; - fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", text_heard.c_str(), "\033[0m", (int) t_ms); + text_heard.insert(0, 1, ' '); + text_heard += "\n" + bot_name + chat_symb; + fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m"); + fflush(stdout); + embd = ::llama_tokenize(ctx_llama, text_heard, false); + + // text inference + bool done = false; std::string text_to_speak; - //std::string prompt_base = gpt2_get_prompt(ctx_gpt); + while (true) { + // predict + if (embd.size() > 0) { + if (n_past + (int) embd.size() > n_ctx) { + n_past = n_keep; - //std::string text_to_speak; + // insert n_left/2 tokens at the start of embd from last_n_tokens + embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); - //{ - // prompt_base += "B: " + text_heard + "\n"; + //printf("\n---\n"); + //printf("resetting: '"); + //for (int i = 0; i < (int) embd.size(); i++) { + // printf("%s", llama_token_to_str(ctx_llama, embd[i])); + //} + //printf("'\n"); + //printf("\n---\n"); + } - // std::string prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base); + if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + } - // text_to_speak = gpt2_gen_text(ctx_gpt, prompt.c_str(), params.max_tokens); - // text_to_speak = std::regex_replace(text_to_speak, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), ""); - // text_to_speak = text_to_speak.substr(0, text_to_speak.find_first_of('\n')); + //printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size()); - // // remove first 2 lines of base prompt - // if (n_iter > 4) { - // { - // const size_t pos = prompt_base.find_first_of('\n'); - // if (pos != std::string::npos) { - // prompt_base = prompt_base.substr(pos + 1); - // } - // } - // { - // const size_t pos = prompt_base.find_first_of('\n'); - // if (pos != std::string::npos) { - // prompt_base = prompt_base.substr(pos + 1); - // } - // } - // } + embd_inp.insert(embd_inp.end(), embd.begin(), embd.end()); + n_past += embd.size(); + embd.clear(); - // prompt_base += "A:" + text_to_speak + "\n"; + if (done) break; - // { - // prompt = ::replace(::replace(k_prompt, "{0}", params.person), "{1}", prompt_base); + { + // out of user input, sample next token + const float top_k = 5; + const float top_p = 0.80f; + const float temp = 0.30f; + const float repeat_penalty = 1.1764f; - // printf("===============\n"); - // printf("prompt:\n"); - // printf("%s\n", prompt.c_str()); - // printf("===============\n"); - // } - //} + const int repeat_last_n = 256; - //gpt2_set_prompt(ctx_gpt, prompt_base.c_str()); + llama_token id = 0; + + { + //auto logits = llama_get_logits(ctx_llama); + //logits[llama_token_eos()] = 0; + + id = llama_sample_top_p_top_k(ctx_llama, + embd_inp.data() + std::max(0, n_past - repeat_last_n), + repeat_last_n, top_k, top_p, temp, repeat_penalty); + } + + if (id != llama_token_eos()) { + // add it to the context + embd.push_back(id); + + text_to_speak += llama_token_to_str(ctx_llama, id); + + printf("%s", llama_token_to_str(ctx_llama, id)); + } else { + // TODO + printf("EOS TOKEN - SHOULD NOT HAPPEN\n"); + exit(0); + } + } + + { + std::string last_output; + for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) { + last_output += llama_token_to_str(ctx_llama, embd_inp[i]); + } + last_output += llama_token_to_str(ctx_llama, embd[0]); + + for (std::string & antiprompt : antiprompts) { + if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { + done = true; + text_to_speak = ::replace(text_to_speak, antiprompt, ""); + fflush(stdout); + break; + } + } + } + + is_running = sdl_poll_events(); + + if (!is_running) { + break; + } + } - text_to_speak = ::replace(text_to_speak, params.person + ": ", ""); system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str()); audio.clear(); diff --git a/whisper.cpp b/whisper.cpp index 1ad0ce54..75c1e260 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2514,21 +2514,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } -#ifdef WHISPER_USE_COREML - const auto path_coreml = whisper_get_coreml_path(ctx->path_model); - - fprintf(stderr, "%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); - fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__); - - state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); - if (!state->ctx_coreml) { - fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); - return nullptr; - } - - fprintf(stderr, "%s: Core ML model loaded\n", __func__); -#endif - state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); state->logits_id.reserve(ctx->model.hparams.n_vocab); @@ -2548,6 +2533,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->rng = std::mt19937(0); +#ifdef WHISPER_USE_COREML + const auto path_coreml = whisper_get_coreml_path(ctx->path_model); + + fprintf(stderr, "%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + fprintf(stderr, "%s: first run on a device may take a while ...\n", __func__); + + state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); + if (!state->ctx_coreml) { + fprintf(stderr, "%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); + return nullptr; + } + + fprintf(stderr, "%s: Core ML model loaded\n", __func__); +#endif + return state; }