Compare commits

...

5 Commits

8 changed files with 1929 additions and 221 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@ build-em/
build-debug/ build-debug/
build-release/ build-release/
build-static/ build-static/
build-no-accel/
build-sanitize-addr/ build-sanitize-addr/
build-sanitize-thread/ build-sanitize-thread/

View File

@ -145,7 +145,15 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
var db = event.target.result; var db = event.target.result;
var tx = db.transaction(['models'], 'readwrite'); var tx = db.transaction(['models'], 'readwrite');
var os = tx.objectStore('models'); var os = tx.objectStore('models');
var rq = null;
try {
var rq = os.put(data, url); var rq = os.put(data, url);
} catch (e) {
cbPrint('loadRemote: failed to store "' + url + '" in the IndexedDB: \n' + e);
cbCancel();
return;
}
rq.onsuccess = function (event) { rq.onsuccess = function (event) {
cbPrint('loadRemote: "' + url + '" stored in the IndexedDB'); cbPrint('loadRemote: "' + url + '" stored in the IndexedDB');
@ -180,7 +188,6 @@ function loadRemote(url, dst, size_mb, cbProgress, cbReady, cbCancel, cbPrint) {
rq.onabort = function (event) { rq.onabort = function (event) {
cbPrint('loadRemote: failed to open IndexedDB: abort'); cbPrint('loadRemote: failed to open IndexedDB: abort');
cbCancel();
}; };
} }

View File

@ -31,9 +31,9 @@ endif()
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \ set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
--bind \ --bind \
-s USE_PTHREADS=1 \ -s USE_PTHREADS=1 \
-s PTHREAD_POOL_SIZE=8 \ -s PTHREAD_POOL_SIZE_STRICT=0 \
-s INITIAL_MEMORY=1500MB \ -s INITIAL_MEMORY=2000MB \
-s TOTAL_MEMORY=1500MB \ -s TOTAL_MEMORY=2000MB \
-s FORCE_FILESYSTEM=1 \ -s FORCE_FILESYSTEM=1 \
-s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \ -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \
${EXTRA_FLAGS} \ ${EXTRA_FLAGS} \

View File

@ -10,6 +10,12 @@ std::thread g_worker;
std::vector<struct whisper_context *> g_contexts(4, nullptr); std::vector<struct whisper_context *> g_contexts(4, nullptr);
static inline int mpow2(int n) {
int p = 1;
while (p <= n) p *= 2;
return p/2;
}
EMSCRIPTEN_BINDINGS(whisper) { EMSCRIPTEN_BINDINGS(whisper) {
emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { emscripten::function("init", emscripten::optional_override([](const std::string & path_model) {
if (g_worker.joinable()) { if (g_worker.joinable()) {
@ -43,7 +49,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
} }
})); }));
emscripten::function("full_default", emscripten::optional_override([](size_t index, const emscripten::val & audio, const std::string & lang, bool translate) { emscripten::function("full_default", emscripten::optional_override([](size_t index, const emscripten::val & audio, const std::string & lang, int nthreads, bool translate) {
if (g_worker.joinable()) { if (g_worker.joinable()) {
g_worker.join(); g_worker.join();
} }
@ -66,7 +72,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
params.print_special = false; params.print_special = false;
params.translate = translate; params.translate = translate;
params.language = whisper_is_multilingual(g_contexts[index]) ? lang.c_str() : "en"; params.language = whisper_is_multilingual(g_contexts[index]) ? lang.c_str() : "en";
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency()); params.n_threads = std::min(nthreads, std::min(16, mpow2(std::thread::hardware_concurrency())));
params.offset_ms = 0; params.offset_ms = 0;
std::vector<float> pcmf32; std::vector<float> pcmf32;

View File

@ -40,21 +40,34 @@
Note that the computation is quite heavy and may take a few seconds to complete.<br> Note that the computation is quite heavy and may take a few seconds to complete.<br>
The transcription results will be displayed in the text area below.<br><br> The transcription results will be displayed in the text area below.<br><br>
<b>Important: your browser must support WASM SIMD instructions for this to work.</b> <b>Important:</b>
<ul>
<li>your browser must support WASM SIMD instructions for this to work</li>
<li>quantized models are still in experimental stage (<a href="https://github.com/ggerganov/ggml/pull/27">more info</a>)</li>
<li>Firefox cannot load files larger than 256 MB - use Chrome instead</li>
</ul>
<br><br><hr> <hr>
<div id="model"> <div id="model">
Whisper model: <span id="model-whisper-status"></span> Whisper models: <span id="model-whisper-status"></span><br><br>
<button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button> <button id="fetch-whisper-tiny-en" onclick="loadWhisper('tiny.en')">tiny.en (75 MB)</button>
<button id="fetch-whisper-tiny" onclick="loadWhisper('tiny')">tiny (75 MB)</button> <button id="fetch-whisper-tiny" onclick="loadWhisper('tiny')">tiny (75 MB)</button>
<button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button> <button id="fetch-whisper-base-en" onclick="loadWhisper('base.en')">base.en (142 MB)</button>
<button id="fetch-whisper-base" onclick="loadWhisper('base')">base (142 MB)</button> <button id="fetch-whisper-base" onclick="loadWhisper('base')">base (142 MB)</button>
<button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button> <button id="fetch-whisper-small-en" onclick="loadWhisper('small.en')">small.en (466 MB)</button>
<button id="fetch-whisper-small" onclick="loadWhisper('small')">small (466 MB)</button> <button id="fetch-whisper-small" onclick="loadWhisper('small')">small (466 MB)</button>
<span id="fetch-whisper-progress"></span>
<input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" /> <input type="file" id="whisper-file" name="file" onchange="loadFile(event, 'whisper.bin')" />
<br><br>
Quantized models:<br><br>
<button id="fetch-whisper-base-en-q4_0" onclick="loadWhisper('base-en-q4_0')">base.en (4bit, 49 MB)</button>
<button id="fetch-whisper-base-q4_0" onclick="loadWhisper('base-q4_0')">base (4bit, 49 MB)</button>
<button id="fetch-whisper-small-en-q4_0" onclick="loadWhisper('small-en-q4_0')">small.en (4bit, 152 MB)</button>
<button id="fetch-whisper-small-q4_0" onclick="loadWhisper('small-q4_0')">small (4bit, 152 MB)</button><br>
<button id="fetch-whisper-medium-en-q4_0" onclick="loadWhisper('medium-en-q4_0')">medium.en (4bit, 469 MB)</button>
<button id="fetch-whisper-medium-q4_0" onclick="loadWhisper('medium-q4_0')">medium (4bit, 469 MB)</button>
<button id="fetch-whisper-large-q4_0" onclick="loadWhisper('large-q4_0')">large (4bit, 985 MB)</button>
<span id="fetch-whisper-progress"></span>
</div> </div>
<br> <br>
@ -161,6 +174,12 @@
<option value="yi">Yiddish</option> <option value="yi">Yiddish</option>
</select> </select>
</td> </td>
<!-- Slider to select number of threads between 1 and 16 -->
<td>
Threads:
<input type="range" id="threads" name="threads" min="1" max="16" value="8" onchange="changeThreads(this.value)" />
<span id="threads-value">8</span>
</td>
<td> <td>
<button onclick="onProcess(false);">Transcribe</button> <button onclick="onProcess(false);">Transcribe</button>
</td> </td>
@ -263,11 +282,13 @@
Module.FS_createDataFile("/", fname, buf, true, true); Module.FS_createDataFile("/", fname, buf, true, true);
model_whisper = fname; //model_whisper = fname;
document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!'; document.getElementById('model-whisper-status').innerHTML = 'loaded "' + model_whisper + '"!';
printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length); printTextarea('storeFS: stored model: ' + fname + ' size: ' + buf.length);
document.getElementById('model').innerHTML = 'Model fetched: ' + model_whisper;
} }
function loadFile(event, fname) { function loadFile(event, fname) {
@ -292,6 +313,15 @@
document.getElementById('fetch-whisper-tiny' ).style.display = 'none'; document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
document.getElementById('fetch-whisper-base' ).style.display = 'none'; document.getElementById('fetch-whisper-base' ).style.display = 'none';
document.getElementById('fetch-whisper-small' ).style.display = 'none'; document.getElementById('fetch-whisper-small' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-base-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-small-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-medium-en-q4_0').style.display = 'none';
document.getElementById('fetch-whisper-medium-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-large-q4_0' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none'; document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name; document.getElementById('model-whisper-status' ).innerHTML = 'loaded model: ' + file.name;
} }
@ -304,6 +334,14 @@
'base': 'https://whisper.ggerganov.com/ggml-model-whisper-base.bin', 'base': 'https://whisper.ggerganov.com/ggml-model-whisper-base.bin',
'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin', 'small.en': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en.bin',
'small': 'https://whisper.ggerganov.com/ggml-model-whisper-small.bin', 'small': 'https://whisper.ggerganov.com/ggml-model-whisper-small.bin',
'base-en-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-base.en-q4_0.bin',
'base-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-base-q4_0.bin',
'small-en-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-small.en-q4_0.bin',
'small-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-small-q4_0.bin',
'medium-en-q4_0':'https://whisper.ggerganov.com/ggml-model-whisper-medium.en-q4_0.bin',
'medium-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-medium-q4_0.bin',
'large-q4_0': 'https://whisper.ggerganov.com/ggml-model-whisper-large-q4_0.bin',
}; };
let sizes = { let sizes = {
@ -313,6 +351,14 @@
'base': 142, 'base': 142,
'small.en': 466, 'small.en': 466,
'small': 466, 'small': 466,
'base-en-q4_0': 49,
'base-q4_0': 49,
'small-en-q4_0': 152,
'small-q4_0': 152,
'medium-en-q4_0': 469,
'medium-q4_0': 469,
'large-q4_0': 985,
}; };
let url = urls[model]; let url = urls[model];
@ -327,6 +373,15 @@
document.getElementById('fetch-whisper-tiny' ).style.display = 'none'; document.getElementById('fetch-whisper-tiny' ).style.display = 'none';
document.getElementById('fetch-whisper-base' ).style.display = 'none'; document.getElementById('fetch-whisper-base' ).style.display = 'none';
document.getElementById('fetch-whisper-small' ).style.display = 'none'; document.getElementById('fetch-whisper-small' ).style.display = 'none';
document.getElementById('fetch-whisper-base-en-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-base-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-small-en-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-small-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-medium-en-q4_0').style.display = 'none';
document.getElementById('fetch-whisper-medium-q4_0' ).style.display = 'none';
document.getElementById('fetch-whisper-large-q4_0' ).style.display = 'none';
document.getElementById('whisper-file' ).style.display = 'none'; document.getElementById('whisper-file' ).style.display = 'none';
document.getElementById('model-whisper-status' ).innerHTML = 'loading model: ' + model; document.getElementById('model-whisper-status' ).innerHTML = 'loading model: ' + model;
@ -337,12 +392,22 @@
cbCancel = function() { cbCancel = function() {
var el; var el;
el = document.getElementById('fetch-whisper-tiny-en' ); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-tiny-en' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en' ); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-base-en' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-small-en'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-tiny' ); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-tiny' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base' ); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-base' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small' ); if (el) el.style.display = 'inline-block'; el = document.getElementById('fetch-whisper-small' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-en-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-base-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-en-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-small-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-medium-en-q4_0'); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-medium-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('fetch-whisper-large-q4_0' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block'; el = document.getElementById('whisper-file' ); if (el) el.style.display = 'inline-block';
el = document.getElementById('model-whisper-status' ); if (el) el.innerHTML = ''; el = document.getElementById('model-whisper-status' ); if (el) el.innerHTML = '';
}; };
@ -354,7 +419,8 @@
// audio file // audio file
// //
const kMaxAudio_s = 120; const kMaxAudio_s = 30*60;
const kMaxRecording_s = 2*60;
const kSampleRate = 16000; const kSampleRate = 16000;
window.AudioContext = window.AudioContext || window.webkitAudioContext; window.AudioContext = window.AudioContext || window.webkitAudioContext;
@ -423,7 +489,7 @@
doRecording = false; doRecording = false;
} }
// record up to kMaxAudio_s seconds of audio from the microphone // record up to kMaxRecording_s seconds of audio from the microphone
// check if doRecording is false every 1000 ms and stop recording if so // check if doRecording is false every 1000 ms and stop recording if so
// update progress information // update progress information
function startRecording() { function startRecording() {
@ -479,9 +545,9 @@
printTextarea('js: audio recorded, size: ' + audio.length); printTextarea('js: audio recorded, size: ' + audio.length);
// truncate to first 30 seconds // truncate to first 30 seconds
if (audio.length > kMaxAudio_s*kSampleRate) { if (audio.length > kMaxRecording_s*kSampleRate) {
audio = audio.slice(0, kMaxAudio_s*kSampleRate); audio = audio.slice(0, kMaxRecording_s*kSampleRate);
printTextarea('js: truncated audio to first ' + kMaxAudio_s + ' seconds'); printTextarea('js: truncated audio to first ' + kMaxRecording_s + ' seconds');
} }
setAudio(audio); setAudio(audio);
}); });
@ -509,24 +575,31 @@
}); });
} }
document.getElementById('progress-bar').style.width = (100*(Date.now() - startTime)/1000/kMaxAudio_s) + '%'; document.getElementById('progress-bar').style.width = (100*(Date.now() - startTime)/1000/kMaxRecording_s) + '%';
document.getElementById('progress-text').innerHTML = (100*(Date.now() - startTime)/1000/kMaxAudio_s).toFixed(0) + '%'; document.getElementById('progress-text').innerHTML = (100*(Date.now() - startTime)/1000/kMaxRecording_s).toFixed(0) + '%';
}, 1000); }, 1000);
printTextarea('js: recording ...'); printTextarea('js: recording ...');
setTimeout(function() { setTimeout(function() {
if (doRecording) { if (doRecording) {
printTextarea('js: recording stopped after ' + kMaxAudio_s + ' seconds'); printTextarea('js: recording stopped after ' + kMaxRecording_s + ' seconds');
stopRecording(); stopRecording();
} }
}, kMaxAudio_s*1000); }, kMaxRecording_s*1000);
} }
// //
// transcribe // transcribe
// //
var nthreads = 8;
function changeThreads(value) {
nthreads = value;
document.getElementById('threads-value').innerHTML = nthreads;
}
function onProcess(translate) { function onProcess(translate) {
if (!instance) { if (!instance) {
instance = Module.init('whisper.bin'); instance = Module.init('whisper.bin');
@ -553,7 +626,7 @@
printTextarea(''); printTextarea('');
setTimeout(function() { setTimeout(function() {
var ret = Module.full_default(instance, audio, document.getElementById('language').value, translate); var ret = Module.full_default(instance, audio, document.getElementById('language').value, nthreads, translate);
console.log('js: full_default returned: ' + ret); console.log('js: full_default returned: ' + ret);
if (ret) { if (ret) {
printTextarea("js: whisper returned: " + ret); printTextarea("js: whisper returned: " + ret);

1790
ggml.c

File diff suppressed because it is too large Load Diff

7
ggml.h
View File

@ -198,6 +198,8 @@ struct ggml_object;
struct ggml_context; struct ggml_context;
enum ggml_type { enum ggml_type {
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_I8, GGML_TYPE_I8,
GGML_TYPE_I16, GGML_TYPE_I16,
GGML_TYPE_I32, GGML_TYPE_I32,
@ -326,7 +328,10 @@ void ggml_print_objects(const struct ggml_context * ctx);
int ggml_nelements(const struct ggml_tensor * tensor); int ggml_nelements(const struct ggml_tensor * tensor);
size_t ggml_nbytes (const struct ggml_tensor * tensor); size_t ggml_nbytes (const struct ggml_tensor * tensor);
size_t ggml_type_size (enum ggml_type type); int ggml_blck_size (enum ggml_type type);
size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
size_t ggml_element_size(const struct ggml_tensor * tensor); size_t ggml_element_size(const struct ggml_tensor * tensor);
struct ggml_context * ggml_init(struct ggml_init_params params); struct ggml_context * ggml_init(struct ggml_init_params params);

View File

@ -252,12 +252,34 @@ static const std::map<e_model, size_t> MEM_REQ_SCRATCH3 = {
{ MODEL_LARGE, 9ull*MB }, { MODEL_LARGE, 9ull*MB },
}; };
static const std::map<e_model, size_t> MEM_REQ_MODEL = { static const std::map<ggml_type, std::map<e_model, size_t>> MEM_REQ_MODEL = {
{ GGML_TYPE_F16,
{
{ MODEL_TINY, 74ull*MB }, { MODEL_TINY, 74ull*MB },
{ MODEL_BASE, 142ull*MB }, { MODEL_BASE, 142ull*MB },
{ MODEL_SMALL, 466ull*MB }, { MODEL_SMALL, 466ull*MB },
{ MODEL_MEDIUM, 1464ull*MB }, { MODEL_MEDIUM, 1464ull*MB },
{ MODEL_LARGE, 2952ull*MB }, { MODEL_LARGE, 2952ull*MB },
},
},
{ GGML_TYPE_Q4_0,
{
{ MODEL_TINY, 26ull*MB },
{ MODEL_BASE, 50ull*MB },
{ MODEL_SMALL, 154ull*MB },
{ MODEL_MEDIUM, 470ull*MB },
{ MODEL_LARGE, 940ull*MB },
},
},
{ GGML_TYPE_Q4_1,
{
{ MODEL_TINY, 31ull*MB },
{ MODEL_BASE, 57ull*MB },
{ MODEL_SMALL, 181ull*MB },
{ MODEL_MEDIUM, 559ull*MB },
{ MODEL_LARGE, 1122ull*MB },
},
},
}; };
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = { static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
@ -681,7 +703,7 @@ static bool kv_cache_reinit(struct whisper_kv_cache & cache) {
const ggml_type wtype = cache.k->type; const ggml_type wtype = cache.k->type;
WHISPER_ASSERT(wtype == cache.v->type); WHISPER_ASSERT(wtype == cache.v->type);
WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_sizef(wtype));
struct ggml_init_params params; struct ggml_init_params params;
params.mem_size = cache.buf.size(); params.mem_size = cache.buf.size();
@ -776,12 +798,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.type = e_model::MODEL_LARGE; model.type = e_model::MODEL_LARGE;
} }
// for the big tensors, we have the option to store the data in 16-bit floats // for the big tensors, we have the option to store the data in 16-bit floats or quantized
// in order to save memory and also to speed up the computation // in order to save memory and also to speed up the computation
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; wctx.wtype = GGML_TYPE_COUNT;
switch (model.hparams.f16) {
case 0: wctx.wtype = GGML_TYPE_F32; break;
case 1: wctx.wtype = GGML_TYPE_F16; break;
case 2: wctx.wtype = GGML_TYPE_Q4_0; break;
case 3: wctx.wtype = GGML_TYPE_Q4_1; break;
default:
{
fprintf(stderr, "%s: invalid model (bad f16 value %d)\n", __func__, model.hparams.f16);
return false;
}
}
const size_t scale = model.hparams.f16 ? 1 : 2; const size_t scale = model.hparams.f16 ? 1 : 2;
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx);
fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state);
@ -792,7 +827,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head); fprintf(stderr, "%s: n_text_head = %d\n", __func__, hparams.n_text_head);
fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); fprintf(stderr, "%s: n_text_layer = %d\n", __func__, hparams.n_text_layer);
fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels);
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: ftype = %s\n", __func__, ftype_str[model.hparams.f16]);
fprintf(stderr, "%s: type = %d\n", __func__, model.type); fprintf(stderr, "%s: type = %d\n", __func__, model.type);
// print memory requirements // print memory requirements
@ -803,7 +838,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH1.at (model.type) + MEM_REQ_SCRATCH1.at (model.type) +
MEM_REQ_SCRATCH2.at (model.type) + MEM_REQ_SCRATCH2.at (model.type) +
MEM_REQ_SCRATCH3.at (model.type) + MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) + scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) + scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)); scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
@ -819,9 +854,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// always have at least one decoder // always have at least one decoder
wctx.model.buf = new std::vector<uint8_t>(); wctx.model.buf = new std::vector<uint8_t>();
wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(wctx.wtype).at(model.type));
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, GGML_TYPE_F16, model.hparams.n_text_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
return false; return false;
} }
@ -831,7 +866,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
} }
if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, GGML_TYPE_F16, model.hparams.n_audio_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__);
return false; return false;
} }
@ -963,92 +998,92 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// encoder // encoder
{ {
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe; ctx_size += n_audio_ctx*n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_pe;
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w ctx_size += 3*n_mels*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_1_w
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_1_b ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_1_b
ctx_size += 3*n_audio_state*n_audio_state*ggml_type_size(wtype); // e_conv_2_w ctx_size += 3*n_audio_state*n_audio_state*ggml_type_sizef(GGML_TYPE_F16); // e_conv_2_w
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_conv_2_b ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_conv_2_b
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_w; ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_w;
ctx_size += n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_ln_b; ctx_size += n_audio_state*ggml_type_sizef(GGML_TYPE_F32); // e_ln_b;
} }
// decoder // decoder
{ {
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe; ctx_size += n_text_ctx*n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_pe;
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te; ctx_size += n_vocab*n_text_state*ggml_type_sizef(wtype); // d_te;
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_w; ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_w;
ctx_size += n_text_state*ggml_type_size(GGML_TYPE_F32); // d_ln_b; ctx_size += n_text_state*ggml_type_sizef(GGML_TYPE_F32); // d_ln_b;
} }
// encoder layers // encoder layers
{ {
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_0_w ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_0_w
ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b ctx_size += n_audio_layer*( 4*n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_size(wtype)); // mlp_1_w ctx_size += n_audio_layer*(4*n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // mlp_1_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
ctx_size += n_audio_layer*(n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b ctx_size += n_audio_layer*(n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_q_w ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_q_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_k_w ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_k_w
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_v_w ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_v_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_size(wtype)); // attn_ln_1_w ctx_size += n_audio_layer*(n_audio_state*n_audio_state*ggml_type_sizef(wtype)); // attn_ln_1_w
ctx_size += n_audio_layer*( n_audio_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b ctx_size += n_audio_layer*( n_audio_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
} }
// decoder layers // decoder layers
{ {
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_w ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_ln_b ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_ln_b
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_0_w ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_0_w
ctx_size += n_text_layer*( 4*n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_0_b ctx_size += n_text_layer*( 4*n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_0_b
ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_size(wtype)); // mlp_1_w ctx_size += n_text_layer*(4*n_text_state*n_text_state*ggml_type_sizef(wtype)); // mlp_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // mlp_1_b ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // mlp_1_b
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_w ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_0_b ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_0_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_q_w ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_q_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_q_b ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_q_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_k_w ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_k_w
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_v_w ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_v_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_v_b ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_v_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // attn_ln_1_w ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // attn_ln_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // attn_ln_1_b ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // attn_ln_1_b
// //
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_w ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_w
ctx_size += n_text_layer*(n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_0_b ctx_size += n_text_layer*(n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_0_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_q_w ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_q_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_q_b ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_q_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_k_w ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_k_w
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_v_w ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_v_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_v_b ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_v_b
ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_size(wtype)); // cross_attn_ln_1_w ctx_size += n_text_layer*(n_text_state*n_text_state*ggml_type_sizef(wtype)); // cross_attn_ln_1_w
ctx_size += n_text_layer*( n_text_state*ggml_type_size(GGML_TYPE_F32)); // cross_attn_ln_1_b ctx_size += n_text_layer*( n_text_state*ggml_type_sizef(GGML_TYPE_F32)); // cross_attn_ln_1_b
} }
ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead
@ -1094,10 +1129,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
{ {
model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx);
model.e_conv_1_w = ggml_new_tensor_3d(ctx, wtype, 3, n_mels, n_audio_state); model.e_conv_1_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_mels, n_audio_state);
model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_conv_2_w = ggml_new_tensor_3d(ctx, wtype, 3, n_audio_state, n_audio_state); model.e_conv_2_w = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 3, n_audio_state, n_audio_state);
model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state);
model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state);
@ -1313,9 +1348,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
return false; return false;
} }
const size_t bpe = (ftype == 0) ? sizeof(float) : sizeof(ggml_fp16_t); size_t bpe = 0;
if (nelements*bpe != ggml_nbytes(tensor)) { switch (ftype) {
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
default:
{
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
return false;
}
};
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe); __func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
return false; return false;
@ -1513,14 +1560,14 @@ static bool whisper_encode(
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_cpy(ctx0, ggml_cpy(ctx0,
Qcur, Qcur,
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_cpy(ctx0, ggml_cpy(ctx0,
Kcur, Kcur,
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
struct ggml_tensor * V = struct ggml_tensor * V =
@ -1530,7 +1577,7 @@ static bool whisper_encode(
Vcur, Vcur,
n_state/n_head, n_head, n_ctx), n_state/n_head, n_head, n_ctx),
1, 2, 0, 3), 1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
); );
struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false);
@ -1546,7 +1593,7 @@ static bool whisper_encode(
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_cpy(ctx0, ggml_cpy(ctx0,
Kcur, Kcur,
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
0, 2, 1, 3); 0, 2, 1, 3);
// K * Q // K * Q
@ -1564,7 +1611,7 @@ static bool whisper_encode(
// ggml_permute(ctx0, // ggml_permute(ctx0,
// ggml_cpy(ctx0, // ggml_cpy(ctx0,
// Vcur, // Vcur,
// ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)), // ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
// 1, 2, 0, 3); // 1, 2, 0, 3);
//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); //struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
@ -1576,7 +1623,7 @@ static bool whisper_encode(
Vcur, Vcur,
n_state/n_head, n_head, n_ctx), n_state/n_head, n_head, n_ctx),
0, 2, 1, 3), 0, 2, 1, 3),
ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_ctx, n_head) ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
); );
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_transpose(ctx0, V), KQ_soft_max);
@ -1634,7 +1681,7 @@ static bool whisper_encode(
wctx.use_buf(ctx0, 0); wctx.use_buf(ctx0, 0);
cur = ggml_flash_ff(ctx0, cur = ggml_flash_ff(ctx0,
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wctx.wtype, n_state, n_ctx)), ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_state, n_ctx)),
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else #else
wctx.use_buf(ctx0, 0); wctx.use_buf(ctx0, 0);
@ -4445,23 +4492,32 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float) // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256); std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*256);
// put a bunch of random data in the buffer
for (size_t i = 0; i < buf.size(); i++) buf[i] = i; for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
for (int j = 0; j < (int) sizes.size(); j++) { for (int j = 0; j < (int) sizes.size(); j++) {
int n_q4_0 = 0;
int n_q4_1 = 0;
int n_fp16 = 0; int n_fp16 = 0;
int n_fp32 = 0; int n_fp32 = 0;
// GFLOPS/s // GFLOPS/s
double s_q4_0 = 0.0;
double s_q4_1 = 0.0;
double s_fp16 = 0.0; double s_fp16 = 0.0;
double s_fp32 = 0.0; double s_fp32 = 0.0;
const size_t N = sizes[j]; const size_t N = sizes[j];
for (int k = 0; k < 2; ++k) { for (int k = 0; k < 4; ++k) {
const ggml_type wtype = k == 0 ? GGML_TYPE_F16 : GGML_TYPE_F32; const ggml_type wtype =
k == 0 ? GGML_TYPE_Q4_0 :
k == 1 ? GGML_TYPE_Q4_1 :
k == 2 ? GGML_TYPE_F16 :
GGML_TYPE_F32;
double & s = k == 0 ? s_fp16 : s_fp32; double & s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 : k == 2 ? s_fp16 : s_fp32;
int & n = k == 0 ? n_fp16 : n_fp32; int & n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 : k == 2 ? n_fp16 : n_fp32;
struct ggml_init_params gparams = { struct ggml_init_params gparams = {
/*.mem_size =*/ buf.size(), /*.mem_size =*/ buf.size(),
@ -4504,8 +4560,8 @@ WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) {
s = ((2.0*N*N*N*n)/tsum)*1e-9; s = ((2.0*N*N*N*n)/tsum)*1e-9;
} }
fprintf(stderr, "ggml_mul_mat: %5zu x %5zu: F16 %8.1f GFLOPS (%3d runs) / F32 %8.1f GFLOPS (%3d runs)\n", fprintf(stderr, "ggml_mul_mat: %4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) / Q4_1 %7.1f GFLOPS (%3d runs) / F16 %7.1f GFLOPS (%3d runs) / F32 %7.1f GFLOPS (%3d runs)\n",
N, N, s_fp16, n_fp16, s_fp32, n_fp32); N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1, s_fp16, n_fp16, s_fp32, n_fp32);
} }
return 0; return 0;