sycl: use oneDNN for matrices multiplication (llama/12972)

This commit is contained in:
Łukasz Ślusarczyk 2025-05-15 16:53:41 +02:00 committed by Georgi Gerganov
parent 8e9bf548f4
commit d807c497a4
4 changed files with 189 additions and 95 deletions

View File

@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL "ggml: use SYCL" OFF)
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device") "ggml: sycl target device")
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING

View File

@ -49,9 +49,10 @@ endif()
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
# Link against oneDNN # Link against oneDNN
find_package(DNNL)
set(GGML_SYCL_DNNL 0) set(GGML_SYCL_DNNL 0)
if(DNNL_FOUND) if(GGML_SYCL_DNN)
find_package(DNNL)
if(DNNL_FOUND)
if (NOT DEFINED DNNL_GPU_VENDOR) if (NOT DEFINED DNNL_GPU_VENDOR)
# default to intel target # default to intel target
set(DNNL_GPU_VENDOR "INTEL") set(DNNL_GPU_VENDOR "INTEL")
@ -75,8 +76,11 @@ if(DNNL_FOUND)
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
Disabling oneDNN support.") Disabling oneDNN support.")
endif() endif()
else() else()
message(STATUS "oneDNN not found, disabling oneDNN support") message(STATUS "oneDNN not found, disabling oneDNN support")
endif()
else()
message(STATUS "oneDNN support disabled by the user")
endif() endif()
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})

View File

@ -32,16 +32,36 @@ public:
else static_assert(0); else static_assert(0);
} }
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, // matrix A has m rows, k columns
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { // matrix B has k rows, n columns
// nra - number of elements to skip when moving into next row in A
// nrb - number of elements to skip when moving into next row in B
// nca - number of elements to skip when moving into next column in A
// ncb - number of elements to skip when moving into next column in B
// stride_a - number of elements to skip when moving to next A matrix
// stride_b - number of elements to skip when moving to next B matrix
// batches_a - number of A matrices
// batches_b - number of B matrices
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
auto stream = ctx.stream_dnnl(q); auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q); auto eng = ctx.engine_dnnl(q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n }; // { # strides, # rows, # columns }
dnnl::memory::dims c_dims = { m, n }; dnnl::memory::dims a_dims = { batches_a, m, k };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); dnnl::memory::dims b_dims = { batches_b, k, n };
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
dnnl::memory::dims a_strides = { stride_a, nra, nca };
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
dnnl::primitive_attr primitive_attr; dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
@ -63,6 +83,15 @@ public:
matmul_prim.execute(stream, matmul_args); matmul_prim.execute(stream, matmul_args);
} }
// matrices A and B are column major, both having k rows
// matrix A has m column, matrix B has n columns
// output: column major matrix C = A transposed * B
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
}
}; };
#endif #endif

View File

@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
int g_ggml_sycl_debug = 0; int g_ggml_sycl_debug = 0;
int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_optimize = 0;
int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_graph = 0;
int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_prioritize_dmmv = 0;
static ggml_sycl_device_info ggml_sycl_init() { static ggml_sycl_device_info ggml_sycl_init() {
@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
#ifdef GGML_SYCL_GRAPH
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
#endif
#if GGML_SYCL_DNNL
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
#else
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
GGML_LOG_INFO("Build with Macros:\n"); GGML_LOG_INFO("Build with Macros:\n");
#if defined(GGML_SYCL_FORCE_MMQ) #if defined(GGML_SYCL_FORCE_MMQ)
@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne00 == ne10);
const int64_t row_diff = row_high - row_low; const int64_t row_diff = row_high - row_low;
int id; int id;
SYCL_CHECK( SYCL_CHECK(
CHECK_TRY_ERROR(id = get_current_device_id())); CHECK_TRY_ERROR(id = get_current_device_id()));
#if !GGML_SYCL_DNNL
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0]; // used by MKL only
// the main device has a larger memory buffer to hold the results from all GPUs // the main device has a larger memory buffer to hold the results from all GPUs
// ldc == nrows of the matrix that cuBLAS writes into // ldc == nrows of the matrix that cuBLAS writes into
int ldc = id == ctx.device ? ne0 : row_diff; int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
#endif
#ifdef GGML_SYCL_F16 #ifdef GGML_SYCL_F16
bool use_fp16 = true; // TODO(Yu) SYCL capability check bool use_fp16 = true; // TODO(Yu) SYCL capability check
@ -2033,7 +2043,17 @@ inline void ggml_sycl_op_mul_mat_sycl(
: src1_as_f16.get(); : src1_as_f16.get();
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols); ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
#if !GGML_SYCL_DNNL #if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
}
else
#endif
{
const sycl::half alpha_f16 = 1.0f; const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f; const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
@ -2045,13 +2065,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
dpct::library_data_t::real_half))); dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else }
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
#endif
} }
else { else {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
#if !GGML_SYCL_DNNL #if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
#endif
{
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
dpct::get_value(&beta, *stream), dst_dd_i, ldc))); dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
#else }
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
#endif
} }
GGML_UNUSED(dst); GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i); GGML_UNUSED(src1_ddq_i);
@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
std::exit(1); std::exit(1);
} }
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst, static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16); const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16); const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst); uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
} }
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool()); ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
char * dst_t = reinterpret_cast<char *>(dst_ddf);
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float; dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
@ -2783,17 +2801,57 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0); GGML_ASSERT(ne13 % ne03 == 0);
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
GGML_ASSERT(ne10 == ne00);
// broadcast factors // broadcast factors
const int64_t r2 = ne12 / ne02; const int64_t r2 = ne12 / ne02;
const int64_t r3 = ne13 / ne03; const int64_t r3 = ne13 / ne03;
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
};
if (r2 == 1 && r3 == 1) {
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
}
else {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
}
}
} else {
// iterate over batches from smaller set of matrices (matrix 0)
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
}
}
}
}
else
#endif
{
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t, src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
} else { } else {
const int ne23 = ne12 * ne13; const int ne23 = ne12 * ne13;
@ -2809,7 +2867,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
}); });
}); });
@ -2820,6 +2878,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
} }
}
} catch (const sycl::exception & exc) { } catch (const sycl::exception & exc) {
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
std::exit(1); std::exit(1);
@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream())); sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
model_sycl_graph.begin_recording(*(sycl_ctx->stream())); model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
model_sycl_graph.end_recording(); model_sycl_graph.end_recording();