diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4f765ab5..b32d5da3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2520,8 +2520,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: - // TODO: support attention sinks [TAG_ATTN_SINKS] - return op->src[2] == nullptr; case GGML_OP_NORM: case GGML_OP_RMS_NORM: return true; @@ -6594,17 +6592,24 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(src1->extra); } + const ggml_tensor * src2 = dst->src[2]; + if (src2) { + GGML_ASSERT(src2->extra); + } + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr; + ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr; cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; + cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -6672,25 +6677,27 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_head_log2)); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl index a6d8ede6..571d1650 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_4_f16( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_4_f16( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_4_f16( global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_4_f16( } // parallel max - float4 lmax4 = -INFINITY; + float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); } @@ -92,7 +96,11 @@ kernel void kernel_soft_max_4_f16( } float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { pdst4[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl index 35b5573b..1f944b22 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_4( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_4( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_4( global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_4( } // parallel max - float4 lmax4 = -INFINITY; + float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -92,7 +96,11 @@ kernel void kernel_soft_max_4( } float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { pdst4[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl index 9d292b57..4baa6c28 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max_f16( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max_f16( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max_f16( global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max_f16( } // parallel max - float lmax = -INFINITY; + float lmax = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -91,7 +95,11 @@ kernel void kernel_soft_max_f16( pdst[i00] = exp_psrc0; } - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { pdst[i00] /= sum; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl index 7c53dfbe..d503190b 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -26,6 +26,8 @@ kernel void kernel_soft_max( ulong offset0, global char * src1, ulong offset1, + global char * src2, + ulong offset2, global char * dst, ulong offsetd, int ne00, @@ -48,6 +50,7 @@ kernel void kernel_soft_max( ) { src0 = src0 + offset0; src1 = src1 + offset1; + src2 = src2 + offset2; dst = dst + offsetd; int i03 = get_group_id(2); @@ -60,6 +63,7 @@ kernel void kernel_soft_max( global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * psrc2 = src2 != src0 ? (global float *)(src2) : 0; global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; @@ -75,7 +79,7 @@ kernel void kernel_soft_max( } // parallel max - float lmax = -INFINITY; + float lmax = psrc2 ? psrc2[i02] : -INFINITY; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); } @@ -91,7 +95,11 @@ kernel void kernel_soft_max( pdst[i00] = exp_psrc0; } - const float sum = sub_group_reduce_add(lsum); + float sum = sub_group_reduce_add(lsum); + + if (psrc2) { + sum += exp(psrc2[i02] - max); + } for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { pdst[i00] /= sum;