cuda : fix im2col kernel

This commit is contained in:
Georgi Gerganov
2023-11-10 19:39:24 +02:00
parent 000b952c2d
commit 9c1ddc77a7
3 changed files with 30 additions and 25 deletions

View File

@ -4737,13 +4737,18 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
}
static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
__syncthreads();
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
const int offset_dst =
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW;
int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]);
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
} else {
dst[offset_dst] = __float2half(0.0f);
}
}

View File

@ -1604,22 +1604,22 @@ static struct ggml_cgraph * whisper_build_graph_conv(
// convolution + gelu
{
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
//cur = ggml_add(ctx0, cur, model.e_conv_1_b);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
model.e_conv_1_b,
cur),
cur);
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
//cur = ggml_add(ctx0,
// ggml_repeat(ctx0,
// model.e_conv_1_b,
// cur),
// cur);
cur = ggml_gelu(ctx0, cur);
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
//cur = ggml_add(ctx0, cur, model.e_conv_2_b);
cur = ggml_add(ctx0,
ggml_repeat(ctx0,
model.e_conv_2_b,
cur),
cur);
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
//cur = ggml_add(ctx0,
// ggml_repeat(ctx0,
// model.e_conv_2_b,
// cur),
// cur);
cur = ggml_gelu(ctx0, cur);
}