mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-19 09:22:18 +02:00
whisper : fix bench regression
This commit is contained in:
33
whisper.cpp
33
whisper.cpp
@@ -118,6 +118,21 @@ static void byteswap_tensor(ggml_tensor * tensor) {
|
|||||||
#define WHISPER_USE_SCRATCH
|
#define WHISPER_USE_SCRATCH
|
||||||
#define WHISPER_MAX_SCRATCH_BUFFERS 16
|
#define WHISPER_MAX_SCRATCH_BUFFERS 16
|
||||||
|
|
||||||
|
//
|
||||||
|
// ggml helpers
|
||||||
|
//
|
||||||
|
|
||||||
|
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||||
|
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
|
||||||
|
|
||||||
|
if (plan.work_size > 0) {
|
||||||
|
buf.resize(plan.work_size);
|
||||||
|
plan.work_data = buf.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_graph_compute(graph, &plan);
|
||||||
|
}
|
||||||
|
|
||||||
// available whisper models
|
// available whisper models
|
||||||
enum e_model {
|
enum e_model {
|
||||||
MODEL_UNKNOWN,
|
MODEL_UNKNOWN,
|
||||||
@@ -666,6 +681,7 @@ struct whisper_state {
|
|||||||
|
|
||||||
// memory buffers used by encode / decode contexts
|
// memory buffers used by encode / decode contexts
|
||||||
std::vector<uint8_t> buf_compute;
|
std::vector<uint8_t> buf_compute;
|
||||||
|
std::vector<uint8_t> buf_work;
|
||||||
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
|
||||||
|
|
||||||
int buf_last = 0;
|
int buf_last = 0;
|
||||||
@@ -1830,8 +1846,8 @@ static bool whisper_encode_internal(
|
|||||||
{
|
{
|
||||||
struct ggml_cgraph gf = {};
|
struct ggml_cgraph gf = {};
|
||||||
|
|
||||||
ggml_build_forward_expand (&gf, cur);
|
ggml_build_forward_expand(&gf, cur);
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
|
||||||
|
|
||||||
//ggml_graph_print(&gf);
|
//ggml_graph_print(&gf);
|
||||||
}
|
}
|
||||||
@@ -1916,7 +1932,7 @@ static bool whisper_encode_internal(
|
|||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
|
||||||
//ggml_graph_print(&gf);
|
//ggml_graph_print(&gf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2329,8 +2345,8 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand (&gf, logits);
|
ggml_build_forward_expand(&gf, logits);
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_helper(wstate.buf_work, &gf, n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract logits for all N tokens
|
// extract logits for all N tokens
|
||||||
@@ -5225,7 +5241,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|||||||
// b: N*N*sizeof(float)
|
// b: N*N*sizeof(float)
|
||||||
// c: N*N*sizeof(float)
|
// c: N*N*sizeof(float)
|
||||||
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
|
||||||
std::vector<char> buf(4llu*N_max*N_max*sizeof(float) + 4*512);
|
std::vector<uint8_t> buf (3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead());
|
||||||
|
std::vector<uint8_t> work(1llu*N_max*N_max*sizeof(float) + 1*ggml_tensor_overhead());
|
||||||
|
|
||||||
// put a bunch of random data in the buffer
|
// put a bunch of random data in the buffer
|
||||||
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
for (size_t i = 0; i < buf.size(); i++) buf[i] = i;
|
||||||
@@ -5280,12 +5297,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
|||||||
double tsum = 0.0;
|
double tsum = 0.0;
|
||||||
|
|
||||||
// heat-up
|
// heat-up
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_helper(work, &gf, n_threads);
|
||||||
|
|
||||||
for (int i = 0; i < n_max; ++i) {
|
for (int i = 0; i < n_max; ++i) {
|
||||||
const int64_t t0 = ggml_time_us();
|
const int64_t t0 = ggml_time_us();
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_helper(work, &gf, n_threads);
|
||||||
|
|
||||||
const int64_t t1 = ggml_time_us();
|
const int64_t t1 = ggml_time_us();
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user