From b275e52b46e62c05b336f755ad5101ee52793a7f Mon Sep 17 00:00:00 2001 From: deepsek <166548550+deepsek@users.noreply.github.com> Date: Sat, 26 Jul 2025 18:28:14 -0400 Subject: [PATCH] HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA 3 (llama/14624) This commit adds support for MFMA instructions to MMQ. CDNA1/GFX908 CDNA2/GFX90a and CDNA3/GFX942 are supported by the MFMA-enabled code path added by this commit. The code path and stream-k is only enabled on CDNA3 for now as it fails to outperform blas in all cases on the other devices. Blas is currently only consistently outperformed on CDNA3 due to issues in the amd-provided blas libraries. This commit also improves the awareness of MMQ towards different warp sizes and as a side effect improves the performance of all quant formats besides q4_0 and q4_1, which regress slightly, on GCN gpus. --- ggml/src/ggml-cuda/common.cuh | 16 +- ggml/src/ggml-cuda/mma.cuh | 114 +- ggml/src/ggml-cuda/mmq.cu | 10 +- ggml/src/ggml-cuda/mmq.cuh | 1857 +++++++++++++++++++----------- ggml/src/ggml-cuda/vendors/hip.h | 14 +- 5 files changed, 1303 insertions(+), 708 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9435daf0..cdc3bb5a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -56,7 +56,7 @@ #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a -#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 @@ -72,8 +72,9 @@ #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) #define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) -#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) -#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 @@ -226,6 +227,10 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) +#define AMD_MFMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) + #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -288,6 +293,11 @@ static bool fp32_mma_hardware_available(const int cc) { return GGML_CUDA_CC_IS_CDNA(cc); } +// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later. +static bool amd_mfma_available(const int cc) { + return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc); +} + // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. static bool new_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 2af63355..d6817d80 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -12,7 +12,8 @@ // The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. // All matrix tiles have ne physical 32 bit elements per warp. // -// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. +// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. +// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior. #include "common.cuh" @@ -66,7 +67,44 @@ namespace ggml_cuda_mma { struct tile { static constexpr int I = I_; static constexpr int J = J_; - static constexpr int ne = I * J / WARP_SIZE; + +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + static constexpr int ne = I * J / 64; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 4) { + return threadIdx.x % 32; + } else if constexpr (I == 16 && J == 16) { + return 4 * (threadIdx.x / 16) + l; + } else if constexpr (I == 32 && J == 32) { + return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + return (2 * ((threadIdx.x / 16) % 2) + l); + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x / 16) + l; + } else if constexpr (I == 32 && J == 4) { + return 2 * (threadIdx.x / 32) + l; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 32) { + return threadIdx.x % 32; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } +#else + static constexpr int ne = I * J / 32; T x[ne] = {0}; static __device__ __forceinline__ int get_i(const int l) { @@ -94,6 +132,7 @@ namespace ggml_cuda_mma { static_assert(I == -1 && J == -1, "template specialization not implemented"); } } +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) }; template @@ -148,10 +187,23 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { +#if defined(AMD_MFMA_AVAILABLE) + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } + } else { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + } +#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } +#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -186,7 +238,7 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void load_ldmatrix( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#if defined(NEW_MMA_AVAILABLE) int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" @@ -393,4 +445,60 @@ namespace ggml_cuda_mma { NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } + + static __device__ __forceinline__ void mma( + tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], + ((int64_t *) B.x)[0], + acc[0], + 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], + B.x[0], + acc[0], + 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], + B.x[1], + acc[0], + 0, 0, 0); +#endif // defined(CDNA3) +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; + int32x16_t * acc = (int32x16_t *) D.x; +#if defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], + ((int64_t *) B.x)[0], + acc[0], + 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], + B.x[0], + acc[0], + 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], + B.x[1], + acc[0], + 0, 0, 0); +#endif // defined(CDNA3) +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 2db5b4ab..e2fd0c1c 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -109,7 +109,8 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q( // The stream-k decomposition is only faster for recent NVIDIA GPUs. // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && - ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11; + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))) + && src1_ncols == ne11; const mmq_args args = { src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, @@ -304,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (new_mma_available(cc)) { + if (new_mma_available(cc) || amd_mfma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 9696a320..36e84be1 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -90,7 +90,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return new_mma_available(cc) ? 128 : + return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -100,12 +100,12 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) return 128; -#else // NEW_MMA_AVAILABLE +#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) - return 128; + return 64; #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA @@ -115,12 +115,11 @@ static constexpr __device__ int get_mmq_x_max_device() { return MMQ_DP4A_MAX_BATCH_SIZE; #endif // GGML_CUDA_FORCE_MMQ #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - return 64; #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -144,16 +143,25 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } -#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} -#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0} -#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} -#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. +// The K dimension of the tiles has either, +// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), +// 32 bit elements for the quantized data (does not include scales). +// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. +// The final tile size in K direction is padded to avoid shared memory bank conflicts, +// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. +#define MMQ_TILE_NE_K 32 + +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { @@ -179,11 +187,11 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } } -#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); @@ -215,42 +223,80 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { } } -#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) +// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8; + if (amd_mfma_available(cc)) { + return mmq_x >= 128 ? 32 : 16; + } else if (new_mma_available(cc) && mmq_x >= 48) { + return 16; + } else { + return 8; + } } -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 128 ? 32 : 16; +} +#elif defined(NEW_MMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } #else -static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { +static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { return 8; } -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE + +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +static int mmq_get_nwarps_host(const int cc) { + return amd_mfma_available(cc) ? 8 : 4; +} +#else +static int mmq_get_nwarps_host(const int /*cc*/) { + return 8; +} +#endif // (GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + +static constexpr __device__ int mmq_get_nwarps_device() { +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +#if defined(AMD_MFMA_AVAILABLE) + return 8; +#else + return 4; +#endif // AMD_MFMA_AVAILABLE +#else + return 8; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +} // ------------------------------------------------------------ -template static __device__ __forceinline__ void load_tiles_q4_0( +template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + 2*WARP_SIZE); + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_0; - const int kqsx = threadIdx.x % QI4_0; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_0; + const int kqsx = txi % QI4_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -259,20 +305,21 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { - int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -280,17 +327,19 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; @@ -299,7 +348,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -307,7 +356,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); @@ -320,32 +369,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u, - x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q4_1( +template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_1; - const int kqsx = threadIdx.x % QI4_1; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_1; + const int kqsx = txi % QI4_1; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -354,20 +408,21 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { - int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -375,17 +430,19 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; @@ -394,7 +451,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -402,7 +459,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); @@ -415,32 +472,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u, - x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q5_0( +template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI5_0; - const int kqsx = threadIdx.x % QI5_0; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_0; + const int kqsx = txi % QI5_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -449,7 +511,7 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; const int ql = get_int_b2(bxi->qs, kqsx); - const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); + const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 @@ -465,21 +527,22 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { - int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -487,32 +550,37 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_q5_1( +template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI5_1; - const int kqsx = threadIdx.x % QI5_1; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_1; + const int kqsx = txi % QI5_1; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -521,7 +589,7 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; const int ql = get_int_b4(bxi->qs, kqsx); - const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); + const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 @@ -535,21 +603,22 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { - int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -557,32 +626,38 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_q8_0( +template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_tile + 2*WARP_SIZE); + float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI8_0; - const int kqsx = threadIdx.x % QI8_0; + // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp + constexpr int threads_per_row = 32; + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI8_0; + const int kqsx = txi % QI8_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -590,21 +665,22 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else - x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); - x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; + constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) { - int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -612,17 +688,19 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; @@ -631,7 +709,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -639,21 +717,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE], - x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + float dB; + const int j = j0 + tile_C::get_j(0); + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; + } + } + } + } +#else typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; typedef tile<16, 8, int> tile_C; @@ -662,23 +795,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + 2*WARP_SIZE; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; - tile_A A[ntx][WARP_SIZE/QI8_0]; - float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0]; + tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; + float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); @@ -689,7 +822,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; @@ -700,7 +833,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { tile_B B; float dB[tile_C::ne/2]; @@ -729,11 +862,14 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } +#endif // defined(AMD_MFMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); const int * x_qs = (const int *) x; @@ -742,7 +878,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -750,45 +886,95 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; - typedef tile<16, 8, int> tile_A; - typedef tile< 8, 8, int> tile_B; - typedef tile<16, 8, int> tile_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; - tile_A A[ntx][WARP_SIZE/QI8_1]; - float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1]; + tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); @@ -799,7 +985,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); @@ -810,7 +996,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { tile_B B; float2 dsB[tile_C::ne/2]; @@ -836,11 +1022,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } } } +#endif // defined(AMD_MFMA_AVAILABLE) } -template +// Used for Q3_K, IQ2_S, and IQ2_XS +template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; const int * x_qs = (const int *) x; @@ -849,7 +1039,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; #pragma unroll @@ -857,23 +1047,73 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl( - &x_qs[i*(2*WARP_SIZE + 1) + k0], + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +// Used for Q3_K, IQ2_S, and IQ2_XS: +template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -884,10 +1124,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -899,7 +1139,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { const int k0 = k00 + k01; load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); @@ -910,7 +1150,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { const int k0 = k00 + k01; dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; @@ -921,7 +1161,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { tile_B B[2]; float dB[tile_C::ne/2]; @@ -952,26 +1192,29 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_q2_K( +template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI2_K; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); + constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) { - int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -987,11 +1230,11 @@ template static __device__ __forceinlin const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else - x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1002,17 +1245,19 @@ template static __device__ __forceinlin const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else - x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; @@ -1029,7 +1274,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1037,13 +1282,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; constexpr int ns = 2; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } @@ -1052,7 +1297,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. // As a workaround 2 separate loops are used instead. #pragma unroll - for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1060,23 +1305,89 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; constexpr int ns = 1; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B[0]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -1087,10 +1398,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -1103,7 +1414,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); @@ -1117,7 +1428,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { const int k0 = k00 + k01; const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); @@ -1140,7 +1451,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { tile_B B[2]; // Here load_generic is faster than load_ldmatrix. @@ -1148,7 +1459,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); tile_C Cm[2]; - if (k01 >= WARP_SIZE * 3/4) { + if (k01 >= MMQ_TILE_NE_K * 3/4) { tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; @@ -1166,16 +1477,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; - if (k01 >= WARP_SIZE * 3/4) { + if (k01 >= MMQ_TILE_NE_K * 3/4) { tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); } } } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { float2 sB[tile_C::ne/2]; #pragma unroll @@ -1198,27 +1509,31 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_q3_K( +template static __device__ __forceinline__ void load_tiles_q3_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI3_K; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) { - int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1238,17 +1553,18 @@ template static __device__ __forceinlin const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else - x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { - int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; if (need_check) { i = min(i, i_max); @@ -1256,7 +1572,7 @@ template static __device__ __forceinlin const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % 4; const int ksc_low = ksc % (QI3_K/8); const int shift_low = 4 * (ksc / (QI3_K/8)); @@ -1268,23 +1584,23 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l]; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; } #else - x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; -#endif // NEW_MMA_AVAILABLE + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifndef NEW_MMA_AVAILABLE +#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { - int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1294,12 +1610,14 @@ template static __device__ __forceinlin x_df[i] = bxi->d; } -#endif // NEW_MMA_AVAILABLE +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) } -template +template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; @@ -1309,7 +1627,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1317,13 +1635,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4; + const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } @@ -1340,72 +1658,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits } -template static __device__ __forceinline__ void load_tiles_q4_K( +template static __device__ __forceinline__ void load_tiles_q4_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - const int qs0 = get_int_b4(bxi->qs, threadIdx.x); + const int qs0 = get_int_b4(bxi->qs, txi); -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifdef NEW_MMA_AVAILABLE - +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { - int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } - if (need_check) { - i = min(i, i_max); - } + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; - const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/16); + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); - const int sc32 = unpack_scales_q45_K(scales, ksc + 0); - const int m32 = unpack_scales_q45_K(scales, ksc + 2); + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; - const uint8_t * sc8 = (const uint8_t *) &sc32; - const uint8_t * m8 = (const uint8_t *) &m32; + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); - const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); - -#pragma unroll - for (int l = 0; l < int(sizeof(int)); ++l) { - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + #pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } } } - #else - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) { - int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1415,30 +1746,32 @@ template static __device__ __forceinlin x_dm[i] = bxi->dm; } - + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); const int scales8 = unpack_scales_q45_K(scales, ksc); - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; @@ -1448,7 +1781,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1456,97 +1789,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16); + const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q5_K( +template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); + half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - const int ky = QR5_K*threadIdx.x; + const int ky = QR5_K*txi; - const int ql = get_int_b4(bxi->qs, threadIdx.x); + const int ql = get_int_b4(bxi->qs, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010; + const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; - const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; + const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else - x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifdef NEW_MMA_AVAILABLE - +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { - int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - - const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/16); - - const int sc32 = unpack_scales_q45_K(scales, ksc + 0); - const int m32 = unpack_scales_q45_K(scales, ksc + 2); - - const uint8_t * sc8 = (const uint8_t *) &sc32; - const uint8_t * m8 = (const uint8_t *) &m32; - - const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); - -#pragma unroll - for (int l = 0; l < int(sizeof(int)); ++l) { - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); - } - } - + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { #else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) { - int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y; + for (int l = 0; l < int(sizeof(int)); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1557,9 +1903,10 @@ template static __device__ __forceinlin x_dm[i] = bxi->dm; } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { - int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1569,17 +1916,19 @@ template static __device__ __forceinlin const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); const int scales8 = unpack_scales_q45_K(scales, ksc); - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; @@ -1589,7 +1938,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1597,36 +1946,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q6_K( +template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); + int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -1634,67 +1989,67 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; - const int ql = get_int_b2(bxi->ql, threadIdx.x); + const int ql = get_int_b2(bxi->ql, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4)); - const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030; - const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030; + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); + const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; - const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; - const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; + const int kq0 = 2*txi - txi % (QI6_K/2) + 0; + const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else - x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { - int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else - x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#ifdef NEW_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else - x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); -#endif // NEW_MMA_AVAILABLE + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; @@ -1704,7 +2059,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1712,23 +2067,74 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]); + const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, - x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile< 8, 4, int> tile_B; @@ -1738,11 +2144,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -1755,7 +2161,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); @@ -1763,7 +2169,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { const int k0 = k00 + k01; #pragma unroll @@ -1793,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( float tmp[ntx][tile_C::ne] = {{0.0f}}; #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { tile_B B[2]; float dB[tile_C::ne/2]; @@ -1832,27 +2238,32 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_iq4_nl( +template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_NL; - const int kqsx = threadIdx.x % QI4_NL; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_NL; + const int kqsx = txi % QI4_NL; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -1862,22 +2273,24 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); - const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; + const int k0 = kbx * (2 * QI4_NL) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) { - int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -1885,31 +2298,35 @@ template static __device__ __forceinlin const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else - x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_xxs( +template static __device__ __forceinline__ void load_tiles_iq2_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_XXS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1932,42 +2349,46 @@ template static __device__ __forceinlin const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_xs( +template static __device__ __forceinline__ void load_tiles_iq2_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_XS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1986,44 +2407,48 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_s( +template static __device__ __forceinline__ void load_tiles_iq2_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_S/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2049,44 +2474,48 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq3_xxs( +template static __device__ __forceinline__ void load_tiles_iq3_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI3_XXS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2107,42 +2536,46 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq3_s( +template static __device__ __forceinline__ void load_tiles_iq3_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI3_S/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2170,42 +2603,46 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq1_s( +template static __device__ __forceinline__ void load_tiles_iq1_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); + half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI1_S; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) { - int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2225,66 +2662,71 @@ template static __device__ __forceinlin const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#ifdef NEW_MMA_AVAILABLE - x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else - x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // NEW_MMA_AVAILABLE + x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq4_xs( +template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = 0; // threadIdx.x / QI4_XS - const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx; + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); - const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef NEW_MMA_AVAILABLE + const int k0 = 8 * (kqsx / 4) + kqsx % 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } + constexpr int rows_per_warp = warp_size / 8; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); if (need_check) { i = min(i, i_max); @@ -2297,18 +2739,21 @@ template static __device__ __forceinlin const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else - x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void mmq_write_back_dp4a( const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -2318,32 +2763,40 @@ static __device__ __forceinline__ void mmq_write_back_dp4a( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } } -template +template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { - typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) + constexpr int tileC_IJ = mmq_get_granularity_device(0); + typedef tile tile_C; + constexpr int rows_per_warp = granularity; +#else + typedef tile<16, 8, int> tile_C; constexpr int rows_per_warp = 2 * granularity; +#endif constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#ifdef NEW_MMA_AVAILABLE +#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { @@ -2371,179 +2824,181 @@ static __device__ __forceinline__ void mmq_write_back_mma( // ------------------------------------------------------------------------------------------------------------------------------------- -template +template struct mmq_type_traits; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template +template static __device__ __forceinline__ void mul_mat_q_process_tile( const char * __restrict__ x, const int offset_x, const int * __restrict__ y, const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int stride_row_x, const int ncols_y, const int stride_col_dst, const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int nwarps = mmq_get_nwarps_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; extern __shared__ int data_mul_mat_q[]; int * tile_y = data_mul_mat_q + mmq_x; - int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); + int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#ifdef NEW_MMA_AVAILABLE - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; - constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; - constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // NEW_MMA_AVAILABLE + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; - float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); @@ -2551,8 +3006,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { - int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; } @@ -2567,8 +3022,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { - int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; } @@ -2576,7 +3031,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( __syncthreads(); - vec_dot(tile_x, tile_y, sum, WARP_SIZE); + vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); __syncthreads(); } @@ -2591,16 +3046,16 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 -template +template #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) - __launch_bounds__(WARP_SIZE*nwarps, 2) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #else #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - __launch_bounds__(WARP_SIZE*nwarps, 1) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) #else - __launch_bounds__(WARP_SIZE*nwarps, 2) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( @@ -2616,6 +3071,9 @@ static __global__ void mul_mat_q( return; } + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); @@ -2627,10 +3085,10 @@ static __global__ void mul_mat_q( // For MoE the correct indices are loaded from ids_dst. extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2639,7 +3097,7 @@ static __global__ void mul_mat_q( __syncthreads(); // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: -#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { const int wt = blockIdx.z / nchannels_y; const int zt = blockIdx.z - wt*nchannels_y; @@ -2667,10 +3125,10 @@ static __global__ void mul_mat_q( // __syncthreads(); // There is no previous tile that could cause a race condition. #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2688,12 +3146,12 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } -#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -2745,10 +3203,10 @@ static __global__ void mul_mat_q( __syncthreads(); #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2766,7 +3224,7 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); @@ -2812,10 +3270,10 @@ static __global__ void mul_mat_q( // The memory layout for the fixup buffer is always contiguous, therefore reset ids: __syncthreads(); #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2833,13 +3291,13 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } -template +template static __global__ void mul_mat_q_stream_k_fixup( const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, @@ -2849,7 +3307,10 @@ static __global__ void mul_mat_q_stream_k_fixup( constexpr int blocks_per_iter = MMQ_ITER_K / qk; const int64_t blocks_per_ne00 = ncols_x / qk; - float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; const int nty = (nrows_x + mmq_y - 1) / mmq_y; @@ -2893,10 +3354,10 @@ static __global__ void mul_mat_q_stream_k_fixup( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; + sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } } @@ -2937,14 +3398,14 @@ static __global__ void mul_mat_q_stream_k_fixup( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } return; @@ -2955,7 +3416,7 @@ static __global__ void mul_mat_q_stream_k_fixup( const int col_high = expert_bounds[zt + 1]; const int col_diff = col_high - col_low; - for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) { + for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { ids_dst_shared[j] = ids_dst[col_low + j]; } __syncthreads(); @@ -2975,14 +3436,14 @@ static __global__ void mul_mat_q_stream_k_fixup( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } } @@ -2996,13 +3457,13 @@ struct mmq_args { }; template -static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) { +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); - return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } template @@ -3010,14 +3471,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc); const int mmq_y = get_mmq_y_host(cc); - const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); + const dim3 block_dims(warp_size, nwarps, 1); - const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc); + const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; @@ -3032,14 +3495,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3059,8 +3522,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3070,13 +3532,12 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a return; } - mul_mat_q_stream_k_fixup<<>> + mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3086,7 +3547,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a return; } - mul_mat_q_stream_k_fixup<<>> + mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } @@ -3094,9 +3555,11 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc); const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); @@ -3107,7 +3570,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { const int granularity = mmq_get_granularity_host(mmq_x, cc); - if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc) > smpbo) { + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { continue; } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 184d445f..56e59a05 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -160,7 +160,19 @@ #endif #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) -#define CDNA +#define CDNA // For the entire family +#endif + +#if defined(__gfx942__) +#define CDNA3 +#endif + +#if defined(__gfx90a__) +#define CDNA2 +#endif + +#if defined(__gfx908__) +#define CDNA1 #endif #if defined(__GFX12__)