mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-19 08:07:09 +02:00
CUDA/HIP: fix ssm_scan on devices where warp size is not 32 (llama/14196)
This commit is contained in:
parent
aeaed9806f
commit
a433680a2f
@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
|
|||||||
float * __restrict__ dst, const int64_t L) {
|
float * __restrict__ dst, const int64_t L) {
|
||||||
GGML_UNUSED(src1_nb0);
|
GGML_UNUSED(src1_nb0);
|
||||||
GGML_UNUSED(src2_nb0);
|
GGML_UNUSED(src2_nb0);
|
||||||
|
|
||||||
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
const int bidx = blockIdx.x; // split along B
|
const int bidx = blockIdx.x; // split along B
|
||||||
const int bidy = blockIdx.y; // split along D
|
const int bidy = blockIdx.y; // split along D
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
|
|||||||
if (N == 16) {
|
if (N == 16) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
for (size_t i = 0; i < splitD / 4; i += 2) {
|
||||||
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
|
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
||||||
// todo: bank conflict
|
// todo: bank conflict
|
||||||
// I am always confused with how to use the swizzling method to solve
|
// I am always confused with how to use the swizzling method to solve
|
||||||
// bank conflit. Hoping somebody can tell me.
|
// bank conflit. Hoping somebody can tell me.
|
||||||
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
for (size_t i = 0; i < splitD / 4; i += 2) {
|
||||||
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
|
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
||||||
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user