diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ff53bdfb..e46007a5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1476,26 +1476,26 @@ static void ggml_vk_load_shaders(vk_device& device) { // spec constants and tile sizes for quant matmul (non-Qi_K) l_warptile_mmq = { 256, 128, 256, 64 }; m_warptile_mmq = { 256, 128, 128, 64 }; - s_warptile_mmq = { 256, 128, 128, 64 }; + s_warptile_mmq = { 256, 32, 64, 128 }; l_mmq_wg_denoms = { 128, 256, 1 }; m_mmq_wg_denoms = { 128, 128, 1 }; - s_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul (Qi_K) - l_warptile_mmq_k = { 256, 128, 512, 16 }; - m_warptile_mmq_k = { 256, 128, 256, 16 }; - s_warptile_mmq_k = { 256, 32, 128, 64 }; - l_mmq_wg_denoms_k = { 128, 512, 1 }; - m_mmq_wg_denoms_k = { 128, 256, 1 }; - s_mmq_wg_denoms_k = { 32, 128, 1 }; + l_warptile_mmq_k = { 256, 64, 128, 64 }; + m_warptile_mmq_k = { 256, 32, 64, 64 }; + s_warptile_mmq_k = { 256, 32, 32, 128 }; + l_mmq_wg_denoms_k = { 64, 128, 1 }; + m_mmq_wg_denoms_k = { 32, 64, 1 }; + s_mmq_wg_denoms_k = { 32, 32, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 16 }; + l_warptile_mmqid = { 256, 128, 64, 16 }; m_warptile_mmqid = { 256, 128, 64, 16 }; - s_warptile_mmqid = { 256, 64, 64, 16 }; - l_mmqid_wg_denoms = { 128, 128, 1 }; + s_warptile_mmqid = { 256, 128, 64, 16 }; + l_mmqid_wg_denoms = { 128, 64, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; - s_mmqid_wg_denoms = { 64, 64, 1 }; + s_mmqid_wg_denoms = { 128, 64, 1 }; l_align = 128; m_align = 64; @@ -3850,10 +3850,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) { + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; @@ -3898,13 +3902,17 @@ static void ggml_vk_matmul( } static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) { + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s;