mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-26 16:48:50 +01:00
sync : ggml (Metal fixes, new ops, tests) (#1633)
* sync : ggml (Metal fixes, new ops, tests) * cuda : fix bin bcast when src1 and dst have different types
This commit is contained in:
parent
ec03661b20
commit
8171e621fc
@ -43,7 +43,7 @@ GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph
|
||||
// ggml-backend v2 API
|
||||
//
|
||||
|
||||
// Seperate tensor and graph allocator objects
|
||||
// Separate tensor and graph allocator objects
|
||||
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
|
||||
// The original API is kept as a wrapper around the new API
|
||||
|
||||
|
776
ggml-cuda.cu
776
ggml-cuda.cu
File diff suppressed because it is too large
Load Diff
581
ggml-metal.m
581
ggml-metal.m
@ -66,9 +66,11 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(div_row);
|
||||
GGML_METAL_DECL_KERNEL(scale);
|
||||
GGML_METAL_DECL_KERNEL(scale_4);
|
||||
GGML_METAL_DECL_KERNEL(silu);
|
||||
GGML_METAL_DECL_KERNEL(tanh);
|
||||
GGML_METAL_DECL_KERNEL(relu);
|
||||
GGML_METAL_DECL_KERNEL(gelu);
|
||||
GGML_METAL_DECL_KERNEL(gelu_quick);
|
||||
GGML_METAL_DECL_KERNEL(silu);
|
||||
GGML_METAL_DECL_KERNEL(soft_max);
|
||||
GGML_METAL_DECL_KERNEL(soft_max_4);
|
||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||
@ -86,6 +88,7 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(group_norm);
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
|
||||
@ -102,6 +105,21 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
||||
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
||||
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
|
||||
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
||||
@ -130,8 +148,11 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DECL_KERNEL(upscale_f32);
|
||||
GGML_METAL_DECL_KERNEL(pad_f32);
|
||||
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
|
||||
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
|
||||
GGML_METAL_DECL_KERNEL(leaky_relu_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
||||
@ -140,6 +161,7 @@ struct ggml_metal_context {
|
||||
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
||||
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(concat);
|
||||
GGML_METAL_DECL_KERNEL(sqr);
|
||||
GGML_METAL_DECL_KERNEL(sum_rows);
|
||||
@ -318,9 +340,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(div_row);
|
||||
GGML_METAL_ADD_KERNEL(scale);
|
||||
GGML_METAL_ADD_KERNEL(scale_4);
|
||||
GGML_METAL_ADD_KERNEL(silu);
|
||||
GGML_METAL_ADD_KERNEL(tanh);
|
||||
GGML_METAL_ADD_KERNEL(relu);
|
||||
GGML_METAL_ADD_KERNEL(gelu);
|
||||
GGML_METAL_ADD_KERNEL(gelu_quick);
|
||||
GGML_METAL_ADD_KERNEL(silu);
|
||||
GGML_METAL_ADD_KERNEL(soft_max);
|
||||
GGML_METAL_ADD_KERNEL(soft_max_4);
|
||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||
@ -338,6 +362,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(group_norm);
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
|
||||
@ -354,6 +379,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
||||
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
||||
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
|
||||
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
||||
@ -384,8 +424,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||
GGML_METAL_ADD_KERNEL(alibi_f32);
|
||||
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||
GGML_METAL_ADD_KERNEL(upscale_f32);
|
||||
GGML_METAL_ADD_KERNEL(pad_f32);
|
||||
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
|
||||
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
|
||||
GGML_METAL_ADD_KERNEL(leaky_relu_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
||||
@ -394,6 +437,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
||||
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(concat);
|
||||
GGML_METAL_ADD_KERNEL(sqr);
|
||||
GGML_METAL_ADD_KERNEL(sum_rows);
|
||||
@ -418,9 +462,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(div_row);
|
||||
GGML_METAL_DEL_KERNEL(scale);
|
||||
GGML_METAL_DEL_KERNEL(scale_4);
|
||||
GGML_METAL_DEL_KERNEL(silu);
|
||||
GGML_METAL_DEL_KERNEL(tanh);
|
||||
GGML_METAL_DEL_KERNEL(relu);
|
||||
GGML_METAL_DEL_KERNEL(gelu);
|
||||
GGML_METAL_DEL_KERNEL(gelu_quick);
|
||||
GGML_METAL_DEL_KERNEL(silu);
|
||||
GGML_METAL_DEL_KERNEL(soft_max);
|
||||
GGML_METAL_DEL_KERNEL(soft_max_4);
|
||||
GGML_METAL_DEL_KERNEL(diag_mask_inf);
|
||||
@ -438,6 +484,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||
GGML_METAL_DEL_KERNEL(group_norm);
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
|
||||
@ -454,6 +501,21 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
||||
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
||||
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
|
||||
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
||||
@ -484,8 +546,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||
GGML_METAL_DEL_KERNEL(alibi_f32);
|
||||
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DEL_KERNEL(upscale_f32);
|
||||
GGML_METAL_DEL_KERNEL(pad_f32);
|
||||
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
|
||||
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
|
||||
GGML_METAL_DEL_KERNEL(leaky_relu_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
||||
@ -494,6 +559,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
||||
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(concat);
|
||||
GGML_METAL_DEL_KERNEL(sqr);
|
||||
GGML_METAL_DEL_KERNEL(sum_rows);
|
||||
@ -795,9 +861,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
switch (op->op) {
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(op)) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@ -809,6 +877,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_CONCAT:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_SCALE:
|
||||
@ -816,21 +885,50 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_ALIBI:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_IM2COL:
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return true;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CONT:
|
||||
{
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_TYPE_F16:
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
};
|
||||
}
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->ne[0] % 4 == 0;
|
||||
return op->ne[3] == 1;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
@ -906,7 +1004,10 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
}
|
||||
|
||||
GGML_ASSERT(ggml_metal_supports_op(dst));
|
||||
if (!ggml_metal_supports_op(dst)) {
|
||||
GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
||||
GGML_ASSERT(!"unsupported op");
|
||||
}
|
||||
|
||||
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
||||
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
||||
@ -1003,34 +1104,39 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
const size_t offs = 0;
|
||||
|
||||
bool bcast_row = false;
|
||||
|
||||
int64_t nb = ne00;
|
||||
|
||||
if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
nb = ne00 / 4;
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
|
||||
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
|
||||
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
|
||||
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
||||
default: GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
bcast_row = true;
|
||||
} else {
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
|
||||
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
|
||||
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
|
||||
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
||||
default: GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
@ -1058,18 +1164,99 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
||||
|
||||
if (bcast_row) {
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} else {
|
||||
const int nth = MIN(1024, ne0);
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ACC:
|
||||
{
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dstt == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
||||
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
||||
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
||||
const size_t offs = ((int32_t *) dst->op_params)[3];
|
||||
|
||||
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
||||
|
||||
if (!inplace) {
|
||||
// run a separete kernel to cpy src->dst
|
||||
// not sure how to avoid this
|
||||
// TODO: make a simpler cpy_bytes kernel
|
||||
|
||||
const int nth = MIN(1024, ne00);
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
|
||||
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
|
||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
[encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
|
||||
[encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
|
||||
[encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
|
||||
[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
@ -1093,16 +1280,15 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_TANH:
|
||||
{
|
||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||
[encoder setComputePipelineState:ctx->pipeline_tanh];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
{
|
||||
@ -1123,6 +1309,28 @@ void ggml_metal_graph_compute(
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
{
|
||||
[encoder setComputePipelineState:ctx->pipeline_gelu_quick];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
{
|
||||
[encoder setComputePipelineState:ctx->pipeline_silu];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
GGML_ASSERT(n % 4 == 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
@ -1197,6 +1405,8 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
if (id_src1) {
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
@ -1448,7 +1658,7 @@ void ggml_metal_graph_compute(
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||
const int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
}
|
||||
@ -1460,7 +1670,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
||||
|
||||
const int n_as = ne00;
|
||||
const int n_as = ((int32_t *) dst->op_params)[1];
|
||||
|
||||
// TODO: make this more general
|
||||
GGML_ASSERT(n_as <= 8);
|
||||
@ -1492,14 +1702,22 @@ void ggml_metal_graph_compute(
|
||||
|
||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||
// to the matrix-vector kernel
|
||||
int ne11_mm_min = 0;
|
||||
int ne11_mm_min = 1;
|
||||
|
||||
const int idx = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
// batch size
|
||||
GGML_ASSERT(ne01 == ne11);
|
||||
|
||||
const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
|
||||
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||
ne11 > ne11_mm_min) {
|
||||
// !!!
|
||||
// TODO: for now, always use mat-vec kernels until we figure out how to improve the
|
||||
// indirect matrix multiplication
|
||||
// !!!
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
|
||||
switch (src2->type) {
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
|
||||
@ -1518,19 +1736,22 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
|
||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:15];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
||||
// TODO: how to make this an array? read Metal docs
|
||||
for (int j = 0; j < n_as; ++j) {
|
||||
struct ggml_tensor * src_cur = dst->src[2 + j];
|
||||
@ -1538,11 +1759,157 @@ void ggml_metal_graph_compute(
|
||||
size_t offs_src_cur = 0;
|
||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
||||
|
||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
|
||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
||||
}
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
|
||||
// TODO: processing one row at a time (ne11 -> 1) is not efficient
|
||||
[encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
int nth0 = 32;
|
||||
int nth1 = 1;
|
||||
int nrows = 1;
|
||||
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src2t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
{
|
||||
nth0 = 8;
|
||||
nth1 = 8;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
{
|
||||
nth0 = 4; //1;
|
||||
nth1 = 8; //32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
{
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
||||
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
||||
// TODO: how to make this an array? read Metal docs
|
||||
for (int j = 0; j < n_as; ++j) {
|
||||
struct ggml_tensor * src_cur = dst->src[2 + j];
|
||||
|
||||
size_t offs_src_cur = 0;
|
||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
|
||||
|
||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
||||
}
|
||||
|
||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
||||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
||||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q3_K) {
|
||||
#ifdef GGML_QKK_64
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#else
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#endif
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
@ -1563,16 +1930,19 @@ void ggml_metal_graph_compute(
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
|
||||
[encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
|
||||
[encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
|
||||
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
|
||||
|
||||
const int64_t n = ggml_nelements(src1);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
{
|
||||
@ -1599,6 +1969,38 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
//float eps;
|
||||
//memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
const float eps = 1e-6f; // TODO: temporarily hardcoded
|
||||
|
||||
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
//while (nth < ne00/4 && nth < 1024) {
|
||||
// nth *= 2;
|
||||
//}
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_group_norm];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
|
||||
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
|
||||
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
|
||||
[encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
|
||||
[encoder setBytes:&eps length:sizeof( float) atIndex:9];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_NORM:
|
||||
{
|
||||
float eps;
|
||||
@ -1768,6 +2170,65 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
} break;
|
||||
case GGML_OP_UPSCALE:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
const int sf = dst->op_params[0];
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_upscale_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||||
[encoder setBytes:&sf length:sizeof(sf) atIndex:18];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_pad_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
||||
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
||||
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
||||
|
||||
const int nth = MIN(1024, ne0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
@ -1789,6 +2250,22 @@ void ggml_metal_graph_compute(
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
float slope;
|
||||
memcpy(&slope, dst->op_params, sizeof(float));
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&slope length:sizeof(slope) atIndex:2];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
@ -1817,7 +2294,7 @@ void ggml_metal_graph_compute(
|
||||
{
|
||||
switch (dstt) {
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
|
||||
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
};
|
||||
} break;
|
||||
|
1616
ggml-metal.metal
1616
ggml-metal.metal
File diff suppressed because it is too large
Load Diff
@ -3114,7 +3114,7 @@ void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restri
|
||||
|
||||
size_t vl = __riscv_vsetvl_e8m1(qk/2);
|
||||
|
||||
// These tempory registers are for masking and shift operations
|
||||
// These temporary registers are for masking and shift operations
|
||||
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
|
||||
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
|
||||
|
||||
@ -4757,7 +4757,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||
|
||||
vl = 16;
|
||||
|
||||
// retreive lane to multiply with scale
|
||||
// retrieve lane to multiply with scale
|
||||
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
|
||||
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
|
||||
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
|
||||
|
363
ggml.c
363
ggml.c
@ -1,4 +1,4 @@
|
||||
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
|
||||
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
||||
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
||||
|
||||
#include "ggml-impl.h"
|
||||
@ -33,7 +33,7 @@
|
||||
// we should just be careful :)
|
||||
#pragma warning(disable: 4244 4267)
|
||||
|
||||
// disable POSIX deprecation warnigns
|
||||
// disable POSIX deprecation warnings
|
||||
// these functions are never going away, anyway
|
||||
#pragma warning(disable: 4996)
|
||||
#endif
|
||||
@ -1395,7 +1395,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
|
||||
inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
|
||||
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
|
||||
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
||||
inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
|
||||
inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
|
||||
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float GELU_QUICK_COEF = -1.702f;
|
||||
@ -1623,7 +1623,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"POOL_1D",
|
||||
"POOL_2D",
|
||||
"UPSCALE",
|
||||
"PAD",
|
||||
"ARGSORT",
|
||||
"LEAKY_RELU",
|
||||
|
||||
"FLASH_ATTN",
|
||||
"FLASH_FF",
|
||||
@ -1650,7 +1652,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
|
||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@ -1707,7 +1709,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"pool_1d(x)",
|
||||
"pool_2d(x)",
|
||||
"upscale(x)",
|
||||
"pad(x)",
|
||||
"argsort(x)",
|
||||
"leaky_relu(x)",
|
||||
|
||||
"flash_attn(x)",
|
||||
"flash_ff(x)",
|
||||
@ -1734,7 +1738,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
|
||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@ -1750,17 +1754,16 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"GELU",
|
||||
"GELU_QUICK",
|
||||
"SILU",
|
||||
"LEAKY",
|
||||
};
|
||||
|
||||
static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
|
||||
static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10");
|
||||
|
||||
|
||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||
|
||||
// WARN:
|
||||
// Mis-confguration can lead to problem that's hard to reason about:
|
||||
// Mis-configuration can lead to problem that's hard to reason about:
|
||||
// * At best it crash or talks nosense.
|
||||
// * At worst it talks slightly difference but hard to perceive.
|
||||
//
|
||||
@ -3830,12 +3833,25 @@ struct ggml_tensor * ggml_relu_inplace(
|
||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
|
||||
}
|
||||
|
||||
// ggml_leaky
|
||||
// ggml_leaky_relu
|
||||
|
||||
struct ggml_tensor * ggml_leaky(
|
||||
struct ggml_tensor * ggml_leaky_relu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
|
||||
struct ggml_tensor * a, float negative_slope, bool inplace) {
|
||||
bool is_node = false;
|
||||
|
||||
if (!inplace && (a->grad)) {
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
|
||||
|
||||
result->op = GGML_OP_LEAKY_RELU;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_gelu
|
||||
@ -4022,8 +4038,9 @@ static struct ggml_tensor * ggml_group_norm_impl(
|
||||
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_GROUP_NORM;
|
||||
result->op_params[0] = n_groups;
|
||||
|
||||
result->op = GGML_OP_GROUP_NORM;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = NULL; // TODO: maybe store epsilon here?
|
||||
@ -4075,17 +4092,18 @@ struct ggml_tensor * ggml_mul_mat(
|
||||
|
||||
struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * as[],
|
||||
struct ggml_tensor * const as[],
|
||||
int n_as,
|
||||
struct ggml_tensor * ids,
|
||||
int id,
|
||||
struct ggml_tensor * b) {
|
||||
|
||||
int64_t n_as = ids->ne[0];
|
||||
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(ggml_is_vector(ids));
|
||||
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
|
||||
GGML_ASSERT(ids->ne[1] == b->ne[1]);
|
||||
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
|
||||
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
|
||||
GGML_ASSERT(id >= 0 && id < n_as);
|
||||
GGML_ASSERT(id >= 0 && id < ids->ne[0]);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
@ -4097,13 +4115,14 @@ struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, id);
|
||||
ggml_set_op_params_i32(result, 1, n_as);
|
||||
|
||||
result->op = GGML_OP_MUL_MAT_ID;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = ids;
|
||||
result->src[1] = b;
|
||||
|
||||
for (int64_t i = 0; i < n_as; i++) {
|
||||
for (int i = 0; i < n_as; i++) {
|
||||
struct ggml_tensor * a = as[i];
|
||||
GGML_ASSERT(ggml_are_same_shape(as[0], a));
|
||||
GGML_ASSERT(ggml_can_mul_mat(a, b));
|
||||
@ -4731,7 +4750,9 @@ struct ggml_tensor * ggml_get_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[1]);
|
||||
GGML_ASSERT(b->ne[3] == 1);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
|
||||
bool is_node = false;
|
||||
|
||||
@ -4741,7 +4762,7 @@ struct ggml_tensor * ggml_get_rows(
|
||||
|
||||
// TODO: implement non F32 return
|
||||
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
|
||||
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
|
||||
|
||||
result->op = GGML_OP_GET_ROWS;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
@ -5519,6 +5540,30 @@ static struct ggml_tensor * ggml_upscale_impl(
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_pad(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0, int p1, int p2, int p3) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
|
||||
a->ne[0] + p0,
|
||||
a->ne[1] + p1,
|
||||
a->ne[2] + p2,
|
||||
a->ne[3] + p3);
|
||||
|
||||
result->op = GGML_OP_PAD;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src[0] = a;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_upscale(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -7520,7 +7565,7 @@ static void ggml_compute_forward_acc_f32(
|
||||
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
||||
|
||||
// view src0 and dst with these strides and data offset inbytes during acc
|
||||
// nb0 is implicitely element_size because src0 and dst are contiguous
|
||||
// nb0 is implicitly element_size because src0 and dst are contiguous
|
||||
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
||||
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
||||
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
||||
@ -7714,8 +7759,10 @@ static void ggml_compute_forward_mul_f32(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
// TODO: OpenCL kernel support broadcast
|
||||
#ifdef GGML_USE_CLBLAST
|
||||
if (src1->backend == GGML_BACKEND_GPU) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||
if (ith == 0) {
|
||||
ggml_cl_mul(src0, src1, dst);
|
||||
}
|
||||
@ -8981,10 +9028,9 @@ static void ggml_compute_forward_silu(
|
||||
} break;
|
||||
}
|
||||
}
|
||||
// ggml_compute_forward_leaky_relu
|
||||
|
||||
// ggml_compute_forward_leaky
|
||||
|
||||
static void ggml_compute_forward_leaky_f32(
|
||||
static void ggml_compute_forward_leaky_relu_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
@ -8998,24 +9044,27 @@ static void ggml_compute_forward_leaky_f32(
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
float negative_slope;
|
||||
memcpy(&negative_slope, dst->op_params, sizeof(float));
|
||||
|
||||
assert(dst->nb[0] == sizeof(float));
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
ggml_vec_leaky_f32(nc,
|
||||
ggml_vec_leaky_relu_f32(nc,
|
||||
(float *) ((char *) dst->data + i*( dst->nb[1])),
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])));
|
||||
(float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_leaky(
|
||||
static void ggml_compute_forward_leaky_relu(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_leaky_f32(params, src0, dst);
|
||||
ggml_compute_forward_leaky_relu_f32(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
@ -9504,8 +9553,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
|
||||
// NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
|
||||
// all the experts for each batch element and the processing would become incredibly slow
|
||||
// TODO: find the optimal values for these
|
||||
if (ggml_is_contiguous(src0) &&
|
||||
if (dst->op != GGML_OP_MUL_MAT_ID &&
|
||||
ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
//src0->type == GGML_TYPE_F32 &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
@ -9519,11 +9571,16 @@ static bool ggml_compute_forward_mul_mat_use_blas(
|
||||
}
|
||||
#endif
|
||||
|
||||
// off1 = offset in i11 and i1
|
||||
// cne1 = ne11 and ne1
|
||||
// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
|
||||
// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
|
||||
static void ggml_compute_forward_mul_mat(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
struct ggml_tensor * dst,
|
||||
int64_t off1, int64_t cne1) {
|
||||
int64_t t0 = ggml_perf_time_us();
|
||||
UNUSED(t0);
|
||||
|
||||
@ -9591,10 +9648,9 @@ static void ggml_compute_forward_mul_mat(
|
||||
const int64_t i03 = i13/r3;
|
||||
const int64_t i02 = i12/r2;
|
||||
|
||||
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
|
||||
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
|
||||
|
||||
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
||||
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
|
||||
const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
|
||||
float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
if (type != GGML_TYPE_F32) {
|
||||
float * const wdata = params->wdata;
|
||||
@ -9611,10 +9667,10 @@ static void ggml_compute_forward_mul_mat(
|
||||
}
|
||||
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
ne11, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne00,
|
||||
0.0f, d, ne01);
|
||||
cne1, ne01, ne10,
|
||||
1.0f, y, ne10,
|
||||
x, ne00,
|
||||
0.0f, d, ne01);
|
||||
}
|
||||
}
|
||||
|
||||
@ -9630,6 +9686,7 @@ static void ggml_compute_forward_mul_mat(
|
||||
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
||||
|
||||
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
||||
assert(src1->type == GGML_TYPE_F32);
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
@ -9652,7 +9709,7 @@ static void ggml_compute_forward_mul_mat(
|
||||
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
||||
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = ne11*ne12*ne13; // src1 rows
|
||||
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
|
||||
|
||||
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
||||
|
||||
@ -9694,9 +9751,9 @@ static void ggml_compute_forward_mul_mat(
|
||||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
||||
const int64_t i13 = (ir1/(ne12*ne11));
|
||||
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
|
||||
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
|
||||
const int64_t i13 = (ir1/(ne12*cne1));
|
||||
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
|
||||
const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
|
||||
|
||||
// broadcast src0 into src1
|
||||
const int64_t i03 = i13/r3;
|
||||
@ -9736,20 +9793,28 @@ static void ggml_compute_forward_mul_mat(
|
||||
|
||||
static void ggml_compute_forward_mul_mat_id(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * ids = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
// during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
|
||||
ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
|
||||
return;
|
||||
}
|
||||
|
||||
const int id = ggml_get_op_params_i32(dst, 0);
|
||||
const struct ggml_tensor * ids = src0;
|
||||
const int id = ggml_get_op_params_i32(dst, 0);
|
||||
const int n_as = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
const int a_id = ((int32_t *)ids->data)[id];
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
|
||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[a_id + 2];
|
||||
|
||||
ggml_compute_forward_mul_mat(params, src0, src1, dst);
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_out_prod
|
||||
@ -10161,7 +10226,7 @@ static void ggml_compute_forward_set_f32(
|
||||
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
||||
|
||||
// view src0 and dst with these strides and data offset inbytes during set
|
||||
// nb0 is implicitely element_size because src0 and dst are contiguous
|
||||
// nb0 is implicitly element_size because src0 and dst are contiguous
|
||||
size_t nb1 = ((int32_t *) dst->op_params)[0];
|
||||
size_t nb2 = ((int32_t *) dst->op_params)[1];
|
||||
size_t nb3 = ((int32_t *) dst->op_params)[2];
|
||||
@ -10325,21 +10390,30 @@ static void ggml_compute_forward_get_rows_q(
|
||||
return;
|
||||
}
|
||||
|
||||
const int nc = src0->ne[0];
|
||||
const int nr = ggml_nelements(src1);
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
|
||||
|
||||
assert( dst->ne[0] == nc);
|
||||
assert( dst->ne[1] == nr);
|
||||
assert(src0->nb[0] == ggml_type_size(type));
|
||||
assert(ne0 == nc);
|
||||
assert(ne02 == ne11);
|
||||
assert(nb00 == ggml_type_size(type));
|
||||
assert(ggml_nrows(dst) == nr);
|
||||
|
||||
for (int i = 0; i < nr; ++i) {
|
||||
const int r = ((int32_t *) src1->data)[i];
|
||||
// TODO: multi-thread
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
||||
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
dequantize_row_q(
|
||||
(const void *) ((char *) src0->data + r*src0->nb[1]),
|
||||
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
|
||||
dequantize_row_q(
|
||||
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -10354,19 +10428,26 @@ static void ggml_compute_forward_get_rows_f16(
|
||||
return;
|
||||
}
|
||||
|
||||
const int nc = src0->ne[0];
|
||||
const int nr = ggml_nelements(src1);
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
assert( dst->ne[0] == nc);
|
||||
assert( dst->ne[1] == nr);
|
||||
assert(src0->nb[0] == sizeof(ggml_fp16_t));
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
||||
|
||||
for (int i = 0; i < nr; ++i) {
|
||||
const int r = ((int32_t *) src1->data)[i];
|
||||
assert(ne0 == nc);
|
||||
assert(ne02 == ne11);
|
||||
assert(nb00 == sizeof(ggml_fp16_t));
|
||||
assert(ggml_nrows(dst) == nr);
|
||||
|
||||
for (int j = 0; j < nc; ++j) {
|
||||
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
|
||||
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
|
||||
// TODO: multi-thread
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
||||
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
ggml_fp16_to_fp32_row(
|
||||
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
|
||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -10382,19 +10463,27 @@ static void ggml_compute_forward_get_rows_f32(
|
||||
return;
|
||||
}
|
||||
|
||||
const int nc = src0->ne[0];
|
||||
const int nr = ggml_nelements(src1);
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
assert( dst->ne[0] == nc);
|
||||
assert( dst->ne[1] == nr);
|
||||
assert(src0->nb[0] == sizeof(float));
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
|
||||
|
||||
for (int i = 0; i < nr; ++i) {
|
||||
const int r = ((int32_t *) src1->data)[i];
|
||||
assert(ne0 == nc);
|
||||
assert(ne02 == ne11);
|
||||
assert(nb00 == sizeof(float));
|
||||
assert(ggml_nrows(dst) == nr);
|
||||
|
||||
ggml_vec_cpy_f32(nc,
|
||||
(float *) ((char *) dst->data + i*dst->nb[1]),
|
||||
(float *) ((char *) src0->data + r*src0->nb[1]));
|
||||
// TODO: multi-thread
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
||||
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
ggml_vec_cpy_f32(nc,
|
||||
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
|
||||
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -12114,6 +12203,7 @@ static void ggml_compute_forward_upscale_f32(
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
@ -12121,16 +12211,17 @@ static void ggml_compute_forward_upscale_f32(
|
||||
|
||||
// TODO: optimize
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = ith; i02 < ne02; i02++) {
|
||||
for (int m = 0; m < dst->ne[1]; m++) {
|
||||
int i01 = m / scale_factor;
|
||||
for (int n = 0; n < dst->ne[0]; n++) {
|
||||
int i00 = n / scale_factor;
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
const int64_t i03 = i3;
|
||||
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
|
||||
const int64_t i02 = i2;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const int64_t i01 = i1 / scale_factor;
|
||||
for (int64_t i0 = 0; i0 < ne0; i0++) {
|
||||
const int64_t i00 = i0 / scale_factor;
|
||||
|
||||
const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||
|
||||
float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
|
||||
const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
@ -12155,6 +12246,64 @@ static void ggml_compute_forward_upscale(
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_pad
|
||||
|
||||
static void ggml_compute_forward_pad_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
GGML_ASSERT( dst->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
|
||||
float * dst_ptr = (float *) dst->data;
|
||||
|
||||
// TODO: optimize
|
||||
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
|
||||
for (int64_t i0 = 0; i0 < ne0; ++i0) {
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
|
||||
|
||||
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
|
||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
dst_ptr[dst_idx] = *src_ptr;
|
||||
} else {
|
||||
dst_ptr[dst_idx] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_pad(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
struct ggml_tensor * dst) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_pad_f32(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_argsort
|
||||
|
||||
static void ggml_compute_forward_argsort_f32(
|
||||
@ -13362,10 +13511,6 @@ static void ggml_compute_forward_unary(
|
||||
{
|
||||
ggml_compute_forward_silu(params, src0, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_LEAKY:
|
||||
{
|
||||
ggml_compute_forward_leaky(params, src0, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
@ -14037,11 +14182,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
|
||||
ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
ggml_compute_forward_mul_mat_id(params, tensor);
|
||||
ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
|
||||
} break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
{
|
||||
@ -14147,10 +14292,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
{
|
||||
ggml_compute_forward_upscale(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
ggml_compute_forward_pad(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
ggml_compute_forward_argsort(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
{
|
||||
const int32_t t = ggml_get_op_params_i32(tensor, 0);
|
||||
@ -14475,7 +14628,7 @@ void ggml_build_backward_gradient_checkpointing(
|
||||
// insert new tensors recomputing src, reusing already made replacements,
|
||||
// remember replacements: remember new tensors with mapping from corresponding gf nodes
|
||||
// recurse for input tensors,
|
||||
// unless (i.e. terminating when) input tensors are replacments (like checkpoints)
|
||||
// unless (i.e. terminating when) input tensors are replacements (like checkpoints)
|
||||
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
|
||||
}
|
||||
// insert rewritten backward node with replacements made into resulting backward graph gb
|
||||
@ -15143,10 +15296,18 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
GGML_ASSERT(false); // TODO: not implemented
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
{
|
||||
struct ggml_tensor * flash_grad = NULL;
|
||||
@ -15752,6 +15913,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_ARGMAX:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
@ -15764,7 +15926,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_UNARY_OP_TANH:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_LEAKY:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
@ -15883,6 +16044,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_PAD:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_ARGSORT:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
28
ggml.h
28
ggml.h
@ -215,9 +215,9 @@
|
||||
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
|
||||
|
||||
#define GGML_MAX_DIMS 4
|
||||
#define GGML_MAX_PARAMS 1024
|
||||
#define GGML_MAX_PARAMS 2048
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_SRC 6
|
||||
#define GGML_MAX_SRC 10
|
||||
#define GGML_MAX_NAME 64
|
||||
#define GGML_MAX_OP_PARAMS 64
|
||||
#define GGML_DEFAULT_N_THREADS 4
|
||||
@ -423,7 +423,9 @@ extern "C" {
|
||||
GGML_OP_POOL_1D,
|
||||
GGML_OP_POOL_2D,
|
||||
GGML_OP_UPSCALE, // nearest interpolate
|
||||
GGML_OP_PAD,
|
||||
GGML_OP_ARGSORT,
|
||||
GGML_OP_LEAKY_RELU,
|
||||
|
||||
GGML_OP_FLASH_ATTN,
|
||||
GGML_OP_FLASH_FF,
|
||||
@ -463,7 +465,6 @@ extern "C" {
|
||||
GGML_UNARY_OP_GELU,
|
||||
GGML_UNARY_OP_GELU_QUICK,
|
||||
GGML_UNARY_OP_SILU,
|
||||
GGML_UNARY_OP_LEAKY,
|
||||
|
||||
GGML_UNARY_OP_COUNT,
|
||||
};
|
||||
@ -793,6 +794,9 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
// dst = a
|
||||
// view(dst, nb1, nb2, nb3, offset) += b
|
||||
// return dst
|
||||
GGML_API struct ggml_tensor * ggml_acc(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -957,15 +961,14 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_leaky(
|
||||
GGML_API struct ggml_tensor * ggml_leaky_relu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * a, float negative_slope, bool inplace);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_relu_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// TODO: double-check this computation is correct
|
||||
GGML_API struct ggml_tensor * ggml_gelu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
@ -1051,7 +1054,8 @@ extern "C" {
|
||||
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
||||
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * as[],
|
||||
struct ggml_tensor * const as[],
|
||||
int n_as,
|
||||
struct ggml_tensor * ids,
|
||||
int id,
|
||||
struct ggml_tensor * b);
|
||||
@ -1263,6 +1267,7 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// supports 3D: a->ne[2] == b->ne[1]
|
||||
GGML_API struct ggml_tensor * ggml_get_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
@ -1549,6 +1554,15 @@ extern "C" {
|
||||
struct ggml_tensor * a,
|
||||
int scale_factor);
|
||||
|
||||
// pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
|
||||
GGML_API struct ggml_tensor * ggml_pad(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int p0,
|
||||
int p1,
|
||||
int p2,
|
||||
int p3);
|
||||
|
||||
// sort rows
|
||||
enum ggml_sort_order {
|
||||
GGML_SORT_ASC,
|
||||
|
Loading…
Reference in New Issue
Block a user