diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 2d37e271..2a789edf 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -19,5 +19,8 @@ #include "dmmv.hpp" #include "mmq.hpp" #include "mmvq.hpp" +#include "rope.hpp" +#include "norm.hpp" +#include "softmax.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 414c37ee..9a1c161b 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -17,6 +17,7 @@ #include #include "dpct/helper.hpp" +#include "ggml-sycl.h" #include "presets.hpp" #define GGML_COMMON_DECL_SYCL @@ -46,10 +47,6 @@ static int g_ggml_sycl_debug = 0; } \ }() -// #define DEBUG_SYCL_MALLOC - -static int g_work_group_size = 0; -// typedef sycl::half ggml_fp16_t; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP #define VER_4VEC 610 // todo for hardward optimize. @@ -192,6 +189,8 @@ struct ggml_sycl_device_info { sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {}; std::array default_tensor_split = {}; + + int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0}; }; const ggml_sycl_device_info & ggml_sycl_info(); @@ -294,5 +293,57 @@ struct ggml_backend_sycl_context { } }; +// common device functions + +static __dpct_inline__ float warp_reduce_sum(float x, + const sycl::nd_item<3>& item_ct1) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + /* + DPCT1096:98: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + x += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), x, mask); + } + return x; +} + +static __dpct_inline__ sycl::float2 +warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + a.x() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.x(), + mask); + a.y() += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), a.y(), + mask); + } + return a; +} + +static __dpct_inline__ float warp_reduce_max(float x, + const sycl::nd_item<3>& item_ct1) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + /* + DPCT1096:97: The right-most dimension of the work-group used in the SYCL + kernel that calls this function may be less than "32". The function + "dpct::permute_sub_group_by_xor" may return an unexpected result on the + CPU device. Modify the size of the work-group to ensure that the value + of the right-most dimension is a multiple of "32". + */ + x = sycl::fmax(x, dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), x, mask)); + } + return x; +} + +// Helper for vec loading aligned data +template +inline sycl::vec vec_aligned_load(const Tp* aligned_ptr) { + return *reinterpret_cast*>(aligned_ptr); +} #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index ce9de2b4..a15271b5 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -152,12 +152,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); - stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), [=](sycl::nd_item<3> item_ct1) { - dequantize_block_q4_K(vx, y, item_ct1); + dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1); }); + }); } } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index b6080d83..ed8ad098 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri #if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { - d = q[j] & 63; m = q[j + 4] & 63; + d = q[j] & 63; + m = q[j + 4] & 63; } else { d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); @@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8 template static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, - const sycl::nd_item<3> &item_ct1) { + uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { const block_q4_K * x = (const block_q4_K *) vx; const int i = item_ct1.get_group(2); @@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri dst_t * y = yy + i*QK_K + 64*il + n*ir; - const float dall = x[i].dm[0]; - const float dmin = x[i].dm[1]; + const sycl::half2 dm = x[i].dm; + const float dall = dm[0]; + const float dmin = dm[1]; - const uint8_t * q = x[i].qs + 32*il + n*ir; + if (tid < 12) + scales_local[tid] = x[i].scales[tid]; + item_ct1.barrier(sycl::access::fence_space::local_space); uint8_t sc, m; - get_scale_min_k4(is + 0, x[i].scales, sc, m); - const float d1 = dall * sc; const float m1 = dmin * m; - get_scale_min_k4(is + 1, x[i].scales, sc, m); - const float d2 = dall * sc; const float m2 = dmin * m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + sycl::vec q_vec = vec_aligned_load(x[i].qs + 32*il + n*ir); for (int l = 0; l < n; ++l) { - y[l + 0] = d1 * (q[l] & 0xF) - m1; - y[l +32] = d2 * (q[l] >> 4) - m2; + y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; + y[l +32] = d2 * (q_vec[l] >> 4) - m2; } #else const int tid = item_ct1.get_local_id(2); diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 3a87d3ef..70a94fc1 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,6 +3,7 @@ #include "dequantize.hpp" #include "presets.hpp" + static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -76,7 +77,7 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -104,7 +105,7 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1); }); @@ -227,7 +228,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -346,7 +347,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -499,7 +500,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -633,7 +634,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -748,7 +749,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -774,7 +775,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { dequantize_mul_mat_vec( vx, y, dst, ncols, nrows, item_ct1); }); @@ -795,7 +796,7 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { dequantize_mul_mat_vec( vx, y, dst, ncols, nrows, item_ct1); }); @@ -816,7 +817,7 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { dequantize_mul_mat_vec( vx, y, dst, ncols, nrows, item_ct1); }); @@ -837,7 +838,7 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { dequantize_mul_mat_vec( vx, y, dst, ncols, nrows, item_ct1); }); @@ -858,7 +859,7 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { dequantize_mul_mat_vec( vx, y, dst, ncols, nrows, item_ct1); }); @@ -873,10 +874,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y, const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, ny, 32); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1); }); } @@ -889,10 +890,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, ny, 32); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1); }); } @@ -905,10 +906,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, ny, 32); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1); }); } @@ -918,10 +919,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const sycl::range<3> block_dims(1, 1, 32); + const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE); stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1); }); } @@ -934,10 +935,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, const int ny = 2 / K_QUANTS_PER_ITERATION; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, ny, 32); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1); }); } diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 1ff29721..31df1cb9 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -255,7 +255,7 @@ namespace dpct void set_pitch(size_t pitch) { _pitch = pitch; } size_t get_x() { return _x; } - void set_x(size_t x) { _x = x; }; + void set_x(size_t x) { _x = x; } size_t get_y() { return _y; } void set_y(size_t y) { _y = y; } @@ -1056,7 +1056,7 @@ namespace dpct #error "Only support Windows and Linux." #endif next_free = mapped_address_space; - }; + } public: using buffer_id_t = int; @@ -1077,7 +1077,7 @@ namespace dpct #else #error "Only support Windows and Linux." #endif - }; + } mem_mgr(const mem_mgr &) = delete; mem_mgr &operator=(const mem_mgr &) = delete; @@ -2426,6 +2426,7 @@ namespace dpct b, ldb, beta, c, ldc, batch_size); break; } +#endif case detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): @@ -2458,7 +2459,6 @@ namespace dpct batch_size); break; } -#endif case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): @@ -2595,6 +2595,7 @@ namespace dpct stride_c, batch_size); break; } +#endif case detail::get_type_combination_id( library_data_t::real_int8, library_data_t::real_int8, library_data_t::real_int32, library_data_t::real_int32): @@ -2623,7 +2624,6 @@ namespace dpct beta, c, ldc, stride_c, batch_size); break; } -#endif case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_float): diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 23227649..3fbc4dd6 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -37,7 +37,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -85,7 +85,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -133,7 +133,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -181,7 +181,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -229,7 +229,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -277,7 +277,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -325,7 +325,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -373,7 +373,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -421,7 +421,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -470,7 +470,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -495,7 +495,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -519,7 +519,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -543,7 +543,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -567,7 +567,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -591,7 +591,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -615,7 +615,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -639,7 +639,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -663,7 +663,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -687,7 +687,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -711,7 +711,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -734,8 +734,8 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq2_xxs_q8_1( + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_xxs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); }); @@ -759,8 +759,8 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq2_xs_q8_1( + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_xs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); }); @@ -784,8 +784,8 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq2_s_q8_1( + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq2_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); }); @@ -809,8 +809,8 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq3_xxs_q8_1( + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq3_xxs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); }); @@ -833,8 +833,8 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq3_s_q8_1( + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq3_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); }); @@ -858,7 +858,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q_iq1_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -879,7 +879,7 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q_iq1_m_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -901,7 +901,7 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { + [[intel::reqd_sub_group_size(WARP_SIZE)]] { mul_mat_vec_q_iq4_nl_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -923,8 +923,8 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q_iq4_xs_q8_1( + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_iq4_xs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); }); @@ -936,7 +936,7 @@ void ggml_sycl_op_mul_mat_vec_q( const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_row_size, + const int64_t src1_ncols, const int64_t src1_padded_col_size, const dpct::queue_ptr &stream) { const int64_t ne10 = src1->ne[0]; @@ -948,77 +948,80 @@ void ggml_sycl_op_mul_mat_vec_q( int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff; - - switch (src0->type) { + for (int i = 0; i < src1_ncols; i++) + { + const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; + const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; + float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; + switch (src0->type) { case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; default: GGML_ASSERT(false); break; + } } - (void) src1; (void) dst; (void) src1_ddf_i; - (void) src1_ncols; - (void) src1_padded_row_size; } diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp new file mode 100644 index 00000000..e0c5dfec --- /dev/null +++ b/ggml/src/ggml-sycl/norm.cpp @@ -0,0 +1,374 @@ +#include "norm.hpp" + +static void norm_f32(const float* x, float* dst, const int ncols, const float eps, + const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + const int tid = item_ct1.get_local_id(2); + + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + assert(nwarps % WARP_SIZE == 0); + sycl::float2 mean_var = sycl::float2(0.f, 0.f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row * ncols + col]; + mean_var.x() += xi; + mean_var.y() += xi * xi; + } + + // sum up partial sums + mean_var = warp_reduce_sum(mean_var, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = mean_var; + } + /* + DPCT1118:0: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + item_ct1.barrier(sycl::access::fence_space::local_space); + mean_var = 0.f; + int nreduce = nwarps / WARP_SIZE; + for (size_t i = 0; i < nreduce; i += 1) + { + mean_var += s_sum[lane_id + i * WARP_SIZE]; + } + mean_var = warp_reduce_sum(mean_var, item_ct1); + } + + const float mean = mean_var.x() / ncols; + const float var = mean_var.y() / ncols - mean * mean; + const float inv_std = sycl::rsqrt(var + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std; + } +} + +static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps, + const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + int start = item_ct1.get_group(2) * group_size; + int end = start + group_size; + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + assert(nwarps % WARP_SIZE == 0); + start += item_ct1.get_local_id(2); + int nreduce = nwarps / WARP_SIZE; + + if (end >= ne_elements) { + end = ne_elements; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += block_size) { + tmp += x[j]; + } + + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + /* + DPCT1118:1: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + /* + DPCT1065:54: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + float mean = tmp / group_size; + tmp = 0.0f; + + for (int j = start; j < end; j += block_size) { + float xi = x[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + /* + DPCT1118:2: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + /* + DPCT1065:55: Consider replacing sycl::nd_item::barrier() with + sycl::nd_item::barrier(sycl::access::fence_space::local_space) for + better performance if there is no access to global memory. + */ + item_ct1.barrier(); + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + float variance = tmp / group_size; + float scale = sycl::rsqrt(variance + eps); + for (int j = start; j < end; j += block_size) { + dst[j] *= scale; + } +} + +static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps, + const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + const int tid = item_ct1.get_local_id(2); + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + assert(nwarps % WARP_SIZE == 0); + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row * ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + /* + DPCT1118:3: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + item_ct1.barrier(sycl::access::fence_space::local_space); + int nreduce = nwarps / WARP_SIZE; + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + const float mean = tmp / ncols; + const float scale = sycl::rsqrt(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row * ncols + col] = scale * x[row * ncols + col]; + } +} + +static void norm_f32_sycl(const float* x, float* dst, const int ncols, + const int nrows, const float eps, + queue_ptr stream, int device) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + norm_f32(x, dst, ncols, eps, item_ct1, + nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:17: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1( + sycl::range<1>(work_group_size / WARP_SIZE), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + norm_f32(x, dst, ncols, eps, item_ct1, + s_sum_acc_ct1.get_pointer(), work_group_size); + }); + }); + } +} + +static void group_norm_f32_sycl(const float* x, float* dst, + const int num_groups, const int group_size, + const int ne_elements, queue_ptr stream, int device) { + static const float eps = 1e-6f; + if (group_size < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + const float eps_ct4 = eps; + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + group_norm_f32( + x, dst, group_size, ne_elements, eps_ct4, item_ct1, + nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:18: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + cgh); + + const float eps_ct4 = eps; + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + group_norm_f32(x, dst, group_size, ne_elements, + eps_ct4, item_ct1, + s_sum_acc_ct1.get_pointer(), work_group_size); + }); + }); + } +} + +static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, + const int nrows, const float eps, + queue_ptr stream, int device) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); + if (ncols < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + rms_norm_f32(x, dst, ncols, eps, item_ct1, + nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:19: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + cgh); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + rms_norm_f32(x, dst, ncols, eps, item_ct1, + s_sum_acc_ct1.get_pointer(), work_group_size); + }); + }); + } +} + +void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, + ggml_tensor* dst, const float* src0_dd, + const float* src1_dd, float* dst_dd, + const queue_ptr& main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + + (void)src1; + (void)dst; + (void)src1_dd; +} + +void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst, + const float* src0_dd, const float* src1_dd, + float* dst_dd, + const queue_ptr& main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + int num_groups = dst->op_params[0]; + int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device); + + (void)src1; + (void)dst; + (void)src1_dd; +} + +void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst, + const float* src0_dd, const float* src1_dd, + float* dst_dd, + const queue_ptr& main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + + (void)src1; + (void)dst; + (void)src1_dd; +} diff --git a/ggml/src/ggml-sycl/norm.hpp b/ggml/src/ggml-sycl/norm.hpp new file mode 100644 index 00000000..a9ad9156 --- /dev/null +++ b/ggml/src/ggml-sycl/norm.hpp @@ -0,0 +1,35 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_NORM_HPP +#define GGML_SYCL_NORM_HPP + +#include "common.hpp" + +void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, + ggml_tensor* dst, const float* src0_dd, + const float* src1_dd, float* dst_dd, + const queue_ptr& main_stream); + +void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst, + const float* src0_dd, const float* src1_dd, + float* dst_dd, + const queue_ptr& main_stream); + +void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst, + const float* src0_dd, const float* src1_dd, + float* dst_dd, + const queue_ptr& main_stream); + +#endif // GGML_SYCL_NORM_HPP diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index 5e6b6181..15ddcac1 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -15,10 +15,8 @@ #define GGML_SYCL_MAX_STREAMS 8 #define GGML_SYCL_MAX_BUFFERS 256 -#define GGML_SYCL_MAX_DEVICES 48 -#define GGML_SYCL_NAME "SYCL" -#define WARP_SIZE 32 +#define WARP_SIZE GGML_SYCL_WARP_SIZE #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #define SYCL_GELU_BLOCK_SIZE 256 @@ -64,4 +62,5 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA #define MUL_MAT_SRC1_COL_STRIDE 128 +#define QK_WARP_SIZE 32 #endif // GGML_SYCL_PRESETS_HPP diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp new file mode 100644 index 00000000..eabf1693 --- /dev/null +++ b/ggml/src/ggml-sycl/rope.cpp @@ -0,0 +1,275 @@ +#include "rope.hpp" + +struct rope_corr_dims { + float v[2]; +}; + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low); + return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale); + } + *cos_theta = sycl::cos(theta) * mscale; + *sin_theta = sycl::sin(theta) * mscale; +} + +template +static void rope_norm( + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, + const sycl::nd_item<3> &item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne0) { + return; + } + + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i0 >= n_dims) { + const int i = row*ne0 + i0; + + dst[i + 0] = x[i + 0]; + dst[i + 1] = x[i + 1]; + + return; + } + + const int i = row*ne0 + i0; + const int i2 = row/p_delta_rows; + + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + 1]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + 1] = x0*sin_theta + x1*cos_theta; +} + +template +static void rope_neox( + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, + const sycl::nd_item<3> &item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne0) { + return; + } + + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i0 >= n_dims) { + const int i = row*ne0 + i0; + + dst[i + 0] = x[i + 0]; + dst[i + 1] = x[i + 1]; + + return; + } + + const int i = row*ne0 + i0/2; + const int i2 = row/p_delta_rows; + + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + const float x0 = x[i + 0]; + const float x1 = x[i + n_dims/2]; + + dst[i + 0] = x0*cos_theta - x1*sin_theta; + dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; +} + +template +static void rope_norm_sycl( + const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); + const sycl::range<3> block_nums(1, num_blocks_x, nr); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + if (freq_factors == nullptr) { + /* + DPCT1049:40: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, + ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, + item_ct1); + }); + } else { + /* + DPCT1049:41: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, + ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, + item_ct1); + }); + } +} + +template +static void rope_neox_sycl( + const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); + const sycl::range<3> block_nums(1, num_blocks_x, nr); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + if (freq_factors == nullptr) { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, + item_ct1); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, n_dims, pos, freq_scale, + p_delta_rows, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, + item_ct1); + }); + } +} + +void ggml_sycl_op_rope( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream) { + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t nr = ggml_nrows(src0); + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + // RoPE alteration for extended context + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & 2; + + const int32_t * pos = (const int32_t *) src1_dd; + + const float * freq_factors = nullptr; + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; + } + + rope_corr_dims corr_dims; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); + + // compute + if (is_neox) { + if (src0->type == GGML_TYPE_F32) { + rope_neox_sycl( + (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, main_stream + ); + } else if (src0->type == GGML_TYPE_F16) { + rope_neox_sycl( + (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, main_stream + ); + } else { + GGML_ASSERT(false); + } + } else { + if (src0->type == GGML_TYPE_F32) { + rope_norm_sycl( + (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, main_stream + ); + } else if (src0->type == GGML_TYPE_F16) { + rope_norm_sycl( + (const sycl::half *)src0_dd, (sycl::half *)dst_dd, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, main_stream + ); + } else { + GGML_ASSERT(false); + } + } + + (void) src1; + (void) dst; + (void) src1_dd; +} diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp new file mode 100644 index 00000000..00354c31 --- /dev/null +++ b/ggml/src/ggml-sycl/rope.hpp @@ -0,0 +1,22 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_ROPE_HPP +#define GGML_SYCL_ROPE_HPP + +#include "common.hpp" + +void ggml_sycl_op_rope( + ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, float *dst_dd, const queue_ptr &main_stream); + +#endif // GGML_SYCL_ROPE_HPP diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp new file mode 100644 index 00000000..e624b6ba --- /dev/null +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -0,0 +1,250 @@ +#include "norm.hpp" + +template +static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, + const int nrows_y, const float scale, const float max_bias, const float m0, + const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { + const int ncols = ncols_template == 0 ? ncols_par : ncols_template; + + const int tid = item_ct1.get_local_id(2); + const int rowx = item_ct1.get_group(2); + const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension + + const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; + + const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + const int nthreads = block_size; + const int nwarps = nthreads / WARP_SIZE; + int nreduce = nwarps / WARP_SIZE; + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const uint32_t h = rowx/nrows_y; // head index + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = sycl::pow(base, float(exp)); + } + + float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols; + float max_val = -INFINITY; + + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const int ix = rowx*ncols + col; + const int iy = rowy*ncols + col; + + const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); + + vals[col] = val; + max_val = sycl::max(max_val, val); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val, item_ct1); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf[lane_id] = -INFINITY; + for (size_t i = 1; i < nreduce; i += 1) + buf[lane_id + i * WARP_SIZE] = -INFINITY; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (lane_id == 0) { + buf[warp_id] = max_val; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + max_val = buf[lane_id]; + for (size_t i = 1; i < nreduce; i += 1) + { + max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); + } + max_val = warp_reduce_max(max_val, item_ct1); + } + + float tmp = 0.f; +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = sycl::native::exp(vals[col] - max_val); + tmp += val; + vals[col] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + item_ct1.barrier(sycl::access::fence_space::local_space); + if (warp_id == 0) { + buf[lane_id] = 0.f; + for (size_t i = 1; i < nreduce; i += 1) + buf[lane_id + i * WARP_SIZE] = 0.f; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (lane_id == 0) { + buf[warp_id] = tmp; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + tmp = buf[lane_id]; + for (size_t i = 1; i < nreduce; i += 1) + { + tmp += buf[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + const float inv_sum = 1.f / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + const int idst = rowx*ncols + col; + dst[idst] = vals[col] * inv_sum; + } +} + +template +static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, + const int nrows_y, const float scale, const float max_bias, const float m0, + const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, + const size_t n_local_scratch, queue_ptr stream) { + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor local_buf_acc(n_local_scratch, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + soft_max_f32(x, mask, dst, ncols_par, + nrows_y, scale, max_bias, m0, + m1, n_head_log2, item_ct1, + local_buf_acc.get_pointer()); + }); + }); +} + +static void soft_max_f32_sycl(const float * x, const float * mask, + float * dst, const int ncols_x, const int nrows_x, + const int nrows_y, const float scale, const float max_bias, + queue_ptr stream, int device) { + int nth = WARP_SIZE; + int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; + while (nth < ncols_x && nth < max_block_size) nth *= 2; + if (nth>max_block_size) nth = max_block_size; + + const sycl::range<3> block_dims(1, 1, nth); + const sycl::range<3> block_nums(1, 1, nrows_x); + const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE); + + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const size_t local_mem_size = stream->get_device().get_info(); + if (n_local_scratch*sizeof(float) < local_mem_size) { + if (ncols_x > max_block_size) { + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + return; + } + switch (ncols_x) { + case 32: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 64: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 128: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 256: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 512: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 1024: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 2048: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + case 4096: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + default: + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, n_local_scratch, stream); + break; + } + } else { + soft_max_f32_submitter(x, mask, dst, ncols_x, nrows_y, scale, + max_bias, m0, m1, n_head_log2, block_nums, + block_dims, WARP_SIZE, stream); + } +} + +void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + +#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") +#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + + soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, + nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); +} diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp new file mode 100644 index 00000000..bdb8f712 --- /dev/null +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -0,0 +1,24 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_SOFTMAX_HPP +#define GGML_SYCL_SOFTMAX_HPP + +#include "common.hpp" + +void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream); + +#endif // GGML_SYCL_SOFTMAX_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 5e2e8254..d2dccade 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -820,7 +820,6 @@ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq, #if QK_K == 256 const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; -#if QR2_XXS == 8 const int ib32 = iqs; const uint16_t * q2 = bq2->qs + 4*ib32; const uint8_t * aux8 = (const uint8_t *)q2; @@ -838,26 +837,6 @@ vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq, } const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f; return d * sumi; -#else - // iqs is 0...15 - const int ib32 = iqs/2; - const int il = iqs%2; - const uint16_t * q2 = bq2->qs + 4*ib32; - const uint8_t * aux8 = (const uint8_t *)q2; - const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]); - const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]); - const uint32_t aux32 = q2[2] | (q2[3] << 16); - const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * bq8_1[ib32].ds[0] * 0.25f; - const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127]; - const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127]; - const int8_t * q8 = bq8_1[ib32].qs + 16*il; - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < 8; ++j) { - sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1); - sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1); - } - return d * (sumi1 + sumi2); -#endif #else assert(false); return 0.f;