From 1e8d692365e44f165980acad9954b945e35b44f0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 15 Aug 2025 17:16:36 +0300 Subject: [PATCH] vulkan : fix out-of-bounds access in argmax kernel (llama/15342) ggml-ci --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c5354293..76a0cfa4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8392,7 +8392,7 @@ static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); } static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp index eaf4da34..a1d4c240 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -5,6 +5,8 @@ #extension GL_EXT_control_flow_attributes : enable +#define FLT_MAX 3.402823466e+38F + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -19,19 +21,26 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint col = gl_LocalInvocationID.x; - if (col >= p.KX) { + if (row >= p.KY) { return; } - A_TYPE amax = data_a[row*p.KX + col]; - tmp[col] = col; + + A_TYPE amax = -FLT_MAX; + uint acol = col; + + if (col < p.KX) { + amax = data_a[row*p.KX + col]; + } for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { A_TYPE val = data_a[row*p.KX + i]; if (val > amax) { amax = val; - tmp[col] = i; + acol = i; } } + + tmp[col] = acol; tmpmax[col] = amax; barrier();