CUDA: more warps for mmvq on NVIDIA (llama/5394)

This commit is contained in:
Johannes Gäßler 2024-02-08 21:56:40 +01:00 committed by Georgi Gerganov
parent eec38f63bd
commit 9711bae0b3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -5310,22 +5310,26 @@ template <bool need_check> static __global__ void
#endif // __CUDA_ARCH__ >= CC_VOLTA #endif // __CUDA_ARCH__ >= CC_VOLTA
} }
template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda> #define MMVQ_NWARPS_NVIDIA 4
#define MMVQ_NWARPS_AMD_RDNA2 1
#define MMVQ_NWARPS_AMD_OLD 4
template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par; const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row = blockIdx.x;
if (row >= nrows_x) {
return;
}
const int blocks_per_row_x = ncols_x / qk; const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1; const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_warp = vdr * WARP_SIZE / qi; const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
// partial sum for each thread // partial sum for each thread
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f}; float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
@ -5333,12 +5337,12 @@ static __global__ void mul_mat_vec_q(
const block_q_t * x = (const block_q_t *) vx; const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy; const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) { for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
const int ibx = row*blocks_per_row_x + i; // x block index const int ibx = row*blocks_per_row_x + i; // x block index
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_y; ++j) { for (int j = 0; j < ncols_y; ++j) {
@ -5346,9 +5350,25 @@ static __global__ void mul_mat_vec_q(
} }
} }
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
}
}
__syncthreads();
if (threadIdx.y > 0) {
return;
}
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_y; ++j) { for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < nwarps-1; ++i) {
tmp[j] += tmp_shared[i][j][threadIdx.x];
}
tmp[j] = warp_reduce_sum(tmp[j]); tmp[j] = warp_reduce_sum(tmp[j]);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
@ -6833,46 +6853,65 @@ static void mul_mat_vec_q_cuda(
GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_x % qk == 0);
GGML_ASSERT(ncols_y <= 4); GGML_ASSERT(ncols_y <= 4);
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; int id;
const dim3 block_nums(block_num_y, 1, 1); CUDA_CHECK(cudaGetDevice(&id));
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
switch (ncols_y) { int nwarps;
case 1: if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); } else {
break; nwarps = MMVQ_NWARPS_NVIDIA;
case 2: }
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); const dim3 block_nums(nrows_x, 1, 1);
break; const dim3 block_dims(WARP_SIZE, nwarps, 1);
case 3:
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> switch (nwarps) {
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); case 1: switch(ncols_y) {
break; case 1:
case 4: mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break;
break; case 2:
// case 5: mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
// mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break;
// break; case 3:
// case 6: mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
// mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break;
// break; case 4:
// case 7: mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
// mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break;
// break; default:
// case 8: GGML_ASSERT(false);
// mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> break;
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); } break;
// break; case 4: switch(ncols_y) {
case 1:
mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 2:
mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 3:
mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
case 4:
mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break;
default:
GGML_ASSERT(false);
break;
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
break; break;
} }
} }