From 75e9a840c57fe781c3e203d58caa10333940d58d Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 13 May 2025 13:10:08 +0300 Subject: [PATCH] ggml : add mrope kernel for metal (llama/13457) --- ggml/src/ggml-metal/ggml-metal-impl.h | 4 + ggml/src/ggml-metal/ggml-metal.m | 58 +++++++--- ggml/src/ggml-metal/ggml-metal.metal | 146 ++++++++++++++++++++++++++ 3 files changed, 192 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ea8cf458..17eab976 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -207,6 +207,10 @@ typedef struct { float attn_factor; float beta_fast; float beta_slow; + int32_t sect_0; + int32_t sect_1; + int32_t sect_2; + int32_t sect_3; } ggml_metal_kargs_rope; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 943f0fe7..576f9581 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -332,6 +332,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, @@ -1275,6 +1279,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); @@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return true; - } + return true; case GGML_OP_IM2COL: return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: @@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node( } break; case GGML_OP_ROPE: { + // make sure we have one or more position id(ne10) per token(ne02) GGML_ASSERT(ne10 % ne02 == 0); GGML_ASSERT(ne10 >= ne02); @@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node( memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + // mrope + const int sect_0 = ((const int32_t *) dst->op_params)[11]; + const int sect_1 = ((const int32_t *) dst->op_params)[12]; + const int sect_2 = ((const int32_t *) dst->op_params)[13]; + const int sect_3 = ((const int32_t *) dst->op_params)[14]; id pipeline = nil; - if (!is_neox) { + if (is_neox) { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_mrope && !is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } else { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } @@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node( /*.attn_factor =*/ attn_factor, /*.beta_fast =*/ beta_fast, /*.beta_slow =*/ beta_slow, + /* sect_0 =*/ sect_0, + /* sect_1 =*/ sect_1, + /* sect_2 =*/ sect_2, + /* sect_3 =*/ sect_3, }; [encoder setComputePipelineState:pipeline]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a808e79c..9cfddf45 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox( } } +template +kernel void kernel_rope_multi( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + // mrope theta calculations + // note: the rest is the same as kernel_rope_neox + const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3; + const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1 + const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 + const int sector = ic % sect_dims; + + float theta_base; + if (sector < args.sect_0) { + theta_base = (float) pos[i2]; + } else if (sector < sec_w01) { + theta_base = (float) pos[i2 + args.ne02]; + } else if (sector < sec_w012) { + theta_base = (float) pos[i2 + args.ne02 * 2]; + } else { + theta_base = (float) pos[i2 + args.ne02 * 3]; + } + // end of mrope + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_vision( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < 2*args.n_dims) { // different from kernel_rope_multi + const int ic = i0/2; + + // mrope theta calculations (only support 2 dimensions) + const int sect_dims = args.sect_0 + args.sect_1; + const int sector = ic % sect_dims; + + float p; + float theta_base; + if (sector < args.sect_1) { + p = (float) sector; + theta_base = (float) pos[i2]; + } else { + p = (float) sector - args.sect_0; + theta_base = (float) pos[i2 + args.ne02]; + } + + const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); + // end of mrope + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims]; // different from kernel_rope_multi + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + typedef decltype(kernel_rope_norm) kernel_rope_norm_t; typedef decltype(kernel_rope_neox) kernel_rope_neox_t; +typedef decltype(kernel_rope_multi) kernel_rope_multi_t; +typedef decltype(kernel_rope_vision) kernel_rope_vision_t; template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; @@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi; +template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi; + +template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; +template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; + typedef void (im2col_t)( device const float * x, device char * dst,