forked from extern/whisper.cpp
whisper : add whisper_tokenize()
Tokenizes a string into a list of vocabulary tokens
This commit is contained in:
parent
ea19ed33f1
commit
bf69b669a0
81
whisper.cpp
81
whisper.cpp
@ -14,6 +14,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
#define USE_FLASH_ATTN
|
#define USE_FLASH_ATTN
|
||||||
//#define USE_FLASH_FF
|
//#define USE_FLASH_FF
|
||||||
@ -2161,6 +2162,71 @@ static bool log_mel_spectrogram(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// split text into tokens
|
||||||
|
//
|
||||||
|
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||||
|
//
|
||||||
|
// Regex (Python):
|
||||||
|
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||||
|
//
|
||||||
|
// Regex (C++):
|
||||||
|
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||||
|
//
|
||||||
|
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
|
||||||
|
std::vector<std::string> words;
|
||||||
|
|
||||||
|
// first split the text into words
|
||||||
|
{
|
||||||
|
std::string str = text;
|
||||||
|
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||||
|
|
||||||
|
std::regex re(pat);
|
||||||
|
std::smatch m;
|
||||||
|
|
||||||
|
while (std::regex_search(str, m, re)) {
|
||||||
|
for (auto x : m) {
|
||||||
|
words.push_back(x);
|
||||||
|
}
|
||||||
|
str = m.suffix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the longest tokens that form the words:
|
||||||
|
std::vector<whisper_vocab::id> tokens;
|
||||||
|
for (const auto & word : words) {
|
||||||
|
if (word.size() == 0) continue;
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
int n = word.size();
|
||||||
|
while (i < n) {
|
||||||
|
int j = n;
|
||||||
|
while (j > i) {
|
||||||
|
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||||
|
if (it != vocab.token_to_id.end()) {
|
||||||
|
tokens.push_back(it->second);
|
||||||
|
i = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
--j;
|
||||||
|
}
|
||||||
|
if (i == n) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (j == i) {
|
||||||
|
auto sub = word.substr(i, 1);
|
||||||
|
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||||
|
tokens.push_back(vocab.token_to_id.at(sub));
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// interface implementation
|
// interface implementation
|
||||||
//
|
//
|
||||||
@ -2291,6 +2357,21 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx,
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
||||||
|
const auto res = tokenize(ctx->vocab, text);
|
||||||
|
|
||||||
|
if (res.size() > n_max_tokens) {
|
||||||
|
fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < res.size(); i++) {
|
||||||
|
tokens[i] = res[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.size();
|
||||||
|
}
|
||||||
|
|
||||||
int whisper_lang_id(const char * lang) {
|
int whisper_lang_id(const char * lang) {
|
||||||
if (!g_lang.count(lang)) {
|
if (!g_lang.count(lang)) {
|
||||||
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
||||||
|
11
whisper.h
11
whisper.h
@ -139,6 +139,17 @@ extern "C" {
|
|||||||
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
|
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
|
||||||
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
|
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
|
||||||
|
|
||||||
|
// Convert the provided text into tokens.
|
||||||
|
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||||
|
// Returns the number of tokens on success, no more than n_max_tokens
|
||||||
|
// Returns -1 on failure
|
||||||
|
// TODO: not sure if correct
|
||||||
|
WHISPER_API int whisper_tokenize(
|
||||||
|
struct whisper_context * ctx,
|
||||||
|
const char * text,
|
||||||
|
whisper_token * tokens,
|
||||||
|
int n_max_tokens);
|
||||||
|
|
||||||
// Return the id of the specified language, returns -1 if not found
|
// Return the id of the specified language, returns -1 if not found
|
||||||
WHISPER_API int whisper_lang_id(const char * lang);
|
WHISPER_API int whisper_lang_id(const char * lang);
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user