diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 63ac4a98..c87a3238 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -333,6 +333,7 @@ struct ggml_backend_opencl_context { size_t max_alloc_size; bool fp16_support; bool has_vector_subgroup_broadcast; + bool disable_fusion; ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; @@ -411,7 +412,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick, kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; cl_kernel kernel_norm; - cl_kernel kernel_rms_norm; + cl_kernel kernel_rms_norm, kernel_rms_norm_mul; cl_kernel kernel_group_norm; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; @@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve backend_ctx->program_rms_norm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err)); GGML_LOG_CONT("."); } @@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err)); #endif // GGML_OPENCL_USE_ADRENO_KERNELS + backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; + dev_ctx->backend_ctx = backend_ctx.release(); return dev_ctx->backend_ctx; } @@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) { sync_with_other_backends(backend_ctx); } +static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + return false; + } + + // rms_norm assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + + return true; +} + +static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor); + static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]); + i++; + continue; + } + bool ok = ggml_cl_compute_forward(backend, node); if (!ok) { GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); @@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) { + GGML_ASSERT(mul_tensor); + GGML_ASSERT(rms_norm_tensor); + + // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm) + const ggml_tensor * src0 = rms_norm_tensor->src[0]; + const ggml_tensor * src1; + if (mul_tensor->src[0] == rms_norm_tensor) { + src1 = mul_tensor->src[1]; + } else if (mul_tensor->src[1] == rms_norm_tensor) { + src1 = mul_tensor->src[0]; + } else { + GGML_ASSERT(false && "Invalid args for rms_norm and mul"); + } + const ggml_tensor * dst = mul_tensor; + + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + float eps; + memcpy(&eps, rms_norm_tensor->op_params, sizeof(float)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + GGML_ASSERT(ne00 % 4 == 0); + + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + cl_kernel kernel = backend_ctx->kernel_rms_norm_mul; + + int nth = sgs; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + while (nth < ne00 && nth < max_workgroup_size) { + nth *= 2; + } + nth = MIN(nth, max_workgroup_size); + nth = MIN(nth, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl index 9d21f339..ecd053cb 100644 --- a/ggml/src/ggml-opencl/kernels/rms_norm.cl +++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -94,3 +94,82 @@ kernel void kernel_rms_norm( } } } + +//------------------------------------------------------------------------------ +// rms_norm_mul +//------------------------------------------------------------------------------ +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, + float eps, + local float * sum +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11); + + float sumf = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += dot(x[i00], x[i00]); + } + sumf = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + float mean = sum[0]; + float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1); + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = (x[i00] * scale) * f[i00%(ne10/4)]; + } +}