From 51a3580c7931590182fa5ea01eacc7c8b0a8ddb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 17 Feb 2025 14:03:24 +0100 Subject: [PATCH] CUDA: use async data loading for FlashAttention (llama/11894) * CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa --- ggml/src/ggml-cuda/common.cuh | 21 +- ggml/src/ggml-cuda/cp-async.cuh | 46 ++ ggml/src/ggml-cuda/fattn-common.cuh | 15 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 614 +++++++++++++++------------ ggml/src/ggml-cuda/mma.cuh | 509 ++++++++-------------- ggml/src/ggml-cuda/mmq.cuh | 278 ++++++------ 6 files changed, 744 insertions(+), 739 deletions(-) create mode 100644 ggml/src/ggml-cuda/cp-async.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index fd4dcfa9..4a92d35f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -41,12 +41,13 @@ #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons -#define GGML_CUDA_CC_PASCAL 600 -#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define GGML_CUDA_CC_VOLTA 700 -#define GGML_CUDA_CC_TURING 750 -#define GGML_CUDA_CC_AMPERE 800 -#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 // GCN/CNDA, wave size is 64 #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 @@ -199,6 +200,10 @@ typedef float2 dfloat2; #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) #define FLASH_ATTN_AVAILABLE #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) @@ -231,6 +236,10 @@ static bool new_mma_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; } +static bool cp_async_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; +} + static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) return __AMDGCN_WAVEFRONT_SIZE; diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh new file mode 100644 index 00000000..51aa41e7 --- /dev/null +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -0,0 +1,46 @@ +// Simplified API for asynchronous data loading. + +#include "common.cuh" + +// Copies data from global to shared memory, cg == cache global. +// Both the src and dst pointers must be aligned to 16 bit. +// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. +// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared. +// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements. +template +static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) { + static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload"); +#ifdef CP_ASYNC_AVAILABLE +#if CUDART_VERSION >= 11040 + if (preload == 256) { + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 128) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 64) { + asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else +#endif // CUDART_VERSION >= 11040 + { + asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } +#else + GGML_UNUSED(dst); + GGML_UNUSED(src); + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} + +// Makes each thread wait until its asynchronous data copies are done. +// This does NOT provide any additional synchronization. +// In particular, when copying data with multiple warps a call to __syncthreads will be needed. +static __device__ __forceinline__ void cp_async_wait_all() { +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.wait_all;"); +#else + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d40ee2da..fefbd319 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -716,7 +716,9 @@ void launch_fattn( ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); @@ -768,13 +770,14 @@ void launch_fattn( dim3 blocks_num; if (parallel_blocks == 0) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm; - const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total; - const bool short_context = K->ne[1] < 4096; + const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm); + const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves); const int nblocks_stream_k = 2*nsm; - blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k; + const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; blocks_num.y = 1; blocks_num.z = 1; @@ -827,7 +830,7 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if constexpr (parallel_blocks == 0) { - if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(D, 1, 1); const dim3 blocks_num_combine = blocks_num; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 05bc91a3..d777f541 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1,7 +1,252 @@ #include "common.cuh" +#include "cp-async.cuh" #include "mma.cuh" #include "fattn-common.cuh" +using namespace ggml_cuda_mma; + +typedef tile<16, 8, half2> tile_A; +typedef tile< 8, 8, half2> tile_B; +typedef tile<16, 8, float> tile_C_KQ; +typedef tile<16, 4, half2> tile_C_VKQ; + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + // If cp.async is available, load up to the highest power of 2 in D asynchronously: +#ifdef CP_ASYNC_AVAILABLE + static_assert(D >= 64 && D < 512, "bad D"); + constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); + + const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); + + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; + constexpr int stride_i = WARP_SIZE / chunks_per_row; +#pragma unroll + for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); + const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; + + cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); + } +#else + constexpr int k0_sync_start = 0; +#endif // CP_ASYNC_AVAILABLE + static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + + // If D is not a power of 2, the rest is loaded synchronously. + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); + const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop || k0_stop <= k0_sync_start) { + continue; + } + +#pragma unroll + for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; + } + } + } +} + +template +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ maskh, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const int ne01, + const int ne02, + const int stride_Q, + const int stride_KV, + const int stride_mask, + const int jt, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + const tile_B * const __restrict__ Q_B, + tile_C_VKQ * const __restrict__ VKQ_C, + float2 & KQ_max, + float2 & KQ_rowsum, + const int kb0) { +#ifdef NEW_MMA_AVAILABLE + constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. + constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + + const int k_VKQ_0 = kb0*KQ_stride; + tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)]; + +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); +#else + flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]); + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + + if (use_logit_softcap) { + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + if (maskh) { + static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size"); + static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size"); +#pragma unroll + for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const int i = i0 + tile_C_KQ::get_i(l); + const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l); + + KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + float2 KQ_max_new = KQ_max; + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) { + KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); + KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); + } + } + + // Values per KQ column are spread across 8 threads, does not need full warp reduce: +#pragma unroll + for (int offset = 16; offset > 2; offset >>= 1) { + KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); + KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); + } + + float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); + static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) { +#pragma unroll + for (int l = 0; l < tile_C_KQ::ne; ++l) { + const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y; + const float diff = KQ_C[k].x[l] - KQ_max_l; + KQ_C[k].x[l] = expf(diff); + + if (l % 2 == 0) { + KQ_rowsum_add.x += KQ_C[k].x[l]; + } else { + KQ_rowsum_add.y += KQ_C[k].x[l]; + } + } + } + + { + const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); + const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); + KQ_max = KQ_max_new; + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; + KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; + + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); +#pragma unroll + for (int i = 0; i < D/tile_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < tile_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + tile_B B[KQ_stride/(np*2*tile_B::J)]; + static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size"); +#pragma unroll + for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + +#ifdef CP_ASYNC_AVAILABLE + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV); + } +#else + flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); + __syncthreads(); +#endif // CP_ASYNC_AVAILABLE + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } + } + +#ifndef CP_ASYNC_AVAILABLE + __syncthreads(); // Only needed if tile_K == tile_V. +#endif // CP_ASYNC_AVAILABLE + +#else + NO_DEVICE_CODE; +#endif // NEW_MMA_AVAILABLE +} + template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, @@ -13,61 +258,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float scale, const float slope, const float logit_softcap, - const int ne00, const int ne01, const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3, + const int stride_Q, + const int stride_KV, + const int stride_mask, const int jt, const int kb0_start, const int kb0_stop) { #ifdef NEW_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C_KQ; - typedef mma_C_I16J8 mma_C_VKQ; - - static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); - constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column. + static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps"); + constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column. static_assert(D % nwarps == 0, "bad D"); static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. - extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements. - const int stride_Q = nb01 / sizeof(float2); - const int stride_KV = nb11 / sizeof(half2); - const int stride_mask = nb31 / sizeof(half); + // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements: + extern __shared__ half2 tile_K[]; +#ifdef CP_ASYNC_AVAILABLE + half2 * tile_V = tile_K + KQ_stride*D2_padded; +#else + half2 * tile_V = tile_K; +#endif // CP_ASYNC_AVAILABLE - mma_B Q_B[D/(2*mma_B::K)]; - mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; + tile_B Q_B[D/(2*tile_B::J)]; + tile_C_VKQ VKQ_C[D/tile_C_VKQ::I]; - float2 KQ_rowsum = {0.0f, 0.0f}; - float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; - float2 KQ_max_scale = {0.0f, 0.0f}; + float2 KQ_rowsum = {0.0f, 0.0f}; + float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; - // Temporarily load Q data into tile_KV, will be loaded into registers afterwards. + // Temporarily load Q data into tile_K, will be loaded into registers afterwards. // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll @@ -76,6 +300,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k0_stop = D/2 - (D/2) % (1*stride_k); const int stride_j = WARP_SIZE / stride_k; + if (k0_start == k0_stop) { + continue; + } + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { break; } @@ -90,14 +318,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; - tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); } } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); + tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f); } } } @@ -106,198 +334,42 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); { - const int j0 = (threadIdx.y / np) * mma_B::J; + const int j0 = (threadIdx.y / np) * tile_B::I; #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { - Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); } } __syncthreads(); + // Preload K data for first iteration when using cp_async: +#ifdef CP_ASYNC_AVAILABLE + flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV); +#endif // CP_ASYNC_AVAILABLE + // Iterate over ne11 == previous tokens: - for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { - const int k_VKQ_0 = kb0*KQ_stride; - mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; - - // Load K data into tile with decreasing granularity for D for better memory bandwidth: - static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); -#pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { - const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - -#pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { - const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; - } - } - } - - __syncthreads(); - - // Calculate tile of KQ: -#pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { - mma_A K_A; - K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); - } - } - - __syncthreads(); - - if (use_logit_softcap) { - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); -#pragma unroll - for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { -#pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); - } - } - } - - if (maskh) { - static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size"); - static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); -#pragma unroll - for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { - const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; -#pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - const int i = i0 + mma_C_KQ::get_i(l); - const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); - - KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); - } - } - } - - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. - float2 KQ_max_new = KQ_max; - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); -#pragma unroll - for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { -#pragma unroll - for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { - KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); - KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); - } - } - - // Values per KQ column are spread across 8 threads, does not need full warp reduce: -#pragma unroll - for (int offset = 16; offset > 2; offset >>= 1) { - KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); - KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); - } - - { - const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); - KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); - if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale.x = 0.0f; - } - if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale.y = 0.0f; - } - KQ_max = KQ_max_new; - } - - float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); - static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); -#pragma unroll - for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { -#pragma unroll - for (int l = 0; l < mma_C_KQ::ne; ++l) { - const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; - const float diff = KQ_C[k].x[l] - KQ_max_l; - KQ_C[k].x[l] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_C[k].x[l] = 0.0f; - } - - if (l % 2 == 0) { - KQ_rowsum_add.x += KQ_C[k].x[l]; - } else { - KQ_rowsum_add.y += KQ_C[k].x[l]; - } - } - } - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; - KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; - - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); -#pragma unroll - for (int i = 0; i < D/mma_C_VKQ::I; ++i) { -#pragma unroll - for (int l = 0; l < mma_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; - } - } - - // Convert KQ C tiles into B tiles for VKQ calculation: - mma_B B[KQ_stride/(np*2*mma_B::K)]; - static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); -#pragma unroll - for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { - B[k] = KQ_C[k].to_mma_B(); - } - - // Load V data into tile with decreasing granularity for D for better memory bandwidth: - static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); -#pragma unroll - for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); - const int i0_stop = D/2 - (D/2) % (1*stride_i); - const int stride_k = WARP_SIZE / stride_i; - -#pragma unroll - for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { - const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); - -#pragma unroll - for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { - const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); - - tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; - } - } - } - - __syncthreads(); - - // Calculate VKQ tile: -#pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { - static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); -#pragma unroll - for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { - const int k0 = k00 + (threadIdx.y % np)*mma_A::K; - - mma_A A; - A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); - } - } + for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + } + { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + constexpr bool last_iter = true; + flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + } + // With cp_async there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. +#ifdef CP_ASYNC_AVAILABLE + if (nwarps*tile_B::I > KQ_stride) { __syncthreads(); } +#endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8 threads each, does not need full reduce. @@ -310,26 +382,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Write VKQ accumulators to shared memory in column-major format. // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // Also for np > 1 the combination is done via these values in shared memory. - const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data + const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { - const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format. + for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int l = 0; l < mma_B::ne; ++l) { - const int k = k0 + mma_B::get_k(l); + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - tile_KV[j_cwd*D2_padded + k] = B.x[l]; + tile_K[j_cwd*D2_padded + k] = B.x[l]; } } - const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset - const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta + const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset + const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum - if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; } __syncthreads(); @@ -337,11 +409,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( static_assert(np == 1 || np == 2 || np == 4, "bad np"); if (np == 1) { // No combination is needed, the meta data can be directly written from registers to VRAM. - if (needs_fixup && threadIdx.x < mma_B::J) { + if (needs_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[j_cwm] = KQ_cmr; } - if (is_fixup && threadIdx.x < mma_B::J) { + if (is_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[j_cwm] = KQ_cmr; } @@ -350,42 +422,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; + float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2; float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { KQ_cm = meta_j[0]; } float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. #pragma unroll - for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { KQ_crs = KQ_cms*meta_j[1]; } #pragma unroll - for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { + for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) { KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } // Write back combined meta data: - if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { - meta_j[0] = KQ_cmn; // Combined max. KQ values. - meta_j[1] = KQ_crs; // Combined KQ rowsums. - meta_j[2] = KQ_cms; // KQ max scales per parallel warp. + if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) { + *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum. } - if (needs_fixup && threadIdx.x < mma_B::J) { + if (needs_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; - dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && threadIdx.x < mma_B::J) { + if (is_fixup && threadIdx.x < tile_B::I) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; - dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } } @@ -404,6 +474,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k0_stop = D/2 - (D/2) % (1*stride_k); const int stride_j = WARP_SIZE / stride_k; + if (k0_start == k0_stop) { + continue; + } + if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { break; } @@ -411,12 +485,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; + const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I; if (!is_fixup && jt*ncols + j_dst >= ne01) { continue; } - const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; + const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -424,8 +498,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; - const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0]; + const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]); dstk_val.x += dstk_val_add.x*KQ_crs; dstk_val.y += dstk_val_add.y*KQ_crs; } @@ -450,7 +524,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); } #else - NO_DEVICE_CODE; + NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } @@ -494,6 +568,11 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { +#ifndef NEW_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // NEW_MMA_AVAILABLE + // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -504,6 +583,10 @@ static __global__ void flash_attn_ext_f16( const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const int stride_Q = nb01 / sizeof(float2); + const int stride_KV = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + const int iter_k = ne11 / KQ_stride; const int iter_j = (ne01 + (ncols - 1)) / ncols; @@ -535,14 +618,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } kbc += iter_k; @@ -571,24 +652,27 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, - ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, - jt, kb0_start, kb0_stop); + ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop); } template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; + typedef tile<16, 8, half2> tile_A; + typedef tile< 8, 8, half2> tile_B; - static_assert(D % mma_B::K == 0, "bad D"); - static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); + static_assert(D % tile_B::J == 0, "bad D"); + static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block"); const ggml_tensor * KQV = dst; + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - constexpr int KQ_stride = D <= 128 ? 64 : 32; - constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? - cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); - constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); + constexpr int KQ_stride = D <= 128 ? 64 : 32; + constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ? + cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8); + + const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride; + const int nrows_combine = nwarps*tile_B::J; + const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index bbc0a35a..0a5656e4 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -4,11 +4,12 @@ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction // // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. -// A is a row-major matrix with shape I x K. -// B is a column-major matrix with shape K x J. -// C is a column-major matrix with shape I x J. -// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements. -// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// A is a row-major matrix with shape M x K. +// B is a column-major matrix with shape K x N. +// C is a column-major matrix with shape M x N. +// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns. +// Note that J is measured in physical 32 bit elements instead of logical elements. +// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. // All matrix tiles have ne physical 32 bit elements per warp. // // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. @@ -23,7 +24,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { #ifdef NEW_MMA_AVAILABLE asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" - : "+r"(ret) : "r"(x)); + : "=r"(ret) : "r"(x)); #else NO_DEVICE_CODE; #endif // defined(NEW_MMA_AVAILABLE) @@ -52,407 +53,267 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { #endif // CUDART_VERSION >= 11080 +static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { + half2 ret; + *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x)); + return ret; +} -template -struct mma_A_I16K4 { - static_assert(sizeof(T) == 4, "bad type size"); +namespace ggml_cuda_mma { - static constexpr int I = 16; - static constexpr int K = 4; - static constexpr int ne = 2; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + T x[ne] = {0}; - T x[ne]; + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && (J == 4 || J == 8)) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l%2) * (I/2) + threadIdx.x / K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 8 && J == 8) { + return 4 * l + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x % 4) + l % 2; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; - static __device__ __forceinline__ int get_k(const int /* l */) { - const int ret = threadIdx.x % K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return l * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return l * 4 + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 4 + threadIdx.x % 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; #pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_i(l)*stride + get_k(l)]; + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + tile<8, 8, half2> ret; + ret.x[0] = ggml_cuda_movmatrix(t.x[0]); + ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + + return ret; + } + + template + static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } } - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(xi[0]), "+r"(xi[1]) + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else + load_generic(t, xs0, stride); +#endif // NEW_MMA_AVAILABLE + } + + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else load_generic(xs0, stride); #endif // NEW_MMA_AVAILABLE } -}; -template -struct mma_A_I16K8 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int I = 16; - static constexpr int K = 8; - static constexpr int ne = 4; - - T x[ne]; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - - static __device__ __forceinline__ int get_k(const int l) { - const int ret = (l/2) * (K/2) + threadIdx.x % (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } - - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_i(l)*stride + get_k(l)]; - } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef NEW_MMA_AVAILABLE - int * xi = (int * ) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); #else + load_generic(t, xs0, stride); +#endif // NEW_MMA_AVAILABLE + } + + template + static __device__ __forceinline__ void load_ldmatrix_trans( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef NEW_MMA_AVAILABLE + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED(t); GGML_UNUSED(xs0); GGML_UNUSED(stride); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) { -#ifdef NEW_MMA_AVAILABLE - int * xi = (int * ) x; - const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); - asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" - : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3]) - : "l"(xs)); -#else - GGML_UNUSED(xs0); - GGML_UNUSED(stride); - NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE - } - - __device__ __forceinline__ void transpose() { - int * xi = (int *) x; - xi[0] = ggml_cuda_movmatrix(xi[0]); - - const int tmp = ggml_cuda_movmatrix(xi[1]); - xi[1] = ggml_cuda_movmatrix(xi[2]); - xi[2] = tmp; - - xi[3] = ggml_cuda_movmatrix(xi[3]); - } -}; - -template -struct mma_B_J8K4 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int J = 8; - static constexpr int K = 4; - static constexpr int ne = 1; - - T x[ne]; - - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x / K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - static __device__ __forceinline__ int get_k(const int /* l */) { - const int ret = threadIdx.x % K; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } - - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_k(l)]; - } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { -#ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride; - asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" - : "+r"(xi[0]) : "l"(xs)); -#else - load_generic(xs0, stride); -#endif // NEW_MMA_AVAILABLE - } -}; - -template -struct mma_B_J8K8 { - static_assert(sizeof(T) == 4, "bad type size"); - - static constexpr int J = 8; - static constexpr int K = 8; - static constexpr int ne = 2; - - T x[ne]; - - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x / (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - static __device__ __forceinline__ int get_k(const int l) { - const int ret = l * (K/2) + threadIdx.x % (K/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < K); - return ret; - } - - __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_k(l)]; - } - } - - __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { -#ifdef NEW_MMA_AVAILABLE - int * xi = (int *) x; - const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; - asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" - : "+r"(xi[0]), "+r"(xi[1]) - : "l"(xs)); -#else - load_generic(xs0, stride); -#endif // NEW_MMA_AVAILABLE - } -}; - -template -struct mma_C_I16J8 {}; - -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 8; - static constexpr int ne = 4; - - int x[ne] = {0}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - - static __device__ __forceinline__ int get_j(const int l) { - const int ret = 2 * (threadIdx.x % (J/2)) + l%2; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K4 & mma_A, const mma_B_J8K4 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { #ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { #ifdef NEW_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" - : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1])); #else // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[0]), "+r"(x[1]) - : "r"(mma_A.x[2]), "r"(mma_B.x[1])); + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[2]), "r"(B.x[1])); asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" - : "+r"(x[2]), "+r"(x[3]) - : "r"(mma_A.x[3]), "r"(mma_B.x[1])); + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[3]), "r"(B.x[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -}; -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 4; - static constexpr int ne = 2; - - half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = l * (I/2) + threadIdx.x / J; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - - static __device__ __forceinline__ int get_j(const int /* l */) { - const int ret = threadIdx.x % J; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef NEW_MMA_AVAILABLE - int * Axi = (int *) mma_A.x; - int * Bxi = (int *) mma_B.x; - int * xi = (int *) x; + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" - : "+r"(xi[0]), "+r"(xi[1]) + : "+r"(Dxi[0]), "+r"(Dxi[1]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ mma_B_J8K8 to_mma_B() { - mma_B_J8K8 mma_B; - - int * xi = (int *) x; - int * Bxi = (int *) mma_B.x; - Bxi[0] = ggml_cuda_movmatrix(xi[0]); - Bxi[1] = ggml_cuda_movmatrix(xi[1]); - - return mma_B; - } -}; - -template <> -struct mma_C_I16J8 { - static constexpr int I = 16; - static constexpr int J = 8; - static constexpr int ne = 4; - - float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f}; - - static __device__ __forceinline__ int get_i(const int l) { - const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < I); - return ret; - } - - static __device__ __forceinline__ int get_j(const int l) { - const int ret = 2 * (threadIdx.x % (J/2)) + l%2; - GGML_CUDA_ASSUME(ret >= 0); - GGML_CUDA_ASSUME(ret < J); - return ret; - } - - __device__ __forceinline__ void mma(const mma_A_I16K8 & mma_A, const mma_B_J8K8 & mma_B) { + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef NEW_MMA_AVAILABLE - int * Axi = (int *) mma_A.x; - int * Bxi = (int *) mma_B.x; - int * xi = (int *) x; + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); #else // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" - : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else - GGML_UNUSED(mma_A); - GGML_UNUSED(mma_B); + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } - __device__ __forceinline__ mma_B_J8K8 to_mma_B() { - mma_B_J8K8 mma_B; - mma_B.x[0] = make_half2(x[0], x[1]); - mma_B.x[1] = make_half2(x[2], x[3]); - - int * Bxi = (int *) mma_B.x; - Bxi[0] = ggml_cuda_movmatrix(Bxi[0]); - Bxi[1] = ggml_cuda_movmatrix(Bxi[1]); - - return mma_B; - } - - __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) { -#pragma unroll - for (int l = 0; l < ne; ++l) { - x[l] = xs0[get_j(l)*stride + get_i(l)]; - } - } -}; +} diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 53915420..0451c65f 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -7,6 +7,8 @@ #include #include +using namespace ggml_cuda_mma; + #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_ITER_K 256 #define MMQ_NWARPS 8 @@ -647,15 +649,15 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + 2*WARP_SIZE; @@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_0]; - float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; + tile_A A[ntx][WARP_SIZE/QI8_0]; + float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; @@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { const int k0 = k00 + k01; - A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { @@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { - mma_B B; - float dB[mma_C::ne/2]; + tile_B B; + float dB[tile_C::ne/2]; - B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma(A[n][k01/QI8_0], B); + tile_C C; + mma(C, A[n][k01/QI8_0], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; } } } @@ -758,23 +760,23 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - typedef mma_A_I16K8 mma_A; - typedef mma_B_J8K8 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_1]; - float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; + tile_A A[ntx][WARP_SIZE/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; @@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { @@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { - mma_B B; - float2 dsB[mma_C::ne/2]; + tile_B B; + float2 dsB[tile_C::ne/2]; - B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma(A[n][k01/QI8_1], B); + tile_C C; + mma(C, A[n][k01/QI8_1], B); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; } } } @@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_A_I16K8 mma_A_K8; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - float dA[ntx][mma_C::ne/2][8]; + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { @@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { - mma_B B[2]; - float dB[mma_C::ne/2]; + tile_B B[2]; + float dB[tile_C::ne/2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma(A[n][k01/4 + 0], B[0]); - C[1].mma(A[n][k01/4 + 1], B[1]); + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); } } } @@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_A_I16K8 mma_A_K8; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - float dA[ntx][mma_C::ne/2][8]; - float mA[ntx][mma_C::ne/2][8]; + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + float mA[ntx][tile_C::ne/2][8]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - ((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } } #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { @@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float2 dB[mma_C::ne/2]; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float2 dB[tile_C::ne/2]; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); } #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { - mma_B B[2]; + tile_B B[2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); - mma_C Cm[2]; + tile_C Cm[2]; if (k01 >= WARP_SIZE * 3/4) { - mma_A A1; + tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; - Cm[0].mma(A1, B[0]); - Cm[1].mma(A1, B[1]); + mma(Cm[0], A1, B[0]); + mma(Cm[1], A1, B[1]); } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C Cd[2]; + tile_C Cd[2]; - Cd[0].mma(A[n][k01/4 + 0], B[0]); - Cd[1].mma(A[n][k01/4 + 1], B[1]); + mma(Cd[0], A[n][k01/4 + 0], B[0]); + mma(Cd[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { + for (int l = 0; l < tile_C::ne; ++l) { float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; if (k01 >= WARP_SIZE * 3/4) { tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; } - sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); } } } #pragma unroll for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { - float2 sB[mma_C::ne/2]; + float2 sB[tile_C::ne/2]; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } @@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; - sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; } } } @@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { #ifdef NEW_MMA_AVAILABLE - typedef mma_A_I16K4 mma_A; - typedef mma_B_J8K4 mma_B; - typedef mma_C_I16J8 mma_C; + typedef tile<16, 4, int> tile_A; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; const float * x_df = (const float *) x_qs + WARP_SIZE*2; @@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); - mma_A A[ntx][8]; - int scA[ntx][mma_C::ne/2][8]; - float dA[ntx][mma_C::ne/2]; + tile_A A[ntx][8]; + int scA[ntx][tile_C::ne/2][8]; + float dA[ntx][tile_C::ne/2]; #pragma unroll for (int n = 0; n < ntx; ++n) { @@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { const int k0 = k00 + k01; - A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); - A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll @@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int k0 = k00 + k01; #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; const int8_t * sc = (const int8_t *) &sc_packed; @@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; } } #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float tmp[ntx][mma_C::ne] = {{0.0f}}; + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float tmp[ntx][tile_C::ne] = {{0.0f}}; #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - mma_B B[2]; - float dB[mma_C::ne/2]; + tile_B B[2]; + float dB[tile_C::ne/2]; // Here load_generic is faster than load_ldmatrix. - B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); - B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; } #pragma unroll for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma(A[n][k01/4 + 0], B[0]); - C[1].mma(A[n][k01/4 + 1], B[1]); + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { + for (int l = 0; l < tile_C::ne; ++l) { tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; } } @@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2]; + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; } } } @@ -2312,36 +2314,36 @@ template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - typedef mma_C_I16J8 mma_C; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); + const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); #ifdef NEW_MMA_AVAILABLE - static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); + static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); #endif // NEW_MMA_AVAILABLE #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); + for (int l = 0; l < tile_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); if (j > j_max) { continue; } - const int i = i0 + n*mma_C::I + mma_C::get_i(l); + const int i = i0 + n*tile_C::I + tile_C::get_i(l); if (need_check && i > i_max) { continue; } - dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l]; + dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; } } }