// Talk with AI // #include "common.h" #include "common-sdl.h" #include "whisper.h" #include "llama.h" #include #include #include #include #include #include #include #include #include std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars std::vector res(text.size() + (int)add_bos); int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); assert(n >= 0); res.resize(n); return res; } // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t voice_id = 0; int32_t voice_ms = 10000; int32_t capture_id = -1; int32_t max_tokens = 64; int32_t audio_ctx = 0; int32_t n_parts_llama = -1; float vad_thold = 0.4f; float freq_thold = 100.0f; bool speed_up = false; bool translate = false; bool print_special = false; bool print_energy = false; bool no_timestamps = true; bool verbose_prompt = false; std::string name_ni = "Georgi"; // natural intelligence std::string name_ai = "LLaMA"; // artificial intelligence std::string language = "en"; std::string model_wsp = "models/ggml-base.en.bin"; std::string model_llama = "models/ggml-llama-7B.bin"; std::string speak = "./examples/talk/speak.sh"; std::string prompt = ""; std::string fname_out; }; void whisper_print_usage(int argc, char ** argv, const whisper_params & params); bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { for (int i = 1; i < argc; i++) { std::string arg = argv[i]; if (arg == "-h" || arg == "--help") { whisper_print_usage(argc, argv, params); exit(0); } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } else if (arg == "-vid" || arg == "--voice-id") { params.voice_id = std::stoi(argv[++i]); } else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); } else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-nni" || arg == "--name-ni") { params.name_ni = argv[++i]; } else if (arg == "-nai" || arg == "--name-ai") { params.name_ai = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; } else if (arg == "--prompt-file") { std::ifstream file(argv[++i]); std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); if (params.prompt.back() == '\n') { params.prompt.pop_back(); } } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params); exit(0); } } return true; } void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) { fprintf(stderr, "\n"); fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n"); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -vid N, --voice-id N [%-7d] voice ID\n", params.voice_id); fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms); fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); fprintf(stderr, " -nni NAME,--name-ni NAME [%-7s] natural intelligence name\n", params.name_ni.c_str()); fprintf(stderr, " -nai NAME,--name-ai NAME [%-7s] artificial intelligence name\n", params.name_ai.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str()); fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama); fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str()); fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", ""); fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false"); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, "\n"); } std::string transcribe( whisper_context * ctx, const whisper_params & params, const std::vector & pcmf32, const std::string prompt_text, float & prob, int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); prob = 0.0f; t_ms = 0; std::vector prompt_tokens; whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); prompt_tokens.resize(1024); prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size())); wparams.print_progress = false; wparams.print_special = params.print_special; wparams.print_realtime = false; wparams.print_timestamps = !params.no_timestamps; wparams.translate = params.translate; wparams.no_context = true; wparams.single_segment = true; wparams.max_tokens = params.max_tokens; wparams.language = params.language.c_str(); wparams.n_threads = 2; wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; static int iter = params.voice_id; std::this_thread::sleep_for(std::chrono::milliseconds(100*iter)); iter = (iter + 1) % 4; if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { return ""; } int prob_n = 0; std::string result; const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); result += text; const int n_tokens = whisper_full_n_tokens(ctx, i); for (int j = 0; j < n_tokens; ++j) { const auto token = whisper_full_get_token_data(ctx, i, j); prob += token.p; ++prob_n; } } if (prob_n > 0) { prob /= prob_n; } const auto t_end = std::chrono::high_resolution_clock::now(); t_ms = std::chrono::duration_cast(t_end - t_start).count(); return result; } const std::vector k_participants = { "LLaMA", "GGaMA", "SSaMA", "RRaMA", }; // homophones const std::map> k_homophones = { { "LLaMA", { "llama", "Llama", "LLAMA", }, }, { "GGaMA", { "gama", "Gama", "GAMA", "gamma", "Gamma", "GAMMA", }, }, { "SSaMA", { "sama", "Sama", "SAMA", "samma", "Samma", "SAMMA", }, }, { "RRaMA", { "rama", "Rama", "RAMA", "ramma", "Ramma", "RAMMA", }, }, }; const std::string k_prompt_whisper = R"(A conversation between {1}, {10}, {11}, {12} and {13}.)"; const std::map k_prompt = { { k_participants.at(0), R"(Text transcript of a never ending dialog, between {1}, {10}, {11}, {12} and {13}. There are no annotations like (30 seconds passed...) or (to himself), just what the participants say aloud to each other. The transcript only includes text, it does not include markup like HTML and Markdown. {10}, {11}, {12} and {13} respond with short and concise answers. {10} is smart, objective, honest and kind. Never fails to give a meaningful and insightful answer and opinion. {1} is leading the conversation and asking the questions. {1}{4} Hello {10}! What is your opinion on the current state of the world? {10}{4} Great question {1}! I think we live in a very interesting time. There are many things to be concerned about, but also many things to be optimistic about. {1}{4} What advice would you give to a young person who is just starting out in life? {10}{4} I would tell them to be patient and to not be afraid to fail. It is important to learn from your mistakes and to keep trying. {1}{4})" }, { k_participants.at(1), R"(Text transcript of a never ending dialog, between {1}, {10}, {11}, {12} and {13}. There are no annotations like (30 seconds passed...) or (to himself), just what the participants say aloud to each other. The transcript only includes text, it does not include markup like HTML and Markdown. {10}, {11}, {12} and {13} respond with short and concise answers. {11} has critical thinking skills, is very knowledgeable and is a good listener. He is very humble and never arrogant. {1} is leading the conversation and asking the questions. {1}{4} Hello {11}! What is your opinion on the current state of the world? {11}{4} The world is about to experience a major change. We are on the verge of a new era. {1}{4} What advice would you give to a young person who is just starting out in life? {11}{4} My advice would be to be open minded and to be willing to learn from others. {1}{4})" }, { k_participants.at(2), R"(Text transcript of a never ending dialog, between {1}, {10}, {11}, {12} and {13}. There are no annotations like (30 seconds passed...) or (to himself), just what the participants say aloud to each other. The transcript only includes text, it does not include markup like HTML and Markdown. {10}, {11}, {12} and {13} respond with short and concise answers. {12} has strong leadership skills, strategic thinking, and innovative ideas. Has the ability to mentor and support young people. {1} is leading the conversation and asking the questions. {1}{4} Hello {12}! What is your opinion on the current state of the world? {12}{4} Our future is bright. We are living in a time of great opportunity. {1}{4} What advice would you give to a young person who is just starting out in life? {12}{4} I would tell them to be brave and to be willing to take risks. {1}{4})" }, { k_participants.at(3), R"(Text transcript of a never ending dialog, between {1}, {10}, {11}, {12} and {13}. There are no annotations like (30 seconds passed...) or (to himself), just what the participants say aloud to each other. The transcript only includes text, it does not include markup like HTML and Markdown. {10}, {11}, {12} and {13} respond with short and concise answers. {13} is rude, arrogant, and has a bad attitude. He is very opinionated and never listens to others. {1} is leading the conversation and asking the questions. {1}{4} Hello {13}! What is your opinion on the current state of the world? {13}{4} The world is a terrible place. It is full of evil and corruption. {1}{4} What advice would you give to a young person who is just starting out in life? {13}{4} I would tell them to be selfish and to never trust anyone. {1}{4})" }, }; int main(int argc, char ** argv) { whisper_params params; if (whisper_params_parse(argc, argv, params) == false) { return 1; } if (whisper_lang_id(params.language.c_str()) == -1) { fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); whisper_print_usage(argc, argv, params); exit(0); } // whisper init struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str()); // llama init auto lparams = llama_context_default_params(); // tune these to your liking lparams.n_ctx = 512; lparams.seed = 1; lparams.f16_kv = true; lparams.n_parts = params.n_parts_llama; struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams); // print some info about the processing { fprintf(stderr, "\n"); if (!whisper_is_multilingual(ctx_wsp)) { if (params.language != "en" || params.translate) { params.language = "en"; params.translate = false; fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } } fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n", __func__, params.n_threads, params.language.c_str(), params.translate ? "translate" : "transcribe", params.no_timestamps ? 0 : 1); fprintf(stderr, "\n"); } // init audio audio_async audio(30*1000); if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) { fprintf(stderr, "%s: audio.init() failed!\n", __func__); return 1; } audio.resume(); int n_iter = 0; bool is_running = true; bool force_speak = false; float prob0 = 0.0f; const std::string chat_symb = ":"; const std::string name_ni = params.name_ni; const std::string name_ai = params.name_ai; // the participant that was referenced last std::string name_ref = name_ni; std::vector pcmf32_cur; std::vector pcmf32_prompt; std::string prompt_whisper = k_prompt_whisper; prompt_whisper = ::replace(prompt_whisper, "{1}", name_ni); prompt_whisper = ::replace(prompt_whisper, "{10}", k_participants.at(0)); prompt_whisper = ::replace(prompt_whisper, "{11}", k_participants.at(1)); prompt_whisper = ::replace(prompt_whisper, "{12}", k_participants.at(2)); prompt_whisper = ::replace(prompt_whisper, "{13}", k_participants.at(3)); // construct the initial prompt for LLaMA inference std::string prompt_llama = params.prompt.empty() ? k_prompt.find(name_ai)->second : params.prompt; // need to have leading ' ' prompt_llama.insert(0, 1, ' '); prompt_llama = ::replace(prompt_llama, "{1}", name_ni); prompt_llama = ::replace(prompt_llama, "{10}", k_participants.at(0)); prompt_llama = ::replace(prompt_llama, "{11}", k_participants.at(1)); prompt_llama = ::replace(prompt_llama, "{12}", k_participants.at(2)); prompt_llama = ::replace(prompt_llama, "{13}", k_participants.at(3)); { // get date string std::string date_str; { time_t t = time(0); struct tm * now = localtime(&t); char buf[128]; strftime(buf, sizeof(buf), "%d/%m/%Y", now); date_str = buf; } prompt_llama = ::replace(prompt_llama, "{1}", date_str); } { // get time string std::string time_str; { time_t t = time(0); struct tm * now = localtime(&t); char buf[128]; strftime(buf, sizeof(buf), "%H:%M", now); time_str = buf; } prompt_llama = ::replace(prompt_llama, "{2}", time_str); } { // get year string std::string year_str; { time_t t = time(0); struct tm * now = localtime(&t); char buf[128]; strftime(buf, sizeof(buf), "%Y", now); year_str = buf; } prompt_llama = ::replace(prompt_llama, "{3}", year_str); } prompt_llama = ::replace(prompt_llama, "{4}", chat_symb); // evaluate the initial prompt auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true); printf("\n"); printf("%s : initializing - please wait ...\n", __func__); if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } if (params.verbose_prompt) { fprintf(stdout, "\n"); fprintf(stdout, "%s", prompt_whisper.c_str()); fprintf(stdout, "\n"); fprintf(stdout, "\n"); fprintf(stdout, "%s", prompt_llama.c_str()); fprintf(stdout, "\n"); fprintf(stdout, "\n"); fflush(stdout); } printf("%s : done! start speaking in the microphone\n", __func__); printf("\n"); printf("%s%s", name_ni.c_str(), chat_symb.c_str()); fflush(stdout); // clear audio buffer audio.clear(); // text inference variables const int voice_id = params.voice_id; const int n_keep = embd_inp.size(); const int n_ctx = llama_n_ctx(ctx_llama); int n_past = n_keep; int n_prev = 64; // TODO arg std::vector embd; // reverse prompts for detecting when it's time to stop speaking std::vector antiprompts = { name_ni + chat_symb, }; for (const auto & p : k_participants) { antiprompts.push_back(p + chat_symb); } std::string text_heard_all; // main loop while (is_running) { // handle Ctrl + C is_running = sdl_poll_events(); if (!is_running) { break; } // delay std::this_thread::sleep_for(std::chrono::milliseconds(100)); int64_t t_ms = 0; { audio.get(15000, pcmf32_cur); if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) { //fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); audio.get(params.voice_ms, pcmf32_cur); std::string text_heard; if (!force_speak) { text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms)); } // remove text between brackets using regex { std::regex re("\\[.*?\\]"); text_heard = std::regex_replace(text_heard, re, ""); } // remove text between brackets using regex { std::regex re("\\(.*?\\)"); text_heard = std::regex_replace(text_heard, re, ""); } // remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' ' text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), ""); // take first line text_heard = text_heard.substr(0, text_heard.find_first_of('\n')); // remove leading and trailing whitespace text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), ""); text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), ""); const std::vector tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false); if (text_heard.empty() || tokens.empty() || force_speak) { //fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__); audio.clear(); continue; } force_speak = false; if (text_heard[0] != ' ') { text_heard.insert(0, 1, ' '); } // replace homophones for (const auto & homophone : k_homophones) { for (const auto & word : homophone.second) { text_heard = ::replace(text_heard, word, homophone.first); } } // check which participant was mentioned const auto name_ref_old = name_ref; for (const auto & participant : k_participants) { if (participant == name_ref) { continue; } if (text_heard.find(participant) != std::string::npos) { name_ref = participant; break; } } if (name_ref == name_ref_old && name_ref != name_ai) { name_ref = name_ni; } text_heard += "\n" + name_ref + chat_symb; fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m"); fflush(stdout); text_heard_all += text_heard; // keep only last 100 characters if (text_heard_all.size() > 100) { text_heard_all = text_heard_all.substr(text_heard_all.size() - 100); } if (name_ref != name_ai) { } else { // text inference bool done = false; std::string text_to_speak; embd = ::llama_tokenize(ctx_llama, text_heard_all, false); text_heard_all.clear(); while (true) { // predict if (embd.size() > 0) { if (n_past + (int) embd.size() > n_ctx) { n_past = n_keep; // insert n_left/2 tokens at the start of embd from last_n_tokens embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); //printf("\n---\n"); //printf("resetting: '"); //for (int i = 0; i < (int) embd.size(); i++) { // printf("%s", llama_token_to_str(ctx_llama, embd[i])); //} //printf("'\n"); //printf("\n---\n"); } if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } } //printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size()); embd_inp.insert(embd_inp.end(), embd.begin(), embd.end()); n_past += embd.size(); embd.clear(); if (done) break; { // out of user input, sample next token const float top_k = 5; const float top_p = 0.80f; const float temp = 0.20f; const float repeat_penalty = 1.0764f; const int repeat_last_n = 256; llama_token id = 0; { auto logits = llama_get_logits(ctx_llama); logits[llama_token_eos()] = 0; id = llama_sample_top_p_top_k(ctx_llama, embd_inp.data() + std::max(0, n_past - repeat_last_n), repeat_last_n, top_k, top_p, temp, repeat_penalty); } if (id != llama_token_eos()) { // add it to the context embd.push_back(id); text_to_speak += llama_token_to_str(ctx_llama, id); printf("%s", llama_token_to_str(ctx_llama, id)); } // new line if (id == 13) { } } { std::string last_output; for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) { last_output += llama_token_to_str(ctx_llama, embd_inp[i]); } last_output += llama_token_to_str(ctx_llama, embd[0]); for (const std::string & antiprompt : antiprompts) { if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { done = true; text_to_speak = ::replace(text_to_speak, antiprompt, ""); fflush(stdout); break; } } } is_running = sdl_poll_events(); if (!is_running) { break; } } text_to_speak = ::replace(text_to_speak, "\"", ""); system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str()); } audio.clear(); ++n_iter; } } } audio.pause(); whisper_print_timings(ctx_wsp); whisper_free(ctx_wsp); llama_print_timings(ctx_llama); llama_free(ctx_llama); return 0; }