metal : use F32 prec in FA kernels (llama/12688)

* metal : use F32 prec in FA kernels

ggml-ci

* cont : fix FA vec kernel

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-04-01 14:57:19 +03:00
parent 6cb2b86581
commit f6ce10e4a1
2 changed files with 48 additions and 48 deletions

View File

@ -4179,7 +4179,7 @@ static void ggml_metal_encode_node(
// ne00*(nsg) // ne00*(nsg)
// each simdgroup has a full f16 head vector in shared mem to accumulate results // each simdgroup has a full f16 head vector in shared mem to accumulate results
// //
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 2*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2; int64_t nsgmax = 2;
while (true) { while (true) {

View File

@ -3184,8 +3184,8 @@ kernel void kernel_flash_attn_ext(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
{ {
half S[Q] = { [0 ... Q-1] = 0.0f }; float S[Q] = { [0 ... Q-1] = 0.0f };
half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
// thread indices inside the simdgroup // thread indices inside the simdgroup
// TODO: see if we can utilize quad-group functions for better performance // TODO: see if we can utilize quad-group functions for better performance
@ -3202,13 +3202,13 @@ kernel void kernel_flash_attn_ext(
const bool has_mask = mask != q; const bool has_mask = mask != q;
half slope = 1.0f; float slope = 1.0f;
// ALiBi // ALiBi
if (args.max_bias > 0.0f) { if (args.max_bias > 0.0f) {
const short h = iq2; const short h = iq2;
const half base = h < args.n_head_log2 ? args.m0 : args.m1; const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph); slope = pow(base, exph);
@ -3224,14 +3224,14 @@ kernel void kernel_flash_attn_ext(
if (has_mask) { if (has_mask) {
// used to detect blocks full of -INF // used to detect blocks full of -INF
half smax = -INFINITY; float smax = -INFINITY;
// load the mask in shared memory // load the mask in shared memory
#pragma unroll(Q) #pragma unroll(Q)
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
const half m = pm[ic + tiisg]; const float m = pm[ic + tiisg];
ss[j*TS + C + tiisg] = m; ss[j*TS + C + tiisg] = m;
smax = max(smax, m); smax = max(smax, m);
@ -3327,10 +3327,10 @@ kernel void kernel_flash_attn_ext(
// online softmax // online softmax
{ {
for (ushort j = 0; j < Q; ++j) { for (ushort j = 0; j < Q; ++j) {
const half m = M[j]; const float m = M[j];
// scale and apply the logitcap / mask // scale and apply the logitcap / mask
half s = ss[j*TS + tiisg]*args.scale; float s = ss[j*TS + tiisg]*args.scale;
if (args.logit_softcap != 0.0f) { if (args.logit_softcap != 0.0f) {
s = args.logit_softcap*precise::tanh(s); s = args.logit_softcap*precise::tanh(s);
@ -3341,8 +3341,8 @@ kernel void kernel_flash_attn_ext(
M[j] = simd_max(max(M[j], s)); M[j] = simd_max(max(M[j], s));
const half ms = exp(m - M[j]); const float ms = exp(m - M[j]);
const half vs = exp(s - M[j]); const float vs = exp(s - M[j]);
S[j] = S[j]*ms + simd_sum(vs); S[j] = S[j]*ms + simd_sum(vs);
@ -3444,8 +3444,8 @@ kernel void kernel_flash_attn_ext(
// reduce the warps sequentially // reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) { for (ushort sg = 1; sg < nsg; ++sg) {
half S = { 0.0f }; float S = { 0.0f };
half M = { -__FLT16_MAX__/2 }; float M = { -__FLT16_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -3461,16 +3461,16 @@ kernel void kernel_flash_attn_ext(
// the first simdgroup accumulates the results from the other simdgroups // the first simdgroup accumulates the results from the other simdgroups
if (sgitg == 0) { if (sgitg == 0) {
for (short j = 0; j < Q; ++j) { for (short j = 0; j < Q; ++j) {
const half S0 = ss[j*TS + 0]; const float S0 = ss[j*TS + 0];
const half S1 = ss[j*TS + sg*SH + 0]; const float S1 = ss[j*TS + sg*SH + 0];
const half M0 = ss[j*TS + 1]; const float M0 = ss[j*TS + 1];
const half M1 = ss[j*TS + sg*SH + 1]; const float M1 = ss[j*TS + sg*SH + 1];
M = max(M0, M1); M = max(M0, M1);
const half ms0 = exp(M0 - M); const float ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M); const float ms1 = exp(M1 - M);
S = S0*ms0 + S1*ms1; S = S0*ms0 + S1*ms1;
@ -3646,16 +3646,16 @@ kernel void kernel_flash_attn_ext_vec(
constexpr short DV4 = DV/4; constexpr short DV4 = DV/4;
constexpr short NW = N_SIMDWIDTH; constexpr short NW = N_SIMDWIDTH;
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
constexpr short SH = 2*C; // shared memory per simdgroup constexpr short SH = 4*C; // shared memory per simdgroup
const short T = DK + nsg*SH; // shared memory size per query in (half) const short T = DK + nsg*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
// store the result for all queries in local memory (the O matrix from the paper) // store the result for all queries in local memory (the O matrix from the paper)
o4_t lo[DV4/NL]; o4_t lo[DV4/NL];
@ -3684,8 +3684,8 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
{ {
half S = 0.0f; float S = 0.0f;
half M = -__FLT16_MAX__/2; float M = -__FLT16_MAX__/2;
// thread indices inside the simdgroup // thread indices inside the simdgroup
const short tx = tiisg%NL; const short tx = tiisg%NL;
@ -3703,13 +3703,13 @@ kernel void kernel_flash_attn_ext_vec(
// pointer to the mask // pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31); device const half * pm = (device const half *) (mask + iq1*args.nb31);
half slope = 1.0f; float slope = 1.0f;
// ALiBi // ALiBi
if (args.max_bias > 0.0f) { if (args.max_bias > 0.0f) {
const short h = iq2; const short h = iq2;
const half base = h < args.n_head_log2 ? args.m0 : args.m1; const float base = h < args.n_head_log2 ? args.m0 : args.m1;
const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph); slope = pow(base, exph);
@ -3799,13 +3799,13 @@ kernel void kernel_flash_attn_ext_vec(
// online softmax // online softmax
{ {
const half m = M; const float m = M;
const half s = ss[tiisg]; const float s = ss[tiisg];
M = simd_max(max(M, s)); M = simd_max(max(M, s));
const half ms = exp(m - M); const float ms = exp(m - M);
const half vs = exp(s - M); const float vs = exp(s - M);
S = S*ms + simd_sum(vs); S = S*ms + simd_sum(vs);
@ -3836,7 +3836,7 @@ kernel void kernel_flash_attn_ext_vec(
v4_t mv; v4_t mv;
deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
lo[ii/NL] += mv*ms; lo[ii/NL] += o4_t(float4(mv)*float4(ms));
} }
} }
} }
@ -3907,18 +3907,18 @@ kernel void kernel_flash_attn_ext_vec(
// parallel reduce // parallel reduce
for (short r = nsg/2; r > 0; r >>= 1) { for (short r = nsg/2; r > 0; r >>= 1) {
if (sgitg < r) { if (sgitg < r) {
const half S0 = ss[ 0]; const float S0 = ss[ 0];
const half S1 = ss[r*SH + 0]; const float S1 = ss[r*(SH/2) + 0];
const half M0 = ss[ 1]; const float M0 = ss[ 1];
const half M1 = ss[r*SH + 1]; const float M1 = ss[r*(SH/2) + 1];
const half M = max(M0, M1); const float M = max(M0, M1);
const half ms0 = exp(M0 - M); const float ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M); const float ms1 = exp(M1 - M);
const half S = S0*ms0 + S1*ms1; const float S = S0*ms0 + S1*ms1;
if (tiisg == 0) { if (tiisg == 0) {
ss[0] = S; ss[0] = S;
@ -3950,11 +3950,11 @@ kernel void kernel_flash_attn_ext_vec(
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
// //
#define FA_TYPES \ #define FA_TYPES \
half4, \ half4, \
half4, \ half4, \
half4, \ half4, \
float, \ float, \
half, half4, \ float, float4, \
half4 half4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;