ggml : add mrope kernel for metal (llama/13457)

This commit is contained in:
Xuan-Son Nguyen 2025-05-13 13:10:08 +03:00 committed by Georgi Gerganov
parent 41ed62bdbc
commit 75e9a840c5
3 changed files with 192 additions and 16 deletions

View File

@ -207,6 +207,10 @@ typedef struct {
float attn_factor;
float beta_fast;
float beta_slow;
int32_t sect_0;
int32_t sect_1;
int32_t sect_2;
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {

View File

@ -332,6 +332,10 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@ -1275,6 +1279,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
return false;
}
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
return true;
}
return true;
case GGML_OP_IM2COL:
return op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_POOL_1D:
@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node(
} break;
case GGML_OP_ROPE:
{
// make sure we have one or more position id(ne10) per token(ne02)
GGML_ASSERT(ne10 % ne02 == 0);
GGML_ASSERT(ne10 >= ne02);
@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node(
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
// mrope
const int sect_0 = ((const int32_t *) dst->op_params)[11];
const int sect_1 = ((const int32_t *) dst->op_params)[12];
const int sect_2 = ((const int32_t *) dst->op_params)[13];
const int sect_3 = ((const int32_t *) dst->op_params)[14];
id<MTLComputePipelineState> pipeline = nil;
if (!is_neox) {
if (is_neox) {
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
default: GGML_ABORT("fatal error");
};
} else if (is_mrope && !is_vision) {
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
default: GGML_ABORT("fatal error");
};
} else if (is_vision) {
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
default: GGML_ABORT("fatal error");
};
} else {
switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
default: GGML_ABORT("fatal error");
};
}
@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node(
/*.attn_factor =*/ attn_factor,
/*.beta_fast =*/ beta_fast,
/*.beta_slow =*/ beta_slow,
/* sect_0 =*/ sect_0,
/* sect_1 =*/ sect_1,
/* sect_2 =*/ sect_2,
/* sect_3 =*/ sect_3,
};
[encoder setComputePipelineState:pipeline];

View File

@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
}
}
template<typename T>
kernel void kernel_rope_multi(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
// mrope theta calculations
// note: the rest is the same as kernel_rope_neox
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
// end of mrope
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_vision(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
const int ic = i0/2;
// mrope theta calculations (only support 2 dimensions)
const int sect_dims = args.sect_0 + args.sect_1;
const int sector = ic % sect_dims;
float p;
float theta_base;
if (sector < args.sect_1) {
p = (float) sector;
theta_base = (float) pos[i2];
} else {
p = (float) sector - args.sect_0;
theta_base = (float) pos[i2 + args.ne02];
}
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
typedef void (im2col_t)(
device const float * x,
device char * dst,