sycl: use oneDNN for matrices multiplication (llama/12972)

This commit is contained in:
Łukasz Ślusarczyk 2025-05-15 16:53:41 +02:00 committed by Georgi Gerganov
parent 8e9bf548f4
commit d807c497a4
4 changed files with 189 additions and 95 deletions

View File

@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL "ggml: use SYCL" OFF)
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device") "ggml: sycl target device")
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING

View File

@ -49,34 +49,38 @@ endif()
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
# Link against oneDNN # Link against oneDNN
find_package(DNNL)
set(GGML_SYCL_DNNL 0) set(GGML_SYCL_DNNL 0)
if(DNNL_FOUND) if(GGML_SYCL_DNN)
if (NOT DEFINED DNNL_GPU_VENDOR) find_package(DNNL)
# default to intel target if(DNNL_FOUND)
set(DNNL_GPU_VENDOR "INTEL") if (NOT DEFINED DNNL_GPU_VENDOR)
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") # default to intel target
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") set(DNNL_GPU_VENDOR "INTEL")
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
endif()
endif() endif()
endif()
# Verify oneDNN was compiled for the same target as llama # Verify oneDNN was compiled for the same target as llama
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
set(GGML_SYCL_DNNL 1) set(GGML_SYCL_DNNL 1)
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
foreach(CONFIG ${CONFIGS}) foreach(CONFIG ${CONFIGS})
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
message(STATUS "Found oneDNN: ${DNNL_LIB}") message(STATUS "Found oneDNN: ${DNNL_LIB}")
endforeach() endforeach()
else()
message(WARNING
"oneDNN must be compiled for the same target as llama.cpp.
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
Disabling oneDNN support.")
endif()
else() else()
message(WARNING message(STATUS "oneDNN not found, disabling oneDNN support")
"oneDNN must be compiled for the same target as llama.cpp.
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
Disabling oneDNN support.")
endif() endif()
else() else()
message(STATUS "oneDNN not found, disabling oneDNN support") message(STATUS "oneDNN support disabled by the user")
endif() endif()
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})

View File

@ -32,16 +32,36 @@ public:
else static_assert(0); else static_assert(0);
} }
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, // matrix A has m rows, k columns
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { // matrix B has k rows, n columns
// nra - number of elements to skip when moving into next row in A
// nrb - number of elements to skip when moving into next row in B
// nca - number of elements to skip when moving into next column in A
// ncb - number of elements to skip when moving into next column in B
// stride_a - number of elements to skip when moving to next A matrix
// stride_b - number of elements to skip when moving to next B matrix
// batches_a - number of A matrices
// batches_b - number of B matrices
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
auto stream = ctx.stream_dnnl(q); auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q); auto eng = ctx.engine_dnnl(q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n }; // { # strides, # rows, # columns }
dnnl::memory::dims c_dims = { m, n }; dnnl::memory::dims a_dims = { batches_a, m, k };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); dnnl::memory::dims b_dims = { batches_b, k, n };
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
dnnl::memory::dims a_strides = { stride_a, nra, nca };
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
dnnl::primitive_attr primitive_attr; dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
@ -63,6 +83,15 @@ public:
matmul_prim.execute(stream, matmul_args); matmul_prim.execute(stream, matmul_args);
} }
// matrices A and B are column major, both having k rows
// matrix A has m column, matrix B has n columns
// output: column major matrix C = A transposed * B
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
}
}; };
#endif #endif

View File

@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0; int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_graph = 0;
int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_prioritize_dmmv = 0;
static ggml_sycl_device_info ggml_sycl_init() { static ggml_sycl_device_info ggml_sycl_init() {
@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
#ifdef GGML_SYCL_GRAPH
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
#endif
#if GGML_SYCL_DNNL
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
GGML_LOG_INFO("Build with Macros:\n"); GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ) #if defined(GGML_SYCL_FORCE_MMQ)
@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne00 == ne10);
const int64_t row_diff = row_high - row_low; const int64_t row_diff = row_high - row_low;
int id; int id;
SYCL_CHECK( SYCL_CHECK(
CHECK_TRY_ERROR(id = get_current_device_id())); CHECK_TRY_ERROR(id = get_current_device_id()));
#if !GGML_SYCL_DNNL
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0]; // used by MKL only
// the main device has a larger memory buffer to hold the results from all GPUs // the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into // ldc == nrows of the matrix that cuBLAS writes into
int ldc = id == ctx.device ? ne0 : row_diff; int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
#endif
#ifdef GGML_SYCL_F16 #ifdef GGML_SYCL_F16
bool use_fp16 = true; // TODO(Yu) SYCL capability check bool use_fp16 = true; // TODO(Yu) SYCL capability check
@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
: src1_as_f16.get(); : src1_as_f16.get();
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols); ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
#if !GGML_SYCL_DNNL #if GGML_SYCL_DNNL
const sycl::half alpha_f16 = 1.0f; if (!g_ggml_sycl_disable_dnn) {
const sycl::half beta_f16 = 0.0f; DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
*stream, oneapi::math::transpose::trans, dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, }
dst_f16.get(), dpct::library_data_t::real_half, ldc, else
dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
#endif #endif
{
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
dst_f16.get(), dpct::library_data_t::real_half, ldc,
dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
}
} }
else { else {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
#if !GGML_SYCL_DNNL #if GGML_SYCL_DNNL
const float alpha = 1.0f; if (!g_ggml_sycl_disable_dnn) {
const float beta = 0.0f; DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, }
dpct::get_value(&beta, *stream), dst_dd_i, ldc))); else
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
#endif #endif
{
const float alpha = 1.0f;
const float beta = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
}
} }
GGML_UNUSED(dst); GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i); GGML_UNUSED(src1_ddq_i);
@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst, static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16); const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16); const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst); uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
} }
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool()); ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
char * dst_t = reinterpret_cast<char *>(dst_ddf);
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float; dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0); GGML_ASSERT(ne13 % ne03 == 0);
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
GGML_ASSERT(ne10 == ne00);
// broadcast factors // broadcast factors
const int64_t r2 = ne12 / ne02; const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03; const int64_t r3 = ne13 / ne03;
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { #if GGML_SYCL_DNNL
// there is no broadcast and src0, src1 are contiguous across dims 2, 3 if (!g_ggml_sycl_disable_dnn) {
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23); DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23); src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1); src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
};
sycl::range<3> block_dims(1, ne12, ne13); if (r2 == 1 && r3 == 1) {
queue->submit([&](sycl::handler & cgh) { if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
const void ** ptrs_src_get = ptrs_src.get(); dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
void ** ptrs_dst_get = ptrs_dst.get(); }
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); else {
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
}
}
} else {
// iterate over batches from smaller set of matrices (matrix 0)
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
}
}
}
}
else
#endif
{
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else {
const int ne23 = ne12 * ne13;
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
sycl::range<3> block_dims(1, ne12, ne13);
queue->submit([&](sycl::handler & cgh) {
const void ** ptrs_src_get = ptrs_src.get();
void ** ptrs_dst_get = ptrs_dst.get();
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
});
}); });
});
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
}
} }
} catch (const sycl::exception & exc) { } catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream())); sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
model_sycl_graph.begin_recording(*(sycl_ctx->stream())); model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording(); model_sycl_graph.end_recording();