talk-llama : add new example + sync ggml from llama.cpp (#664)

* talk-llama : talk with LLaMA AI

* talk.llama : disable EOS token

* talk-llama : add README instructions

* ggml : fix build in debug
This commit is contained in:
Georgi Gerganov 2023-03-27 21:00:32 +03:00 committed by GitHub
parent 8e361d90d7
commit 4a0deb8b1e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 5061 additions and 528 deletions

3
.gitignore vendored
View File

@ -18,6 +18,7 @@ build-sanitize-thread/
/stream
/command
/talk
/talk-llama
/bench
arm_neon.h
@ -32,3 +33,5 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
extra/bench-gg.txt
*.mlmodel*

View File

@ -36,7 +36,7 @@ LDFLAGS =
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
ifneq ($(wildcard /usr/include/musl/*),)
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
endif
@ -178,7 +178,7 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
default: main
default: main bench
#
# Build library
@ -197,7 +197,7 @@ libwhisper.so: ggml.o whisper.o
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
clean:
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
rm -f *.o main stream command talk talk-llama bench libwhisper.a libwhisper.so
#
# Examples
@ -212,6 +212,9 @@ main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
./main -h
bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
@ -221,8 +224,8 @@ command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whi
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk-llama $(CC_SDL) $(LDFLAGS)
#
# Audio samples

View File

@ -63,4 +63,5 @@ else()
add_subdirectory(command)
add_subdirectory(bench)
add_subdirectory(talk)
add_subdirectory(talk-llama)
endif()

2
examples/talk-llama/.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
eleven-labs.py
audio.mp3

View File

@ -0,0 +1,10 @@
if (WHISPER_SUPPORT_SDL2)
# talk-llama
set(TARGET talk-llama)
add_executable(${TARGET} talk-llama.cpp llama.cpp)
include(DefaultTargetOptions)
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
endif ()

View File

@ -0,0 +1,32 @@
# talk-llama
Talk with an LLaMA AI in your terminal
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
## Building
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
```bash
# Install SDL2 on Linux
sudo apt-get install libsdl2-dev
# Install SDL2 on Mac OS
brew install sdl2
# Build the "talk-llama" executable
make talk-llama
# Run it
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
```
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
## TTS
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
By default, it is configured to use MacOS's `say`, but you can use whatever you wish.

File diff suppressed because it is too large Load Diff

153
examples/talk-llama/llama.h Normal file
View File

@ -0,0 +1,153 @@
#ifndef LLAMA_H
#define LLAMA_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#ifdef LLAMA_SHARED
# ifdef _WIN32
# ifdef LLAMA_BUILD
# define LLAMA_API __declspec(dllexport)
# else
# define LLAMA_API __declspec(dllimport)
# endif
# else
# define LLAMA_API __attribute__ ((visibility ("default")))
# endif
#else
# define LLAMA_API
#endif
#define LLAMA_FILE_VERSION 1
#define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
#ifdef __cplusplus
extern "C" {
#endif
//
// C interface
//
// TODO: show sample usage
//
struct llama_context;
typedef int llama_token;
typedef struct llama_token_data {
llama_token id; // token id
float p; // probability of the token
float plog; // log probability of the token
} llama_token_data;
typedef void (*llama_progress_callback)(double progress, void *ctx);
struct llama_context_params {
int n_ctx; // text context
int n_parts; // -1 for default
int seed; // RNG seed, 0 for random
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback
void * progress_callback_user_data;
};
LLAMA_API struct llama_context_params llama_context_default_params();
// Various functions for loading a ggml llama model.
// Allocate (almost) all memory needed for the model.
// Return NULL on failure
LLAMA_API struct llama_context * llama_init_from_file(
const char * path_model,
struct llama_context_params params);
// Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx);
// TODO: not great API - very likely to change
// Returns 0 on success
LLAMA_API int llama_model_quantize(
const char * fname_inp,
const char * fname_out,
int itype,
int qk);
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls
// Returns 0 on success
LLAMA_API int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
int n_tokens,
int n_past,
int n_threads);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
LLAMA_API int llama_tokenize(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
LLAMA_API int llama_n_embd (struct llama_context * ctx);
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token
// Rows: n_tokens
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
// Special tokens
LLAMA_API llama_token llama_token_bos();
LLAMA_API llama_token llama_token_eos();
// TODO: improve the last_n_tokens interface ?
LLAMA_API llama_token llama_sample_top_p_top_k(
struct llama_context * ctx,
const llama_token * last_n_tokens_data,
int last_n_tokens_size,
int top_k,
double top_p,
double temp,
double repeat_penalty);
// Performance information
LLAMA_API void llama_print_timings(struct llama_context * ctx);
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
// Print system information
LLAMA_API const char * llama_print_system_info(void);
#ifdef __cplusplus
}
#endif
#endif

20
examples/talk-llama/speak.sh Executable file
View File

@ -0,0 +1,20 @@
#!/bin/bash
# Usage:
# speak.sh <voice_id> <text-to-speak>
# espeak
# Mac OS: brew install espeak
# Linux: apt-get install espeak
#
#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
# for Mac
say "$2"
# Eleven Labs
#
#wd=$(dirname $0)
#script=$wd/eleven-labs.py
#python3 $script $1 "$2" >/dev/null 2>&1
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1

View File

@ -0,0 +1,529 @@
// Talk with AI
//
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "llama.h"
#include <cassert>
#include <cstdio>
#include <fstream>
#include <regex>
#include <string>
#include <thread>
#include <vector>
#include <regex>
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
std::vector<llama_token> res(text.size() + (int)add_bos);
int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
assert(n >= 0);
res.resize(n);
return res;
}
// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t voice_ms = 10000;
int32_t capture_id = -1;
int32_t max_tokens = 32;
int32_t audio_ctx = 0;
float vad_thold = 0.6f;
float freq_thold = 100.0f;
bool speed_up = false;
bool translate = false;
bool print_special = false;
bool print_energy = false;
bool no_timestamps = true;
std::string person = "Georgi";
std::string language = "en";
std::string model_wsp = "models/ggml-base.en.bin";
std::string model_llama = "models/ggml-llama-7B.bin";
std::string speak = "./examples/talk/speak.sh";
std::string fname_out;
};
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
}
return true;
}
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
fprintf(stderr, " -mg FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n");
}
std::string transcribe(
whisper_context * ctx,
const whisper_params & params,
const std::vector<float> & pcmf32,
const std::string prompt_text,
float & prob,
int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
prob = 0.0f;
t_ms = 0;
std::vector<whisper_token> prompt_tokens;
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size()));
wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}
int prob_n = 0;
std::string result;
const int n_segments = whisper_full_n_segments(ctx);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
result += text;
const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);
prob += token.p;
++prob_n;
}
}
if (prob_n > 0) {
prob /= prob_n;
}
const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
return result;
}
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
// need to have leading ' '
const std::string k_prompt_llama = R"( Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}s requests immediately and with details and precision.
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
The transcript only includes text, it does not include markup like HTML and Markdown.
{1} responds with short and concise answers.
{0}{4} Hello, {1}!
{1}{4} Hello {0}! How may I help you today?
{0}{4} What time is it?
{1}{4} It is {2} o'clock.
{0}{4} What year is it?
{1}{4} We are in {3}.
{0}{4} What is a cat?
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
{0}{4} Name a color.
{1}{4} Blue
{0}{4})";
int main(int argc, char ** argv) {
whisper_params params;
if (whisper_params_parse(argc, argv, params) == false) {
return 1;
}
if (whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
exit(0);
}
// whisper init
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
// llama init
auto lparams = llama_context_default_params();
// tune these to your liking
lparams.n_ctx = 512;
lparams.seed = 1;
lparams.f16_kv = true;
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
// print some info about the processing
{
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx_wsp)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
__func__,
params.n_threads,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1);
fprintf(stderr, "\n");
}
// init audio
audio_async audio(30*1000);
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
return 1;
}
audio.resume();
int n_iter = 0;
bool is_running = true;
bool force_speak = false;
float prob0 = 0.0f;
const std::string chat_symb = ":";
const std::string bot_name = "LLaMA";
std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
// construct the initial prompt for LLaMA inference
std::string prompt_llama = k_prompt_llama;
prompt_llama = ::replace(prompt_llama, "{0}", params.person);
prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
{
// get time string
std::string time_str;
{
time_t t = time(0);
struct tm * now = localtime(&t);
char buf[128];
strftime(buf, sizeof(buf), "%H:%M", now);
time_str = buf;
}
prompt_llama = ::replace(prompt_llama, "{2}", time_str);
}
{
// get year string
std::string year_str;
{
time_t t = time(0);
struct tm * now = localtime(&t);
char buf[128];
strftime(buf, sizeof(buf), "%Y", now);
year_str = buf;
}
prompt_llama = ::replace(prompt_llama, "{3}", year_str);
}
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
// evaluate the initial prompt
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
printf("\n");
printf("%s : initializing - please wait ...\n", __func__);
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
//fprintf(stdout, "\n");
//fprintf(stdout, "%s", prompt_llama.c_str());
//fflush(stdout);
printf("%s : done! start speaking in the microphone\n", __func__);
printf("\n");
printf("%s%s", params.person.c_str(), chat_symb.c_str());
fflush(stdout);
// clear audio buffer
audio.clear();
// text inference variables
const int voice_id = 2;
const int n_keep = embd_inp.size();
const int n_ctx = llama_n_ctx(ctx_llama);
int n_past = n_keep;
int n_prev = 64; // TODO arg
std::vector<llama_token> embd;
// reverse prompts for detecting when it's time to stop speaking
std::vector<std::string> antiprompts = {
params.person + chat_symb,
};
// main loop
while (is_running) {
// handle Ctrl + C
is_running = sdl_poll_events();
if (!is_running) {
break;
}
// delay
std::this_thread::sleep_for(std::chrono::milliseconds(100));
int64_t t_ms = 0;
{
audio.get(2000, pcmf32_cur);
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
//fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur);
std::string text_heard;
if (!force_speak) {
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
}
// remove text between brackets using regex
{
std::regex re("\\[.*?\\]");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove text between brackets using regex
{
std::regex re("\\(.*?\\)");
text_heard = std::regex_replace(text_heard, re, "");
}
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
// take first line
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
// remove leading and trailing whitespace
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
if (text_heard.empty() || tokens.empty() || force_speak) {
//fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
audio.clear();
continue;
}
force_speak = false;
text_heard.insert(0, 1, ' ');
text_heard += "\n" + bot_name + chat_symb;
fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
fflush(stdout);
embd = ::llama_tokenize(ctx_llama, text_heard, false);
// text inference
bool done = false;
std::string text_to_speak;
while (true) {
// predict
if (embd.size() > 0) {
if (n_past + (int) embd.size() > n_ctx) {
n_past = n_keep;
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
//}
//printf("'\n");
//printf("\n---\n");
}
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}
//printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
n_past += embd.size();
embd.clear();
if (done) break;
{
// out of user input, sample next token
const float top_k = 5;
const float top_p = 0.80f;
const float temp = 0.30f;
const float repeat_penalty = 1.1764f;
const int repeat_last_n = 256;
llama_token id = 0;
{
auto logits = llama_get_logits(ctx_llama);
logits[llama_token_eos()] = 0;
id = llama_sample_top_p_top_k(ctx_llama,
embd_inp.data() + std::max(0, n_past - repeat_last_n),
repeat_last_n, top_k, top_p, temp, repeat_penalty);
}
if (id != llama_token_eos()) {
// add it to the context
embd.push_back(id);
text_to_speak += llama_token_to_str(ctx_llama, id);
printf("%s", llama_token_to_str(ctx_llama, id));
}
}
{
std::string last_output;
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
}
last_output += llama_token_to_str(ctx_llama, embd[0]);
for (std::string & antiprompt : antiprompts) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
done = true;
text_to_speak = ::replace(text_to_speak, antiprompt, "");
fflush(stdout);
break;
}
}
}
is_running = sdl_poll_events();
if (!is_running) {
break;
}
}
text_to_speak = ::replace(text_to_speak, "\"", "");
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
audio.clear();
++n_iter;
}
}
}
audio.pause();
whisper_print_timings(ctx_wsp);
whisper_free(ctx_wsp);
llama_print_timings(ctx_llama);
llama_free(ctx_llama);
return 0;
}

View File

@ -7,7 +7,10 @@
# Mac OS: brew install espeak
# Linux: apt-get install espeak
#
espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
#espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
# Mac OS "say" command
say "$2"
# Eleven Labs
#

2872
ggml.c

File diff suppressed because it is too large Load Diff

27
ggml.h
View File

@ -198,6 +198,8 @@ struct ggml_object;
struct ggml_context;
enum ggml_type {
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
@ -226,7 +228,9 @@ enum ggml_op {
GGML_OP_STEP,
GGML_OP_RELU,
GGML_OP_GELU,
GGML_OP_SILU,
GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM,
GGML_OP_MUL_MAT,
@ -326,7 +330,10 @@ void ggml_print_objects(const struct ggml_context * ctx);
int ggml_nelements(const struct ggml_tensor * tensor);
size_t ggml_nbytes (const struct ggml_tensor * tensor);
size_t ggml_type_size (enum ggml_type type);
int ggml_blck_size (enum ggml_type type);
size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
size_t ggml_element_size(const struct ggml_tensor * tensor);
struct ggml_context * ggml_init(struct ggml_init_params params);
@ -336,6 +343,9 @@ size_t ggml_used_mem(const struct ggml_context * ctx);
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
bool ggml_mlock_supported(void);
bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx,
enum ggml_type type,
@ -466,12 +476,20 @@ struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_silu(
struct ggml_context * ctx,
struct ggml_tensor * a);
// normalize along rows
// TODO: eps is hardcoded to 1e-5 for now
struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a);
// A: m rows, n columns
// B: p rows, n columns (i.e. we transpose it internally)
// result is m columns, p rows
@ -726,6 +744,13 @@ enum ggml_opt_result ggml_opt(
struct ggml_opt_params params,
struct ggml_tensor * f);
//
// quantization
//
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
//
// system info
//

View File

@ -636,6 +636,8 @@ struct whisper_context {
whisper_model model;
whisper_vocab vocab;
whisper_state * state = nullptr;
std::string path_model; // populated by whisper_init_from_file()
};
template<typename T>
@ -1597,7 +1599,7 @@ static bool whisper_encode_internal(
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
cur),
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
}
}
#ifdef WHISPER_USE_FLASH_FF
wstate.use_buf(ctx0, 0);
@ -1637,7 +1639,7 @@ static bool whisper_encode_internal(
ggml_repeat(ctx0, layer.mlp_1_b, cur),
cur);
#endif
}
}
wstate.use_buf(ctx0, 3);
@ -1841,8 +1843,6 @@ static bool whisper_decode_internal(
// self-attention
{
wstate.use_buf(ctx0, 1);
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.attn_q_w,
cur);
@ -1904,8 +1904,6 @@ static bool whisper_decode_internal(
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
wstate.use_buf(ctx0, 0);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctx0,
// KQ,
@ -1914,20 +1912,16 @@ static bool whisper_decode_internal(
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
wstate.use_buf(ctx0, 1);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
wstate.use_buf(ctx0, 0);
struct ggml_tensor * V_trans =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3);
wstate.use_buf(ctx0, 1);
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
n_state/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
@ -1964,8 +1958,6 @@ static bool whisper_decode_internal(
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
wstate.use_buf(ctx0, 1);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
ggml_mul(ctx0,
@ -1976,8 +1968,6 @@ static bool whisper_decode_internal(
// cross-attention
{
wstate.use_buf(ctx0, 0);
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
layer.cross_attn_q_w,
cur);
@ -2001,12 +1991,13 @@ static bool whisper_decode_internal(
ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
n_state/n_head, n_head, M);
struct ggml_tensor * V_trans = ggml_permute(ctx0, Vcross, 1, 2, 0, 3);
struct ggml_tensor * V_trans =
ggml_cpy(ctx0,
ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
// ------
wstate.use_buf(ctx0, 1);
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
@ -2016,8 +2007,6 @@ static bool whisper_decode_internal(
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
wstate.use_buf(ctx0, 0);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
@ -2030,16 +2019,10 @@ static bool whisper_decode_internal(
// no masking for cross-attention
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
wstate.use_buf(ctx0, 1);
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
wstate.use_buf(ctx0, 0);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
wstate.use_buf(ctx0, 1);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
// cur = KQV_merged.contiguous().view(n_state, N)
@ -2482,7 +2465,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
return nullptr;
@ -2503,7 +2485,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
state->logits_id.reserve(ctx->model.hparams.n_vocab);
@ -2554,7 +2535,13 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
fin->close();
};
return whisper_init_no_state(&loader);
auto ctx = whisper_init_no_state(&loader);
if (ctx) {
ctx->path_model = path_model;
}
return ctx;
}
struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {