mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-25 16:18:58 +01:00
whisper : add whisper_tokenize()
Tokenizes a string into a list of vocabulary tokens
This commit is contained in:
parent
ea19ed33f1
commit
bf69b669a0
81
whisper.cpp
81
whisper.cpp
@ -14,6 +14,7 @@
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <regex>
|
||||
|
||||
#define USE_FLASH_ATTN
|
||||
//#define USE_FLASH_FF
|
||||
@ -2161,6 +2162,71 @@ static bool log_mel_spectrogram(
|
||||
return true;
|
||||
}
|
||||
|
||||
// split text into tokens
|
||||
//
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
//
|
||||
// Regex (Python):
|
||||
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
//
|
||||
// Regex (C++):
|
||||
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||
//
|
||||
static std::vector<whisper_vocab::id> tokenize(const whisper_vocab & vocab, const std::string & text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<whisper_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
@ -2291,6 +2357,21 @@ struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx,
|
||||
return res;
|
||||
}
|
||||
|
||||
int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
|
||||
const auto res = tokenize(ctx->vocab, text);
|
||||
|
||||
if (res.size() > n_max_tokens) {
|
||||
fprintf(stderr, "%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < res.size(); i++) {
|
||||
tokens[i] = res[i];
|
||||
}
|
||||
|
||||
return res.size();
|
||||
}
|
||||
|
||||
int whisper_lang_id(const char * lang) {
|
||||
if (!g_lang.count(lang)) {
|
||||
fprintf(stderr, "%s: unknown language '%s'\n", __func__, lang);
|
||||
|
11
whisper.h
11
whisper.h
@ -139,6 +139,17 @@ extern "C" {
|
||||
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
// Returns -1 on failure
|
||||
// TODO: not sure if correct
|
||||
WHISPER_API int whisper_tokenize(
|
||||
struct whisper_context * ctx,
|
||||
const char * text,
|
||||
whisper_token * tokens,
|
||||
int n_max_tokens);
|
||||
|
||||
// Return the id of the specified language, returns -1 if not found
|
||||
WHISPER_API int whisper_lang_id(const char * lang);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user