mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-18 23:57:09 +02:00
CUDA/HIP: fix ssm_scan on devices where warp size is not 32 (llama/14196)
This commit is contained in:
parent
aeaed9806f
commit
a433680a2f
@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
|
||||
float * __restrict__ dst, const int64_t L) {
|
||||
GGML_UNUSED(src1_nb0);
|
||||
GGML_UNUSED(src2_nb0);
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int bidx = blockIdx.x; // split along B
|
||||
const int bidy = blockIdx.y; // split along D
|
||||
const int tid = threadIdx.x;
|
||||
@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
|
||||
if (N == 16) {
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
||||
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
|
||||
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
||||
// todo: bank conflict
|
||||
// I am always confused with how to use the swizzling method to solve
|
||||
// bank conflit. Hoping somebody can tell me.
|
||||
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
||||
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
|
||||
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
||||
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user