mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-19 17:28:09 +02:00
metal : use F32 accumulators in FA kernels (llama/13975)
ggml-ci
This commit is contained in:
parent
126aeb4a49
commit
c72d3ce935
@ -4766,6 +4766,8 @@ static bool ggml_metal_encode_node(
|
|||||||
GGML_ASSERT(nqptg % 8 == 0);
|
GGML_ASSERT(nqptg % 8 == 0);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
|
||||||
|
|
||||||
// 2*(2*ncpsg + nqptg)*(nsg)
|
// 2*(2*ncpsg + nqptg)*(nsg)
|
||||||
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
||||||
//
|
//
|
||||||
@ -4773,7 +4775,7 @@ static bool ggml_metal_encode_node(
|
|||||||
// the shared memory needed for the simdgroups to load the KV cache
|
// the shared memory needed for the simdgroups to load the KV cache
|
||||||
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
||||||
//
|
//
|
||||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
|
|
||||||
@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node(
|
|||||||
// and store the soft_max values and the mask
|
// and store the soft_max values and the mask
|
||||||
//
|
//
|
||||||
// ne00*(nsg)
|
// ne00*(nsg)
|
||||||
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||||
//
|
//
|
||||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
|
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
||||||
|
|
||||||
int64_t nsgmax = 2;
|
int64_t nsgmax = 2;
|
||||||
while (true) {
|
while (true) {
|
||||||
|
@ -3329,13 +3329,13 @@ kernel void kernel_flash_attn_ext(
|
|||||||
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
||||||
|
|
||||||
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
||||||
const short T = DK + 2*TS; // shared memory size per query in (half)
|
const short T = 2*DK + 2*TS; // shared memory size per query in (half)
|
||||||
|
|
||||||
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
|
||||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
|
||||||
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
|
||||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
|
||||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
|
||||||
|
|
||||||
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
||||||
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
||||||
@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
|
|||||||
if (iq1 + j < args.ne01) {
|
if (iq1 + j < args.ne01) {
|
||||||
sq4[j*DK4 + i] = (q4_t) q4[i];
|
sq4[j*DK4 + i] = (q4_t) q4[i];
|
||||||
} else {
|
} else {
|
||||||
sq4[j*DK4 + i] = (q4_t) 0.0f;
|
sq4[j*DK4 + i] = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
for (ushort sg = 1; sg < nsg; ++sg) {
|
for (ushort sg = 1; sg < nsg; ++sg) {
|
||||||
float S = { 0.0f };
|
|
||||||
float M = { -__FLT_MAX__/2 };
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// each simdgroup stores its output to shared memory, reusing sq
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
|
|||||||
const float M0 = ss[j*TS + 1];
|
const float M0 = ss[j*TS + 1];
|
||||||
const float M1 = ss[j*TS + sg*SH + 1];
|
const float M1 = ss[j*TS + sg*SH + 1];
|
||||||
|
|
||||||
M = max(M0, M1);
|
const float M = max(M0, M1);
|
||||||
|
|
||||||
const float ms0 = exp(M0 - M);
|
const float ms0 = exp(M0 - M);
|
||||||
const float ms1 = exp(M1 - M);
|
const float ms1 = exp(M1 - M);
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
const float S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
ss[j*TS + 0] = S;
|
ss[j*TS + 0] = S;
|
||||||
@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
device float4 * dst4 = (device float4 *) dst;
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
|
||||||
|
|
||||||
// final rescale with 1/S and store to global memory
|
// final rescale with 1/S and store to global memory
|
||||||
if (sgitg == 0) {
|
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
|
||||||
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
|
const float S = 1.0f/sf[j*TS + 0];
|
||||||
const float S = ss[j*TS + 0];
|
|
||||||
|
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
|
||||||
|
|
||||||
for (short i = tiisg; i < DV4; i += NW) {
|
for (short i = tiisg; i < DV4; i += NW) {
|
||||||
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
|
dst4[i] = (float4) so4[j*DV4 + i]*S;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
|
|||||||
// template to be able to explore different combinations
|
// template to be able to explore different combinations
|
||||||
//
|
//
|
||||||
#define FA_TYPES \
|
#define FA_TYPES \
|
||||||
half, half4, simdgroup_half8x8, \
|
float, float4, simdgroup_float8x8, \
|
||||||
half, half4x4, simdgroup_half8x8, \
|
half, half4x4, simdgroup_half8x8, \
|
||||||
half, half4x4, simdgroup_half8x8, \
|
half, half4x4, simdgroup_half8x8, \
|
||||||
float, simdgroup_float8x8, \
|
float, simdgroup_float8x8, \
|
||||||
float, simdgroup_float8x8, \
|
float, simdgroup_float8x8, \
|
||||||
half, half4, simdgroup_half8x8
|
float, float4, simdgroup_float8x8
|
||||||
|
//half, half4, simdgroup_half8x8
|
||||||
|
|
||||||
|
#define FA_TYPES_BF \
|
||||||
|
bfloat, bfloat4, simdgroup_bfloat8x8, \
|
||||||
|
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
||||||
|
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
|
||||||
|
float, simdgroup_float8x8, \
|
||||||
|
float, simdgroup_float8x8, \
|
||||||
|
float, float4, simdgroup_float8x8
|
||||||
|
//half, half4, simdgroup_half8x8
|
||||||
|
|
||||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
|
||||||
|
|
||||||
@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
|
|||||||
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
|
||||||
|
|
||||||
#if defined(GGML_METAL_USE_BF16)
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
|
||||||
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
|
||||||
@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
|
|||||||
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
|
||||||
|
|
||||||
#undef FA_TYPES
|
#undef FA_TYPES
|
||||||
|
#undef FA_TYPES_BF
|
||||||
|
|
||||||
template<
|
template<
|
||||||
typename q4_t, // query types in shared memory
|
typename q4_t, // query types in shared memory
|
||||||
@ -3852,7 +3862,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
|
||||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
|
||||||
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
|
||||||
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
|
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
|
||||||
|
|
||||||
// store the result for all queries in local memory (the O matrix from the paper)
|
// store the result for all queries in local memory (the O matrix from the paper)
|
||||||
o4_t lo[DV4/NL];
|
o4_t lo[DV4/NL];
|
||||||
@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||||||
half4, \
|
half4, \
|
||||||
float, \
|
float, \
|
||||||
float, float4, \
|
float, float4, \
|
||||||
half4
|
float4
|
||||||
|
|
||||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user