llama : add high-throughput mode (llama/14363)

* kv-cache : prepare K/V buffers for separation

ggml-ci

* batched-bench : fix oob write

ggml-ci

* llama : add "virtual sequences"

ggml-ci

* llama : use "stream" vs "virtual sequence"

ggml-ci

* graph : fix stream splitting when KV cache is not used

ggml-ci

* kv-cache : add multi-stream save/load support

ggml-ci

* llama : add "--attn-streams" flag

ggml-ci

* kv-cache : fix handling when find_slot fails

ggml-ci

* kv-cache : restore find_slot impl

ggml-ci

* kv-cache : add comments

* kv-cache : add bounds checks for sequence id

ggml-ci

* cont : add n_seq_max to batch allocr

ggml-ci

* kv-cache : perform stream copies lazily after llama_synchronize

ggml-ci

* kv-cache : avoid throwing exceptions across the C boundary

ggml-ci

* CUDA: 4D FlashAttention support (llama/14628)

* CUDA: 4D FlashAttention support

* CUDA: fix WMMA FA kernel

* llama : rename attn_streams -> kv_unified

ggml-ci

* common : rename kv_split -> kv_unified

ggml-ci

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Georgi Gerganov
2025-07-16 16:35:42 +03:00
parent 9cc645fec0
commit ae1bb2c8ea
8 changed files with 143 additions and 102 deletions

View File

@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
template<int D, int ncols1, int ncols2> // D == head size template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1) __launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup( static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
constexpr int ncols = ncols1*ncols2; constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x; const int bidx0 = blockIdx.x;
@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
return; return;
} }
const int channel = kbc0 / (iter_k*iter_j); const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k; const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
if (jt*ncols1 + j >= ne01) { if (jt*ncols1 + j >= ne01) {
return; return;
} }
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid; dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup: // Load the partial result that needs a fixup:
float dst_val = 0.0f; float dst_val = 0.0f;
@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
int bidx = bidx0 - 1; int bidx = bidx0 - 1;
int kbc_stop = kbc0; int kbc_stop = kbc0;
while(true) { while(true) {
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data. if (kbc == kbc_stop) { // Did not have any data.
bidx--; bidx--;
kbc_stop = kbc; kbc_stop = kbc;
@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
const float2 * __restrict__ VKQ_meta, const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst, float * __restrict__ dst,
const int parallel_blocks) { const int parallel_blocks) {
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x; // Dimension 0: threadIdx.x
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x; // Dimension 1: blockIdx.x
dst += D * gridDim.z*blockIdx.x; // Dimension 2: blockIdx.y
// Dimension 3: blockIdx.z
// Memory layout is permuted with [0, 2, 1, 3]
const int ne01 = gridDim.x;
const int ne02 = gridDim.y;
const int col = blockIdx.x;
const int head = blockIdx.y;
const int sequence = blockIdx.z;
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
VKQ_meta += j_dst_unrolled * parallel_blocks;
dst += j_dst_unrolled * D;
const int tid = threadIdx.x; const int tid = threadIdx.x;
__builtin_assume(tid < D); __builtin_assume(tid < D);
extern __shared__ float2 meta[]; extern __shared__ float2 meta[];
for (int i = tid; i < 2*parallel_blocks; i += D) { for (int i = tid; i < 2*parallel_blocks; i += D) {
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
} }
__syncthreads(); __syncthreads();
@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask; *((uint32_t *) &KQ_max_scale) &= ftz_mask;
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid]; VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y; VKQ_denominator += KQ_max_scale * meta[l].y;
} }
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; dst[tid] = VKQ_numerator / VKQ_denominator;
} }
[[noreturn]] [[noreturn]]
@ -705,8 +723,6 @@ void launch_fattn(
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
GGML_ASSERT(Q->ne[3] == 1);
ggml_cuda_pool & pool = ctx.pool(); ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream(); cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device(); const int id = ggml_cuda_get_device();
@ -853,8 +869,8 @@ void launch_fattn(
scale, max_bias, m0, m1, n_head_log2, logit_softcap, scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
Q->nb[1], Q->nb[2], Q->nb[3], Q->nb[1], Q->nb[2], Q->nb[3],
nb11, nb12, nb13, nb11, nb12, nb13,
nb21, nb22, nb23, nb21, nb22, nb23,
@ -869,11 +885,11 @@ void launch_fattn(
flash_attn_stream_k_fixup<DV, ncols1, ncols2> flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>> <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
} }
} else if (parallel_blocks > 1) { } else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1); const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
flash_attn_combine_results<DV> flash_attn_combine_results<DV>

View File

@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
// kbc == k block continuous, current index in continuous ijk space. // kbc == k block continuous, current index in continuous ijk space.
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
int kb0_start = kbc % iter_k; int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) { while (kbc < kbc_stop && kb0_stop == iter_k) {
const int channel = kbc / (iter_k*iter_j); const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter;
@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
return; return;
} }
const int channel = kbc / (iter_k*iter_j); const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr : const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1); (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter; const int kb0_start_kernel = kb0_start * kb_niter;
const int kb0_stop_kernel = kb0_stop * kb_niter; const int kb0_stop_kernel = kb0_stop * kb_niter;

View File

@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads(); __syncthreads();
} }
float2 * dst2 = (float2 *) dst;
#pragma unroll #pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y; const int j_VKQ = j_VKQ_0 + threadIdx.y;
@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum((float)kqsum_j); kqsum_j = warp_reduce_sum((float)kqsum_j);
#pragma unroll const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; #pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= __half2half2(kqsum_j); dst_val /= __half2half2(kqsum_j);
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
} }
if (gridDim.y != 1 && threadIdx.x == 0) { if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
} }
} }
#else #else
@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View File

@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0); const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio)); const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const int stride_KV2 = nb11 / sizeof(half2); const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads(); __syncthreads();
} }
float2 * dst2 = (float2 *) dst;
#pragma unroll #pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y; const int j_VKQ = j_VKQ_0 + threadIdx.y;
@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
float kqsum_j = kqsum[j_VKQ_0/nwarps]; float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j); kqsum_j = warp_reduce_sum(kqsum_j);
#pragma unroll const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
const int i0 = i00 + 2*threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)]; #pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val.x /= kqsum_j; dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j; dst_val.y /= kqsum_j;
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
} }
if (gridDim.y != 1 && threadIdx.x == 0) { if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
} }
} }
#else #else

View File

@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0; Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio); K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ]; dst_val /= kqsum[j_VKQ];
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
} }
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
} }
#else #else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

View File

@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0; Q += nb03*sequence + nb02* head + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio); K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape V += nb23*sequence + nb22*(head / gqa_ratio);
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE; constexpr int nwarps = D / WARP_SIZE;
@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ]; dst_val /= kqsum[j_VKQ];
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y; dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
} }
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]); dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
} }
#else #else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);

View File

@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
const int ne13, const int ne13,
const int ne31, const int ne31,
const int ne32, const int ne32,
const int ne33,
const int nb31, const int nb31,
const int nb32, const int nb32,
const int nb33,
const int nb01, const int nb01,
const int nb02, const int nb02,
const int nb03, const int nb03,
@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0); const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio)); const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0); const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const half2 * mask2 = (const half2 *) maskh; const half2 * mask2 = (const half2 *) maskh;
const int stride_Q = nb01 / sizeof(float); const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half); const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1); const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef); const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef); const half2 slope2 = make_half2(slopef, slopef);
@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
if (ic0 + j_VKQ >= ne01) { if (ic0 + j_VKQ >= ne01) {
return; return;
} }
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
float KQ_rowsum_j; float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) { if (std::is_same<KQ_acc_t, float>::value) {
@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
} }
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size) { for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
if (gridDim.y == 1) { if (gridDim.y == 1) {
dst_val /= KQ_rowsum_j; dst_val /= KQ_rowsum_j;
} }
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val; dst[j_dst_unrolled*D + i] = dst_val;
} }
if (gridDim.y == 1 || threadIdx.x != 0) { if (gridDim.y == 1 || threadIdx.x != 0) {
@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
} }
dst_meta_val.y = KQ_rowsum_j; dst_meta_val.y = KQ_rowsum_j;
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val; dst_meta[j_dst_unrolled] = dst_meta_val;
} }
#else #else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);

View File

@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 192) { if (op->src[0]->ne[0] == 192) {
return false; return false;
} }
// TODO: support broadcast
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
if (op->src[0]->ne[3] != 1) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false; return false;
} }
@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
return true; return true;
} }
if (op->src[3] && op->src[3]->ne[2] != 1) {
return false;
}
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
} }