From 2f77a9e9bd31e01db57ad4d435d4dc71b77f9434 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Fri, 21 Mar 2025 19:27:47 +0000 Subject: [PATCH] vulkan: workaround for AMD Windows driver 16 bit unpack8 bug (llama/12472) --- .../ggml-vulkan/vulkan-shaders/dequant_funcs.comp | 4 ++-- .../vulkan-shaders/mul_mat_vec_iq2_s.comp | 4 ++-- .../vulkan-shaders/mul_mat_vec_iq3_s.comp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp | 14 +++++++------- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 8835c442..2a162a2c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -82,8 +82,8 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]); - const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy; return vec4(v0.x, v0.y, v1.x, v1.y); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp index 9718a05e..8d01536f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -19,8 +19,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const float db = d * (0.5 + scale) * 0.25; const uint qh = data_a[ibi].qh[ib32]; - const u8vec2 qs16 = unpack8(data_a_packed16[ibi].qs[itid]); - const u8vec2 sign16 = unpack8(data_a_packed16[ibi].qs[QUANT_K / 16 + itid]); + const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147 + const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy; [[unroll]] for (uint l = 0; l < 2; ++l) { const uint8_t sign = sign16[l]; const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp index af48f329..f021e404 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -21,7 +21,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, sum[j] = 0.0; } [[unroll]] for (uint l = 0; l < 4; ++l) { - const u8vec2 qs = unpack8(data_a_packed16[ibi].qs[4 * ib32 + l]); + const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147 const uint sign = data_a[ibi].signs[4 * ib32 + l]; const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)])); const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)])); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 0d03411f..5a0054ba 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -336,8 +336,8 @@ void main() { const uint iqs = idx & 0x07; const float d = float(data_a_packed16[ib].d); - const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]); - const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); @@ -544,7 +544,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -564,7 +564,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -586,7 +586,7 @@ void main() { const float db = d * 0.25 * (0.5 + scale); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -611,7 +611,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -631,7 +631,7 @@ void main() { const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);