From 8081e7a23d4f60b9edafc12af24926734372c9b2 Mon Sep 17 00:00:00 2001
From: Svetlozar Georgiev <55534064+sgeor255@users.noreply.github.com>
Date: Thu, 15 May 2025 16:35:44 +0100
Subject: [PATCH] sycl: reordered Q4_K MMVQ (llama/13109)

---
 ggml/src/ggml-sycl/convert.cpp    |  31 ++++++++-
 ggml/src/ggml-sycl/dequantize.hpp |  80 +++++++++++++++------
 ggml/src/ggml-sycl/dmmv.cpp       |   8 ++-
 ggml/src/ggml-sycl/ggml-sycl.cpp  |  80 +++++++++++++++++----
 ggml/src/ggml-sycl/mmvq.cpp       |  31 ++++++++-
 ggml/src/ggml-sycl/quants.hpp     |  22 ++++++
 ggml/src/ggml-sycl/vecdotq.hpp    | 112 ++++++++++++++++++------------
 7 files changed, 280 insertions(+), 84 deletions(-)

diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp
index b2f8a656..75bac98e 100644
--- a/ggml/src/ggml-sycl/convert.cpp
+++ b/ggml/src/ggml-sycl/convert.cpp
@@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
     }
 }
 
+template <typename dst_t>
+static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
+    const int64_t nb = k / QK_K;
+    const size_t  local_size  = 32;
+    const size_t  global_size = nb * local_size;
+
+    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
+
+    stream->submit([&](sycl::handler & cgh) {
+        sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
+
+        cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
+                         [=](sycl::nd_item<1> item_ct1) {
+                             dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
+                         });
+    });
+}
+
 template <typename dst_t>
 static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
                                      dpct::queue_ptr stream) {
@@ -504,7 +522,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
         case GGML_TYPE_Q3_K:
             return dequantize_row_q3_K_sycl;
         case GGML_TYPE_Q4_K:
-            return dequantize_row_q4_K_sycl;
+            if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
+                return dequantize_row_q4_K_sycl_reorder;
+            } else {
+                return dequantize_row_q4_K_sycl;
+            }
         case GGML_TYPE_Q5_K:
             return dequantize_row_q5_K_sycl;
         case GGML_TYPE_Q6_K:
@@ -556,7 +578,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
         case GGML_TYPE_Q3_K:
             return dequantize_row_q3_K_sycl;
         case GGML_TYPE_Q4_K:
-            return dequantize_row_q4_K_sycl;
+            if (dst->src[0]->extra &&
+                ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) {
+                return dequantize_row_q4_K_sycl_reorder;
+            } else {
+                return dequantize_row_q4_K_sycl;
+            }
         case GGML_TYPE_Q5_K:
             return dequantize_row_q5_K_sycl;
         case GGML_TYPE_Q6_K:
diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp
index 651c2160..64e92f73 100644
--- a/ggml/src/ggml-sycl/dequantize.hpp
+++ b/ggml/src/ggml-sycl/dequantize.hpp
@@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8
 }
 #endif
 
+template <typename dst_t>
+inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall,
+                                   const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) {
+    const int is = 2 * il;
+    constexpr int n  = 4;
+
+    uint8_t sc, m;
+    get_scale_min_k4(is + 0, scales_local, sc, m);
+    const float d1 = dall * sc;
+    const float m1 = dmin * m;
+
+    get_scale_min_k4(is + 1, scales_local, sc, m);
+    const float d2 = dall * sc;
+    const float m2 = dmin * m;
+
+    sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(qs_ptr + 32 * il + n * ir);
+    for (int l = 0; l < n; ++l) {
+        y[l + 0]  = d1 * (q_vec[l] & 0xF) - m1;
+        y[l + 32] = d2 * (q_vec[l] >> 4) - m2;
+    }
+}
+
 template<typename dst_t>
 static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
                                   uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
@@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
     const int64_t i = item_ct1.get_group(2);
 
 #if QK_K == 256
-    // assume 32 threads
     const int64_t tid = item_ct1.get_local_id(2);
-    const int64_t il  = tid/8;
-    const int64_t ir  = tid%8;
-    const int64_t is  = 2*il;
-    const int64_t n   = 4;
+    const int64_t il  = tid / 8;
+    const int64_t ir  = tid % 8;
 
-    dst_t * y = yy + i*QK_K + 64*il + n*ir;
+    dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
 
     const sycl::half2 dm = x[i].dm;
     const float dall = dm[0];
     const float dmin = dm[1];
 
-    if (tid < 12)
+    if (tid < 12) {
         scales_local[tid] = x[i].scales[tid];
-    item_ct1.barrier(sycl::access::fence_space::local_space);
-
-    uint8_t sc, m;
-    get_scale_min_k4(is + 0, scales_local, sc, m);
-    const float d1 = dall * sc;
-    const float m1 = dmin * m;
-    get_scale_min_k4(is + 1, scales_local, sc, m);
-    const float d2 = dall * sc;
-    const float m2 = dmin * m;
-
-    sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
-    for (int l = 0; l < n; ++l) {
-        y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
-        y[l +32] = d2 * (q_vec[l] >>  4) - m2;
     }
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+    dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir);
 #else
     const int64_t tid = item_ct1.get_local_id(2);
     const uint8_t * q = x[i].qs;
@@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
 #endif
 }
 
+template <typename dst_t>
+static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local,
+                                          const sycl::nd_item<1> & item_ct1, int64_t nb) {
+    const int64_t i   = item_ct1.get_group(0);     // block index
+    const int64_t tid = item_ct1.get_local_id(0);  // thread index within block
+    const int64_t il  = tid / 8;
+    const int64_t ir  = tid % 8;
+
+    dst_t * y = yy + i * QK_K + 64 * il + 4 * ir;
+
+    const uint8_t * base          = static_cast<const uint8_t *>(vx);
+    const size_t    qs_offset     = i * (QK_K / 2);
+    const size_t    scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE;
+    const size_t    dm_offset     = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2);
+
+    const uint8_t *    qs_ptr     = base + qs_offset;
+    const uint8_t *    scales_ptr = base + scales_offset;
+    ggml_half2         dm_values  = *reinterpret_cast<const ggml_half2 *>(base + dm_offset);
+
+    const float dall = dm_values.x();
+    const float dmin = dm_values.y();
+
+    if (tid < 12) {
+        scales_local[tid] = scales_ptr[tid];
+    }
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+    dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir);
+}
+
 template<typename dst_t>
 static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
                                   const sycl::nd_item<3> &item_ct1) {
diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp
index 04a85fa3..b58150c6 100644
--- a/ggml/src/ggml-sycl/dmmv.cpp
+++ b/ggml/src/ggml-sycl/dmmv.cpp
@@ -1129,7 +1129,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
             dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
             break;
         case GGML_TYPE_Q4_K:
-            dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
+                ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
+                // reorder is currently not supported for dmmv
+                GGML_ABORT("Unimplemented dequantize case case for q4_k reorder");
+            } else {
+                dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
+            }
             break;
         case GGML_TYPE_Q5_K:
             dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 1205fce0..5ff7fa13 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -352,7 +352,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
         assert(tensor->view_src->buffer->buft == buffer->buft);
         return GGML_STATUS_SUCCESS;
     }
-    if (tensor->type == GGML_TYPE_Q4_0 && !g_ggml_sycl_disable_optimize) {
+    if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
         ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
         tensor->extra                 = extra;
         ctx->tensor_extras.push_back(extra);  //used to release it when destroy ctx.
@@ -2900,6 +2900,8 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
             return true;
+        case GGML_TYPE_Q4_K:
+            return !g_ggml_sycl_prioritize_dmmv;
         default:
             return false;
     }
@@ -2917,6 +2919,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) {
 inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_K:
             return true;
         default:
             return false;
@@ -2942,16 +2945,16 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
     }
 }
 
-static void reorder_qw(char *data_device, const int ncols, const int nrows,
-                size_t size, size_t offset, dpct::queue_ptr stream) {
-    auto tmp_buf = sycl::malloc_shared<char>(size, *stream);
+static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
+                            dpct::queue_ptr stream) {
+    auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
     SYCL_CHECK(
         CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
             .wait()));
     GGML_ASSERT((size % sizeof(block_q4_0) == 0));
     GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
     int offset_blks = offset / sizeof(block_q4_0);
-    auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;
+    auto qs_ptr      = data_device + offset_blks * QK4_0 / 2;
     auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
 
     stream->parallel_for(
@@ -2965,18 +2968,59 @@ static void reorder_qw(char *data_device, const int ncols, const int nrows,
                 *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
             }
             *(d_ptr + ib) = x[ib].d;
-        });
+        }).wait_and_throw();
+
+    sycl::free(tmp_buf, *stream);
+}
+
+static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
+    GGML_ASSERT(size % sizeof(block_q4_K) == 0);
+    GGML_ASSERT(offset % sizeof(block_q4_K) == 0);
+
+    const int nblocks = size / sizeof(block_q4_K);
+
+    auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
+    SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
+
+    auto * qs_ptr     = data_device;
+    auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
+    auto * dm_ptr     = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
+
+    stream->parallel_for(nblocks, [=](auto i) {
+        const block_q4_K * x  = (const block_q4_K *) tmp_buf;
+        const int          ib = i;
+
+        for (int j = 0; j < QK_K / 2; ++j) {
+            qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j];
+        }
+
+        for (int j = 0; j < K_SCALE_SIZE; ++j) {
+            scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j];
+        }
+
+        dm_ptr[ib] = x[ib].dm;
+    }).wait_and_throw();
 
     sycl::free(tmp_buf, *stream);
 }
 
 static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
-    char*data_device = (char*)src0->data;
+    uint8_t * data_device = (uint8_t *) src0->data;
     size_t ncols = src0->ne[0];
     size_t nrows = src0->ne[1];
     size_t size = ggml_nbytes(src0);
 
-    reorder_qw(data_device, ncols, nrows, size, 0, stream);
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            reorder_qw_q4_k(data_device, size, 0, stream);
+            break;
+        default:
+            GGML_ABORT("reorder_qw() called with unsupported type");
+            break;
+    }
 }
 
 static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) {
@@ -3019,8 +3063,18 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor *
     extra->optimized_feature.reorder = true;  // Used to decode/dequan in next steps and avoid re-reordering
 }
 
-static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
+static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
+           src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
+}
+
+static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 &&
+           src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+}
+
+static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
     int64_t min_compute_capability = INT_MAX;
 
@@ -3043,13 +3097,9 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
     }
 
     // check data types and tensor shapes for custom matrix multiplication kernels:
-    bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
+    bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst);
 
-    bool use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+    bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst);
 
     bool use_mul_mat_q =  ggml_sycl_supports_mmq(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp
index 3cade1a4..23eeb74d 100644
--- a/ggml/src/ggml-sycl/mmvq.cpp
+++ b/ggml/src/ggml-sycl/mmvq.cpp
@@ -24,6 +24,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
     const int     blocks_per_row              = ncols / block_traits::qk;
     constexpr int blocks_per_subgroup         = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi);
     constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq;
+    const int     nblocks                     = nrows * (ncols / block_traits::qk);
 
     static_assert(blocks_per_subgroup > 0);
     static_assert(block_elements_per_subgroup > 0);
@@ -45,7 +46,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
             // x block quant index when casting the quants to int
             const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);
 
-            partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs);
+            partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, &y[iby], iqs, nblocks);
         }
     }
 
@@ -739,6 +740,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
     }
 }
 
+static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
+    const int nrows, dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_K == 0);
+
+    const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
+    constexpr size_t num_subgroups = 16;
+    GGML_ASSERT(block_num_y % num_subgroups == 0);
+
+    const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
+    const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
+
+    stream->submit([&](sycl::handler & cgh) {
+        cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
+                            [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
+                                                                                            nrows, nd_item);
+                            });
+    });
+}
+
+
 static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
                                        float *dst, const int ncols,
                                        const int nrows,
@@ -1035,7 +1057,12 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
                 mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
                 break;
             case GGML_TYPE_Q4_K:
-                mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+                if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
+                    ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
+                    reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+                } else {
+                    mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+                }
                 break;
             case GGML_TYPE_Q5_K:
                 mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp
index a74e3052..88ec13ea 100644
--- a/ggml/src/ggml-sycl/quants.hpp
+++ b/ggml/src/ggml-sycl/quants.hpp
@@ -56,6 +56,28 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
     static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
 };
 
+template <> struct block_q_t<GGML_TYPE_Q4_K> {
+    struct traits {
+        static constexpr uint32_t qk       = QK_K;
+        static constexpr uint32_t qi       = QI4_K;
+        static constexpr uint32_t qr       = QR4_K;
+        static constexpr uint32_t vdr_mmvq = 2;
+    };
+
+    static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
+
+    static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
+        auto nblocks = (nrows * (ncols / traits::qk));
+        return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
+    }
+
+    static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
+
+    constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }
+
+    constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
+};
+
 }  // namespace ggml_sycl_reordered
 
 #endif  // GGML_SYCL_QUANTS_HPP
diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp
index cbf664fc..ed369931 100644
--- a/ggml/src/ggml-sycl/vecdotq.hpp
+++ b/ggml/src/ggml-sycl/vecdotq.hpp
@@ -285,7 +285,7 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
     }
 
     __dpct_inline__ float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
-                     const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
+                     const block_q8_1 * __restrict__ bq8_1, const int & iqs, int /* nblocks */) {
         const uint8_t * bq4_0 = static_cast<const uint8_t *>(vbq) + ibx_offset;
         const ggml_half d     = *(reinterpret_cast<const ggml_half *>(static_cast<const uint8_t *>(vbq) + d_offset));
         int             v[q4_0_traits::vdr_mmvq];
@@ -303,6 +303,67 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> {
     };
 };
 
+static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales,
+                                             const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1,
+                                             const int &        iqs) {
+    int   v[2];
+    int   u[2 * QR4_K];
+    float d8[QR4_K];
+
+    v[0] = q4[0];
+    v[1] = q4[4];
+
+    uint16_t  aux[2];
+    const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2;
+    if (j < 2) {
+        aux[0] = scales[j + 0] & 0x3f3f;
+        aux[1] = scales[j + 2] & 0x3f3f;
+    } else {
+        aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2);
+        aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2);
+    }
+
+    const uint8_t * sc = (const uint8_t *) aux;
+    const uint8_t * m  = sc + 2;
+
+    const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
+
+    for (int i = 0; i < QR4_K; ++i) {
+        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+        d8[i]                   = bq8i->ds[0];
+
+        const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4);
+        u[2 * i + 0]   = q8[0];
+        u[2 * i + 1]   = q8[4];
+    }
+
+    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8);
+}
+
+template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> {
+    static constexpr ggml_type gtype = GGML_TYPE_Q4_K;
+
+    using q4_k_block  = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q4_K>;
+    using q4_k_traits = typename q4_k_block::traits;
+
+    float operator()(const void * __restrict__ vbq, const int ibx_offset, const int d_offset,
+                     const block_q8_1 * __restrict__ bq8_1, const int & iqs, int nblocks) {
+        const int ib = ibx_offset / (QK_K / 2);
+
+        const uint8_t *    base           = static_cast<const uint8_t *>(vbq);
+        const uint8_t *    qs             = base + ibx_offset;
+        const int          total_qs_bytes = nblocks * (QK_K / 2);
+        const uint8_t *    scs            = base + total_qs_bytes + ib * K_SCALE_SIZE;
+        const ggml_half2 * dms            = reinterpret_cast<const ggml_half2 *>(base + d_offset);
+
+        const int        bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
+        const int *      q4         = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
+        const uint16_t * scales     = (const uint16_t *) scs;
+
+        return vec_dot_q4_K_q8_1_common(q4, scales, *dms, bq8_1, iqs);
+    }
+};
+
 #define VDR_Q4_0_Q8_1_MMVQ 2
 #define VDR_Q4_0_Q8_1_MMQ  4
 
@@ -649,52 +710,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq,
     return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
 }
 
-static __dpct_inline__ float
-vec_dot_q4_K_q8_1(const void *__restrict__ vbq,
-                  const block_q8_1 *__restrict__ bq8_1, const int &iqs) {
-
+static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
+                                               const int & iqs) {
 #ifndef GGML_QKK_64
+
     const block_q4_K * bq4_K = (const block_q4_K *) vbq;
 
-    int    v[2];
-    int    u[2*QR4_K];
-    float d8[QR4_K];
+    const int        bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2));
+    const int *      q4         = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4));
+    const uint16_t * scales     = (const uint16_t *) bq4_K->scales;
 
-    // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
-    const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
-
-    // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
-    // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
-    // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
-    // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
-
-    const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
-    v[0] = q4[0];
-    v[1] = q4[4];
-
-    const uint16_t * scales = (const uint16_t *)bq4_K->scales;
-    uint16_t aux[2];
-    const int j = bq8_offset/2;
-    if (j < 2) {
-        aux[0] = scales[j+0] & 0x3f3f;
-        aux[1] = scales[j+2] & 0x3f3f;
-    } else {
-        aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
-        aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
-    }
-    const uint8_t * sc = (const uint8_t *)aux;
-    const uint8_t * m  = sc + 2;
-
-    for (int i = 0; i < QR4_K; ++i) {
-        const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
-        d8[i] = bq8i->ds[0];
-
-        const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
-        u[2*i+0] = q8[0];
-        u[2*i+1] = q8[4];
-    }
-
-    return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+    return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs);
 
 #else