From 9fb1ebdaadaf3a8a73ed357df5dea7a896cfa519 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 26 Jun 2025 15:51:19 +0300 Subject: [PATCH] metal : add special-case mat-vec mul for ne00 == 4 (llama/14385) ggml-ci --- ggml/src/ggml-metal/ggml-metal.m | 25 +++++++++-- ggml/src/ggml-metal/ggml-metal.metal | 64 ++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 248fa378..d8d30cc0 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -211,11 +211,14 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, @@ -1175,11 +1178,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); @@ -3111,14 +3117,23 @@ static bool ggml_metal_encode_node( nsg = 1; nr0 = 1; nr1 = 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; + if (ne00 == 4) { + nr0 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; + } } break; case GGML_TYPE_F16: { nsg = 1; nr0 = 1; if (src1t == GGML_TYPE_F32) { - if (ne11 * ne12 < 4) { + if (ne00 == 4) { + nr0 = 32; + nr1 = 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline; + } else if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; @@ -3137,7 +3152,11 @@ static bool ggml_metal_encode_node( nsg = 1; nr0 = 1; if (src1t == GGML_TYPE_F32) { - if (ne11 * ne12 < 4) { + if (ne00 == 4) { + nr0 = 32; + nr1 = 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline; + } else if (ne11 * ne12 < 4) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f0282760..5f004a85 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2532,6 +2532,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv< template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv; #endif +template +void kernel_mul_mv_c4_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg) { + const int r0 = tgpig.x*32 + tiisg; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; + + if (r0 >= args.ne01) { + return; + } + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + device const T04 * x = (device const T04 *) (src0 + offset0); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= args.ne11) { + break; + } + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T14 * y = (device const T14 *) (src1 + offset1); + + dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]); + } +} + +template +kernel void kernel_mul_mv_c4( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_c4_impl( + args, + src0, + src1, + dst, + tgpig, + tiisg); +} + +typedef decltype(kernel_mul_mv_c4) mul_mv_c4_t; + +template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +#endif + template kernel void kernel_mul_mv_1row( constant ggml_metal_kargs_mul_mv & args,