Compare commits

..

9 Commits

11 changed files with 1964 additions and 782 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@ build-em/
build-debug/
build-release/
build-static/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/

View File

@ -469,7 +469,8 @@ in [models](models).
- [X] .NET: | [#422](https://github.com/ggerganov/whisper.cpp/discussions/422)
- [sandrohanea/whisper.net](https://github.com/sandrohanea/whisper.net)
- [NickDarvey/whisper](https://github.com/NickDarvey/whisper)
- [ ] Python: soon | [WIP](https://github.com/ggerganov/whisper.cpp/issues/9)
- [X] Python: | [#9](https://github.com/ggerganov/whisper.cpp/issues/9)
- [stlukey/whispercpp.py](https://github.com/stlukey/whispercpp.py) (Cython)
## Examples

View File

@ -145,7 +145,15 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
var db = event.target.result;
var tx = db.transaction(['models'], 'readwrite');
var os = tx.objectStore('models');
var rq = os.put(data, url);
var rq = null;
try {
var rq = os.put(data, url);
} catch (e) {
cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB: \n' + e);
cbCancel();
return;
}
rq.onsuccess = function (event) {
cbPrint('loadRemote: "' + url + '" stored in the IndexedDB');
@ -180,7 +188,6 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
rq.onabort = function (event) {
cbPrint('loadRemote: failed to open IndexedDB: abort');
cbCancel();
};
}

View File

@ -618,8 +618,6 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10;
}
whisper_full_cluster_segments(ctx);
}
// output stuff

View File

@ -31,9 +31,9 @@ endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \
-s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \
-s INITIAL_MEMORY=1500MB \
-s TOTAL_MEMORY=1500MB \
-s PTHREAD_POOL_SIZE_STRICT=0 \
-s INITIAL_MEMORY=2000MB \
-s TOTAL_MEMORY=2000MB \
-s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \

View File

@ -10,6 +10,12 @@ std::thread g_worker;
std::vector<struct whisper_context *> g_contexts(4, nullptr);
static inline int mpow2(int n) {
int p = 1;
while (p <= n) p *= 2;
return p/2;
}
EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_worker.joinable()) {
@ -43,7 +49,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
}
}));
emscripten::function("full_default", emscripten::optional_override([](size_t index, const emscripten::val & audio, const std::string & lang, bool translate) {
emscripten::function("full_default", emscripten::optional_override([](size_t index, const emscripten::val & audio, const std::string & lang, int nthreads, bool translate) {
if (g_worker.joinable()) {
g_worker.join();
}
@ -66,7 +72,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
params.print_special = false;
params.translate = translate;
params.language = whisper_is_multilingual(g_contexts[index]) ? lang.c_str() : "en";
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
params.n_threads = std::min(nthreads, std::min(16, mpow2(std::thread::hardware_concurrency())));
params.offset_ms = 0;
std::vector<float> pcmf32;

View File

@ -40,21 +40,34 @@
Note that the computation is quite heavy and may take a few seconds to complete.<br>
The transcription results will be displayed in the text area below.<br><br>
<b>Important: your browser must support WASM SIMD instructions for this to work.</b>
<b>Important:</b>
<ul>
<li>your browser must support WASM SIMD instructions for this to work</li>
<li>quantized models are still in experimental stage (<a href="https://github.com/ggerganov/ggml/pull/27">more info</a>)</li>
<li>Firefox cannot load files larger than 256 MB - use Chrome instead</li>
</ul>
<br><br><hr>
<hr>
<div id="model">
Whisper model: <span id="model-whisper-status"></span>
Whisper models: <span id="model-whisper-status"></span><br><br>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-tiny" onclick="loadWhisper('tiny')">tiny (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<button id="fetch-whisper-base" onclick="loadWhisper('base')">base (142 MB)</button>
<button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button>
<button id="fetch-whisper-small" onclick="loadWhisper('small')">small (466 MB)</button>
<span id="fetch-whisper-progress"></span>
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-base-en-q4_0" onclick="loadWhisper('base-en-q4_0')">base.en (4bit, 49 MB)</button>
<button id="fetch-whisper-base-q4_0" onclick="loadWhisper('base-q4_0')">base (4bit, 49 MB)</button>
<button id="fetch-whisper-small-en-q4_0" onclick="loadWhisper('small-en-q4_0')">small.en (4bit, 152 MB)</button>
<button id="fetch-whisper-small-q4_0" onclick="loadWhisper('small-q4_0')">small (4bit, 152 MB)</button><br>
<button id="fetch-whisper-medium-en-q4_0" onclick="loadWhisper('medium-en-q4_0')">medium.en (4bit, 469 MB)</button>
<button id="fetch-whisper-medium-q4_0" onclick="loadWhisper('medium-q4_0')">medium (4bit, 469 MB)</button>
<button id="fetch-whisper-large-q4_0" onclick="loadWhisper('large-q4_0')">large (4bit, 985 MB)</button>
<span id="fetch-whisper-progress"></span>
</div>
<br>
@ -161,6 +174,12 @@
<option value="yi">Yiddish</option>
</select>
</td>
<!-- Slider to select number of threads between 1 and 16 -->
<td>
Threads:
<input type="range" id="threads" name="threads" min="1" max="16" value="8" onchange="changeThreads(this.value)" />
<span id="threads-value">8</span>
</td>
<td>
<button onclick="onProcess(false);">Transcribe</button>
</td>
@ -263,11 +282,13 @@
Module.FS_createDataFile("/", fname, buf, true, true);
model_whisper = fname;
//model_whisper = fname;
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
document.getElementById('model').innerHTML = 'Model fetched: ' + model_whisper;
}
function loadFile(event, fname) {
@ -292,6 +313,15 @@
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
document.getElementById('fetch-whisper-base' ).style.display = 'none';
document.getElementById('fetch-whisper-small' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-base-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-small-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-medium-en-q4_0').style.display = 'none';
document.getElementById('fetch-whisper-medium-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-large-q4_0' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
}
@ -304,6 +334,14 @@
'base': 'https://whisper.ggerganov.com/ggml-model-whisper-base.bin',
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
'small': 'https://whisper.ggerganov.com/ggml-model-whisper-small.bin',
'base-en-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q4_0.bin',
'base-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-base-q4_0.bin',
'small-en-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q4_0.bin',
'small-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-small-q4_0.bin',
'medium-en-q4_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q4_0.bin',
'medium-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-medium-q4_0.bin',
'large-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q4_0.bin',
};
let sizes = {
@ -313,6 +351,14 @@
'base': 142,
'small.en': 466,
'small': 466,
'base-en-q4_0': 49,
'base-q4_0': 49,
'small-en-q4_0': 152,
'small-q4_0': 152,
'medium-en-q4_0': 469,
'medium-q4_0': 469,
'large-q4_0': 985,
};
let url = urls[model];
@ -327,6 +373,15 @@
document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
document.getElementById('fetch-whisper-base' ).style.display = 'none';
document.getElementById('fetch-whisper-small' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-base-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-small-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-medium-en-q4_0').style.display = 'none';
document.getElementById('fetch-whisper-medium-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-large-q4_0' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loading model: ' + model;
@ -337,12 +392,22 @@
cbCancel = function() {
var el;
el = document.getElementById('fetch-whisper-tiny-en' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-medium-en-q4_0'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-medium-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-large-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status' ); if (el) el.innerHTML = '';
};
@ -354,7 +419,8 @@
// audio file
//
const kMaxAudio_s = 120;
const kMaxAudio_s = 30*60;
const kMaxRecording_s = 2*60;
const kSampleRate = 16000;
window.AudioContext = window.AudioContext || window.webkitAudioContext;
@ -423,7 +489,7 @@
doRecording = false;
}
// record up to kMaxAudio_s seconds of audio from the microphone
// record up to kMaxRecording_s seconds of audio from the microphone
// check if doRecording is false every 1000 ms and stop recording if so
// update progress information
function startRecording() {
@ -479,9 +545,9 @@
printTextarea('js: audio recorded, size: ' + audio.length);
// truncate to first 30 seconds
if (audio.length > kMaxAudio_s*kSampleRate) {
audio = audio.slice(0, kMaxAudio_s*kSampleRate);
printTextarea('js: truncated audio to first ' + kMaxAudio_s + ' seconds');
if (audio.length > kMaxRecording_s*kSampleRate) {
audio = audio.slice(0, kMaxRecording_s*kSampleRate);
printTextarea('js: truncated audio to first ' + kMaxRecording_s + ' seconds');
}
setAudio(audio);
});
@ -509,24 +575,31 @@
});
}
document.getElementById('progress-bar').style.width = (100*(Date.now() - startTime)/1000/kMaxAudio_s) + '%';
document.getElementById('progress-text').innerHTML = (100*(Date.now() - startTime)/1000/kMaxAudio_s).toFixed(0) + '%';
document.getElementById('progress-bar').style.width = (100*(Date.now() - startTime)/1000/kMaxRecording_s) + '%';
document.getElementById('progress-text').innerHTML = (100*(Date.now() - startTime)/1000/kMaxRecording_s).toFixed(0) + '%';
}, 1000);
printTextarea('js: recording ...');
setTimeout(function() {
if (doRecording) {
printTextarea('js: recording stopped after ' + kMaxAudio_s + ' seconds');
printTextarea('js: recording stopped after ' + kMaxRecording_s + ' seconds');
stopRecording();
}
}, kMaxAudio_s*1000);
}, kMaxRecording_s*1000);
}
//
// transcribe
//
var nthreads = 8;
function changeThreads(value) {
nthreads = value;
document.getElementById('threads-value').innerHTML = nthreads;
}
function onProcess(translate) {
if (!instance) {
instance = Module.init('whisper.bin');
@ -553,7 +626,7 @@
printTextarea('');
setTimeout(function() {
var ret = Module.full_default(instance, audio, document.getElementById('language').value, translate);
var ret = Module.full_default(instance, audio, document.getElementById('language').value, nthreads, translate);
console.log('js: full_default returned: ' + ret);
if (ret) {
printTextarea("js: whisper returned: " + ret);

1979
ggml.c

File diff suppressed because it is too large Load Diff

17
ggml.h
View File

@ -198,6 +198,8 @@ struct ggml_object;
struct ggml_context;
enum ggml_type {
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
@ -326,7 +328,10 @@ void ggml_print_objects(const struct ggml_context * ctx);
int ggml_nelements(const struct ggml_tensor * tensor);
size_t ggml_nbytes (const struct ggml_tensor * tensor);
size_t ggml_type_size (enum ggml_type type);
int ggml_blck_size (enum ggml_type type);
size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
size_t ggml_element_size(const struct ggml_tensor * tensor);
struct ggml_context * ggml_init(struct ggml_init_params params);
@ -726,16 +731,6 @@ enum ggml_opt_result ggml_opt(
struct ggml_opt_params params,
struct ggml_tensor * f);
//
// Temp stuff
//
void ggml_svd_reduce_dims(
int ne0,
int ne1,
float * a,
int nd);
//
// system info
//

View File

@ -252,12 +252,34 @@ static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
{ MODEL_LARGE, 9ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_MODEL = {
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB },
static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
{ GGML_TYPE_F16,
{
{ MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB },
},
},
{ GGML_TYPE_Q4_0,
{
{ MODEL_TINY, 26ull*MB },
{ MODEL_BASE, 50ull*MB },
{ MODEL_SMALL, 154ull*MB },
{ MODEL_MEDIUM, 470ull*MB },
{ MODEL_LARGE, 940ull*MB },
},
},
{ GGML_TYPE_Q4_1,
{
{ MODEL_TINY, 31ull*MB },
{ MODEL_BASE, 57ull*MB },
{ MODEL_SMALL, 181ull*MB },
{ MODEL_MEDIUM, 559ull*MB },
{ MODEL_LARGE, 1122ull*MB },
},
},
};
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
@ -268,14 +290,6 @@ static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
{ MODEL_LARGE, 71ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_KV_ENC_SELF = {
{ MODEL_TINY, 23ull*MB },
{ MODEL_BASE, 26ull*MB },
{ MODEL_SMALL, 216ull*MB },
{ MODEL_MEDIUM, 243ull*MB },
{ MODEL_LARGE, 271ull*MB },
};
static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
{ MODEL_TINY, 9ull*MB },
{ MODEL_BASE, 18ull*MB },
@ -579,7 +593,6 @@ struct whisper_context {
// cross-attention KV cache for the decoders
// shared between all decoders
whisper_kv_cache kv_cross;
whisper_kv_cache kv_enc_self;
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
@ -601,18 +614,16 @@ struct whisper_context {
mutable std::mt19937 rng; // used for sampling at t > 0.0
int lang_id;
int lang_id = 0; // english by default
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
int64_t t_beg = 0;
int64_t t_last = 0;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
// [EXPERIMENTAL] speed-up techniques
int32_t exp_n_audio_ctx; // 0 - use default
std::vector<float> audio_embd;
int32_t exp_n_audio_ctx = 0; // 0 - use default
void use_buf(struct ggml_context * ctx, int i) {
#if defined(WHISPER_USE_SCRATCH)
@ -692,7 +703,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
const ggml_type wtype = cache.k->type;
WHISPER_ASSERT(wtype == cache.v->type);
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype));
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
struct ggml_init_params params;
params.mem_size = cache.buf.size();
@ -787,12 +798,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.type = e_model::MODEL_LARGE;
}
// for the big tensors, we have the option to store the data in 16-bit floats
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
wctx.wtype = GGML_TYPE_COUNT;
switch (model.hparams.f16) {
case 0: wctx.wtype = GGML_TYPE_F32; break;
case 1: wctx.wtype = GGML_TYPE_F16; break;
case 2: wctx.wtype = GGML_TYPE_Q4_0; break;
case 3: wctx.wtype = GGML_TYPE_Q4_1; break;
default:
{
fprintf(stderr, "%s: invalid model (bad f16 value %d)\n", __func__, model.hparams.f16);
return false;
}
}
const size_t scale = model.hparams.f16 ? 1 : 2;
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
@ -803,7 +827,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
fprintf(stderr, "%s: ftype = %s\n", __func__, ftype_str[model.hparams.f16]);
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
// print memory requirements
@ -814,7 +838,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH1.at (model.type) +
MEM_REQ_SCRATCH2.at (model.type) +
MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
@ -830,9 +854,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// always have at least one decoder
wctx.model.buf = new std::vector<uint8_t>();
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type));
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) {
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, GGML_TYPE_F16, model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
return false;
}
@ -842,12 +866,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return false;
}
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_ENC_SELF.at(model.type), wctx.kv_enc_self, wctx.wtype, model.hparams.n_audio_ctx)) {
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, GGML_TYPE_F16, model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return false;
}
@ -979,92 +998,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// encoder
{
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b
ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_1_w
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_2_w
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w;
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b;
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
}
// decoder
{
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w;
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b;
ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
}
// encoder layers
{
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
}
// decoder layers
{
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
//
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
}
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
@ -1110,10 +1129,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
{
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state);
model.e_conv_1_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_mels, n_audio_state);
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state);
model.e_conv_2_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_audio_state, n_audio_state);
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
@ -1329,9 +1348,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
return false;
}
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t);
size_t bpe = 0;
if (nelements*bpe != ggml_nbytes(tensor)) {
switch (ftype) {
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
default:
{
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
return false;
}
};
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false;
@ -1374,8 +1405,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
static bool whisper_encode(
whisper_context & wctx,
const int mel_offset,
const int n_threads,
bool repeat = false) {
const int n_threads) {
const int64_t t_start_us = ggml_time_us();
const auto & model = wctx.model;
@ -1407,31 +1437,13 @@ static bool whisper_encode(
const int i0 = std::min(mel_offset, mel_inp.n_len);
const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
if (repeat == false) {
for (int j = 0; j < mel_inp.n_mel; ++j) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
}
}
} else {
for (int j = 0; j < mel_inp.n_mel; ++j) {
int k = 0;
while (k < 2*n_ctx) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + k] = mel_inp.data[j*mel_inp.n_len + i];
k++;
if (k >= 2*n_ctx) {
break;
}
}
}
for (int j = 0; j < mel_inp.n_mel; ++j) {
for (int i = i0; i < i1; ++i) {
dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i];
}
}
}
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
struct ggml_tensor * cur;
// convolution + gelu
@ -1459,18 +1471,6 @@ static bool whisper_encode(
cur = ggml_gelu(ctx0, cur);
}
//{
// //printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
// wctx.use_buf(ctx0, -1);
// struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(0*n_ctx));
// //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
// //ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//}
wctx.use_buf(ctx0, 3);
// ===================================================================
@ -1551,18 +1551,6 @@ static bool whisper_encode(
Vcur),
Vcur);
//{
// //printf("Kcur: %d %d %d %d, size element = %d\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], Kcur->ne[3], ggml_element_size(Kcur));
// wctx.use_buf(ctx0, -1);
// struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
// struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
// ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//}
// ------
wctx.use_buf(ctx0, 0);
@ -1572,14 +1560,14 @@ static bool whisper_encode(
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Kcur,
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
struct ggml_tensor * V =
@ -1589,7 +1577,7 @@ static bool whisper_encode(
Vcur,
n_state/n_head, n_head, n_ctx),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
@ -1605,7 +1593,7 @@ static bool whisper_encode(
ggml_permute(ctx0,
ggml_cpy(ctx0,
Kcur,
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3);
// K * Q
@ -1623,7 +1611,7 @@ static bool whisper_encode(
// ggml_permute(ctx0,
// ggml_cpy(ctx0,
// Vcur,
// ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
// ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
@ -1635,7 +1623,7 @@ static bool whisper_encode(
Vcur,
n_state/n_head, n_head, n_ctx),
0, 2, 1, 3),
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head)
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
@ -1647,18 +1635,6 @@ static bool whisper_encode(
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
{
//printf("cur: %d %d %d %d, size element = %d\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_element_size(cur));
wctx.use_buf(ctx0, -1);
struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_enc_self.k, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.k)*n_state)*(il*n_ctx));
//struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_enc_self.v, n_state*n_ctx, (ggml_element_size(wctx.kv_enc_self.v)*n_state)*(il*n_ctx));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, cur, k));
//ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}
}
// projection
@ -1705,7 +1681,7 @@ static bool whisper_encode(
wctx.use_buf(ctx0, 0);
cur = ggml_flash_ff(ctx0,
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)),
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_state, n_ctx)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
wctx.use_buf(ctx0, 0);
@ -1768,6 +1744,8 @@ static bool whisper_encode(
// run the computation
{
struct ggml_cgraph gf = {};
gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur);
ggml_graph_compute (ctx0, &gf);
@ -1789,24 +1767,6 @@ static bool whisper_encode(
// printf("\n");
//}
{
//const int i0 = std::min(mel_offset, mel_inp.n_len);
//const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len);
const int i0 = 0;
const int i1 = cur->ne[1];
//printf("i0 = %d, i1 = %d, (i1 - i0) = %d, embd size = %d\n", i0, i1, i1 - i0, cur->ne[0]);
wctx.audio_embd.clear();
wctx.audio_embd.resize(cur->ne[0], 0.0f);
for (int j = 0; j < cur->ne[0]; ++j) {
for (int i = i0; i < i1; ++i) {
wctx.audio_embd[j] += ((float *)(cur->data))[(i - i0)*cur->ne[0] + j];
}
wctx.audio_embd[j] /= (i1 - i0);
}
}
// pre-compute cross-attention memory
{
struct ggml_cgraph gf = {};
@ -3049,6 +3009,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
};
switch (strategy) {
@ -3176,7 +3139,7 @@ static const std::vector<std::string> non_speech_tokens = {
// - applies logit filters
// - computes logprobs and probs
static void whisper_process_logits(
const struct whisper_context & ctx,
struct whisper_context & ctx,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
float temperature) {
@ -3232,6 +3195,9 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;
if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}
// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
@ -3935,7 +3901,7 @@ int whisper_full(
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});
unsigned int cur_c = 0;
uint32_t cur_c = 0;
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
@ -4420,7 +4386,7 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
}
int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->lang_id;
return ctx->lang_id;
}
int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) {
@ -4526,23 +4492,32 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
// put a bunch of random data in the buffer
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
for (int j = 0; j < (int) sizes.size(); j++) {
int n_q4_0 = 0;
int n_q4_1 = 0;
int n_fp16 = 0;
int n_fp32 = 0;
// GFLOPS/s
double s_q4_0 = 0.0;
double s_q4_1 = 0.0;
double s_fp16 = 0.0;
double s_fp32 = 0.0;
const size_t N = sizes[j];
for (int k = 0; k < 2; ++k) {
const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32;
for (int k = 0; k < 4; ++k) {
const ggml_type wtype =
k == 0 ? GGML_TYPE_Q4_0 :
k == 1 ? GGML_TYPE_Q4_1 :
k == 2 ? GGML_TYPE_F16 :
GGML_TYPE_F32;
double & s = k == 0 ? s_fp16 : s_fp32;
int & n = k == 0 ? n_fp16 : n_fp32;
double & s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 : k == 2 ? s_fp16 : s_fp32;
int & n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 : k == 2 ? n_fp16 : n_fp32;
struct ggml_init_params gparams = {
/*.mem_size =*/ buf.size(),
@ -4585,8 +4560,8 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
s = ((2.0*N*N*N*n)/tsum)*1e-9;
}
fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n",
N, N, s_fp16, n_fp16, s_fp32, n_fp32);
fprintf(stderr, "ggml_mul_mat: %4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) / Q4_1 %7.1f GFLOPS (%3d runs) / F16 %7.1f GFLOPS (%3d runs) / F32 %7.1f GFLOPS (%3d runs)\n",
N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1, s_fp16, n_fp16, s_fp32, n_fp32);
}
return 0;
@ -4893,258 +4868,3 @@ static void whisper_exp_compute_token_level_timestamps(
// }
//}
}
//
// diarization stuff
//
void whisper_full_cluster_segments(struct whisper_context * ctx) {
const int n_segments = ctx->result_all.size();
printf("%s: clustering %d segments\n", __func__, n_segments);
const auto mel_len_save = ctx->mel.n_len;
printf("%s: mel_len_save = %d\n", __func__, mel_len_save);
const int n_ctx = ctx->model.hparams.n_audio_ctx;
const int n_state = ctx->model.hparams.n_audio_state;
const int n_layer = ctx->model.hparams.n_audio_layer;
#if 0
// use the last layer of the encoder
{
std::vector<float> embd(n_segments*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
for (int j = 0; j < n_state; ++j) {
embd[i*n_state + j] = ctx->audio_embd[j];
}
}
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_state, n_segments, embd.data(), n_features);
#elif 0
// use cross kv cache of various layers
for (int il = 0; il < n_layer; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
const size_t offs = ggml_element_size(ctx->kv_cross.k)*(il*n_ctx*n_state);
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_cross.k->data + offs);
for (int j = 0; j < n_ctx*n_state; ++j) {
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
}
}
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#elif 0
// use conv embedding
for (int il = 0; il < 1; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
const size_t offs = ggml_element_size(ctx->kv_enc_self.k)*(il*n_ctx*n_state);
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_enc_self.k->data + offs);
for (int j = 0; j < n_ctx*n_state; ++j) {
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
}
}
const int n_features = std::min(3, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#else
// use enc self kv cache of various layers
for (int il = 0; il < n_layer; ++il) {
std::vector<float> embd(n_segments*n_ctx*n_state);
for (int i = 0; i < n_segments; ++i) {
const auto & segment_i = ctx->result_all[i];
printf("%s: layer %2d, segment %3d: t0 = %7d, t1 = %7d, text = %s\n", __func__, il, i, (int) segment_i.t0, (int) segment_i.t1, segment_i.text.c_str());
ctx->mel.n_len = segment_i.t1;
whisper_encode(*ctx, segment_i.t0, 7, true);
const size_t offs = ggml_element_size(ctx->kv_enc_self.k)*(il*n_ctx*n_state);
const ggml_fp16_t * f = (const ggml_fp16_t * )((const char *) ctx->kv_enc_self.k->data + offs);
for (int j = 0; j < n_ctx*n_state; ++j) {
embd[i*n_ctx*n_state + j] = ggml_fp16_to_fp32(f[j]);
}
}
const int n_features = std::min(4, n_segments);
ggml_svd_reduce_dims(n_ctx*n_state, n_segments, embd.data(), n_features);
#endif
std::vector<std::vector<double>> features(n_segments);
for (int i = 0; i < n_segments; ++i) {
features[i].resize(n_features);
for (int j = 0; j < n_features; ++j) {
features[i][j] = embd[i*n_features + j];
}
}
// fuzzy c-means clustering
const int n_clusters = 2;
std::vector<std::vector<double>> centroids(n_clusters, std::vector<double>(n_features, 0.0));
std::vector<std::vector<double>> membership(n_segments, std::vector<double>(n_clusters, 0.0));
// initialize the centroids
for (int i = 0; i < n_clusters; ++i) {
for (int j = 0; j < n_features; ++j) {
centroids[i][j] = features[i][j];
}
}
// initialize the membership
for (int i = 0; i < n_segments; ++i) {
//membership[i][i % n_clusters] = 1.0;
//for (int j = 0; j < n_clusters; ++j) {
// membership[i][j] = rand() / (float) RAND_MAX;
//}
for (int j = 0; j < n_clusters; ++j) {
membership[i][j] = 1.0 / n_clusters;
}
}
const int niter = 10000;
// iterate
for (int i = 0; i < niter; ++i) {
// print the membership
if (i == niter - 1) {
//{
for (int i = 0; i < n_segments; ++i) {
#if 1
printf("%s: membership %3d: ", __func__, i);
for (int j = 0; j < n_clusters; ++j) {
printf("%.1f ", membership[i][j]);
}
printf(" '%s'\n", ctx->result_all[i].text.c_str());
#else
printf("%s: features : ", __func__);
for (int j = 0; j < n_features; ++j) {
printf("%8.3f ", features[i][j]);
}
printf(" '%s'\n", ctx->result_all[i].text.c_str());
#endif
}
printf("----------------\n");
// print the centroids
for (int i = 0; i < n_clusters; ++i) {
printf("%s: centroid %d: ", __func__, i);
for (int j = 0; j < n_features; ++j) {
printf("%f ", centroids[i][j]);
}
printf("\n");
}
}
// update the membership
for (int j = 0; j < n_segments; ++j) {
for (int k = 0; k < n_clusters; ++k) {
double sum = 0.0;
for (int l = 0; l < n_clusters; ++l) {
//sum += std::pow(whisper_distance(features[j], centroids[k])/whisper_distance(features[j], centroids[l]), 2.0/(2.0 - 1.0));
double d0 = 0.0;
double d1 = 0.0;
#if 1
// use the euclidean distance
{
for (int m = 0; m < n_features; ++m) {
d0 += std::pow(features[j][m] - centroids[k][m], 2.0);
}
d0 = std::sqrt(d0);
for (int m = 0; m < n_features; ++m) {
d1 += std::pow(features[j][m] - centroids[l][m], 2.0);
}
d1 = std::sqrt(d1);
}
#else
// use the cosine distance
{
double dot = 0.0;
double norm0 = 0.0;
double norm1 = 0.0;
for (int m = 0; m < n_features; ++m) {
dot += features[j][m]*centroids[k][m];
norm0 += std::pow(features[j][m], 2.0);
norm1 += std::pow(centroids[k][m], 2.0);
}
d0 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
dot = 0.0;
norm0 = 0.0;
norm1 = 0.0;
for (int m = 0; m < n_features; ++m) {
dot += features[j][m]*centroids[l][m];
norm0 += std::pow(features[j][m], 2.0);
norm1 += std::pow(centroids[l][m], 2.0);
}
d1 = 1.0 - dot/(std::sqrt(norm0)*std::sqrt(norm1));
}
#endif
if (d1 > 0.0) {
sum += std::pow(d0/d1, 2.0/(1.20 - 1.0));
} else {
sum += 1.0;
}
}
membership[j][k] = sum == 0.0 ? 1.0 : 1.0/sum;
}
}
// update the centroids
for (int j = 0; j < n_clusters; ++j) {
for (int k = 0; k < n_features; ++k) {
double sum = 0.0;
double sum2 = 0.0;
for (int l = 0; l < n_segments; ++l) {
sum += membership[l][j]*features[l][k];
sum2 += membership[l][j];
}
centroids[j][k] = sum2 == 0.0 ? 0.0 : sum/sum2;
}
}
}
}
// restore the mel length
ctx->mel.n_len = mel_len_save;
}

View File

@ -243,6 +243,16 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);
// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);
// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
@ -315,6 +325,10 @@ extern "C" {
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
};
WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
@ -372,10 +386,6 @@ extern "C" {
WHISPER_API int whisper_bench_memcpy(int n_threads);
WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads);
// Temporary experimental API
WHISPER_API void whisper_full_cluster_segments(struct whisper_context * ctx);
#ifdef __cplusplus
}
#endif