mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-01 07:25:49 +02:00
vulkan: use scalar FA rather than coopmat2 when N==1 (llama/13554)
This commit is contained in:
parent
4fedad988b
commit
6d61a09bc4
@ -5872,10 +5872,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|||||||
vk_pipeline *pipelines;
|
vk_pipeline *pipelines;
|
||||||
bool small_rows = N <= get_fa_num_small_rows(path);
|
bool small_rows = N <= get_fa_num_small_rows(path);
|
||||||
|
|
||||||
|
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
||||||
|
// So use scalar instead.
|
||||||
if (small_rows && path == FA_COOPMAT1) {
|
if (small_rows && path == FA_COOPMAT1) {
|
||||||
path = FA_SCALAR;
|
path = FA_SCALAR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// scalar is faster than coopmat2 when N==1
|
||||||
|
if (N == 1 && path == FA_COOPMAT2) {
|
||||||
|
path = FA_SCALAR;
|
||||||
|
}
|
||||||
|
|
||||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||||
|
|
||||||
switch (path) {
|
switch (path) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user