cuda : fix bounds check for src0 rows in MMVQ kernel

This commit is contained in:
Georgi Gerganov 2024-06-11 11:30:12 +03:00
parent 20c542c713
commit 9df6298a91
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -75,7 +75,7 @@ static __global__ void mul_mat_vec_q(
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
}
if (threadIdx.x < rows_per_cuda_block) {
if (threadIdx.x < rows_per_cuda_block && row0 + threadIdx.x < nrows_dst) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
}
}