diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6a4d57cb..1242c041 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { return a; } -//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -//#pragma unroll -// for (int mask = 16; mask > 0; mask >>= 1) { -// a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); -// } -// return a; -//#else -// (void) a; -// NO_DEVICE_CODE; -//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL -//} +#ifdef GGML_CUDA_F16 +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#else + (void) a; + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} +#endif // GGML_CUDA_F16 static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll @@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[row] = tmp; @@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[row] = tmp; @@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (tid == 0) { dst[row] = tmp; @@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[row] = tmp; @@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, #endif // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (tid == 0) { dst[row] = tmp; @@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest float amax = fabsf(xi); float sum = xi; -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); - } + amax = warp_reduce_max(amax); + sum = warp_reduce_sum(sum); const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); @@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (tid == 0) { #ifdef GGML_CUDA_F16 @@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32( const int idst = channel*nrows_dst + row_dst; // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[idst] = tmp; @@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous } // sum up partial sums and write back result -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); - } + tmp = warp_reduce_sum(tmp); if (threadIdx.x == 0) { dst[idst] = tmp;