metal : add special-case mat-vec mul for ne00 == 4 (llama/14385)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-26 15:51:19 +03:00
parent c848b9fbef
commit c300f1e32d
2 changed files with 86 additions and 3 deletions

View File

@ -211,11 +211,14 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,
GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
@ -1175,11 +1178,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4, mul_mv_f32_f32_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4, mul_mv_bf16_f32_c4, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4, mul_mv_f16_f32_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction);
@ -3111,14 +3117,23 @@ static bool ggml_metal_encode_node(
nsg = 1;
nr0 = 1;
nr1 = 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
if (ne00 == 4) {
nr0 = 32;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32_C4].pipeline;
} else {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
}
} break;
case GGML_TYPE_F16:
{
nsg = 1;
nr0 = 1;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
if (ne00 == 4) {
nr0 = 32;
nr1 = 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_C4].pipeline;
} else if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
@ -3137,7 +3152,11 @@ static bool ggml_metal_encode_node(
nsg = 1;
nr0 = 1;
if (src1t == GGML_TYPE_F32) {
if (ne11 * ne12 < 4) {
if (ne00 == 4) {
nr0 = 32;
nr1 = 4;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_C4].pipeline;
} else if (ne11 * ne12 < 4) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;

View File

@ -2532,6 +2532,70 @@ template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
#endif
template<typename T04, typename T14, typename args_t>
void kernel_mul_mv_c4_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig,
ushort tiisg) {
const int r0 = tgpig.x*32 + tiisg;
const int rb = tgpig.y*N_MV_T_T;
const int im = tgpig.z;
if (r0 >= args.ne01) {
return;
}
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
device const T04 * x = (device const T04 *) (src0 + offset0);
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
if (r1 >= args.ne11) {
break;
}
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T14 * y = (device const T14 *) (src1 + offset1);
dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]);
}
}
template<typename T04, typename T14>
kernel void kernel_mul_mv_c4(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_c4_impl<T04, T14, constant ggml_metal_kargs_mul_mv &>(
args,
src0,
src1,
dst,
tgpig,
tiisg);
}
typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
#endif
template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
constant ggml_metal_kargs_mul_mv & args,