mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-16 01:37:55 +02:00
cuda : fix im2col kernel
This commit is contained in:
17
ggml-cuda.cu
17
ggml-cuda.cu
@ -4737,13 +4737,18 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
||||
}
|
||||
|
||||
static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
|
||||
int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||
int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||
__syncthreads();
|
||||
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||
|
||||
const int offset_dst =
|
||||
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
||||
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
|
||||
|
||||
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
||||
int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW;
|
||||
int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||
dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||
} else {
|
||||
dst[offset_dst] = __float2half(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
|
24
whisper.cpp
24
whisper.cpp
@ -1604,22 +1604,22 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
||||
// convolution + gelu
|
||||
{
|
||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
||||
//cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
model.e_conv_1_b,
|
||||
cur),
|
||||
cur);
|
||||
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
||||
//cur = ggml_add(ctx0,
|
||||
// ggml_repeat(ctx0,
|
||||
// model.e_conv_1_b,
|
||||
// cur),
|
||||
// cur);
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
||||
//cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
model.e_conv_2_b,
|
||||
cur),
|
||||
cur);
|
||||
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
||||
//cur = ggml_add(ctx0,
|
||||
// ggml_repeat(ctx0,
|
||||
// model.e_conv_2_b,
|
||||
// cur),
|
||||
// cur);
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
}
|
||||
|
Reference in New Issue
Block a user