diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt
index 969a178f..c9ff4aa3 100644
--- a/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ggml/src/ggml-cuda/CMakeLists.txt
@@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND)
 
     set(CUDA_CXX_FLAGS "")
 
-    set(CUDA_FLAGS -use_fast_math)
+    set(CUDA_FLAGS -use_fast_math -extended-lambda)
 
     if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
         # Options are:
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index 919217df..64fb4ff4 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -296,6 +296,25 @@ static __device__ void no_device_code(
 #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
+// The compiler is always able to unroll loops if they contain continue expressions.
+// In such cases loop unrolling can still be achieved via recursion:
+template <int n>
+struct ggml_cuda_unroll {
+    template <typename Func, typename... Args>
+    __device__ void operator()(const Func & f, Args... args) const {
+        f(n - 1, args...);
+        ggml_cuda_unroll<n - 1>{}(f, args...);
+    }
+};
+
+template <>
+struct ggml_cuda_unroll<1> {
+    template <typename Func, typename... Args>
+    __device__ void operator()(const Func & f, Args... args) const {
+        f(0, args...);
+    }
+};
+
 template<int width = WARP_SIZE>
 static __device__ __forceinline__ int warp_reduce_sum(int x) {
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh
index ecb65999..63d0c482 100644
--- a/ggml/src/ggml-cuda/cp-async.cuh
+++ b/ggml/src/ggml-cuda/cp-async.cuh
@@ -2,6 +2,17 @@
 
 #include "common.cuh"
 
+
+static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
+#ifdef CP_ASYNC_AVAILABLE
+    return __cvta_generic_to_shared(generic_ptr);
+#else
+    GGML_UNUSED(generic_ptr);
+    NO_DEVICE_CODE;
+    return 0;
+#endif // CP_ASYNC_AVAILABLE
+}
+
 // Copies data from global to shared memory, cg == cache global.
 // Both the src and dst pointers must be aligned to 16 bit.
 // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh
index c7dc7288..b7180d59 100644
--- a/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ggml/src/ggml-cuda/fattn-common.cuh
@@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
-template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
+template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
         float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
         fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
         GGML_ABORT("fatal error");
     } else {
-        fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
+        fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
         fprintf(stderr, "Only f16 is supported.\n");
         GGML_ABORT("fatal error");
     }
 }
 
-template <int D, int ncols1, int ncols2, int KQ_stride>
+template <int DV, int ncols1, int ncols2>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
     const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@@ -691,7 +691,7 @@ void launch_fattn(
 
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
     GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
-                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+        "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
@@ -754,10 +754,13 @@ void launch_fattn(
     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
     const dim3 block_dim(warp_size, nwarps, 1);
+    int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
+    CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
+
     dim3 blocks_num;
     if (stream_k) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
-        const int max_blocks = 2*nsm;
+        const int max_blocks = max_blocks_per_sm*nsm;
         const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
         const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
 
@@ -769,14 +772,11 @@ void launch_fattn(
         blocks_num.y = 1;
         blocks_num.z = 1;
 
-        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
+        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
     } else {
         GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
         const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
 
-        int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
-        CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
-
         // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
         parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
 
@@ -853,19 +853,19 @@ void launch_fattn(
 
     if (stream_k) {
         if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
-            const dim3 block_dim_combine(D, 1, 1);
+            const dim3 block_dim_combine(DV, 1, 1);
             const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
 
-            flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
+            flash_attn_stream_k_fixup<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
                 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
         }
     } else if (parallel_blocks > 1) {
-        const dim3 block_dim_combine(D, 1, 1);
+        const dim3 block_dim_combine(DV, 1, 1);
         const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
         const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
 
-        flash_attn_combine_results<D>
+        flash_attn_combine_results<DV>
             <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
             (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
     }
diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 04804a15..2b6bdc30 100644
--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -13,104 +13,217 @@ typedef tile<16, 16, float> tile_C_KQ_16;
 typedef tile<16,  4, half2> tile_C_VKQ;
 typedef tile<16,  8, half2> tile_C_VKQ_16;
 
-template<int D, int nwarps, int KQ_per_iter>
+// Config options for specific head sizes.
+// Should not affect results, only speed/register pressure/shared memory use.
+//
+// nbatch_fa:      number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
+// nwarps_max:     maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
+// Q_in_reg:       whether the Q values should be kept permanently in registers.
+// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
+// nbatch_K2:      number of K half2 values in direction of DKQ to load in parallel.
+// nbatch_V2:      number of V half2 values in direction of DV to load in parallel.
+// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
+
+template <int DKQ, int DV>
+struct fattn_mma_f16_config;
+
+template <>
+struct fattn_mma_f16_config< 64,  64> {
+    static constexpr int  nbatch_fa      = 64;
+    static constexpr int  nwarps_max     = 4;
+    static constexpr bool Q_in_reg       = true;
+    static constexpr int  nstages_target = 2;
+    static constexpr int  nbatch_K2      = 32;
+    static constexpr int  nbatch_V2      = 32;
+    static constexpr int  nbatch_combine = 32;
+};
+
+template <>
+struct fattn_mma_f16_config< 80,  80> {
+    static constexpr int  nbatch_fa      = 64;
+    static constexpr int  nwarps_max     = 4;
+    static constexpr bool Q_in_reg       = true;
+    static constexpr int  nstages_target = 2;
+    static constexpr int  nbatch_K2      = 40;
+    static constexpr int  nbatch_V2      = 40;
+    static constexpr int  nbatch_combine = 40;
+};
+
+template <>
+struct fattn_mma_f16_config< 96,  96> {
+    static constexpr int  nbatch_fa      = 64;
+    static constexpr int  nwarps_max     = 4;
+    static constexpr bool Q_in_reg       = true;
+    static constexpr int  nstages_target = 2;
+    static constexpr int  nbatch_K2      = 48;
+    static constexpr int  nbatch_V2      = 48;
+    static constexpr int  nbatch_combine = 48;
+};
+
+template <>
+struct fattn_mma_f16_config<112, 112> {
+    static constexpr int  nbatch_fa      = 64;
+    static constexpr int  nwarps_max     = 4;
+    static constexpr bool Q_in_reg       = true;
+    static constexpr int  nstages_target = 2;
+    static constexpr int  nbatch_K2      = 56;
+    static constexpr int  nbatch_V2      = 56;
+    static constexpr int  nbatch_combine = 56;
+};
+
+template <>
+struct fattn_mma_f16_config<128, 128> {
+    static constexpr int  nbatch_fa      = 64;
+    static constexpr int  nwarps_max     = 4;
+    static constexpr bool Q_in_reg       = true;
+    static constexpr int  nstages_target = 2;
+    static constexpr int  nbatch_K2      = 64;
+    static constexpr int  nbatch_V2      = 64;
+    static constexpr int  nbatch_combine = 64;
+};
+
+template <>
+struct fattn_mma_f16_config<256, 256> {
+    static constexpr int  nbatch_fa      = 32;
+    static constexpr int  nwarps_max     = 4;
+    static constexpr bool Q_in_reg       = true;
+    static constexpr int  nstages_target = 2;
+    static constexpr int  nbatch_K2      = 128;
+    static constexpr int  nbatch_V2      = 128;
+    static constexpr int  nbatch_combine = 128;
+};
+
+template <>
+struct fattn_mma_f16_config<576, 512> {
+    static constexpr int  nbatch_fa      = 32;
+    static constexpr int  nwarps_max     = 8;
+    static constexpr bool Q_in_reg       = false;
+    static constexpr int  nstages_target = 1;
+    static constexpr int  nbatch_K2      = 160;
+    static constexpr int  nbatch_V2      = 128;
+    static constexpr int  nbatch_combine = 128;
+};
+
+// ------------------------------------------------------------------------------------------------------------------
+
+template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
 static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
-        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
-    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
 
-    // If cp.async is available, load up to the highest power of 2 in D asynchronously:
-#ifdef CP_ASYNC_AVAILABLE
-    static_assert(D >= 64 && D < 512, "bad D");
-    constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
-
-    const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
-
-    constexpr int preload = 64;
-    constexpr int h2_per_chunk = 16/sizeof(half2);
-    constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
-    constexpr int stride_i = WARP_SIZE / chunks_per_row;
-#pragma unroll
-    for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
-        const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
-        const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
-
-        cp_async_cg_16<preload>(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
-    }
-#else
-    constexpr int k0_sync_start = 0;
-#endif // CP_ASYNC_AVAILABLE
-    static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
-
-    // If D is not a power of 2, the rest is loaded synchronously.
     // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
-    static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds");
-#pragma unroll
-    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-        const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
-        const int k0_stop  =                                         D/2 - (D/2) % (1*stride_k);
-        const int stride_i = WARP_SIZE / stride_k;
+    // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
 
-        if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
-            continue;
-        }
+    if (use_cp_async) {
+        constexpr int preload = 64;
+        constexpr int h2_per_chunk = 16/sizeof(half2);
+        const int chunks_per_row = D2 / h2_per_chunk;
 
-#pragma unroll
-        for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
-            const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+        const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
 
-#pragma unroll
-            for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+        auto load = [&] __device__ (const int n) {
+            const int stride_k = WARP_SIZE >> n;
+            const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
+            const int k0_stop  =                             chunks_per_row - chunks_per_row % (1*stride_k);
+            const int stride_i = WARP_SIZE / stride_k;
 
-                tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
+            if (k0_start == k0_stop) {
+                return;
             }
-        }
+
+#pragma unroll
+            for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
+                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+                if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
+                    break;
+                }
+
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
+                }
+            }
+        };
+        ggml_cuda_unroll<5>{}(load);
+    } else {
+        static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
+        auto load = [&] __device__ (const int n) {
+            const int stride_k = WARP_SIZE >> n;
+            const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
+            const int k0_stop  =                             D2 - D2 % (1*stride_k);
+            const int stride_i = WARP_SIZE / stride_k;
+
+            if (k0_start == k0_stop) {
+                return;
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
+                const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+                if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
+                    break;
+                }
+
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
+                }
+            }
+        };
+        ggml_cuda_unroll<3>{}(load);
     }
 }
 
-template<int ncols1, int nwarps, int KQ_per_iter>
+template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
 static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
         const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
-    static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter");
-#ifdef CP_ASYNC_AVAILABLE
-    constexpr int preload = KQ_per_iter * sizeof(half);
-    constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter;
-    constexpr int stride_j = nwarps * cols_per_warp;
+    static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
 
-    const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
+    if (use_cp_async) {
+        constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
+        constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
+        constexpr int stride_j = nwarps * cols_per_warp;
+
+        const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
 
+#pragma unroll
+        for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+            const int j = j0 + threadIdx.y*cols_per_warp +
+                (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
+
+            if (j0 + stride_j > ncols1 && j >= ncols1) {
+                break;
+            }
+
+            const int i = 4 * (threadIdx.x % (nbatch_fa/8));
+
+            cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
+        }
+        return;
+    }
+
+    constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+    constexpr int stride_j = nwarps * cols_per_warp;
 #pragma unroll
     for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
-        const int j = j0 + threadIdx.y*cols_per_warp +
-            (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8));
+        const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
 
         if (j0 + stride_j > ncols1 && j >= ncols1) {
             break;
         }
 
-        const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8));
+        const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
 
-        cp_async_cg_16<preload>(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
+        tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
     }
-#else
-    constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter;
-    constexpr int stride_j = nwarps * cols_per_warp;
-#pragma unroll
-    for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
-        const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2));
-
-        if (j0 + stride_j > ncols1 && j >= ncols1) {
-            break;
-        }
-
-        const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2);
-
-        tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i];
-    }
-#endif // CP_ASYNC_AVAILABLE
 }
 
-template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -123,9 +236,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float logit_softcap,
         const int ne01,
         const int ne02,
-        const int stride_KV,
+        const int stride_K,
+        const int stride_V,
         const int stride_mask,
         const int jt,
+        half2        * const __restrict__ tile_Q,
         half2        * const __restrict__ tile_K,
         half2        * const __restrict__ tile_V,
         half2        * const __restrict__ tile_mask,
@@ -135,59 +250,107 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         float        * const __restrict__ KQ_rowsum,
         const int kb0) {
 #ifdef NEW_MMA_AVAILABLE
+    typedef fattn_mma_f16_config<DKQ, DV> c;
+
+#ifdef CP_ASYNC_AVAILABLE
+    constexpr int nstages = c::nstages_target;
+#else
+    constexpr int nstages = 0;
+#endif // CP_ASYNC_AVAILABLE
+
     constexpr int cols_per_warp   = ntiles * tile_B::I;
     constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
     constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
-    constexpr int D2_padded       = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
 
-    const int k_VKQ_0 = kb0 * KQ_per_iter;
-    tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles];
+    constexpr int stride_tile_Q = DKQ/2        + 4;
+    constexpr int stride_tile_K = c::nbatch_K2 + 4;
+    constexpr int stride_tile_V = c::nbatch_V2 + 4;
+
+    const int k_VKQ_0 = kb0 * c::nbatch_fa;
+    tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
 
     // Use wide variants of tiles if ntiles >= 2.
     tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
     tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
     tile_C_KQ_16  * KQ_C_16  = (tile_C_KQ_16  *) KQ_C;
 
-#ifdef CP_ASYNC_AVAILABLE
-    cp_async_wait_all();
-    __syncthreads();
-    flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
-#else
-    if (ncols2 > 1 || mask_h2) {
-        flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
-    }
-    flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
-    __syncthreads();
-#endif // CP_ASYNC_AVAILABLE
-
-    // Calculate tile of KQ:
-#pragma unroll
-    for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) {
-        const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
-#pragma unroll
-        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
-            tile_A K_A;
-            load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
-            if (ntiles == 1) {
-                mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
-            } else {
-#pragma unroll
-                for (int t = 0; t < ntiles/2; ++t) {
-                    // Wide version of KQ_C is column-major => swap A and B.
-                    mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
-                }
-            }
+    if constexpr (nstages > 1) {
+        static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
+        constexpr bool use_cp_async = true;
+        cp_async_wait_all();
+        __syncthreads();
+        flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
+            (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
+    } else {
+        constexpr bool use_cp_async = nstages == 1;
+        if (ncols2 > 1 || mask_h2) {
+            flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
         }
     }
 
-#ifndef CP_ASYNC_AVAILABLE
-    __syncthreads(); // Only needed if tile_K == tile_V.
-#endif // CP_ASYNC_AVAILABLE
+#pragma unroll
+    for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
+        const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
+        const int k0_diff = k0_stop - k0_start;
+
+        if (nstages <= 1) {
+            constexpr bool use_cp_async = nstages == 1;
+            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
+                (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K);
+            if (use_cp_async) {
+                cp_async_wait_all();
+            }
+            __syncthreads();
+        }
+
+        // Calculate tile of KQ:
+        if constexpr (c::Q_in_reg) {
+#pragma unroll
+            for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
+                const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+#pragma unroll
+                for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
+                    tile_A K_A;
+                    load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
+                    if (ntiles == 1) {
+                        mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
+                    } else {
+#pragma unroll
+                        for (int t = 0; t < ntiles/2; ++t) {
+                            // Wide version of KQ_C is column-major => swap A and B.
+                            mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
+                        }
+                    }
+                }
+            }
+        } else {
+            static_assert(ntiles == 2, "ntiles != 2 not implemented");
+#pragma unroll
+            for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
+                load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
+
+#pragma unroll
+                for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
+                    const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+
+                    tile_A K_A;
+                    load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
+
+                    // Wide version of KQ_C is column-major => swap A and B.
+                    mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
+                }
+            }
+        }
+
+        if (nstages <= 1) {
+            __syncthreads(); // Only needed if tile_K == tile_V.
+        }
+    }
 
     if (use_logit_softcap) {
-        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+        static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) {
+        for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
 #pragma unroll
             for (int l = 0; l < tile_C_KQ::ne; ++l) {
                 KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
@@ -205,7 +368,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     if (ntiles == 1) {
         if (ncols2 > 1 || mask_h2) {
 #pragma unroll
-            for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) {
+            for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
                 const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
 #pragma unroll
                 for (int l = 0; l < tile_C_KQ::ne; ++l) {
@@ -213,16 +376,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
 
                     KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
-                        __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]);
+                        __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
                 }
             }
         }
 
         // Calculate softmax for each KQ column using the current max. value.
         // The divisor is stored in KQ_rowsum and will be applied at the end.
-        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+        static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
+        for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
 #pragma unroll
             for (int l = 0; l < tile_C_KQ::ne; ++l) {
                 KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
@@ -238,10 +401,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             }
         }
 
-        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
-
+        static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
+        for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
 #pragma unroll
             for (int l = 0; l < tile_C_KQ::ne; ++l) {
                 KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
@@ -252,7 +414,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     } else { // ntiles > 1
         if (ncols2 > 1 || mask_h2) {
 #pragma unroll
-            for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) {
+            for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
                 const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
 #pragma unroll
                 for (int t = 0; t < ntiles/2; ++t) {
@@ -261,7 +423,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                         const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
                         const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
 
-                        const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]);
+                        const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
                         const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
                         KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
                         KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
@@ -272,9 +434,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 
         // Calculate softmax for each KQ column using the current max. value.
         // The divisor is stored in KQ_rowsum and will be applied at the end.
-        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+        static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
+        for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
 #pragma unroll
             for (int t = 0; t < ntiles/2; ++t) {
 #pragma unroll
@@ -294,9 +456,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             }
         }
 
-        static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size");
+        static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
 #pragma unroll
-        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
+        for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
 #pragma unroll
             for (int t = 0; t < ntiles/2; ++t) {
 #pragma unroll
@@ -325,7 +487,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         if (ntiles == 1) {
             const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
 #pragma unroll
-            for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
+            for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
 #pragma unroll
                 for (int l = 0; l < tile_C_VKQ::ne; ++l) {
                     VKQ_C[i].x[l] *= KQ_max_scale_h2;
@@ -336,7 +498,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             for (int col = 0; col < cols_per_thread; ++col) {
                 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
 #pragma unroll
-                for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
+                for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
 #pragma unroll
                     for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
                         VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
@@ -347,16 +509,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
 
     // Convert KQ C tiles into B tiles for VKQ calculation:
-    tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles];
+    tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
     tile_B_16 * B_16 = (tile_B_16 *) B;
-    static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size");
+    static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
     if (ntiles == 1) {
 #pragma unroll
-        for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) {
+        for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
             B[k] = get_transposed(get_half2(KQ_C[k]));
         }
     } else {
-        for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) {
+        for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
 #pragma unroll
             for (int t = 0; t < ntiles/2; ++t) {
                 B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
@@ -364,52 +526,67 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         }
     }
 
-#ifdef CP_ASYNC_AVAILABLE
-    // Preload K tile for next iteration:
-    cp_async_wait_all();
-    __syncthreads();
-    if (!last_iter) {
-        if (ncols2 > 1 || mask_h2) {
-            flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask);
+    if (nstages > 1) {
+        // Preload K tile for next iteration:
+        constexpr bool use_cp_async = true;
+        cp_async_wait_all();
+        __syncthreads();
+        if (!last_iter) {
+            if (ncols2 > 1 || mask_h2) {
+                flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
+                    (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
+            }
+            flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
+                (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
         }
-        flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV);
     }
-#else
-    flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
-    __syncthreads();
-#endif // CP_ASYNC_AVAILABLE
 
-    // Calculate VKQ tile:
 #pragma unroll
-    for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
-        static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size");
-#pragma unroll
-        for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) {
-            const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+    for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
+        const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
+        const int i0_diff = i0_stop - i0_start;
 
-            tile_A A;
-            load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
-            if (ntiles == 1) {
-                mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
-            } else {
+        if (nstages == 1) {
+            constexpr bool use_cp_async = nstages == 1;
+            flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
+                (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
+            if (use_cp_async) {
+                cp_async_wait_all();
+            }
+            __syncthreads();
+        }
+
+        // Calculate VKQ tile:
 #pragma unroll
-                for (int t = 0; t < ntiles/2; ++t) {
-                    // Wide version of VKQ_C is column-major => swap A and B.
-                    mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
+        for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
+            static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
+#pragma unroll
+            for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
+                const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+
+                tile_A A;
+                load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+                if (ntiles == 1) {
+                    mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+                } else {
+#pragma unroll
+                    for (int t = 0; t < ntiles/2; ++t) {
+                        // Wide version of VKQ_C is column-major => swap A and B.
+                        mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
+                    }
                 }
             }
         }
+
+        if (nstages <= 1) {
+            __syncthreads(); // Only needed if tile_K == tile_V.
+        }
     }
-
-#ifndef CP_ASYNC_AVAILABLE
-    __syncthreads(); // Only needed if tile_K == tile_V.
-#endif // CP_ASYNC_AVAILABLE
-
 #else
     GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
     GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
     GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
-    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
+    GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V);
     GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
     GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
     GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
@@ -419,7 +596,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #endif // NEW_MMA_AVAILABLE
 }
 
-template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
 static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -434,7 +611,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int ne02,
         const int stride_Q1,
         const int stride_Q2,
-        const int stride_KV,
+        const int stride_K,
+        const int stride_V,
         const int stride_mask,
         const int jt,
         const int kb0_start,
@@ -442,6 +620,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 #ifdef NEW_MMA_AVAILABLE
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
+    typedef fattn_mma_f16_config<DKQ, DV> c;
+
+#ifdef CP_ASYNC_AVAILABLE
+    constexpr int nstages = c::nstages_target;
+#else
+    constexpr int nstages = 0;
+#endif // CP_ASYNC_AVAILABLE
+
     constexpr int ncols           = ncols1 * ncols2;
     constexpr int cols_per_warp   = ntiles * tile_B::I;
     constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
@@ -449,22 +635,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
     static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
 
-    static_assert(D           % nwarps == 0, "bad D");
-    static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter");
+    constexpr int stride_tile_Q = DKQ/2        + 4;
+    constexpr int stride_tile_K = c::nbatch_K2 + 4;
+    constexpr int stride_tile_V = c::nbatch_V2 + 4;
 
-    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+    constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
 
-    // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements:
-    extern __shared__ half2 tile_K[];
-#ifdef CP_ASYNC_AVAILABLE
-    half2 * tile_V    = tile_K + KQ_per_iter*D2_padded;
-#else
-    half2 * tile_V    = tile_K;
-#endif // CP_ASYNC_AVAILABLE
-    half2 * tile_mask = tile_V + KQ_per_iter*D2_padded;
+    extern __shared__ half2 tile_Q[];
+    half2 * tile_K    = c::Q_in_reg ? tile_Q                                : tile_Q + ncols        * stride_tile_Q;
+    half2 * tile_V    = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
+    half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
 
-    tile_B       Q_B[D/(2*tile_B::J) * ntiles];
-    tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles];
+    tile_B       Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
+    tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I  * ntiles];
 
     tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
     tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
@@ -476,13 +659,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         KQ_max[col] = -FLT_MAX/2.0f;
     }
 
-    // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
+    // Load Q data into tile_Q, either temporarily or permanently.
+    // Q in registers is faster, but register pressure is the biggest bottleneck.
     // The loading is done with decreasing granularity for D for better memory bandwidth.
     const half2 scale_h2 = make_half2(scale, scale);
 #pragma unroll
     for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-        const int k0_start  = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
-        const int k0_stop   =                             D/2 - (D/2) % (1*stride_k);
+        const int k0_start  = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
+        const int k0_stop   =                             DKQ/2 - (DKQ/2) % (1*stride_k);
         const int stride_jc = WARP_SIZE / stride_k;
 
         if (k0_start == k0_stop) {
@@ -506,14 +690,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
                     const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
-                    tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+                    tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
                 }
             } else {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
-                    tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f);
+                    tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
                 }
             }
         }
@@ -521,18 +705,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
     __syncthreads();
 
-    {
+    if (c::Q_in_reg) {
         const int j0 = (threadIdx.y / np) * cols_per_warp;
 
 #pragma unroll
-        for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+        for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
             if (ntiles == 1) {
-                load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
+                load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
             } else {
 #pragma unroll
                 for (int t = 0; t < ntiles/2; ++t) {
                     load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
-                        tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded);
+                        tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
                 }
             }
         }
@@ -540,35 +724,37 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
     __syncthreads();
 
-    // Preload mask and K data for first iteration when using cp_async:
-#ifdef CP_ASYNC_AVAILABLE
-    if (ncols2 > 1 || mask_h2) {
-        flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask);
+    // Preload mask and K data for first iteration when using cp_async with multiple stages:
+    if constexpr (nstages > 1) {
+        static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
+        constexpr bool use_cp_async = true;
+        if (ncols2 > 1 || mask_h2) {
+            flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
+                (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
+        }
+        flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
+            (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
     }
-    flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV);
-#endif // CP_ASYNC_AVAILABLE
 
     // Iterate over ne11 == previous tokens:
     for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
         constexpr bool last_iter = false;
-        flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+        flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+             ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     }
     { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
         constexpr bool last_iter = true;
-        flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+        flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+             ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
     }
 
-    // With cp_async there is no __syncthreads at the end of the iter,
+    // With multi-stage loading there is no __syncthreads at the end of the iter,
     //     there can be a race condition on shared memory access for combining/writing back results.
-#ifdef CP_ASYNC_AVAILABLE
-    if (nwarps*cols_per_warp > KQ_per_iter) {
+    if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
         __syncthreads();
     }
-#endif // CP_ASYNC_AVAILABLE
 
     // Finally, sum up partial KQ rowsums.
     // The partial sums are spread across 8/4 threads each, does not need full reduce.
@@ -584,38 +770,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
     }
 
-    // Write VKQ accumulators to shared memory in column-major format.
-    // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
-    // Also for np > 1 the combination is done via these values in shared memory.
-    if (ntiles == 1) {
-        const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
-#pragma unroll
-        for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
-            const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
+    // Combine VKQ accumulator values if np > 1.
+    // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
+    // So also write VKQ accumulators to shared memory in column-major format if np == 1.
 
-#pragma unroll
-            for (int l = 0; l < tile_B::ne; ++l) {
-                const int k = k0 + tile_B::get_j(l);
-
-                tile_K[jc_cwd*D2_padded + k] = B.x[l];
-            }
-        }
-    } else {
-#pragma unroll
-        for (int t = 0; t < ntiles/2; ++t) {
-            const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
-#pragma unroll
-            for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) {
-#pragma unroll
-                for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
-                    const int j = j0 + tile_C_VKQ_16::get_i(l);
-                    const int k = k0 + tile_C_VKQ_16::get_j(l);
-
-                    tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
-                }
-            }
-        }
-    }
+    constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
+    constexpr int tile_stride    = nbatch_combine + 4;
+    static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
 
     if constexpr (ntiles == 1) {
         const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
@@ -624,7 +785,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
         if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
             // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
-            ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+            ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
         }
 
         __syncthreads();
@@ -649,7 +810,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
         if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
             // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
-            ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+            ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
         }
 
         __syncthreads();
@@ -676,11 +837,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
 
         const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
-        float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4;
+        float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
         float2 meta[nmeta];
 #pragma unroll
         for (int imeta = 0; imeta < nmeta; ++imeta) {
-            meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2];
+            meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
         }
 
         float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
@@ -690,10 +851,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
 #pragma unroll
         for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
-            if (offset >= WARP_SIZE) {
-                continue;
+            if (offset < WARP_SIZE) {
+                KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
             }
-            KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
         }
 
         float KQ_cms[nmeta]; // KQ combine max scale per warp.
@@ -709,10 +869,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
 #pragma unroll
         for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
-            if (offset >= WARP_SIZE) {
-                continue;
+            if (offset < WARP_SIZE) {
+                KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
             }
-            KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
         }
 
         // Write back combined meta data:
@@ -720,7 +879,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         for (int imeta = 0; imeta < nmeta; ++imeta) {
             if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
                 // Combined KQ max scale + rowsum.
-                meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs);
+                meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
             }
         }
 
@@ -736,88 +895,118 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         }
     }
 
-    if (np > 1) {
-        __syncthreads();
-    }
-
-    if (np == 1 || threadIdx.y % np == 0) {
-        // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
-        // The values after that are for the partial results of the individual blocks.
-        float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
+#pragma unroll
+    for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
+        if (ntiles == 1) {
+            const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
+#pragma unroll
+            for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
+                const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
 
 #pragma unroll
-        for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
-            const int k0_start  = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
-            const int k0_stop   =                             D/2 - (D/2) % (1*stride_k);
-            const int stride_jc = WARP_SIZE / stride_k;
+                for (int l = 0; l < tile_B::ne; ++l) {
+                    const int k = k0 + tile_B::get_j(l);
 
-            if (k0_start == k0_stop) {
-                continue;
+                    tile_Q[jc_cwd*tile_stride + k] = B.x[l];
+                }
             }
-
+        } else {
 #pragma unroll
-            for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
-                const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
-
-                if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
-                    break;
-                }
-
-                const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
-
-                const int j_dst = jc_dst / ncols2;
-                const int c_dst = jc_dst % ncols2;
-
-                if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
-                    continue;
-                }
-
-                const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2;
+            for (int t = 0; t < ntiles/2; ++t) {
+                const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
 #pragma unroll
-                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
-                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
-
-                    float2 dstk_val = make_float2(0.0f, 0.0f);
+                for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
 #pragma unroll
-                    for (int ip = 0; ip < np; ++ip) {
-                        const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0];
-                        const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]);
-                        dstk_val.x += dstk_val_add.x*KQ_crs;
-                        dstk_val.y += dstk_val_add.y*KQ_crs;
-                    }
+                    for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
+                        const int j = j0 + tile_C_VKQ_16::get_i(l);
+                        const int k = k0 + tile_C_VKQ_16::get_j(l);
 
-                    if (!needs_fixup && !is_fixup) {
-                        const float KQ_rowsum_j = meta_j[1];
-                        dstk_val.x /= KQ_rowsum_j;
-                        dstk_val.y /= KQ_rowsum_j;
-                    }
-
-                    if (is_fixup) {
-                        dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val;
-                    } else {
-                        dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val;
+                        tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
                     }
                 }
             }
         }
-    }
 
-    if (np > 1) {
         __syncthreads();
+
+        if (np == 1 || threadIdx.y % np == 0) {
+            // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
+            // The values after that are for the partial results of the individual blocks.
+            float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
+
+#pragma unroll
+            for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+                const int k0_start  = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
+                const int k0_stop   =                             nbatch_combine - nbatch_combine % (1*stride_k);
+                const int stride_jc = WARP_SIZE / stride_k;
+
+                if (k0_start == k0_stop) {
+                    continue;
+                }
+
+#pragma unroll
+                for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
+                    const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+                    if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
+                        break;
+                    }
+
+                    const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
+
+                    const int j_dst = jc_dst / ncols2;
+                    const int c_dst = jc_dst % ncols2;
+
+                    if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
+                        continue;
+                    }
+
+                    const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
+#pragma unroll
+                    for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                        const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                        float2 dstk_val = make_float2(0.0f, 0.0f);
+#pragma unroll
+                        for (int ip = 0; ip < np; ++ip) {
+                            const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
+                            const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
+                            dstk_val.x += dstk_val_add.x*KQ_crs;
+                            dstk_val.y += dstk_val_add.y*KQ_crs;
+                        }
+
+                        if (!needs_fixup && !is_fixup) {
+                            const float KQ_rowsum_j = meta_j[1];
+                            dstk_val.x /= KQ_rowsum_j;
+                            dstk_val.y /= KQ_rowsum_j;
+                        }
+
+                        if (is_fixup) {
+                            dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
+                        } else {
+                            dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
+                        }
+                    }
+                }
+            }
+        }
+        if (np > 1) {
+            __syncthreads();
+        }
     }
 #else
     GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
     GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
     GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
     GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
-    GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
+    GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
     GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
     NO_DEVICE_CODE;
 #endif // NEW_MMA_AVAILABLE
 }
 
-template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap>
-__launch_bounds__(nwarps*WARP_SIZE, 2)
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
+__launch_bounds__(nwarps*WARP_SIZE, 1)
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -857,24 +1046,27 @@ static __global__ void flash_attn_ext_f16(
 #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
-    if (use_logit_softcap && !(D == 128 || D == 256)) {
+    if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
         NO_DEVICE_CODE;
         return;
     }
 
-    static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter");
+    typedef fattn_mma_f16_config<DKQ, DV> c;
+
+    static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
 
     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
 
     const int stride_Q1   = nb01 / sizeof(float2);
     const int stride_Q2   = nb02 / sizeof(float2);
-    const int stride_KV   = nb11 / sizeof(half2);
+    const int stride_K    = nb11 / sizeof(half2);
+    const int stride_V    = nb21 / sizeof(half2);
     const int stride_mask = nb31 / sizeof(half2);
 
     const int iter_k = ne11 / FATTN_KQ_STRIDE;
     const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
-    constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice.
+    constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
 
     // kbc == k block continuous, current index in continuous ijk space.
     int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
@@ -893,9 +1085,9 @@ static __global__ void flash_attn_ext_f16(
 
         const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
         const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
-        const half2  * V_h2    = (const half2  *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+        const half2  * V_h2    = (const half2  *) (V + nb22*(channel*ncols2 / gqa_ratio));
         const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
-        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * D/2);
+        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
 
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
 
@@ -905,14 +1097,14 @@ static __global__ void flash_attn_ext_f16(
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
-            flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+                 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
-            flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+                 ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         }
 
         kbc += iter_k;
@@ -931,9 +1123,9 @@ static __global__ void flash_attn_ext_f16(
 
     const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
     const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
-    const half2  * V_h2    = (const half2  *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+    const half2  * V_h2    = (const half2  *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
     const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
-    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * D/2);
+    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
 
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
 
@@ -942,9 +1134,9 @@ static __global__ void flash_attn_ext_f16(
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
-    flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
-         ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+         ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
 #else
     GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
     GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
@@ -960,28 +1152,42 @@ static __global__ void flash_attn_ext_f16(
 #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 }
 
-template <int D, int ncols1, int ncols2>
+template <int DKQ, int DV, int ncols1, int ncols2>
 void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    constexpr int ncols         = ncols1 * ncols2;
-    constexpr int KQ_per_iter   = D <= 128 && ncols1 <= 64 ? 64 : 32;
-    constexpr int nwarps        = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4;
-    constexpr int ntiles        = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4);
-    constexpr int cols_per_warp = ntiles * tile_B::I;
+    const ggml_tensor * KQV = dst;
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
 
-    static_assert(D     %    tile_B::J  == 0, "bad D");
+    typedef fattn_mma_f16_config<DKQ, DV> c;
+
+    constexpr int nbatch_K2      = c::nbatch_K2      < 1 ? DKQ/2 : c::nbatch_K2;
+    constexpr int nbatch_V2      = c::nbatch_V2      < 1 ? DV /2 : c::nbatch_V2;
+    constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
+
+    const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
+
+    constexpr int ncols         = ncols1 * ncols2;
+    constexpr int ntiles        = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
+    constexpr int cols_per_warp = ntiles * tile_B::I;
+    constexpr int nwarps_max_x  = ncols / cols_per_warp;
+    constexpr int nwarps_max_y  = c::nbatch_fa / tile_A::I;
+    constexpr int nwarps        = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
+
+    static_assert(DKQ   % tile_B::J     == 0, "bad DKQ");
+    static_assert(DV    % tile_A::J     == 0, "bad DV");
     static_assert(ncols % cols_per_warp == 0, "bad ncols");
 
-    const ggml_tensor * KQV = dst;
-    const int id    = ggml_cuda_get_device();
-    const int cc    = ggml_cuda_info().devices[id].cc;
+    const size_t nbytes_shared_KV_1stage = c::nbatch_fa         * std::max(c::nbatch_K2 + 4,  c::nbatch_V2 + 4) * sizeof(half2);
+    const size_t nbytes_shared_KV_2stage = c::nbatch_fa         *         (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
+    const size_t nbytes_shared_Q         = ncols                * (DKQ/2 + 4)                                   * sizeof(half2);
+    const size_t nbytes_shared_mask      = ncols1               * (c::nbatch_fa/2 + 4)                          * sizeof(half2);
+    const size_t nbytes_shared_combine   = nwarps*cols_per_warp * (nbatch_combine + 4)                          * sizeof(half2);
 
-    const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter;
+    const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
 
-    const size_t nbytes_shared_KV      = KQ_shared_rows       * (D           + 8) * sizeof(half);
-    const size_t nbytes_shared_mask    = ncols1               * (KQ_per_iter + 8) * sizeof(half);
-    const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D           + 8) * sizeof(half);
-
-    const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine);
+    const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
+        std::max(nbytes_shared_Q,  nbytes_shared_KV + nbytes_shared_mask) :
+                 nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
 
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -989,59 +1195,73 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
+        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+        if (!shared_memory_limit_raised[id]) {
+            CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+            shared_memory_limit_raised[id] = true;
+        }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
+        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+        if (!shared_memory_limit_raised[id]) {
+            CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+            shared_memory_limit_raised[id] = true;
+        }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
     }
 
-    launch_fattn<D, ncols1, ncols2, KQ_per_iter>
+    launch_fattn<DV, ncols1, ncols2>
         (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
 }
 
 
-#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2)                          \
-    template void ggml_cuda_flash_attn_ext_mma_f16_case                     \
-    <D, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2)                          \
+    template void ggml_cuda_flash_attn_ext_mma_f16_case                           \
+    <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
 
-#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \
-    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \
-    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \
-    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
-    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
+#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols)   \
+    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1,  1); \
+    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2,  2); \
+    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4,  4); \
+    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8,  8); \
+    extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,   8)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,   8)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,   8)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,   8)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,   8)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,   8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,   8)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  16)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  16)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  16)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  16)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  16)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  16)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  32)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  32)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  32)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  32)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  32)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  32)
 
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  64)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  64)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  64)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  64)
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  80,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  96,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  64)
 
-// Kernels with ncols == 128 are only 4% faster due to register pressure.
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
+// The number of viable configurations for Deepseek is very limited:
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu
index e0039e17..9283560d 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, -1>
+            launch_fattn<D, cols_per_block, 1>
                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
         } break;
         case 128: {
@@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, -1>
+            launch_fattn<D, cols_per_block, 1>
                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
         } break;
         default: {
diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu
index fcb6f848..32673adb 100644
--- a/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, -1>
+            launch_fattn<D, cols_per_block, 1>
                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
         } break;
         case 128: {
@@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, -1>
+            launch_fattn<D, cols_per_block, 1>
                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
         } break;
         default: {
diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index e17d2d0e..ef0addc1 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
+    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index d42ddca4..7064675d 100644
--- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
+    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
index bc21b27a..c5668adb 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
         fattn_kernel = flash_attn_ext_f16<
             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
     }
-    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
+    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
 }
 
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 7a2d1e45..9c5c803d 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -8,58 +8,32 @@
 #include "fattn-wmma-f16.cuh"
 #include "fattn.cuh"
 
-template <int D, int ncols2>
+template <int DKQ, int DV, int ncols2>
 static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
 
-    if (Q->ne[1] <= 8/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
-        return;
+    if constexpr (ncols2 <= 8) {
+        if (Q->ne[1] <= 8/ncols2) {
+            ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
+            return;
+        }
     }
 
     if (Q->ne[1] <= 16/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
+        ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] <= 32/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
+        ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
         return;
     }
 
-    ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
+    ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
 }
 
-template <int ncols2>
-static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-
-    switch (Q->ne[0]) {
-        case 64:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
-            break;
-        case 80:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
-            break;
-        case 96:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
-            break;
-        case 112:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
-            break;
-        case 128:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
-            break;
-        case 256:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
-            break;
-        default:
-            GGML_ABORT("fatal error");
-            break;
-    }
-}
-
-static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+template <int DKQ, int DV>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
@@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
     float max_bias = 0.0f;
     memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
 
-    const float use_gqa_opt = mask && max_bias == 0.0f;
+    const bool use_gqa_opt = mask && max_bias == 0.0f;
 
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
 
     if (use_gqa_opt && gqa_ratio % 8 == 0) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio == 4) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
+    if (use_gqa_opt && gqa_ratio % 4 == 0) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio == 2) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
+    if (use_gqa_opt && gqa_ratio % 2 == 0) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
         return;
     }
 
-    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
+    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
+}
+
+static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * V    = dst->src[2];
+    const ggml_tensor * mask = dst->src[3];
+
+    switch (Q->ne[0]) {
+        case 64:
+            GGML_ASSERT(V->ne[0] == 64);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64,  64>(ctx, dst);
+            break;
+        case 80:
+            GGML_ASSERT(V->ne[0] == 80);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80,  80>(ctx, dst);
+            break;
+        case 96:
+            GGML_ASSERT(V->ne[0] == 96);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96,  96>(ctx, dst);
+            break;
+        case 112:
+            GGML_ASSERT(V->ne[0] == 112);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
+            break;
+        case 128:
+            GGML_ASSERT(V->ne[0] == 128);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
+            break;
+        case 256:
+            GGML_ASSERT(V->ne[0] == 256);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
+            break;
+        case 576: {
+            // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+            GGML_ASSERT(V->ne[0] == 512);
+            float max_bias = 0.0f;
+            memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+            const bool use_gqa_opt = mask && max_bias == 0.0f;
+            GGML_ASSERT(use_gqa_opt);
+
+            GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+            const int gqa_ratio = Q->ne[2] / K->ne[2];
+            GGML_ASSERT(gqa_ratio % 16 == 0);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+        } break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
 }
 
 #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
@@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
     const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
-    const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
+    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 42302e4e..7643c4b8 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3215,16 +3215,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return false;
 #endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
-                // different head sizes of K and V are not supported yet
-                return false;
+                const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
+                if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
+                    return false;
+                }
+                const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
+                return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
             }
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
-            if (op->src[0]->ne[0] == 576) {
-                // DeepSeek MLA
-                return false;
-            }
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
new file mode 100644
index 00000000..fb26abeb
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
index 80108615..dc168290 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 1, 8);
-DECL_FATTN_MMA_F16_CASE(80, 1, 8);
-DECL_FATTN_MMA_F16_CASE(96, 1, 8);
-DECL_FATTN_MMA_F16_CASE(112, 1, 8);
-DECL_FATTN_MMA_F16_CASE(128, 1, 8);
-DECL_FATTN_MMA_F16_CASE(256, 1, 8);
+DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
index 66161c0a..9d3cfd8e 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 16, 1);
-DECL_FATTN_MMA_F16_CASE(80, 16, 1);
-DECL_FATTN_MMA_F16_CASE(96, 16, 1);
-DECL_FATTN_MMA_F16_CASE(112, 16, 1);
-DECL_FATTN_MMA_F16_CASE(128, 16, 1);
-DECL_FATTN_MMA_F16_CASE(256, 16, 1);
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
index ee88c72a..2e1883af 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 16, 2);
-DECL_FATTN_MMA_F16_CASE(80, 16, 2);
-DECL_FATTN_MMA_F16_CASE(96, 16, 2);
-DECL_FATTN_MMA_F16_CASE(112, 16, 2);
-DECL_FATTN_MMA_F16_CASE(128, 16, 2);
-DECL_FATTN_MMA_F16_CASE(256, 16, 2);
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
index d888a5a4..2074e954 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 16, 4);
-DECL_FATTN_MMA_F16_CASE(80, 16, 4);
-DECL_FATTN_MMA_F16_CASE(96, 16, 4);
-DECL_FATTN_MMA_F16_CASE(112, 16, 4);
-DECL_FATTN_MMA_F16_CASE(128, 16, 4);
-DECL_FATTN_MMA_F16_CASE(256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
new file mode 100644
index 00000000..f011a208
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
index d93a2d08..24c64cf0 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 2, 4);
-DECL_FATTN_MMA_F16_CASE(80, 2, 4);
-DECL_FATTN_MMA_F16_CASE(96, 2, 4);
-DECL_FATTN_MMA_F16_CASE(112, 2, 4);
-DECL_FATTN_MMA_F16_CASE(128, 2, 4);
-DECL_FATTN_MMA_F16_CASE(256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
index 617464c9..163b1d93 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 2, 8);
-DECL_FATTN_MMA_F16_CASE(80, 2, 8);
-DECL_FATTN_MMA_F16_CASE(96, 2, 8);
-DECL_FATTN_MMA_F16_CASE(112, 2, 8);
-DECL_FATTN_MMA_F16_CASE(128, 2, 8);
-DECL_FATTN_MMA_F16_CASE(256, 2, 8);
+DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
index 970d2b68..0543532e 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 32, 1);
-DECL_FATTN_MMA_F16_CASE(80, 32, 1);
-DECL_FATTN_MMA_F16_CASE(96, 32, 1);
-DECL_FATTN_MMA_F16_CASE(112, 32, 1);
-DECL_FATTN_MMA_F16_CASE(128, 32, 1);
-DECL_FATTN_MMA_F16_CASE(256, 32, 1);
+DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
index 65cd377c..407b6cf4 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 32, 2);
-DECL_FATTN_MMA_F16_CASE(80, 32, 2);
-DECL_FATTN_MMA_F16_CASE(96, 32, 2);
-DECL_FATTN_MMA_F16_CASE(112, 32, 2);
-DECL_FATTN_MMA_F16_CASE(128, 32, 2);
-DECL_FATTN_MMA_F16_CASE(256, 32, 2);
+DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
new file mode 100644
index 00000000..f5fd0e23
--- /dev/null
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
index f4a8bf34..5e466850 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 4, 2);
-DECL_FATTN_MMA_F16_CASE(80, 4, 2);
-DECL_FATTN_MMA_F16_CASE(96, 4, 2);
-DECL_FATTN_MMA_F16_CASE(112, 4, 2);
-DECL_FATTN_MMA_F16_CASE(128, 4, 2);
-DECL_FATTN_MMA_F16_CASE(256, 4, 2);
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
index de191a8a..1ada657f 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 4, 4);
-DECL_FATTN_MMA_F16_CASE(80, 4, 4);
-DECL_FATTN_MMA_F16_CASE(96, 4, 4);
-DECL_FATTN_MMA_F16_CASE(112, 4, 4);
-DECL_FATTN_MMA_F16_CASE(128, 4, 4);
-DECL_FATTN_MMA_F16_CASE(256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
index e8cb0e1b..bad296b4 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 4, 8);
-DECL_FATTN_MMA_F16_CASE(80, 4, 8);
-DECL_FATTN_MMA_F16_CASE(96, 4, 8);
-DECL_FATTN_MMA_F16_CASE(112, 4, 8);
-DECL_FATTN_MMA_F16_CASE(128, 4, 8);
-DECL_FATTN_MMA_F16_CASE(256, 4, 8);
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
index a532e962..0d7a9c72 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 64, 1);
-DECL_FATTN_MMA_F16_CASE(80, 64, 1);
-DECL_FATTN_MMA_F16_CASE(96, 64, 1);
-DECL_FATTN_MMA_F16_CASE(112, 64, 1);
-DECL_FATTN_MMA_F16_CASE(128, 64, 1);
-DECL_FATTN_MMA_F16_CASE(256, 64, 1);
+DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
index bf25181a..9d5a9976 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 8, 1);
-DECL_FATTN_MMA_F16_CASE(80, 8, 1);
-DECL_FATTN_MMA_F16_CASE(96, 8, 1);
-DECL_FATTN_MMA_F16_CASE(112, 8, 1);
-DECL_FATTN_MMA_F16_CASE(128, 8, 1);
-DECL_FATTN_MMA_F16_CASE(256, 8, 1);
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
index 378c132e..a6e6f093 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 8, 2);
-DECL_FATTN_MMA_F16_CASE(80, 8, 2);
-DECL_FATTN_MMA_F16_CASE(96, 8, 2);
-DECL_FATTN_MMA_F16_CASE(112, 8, 2);
-DECL_FATTN_MMA_F16_CASE(128, 8, 2);
-DECL_FATTN_MMA_F16_CASE(256, 8, 2);
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
index 372641be..86d4ffae 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 8, 4);
-DECL_FATTN_MMA_F16_CASE(80, 8, 4);
-DECL_FATTN_MMA_F16_CASE(96, 8, 4);
-DECL_FATTN_MMA_F16_CASE(112, 8, 4);
-DECL_FATTN_MMA_F16_CASE(128, 8, 4);
-DECL_FATTN_MMA_F16_CASE(256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
index 9ff5968b..680a13ca 100644
--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 8, 8);
-DECL_FATTN_MMA_F16_CASE(80, 8, 8);
-DECL_FATTN_MMA_F16_CASE(96, 8, 8);
-DECL_FATTN_MMA_F16_CASE(112, 8, 8);
-DECL_FATTN_MMA_F16_CASE(128, 8, 8);
-DECL_FATTN_MMA_F16_CASE(256, 8, 8);
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
index dd373a09..3428113d 100755
--- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
+++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
 
 """
 
-SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
+SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
 
 TYPES_MMQ = [
     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,18 +57,21 @@ for vkq_size in [16, 32]:
                 with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
                     f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
 
-for ncols in [8, 16, 32, 64, 128]:
-    for ncols2 in [1, 2, 4, 8]:
+for ncols in [8, 16, 32, 64]:
+    for ncols2 in [1, 2, 4, 8, 16]:
+        if ncols2 > ncols:
+            continue
         ncols1 = ncols // ncols2
-        if ncols == 128:
-            continue  # Too much register pressure.
         with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
             f.write(SOURCE_FATTN_MMA_START)
 
-            for head_size in [64, 80, 96, 112, 128, 256]:
-                if ncols == 128 and head_size == 256:
-                    continue  # Needs too much shared memory.
-                f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
+            for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
+                if head_size_kq != 576 and ncols2 == 16:
+                    continue
+                if head_size_kq == 576 and ncols2 != 16:
+                    continue
+                head_size_v = head_size_kq if head_size_kq != 576 else 512
+                f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
 
 for type in TYPES_MMQ:
     with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: