From c72d3ce93539408ed278e64989cdae60f81c0fac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Jun 2025 21:33:40 +0300 Subject: [PATCH] metal : use F32 accumulators in FA kernels (llama/13975) ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 8 ++- ggml/src/ggml-metal/ggml-metal.metal | 94 +++++++++++++++------------- 2 files changed, 57 insertions(+), 45 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f78e7eee..bc93bc63 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4766,6 +4766,8 @@ static bool ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; + // 2*(2*ncpsg + nqptg)*(nsg) // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) // @@ -4773,7 +4775,7 @@ static bool ggml_metal_encode_node( // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; @@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node( // and store the soft_max values and the mask // // ne00*(nsg) - // each simdgroup has a full f16 head vector in shared mem to accumulate results + // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 59899550..58763e39 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext( constexpr short NW = N_SIMDWIDTH; constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = DK + 2*TS; // shared memory size per query in (half) + const short TS = nsg*SH; // shared memory size per query in (s_t == float) + const short T = 2*DK + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t @@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext( if (iq1 + j < args.ne01) { sq4[j*DK4 + i] = (q4_t) q4[i]; } else { - sq4[j*DK4 + i] = (q4_t) 0.0f; + sq4[j*DK4 + i] = 0; } } } @@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { - float S = { 0.0f }; - float M = { -__FLT_MAX__/2 }; - threadgroup_barrier(mem_flags::mem_threadgroup); // each simdgroup stores its output to shared memory, reusing sq @@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext( const float M0 = ss[j*TS + 1]; const float M1 = ss[j*TS + sg*SH + 1]; - M = max(M0, M1); + const float M = max(M0, M1); const float ms0 = exp(M0 - M); const float ms1 = exp(M1 - M); - S = S0*ms0 + S1*ms1; + const float S = S0*ms0 + S1*ms1; if (tiisg == 0) { ss[j*TS + 0] = S; @@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext( } } - device float4 * dst4 = (device float4 *) dst; + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK); // final rescale with 1/S and store to global memory - if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { - const float S = ss[j*TS + 0]; + for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { + const float S = 1.0f/sf[j*TS + 0]; - for (short i = tiisg; i < DV4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; - } + device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; + + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*DV4 + i]*S; } } } @@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext( // template to be able to explore different combinations // #define FA_TYPES \ - half, half4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 + float, float4, simdgroup_float8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 + +#define FA_TYPES_BF \ + bfloat, bfloat4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + float, float4, simdgroup_float8x8 + //half, half4, simdgroup_half8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; @@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES +#undef FA_TYPES_BF template< typename q4_t, // query types in shared memory @@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec( const short T = DK + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results // store the result for all queries in local memory (the O matrix from the paper) o4_t lo[DV4/NL]; @@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec( half4, \ float, \ float, float4, \ - half4 + float4 typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t;