mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-01 07:25:49 +02:00
485 lines
16 KiB
Plaintext
485 lines
16 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
|
|
#extension GL_KHR_shader_subgroup_shuffle : enable
|
|
|
|
#include "types.comp"
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
|
layout (constant_id = 1) const uint32_t Br = 1;
|
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
layout (constant_id = 3) const uint32_t D = 32;
|
|
|
|
layout (constant_id = 5) const uint32_t D_split = 16;
|
|
const uint32_t D_per_thread = D / D_split;
|
|
|
|
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
|
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
|
|
|
layout (push_constant) uniform parameter {
|
|
uint32_t N;
|
|
uint32_t KV;
|
|
|
|
uint32_t ne1;
|
|
uint32_t ne2;
|
|
uint32_t ne3;
|
|
|
|
uint32_t neq2;
|
|
uint32_t neq3;
|
|
uint32_t nek2;
|
|
uint32_t nek3;
|
|
uint32_t nev2;
|
|
uint32_t nev3;
|
|
uint32_t nem1;
|
|
|
|
uint32_t nb01;
|
|
uint32_t nb02;
|
|
uint32_t nb03;
|
|
uint32_t nb11;
|
|
uint32_t nb12;
|
|
uint32_t nb13;
|
|
uint32_t nb21;
|
|
uint32_t nb22;
|
|
uint32_t nb23;
|
|
uint32_t nb31;
|
|
|
|
float scale;
|
|
float max_bias;
|
|
float logit_softcap;
|
|
|
|
uint32_t mask;
|
|
uint32_t n_head_log2;
|
|
float m0;
|
|
float m1;
|
|
|
|
uint32_t gqa_ratio;
|
|
uint32_t split_kv;
|
|
uint32_t k_num;
|
|
} p;
|
|
|
|
layout (binding = 0) readonly buffer Q {float data_q[];};
|
|
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
|
|
layout (binding = 1) readonly buffer K {float16_t data_k[];};
|
|
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
|
|
layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
|
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
|
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
|
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
|
|
#if defined(A_TYPE_PACKED16)
|
|
#define BINDING_IDX_K 0
|
|
#define BINDING_IDX_V 1
|
|
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
|
#endif
|
|
|
|
#if defined(DATA_A_Q4_0)
|
|
#define BLOCK_BYTE_SIZE 18
|
|
|
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
|
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
|
uint shift = (iqs & 0x10) >> 2;
|
|
vui_lo >>= shift;
|
|
vui_hi >>= shift;
|
|
|
|
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
|
}
|
|
#endif
|
|
|
|
#if defined(DATA_A_Q8_0)
|
|
#define BLOCK_BYTE_SIZE 34
|
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
|
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
|
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
|
|
|
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
|
}
|
|
#endif
|
|
|
|
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
|
|
// Store the output when doing grouped query attention.
|
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
{
|
|
uint32_t offset = (iq2 + r) * D + c;
|
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
return elem;
|
|
}
|
|
|
|
// Store column zero. This is used to save per-row m and L values for split_k.
|
|
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
{
|
|
if (r < N && c == 0) {
|
|
uint32_t offset = iq2 + r;
|
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
// Load the slope matrix, indexed by Q's dimension 2.
|
|
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
|
{
|
|
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
|
|
|
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
|
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
|
|
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
}
|
|
|
|
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
|
shared vec4 tmpshv4[WorkGroupSize];
|
|
|
|
shared float masksh[Bc][Br];
|
|
shared vec4 Qf[Br][D / 4];
|
|
|
|
void main() {
|
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
init_iq_shmem(gl_WorkGroupSize);
|
|
#endif
|
|
|
|
const uint32_t tid = gl_LocalInvocationIndex;
|
|
const uint32_t N = p.N;
|
|
const uint32_t KV = p.KV;
|
|
|
|
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
|
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
|
|
|
uint32_t i = gl_WorkGroupID.x;
|
|
uint32_t split_k_index = 0;
|
|
|
|
if (p.k_num > 1) {
|
|
i = 0;
|
|
split_k_index = gl_WorkGroupID.x;
|
|
}
|
|
|
|
const uint32_t Tr = CEIL_DIV(N, Br);
|
|
|
|
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
|
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
|
|
|
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
|
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
|
const uint32_t iq3 = gl_WorkGroupID.z;
|
|
|
|
// broadcast factors
|
|
const uint32_t rk2 = p.neq2/p.nek2;
|
|
const uint32_t rk3 = p.neq3/p.nek3;
|
|
|
|
const uint32_t rv2 = p.neq2/p.nev2;
|
|
const uint32_t rv3 = p.neq3/p.nev3;
|
|
|
|
// k indices
|
|
const uint32_t ik3 = iq3 / rk3;
|
|
const uint32_t ik2 = iq2 / rk2;
|
|
|
|
// v indices
|
|
const uint32_t iv3 = iq3 / rv3;
|
|
const uint32_t iv2 = iq2 / rv2;
|
|
|
|
// nb?1 are already divided by the type size and are in units of elements.
|
|
// When using grouped query attention, Q is indexed by iq2, so the stride
|
|
// should be nb02 (which is in bytes).
|
|
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
uint32_t k_stride = p.nb11;
|
|
uint32_t v_stride = p.nb21;
|
|
// When using grouped query attention, all rows use the same mask (stride 0).
|
|
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
|
// that prevents the compiler from folding the "&" through the select
|
|
// and breaking the alignment detection.
|
|
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
|
|
|
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
|
|
|
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
|
|
uint32_t d = (idx + tid) % (D / 4);
|
|
uint32_t r = (idx + tid) / (D / 4);
|
|
if (r < Br && d < D / 4 &&
|
|
i * Br + r < N) {
|
|
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
|
}
|
|
}
|
|
barrier();
|
|
|
|
vec4 Of[Br][D_per_thread / 4];
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
Of[r][d] = vec4(0.0);
|
|
}
|
|
}
|
|
|
|
float Lf[Br], Mf[Br];
|
|
|
|
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
|
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
|
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
Lf[r] = 0;
|
|
Mf[r] = NEG_FLT_MAX_OVER_2;
|
|
}
|
|
|
|
float slope[Br];
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
slope[r] = 1.0;
|
|
}
|
|
|
|
// ALiBi
|
|
if (p.max_bias > 0.0f) {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
|
}
|
|
}
|
|
|
|
#if BLOCK_SIZE > 1
|
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
|
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
|
|
#else
|
|
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
|
|
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
|
|
#endif
|
|
|
|
[[dont_unroll]]
|
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
|
|
float Sf[Br][cols_per_thread];
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
Sf[r][c] = 0.0;
|
|
}
|
|
}
|
|
|
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
#if BLOCK_SIZE > 1
|
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
uint ib = coord / BLOCK_SIZE;
|
|
uint iqs = (coord % BLOCK_SIZE);
|
|
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
|
#else
|
|
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
|
#endif
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
|
|
}
|
|
}
|
|
}
|
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
// Compute sum across the D_split
|
|
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (p.logit_softcap != 0.0f) {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (p.mask != 0) {
|
|
|
|
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
|
uint32_t c = (idx + tid) % Bc;
|
|
uint32_t r = (idx + tid) / Bc;
|
|
if (idx + tid < Bc * Br) {
|
|
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
|
|
}
|
|
}
|
|
barrier();
|
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
|
|
|
Sf[r][c] += slope[r]*mvf;
|
|
}
|
|
}
|
|
barrier();
|
|
}
|
|
|
|
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
rowmaxf[r] = Sf[r][0];
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
|
|
}
|
|
Moldf[r] = Mf[r];
|
|
|
|
// M = max(rowmax, Mold)
|
|
// P = e^(S - M)
|
|
// eM = e^(Mold - M)
|
|
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
|
|
}
|
|
eMf[r] = exp(Moldf[r] - Mf[r]);
|
|
|
|
// Compute sum across row of P
|
|
rowsumf[r] = 0.0;
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
rowsumf[r] += Pf[r][c];
|
|
}
|
|
|
|
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
|
}
|
|
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
Of[r][d] = eMf[r] * Of[r][d];
|
|
}
|
|
}
|
|
|
|
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
#if BLOCK_SIZE > 1
|
|
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
|
uint ib = coord / BLOCK_SIZE;
|
|
uint iqs = (coord % BLOCK_SIZE);
|
|
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
|
#else
|
|
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
|
#endif
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
Of[r][d] += Pf[r][c] * Vf;
|
|
}
|
|
}
|
|
}
|
|
|
|
barrier();
|
|
}
|
|
|
|
// reduce across threads
|
|
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
float rowmaxf, eMf;
|
|
|
|
tmpsh[tid] = Mf[r];
|
|
// Compute max across the row
|
|
barrier();
|
|
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
|
if (tid < s) {
|
|
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
|
|
}
|
|
barrier();
|
|
}
|
|
rowmaxf = tmpsh[d_tid];
|
|
barrier();
|
|
|
|
float Moldf = Mf[r];
|
|
|
|
// M = max(rowmax, Mold)
|
|
// eM = e^(Mold - M)
|
|
Mf[r] = max(rowmaxf, Moldf);
|
|
eMf = exp(Moldf - Mf[r]);
|
|
|
|
Lf[r] = eMf*Lf[r];
|
|
|
|
tmpsh[tid] = Lf[r];
|
|
|
|
// Compute sum across the row
|
|
barrier();
|
|
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
|
if (tid < s) {
|
|
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
|
|
}
|
|
barrier();
|
|
}
|
|
Lf[r] = tmpsh[d_tid];
|
|
barrier();
|
|
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
|
|
Of[r][d] = eMf * Of[r][d];
|
|
tmpshv4[tid] = Of[r][d];
|
|
|
|
barrier();
|
|
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
|
if (tid < s) {
|
|
Of[r][d] += tmpshv4[tid + s];
|
|
tmpshv4[tid] = Of[r][d];
|
|
}
|
|
barrier();
|
|
}
|
|
Of[r][d] = tmpshv4[d_tid];
|
|
barrier();
|
|
}
|
|
}
|
|
|
|
|
|
// If there is split_k, then the split_k resolve shader does the final
|
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
if (p.k_num > 1) {
|
|
uint32_t o_offset = D * p.ne1 * split_k_index;
|
|
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
if (r < N) {
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
if (r < N) {
|
|
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
|
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
|
}
|
|
}
|
|
|
|
return;
|
|
}
|
|
|
|
float Lfrcp[Br];
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
Lfrcp[r] = 1.0 / Lf[r];
|
|
}
|
|
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
Of[r][d] *= Lfrcp[r];
|
|
}
|
|
}
|
|
|
|
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
|
|
if (p.gqa_ratio > 1) {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
if (r < N) {
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
|
if (i * Br + r < N) {
|
|
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
|
|
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
|
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|