mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-27 00:09:30 +01:00
cuda : replace remaining shfl_xor with calls to warp_reduce functions (llama/5744)
This commit is contained in:
parent
1c71816eab
commit
d83f371b5f
73
ggml-cuda.cu
73
ggml-cuda.cu
@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
|
|
||||||
//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
#ifdef GGML_CUDA_F16
|
||||||
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
//#pragma unroll
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
// for (int mask = 16; mask > 0; mask >>= 1) {
|
#pragma unroll
|
||||||
// a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
// }
|
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
||||||
// return a;
|
}
|
||||||
//#else
|
return a;
|
||||||
// (void) a;
|
#else
|
||||||
// NO_DEVICE_CODE;
|
(void) a;
|
||||||
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
NO_DEVICE_CODE;
|
||||||
//}
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
}
|
||||||
|
#endif // GGML_CUDA_F16
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
dst[row] = tmp;
|
dst[row] = tmp;
|
||||||
@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
dst[row] = tmp;
|
dst[row] = tmp;
|
||||||
@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[row] = tmp;
|
dst[row] = tmp;
|
||||||
@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
dst[row] = tmp;
|
dst[row] = tmp;
|
||||||
@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[row] = tmp;
|
dst[row] = tmp;
|
||||||
@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|||||||
float amax = fabsf(xi);
|
float amax = fabsf(xi);
|
||||||
float sum = xi;
|
float sum = xi;
|
||||||
|
|
||||||
#pragma unroll
|
amax = warp_reduce_max(amax);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
sum = warp_reduce_sum(sum);
|
||||||
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
|
|
||||||
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
const float d = amax / 127;
|
const float d = amax / 127;
|
||||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||||
@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
#ifdef GGML_CUDA_F16
|
#ifdef GGML_CUDA_F16
|
||||||
@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32(
|
|||||||
const int idst = channel*nrows_dst + row_dst;
|
const int idst = channel*nrows_dst + row_dst;
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
dst[idst] = tmp;
|
dst[idst] = tmp;
|
||||||
@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
tmp = warp_reduce_sum(tmp);
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
dst[idst] = tmp;
|
dst[idst] = tmp;
|
||||||
|
Loading…
Reference in New Issue
Block a user