mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-18 15:47:08 +02:00
SYCL: Add mrope kernel (llama/13755)
* SYCL: Add mrope kernel * feat: Optimize rope operations with vectorization Uses `sycl::vec` to load and store two elements at a time, significantly improving performance in `rope_norm`, `rope_neox`, and `rope_multi`. This reduces the number of memory accesses and leverages SIMD instructions for faster execution. * Use ceil_div
This commit is contained in:
parent
1893359cfd
commit
f7f92d0aab
@ -4257,14 +4257,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
|
||||||
const int mode = ((const int32_t *) op->op_params)[2];
|
|
||||||
// mode is not used as a bitmask in practice, the various rope type modes are independent implementations
|
|
||||||
if (mode == GGML_ROPE_TYPE_MROPE) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
|
@ -49,10 +49,7 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
|
|||||||
|
|
||||||
if (i0 >= n_dims) {
|
if (i0 >= n_dims) {
|
||||||
const int i = row * ne0 + i0;
|
const int i = row * ne0 + i0;
|
||||||
|
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||||
dst[i + 0] = x[i + 0];
|
|
||||||
dst[i + 1] = x[i + 1];
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,10 +90,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|||||||
|
|
||||||
if (i0 >= n_dims) {
|
if (i0 >= n_dims) {
|
||||||
const int i = row * ne0 + i0;
|
const int i = row * ne0 + i0;
|
||||||
|
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||||
dst[i + 0] = x[i + 0];
|
|
||||||
dst[i + 1] = x[i + 1];
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,6 +116,63 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
|
|||||||
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, bool has_ff>
|
||||||
|
static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||||
|
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
||||||
|
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
|
||||||
|
const float theta_scale, const float * freq_factors, const mrope_sections sections,
|
||||||
|
const sycl::nd_item<3> & item_ct1) {
|
||||||
|
// get index pos
|
||||||
|
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
|
||||||
|
if (i0 >= ne0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
|
||||||
|
|
||||||
|
if (i0 >= n_dims) {
|
||||||
|
const int i = row_dst*ne0 + i0;
|
||||||
|
*reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int row_x = row_dst % ne1;
|
||||||
|
const int channel_x = row_dst / ne1;
|
||||||
|
const int idst = (row_dst * ne0) + (i0 / 2);
|
||||||
|
const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
|
||||||
|
|
||||||
|
const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
|
||||||
|
const int sec_w = sections.v[1] + sections.v[0];
|
||||||
|
const int sector = (i0 / 2) % sect_dims;
|
||||||
|
|
||||||
|
|
||||||
|
float theta_base = 0.0;
|
||||||
|
if (sector < sections.v[0]) {
|
||||||
|
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
else if (sector >= sections.v[0] && sector < sec_w) {
|
||||||
|
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
|
||||||
|
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
else if (sector >= sec_w + sections.v[2]) {
|
||||||
|
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
|
||||||
|
float cos_theta;
|
||||||
|
float sin_theta;
|
||||||
|
rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
||||||
|
const float x0 = x[ix + 0];
|
||||||
|
const float x1 = x[ix + n_dims/2];
|
||||||
|
|
||||||
|
// store results in dst
|
||||||
|
dst[idst + 0] = x0 * cos_theta - x1 * sin_theta;
|
||||||
|
dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T, bool has_ff>
|
template <typename T, bool has_ff>
|
||||||
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||||
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
|
||||||
@ -171,7 +222,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|||||||
const float * freq_factors, queue_ptr stream) {
|
const float * freq_factors, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
GGML_ASSERT(ne0 % 2 == 0);
|
||||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||||
const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||||
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||||
@ -208,7 +259,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|||||||
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
GGML_ASSERT(ne0 % 2 == 0);
|
||||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||||
const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||||
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
const sycl::range<3> block_nums(1, num_blocks_x, nr);
|
||||||
|
|
||||||
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
const float theta_scale = powf(freq_base, -2.0f / n_dims);
|
||||||
@ -228,6 +279,40 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||||
|
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
|
||||||
|
const float freq_scale, const float freq_base, const float ext_factor,
|
||||||
|
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
|
||||||
|
const mrope_sections sections, queue_ptr stream) {
|
||||||
|
GGML_ASSERT(ne0 % 2 == 0);
|
||||||
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||||
|
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||||
|
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
||||||
|
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
||||||
|
|
||||||
|
const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
|
||||||
|
// Add FP16 capability check if T could be sycl::half
|
||||||
|
if constexpr (std::is_same_v<T, sycl::half>) {
|
||||||
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
}
|
||||||
|
// launch kernel
|
||||||
|
if (freq_factors == nullptr) {
|
||||||
|
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||||
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||||
|
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||||
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// rope vision
|
// rope vision
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
|
||||||
@ -237,7 +322,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
|||||||
const mrope_sections sections, queue_ptr stream) {
|
const mrope_sections sections, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne0 % 2 == 0);
|
GGML_ASSERT(ne0 % 2 == 0);
|
||||||
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
|
||||||
const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
|
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
|
||||||
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
const sycl::range<3> grid_dims(1, n_blocks_y, nr);
|
||||||
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
|
||||||
|
|
||||||
@ -298,8 +383,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|||||||
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
|
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||||
|
|
||||||
|
if (is_mrope) {
|
||||||
|
GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_vision) {
|
||||||
|
GGML_ASSERT(n_dims == ne00/2);
|
||||||
|
}
|
||||||
|
|
||||||
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
const int32_t * pos = (const int32_t *) dst->src[1]->data;
|
||||||
|
|
||||||
const float * freq_factors = nullptr;
|
const float * freq_factors = nullptr;
|
||||||
@ -326,6 +420,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
|
|||||||
} else {
|
} else {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
} else if (is_mrope && !is_vision) {
|
||||||
|
GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
|
||||||
|
if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||||
|
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
|
||||||
|
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
|
||||||
|
freq_factors, sections, main_stream);
|
||||||
|
} else if (dst->src[0]->type == GGML_TYPE_F32) {
|
||||||
|
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
|
||||||
|
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
|
||||||
|
main_stream);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("Fatal error: Tensor type unsupported!");
|
||||||
|
}
|
||||||
} else if (is_vision) {
|
} else if (is_vision) {
|
||||||
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
GGML_SYCL_DEBUG("%s: vision path\n", __func__);
|
||||||
if (dst->src[0]->type == GGML_TYPE_F16) {
|
if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user