mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-04 16:20:42 +02:00
This commit updates the command.wasm example by adding a server.py script to make it easy to start a local http server to try out the example, updates the build instructions, and also addresses some of the compiler warnings that were being generated.
* emscripten : fix TOTAL_STACK for wasm
This commit moves the TOTAL_STACK setting from the compile flags to the
linker flags. This is because the TOTAL_STACK setting is a linker
setting.
The motivation for this change is that currently the following warnings
are generated when building:
```console
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'TOTAL_STACK' [-Wunused-command-line-argument]
```
* examples : suppress C++17 deprecation warning for std::codecvt_utf8
This commit suppresses the C++17 deprecation warning for
std::codecvt_utf8 similar to what is done in
examples/talk-llama/unicode.cpp.
The motivation for this change is to suppress these warnings:
```console
/Users/danbev/work/ai/whisper-work/examples/common.cpp:251:31: warning: 'codecvt_utf8<wchar_t>' is deprecated [-Wdeprecated-declarations]
251 | std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/codecvt:193:28: note: 'codecvt_utf8<wchar_t>' has been explicitly marked deprecated here
193 | class _LIBCPP_TEMPLATE_VIS _LIBCPP_DEPRECATED_IN_CXX17 codecvt_utf8 : public __codecvt_utf8<_Elem> {
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:723:41: note: expanded from macro '_LIBCPP_DEPRECATED_IN_CXX17'
723 | # define _LIBCPP_DEPRECATED_IN_CXX17 _LIBCPP_DEPRECATED
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:688:49: note: expanded from macro '_LIBCPP_DEPRECATED'
688 | # define _LIBCPP_DEPRECATED __attribute__((__deprecated__))
| ^
/Users/danbev/work/ai/whisper-work/examples/common.cpp:251:10: warning: 'wstring_convert<std::codecvt_utf8<wchar_t>>' is deprecated [-Wdeprecated-declarations]
251 | std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/locale:3145:28: note: 'wstring_convert<std::codecvt_utf8<wchar_t>>' has been explicitly marked deprecated here
3145 | class _LIBCPP_TEMPLATE_VIS _LIBCPP_DEPRECATED_IN_CXX17 wstring_convert {
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:723:41: note: expanded from macro '_LIBCPP_DEPRECATED_IN_CXX17'
723 | # define _LIBCPP_DEPRECATED_IN_CXX17 _LIBCPP_DEPRECATED
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:688:49: note: expanded from macro '_LIBCPP_DEPRECATED'
688 | # define _LIBCPP_DEPRECATED __attribute__((__deprecated__))
| ^
/Users/danbev/work/ai/whisper-work/examples/common.cpp:257:31: warning: 'codecvt_utf8<wchar_t>' is deprecated [-Wdeprecated-declarations]
257 | std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/codecvt:193:28: note: 'codecvt_utf8<wchar_t>' has been explicitly marked deprecated here
193 | class _LIBCPP_TEMPLATE_VIS _LIBCPP_DEPRECATED_IN_CXX17 codecvt_utf8 : public __codecvt_utf8<_Elem> {
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:723:41: note: expanded from macro '_LIBCPP_DEPRECATED_IN_CXX17'
723 | # define _LIBCPP_DEPRECATED_IN_CXX17 _LIBCPP_DEPRECATED
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:688:49: note: expanded from macro '_LIBCPP_DEPRECATED'
688 | # define _LIBCPP_DEPRECATED __attribute__((__deprecated__))
| ^
/Users/danbev/work/ai/whisper-work/examples/common.cpp:257:10: warning: 'wstring_convert<std::codecvt_utf8<wchar_t>>' is deprecated [-Wdeprecated-declarations]
257 | std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/locale:3145:28: note: 'wstring_convert<std::codecvt_utf8<wchar_t>>' has been explicitly marked deprecated here
3145 | class _LIBCPP_TEMPLATE_VIS _LIBCPP_DEPRECATED_IN_CXX17 wstring_convert {
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:723:41: note: expanded from macro '_LIBCPP_DEPRECATED_IN_CXX17'
723 | # define _LIBCPP_DEPRECATED_IN_CXX17 _LIBCPP_DEPRECATED
| ^
/Users/danbev/work/wasm/emsdk/upstream/emscripten/cache/sysroot/include/c++/v1/__config:688:49: note: expanded from macro '_LIBCPP_DEPRECATED'
688 | # define _LIBCPP_DEPRECATED __attribute__((__deprecated__))
| ^
4 warnings generated.
```
* ggml : suppress double-promotion warning in GGML_F16x4_REDUCE
This commit adds a cast to `ggml_float` in the `GGML_F16x4_REDUCE` macro
to suppress a double-promotion warning.
Currently the following warning is generated when compiling the
command.wasm example:
```console
/whisper-work/ggml/src/ggml-cpu/ggml-cpu.c:1592:5: warning: implicit conversion increases floating-point precision: 'float' to 'ggml_float' (aka 'double') [-Wdouble-promotion]
1592 | GGML_F16_VEC_REDUCE(sumf, sum);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/danbev/work/ai/whisper-work/ggml/src/ggml-cpu/ggml-cpu.c:932:37: note: expanded from macro 'GGML_F16_VEC_REDUCE'
932 | #define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
| ^
/Users/danbev/work/ai/whisper-work/ggml/src/ggml-cpu/ggml-cpu.c:920:44: note: expanded from macro 'GGML_F16x4_REDUCE'
918 | res = wasm_f32x4_extract_lane(x[0], 0) + \
| ~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
919 | wasm_f32x4_extract_lane(x[0], 1) + \
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
920 | wasm_f32x4_extract_lane(x[0], 2) + \
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~
921 | wasm_f32x4_extract_lane(x[0], 3); \
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/whisper-work/ggml/src/ggml-cpu/ggml-cpu.c:1640:9: warning: implicit conversion increases floating-point precision: 'float' to 'ggml_float' (aka 'double') [-Wdouble-promotion]
1640 | GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/Users/danbev/work/ai/whisper-work/ggml/src/ggml-cpu/ggml-cpu.c:932:37: note: expanded from macro 'GGML_F16_VEC_REDUCE'
932 | #define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE
| ^
/Users/danbev/work/ai/whisper-work/ggml/src/ggml-cpu/ggml-cpu.c:920:44: note: expanded from macro 'GGML_F16x4_REDUCE'
918 | res = wasm_f32x4_extract_lane(x[0], 0) + \
| ~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
919 | wasm_f32x4_extract_lane(x[0], 1) + \
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
920 | wasm_f32x4_extract_lane(x[0], 2) + \
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~
921 | wasm_f32x4_extract_lane(x[0], 3); \
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2 warnings generated.
```
wasm_f32x4_extract_lane returns a 32-bit float and this is what the
addition is performed on. But there is an implicit conversion from
32-bit float to 64-bit double when the result is assigned to `res`,
which is of type `ggml_float`. My understanding here is that this is
intentional and adding a cast to `ggml_float` should suppress the
warning.
* emscripten : add -Wno-deprecated to for emscripten
This commit adds -Wno-deprecated to the CMAKE_CXX_FLAGS for emscripten
builds.
The motivation for this is that currently there a number of warnings
generated like the following:
```console
warning: JS library symbol '$print' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
warning: JS library symbol '$printErr' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
em++: warning: warnings in JS library compilation [-Wjs-compiler]
em++: warning: linker setting ignored during compilation: 'ENVIRONMENT' [-Wunused-command-line-argument]
warning: JS library symbol '$print' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
warning: JS library symbol '$printErr' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
em++: warning: warnings in JS library compilation [-Wjs-compiler]
warning: JS library symbol '$print' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
warning: JS library symbol '$printErr' is deprecated. Please open a bug if you have a continuing need for this symbol [-Wdeprecated]
em++: warning: warnings in JS library compilation [-Wjs-compiler]
em++: warning: linker setting ignored during compilation: 'ENVIRONMENT' [-Wunused-command-line-argument]
em++: warning: linker setting ignored during compilation: 'ENVIRONMENT' [-Wunused-command-line-argument]
```
The downside of this is that we might miss other deprecation warnings
in the future so I'm not sure if this is acceptable. But it make the
wasm examples cleaner without the warnings.
* examples : fix tautological-compare warning in stb_vorbis.c [no ci]
This commit applies a fix to address a tautological-compare warning
in stb_vorbis.c.
The motivation for this is that currently the following warning is
generated when compiling the commmand-wasm example:
```console
/Users/danbev/work/ai/whisper-work/examples/stb_vorbis.c:1404:75: warning: pointer comparison always evaluates to false [-Wtautological-compare]
1404 | if (f->stream_start + loc >= f->stream_end || f->stream_start + loc < f->stream_start) {
| ^
1 warning generated.
```
This fix was taken from an open pull request on the stb repository
that addreses this issue:
https://github.com/nothings/stb/pull/1746
* squash! examples : update command.wasm instructions [no ci]
This commit adds a Python script to serve the the wasm examples build
in the `build-em` directory. Initially I thought that it would be enough
to start a simple python server but I did not notice that there was an
error in the browser console when I did that:
```console
command.js:1 Uncaught (in promise) DataCloneError: Failed to execute 'postMessage' on 'Worker': SharedArrayBuffer transfer requires self.crossOriginIsolated.
at command.js:1:1206224
at new Promise (<anonymous>)
at loadWasmModuleToWorker (command.js:1:1204981)
at Array.map (<anonymous>)
at Object.loadWasmModuleToAllWorkers (command.js:1:1206428)
at command.js:1:1204318
at callRuntimeCallbacks (command.js:1:1202062)
at preRun (command.js:1:6136)
at run (command.js:1:1294094)
at removeRunDependency (command.js:1:7046)
```
We need a few CORS headers to be set and in order hopefully make this
easy for users a Python script is added to the examples directory.
This should be able to server all the wasm examples provided they have
been built. command.wasm's README.md is updated to reflect this change.
* examples : remove unused functions
This commit removed the unused functions convert_to_utf8 and
convert_to_wstring from examples/common.cpp.
* Revert "examples : fix tautological-compare warning in stb_vorbis.c [no ci]"
This reverts commit 8e3c47d961
.
We should not make this change here and instead when the upstream PR is
merged we can sync with it.
Refs: https://github.com/ggerganov/whisper.cpp/issues/2784
680 lines
23 KiB
C++
680 lines
23 KiB
C++
#define _USE_MATH_DEFINES // for M_PI
|
|
|
|
#include "common.h"
|
|
|
|
#include <cmath>
|
|
#include <codecvt>
|
|
#include <cstring>
|
|
#include <fstream>
|
|
#include <locale>
|
|
#include <regex>
|
|
#include <sstream>
|
|
|
|
#if defined(_MSC_VER)
|
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
|
#endif
|
|
|
|
// Function to check if the next argument exists
|
|
static std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) {
|
|
if (i + 1 < argc && argv[i + 1][0] != '-') {
|
|
return argv[++i];
|
|
} else {
|
|
fprintf(stderr, "error: %s requires one argument.\n", flag.c_str());
|
|
gpt_print_usage(argc, argv, params);
|
|
exit(0);
|
|
}
|
|
}
|
|
|
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|
for (int i = 1; i < argc; i++) {
|
|
std::string arg = argv[i];
|
|
|
|
if (arg == "-s" || arg == "--seed") {
|
|
params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "-t" || arg == "--threads") {
|
|
params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "-p" || arg == "--prompt") {
|
|
params.prompt = get_next_arg(i, argc, argv, arg, params);
|
|
} else if (arg == "-n" || arg == "--n_predict") {
|
|
params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "-np" || arg == "--n_parallel") {
|
|
params.n_parallel = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "--top_k") {
|
|
params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "--top_p") {
|
|
params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "--temp") {
|
|
params.temp = std::stof(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "--repeat-last-n") {
|
|
params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "--repeat-penalty") {
|
|
params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "-b" || arg == "--batch_size") {
|
|
params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "-c" || arg == "--context") {
|
|
params.n_ctx= std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") {
|
|
params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "--ignore-eos") {
|
|
params.ignore_eos = true;
|
|
} else if (arg == "-m" || arg == "--model") {
|
|
params.model = get_next_arg(i, argc, argv, arg, params);
|
|
} else if (arg == "-i" || arg == "--interactive") {
|
|
params.interactive = true;
|
|
} else if (arg == "-ip" || arg == "--interactive-port") {
|
|
params.interactive = true;
|
|
params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params));
|
|
} else if (arg == "-h" || arg == "--help") {
|
|
gpt_print_usage(argc, argv, params);
|
|
exit(0);
|
|
} else if (arg == "-f" || arg == "--file") {
|
|
get_next_arg(i, argc, argv, arg, params);
|
|
std::ifstream file(argv[i]);
|
|
if (!file) {
|
|
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
|
break;
|
|
}
|
|
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
|
|
if (params.prompt.back() == '\n') {
|
|
params.prompt.pop_back();
|
|
}
|
|
} else if (arg == "-tt" || arg == "--token_test") {
|
|
params.token_test = get_next_arg(i, argc, argv, arg, params);
|
|
}
|
|
else {
|
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
|
gpt_print_usage(argc, argv, params);
|
|
exit(0);
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
|
fprintf(stderr, "\n");
|
|
fprintf(stderr, "options:\n");
|
|
fprintf(stderr, " -h, --help show this help message and exit\n");
|
|
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
|
|
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
|
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
|
|
fprintf(stderr, " prompt to start generation with (default: random)\n");
|
|
fprintf(stderr, " -f FNAME, --file FNAME\n");
|
|
fprintf(stderr, " load prompt from a file\n");
|
|
fprintf(stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n");
|
|
fprintf(stderr, " test tokenization\n");
|
|
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
|
|
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
|
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
|
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
|
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
|
|
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
|
|
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
|
fprintf(stderr, " -c N, --context N context / KV cache size (default: %d)\n", params.n_ctx);
|
|
fprintf(stderr, " --ignore-eos ignore EOS token during generation\n");
|
|
fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers);
|
|
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
|
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
|
fprintf(stderr, "\n");
|
|
}
|
|
|
|
std::string gpt_random_prompt(std::mt19937 & rng) {
|
|
const int r = rng() % 10;
|
|
switch (r) {
|
|
case 0: return "So";
|
|
case 1: return "Once upon a time";
|
|
case 2: return "When";
|
|
case 3: return "The";
|
|
case 4: return "After";
|
|
case 5: return "If";
|
|
case 6: return "import";
|
|
case 7: return "He";
|
|
case 8: return "She";
|
|
case 9: return "They";
|
|
}
|
|
|
|
return "The";
|
|
}
|
|
|
|
std::string trim(const std::string & s) {
|
|
std::regex e("^\\s+|\\s+$");
|
|
return std::regex_replace(s, e, "");
|
|
}
|
|
|
|
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
|
std::string result = s;
|
|
size_t pos = 0;
|
|
while ((pos = result.find(from, pos)) != std::string::npos) {
|
|
result.replace(pos, from.length(), to);
|
|
pos += to.length();
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void gpt_vocab::add_special_token(const std::string & token) {
|
|
special_tokens.push_back(token);
|
|
}
|
|
|
|
std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
|
std::map<std::string, int32_t> result;
|
|
|
|
// read file into string
|
|
std::string json;
|
|
{
|
|
std::ifstream ifs(fname);
|
|
if (!ifs) {
|
|
fprintf(stderr, "Failed to open %s\n", fname.c_str());
|
|
exit(1);
|
|
}
|
|
|
|
json = std::string((std::istreambuf_iterator<char>(ifs)),
|
|
(std::istreambuf_iterator<char>()));
|
|
}
|
|
|
|
if (json[0] != '{') {
|
|
return result;
|
|
}
|
|
|
|
// parse json
|
|
{
|
|
bool has_key = false;
|
|
bool in_token = false;
|
|
|
|
std::string str_key = "";
|
|
std::string str_val = "";
|
|
|
|
int n = json.size();
|
|
for (int i = 1; i < n; ++i) {
|
|
if (!in_token) {
|
|
if (json[i] == ' ') continue;
|
|
if (json[i] == '"') {
|
|
in_token = true;
|
|
continue;
|
|
}
|
|
} else {
|
|
if (json[i] == '\\' && i+1 < n) {
|
|
if (has_key == false) {
|
|
str_key += json[i];
|
|
} else {
|
|
str_val += json[i];
|
|
}
|
|
++i;
|
|
} else if (json[i] == '"') {
|
|
if (has_key == false) {
|
|
has_key = true;
|
|
++i;
|
|
while (json[i] == ' ') ++i;
|
|
++i; // :
|
|
while (json[i] == ' ') ++i;
|
|
if (json[i] != '\"') {
|
|
while (json[i] != ',' && json[i] != '}') {
|
|
str_val += json[i++];
|
|
}
|
|
has_key = false;
|
|
} else {
|
|
in_token = true;
|
|
continue;
|
|
}
|
|
} else {
|
|
has_key = false;
|
|
}
|
|
|
|
str_key = ::replace(str_key, "\\u0120", " " ); // \u0120 -> space
|
|
str_key = ::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
|
|
str_key = ::replace(str_key, "\\\"", "\""); // \\\" -> "
|
|
|
|
try {
|
|
result[str_key] = std::stoi(str_val);
|
|
} catch (...) {
|
|
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
|
|
|
|
}
|
|
str_key = "";
|
|
str_val = "";
|
|
in_token = false;
|
|
continue;
|
|
}
|
|
if (has_key == false) {
|
|
str_key += json[i];
|
|
} else {
|
|
str_val += json[i];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
void gpt_split_words(std::string str, std::vector<std::string>& words) {
|
|
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
|
const std::regex re(pattern);
|
|
std::smatch m;
|
|
|
|
while (std::regex_search(str, m, re)) {
|
|
for (auto x : m) {
|
|
words.push_back(x);
|
|
}
|
|
str = m.suffix();
|
|
}
|
|
}
|
|
|
|
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
|
std::vector<std::string> words;
|
|
|
|
// first split the text into words
|
|
{
|
|
std::string str = text;
|
|
|
|
// Generate the subpattern from the special_tokens vector if it's not empty
|
|
if (!vocab.special_tokens.empty()) {
|
|
const std::regex escape(R"([\[\\\^\$\.\|\?\*\+\(\)\{\}])");
|
|
std::string special_tokens_subpattern;
|
|
for (const auto & token : vocab.special_tokens) {
|
|
if (!special_tokens_subpattern.empty()) {
|
|
special_tokens_subpattern += "|";
|
|
}
|
|
special_tokens_subpattern += std::regex_replace(token, escape, R"(\$&)");
|
|
}
|
|
|
|
std::regex re(special_tokens_subpattern);
|
|
std::smatch m;
|
|
// Split the text by special tokens.
|
|
while (std::regex_search(str, m, re)) {
|
|
// Split the substrings in-between special tokens into words.
|
|
gpt_split_words(m.prefix(), words);
|
|
// Add matched special tokens as words.
|
|
for (auto x : m) {
|
|
words.push_back(x);
|
|
}
|
|
str = m.suffix();
|
|
}
|
|
// Remaining text without special tokens will be handled below.
|
|
}
|
|
|
|
gpt_split_words(str, words);
|
|
}
|
|
|
|
// find the longest token that forms each word in words:
|
|
std::vector<gpt_vocab::id> tokens;
|
|
for (const auto & word : words) {
|
|
for (int i = 0; i < (int) word.size(); ){
|
|
for (int j = word.size() - 1; j >= i; j--){
|
|
auto cand = word.substr(i, j-i+1);
|
|
auto it = vocab.token_to_id.find(cand);
|
|
if (it != vocab.token_to_id.end()){ // word.substr(i, j-i+1) in vocab
|
|
tokens.push_back(it->second);
|
|
i = j + 1;
|
|
break;
|
|
}
|
|
else if (j == i){ // word.substr(i, 1) has no matching
|
|
fprintf(stderr, "%s: unknown token '%s'\n", __func__, word.substr(i, 1).data());
|
|
i++;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return tokens;
|
|
}
|
|
|
|
static std::vector<gpt_vocab::id> parse_tokens_from_string(const std::string& input, char delimiter) {
|
|
std::vector<gpt_vocab::id> output;
|
|
std::stringstream ss(input);
|
|
std::string token;
|
|
|
|
while (std::getline(ss, token, delimiter)) {
|
|
output.push_back(std::stoi(token));
|
|
}
|
|
|
|
return output;
|
|
}
|
|
|
|
static std::map<std::string, std::vector<gpt_vocab::id>> extract_tests_from_file(const std::string & fpath_test){
|
|
if (fpath_test.empty()){
|
|
fprintf(stderr, "%s : No test file found.\n", __func__);
|
|
return std::map<std::string, std::vector<gpt_vocab::id>>();
|
|
}
|
|
|
|
std::map<std::string, std::vector<gpt_vocab::id>> tests;
|
|
|
|
auto fin = std::ifstream(fpath_test, std::ios_base::in);
|
|
const char * delimeter = " => ";
|
|
const char del_tok = ',';
|
|
std::string line;
|
|
while (std::getline(fin, line)) {
|
|
size_t delimiterPos = line.find(delimeter);
|
|
if (delimiterPos != std::string::npos) {
|
|
std::string text = line.substr(0, delimiterPos);
|
|
std::string s_tokens = line.substr(delimiterPos + std::strlen(delimeter));
|
|
tests[text] = parse_tokens_from_string(s_tokens, del_tok);
|
|
}
|
|
}
|
|
return tests;
|
|
}
|
|
|
|
void test_gpt_tokenizer(gpt_vocab & vocab, const std::string & fpath_test){
|
|
std::map<std::string, std::vector<gpt_vocab::id>> tests = extract_tests_from_file(fpath_test);
|
|
|
|
size_t n_fails = 0;
|
|
|
|
for (const auto & test : tests) {
|
|
std::vector<gpt_vocab::id> tokens = gpt_tokenize(vocab, test.first);
|
|
|
|
if (tokens != test.second){
|
|
n_fails++;
|
|
|
|
// print out failure cases
|
|
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test.first.c_str());
|
|
fprintf(stderr, "%s : tokens in hf: ", __func__);
|
|
for (const auto & t : test.second) {
|
|
fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
|
|
}
|
|
fprintf(stderr, "\n");
|
|
fprintf(stderr, "%s : tokens in ggml: ", __func__);
|
|
for (const auto & t : tokens) {
|
|
fprintf(stderr, "%s(%d), ", vocab.id_to_token[t].c_str(), t);
|
|
}
|
|
fprintf(stderr, "\n");
|
|
}
|
|
}
|
|
|
|
fprintf(stderr, "%s : %zu tests failed out of %zu tests.\n", __func__, n_fails, tests.size());
|
|
}
|
|
|
|
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
|
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
|
|
|
vocab.token_to_id = ::json_parse(fname);
|
|
|
|
for (const auto & kv : vocab.token_to_id) {
|
|
vocab.id_to_token[kv.second] = kv.first;
|
|
}
|
|
|
|
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
|
|
|
// print the vocabulary
|
|
//for (auto kv : vocab.token_to_id) {
|
|
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
|
//}
|
|
|
|
return true;
|
|
}
|
|
|
|
gpt_vocab::id gpt_sample_top_k_top_p(
|
|
const gpt_vocab & vocab,
|
|
const float * logits,
|
|
int top_k,
|
|
double top_p,
|
|
double temp,
|
|
std::mt19937 & rng) {
|
|
int n_logits = vocab.id_to_token.size();
|
|
|
|
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
|
logits_id.reserve(n_logits);
|
|
|
|
{
|
|
const double scale = 1.0/temp;
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
logits_id.push_back(std::make_pair(logits[i]*scale, i));
|
|
}
|
|
}
|
|
|
|
// find the top K tokens
|
|
std::partial_sort(
|
|
logits_id.begin(),
|
|
logits_id.begin() + top_k, logits_id.end(),
|
|
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
|
return a.first > b.first;
|
|
});
|
|
|
|
logits_id.resize(top_k);
|
|
|
|
double maxl = -INFINITY;
|
|
for (const auto & kv : logits_id) {
|
|
maxl = std::max(maxl, kv.first);
|
|
}
|
|
|
|
// compute probs for the top K tokens
|
|
std::vector<double> probs;
|
|
probs.reserve(logits_id.size());
|
|
|
|
double sum = 0.0;
|
|
for (const auto & kv : logits_id) {
|
|
double p = exp(kv.first - maxl);
|
|
probs.push_back(p);
|
|
sum += p;
|
|
}
|
|
|
|
// normalize the probs
|
|
for (auto & p : probs) {
|
|
p /= sum;
|
|
}
|
|
|
|
if (top_p < 1.0f) {
|
|
double cumsum = 0.0f;
|
|
for (int i = 0; i < top_k; i++) {
|
|
cumsum += probs[i];
|
|
if (cumsum >= top_p) {
|
|
top_k = i + 1;
|
|
probs.resize(top_k);
|
|
logits_id.resize(top_k);
|
|
break;
|
|
}
|
|
}
|
|
|
|
cumsum = 1.0/cumsum;
|
|
for (int i = 0; i < (int) probs.size(); i++) {
|
|
probs[i] *= cumsum;
|
|
}
|
|
}
|
|
|
|
//printf("\n");
|
|
//for (int i = 0; i < (int) probs.size(); i++) {
|
|
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
|
//}
|
|
//exit(0);
|
|
|
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
int idx = dist(rng);
|
|
|
|
return logits_id[idx].second;
|
|
}
|
|
|
|
gpt_vocab::id gpt_sample_top_k_top_p_repeat(
|
|
const gpt_vocab & vocab,
|
|
const float * logits,
|
|
const int32_t * last_n_tokens_data,
|
|
size_t last_n_tokens_data_size,
|
|
int top_k,
|
|
double top_p,
|
|
double temp,
|
|
int repeat_last_n,
|
|
float repeat_penalty,
|
|
std::mt19937 & rng) {
|
|
|
|
int n_logits = vocab.id_to_token.size();
|
|
|
|
const auto * plogits = logits;
|
|
|
|
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size);
|
|
|
|
if (temp <= 0) {
|
|
// select the token with the highest logit directly
|
|
float max_logit = plogits[0];
|
|
gpt_vocab::id max_id = 0;
|
|
|
|
for (int i = 1; i < n_logits; ++i) {
|
|
if (plogits[i] > max_logit) {
|
|
max_logit = plogits[i];
|
|
max_id = i;
|
|
}
|
|
}
|
|
return max_id;
|
|
}
|
|
|
|
|
|
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
|
|
logits_id.reserve(n_logits);
|
|
|
|
{
|
|
const float scale = 1.0f/temp;
|
|
for (int i = 0; i < n_logits; ++i) {
|
|
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
|
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
|
if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) {
|
|
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
|
if (plogits[i] < 0.0f) {
|
|
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
|
} else {
|
|
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
|
}
|
|
} else {
|
|
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
|
}
|
|
}
|
|
}
|
|
|
|
// find the top K tokens
|
|
std::partial_sort(
|
|
logits_id.begin(),
|
|
logits_id.begin() + top_k, logits_id.end(),
|
|
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
|
|
return a.first > b.first;
|
|
});
|
|
|
|
logits_id.resize(top_k);
|
|
|
|
double maxl = -INFINITY;
|
|
for (const auto & kv : logits_id) {
|
|
maxl = std::max(maxl, kv.first);
|
|
}
|
|
|
|
// compute probs for the top K tokens
|
|
std::vector<double> probs;
|
|
probs.reserve(logits_id.size());
|
|
|
|
double sum = 0.0;
|
|
for (const auto & kv : logits_id) {
|
|
double p = exp(kv.first - maxl);
|
|
probs.push_back(p);
|
|
sum += p;
|
|
}
|
|
|
|
// normalize the probs
|
|
for (auto & p : probs) {
|
|
p /= sum;
|
|
}
|
|
|
|
if (top_p < 1.0f) {
|
|
double cumsum = 0.0f;
|
|
for (int i = 0; i < top_k; i++) {
|
|
cumsum += probs[i];
|
|
if (cumsum >= top_p) {
|
|
top_k = i + 1;
|
|
probs.resize(top_k);
|
|
logits_id.resize(top_k);
|
|
break;
|
|
}
|
|
}
|
|
|
|
cumsum = 1.0/cumsum;
|
|
for (int i = 0; i < (int) probs.size(); i++) {
|
|
probs[i] *= cumsum;
|
|
}
|
|
}
|
|
|
|
// printf("\n");
|
|
// for (int i = 0; i < (int) probs.size(); i++) {
|
|
// for (int i = 0; i < 10; i++) {
|
|
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
|
// }
|
|
|
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
int idx = dist(rng);
|
|
|
|
return logits_id[idx].second;
|
|
|
|
}
|
|
|
|
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
|
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
|
const float dt = 1.0f / sample_rate;
|
|
const float alpha = dt / (rc + dt);
|
|
|
|
float y = data[0];
|
|
|
|
for (size_t i = 1; i < data.size(); i++) {
|
|
y = alpha * (y + data[i] - data[i - 1]);
|
|
data[i] = y;
|
|
}
|
|
}
|
|
|
|
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
|
const int n_samples = pcmf32.size();
|
|
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
|
|
|
if (n_samples_last >= n_samples) {
|
|
// not enough samples - assume no speech
|
|
return false;
|
|
}
|
|
|
|
if (freq_thold > 0.0f) {
|
|
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
|
}
|
|
|
|
float energy_all = 0.0f;
|
|
float energy_last = 0.0f;
|
|
|
|
for (int i = 0; i < n_samples; i++) {
|
|
energy_all += fabsf(pcmf32[i]);
|
|
|
|
if (i >= n_samples - n_samples_last) {
|
|
energy_last += fabsf(pcmf32[i]);
|
|
}
|
|
}
|
|
|
|
energy_all /= n_samples;
|
|
energy_last /= n_samples_last;
|
|
|
|
if (verbose) {
|
|
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
|
}
|
|
|
|
if (energy_last > vad_thold*energy_all) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
float similarity(const std::string & s0, const std::string & s1) {
|
|
const size_t len0 = s0.size() + 1;
|
|
const size_t len1 = s1.size() + 1;
|
|
|
|
std::vector<int> col(len1, 0);
|
|
std::vector<int> prevCol(len1, 0);
|
|
|
|
for (size_t i = 0; i < len1; i++) {
|
|
prevCol[i] = i;
|
|
}
|
|
|
|
for (size_t i = 0; i < len0; i++) {
|
|
col[0] = i;
|
|
for (size_t j = 1; j < len1; j++) {
|
|
col[j] = std::min(std::min(1 + col[j - 1], 1 + prevCol[j]), prevCol[j - 1] + (i > 0 && s0[i - 1] == s1[j - 1] ? 0 : 1));
|
|
}
|
|
col.swap(prevCol);
|
|
}
|
|
|
|
const float dist = prevCol[len1 - 1];
|
|
|
|
return 1.0f - (dist / std::max(s0.size(), s1.size()));
|
|
}
|
|
|
|
bool is_file_exist(const char * filename) {
|
|
std::ifstream infile(filename);
|
|
return infile.good();
|
|
}
|