mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-10 14:28:38 +02:00
ggml : fix FA mask dim 2 and 3 (llama/14505)
* ggml : fix FA mask dim 2 and 3 ggml-ci * backends : unsupport batched FA in CUDA and Vulkan ggml-ci * vulkan : disable FA for mask->ne[2] != 1
This commit is contained in:
@ -1980,15 +1980,16 @@ extern "C" {
|
||||
|
||||
#define GGML_KQ_MASK_PAD 64
|
||||
|
||||
// q: [n_embd_k, n_batch, n_head, ne3]
|
||||
// k: [n_embd_k, n_kv, n_head_kv, ne3]
|
||||
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
|
||||
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
|
||||
// q: [n_embd_k, n_batch, n_head, ne3 ]
|
||||
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
|
||||
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
|
||||
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
|
||||
//
|
||||
// broadcast:
|
||||
// n_head % n_head_kv == 0
|
||||
// ne3 % ne32 == 0
|
||||
// n_head % ne32 == 0
|
||||
// ne3 % ne33 == 0
|
||||
//
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -7799,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
memset(VKQ32, 0, DV*sizeof(float));
|
||||
}
|
||||
|
||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
|
||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
|
||||
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
|
@ -3390,7 +3390,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
@ -230,8 +230,10 @@ typedef struct {
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne32;
|
||||
int32_t ne33;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
uint64_t nb33;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
float scale;
|
||||
|
@ -5018,8 +5018,10 @@ static bool ggml_metal_encode_node(
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.ne33 =*/ ne33,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb32 =*/ nb32,
|
||||
/*.nb33 =*/ nb33,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.scale =*/ scale,
|
||||
|
@ -3857,7 +3857,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// load the mask in shared memory
|
||||
#pragma unroll(Q)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
||||
|
||||
const float m = pm[ic + tiisg];
|
||||
|
||||
@ -4343,7 +4343,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const bool has_mask = mask != q;
|
||||
|
||||
// pointer to the mask
|
||||
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
|
||||
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
|
@ -10265,6 +10265,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
|
||||
return false;
|
||||
}
|
||||
// It's straightforward to support different K/V dequant, but would
|
||||
// significantly increase the number of pipelines
|
||||
if (op->src[1]->type != op->src[2]->type) {
|
||||
|
@ -3666,7 +3666,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||
if (mask) {
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(ggml_is_3d(mask));
|
||||
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
|
||||
@ -4696,12 +4695,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
|
||||
if (mask) {
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[2] == q->ne[3]);
|
||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||
|
||||
GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
|
||||
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
|
||||
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
|
||||
}
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
|
Reference in New Issue
Block a user