mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-10 15:28:09 +02:00
metal : add im2col support + mul mat-vec f16 x f16
This commit is contained in:
21
ggml-cuda.cu
21
ggml-cuda.cu
@ -4736,7 +4736,10 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
|
|||||||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void im2col_f32_f16(const float* x, half* dst, int ofs0, int ofs1, int IW,int IH,int CHW,int s0,int s1,int p0,int p1,int d0,int d1) {
|
static __global__ void im2col_f32_f16(
|
||||||
|
const float * x, half * dst,
|
||||||
|
int ofs0, int ofs1, int IW, int IH, int CHW,
|
||||||
|
int s0, int s1, int p0, int p1, int d0, int d1) {
|
||||||
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
|
||||||
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
|
||||||
|
|
||||||
@ -5734,11 +5737,10 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
|
|||||||
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void im2col_f32_f16_cuda(const float* x, half* dst,
|
static void im2col_f32_f16_cuda(const float * x, half * dst,
|
||||||
int OH, int IW, int IH,
|
int OH, int IW, int IH, int OW, int IC,
|
||||||
int OW, int IC,
|
int KH, int KW, int N, int ofs0, int ofs1,
|
||||||
int KH, int KW, int N, int ofs0, int ofs1,
|
int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
|
||||||
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
|
|
||||||
dim3 block_nums(IC, OH, OW);
|
dim3 block_nums(IC, OH, OW);
|
||||||
dim3 block_dims(N, KH, KW);
|
dim3 block_dims(N, KH, KW);
|
||||||
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||||
@ -6730,11 +6732,12 @@ inline void ggml_cuda_op_im2col(
|
|||||||
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||||
const int64_t OW = dst->ne[1];
|
const int64_t OW = dst->ne[1];
|
||||||
|
|
||||||
|
const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||||
|
const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||||
|
|
||||||
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
|
||||||
OH, IW, IH, OW, IC, KH, KW, N,
|
OH, IW, IH, OW, IC, KH, KW, N,
|
||||||
src1->nb[is_2D ? 3 : 2] / 4, // nb is byte offset, src is type float32
|
ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
|
||||||
src1->nb[is_2D ? 2 : 1] / 4, // nb is byte offset, src is type float32
|
|
||||||
s0, s1, p0, p1, d0, d1, main_stream);
|
|
||||||
|
|
||||||
(void) src0;
|
(void) src0;
|
||||||
(void) src0_dd;
|
(void) src0_dd;
|
||||||
|
76
ggml-metal.m
76
ggml-metal.m
@ -86,6 +86,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(norm);
|
GGML_METAL_DECL_KERNEL(norm);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
||||||
@ -114,6 +115,7 @@ struct ggml_metal_context {
|
|||||||
GGML_METAL_DECL_KERNEL(rope_f32);
|
GGML_METAL_DECL_KERNEL(rope_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||||
@ -287,6 +289,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(norm);
|
GGML_METAL_ADD_KERNEL(norm);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
||||||
@ -317,6 +320,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_ADD_KERNEL(rope_f32);
|
GGML_METAL_ADD_KERNEL(rope_f32);
|
||||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||||
@ -386,6 +390,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DEL_KERNEL(norm);
|
GGML_METAL_DEL_KERNEL(norm);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
|
||||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
||||||
@ -416,6 +421,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||||||
GGML_METAL_DEL_KERNEL(rope_f32);
|
GGML_METAL_DEL_KERNEL(rope_f32);
|
||||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||||
GGML_METAL_DEL_KERNEL(alibi_f32);
|
GGML_METAL_DEL_KERNEL(alibi_f32);
|
||||||
|
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||||
@ -1139,6 +1145,7 @@ void ggml_metal_graph_compute(
|
|||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
|
||||||
nrows = 4;
|
nrows = 4;
|
||||||
} break;
|
} break;
|
||||||
@ -1146,13 +1153,18 @@ void ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
nth0 = 32;
|
nth0 = 32;
|
||||||
nth1 = 1;
|
nth1 = 1;
|
||||||
if (ne11 * ne12 < 4) {
|
if (src1t == GGML_TYPE_F32) {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
if (ne11 * ne12 < 4) {
|
||||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
nrows = ne11;
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
|
||||||
|
nrows = ne11;
|
||||||
|
} else {
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
||||||
|
nrows = 4;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
|
||||||
nrows = 4;
|
nrows = 4;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
@ -1464,6 +1476,58 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_IM2COL:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||||
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||||
|
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
||||||
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
||||||
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
||||||
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
||||||
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
||||||
|
|
||||||
|
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
||||||
|
const int32_t IC = src1->ne[is_2D ? 2 : 1];
|
||||||
|
const int32_t IH = is_2D ? src1->ne[1] : 1;
|
||||||
|
const int32_t IW = src1->ne[0];
|
||||||
|
|
||||||
|
const int32_t KH = is_2D ? src0->ne[1] : 1;
|
||||||
|
const int32_t KW = src0->ne[0];
|
||||||
|
|
||||||
|
const int32_t OH = is_2D ? dst->ne[2] : 1;
|
||||||
|
const int32_t OW = dst->ne[1];
|
||||||
|
|
||||||
|
const int32_t CHW = IC * KH * KW;
|
||||||
|
|
||||||
|
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
||||||
|
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
||||||
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
|
||||||
|
default: GGML_ASSERT(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
||||||
|
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
||||||
|
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
||||||
|
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
||||||
|
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
||||||
|
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
||||||
|
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
||||||
|
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
||||||
|
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
||||||
|
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
||||||
|
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||||
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
|
108
ggml-metal.metal
108
ggml-metal.metal
@ -792,7 +792,7 @@ kernel void kernel_mul_mv_f32_f32(
|
|||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t rb = tgpig.y*N_F32_F32;
|
const int64_t rb = tgpig.y*N_F32_F32;
|
||||||
@ -844,6 +844,79 @@ kernel void kernel_mul_mv_f32_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define N_F16_F16 4
|
||||||
|
|
||||||
|
kernel void kernel_mul_mv_f16_f16(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
const int64_t r0 = tgpig.x;
|
||||||
|
const int64_t rb = tgpig.y*N_F16_F16;
|
||||||
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
|
||||||
|
if (ne00 < 128) {
|
||||||
|
for (int row = 0; row < N_F16_F16; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00; i += 32) {
|
||||||
|
sumf += (half) x[i] * (half) y[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device const half4 * x4 = (device const half4 *)x;
|
||||||
|
for (int row = 0; row < N_F16_F16; ++row) {
|
||||||
|
int r1 = rb + row;
|
||||||
|
if (r1 >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
device const half4 * y4 = (device const half4 *) y;
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float all_sum = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mv_f16_f32_1row(
|
kernel void kernel_mul_mv_f16_f32_1row(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
@ -1229,6 +1302,39 @@ kernel void kernel_rope(
|
|||||||
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
|
||||||
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
|
||||||
|
|
||||||
|
kernel void kernel_im2col_f16(
|
||||||
|
device const float * x,
|
||||||
|
device half * dst,
|
||||||
|
constant int32_t & ofs0,
|
||||||
|
constant int32_t & ofs1,
|
||||||
|
constant int32_t & IW,
|
||||||
|
constant int32_t & IH,
|
||||||
|
constant int32_t & CHW,
|
||||||
|
constant int32_t & s0,
|
||||||
|
constant int32_t & s1,
|
||||||
|
constant int32_t & p0,
|
||||||
|
constant int32_t & p1,
|
||||||
|
constant int32_t & d0,
|
||||||
|
constant int32_t & d1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tgpg[[threadgroups_per_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
|
||||||
|
const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
|
||||||
|
|
||||||
|
const int32_t offset_dst =
|
||||||
|
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
|
||||||
|
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
|
||||||
|
|
||||||
|
if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
|
||||||
|
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
|
||||||
|
dst[offset_dst] = x[offset_src + iih * IW + iiw];
|
||||||
|
} else {
|
||||||
|
dst[offset_dst] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
device half * dst,
|
device half * dst,
|
||||||
|
53
ggml.c
53
ggml.c
@ -5131,13 +5131,15 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
|
|||||||
int s0,
|
int s0,
|
||||||
int p0,
|
int p0,
|
||||||
int d0) {
|
int d0) {
|
||||||
struct ggml_tensor * result = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
|
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false); // [N, OL, IC * K]
|
||||||
result =
|
|
||||||
ggml_reshape_3d(ctx,
|
struct ggml_tensor * result =
|
||||||
ggml_mul_mat(ctx,
|
ggml_mul_mat(ctx,
|
||||||
ggml_reshape_2d(ctx, result, result->ne[0], (result->ne[2] * result->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
|
ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
|
||||||
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])), // [OC,IC, K] => [OC, IC * K]
|
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K]
|
||||||
result->ne[1], a->ne[2], result->ne[2]); // [N, OC, OL]
|
|
||||||
|
result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5252,22 +5254,24 @@ struct ggml_tensor * ggml_im2col(
|
|||||||
// b: [N, IC, IH, IW]
|
// b: [N, IC, IH, IW]
|
||||||
// result: [N, OC, OH, OW]
|
// result: [N, OC, OH, OW]
|
||||||
struct ggml_tensor * ggml_conv_2d(
|
struct ggml_tensor * ggml_conv_2d(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b,
|
struct ggml_tensor * b,
|
||||||
int s0,
|
int s0,
|
||||||
int s1,
|
int s1,
|
||||||
int p0,
|
int p0,
|
||||||
int p1,
|
int p1,
|
||||||
int d0,
|
int d0,
|
||||||
int d1) {
|
int d1) {
|
||||||
struct ggml_tensor * result = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
|
struct ggml_tensor * result = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true); // [N, OH, OW, IC * KH * KW]
|
||||||
|
|
||||||
result =
|
result =
|
||||||
ggml_reshape_4d(ctx,
|
ggml_reshape_4d(ctx,
|
||||||
ggml_mul_mat(ctx,
|
ggml_mul_mat(ctx,
|
||||||
ggml_reshape_2d(ctx, result, result->ne[0], result->ne[3] * result->ne[2] * result->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
|
ggml_reshape_2d(ctx, result, result->ne[0], result->ne[3] * result->ne[2] * result->ne[1]), // [N, OH, OW, IC * KH * KW] => [N*OH*OW, IC * KH * KW]
|
||||||
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])), // [OC,IC, KH, KW] => [OC, IC * KH * KW]
|
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2]), a->ne[3])), // [OC,IC, KH, KW] => [OC, IC * KH * KW]
|
||||||
result->ne[1], result->ne[2], a->ne[3], result->ne[3]); // [N, OC, OH, OW]
|
result->ne[1], result->ne[2], a->ne[3], result->ne[3]); // [N, OC, OH, OW]
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -11724,17 +11728,18 @@ static void ggml_compute_forward_im2col_f16(
|
|||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||||
|
|
||||||
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
||||||
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
||||||
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
|
||||||
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
||||||
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
||||||
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
||||||
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
const int64_t N = is_2D ? ne13 : ne12;
|
|
||||||
|
const int64_t N = is_2D ? ne13 : ne12;
|
||||||
const int64_t IC = is_2D ? ne12 : ne11;
|
const int64_t IC = is_2D ? ne12 : ne11;
|
||||||
const int64_t IH = is_2D ? ne11 : 1;
|
const int64_t IH = is_2D ? ne11 : 1;
|
||||||
const int64_t IW = ne10;
|
const int64_t IW = ne10;
|
||||||
|
Reference in New Issue
Block a user