SYCL: revert "sycl: simplify bin_bcast_kernel (ggml/13383)" (llama/13752)

Temporarily reverted due to failing fp16 DIV operation

This reverts commit 02cdd2d8b092b5a4bb18e013c6887ce49ba20ac5.

ggml-ci
This commit is contained in:
Akarshan Biswas 2025-05-25 12:38:37 +05:30 committed by Georgi Gerganov
parent 99596d6031
commit 1cd7028428

View File

@ -1,74 +1,93 @@
#include "binbcast.hpp" #include "binbcast.hpp"
#include <array>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <sycl/sycl.hpp> #include <sycl/sycl.hpp>
#include "dpct/helper.hpp"
#include "ggml.h" #include "ggml.h"
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t> template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __dpct_inline__ void k_bin_bcast_contiguous(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
dst_t * dst, std::size_t num_elements, const sycl::nd_item<1> & it) { int ne0, int ne1, int ne2, int ne3,
auto element_id = it.get_global_id(0); int ne10, int ne11, int ne12, int ne13,
auto global_range = it.get_global_range(0); /*int s0, */ int s1, int s2, int s3,
for (; element_id < num_elements; element_id += global_range) { /*int s00,*/ int s01, int s02, int s03,
auto src0_float_val = sycl::vec(src0[element_id]).template convert<float, sycl::rounding_mode::rte>(); /*int s10,*/ int s11, int s12, int s13,
auto src1_float_val = sycl::vec(src1[element_id]).template convert<float, sycl::rounding_mode::rte>(); const sycl::nd_item<3> &item_ct1) {
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]); const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>(); item_ct1.get_local_id(2);
dst[element_id] = val_to_store; const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
item_ct1.get_local_id(1));
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
item_ct1.get_local_id(0)) /
ne3;
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
item_ct1.get_local_id(0)) %
ne3;
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0;
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
} }
} }
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t> template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __dpct_inline__ void k_bin_bcast(const src0_t * __restrict__ src0, const src1_t * __restrict__ src1, dst_t * dst, static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, int ne0, int ne1, int ne2, int ne3,
int s0, int s1, int s2, int s3, int s00, int s01, int s02, int s03, int s10, int ne10, int ne11, int ne12, int ne13,
int s11, int s12, int s13, std::size_t num_dst_elements, /*int s0, */ int s1, int s2, int s3,
const sycl::nd_item<1> & item_ct1) { /*int s00,*/ int s01, int s02, int s03,
auto calculate_logical_index = /*int s10,*/ int s11, int s12, int s13,
[](const std::array<int, 4> & dims, std::size_t element_id) __attribute__((always_inline))->std::array<int, 4> { const sycl::nd_item<3> &item_ct1) {
std::array<int, 4> logical_index;
#pragma unroll(4)
for (int i = 3; i >= 0; i--) {
logical_index[i] = element_id % dims[i];
element_id /= dims[i];
}
return logical_index;
};
auto calculate_index = [](const std::array<int, 4> & dims, const std::array<int, 4> & strides, const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
const std::array<int, 4> & indices) __attribute__((always_inline)) item_ct1.get_local_id(2);
->std::size_t {
std::size_t index = 0;
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
auto index_i = indices[i];
if (indices[i] >= dims[i]) {
index_i = indices[i] % dims[i];
}
index += strides[i] * index_i;
}
return index;
};
auto element_id = item_ct1.get_global_id(0); const int i3 = i/(ne2*ne1*ne0);
for (; element_id < num_dst_elements; element_id += item_ct1.get_global_range(0)) { const int i2 = (i/(ne1*ne0)) % ne2;
auto logical_index = calculate_logical_index({ ne3, ne2, ne1, ne0 }, element_id); const int i1 = (i/ne0) % ne1;
auto src_0_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s03, s02, s01, s00 }, logical_index); const int i0 = i % ne0;
auto src_1_index = calculate_index({ ne13, ne12, ne11, ne10 }, { s13, s12, s11, s10 }, logical_index);
auto dst_index = calculate_index({ ne3, ne2, ne1, ne0 }, { s3, s2, s1, s0 }, logical_index); if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
auto src0_float_val = sycl::vec(src0[src_0_index]).template convert<float, sycl::rounding_mode::rte>(); return;
auto src1_float_val = sycl::vec(src1[src_1_index]).template convert<float, sycl::rounding_mode::rte>();
float dst_val = bin_op(src0_float_val[0], src1_float_val[0]);
auto val_to_store = sycl::vec(dst_val).template convert<dst_t, sycl::rounding_mode::rte>();
dst[dst_index] = val_to_store;
} }
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
} }
template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
template<float (*bin_op)(const float, const float)>
struct bin_bcast_sycl {
template <typename src0_t, typename src1_t, typename dst_t> template <typename src0_t, typename src1_t, typename dst_t>
void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00, void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
@ -77,17 +96,90 @@ template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
auto check_bcast_required = [](const std::array<int64_t, 4> & src_dims, int nr0 = ne10 / ne0;
const std::array<int64_t, 4> & dst_dims) -> bool { int nr1 = ne11/ne1;
for (int i = 0; i < 4; i++) { int nr2 = ne12/ne2;
if (dst_dims[i] > src_dims[i]) { int nr3 = ne13/ne3;
return true;
} int nr[4] = { nr0, nr1, nr2, nr3 };
}
return false; // collapse dimensions until first broadcast dimension
int64_t cne[] = {ne0, ne1, ne2, ne3};
int64_t cne0[] = {ne00, ne01, ne02, ne03};
int64_t cne1[] = {ne10, ne11, ne12, ne13};
size_t cnb[] = {nb0, nb1, nb2, nb3};
size_t cnb0[] = {nb00, nb01, nb02, nb03};
size_t cnb1[] = {nb10, nb11, nb12, nb13};
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
cne[1] = cne[2];
cne[2] = cne[3];
cne[3] = 1;
}; };
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
cnb[1] *= cne[1];
cnb[2] *= cne[2];
cnb[3] *= cne[3];
};
if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0);
collapse(cne1);
}
}
}
{
int64_t ne0 = cne[0];
int64_t ne1 = cne[1];
int64_t ne2 = cne[2];
int64_t ne3 = cne[3];
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
size_t nb3 = cnb[3];
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
GGML_UNUSED(s00);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0); GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0); GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
@ -104,48 +196,67 @@ template <float (*bin_op)(const float, const float)> struct bin_bcast_sycl {
GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
// dst strides in number of elements GGML_ASSERT(s0 == 1);
size_t s0 = nb0 / sizeof(dst_t); GGML_ASSERT(s10 == 1);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
// src1 strides in number of elements const int block_size = 128;
size_t s10 = nb10 / sizeof(src0_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
// src0 strides in number of elements int64_t hne0 = std::max(ne0/2LL, 1LL);
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
std::size_t num_dst_elements = static_cast<std::size_t>(ne0) * static_cast<std::size_t>(ne1) * sycl::range<3> block_dims(1, 1, 1);
static_cast<std::size_t>(ne2) * static_cast<std::size_t>(ne3); block_dims[2] = std::min<unsigned int>(hne0, block_size);
std::size_t local_range = 256; block_dims[1] = std::min<unsigned int>(
std::size_t global_range = ceil_div(num_dst_elements, local_range) * local_range; ne1, block_size / (unsigned int)block_dims[2]);
block_dims[0] = std::min(
std::min<unsigned int>(
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
(unsigned int)block_dims[1]),
64U);
bool needs_broadcasting = check_bcast_required({ ne00, ne01, ne02, ne03 }, { ne0, ne1, ne2, ne3 }) || sycl::range<3> block_nums(
check_bcast_required({ ne10, ne11, ne12, ne13 }, { ne0, ne1, ne2, ne3 }); (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
bool all_contiguous = src0_is_contiguous && src1_is_contiguous && dst_is_contiguous; (ne1 + block_dims[1] - 1) / block_dims[1],
(hne0 + block_dims[2] - 1) / block_dims[2]);
if (! needs_broadcasting && all_contiguous) { if (block_nums[0] > 65535) {
stream->submit([&](sycl::handler & cgh) { // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
k_bin_bcast_contiguous<bin_op>(src0_dd, src1_dd, dst_dd, num_dst_elements, it); {
}); dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
sycl::range<3>(1, 1, block_size),
sycl::range<3>(1, 1, block_size)),
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast_unravel<bin_op>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
s03, s11, s12, s13, item_ct1);
}); });
}
} else { } else {
stream->submit([&](sycl::handler & cgh) { /*
cgh.parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), [=](sycl::nd_item<1> it) { DPCT1049:16: The work-group size passed to the SYCL kernel may
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, s0, s1, exceed the limit. To get the device limit, query
s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, num_dst_elements, it); info::device::max_work_group_size. Adjust the work-group size if
}); needed.
*/
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
ne2, ne3, ne10, ne11, ne12, ne13,
s1, s2, s3, s01, s02, s03, s11, s12, s13,
item_ct1);
}); });
} }
} }
}
}; };
template <class op> template <class op>