mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-16 00:57:53 +02:00
opencl: add fused rms_norm_mul
(llama/14841)
* opencl: add fused `rms_norm` + `mul` * opencl: improve workgroup size for `rms_norm_mul`
This commit is contained in:
@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
|
|||||||
size_t max_alloc_size;
|
size_t max_alloc_size;
|
||||||
bool fp16_support;
|
bool fp16_support;
|
||||||
bool has_vector_subgroup_broadcast;
|
bool has_vector_subgroup_broadcast;
|
||||||
|
bool disable_fusion;
|
||||||
ggml_cl_compiler_version adreno_cl_compiler_version;
|
ggml_cl_compiler_version adreno_cl_compiler_version;
|
||||||
|
|
||||||
int adreno_wave_size;
|
int adreno_wave_size;
|
||||||
@ -411,7 +412,7 @@ struct ggml_backend_opencl_context {
|
|||||||
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
|
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
|
||||||
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
|
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
|
||||||
cl_kernel kernel_norm;
|
cl_kernel kernel_norm;
|
||||||
cl_kernel kernel_rms_norm;
|
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
|
||||||
cl_kernel kernel_group_norm;
|
cl_kernel kernel_group_norm;
|
||||||
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
|
||||||
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
cl_kernel kernel_soft_max, kernel_soft_max_4;
|
||||||
@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||||||
backend_ctx->program_rms_norm =
|
backend_ctx->program_rms_norm =
|
||||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||||
|
|
||||||
CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
|
CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err));
|
||||||
GGML_LOG_CONT(".");
|
GGML_LOG_CONT(".");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
|
|||||||
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
|
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
|
||||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||||
|
|
||||||
|
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
|
||||||
|
|
||||||
dev_ctx->backend_ctx = backend_ctx.release();
|
dev_ctx->backend_ctx = backend_ctx.release();
|
||||||
return dev_ctx->backend_ctx;
|
return dev_ctx->backend_ctx;
|
||||||
}
|
}
|
||||||
@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
|
|||||||
sync_with_other_backends(backend_ctx);
|
sync_with_other_backends(backend_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
||||||
|
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||||
|
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
||||||
|
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
|
||||||
|
|
||||||
|
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
// rms_norm only supports f32
|
||||||
|
if (mul->src[0]->type != GGML_TYPE_F32 ||
|
||||||
|
mul->src[1]->type != GGML_TYPE_F32 ||
|
||||||
|
mul->type != GGML_TYPE_F32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// if rms_norm is the B operand, then we don't handle broadcast
|
||||||
|
if (rms_norm == mul->src[1] &&
|
||||||
|
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// rms_norm assumes contiguous rows
|
||||||
|
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
|
||||||
|
|
||||||
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||||
|
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
bool ok = ggml_cl_compute_forward(backend, node);
|
bool ok = ggml_cl_compute_forward(backend, node);
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||||
@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
|
|||||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
|
||||||
|
GGML_ASSERT(mul_tensor);
|
||||||
|
GGML_ASSERT(rms_norm_tensor);
|
||||||
|
|
||||||
|
// src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
|
||||||
|
const ggml_tensor * src0 = rms_norm_tensor->src[0];
|
||||||
|
const ggml_tensor * src1;
|
||||||
|
if (mul_tensor->src[0] == rms_norm_tensor) {
|
||||||
|
src1 = mul_tensor->src[1];
|
||||||
|
} else if (mul_tensor->src[1] == rms_norm_tensor) {
|
||||||
|
src1 = mul_tensor->src[0];
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "Invalid args for rms_norm and mul");
|
||||||
|
}
|
||||||
|
const ggml_tensor * dst = mul_tensor;
|
||||||
|
|
||||||
|
GGML_ASSERT(src0);
|
||||||
|
GGML_ASSERT(src0->extra);
|
||||||
|
GGML_ASSERT(src1);
|
||||||
|
GGML_ASSERT(src1->extra);
|
||||||
|
GGML_ASSERT(dst);
|
||||||
|
GGML_ASSERT(dst->extra);
|
||||||
|
|
||||||
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||||
|
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||||
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||||
|
|
||||||
|
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||||
|
cl_ulong offset1 = extra1->offset + src0->view_offs;
|
||||||
|
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||||
|
|
||||||
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
|
||||||
|
|
||||||
|
const int ne00 = src0->ne[0];
|
||||||
|
const int ne01 = src0->ne[1];
|
||||||
|
const int ne02 = src0->ne[2];
|
||||||
|
const int ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const cl_ulong nb01 = src0->nb[1];
|
||||||
|
const cl_ulong nb02 = src0->nb[2];
|
||||||
|
const cl_ulong nb03 = src0->nb[3];
|
||||||
|
|
||||||
|
const int ne10 = src1->ne[0];
|
||||||
|
const int ne11 = src1->ne[1];
|
||||||
|
const int ne12 = src1->ne[2];
|
||||||
|
const int ne13 = src1->ne[3];
|
||||||
|
|
||||||
|
const cl_ulong nb11 = src1->nb[1];
|
||||||
|
const cl_ulong nb12 = src1->nb[2];
|
||||||
|
const cl_ulong nb13 = src1->nb[3];
|
||||||
|
|
||||||
|
const cl_ulong nb1 = dst->nb[1];
|
||||||
|
const cl_ulong nb2 = dst->nb[2];
|
||||||
|
const cl_ulong nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
|
||||||
|
size_t sgs;
|
||||||
|
if (backend_ctx->gpu_family == ADRENO) {
|
||||||
|
sgs = 64;
|
||||||
|
} else if (backend_ctx->gpu_family == INTEL) {
|
||||||
|
sgs = 32;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "Unsupported GPU");
|
||||||
|
}
|
||||||
|
|
||||||
|
cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
|
||||||
|
|
||||||
|
int nth = sgs;
|
||||||
|
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
|
||||||
|
while (nth < ne00 && nth < max_workgroup_size) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
nth = MIN(nth, max_workgroup_size);
|
||||||
|
nth = MIN(nth, ne00);
|
||||||
|
|
||||||
|
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
|
||||||
|
size_t local_work_size[] = {(size_t)nth, 1, 1};
|
||||||
|
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
|
||||||
|
|
||||||
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
GGML_ASSERT(src0);
|
GGML_ASSERT(src0);
|
||||||
GGML_ASSERT(src0->extra);
|
GGML_ASSERT(src0->extra);
|
||||||
|
@ -94,3 +94,82 @@ kernel void kernel_rms_norm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
// rms_norm_mul
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
#ifdef INTEL_GPU
|
||||||
|
REQD_SUBGROUP_SIZE_32
|
||||||
|
#elif defined (ADRENO_GPU)
|
||||||
|
REQD_SUBGROUP_SIZE_64
|
||||||
|
#endif
|
||||||
|
kernel void kernel_rms_norm_mul(
|
||||||
|
global char * src0,
|
||||||
|
ulong offset0,
|
||||||
|
global char * src1,
|
||||||
|
ulong offset1,
|
||||||
|
global char * dst,
|
||||||
|
ulong offsetd,
|
||||||
|
int ne00,
|
||||||
|
int ne01,
|
||||||
|
int ne02,
|
||||||
|
int ne03,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
|
int ne10,
|
||||||
|
int ne11,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
|
ulong nb13,
|
||||||
|
ulong nb1,
|
||||||
|
ulong nb2,
|
||||||
|
ulong nb3,
|
||||||
|
float eps,
|
||||||
|
local float * sum
|
||||||
|
) {
|
||||||
|
src0 = src0 + offset0;
|
||||||
|
src1 = src1 + offset1;
|
||||||
|
dst = dst + offsetd;
|
||||||
|
|
||||||
|
int i03 = get_group_id(2);
|
||||||
|
int i02 = get_group_id(1);
|
||||||
|
int i01 = get_group_id(0);
|
||||||
|
|
||||||
|
global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
|
||||||
|
global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
|
||||||
|
|
||||||
|
float sumf = 0;
|
||||||
|
|
||||||
|
// parallel sum
|
||||||
|
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||||
|
sumf += dot(x[i00], x[i00]);
|
||||||
|
}
|
||||||
|
sumf = sub_group_reduce_add(sumf);
|
||||||
|
if (get_sub_group_local_id() == 0) {
|
||||||
|
sum[get_sub_group_id()] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
|
|
||||||
|
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
|
||||||
|
if (get_local_id(0) < i) {
|
||||||
|
sum[get_local_id(0)] += sum[get_local_id(0) + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (get_local_id(0) == 0) {
|
||||||
|
sum[0] /= ne00;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier(CLK_LOCAL_MEM_FENCE);
|
||||||
|
|
||||||
|
float mean = sum[0];
|
||||||
|
float scale = 1.0f/sqrt(mean + eps);
|
||||||
|
|
||||||
|
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
|
||||||
|
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
|
||||||
|
y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user