mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-30 22:40:14 +02:00
whisper : use flash attention (#2152)
* whisper : use flash attention in the encoder * whisper : add kv_pad * whisper : remove extra backend instance (huh?) * whisper : use FA for cross-attention * whisper : use FA for self-attention * whisper : simplify encoder FA * whisper : add flash_attn runtime parameter * scripts : add bench log * scripts : add M1 Pro bench log
This commit is contained in:
@ -36,6 +36,7 @@ struct whisper_params {
|
||||
bool tinydiarize = false;
|
||||
bool save_audio = false; // save audio to wav file
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
@ -72,6 +73,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
||||
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
||||
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
||||
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
@ -153,7 +156,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
|
||||
|
||||
|
Reference in New Issue
Block a user