talk-llama : add alpaca support (#668)

This commit is contained in:
Evan Jones 2023-03-29 16:01:14 -04:00 committed by GitHub
parent 42c6855103
commit a47e812a54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 6 deletions

View File

@ -0,0 +1,23 @@
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Write a text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown.
{1} responds with short and concise answers.
### Response:
{0}{4} Hello, {1}!
{1}{4} Hello {0}! How may I help you today?
{0}{4} What time is it?
{1}{4} It is {2} o'clock.
{0}{4} What year is it?
{1}{4} We are in {3}.
{0}{4} What is a cat?
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{0}{4} Name a color.
{1}{4} Blue
{0}{4}

View File

@ -33,6 +33,8 @@ struct whisper_params {
int32_t max_tokens = 32; int32_t max_tokens = 32;
int32_t audio_ctx = 0; int32_t audio_ctx = 0;
int32_t n_parts_llama = -1;
float vad_thold = 0.6f; float vad_thold = 0.6f;
float freq_thold = 100.0f; float freq_thold = 100.0f;
@ -41,12 +43,14 @@ struct whisper_params {
bool print_special = false; bool print_special = false;
bool print_energy = false; bool print_energy = false;
bool no_timestamps = true; bool no_timestamps = true;
bool verbose_prompt = false;
std::string person = "Georgi"; std::string person = "Georgi";
std::string language = "en"; std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin"; std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin"; std::string model_llama = "models/ggml-llama-7B.bin";
std::string speak = "./examples/talk/speak.sh"; std::string speak = "./examples/talk/speak.sh";
std::string prompt = "";
std::string fname_out; std::string fname_out;
}; };
@ -67,15 +71,24 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "--n-parts-llama") { params.n_parts_llama = std::stoi(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "--verbose-prompt") { params.verbose_prompt = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; } else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "--prompt-file") {
std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') {
params.prompt.pop_back();
}
}
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else { else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
@ -108,7 +121,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -mg FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str()); fprintf(stderr, " -mg FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " --n-parts-llama N [%-7d] num parts in llama model file\n", params.n_parts_llama);
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str()); fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --verbose-prompt [%-7s] print prompt at start\n", params.verbose_prompt ? "true" : "false");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -183,8 +199,7 @@ std::string transcribe(
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)"; const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
// need to have leading ' ' const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
const std::string k_prompt_llama = R"( Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision. {1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other. There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown. The transcript only includes text, it does not include markup like HTML and Markdown.
@ -227,6 +242,7 @@ int main(int argc, char ** argv) {
lparams.n_ctx = 512; lparams.n_ctx = 512;
lparams.seed = 1; lparams.seed = 1;
lparams.f16_kv = true; lparams.f16_kv = true;
lparams.n_parts = params.n_parts_llama;
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams); struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
@ -278,7 +294,10 @@ int main(int argc, char ** argv) {
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name); const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
// construct the initial prompt for LLaMA inference // construct the initial prompt for LLaMA inference
std::string prompt_llama = k_prompt_llama; std::string prompt_llama = params.prompt.empty() ? k_prompt_llama : params.prompt;
// need to have leading ' '
prompt_llama.insert(0, 1, ' ');
prompt_llama = ::replace(prompt_llama, "{0}", params.person); prompt_llama = ::replace(prompt_llama, "{0}", params.person);
prompt_llama = ::replace(prompt_llama, "{1}", bot_name); prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
@ -323,9 +342,11 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
//fprintf(stdout, "\n"); if (params.verbose_prompt) {
//fprintf(stdout, "%s", prompt_llama.c_str()); fprintf(stdout, "\n");
//fflush(stdout); fprintf(stdout, "%s", prompt_llama.c_str());
fflush(stdout);
}
printf("%s : done! start speaking in the microphone\n", __func__); printf("%s : done! start speaking in the microphone\n", __func__);
printf("\n"); printf("\n");