CUDA/HIP: fix ssm_scan on devices where warp size is not 32 (llama/14196)

This commit is contained in:
uvos 2025-06-15 17:30:13 +02:00 committed by Georgi Gerganov
parent aeaed9806f
commit a433680a2f

View File

@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
float * __restrict__ dst, const int64_t L) { float * __restrict__ dst, const int64_t L) {
GGML_UNUSED(src1_nb0); GGML_UNUSED(src1_nb0);
GGML_UNUSED(src2_nb0); GGML_UNUSED(src2_nb0);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int bidx = blockIdx.x; // split along B const int bidx = blockIdx.x; // split along B
const int bidy = blockIdx.y; // split along D const int bidy = blockIdx.y; // split along D
const int tid = threadIdx.x; const int tid = threadIdx.x;
@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
if (N == 16) { if (N == 16) {
#pragma unroll #pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) { for (size_t i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warpSize + i) * stride_A + wtid]; float value = A_block[(wid * warp_size + i) * stride_A + wtid];
// todo: bank conflict // todo: bank conflict
// I am always confused with how to use the swizzling method to solve // I am always confused with how to use the swizzling method to solve
// bank conflit. Hoping somebody can tell me. // bank conflit. Hoping somebody can tell me.
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
} }
#pragma unroll #pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) { for (size_t i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
} }
} }