llava : add MobileVLM support (llama/5132)

* New Feature:
    1. Sum_Rows:
        fix cuda kernel overflow
        fix block shape error when nrows too big
    2. Im2Col:
        Support Batch in cuda
        Support f32 to f32 both in cpu && cuda
    3. DepthWiseConv:
        Support by Im2Col && MulMat
    4. Pool_2d:
        Supoort avg pooling in cuda
    5. HardSigmoid:
        Imp in cuda
    6. HardSwish:
        Imp in cuda

* fix tabs instead of spaces

* code clean

* CUDA POOL2D

* ADD POOL2D test case in test-backend-ops.cpp

* code clean

* fix pool2d_kernel

nits

* fix bug in pool2d kernel

* fix avg pooling, count_include_pad

nits

* test-backend-ops : add more pool_2d tests

* cuda : fix warnings and formatting

* ggml : check types in release builds too in pool_2d

* test-backend-ops : remove f16 pool_2d tests

* cuda : more style fixes

* Add assert in ggml_cuda_op_pool2d

* pool2d float padding fallback

* test-backend-ops : add dst_type to im2col

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
JidongZhang-THU 2024-01-31 21:10:15 +08:00 committed by Georgi Gerganov
parent fc7b0e2c28
commit 12c462d656
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 296 additions and 34 deletions

View File

@ -524,6 +524,8 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256
@ -540,6 +542,7 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong
#define CUDA_PAD_BLOCK_SIZE 256
#define CUDA_ACC_BLOCK_SIZE 256
#define CUDA_IM2COL_BLOCK_SIZE 256
#define CUDA_POOL2D_BLOCK_SIZE 256
#define CUDA_Q8_0_NE_ALIGN 2048
@ -823,6 +826,24 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
dst[i] = fmaxf(x[i], 0);
}
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
@ -5823,7 +5844,7 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols,
}
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.y;
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
@ -6145,9 +6166,10 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
static __global__ void im2col_f32_f16(
const float * x, half * dst,
int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst, int batch_offset,
int offset_delta, int IC, int IW, int IH, int OH, int OW, int KW, int KH, int pelements, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
@ -6160,21 +6182,73 @@ static __global__ void im2col_f32_f16(
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int oh = blockIdx.y;
const int batch = blockIdx.z / IC;
const int ic = blockIdx.z % IC;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t offset_dst =
(blockIdx.y * OW + ix) * CHW +
(blockIdx.z * (KW * KH) + ky * KW + kx);
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = blockIdx.z * offset_delta;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
template <typename Ti, typename To>
static __global__ void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= parallel_elements) {
return;
}
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / ow;
const int cur_ow = idx % O_HW % ow;
const Ti* i_ptr = src + nc * I_HW;
To* o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * sh - ph;
const int bh = max(0, start_h);
const int eh = min(ih, start_h + kh);
const int start_w = cur_ow * sw - pw;
const int bw = max(0, start_w);
const int ew = min(iw, start_w + kw);
const To scale = 1. / (kh * kw);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
}
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
#if __CUDA_ARCH__ >= 350
Ti cur = __ldg(i_ptr + i * iw + j);
#else
Ti cur = i_ptr[i * iw + j];
#endif
switch (op) {
case GGML_OP_POOL_AVG: res += cur * scale; break;
case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
}
}
}
o_ptr[cur_oh * ow + cur_ow] = res;
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@ -6388,6 +6462,16 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
@ -7475,7 +7559,7 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(1, nrows, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
@ -7587,14 +7671,15 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
}
}
static void im2col_f32_f16_cuda(const float* x, half* dst,
template <typename T>
static void im2col_cuda(const float* x, T* dst,
int IW, int IH, int OW, int OH, int KW, int KH, int IC,
int offset_delta,
int batch, int batch_offset, int offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, IC);
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
dim3 block_nums(num_blocks, OH, batch * IC);
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
@ -8179,6 +8264,34 @@ static void ggml_cuda_op_relu(
(void) src1_dd;
}
static void ggml_cuda_op_hardsigmoid(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
hardsigmoid_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
static void ggml_cuda_op_hardswish(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
hardswish_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
(void) src1;
(void) dst;
(void) src1_dd;
}
static void ggml_cuda_op_leaky_relu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
@ -8810,13 +8923,46 @@ static void ggml_cuda_op_alibi(
(void) src1_dd;
}
static void ggml_cuda_op_pool2d(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int parallel_elements = N * OC * OH * OW;
const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
dim3 block_nums(num_blocks);
pool2d_nchw_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, main_stream>>>(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_dd, dst_dd, op);
(void) src1;
(void) src1_dd;
}
static void ggml_cuda_op_im2col(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, cudaStream_t main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@ -8838,8 +8984,14 @@ static void ggml_cuda_op_im2col(
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[3];
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
if(dst->type == GGML_TYPE_F16) {
im2col_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
} else {
im2col_cuda(src1_dd, (float*) dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
}
(void) src0;
(void) src0_dd;
@ -9435,6 +9587,13 @@ static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, g
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
}
static void ggml_cuda_hardsigmoid(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardsigmoid);
}
static void ggml_cuda_hardswish(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_hardswish);
}
static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
}
@ -10220,6 +10379,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_alibi);
}
static void ggml_cuda_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pool2d);
}
static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col);
}
@ -10321,6 +10484,12 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_UNARY_OP_RELU:
func = ggml_cuda_relu;
break;
case GGML_UNARY_OP_HARDSIGMOID:
func = ggml_cuda_hardsigmoid;
break;
case GGML_UNARY_OP_HARDSWISH:
func = ggml_cuda_hardswish;
break;
default:
return false;
}
@ -10395,6 +10564,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
case GGML_OP_IM2COL:
func = ggml_cuda_im2col;
break;
case GGML_OP_POOL_2D:
func = ggml_cuda_pool2d;
break;
case GGML_OP_SUM_ROWS:
func = ggml_cuda_sum_rows;
break;
@ -11123,6 +11295,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_HARDSIGMOID:
case GGML_UNARY_OP_HARDSWISH:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_TANH:
return true;
@ -11221,6 +11395,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_ROPE:
case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:

118
ggml.c
View File

@ -5349,7 +5349,7 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
int s0,
int p0,
int d0) {
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
@ -5427,16 +5427,15 @@ struct ggml_tensor * ggml_conv_depthwise_2d(
int p1,
int d0,
int d1) {
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
s0, s1, p0, p1, d0, d1, true); // [N * IC, OH, OW, KH * KW]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1), // [OC1, KH, KW] => [1, OC, 1, KH * KW]
ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3])); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC1, KH, KW] => [1, OC, 1, KH * KW]
struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
return result;
@ -5457,7 +5456,8 @@ struct ggml_tensor * ggml_im2col(
int p1,
int d0,
int d1,
bool is_2D) {
bool is_2D,
enum ggml_type dst_type) {
if(is_2D) {
GGML_ASSERT(a->ne[2] == b->ne[2]);
@ -5481,7 +5481,7 @@ struct ggml_tensor * ggml_im2col(
is_2D ? b->ne[3] : 1,
};
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
ggml_set_op_params(result, params, sizeof(params));
@ -5506,7 +5506,7 @@ struct ggml_tensor * ggml_conv_2d(
int p1,
int d0,
int d1) {
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N, OH, OW, IC * KH * KW]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
@ -5632,12 +5632,13 @@ struct ggml_tensor * ggml_pool_2d(
is_node = true;
}
struct ggml_tensor * result;
const int64_t ne[3] = {
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
a->ne[2],
};
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
ggml_set_op_params(result, params, sizeof(params));
@ -5645,7 +5646,6 @@ struct ggml_tensor * ggml_pool_2d(
result->op = GGML_OP_POOL_2D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
return result;
}
@ -12493,6 +12493,92 @@ static void ggml_compute_forward_conv_transpose_1d(
}
}
// src0: kernel [OC, IC, KH, KW]
// src1: image [N, IC, IH, IW]
// dst: result [N, OH, OW, IC*KH*KW]
static void ggml_compute_forward_im2col_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
GGML_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = is_2D ? ne13 : ne12;
const int64_t IC = is_2D ? ne12 : ne11;
const int64_t IH = is_2D ? ne11 : 1;
const int64_t IW = ne10;
const int64_t KH = is_2D ? ne01 : 1;
const int64_t KW = ne00;
const int64_t OH = is_2D ? ne2 : 1;
const int64_t OW = ne1;
int ofs0 = is_2D ? nb13 : nb12;
int ofs1 = is_2D ? nb12 : nb11;
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_TASK_INIT) {
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
{
float * const wdata = (float *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
} else {
dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
}
}
}
}
}
}
}
}
}
// src0: kernel [OC, IC, KH, KW]
// src1: image [N, IC, IH, IW]
// dst: result [N, OH, OW, IC*KH*KW]
@ -12583,14 +12669,14 @@ static void ggml_compute_forward_im2col(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
switch (src0->type) {
switch (dst->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_im2col_f16(params, src0, src1, dst);
} break;
case GGML_TYPE_F32:
{
GGML_ASSERT(false);
ggml_compute_forward_im2col_f32(params, src0, src1, dst);
} break;
default:
{
@ -12781,8 +12867,8 @@ static void ggml_compute_forward_pool_2d(
const struct ggml_compute_params * params,
const struct ggml_tensor * src,
struct ggml_tensor * dst) {
assert(src->type == GGML_TYPE_F32);
assert(params->ith == 0);
GGML_ASSERT(src->type == GGML_TYPE_F32);
GGML_ASSERT(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;

3
ggml.h
View File

@ -1500,7 +1500,8 @@ extern "C" {
int p1,
int d0,
int d1,
bool is_2D);
bool is_2D,
enum ggml_type dst_type);
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
struct ggml_context * ctx,