mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-20 09:47:59 +02:00
sycl: use oneDNN for matrices multiplication (llama/12972)
This commit is contained in:
parent
8e9bf548f4
commit
d807c497a4
@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
|
|||||||
option(GGML_SYCL "ggml: use SYCL" OFF)
|
option(GGML_SYCL "ggml: use SYCL" OFF)
|
||||||
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
||||||
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
||||||
|
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
|
||||||
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
||||||
"ggml: sycl target device")
|
"ggml: sycl target device")
|
||||||
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
||||||
|
@ -49,34 +49,38 @@ endif()
|
|||||||
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
||||||
|
|
||||||
# Link against oneDNN
|
# Link against oneDNN
|
||||||
find_package(DNNL)
|
|
||||||
set(GGML_SYCL_DNNL 0)
|
set(GGML_SYCL_DNNL 0)
|
||||||
if(DNNL_FOUND)
|
if(GGML_SYCL_DNN)
|
||||||
if (NOT DEFINED DNNL_GPU_VENDOR)
|
find_package(DNNL)
|
||||||
# default to intel target
|
if(DNNL_FOUND)
|
||||||
set(DNNL_GPU_VENDOR "INTEL")
|
if (NOT DEFINED DNNL_GPU_VENDOR)
|
||||||
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
# default to intel target
|
||||||
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
set(DNNL_GPU_VENDOR "INTEL")
|
||||||
|
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
||||||
|
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
endif()
|
|
||||||
|
|
||||||
# Verify oneDNN was compiled for the same target as llama
|
# Verify oneDNN was compiled for the same target as llama
|
||||||
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
||||||
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
||||||
set(GGML_SYCL_DNNL 1)
|
set(GGML_SYCL_DNNL 1)
|
||||||
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
||||||
foreach(CONFIG ${CONFIGS})
|
foreach(CONFIG ${CONFIGS})
|
||||||
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
||||||
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
||||||
endforeach()
|
endforeach()
|
||||||
|
else()
|
||||||
|
message(WARNING
|
||||||
|
"oneDNN must be compiled for the same target as llama.cpp.
|
||||||
|
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
||||||
|
Disabling oneDNN support.")
|
||||||
|
endif()
|
||||||
else()
|
else()
|
||||||
message(WARNING
|
message(STATUS "oneDNN not found, disabling oneDNN support")
|
||||||
"oneDNN must be compiled for the same target as llama.cpp.
|
|
||||||
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
|
||||||
Disabling oneDNN support.")
|
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
message(STATUS "oneDNN not found, disabling oneDNN support")
|
message(STATUS "oneDNN support disabled by the user")
|
||||||
endif()
|
endif()
|
||||||
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
|
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
|
||||||
|
|
||||||
|
@ -32,16 +32,36 @@ public:
|
|||||||
else static_assert(0);
|
else static_assert(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
|
// matrix A has m rows, k columns
|
||||||
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
// matrix B has k rows, n columns
|
||||||
|
// nra - number of elements to skip when moving into next row in A
|
||||||
|
// nrb - number of elements to skip when moving into next row in B
|
||||||
|
// nca - number of elements to skip when moving into next column in A
|
||||||
|
// ncb - number of elements to skip when moving into next column in B
|
||||||
|
// stride_a - number of elements to skip when moving to next A matrix
|
||||||
|
// stride_b - number of elements to skip when moving to next B matrix
|
||||||
|
// batches_a - number of A matrices
|
||||||
|
// batches_b - number of B matrices
|
||||||
|
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
||||||
|
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
|
||||||
|
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
|
||||||
|
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
|
||||||
|
|
||||||
auto stream = ctx.stream_dnnl(q);
|
auto stream = ctx.stream_dnnl(q);
|
||||||
auto eng = ctx.engine_dnnl(q);
|
auto eng = ctx.engine_dnnl(q);
|
||||||
dnnl::memory::dims a_dims = { m, k };
|
|
||||||
dnnl::memory::dims b_dims = { k, n };
|
// { # strides, # rows, # columns }
|
||||||
dnnl::memory::dims c_dims = { m, n };
|
dnnl::memory::dims a_dims = { batches_a, m, k };
|
||||||
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
dnnl::memory::dims b_dims = { batches_b, k, n };
|
||||||
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
|
||||||
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
|
||||||
|
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
|
||||||
|
dnnl::memory::dims a_strides = { stride_a, nra, nca };
|
||||||
|
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
|
||||||
|
|
||||||
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
|
||||||
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
|
||||||
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
|
||||||
|
|
||||||
dnnl::primitive_attr primitive_attr;
|
dnnl::primitive_attr primitive_attr;
|
||||||
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||||
@ -63,6 +83,15 @@ public:
|
|||||||
|
|
||||||
matmul_prim.execute(stream, matmul_args);
|
matmul_prim.execute(stream, matmul_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// matrices A and B are column major, both having k rows
|
||||||
|
// matrix A has m column, matrix B has n columns
|
||||||
|
// output: column major matrix C = A transposed * B
|
||||||
|
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
||||||
|
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
||||||
|
|
||||||
|
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
|
|||||||
int g_ggml_sycl_debug = 0;
|
int g_ggml_sycl_debug = 0;
|
||||||
int g_ggml_sycl_disable_optimize = 0;
|
int g_ggml_sycl_disable_optimize = 0;
|
||||||
int g_ggml_sycl_disable_graph = 0;
|
int g_ggml_sycl_disable_graph = 0;
|
||||||
|
int g_ggml_sycl_disable_dnn = 0;
|
||||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||||
|
|
||||||
static ggml_sycl_device_info ggml_sycl_init() {
|
static ggml_sycl_device_info ggml_sycl_init() {
|
||||||
@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
|
|||||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||||
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
||||||
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
||||||
|
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
||||||
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
||||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||||
GGML_LOG_INFO("Running with Environment Variables:\n");
|
GGML_LOG_INFO("Running with Environment Variables:\n");
|
||||||
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
||||||
|
#ifdef GGML_SYCL_GRAPH
|
||||||
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
||||||
|
#else
|
||||||
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
|
||||||
|
#endif
|
||||||
|
#if GGML_SYCL_DNNL
|
||||||
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
|
||||||
|
#else
|
||||||
|
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
||||||
|
#endif
|
||||||
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
||||||
GGML_LOG_INFO("Build with Macros:\n");
|
GGML_LOG_INFO("Build with Macros:\n");
|
||||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||||
@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = src1->ne[0];
|
||||||
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
|
||||||
const int64_t row_diff = row_high - row_low;
|
const int64_t row_diff = row_high - row_low;
|
||||||
|
|
||||||
int id;
|
int id;
|
||||||
SYCL_CHECK(
|
SYCL_CHECK(
|
||||||
CHECK_TRY_ERROR(id = get_current_device_id()));
|
CHECK_TRY_ERROR(id = get_current_device_id()));
|
||||||
#if !GGML_SYCL_DNNL
|
|
||||||
const int64_t ne0 = dst->ne[0];
|
const int64_t ne0 = dst->ne[0]; // used by MKL only
|
||||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||||
// ldc == nrows of the matrix that cuBLAS writes into
|
// ldc == nrows of the matrix that cuBLAS writes into
|
||||||
int ldc = id == ctx.device ? ne0 : row_diff;
|
int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef GGML_SYCL_F16
|
#ifdef GGML_SYCL_F16
|
||||||
bool use_fp16 = true; // TODO(Yu) SYCL capability check
|
bool use_fp16 = true; // TODO(Yu) SYCL capability check
|
||||||
@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|||||||
: src1_as_f16.get();
|
: src1_as_f16.get();
|
||||||
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
||||||
|
|
||||||
#if !GGML_SYCL_DNNL
|
#if GGML_SYCL_DNNL
|
||||||
const sycl::half alpha_f16 = 1.0f;
|
if (!g_ggml_sycl_disable_dnn) {
|
||||||
const sycl::half beta_f16 = 0.0f;
|
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||||
*stream, oneapi::math::transpose::trans,
|
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
||||||
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||||
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||||
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
}
|
||||||
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
else
|
||||||
dpct::library_data_t::real_half)));
|
|
||||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
|
||||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
|
||||||
#else
|
|
||||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
|
|
||||||
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
|
||||||
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
|
||||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
|
||||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
|
||||||
#endif
|
#endif
|
||||||
|
{
|
||||||
|
const sycl::half alpha_f16 = 1.0f;
|
||||||
|
const sycl::half beta_f16 = 0.0f;
|
||||||
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||||
|
*stream, oneapi::math::transpose::trans,
|
||||||
|
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
|
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
||||||
|
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
||||||
|
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
||||||
|
dpct::library_data_t::real_half)));
|
||||||
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
||||||
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
||||||
@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|||||||
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
||||||
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
||||||
|
|
||||||
#if !GGML_SYCL_DNNL
|
#if GGML_SYCL_DNNL
|
||||||
const float alpha = 1.0f;
|
if (!g_ggml_sycl_disable_dnn) {
|
||||||
const float beta = 0.0f;
|
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
||||||
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
||||||
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
}
|
||||||
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
else
|
||||||
#else
|
|
||||||
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
|
||||||
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
|
||||||
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
|
||||||
#endif
|
#endif
|
||||||
|
{
|
||||||
|
const float alpha = 1.0f;
|
||||||
|
const float beta = 0.0f;
|
||||||
|
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
||||||
|
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
||||||
|
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
||||||
|
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
GGML_UNUSED(dst);
|
GGML_UNUSED(dst);
|
||||||
GGML_UNUSED(src1_ddq_i);
|
GGML_UNUSED(src1_ddq_i);
|
||||||
@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
|
|||||||
std::exit(1);
|
std::exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
|
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
|
||||||
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
||||||
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
||||||
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
||||||
@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
|
|||||||
|
|
||||||
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
||||||
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
||||||
uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
|
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
|
||||||
|
|
||||||
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
||||||
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
||||||
@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|||||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||||
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
||||||
char * dst_t = reinterpret_cast<char *>(dst_ddf);
|
|
||||||
|
|
||||||
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
||||||
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
||||||
@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
|
||||||
|
GGML_ASSERT(ne10 == ne00);
|
||||||
|
|
||||||
// broadcast factors
|
// broadcast factors
|
||||||
const int64_t r2 = ne12 / ne02;
|
const int64_t r2 = ne12 / ne02;
|
||||||
const int64_t r3 = ne13 / ne03;
|
const int64_t r3 = ne13 / ne03;
|
||||||
|
|
||||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
#if GGML_SYCL_DNNL
|
||||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
if (!g_ggml_sycl_disable_dnn) {
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
|
||||||
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
|
||||||
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
|
||||||
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
|
|
||||||
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
|
||||||
} else {
|
|
||||||
const int ne23 = ne12 * ne13;
|
|
||||||
|
|
||||||
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
|
||||||
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
|
||||||
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
|
||||||
|
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
|
||||||
|
};
|
||||||
|
|
||||||
sycl::range<3> block_dims(1, ne12, ne13);
|
if (r2 == 1 && r3 == 1) {
|
||||||
queue->submit([&](sycl::handler & cgh) {
|
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||||
const void ** ptrs_src_get = ptrs_src.get();
|
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
|
||||||
void ** ptrs_dst_get = ptrs_dst.get();
|
}
|
||||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
else {
|
||||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
||||||
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
|
||||||
k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
|
||||||
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
|
||||||
|
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// iterate over batches from smaller set of matrices (matrix 0)
|
||||||
|
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
|
||||||
|
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
||||||
|
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
|
||||||
|
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
|
||||||
|
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
|
||||||
|
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||||
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||||
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
||||||
|
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||||
|
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
||||||
|
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
|
||||||
|
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
||||||
|
} else {
|
||||||
|
const int ne23 = ne12 * ne13;
|
||||||
|
|
||||||
|
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
||||||
|
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
||||||
|
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
||||||
|
|
||||||
|
sycl::range<3> block_dims(1, ne12, ne13);
|
||||||
|
queue->submit([&](sycl::handler & cgh) {
|
||||||
|
const void ** ptrs_src_get = ptrs_src.get();
|
||||||
|
void ** ptrs_dst_get = ptrs_dst.get();
|
||||||
|
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
||||||
|
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
||||||
|
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
||||||
|
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
|
||||||
|
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||||
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
||||||
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
||||||
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
||||||
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} catch (const sycl::exception & exc) {
|
} catch (const sycl::exception & exc) {
|
||||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
||||||
@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
|||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
|
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
|
||||||
|
|
||||||
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
||||||
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
||||||
model_sycl_graph.end_recording();
|
model_sycl_graph.end_recording();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user