forked from extern/whisper.cpp
Reduce memory usage even more + better sampling
- The encode/decode memory buffers are now reused - If the 30-sec segment goes for too long without a timestamp token, we force one. Improves transcription for large model - Stereo support - Add "micro-machines.wav" sample
This commit is contained in:
120
main.cpp
120
main.cpp
@ -158,11 +158,11 @@ const std::map<e_model, size_t> MEM_REQ_ENCODE_LAYER = {
|
||||
};
|
||||
|
||||
const std::map<e_model, size_t> MEM_REQ_DECODE = {
|
||||
{ MODEL_TINY, 190ull*MB },
|
||||
{ MODEL_BASE, 190ull*MB },
|
||||
{ MODEL_SMALL, 190ull*MB },
|
||||
{ MODEL_MEDIUM, 200ull*MB },
|
||||
{ MODEL_LARGE, 200ull*MB },
|
||||
{ MODEL_TINY, 94ull*MB },
|
||||
{ MODEL_BASE, 96ull*MB },
|
||||
{ MODEL_SMALL, 98ull*MB },
|
||||
{ MODEL_MEDIUM, 100ull*MB },
|
||||
{ MODEL_LARGE, 102ull*MB },
|
||||
};
|
||||
|
||||
const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
|
||||
@ -173,6 +173,11 @@ const std::map<e_model, size_t> MEM_REQ_DECODE_LAYER = {
|
||||
{ MODEL_LARGE, 110ull*MB },
|
||||
};
|
||||
|
||||
// the memory buffers used to store the model in memory and perform the inference computations
|
||||
std::vector<uint8_t> g_buf_model;
|
||||
std::vector<uint8_t> g_buf_compute;
|
||||
std::vector<uint8_t> g_buf_compute_layer;
|
||||
|
||||
const int SAMPLE_RATE = 16000;
|
||||
const int N_FFT = 400;
|
||||
const int N_MEL = 80;
|
||||
@ -542,13 +547,15 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
|
||||
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
||||
printf("%s: type = %d\n", __func__, model.type);
|
||||
|
||||
g_buf_model.resize(MEM_REQ_MODEL.at(model.type));
|
||||
g_buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
|
||||
g_buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
|
||||
|
||||
// this is the total memory required to run the inference
|
||||
const size_t mem_required =
|
||||
MEM_REQ_MODEL.at(model.type) +
|
||||
MEM_REQ_ENCODE.at(model.type) +
|
||||
MEM_REQ_ENCODE_LAYER.at(model.type) +
|
||||
MEM_REQ_DECODE.at(model.type) +
|
||||
MEM_REQ_DECODE_LAYER.at(model.type);
|
||||
g_buf_model.size() +
|
||||
g_buf_compute.size() +
|
||||
g_buf_compute_layer.size();
|
||||
|
||||
printf("%s: mem_required = %.2f MB\n", __func__, mem_required / 1024.0 / 1024.0);
|
||||
}
|
||||
@ -752,8 +759,8 @@ bool whisper_model_load(const std::string & fname, whisper_model & model, whispe
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.mem_size = g_buf_model.size(),
|
||||
.mem_buffer = g_buf_model.data(),
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
@ -1089,17 +1096,10 @@ bool whisper_encode(
|
||||
const int n_mels = hparams.n_mels;
|
||||
assert(mel_inp.n_mel == n_mels);
|
||||
|
||||
struct ggml_init_params params;
|
||||
|
||||
{
|
||||
static size_t buf_size = MEM_REQ_ENCODE.at(model.type);
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
params = {
|
||||
.mem_size = buf_size,
|
||||
.mem_buffer = buf,
|
||||
};
|
||||
}
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = g_buf_compute.size(),
|
||||
.mem_buffer = g_buf_compute.data(),
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
@ -1151,16 +1151,10 @@ bool whisper_encode(
|
||||
|
||||
// create separate context for each layer to reduce memory usage
|
||||
|
||||
struct ggml_init_params paramsL;
|
||||
{
|
||||
static size_t buf_size = MEM_REQ_ENCODE_LAYER.at(model.type);
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
paramsL = {
|
||||
.mem_size = buf_size,
|
||||
.mem_buffer = buf,
|
||||
};
|
||||
}
|
||||
struct ggml_init_params paramsL = {
|
||||
.mem_size = g_buf_compute_layer.size(),
|
||||
.mem_buffer = g_buf_compute_layer.data(),
|
||||
};
|
||||
|
||||
struct ggml_context * ctxL = ggml_init(paramsL);
|
||||
|
||||
@ -1492,17 +1486,10 @@ bool whisper_decode(
|
||||
const int N = prompt.size();
|
||||
const int M = hparams.n_audio_ctx;
|
||||
|
||||
struct ggml_init_params params;
|
||||
|
||||
{
|
||||
static size_t buf_size = MEM_REQ_DECODE.at(model.type);
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
params = {
|
||||
.mem_size = buf_size,
|
||||
.mem_buffer = buf,
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = g_buf_compute.size(),
|
||||
.mem_buffer = g_buf_compute.data(),
|
||||
};
|
||||
}
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
@ -1525,17 +1512,10 @@ bool whisper_decode(
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const auto & layer = model.layers_decoder[il];
|
||||
|
||||
struct ggml_init_params paramsL;
|
||||
|
||||
{
|
||||
static size_t buf_size = MEM_REQ_DECODE_LAYER.at(model.type);
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
paramsL = {
|
||||
.mem_size = buf_size,
|
||||
.mem_buffer = buf,
|
||||
};
|
||||
}
|
||||
struct ggml_init_params paramsL = {
|
||||
.mem_size = g_buf_compute_layer.size(),
|
||||
.mem_buffer = g_buf_compute_layer.data(),
|
||||
};
|
||||
|
||||
struct ggml_context * ctxL = ggml_init(paramsL);
|
||||
struct ggml_cgraph gf = { .n_threads = n_threads };
|
||||
@ -1849,7 +1829,7 @@ bool whisper_decode(
|
||||
// TODO: temperature
|
||||
whisper_vocab::id whisper_sample_best(
|
||||
const whisper_vocab & vocab,
|
||||
const float * probs) {
|
||||
const float * probs, bool need_timestamp) {
|
||||
int n_logits = vocab.id_to_token.size();
|
||||
|
||||
std::vector<std::pair<double, whisper_vocab::id>> probs_id;
|
||||
@ -1859,7 +1839,7 @@ whisper_vocab::id whisper_sample_best(
|
||||
probs_id.push_back(std::make_pair(probs[i], i));
|
||||
}
|
||||
|
||||
const int top_k = 10;
|
||||
const int top_k = 4;
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
@ -1876,6 +1856,15 @@ whisper_vocab::id whisper_sample_best(
|
||||
// printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
|
||||
//}
|
||||
|
||||
if (need_timestamp) {
|
||||
// at the end of the 30-second audio segment, we start giving preference to time tokens
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
if (probs_id[i].second > vocab.token_beg + 1300 && probs_id[i].first > probs_id[0].first*0.1) {
|
||||
return probs_id[i].second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int res = 0;
|
||||
while ((probs_id[res].second == vocab.token_sot ||
|
||||
probs_id[res].second == vocab.token_solm ||
|
||||
@ -2136,8 +2125,8 @@ int main(int argc, char ** argv) {
|
||||
return 2;
|
||||
}
|
||||
|
||||
if (wav.channels != 1) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be mono\n", argv[0], params.fname_inp.c_str());
|
||||
if (wav.channels != 1 && wav.channels != 2) {
|
||||
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", argv[0], params.fname_inp.c_str());
|
||||
return 3;
|
||||
}
|
||||
|
||||
@ -2158,8 +2147,14 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// convert to float
|
||||
pcmf32.resize(pcm16.size());
|
||||
for (size_t i = 0; i < pcm16.size(); i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
if (wav.channels == 1) {
|
||||
for (size_t i = 0; i < pcm16.size(); i++) {
|
||||
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < pcm16.size(); i++) {
|
||||
pcmf32[i] = float(pcm16[i*2 + 0] + pcm16[i*2 + 1])/32768.0f/2.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -2252,6 +2247,7 @@ int main(int argc, char ** argv) {
|
||||
int seek_delta = 100*CHUNK_SIZE;
|
||||
whisper_vocab::id last_id = 0;
|
||||
|
||||
// print the prompt
|
||||
//printf("\n\n");
|
||||
//for (int i = 0; i < prompt.size(); i++) {
|
||||
// printf("%s: prompt[%d] = %s\n", __func__, i, vocab.id_to_token[prompt[i]].c_str());
|
||||
@ -2294,7 +2290,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab));
|
||||
id = whisper_sample_best(vocab, probs.data() + (probs.size() - n_vocab), result_len == 0);
|
||||
if (i > 0) {
|
||||
tid = whisper_sample_timestamp(vocab, probs.data() + (probs.size() - n_vocab));
|
||||
}
|
||||
@ -2313,6 +2309,8 @@ int main(int argc, char ** argv) {
|
||||
prompt.push_back(id);
|
||||
result_cur.push_back({ id, seek + 2*(tid - vocab.token_beg) });
|
||||
|
||||
//printf("%s: %s\n", __func__, vocab.id_to_token[id].c_str());
|
||||
|
||||
// end of text token
|
||||
if (id == vocab.token_eot) {
|
||||
break;
|
||||
|
Reference in New Issue
Block a user