bench : add batch size 5 bench

This commit is contained in:
Georgi Gerganov
2023-11-14 22:45:08 +02:00
parent 3ed9af34f2
commit ae1bd69041
3 changed files with 37 additions and 16 deletions

View File

@@ -81,7 +81,7 @@ int whisper_bench_full(const whisper_params & params) {
}
// heat encoder
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to encode: %d\n", ret);
return 4;
}
@@ -90,13 +90,13 @@ int whisper_bench_full(const whisper_params & params) {
// prompt heat
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
// text-generation heat
if (int ret = whisper_decode(ctx, tokens, 1, 256, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
@@ -104,20 +104,30 @@ int whisper_bench_full(const whisper_params & params) {
// actual run
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
fprintf(stderr, "error: failed to encode: %d\n", ret);
return 4;
}
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
// text-generation
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}
for (int i = 0; i < 256; i++) {
if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);
// batched decoding
for (int i = 0; i < 64; i++) {
if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}
// prompt processing
for (int i = 0; i < 16; i++) {
if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to decode: %d\n", ret);
return 4;
}
}