mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-19 17:28:09 +02:00
CUDA: skip fully masked-out KV in FA vec kernel (llama/13584)
* CUDA: skip fully masked-out KV in FA vec kernel
This commit is contained in:
parent
f44b53480f
commit
926fe234e9
@ -2,9 +2,9 @@
|
|||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#ifndef GGML_USE_HIP
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // GGML_USE_HIP
|
||||||
static __global__ void flash_attn_vec_ext_f16(
|
static __global__ void flash_attn_vec_ext_f16(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
if (ncols > 1) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
kqsum_shared[j][threadIdx.x] = 0.0f;
|
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__shared__ half maskh_shared[ncols*D];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
maskh_shared[j*D + tid] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
|
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||||
@ -175,6 +188,35 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
|
if (mask) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
|
||||||
|
// In such cases, skip the KV slice.
|
||||||
|
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
|
||||||
|
#ifndef GGML_USE_HIP
|
||||||
|
bool skip = true;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
|
||||||
|
skip = skip && isinf(tmp.x) && isinf(tmp.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif // GGML_USE_HIP
|
||||||
|
}
|
||||||
|
|
||||||
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
|
||||||
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
|
||||||
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
|
||||||
@ -202,7 +244,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
sum = logit_softcap*tanhf(sum);
|
sum = logit_softcap*tanhf(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += maskh_shared[j*D + i_KQ];
|
||||||
|
|
||||||
if (ncols == 1) {
|
if (ncols == 1) {
|
||||||
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
||||||
@ -335,7 +377,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|||||||
float logit_softcap;
|
float logit_softcap;
|
||||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
|
|
||||||
|
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#ifndef GGML_USE_HIP
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
#endif // GGML_USE_HIP
|
||||||
static __global__ void flash_attn_vec_ext_f32(
|
static __global__ void flash_attn_vec_ext_f32(
|
||||||
const char * __restrict__ Q,
|
const char * __restrict__ Q,
|
||||||
const char * __restrict__ K,
|
const char * __restrict__ K,
|
||||||
@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
if (ncols > 1) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
kqsum_shared[j][threadIdx.x] = 0.0f;
|
kqsum_shared[j][threadIdx.x] = 0.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__shared__ float maskf_shared[ncols*D];
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
maskf_shared[j*D + tid] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||||
@ -181,6 +194,34 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
|
|
||||||
|
if (mask) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
|
||||||
|
// In such cases, skip the KV slice.
|
||||||
|
// On AMD __all_sync would not work correctly because it assumes a warp size of 64.
|
||||||
|
#ifndef GGML_USE_HIP
|
||||||
|
bool skip = true;
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||||
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
|
skip = skip && isinf(maskf_shared[j*D + i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (__all_sync(0xFFFFFFFF, skip)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif // GGML_USE_HIP
|
||||||
|
}
|
||||||
|
|
||||||
float kqmax_new_arr[ncols];
|
float kqmax_new_arr[ncols];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
@ -204,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
sum = logit_softcap*tanhf(sum);
|
sum = logit_softcap*tanhf(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
sum += maskf_shared[j*D + i_KQ];
|
||||||
|
|
||||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||||
|
|
||||||
@ -326,7 +367,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
|
|||||||
float logit_softcap;
|
float logit_softcap;
|
||||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
|
|
||||||
|
if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
if (logit_softcap == 0.0f) {
|
if (logit_softcap == 0.0f) {
|
||||||
constexpr bool use_logit_softcap = false;
|
constexpr bool use_logit_softcap = false;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user