mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-14 03:18:42 +02:00
cuda : fix im2col kernel
This commit is contained in:
27
ggml-cuda.cu
27
ggml-cuda.cu
@ -4737,13 +4737,18 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
||||
}
|
||||
|
||||
static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
|
||||
int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||
int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||
__syncthreads();
|
||||
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||
|
||||
const int offset_dst =
|
||||
(threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
|
||||
(blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
|
||||
|
||||
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
||||
int offset_dst = (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW;
|
||||
int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||
dst[offset_dst + (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z)] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||
const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
|
||||
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
|
||||
} else {
|
||||
dst[offset_dst] = __float2half(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@ -5735,7 +5740,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
|
||||
int KH, int KW, int N, int ofs0, int ofs1,
|
||||
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
||||
dim3 block_nums(IC, OH, OW);
|
||||
dim3 block_dims(N, KH, KW);
|
||||
dim3 block_dims(N, KH, KW);
|
||||
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||
}
|
||||
|
||||
@ -6714,16 +6719,16 @@ inline void ggml_cuda_op_im2col(
|
||||
|
||||
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
||||
|
||||
const int64_t N = src1->ne[is_2D ? 3 : 2];
|
||||
const int64_t N = src1->ne[is_2D ? 3 : 2];
|
||||
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
||||
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
||||
const int64_t IW = src1->ne[0];
|
||||
const int64_t IW = src1->ne[0];
|
||||
|
||||
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
||||
const int64_t KW = src0->ne[0];
|
||||
const int64_t KW = src0->ne[0];
|
||||
|
||||
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const int64_t OW = dst->ne[1];
|
||||
const int64_t OW = dst->ne[1];
|
||||
|
||||
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
||||
OH, IW, IH, OW, IC, KH, KW, N,
|
||||
|
4
ggml.c
4
ggml.c
@ -5227,13 +5227,13 @@ struct ggml_tensor * ggml_im2col(
|
||||
}
|
||||
|
||||
const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
|
||||
const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
||||
const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
|
||||
|
||||
const int64_t ne[4] = {
|
||||
is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0],
|
||||
OW,
|
||||
is_2D ? OH : b->ne[2],
|
||||
is_2D ? b->ne[3] : 1,
|
||||
is_2D ? b->ne[3] : 1,
|
||||
};
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
|
||||
|
24
whisper.cpp
24
whisper.cpp
@ -1604,22 +1604,22 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
||||
// convolution + gelu
|
||||
{
|
||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
||||
//cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
model.e_conv_1_b,
|
||||
cur),
|
||||
cur);
|
||||
cur = ggml_add(ctx0, cur, model.e_conv_1_b);
|
||||
//cur = ggml_add(ctx0,
|
||||
// ggml_repeat(ctx0,
|
||||
// model.e_conv_1_b,
|
||||
// cur),
|
||||
// cur);
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
||||
//cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_repeat(ctx0,
|
||||
model.e_conv_2_b,
|
||||
cur),
|
||||
cur);
|
||||
cur = ggml_add(ctx0, cur, model.e_conv_2_b);
|
||||
//cur = ggml_add(ctx0,
|
||||
// ggml_repeat(ctx0,
|
||||
// model.e_conv_2_b,
|
||||
// cur),
|
||||
// cur);
|
||||
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
}
|
||||
|
Reference in New Issue
Block a user