From 1cd702842889833b554e1496c9175fdc744c5601 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Sun, 25 May 2025 12:38:37 +0530 Subject: [PATCH] SYCL: revert "sycl: simplify bin_bcast_kernel (ggml/13383)" (llama/13752) Temporarily reverted due to failing fp16 DIV operation This reverts commit 02cdd2d8b092b5a4bb18e013c6887ce49ba20ac5. ggml-ci --- ggml/src/ggml-sycl/binbcast.cpp | 333 +++++++++++++++++++++----------- 1 file changed, 222 insertions(+), 111 deletions(-) diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index aaa94176..0a9d3a92 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -1,74 +1,93 @@ #include "binbcast.hpp" -#include #include #include #include -#include "dpct/helper.hpp" #include "ggml.h" -template -static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, - dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) { - auto element_id = it.get_global_id(0); - auto global_range = it.get_global_range(0); - for (; element_id < num_elements; element_id += global_range) { - auto src0_float_val = sycl::vec(src0[element_id]).template convert(); - auto src1_float_val = sycl::vec(src1[element_id]).template convert(); - float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); - auto val_to_store = sycl::vec(dst_val).template convert(); - dst[element_id] = val_to_store; +template +static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) / + ne3; + const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) % + ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; + i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } } -template -static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, - int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10, - int s11, int s12, int s13, std::size_t num_dst_elements, - const sycl::nd_item<1> & item_ct1) { - auto calculate_logical_index = - [](const std::array & dims, std::size_t element_id) __attribute__((always_inline))->std::array { - std::array logical_index; -#pragma unroll(4) - for (int i = 3; i >= 0; i--) { - logical_index[i] = element_id % dims[i]; - element_id /= dims[i]; - } - return logical_index; - }; +template +static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { - auto calculate_index = [](const std::array & dims, const std::array & strides, - const std::array & indices) __attribute__((always_inline)) - ->std::size_t { - std::size_t index = 0; -#pragma unroll(4) - for (int i = 0; i < 4; i++) { - auto index_i = indices[i]; - if (indices[i] >= dims[i]) { - index_i = indices[i] % dims[i]; - } - index += strides[i] * index_i; - } - return index; - }; + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); - auto element_id = item_ct1.get_global_id(0); - for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) { - auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id); - auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index); - auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index); - auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index); - auto src0_float_val = sycl::vec(src0[src_0_index]).template convert(); - auto src1_float_val = sycl::vec(src1[src_1_index]).template convert(); - float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); - auto val_to_store = sycl::vec(dst_val).template convert(); - dst[dst_index] = val_to_store; + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } -template struct bin_bcast_sycl { + +template +struct bin_bcast_sycl { template void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, @@ -77,73 +96,165 @@ template struct bin_bcast_sycl { const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { - auto check_bcast_required = [](const std::array & src_dims, - const std::array & dst_dims) -> bool { - for (int i = 0; i < 4; i++) { - if (dst_dims[i] > src_dims[i]) { - return true; - } - } - return false; + int nr0 = ne10 / ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne[] = {ne0, ne1, ne2, ne3}; + int64_t cne0[] = {ne00, ne01, ne02, ne03}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb[] = {nb0, nb1, nb2, nb3}; + size_t cnb0[] = {nb00, nb01, nb02, nb03}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; }; - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; - // dst strides in number of elements - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; - // src1 strides in number of elements - size_t s10 = nb10 / sizeof(src0_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; - // src0 strides in number of elements - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); - std::size_t num_dst_elements = static_cast(ne0) * static_cast(ne1) * - static_cast(ne2) * static_cast(ne3); - std::size_t local_range = 256; - std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range; + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); - bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) || - check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 }); - bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous; + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); - if (! needs_broadcasting && all_contiguous) { - stream->submit([&](sycl::handler & cgh) { - cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { - k_bin_bcast_contiguous(src0_dd, src1_dd, dst_dd, num_dst_elements, it); - }); - }); - } else { - stream->submit([&](sycl::handler & cgh) { - cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { - k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1, - s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it); - }); - }); + GGML_UNUSED(s00); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + sycl::range<3> block_dims(1, 1, 1); + block_dims[2] = std::min(hne0, block_size); + block_dims[1] = std::min( + ne1, block_size / (unsigned int)block_dims[2]); + block_dims[0] = std::min( + std::min( + ne2 * ne3, block_size / (unsigned int)block_dims[2] / + (unsigned int)block_dims[1]), + 64U); + + sycl::range<3> block_nums( + (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], + (ne1 + block_dims[1] - 1) / block_dims[1], + (hne0 + block_dims[2] - 1) / block_dims[2]); + + if (block_nums[0] > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast_unravel( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, + s03, s11, s12, s13, item_ct1); + }); + } + } else { + /* + DPCT1049:16: The work-group size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if + needed. + */ + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, + ne2, ne3, ne10, ne11, ne12, ne13, + s1, s2, s3, s01, s02, s03, s11, s12, s13, + item_ct1); + }); + } } } };