From 905b834af1367d0166d1c3cec61ca0cc3dd35782 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 17 Mar 2025 04:43:35 -0500 Subject: [PATCH] vulkan: use fp32 in coopmat2 q4_k dequant function (llama/12309) --- .../ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 4ccbe613..8efe4653 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -178,7 +178,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 uvec4 v = bl128.block.q4k[0]; - const f16vec2 loadd = unpackFloat2x16(v.x); + const vec2 loadd = vec2(unpackFloat2x16(v.x)); uint32_t sc; uint32_t mbyte; @@ -199,15 +199,15 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 sc &= 0x3F; mbyte &= 0x3F; - const float16_t d = loadd.x * float16_t(sc); - const float16_t m = loadd.y * float16_t(mbyte); + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; - float16_t ret = d * float16_t(qs) - m; + float ret = d * float(qs) - m; - return ret; + return float16_t(ret); } layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {