diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 1832314e..4d4ac47c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -395,11 +395,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2) +#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__) c = __builtin_amdgcn_sdot4(a, b, c, false); #elif defined(RDNA3) c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); -#elif defined(__gfx1010__) || defined(__gfx900__) +#elif defined(RDNA1) || defined(__gfx900__) int tmp1; int tmp2; asm("\n \ diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 4fb466ca..a7d518a5 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { 1; } +enum mmvq_parameter_table_id { + MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_GCN, + MMVQ_PARAMETERS_RDNA2 +}; + +static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { +#if defined(RDNA2) || defined(RDNA3) + return MMVQ_PARAMETERS_RDNA2; +#elif defined(GCN) || defined(CDNA) + return MMVQ_PARAMETERS_GCN; +#else + return MMVQ_PARAMETERS_GENERIC; +#endif +} + +static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { + return MMVQ_PARAMETERS_RDNA2; + } + if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { + return MMVQ_PARAMETERS_GCN; + } + return MMVQ_PARAMETERS_GENERIC; +} + +static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC) { + switch (ncols_y) { + case 1: + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } else if (table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_y) { + case 1: + case 2: + case 3: + case 4: + return 2; + case 5: + case 6: + case 7: + case 8: + default: + return 1; + } + } + return 1; +} + +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_y) { + case 1: + return 1; + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } + return 1; +} + template -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) // tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -59,24 +137,20 @@ static __global__ void mul_mat_vec_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; constexpr int vdr = get_vdr_mmvq(type); + constexpr mmvq_parameter_table_id table_id = get_device_table_id(); + constexpr int nwarps = calc_nwarps(ncols_y, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); -#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) - constexpr int nwarps = 1; - constexpr int rows_per_cuda_block = 1; -#else - constexpr int nwarps = ncols_y <= 4 ? 4 : 2; - constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; -#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) - - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + const int tid = warp_size*threadIdx.y + threadIdx.x; const int row0 = rows_per_cuda_block*blockIdx.x; const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1; - constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; + constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; -// partial sum for each thread + // partial sum for each thread float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; const block_q8_1 * y = (const block_q8_1 *) vy; @@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q( } } - __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE]; + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { #pragma unroll for (int j = 0; j < ncols_y; ++j) { @@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q( for (int l = 0; l < nwarps-1; ++l) { tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; } - tmp[j][i] = warp_reduce_sum(tmp[j][i]); + tmp[j][i] = warp_reduce_sum(tmp[j][i]); } if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { @@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q( } } +static std::pair calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) { + const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id); + const dim3 block_nums(nblocks, 1, 1); + const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1); + return {block_nums, block_dims}; +} + template static void mul_mat_vec_q_cuda( const void * vx, const void * vy, float * dst, @@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda( GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); - int id = ggml_cuda_get_device(); - - int64_t nwarps = 1; - int64_t rows_per_cuda_block = 1; - - if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2 - switch(ncols_y) { - case 1: - nwarps = 4; - rows_per_cuda_block = 1; - break; - case 2: - case 3: - case 4: - nwarps = 4; - rows_per_cuda_block = 2; - break; - case 5: - case 6: - case 7: - case 8: - nwarps = 2; - rows_per_cuda_block = 2; - break; - default: - GGML_ABORT("fatal error"); - break; - } - } - - const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; - const dim3 block_nums(nblocks, 1, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); switch (ncols_y) { case 1: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + { + constexpr int c_ncols_y = 1; + std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; + } case 2: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + { + constexpr int c_ncols_y = 2; + std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; + } case 3: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + { + constexpr int c_ncols_y = 3; + std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; + } case 4: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + { + constexpr int c_ncols_y = 4; + std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; + } case 5: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + { + constexpr int c_ncols_y = 5; + std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; + } case 6: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + { + constexpr int c_ncols_y = 6; + std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; + } case 7: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + { + constexpr int c_ncols_y = 7; + std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; + } case 8: - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + { + constexpr int c_ncols_y = 8; + std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; + } default: GGML_ABORT("fatal error"); break;