diff --git a/ggml/src/ggml-cann/CMakeLists.txt b/ggml/src/ggml-cann/CMakeLists.txt old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/Doxyfile b/ggml/src/ggml-cann/Doxyfile old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp old mode 100644 new mode 100755 index f5462c5a..f311864d --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) { return ACL_FLOAT; case GGML_TYPE_F16: return ACL_FLOAT16; + case GGML_TYPE_BF16: + return ACL_BF16; case GGML_TYPE_I8: return ACL_INT8; case GGML_TYPE_I16: diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp old mode 100644 new mode 100755 index 9c67664a..437ece2d --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -66,6 +66,7 @@ #include #include #include +#include #include #include @@ -74,11 +75,13 @@ #include #include "ggml-impl.h" +#include "ggml.h" #define GGML_COMMON_DECL_C #include "../ggml-common.h" + void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0)); @@ -2861,3 +2864,330 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { break; } } + +void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + + ggml_tensor* src0 = dst->src[0]; // q, fp32 + ggml_tensor* src1 = dst->src[1]; // k, fp16 + ggml_tensor* src2 = dst->src[2]; // v, fp16 + ggml_tensor* src3 = dst->src[3]; // mask, fp16 + + float maxBias = 0.0f; + float scaleValue = 1.0f; + float logitSoftcap = 0.0f; + memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float)); + memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float)); + memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float)); + + if(logitSoftcap == 0.0f){ + size_t faElemSize = sizeof(uint16_t); + auto faDataType = ACL_FLOAT16; //ACL_BF16; + + aclTensor* acl_src0_f16_tensor = nullptr; + aclTensor* acl_src1_f16_tensor = nullptr; + aclTensor* acl_src2_f16_tensor = nullptr; + aclTensor* acl_dst_f16_tensor = nullptr; + + // Step 1: cast the src0 (Query) to fp16 if needed + ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); + void* src0_f16_buffer = nullptr; + + if(ggml_cann_type_mapping(src0->type) != faDataType){ + aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); + src0_f16_buffer = src0_f16_allocator.alloc( + ggml_nelements(src0) * faElemSize); + + int64_t* src0_f16_ne = src0->ne; + size_t src0_f16_nb[GGML_MAX_DIMS]; + src0_f16_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; + } + + acl_src0_f16_tensor = ggml_cann_create_tensor( + src0_f16_buffer, faDataType, faElemSize, + src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS + ); + aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); + ggml_cann_release_resources(ctx, acl_src0_f32_tensor); + }else{ + acl_src0_f16_tensor = ggml_cann_create_tensor(src0); + } + + // Step 2: create the acl tensors for src1 (Key), src2 (Value), + // and the direct output from FusedInferAttention + + acl_src1_f16_tensor = ggml_cann_create_tensor(src1); + acl_src2_f16_tensor = ggml_cann_create_tensor(src2); + + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + void* out_f16_buffer = out_f16_allocator.alloc( + ggml_nelements(dst) * faElemSize); + + int64_t* out_f16_ne = src0->ne; + size_t out_f16_nb[GGML_MAX_DIMS]; + out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + } + + acl_dst_f16_tensor = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, + out_f16_ne, out_f16_nb, GGML_MAX_DIMS + ); + + // Step 3: create the PSEShift tensor if needed + // this tensor is considered as mask (f16) in the llama.cpp + + aclTensor* bcast_pse_tensor = nullptr; + int64_t bcast_pse_ne[GGML_MAX_DIMS]; + size_t bcast_pse_nb[GGML_MAX_DIMS]; + ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); + void* bcast_pse_buffer = nullptr; + + if(src3 != nullptr){ + bcast_pse_buffer = bcast_pse_allocator.alloc( + ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); + + if(src0->ne[1] > 1){ + // Case 1: broadcast pse for prefill stage with multiple head + aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src3->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + + bcast_pse_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } + + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); + + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); + + ggml_cann_release_resources(ctx, acl_mask_f16_tensor); + }else{ + // Case 2: trunc the first row and broadcast pse for decode stage with multiple head + int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; + size_t* trunc_pse_nb = src3->nb; + + aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( + src3->data, ACL_FLOAT16, sizeof(uint16_t), + trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS); + + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src0->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + + bcast_pse_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } + + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); + + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); + + ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); + } + + // Compute the slope if needed. Derived from ggml_cann_softmax(). + if(maxBias != 0.0f){ + // alibi + const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; + const int64_t n_head = src0->ne[2]; + const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); + float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); + float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); + // init arange + ggml_cann_pool_alloc arange_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_arange_buffer = arange_allocator.get(); + + // arange1: [1, ..., n_heads_log2_floor+1) + float start = 1; + float stop = n_heads_log2_floor + 1; + float step = 1; + int64_t n_elements_arange = n_heads_log2_floor; + + int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; + size_t tmp_arange1_nb[] = {faElemSize}; + aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_arange1_ne, tmp_arange1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); + + aclTensor* tmp_arange2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) + start = 1; + stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; + step = 2; + n_elements_arange = ne2_ne3 - n_heads_log2_floor; + int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_arange2_nb[] = {faElemSize}; + + aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( + (char*)tmp_arange_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, + n_elements_arange); + } + + // init mk_base + ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_mk_base_buffer = mk_base_allocator.get(); + int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; + size_t tmp_mk_base1_nb[] = {faElemSize}; + aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base1_ne, tmp_mk_base1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); + + aclTensor* tmp_mk_base2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_mk_base2_nb[] = {faElemSize}; + aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( + (char*)tmp_mk_base_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); + } + + // init mk + int64_t tmp_mk_base_ne[] = {ne2_ne3}; + size_t tmp_mk_base_nb[] = {faElemSize}; + aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); + + // reshape mk + int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]}; + size_t tmp_mk_nb[GGML_MAX_DIMS]; + tmp_mk_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; + } + aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, + ACL_FORMAT_ND); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); + + ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, + tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, + tmp_arange_tensor, tmp_mk_tensor); + } + } + + // Step 4: set the inputs for FusedInferAttention. + int kvTensorNum = 1; + aclTensor* acl_q_tensor = acl_src0_f16_tensor; + aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; + aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; + auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); + auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); + + int64_t numHeads = src0->ne[2]; // N + int64_t numKeyValueHeads = src1->ne[2]; + // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) + int64_t preTokens = 65535; + int64_t nextTokens = 65535; + char layout[5] = {'B', 'N', 'S', 'D', 0}; + int64_t sparseMode = 0; + int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; + int64_t blockSize = 0; + int64_t antiquantMode = 0; + bool softmaxLseFlag = false; + int64_t keyAntiquantMode = 0; + int64_t valueAntiquantMode = 0; + + // Step 5: launch the FusedInferAttentionScoreV2 kernel. + // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, + acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v + bcast_pse_tensor, nullptr, // pse, mask + nullptr, nullptr, // actSeqLen, actSeqLenkv + nullptr, nullptr, // deqScale1, quantScale1 + nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // qPadSize, kvPadSize + nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset + nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset + nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen + numHeads, scaleValue, // heads, scaleValue + preTokens, nextTokens, // preTokens, nextTokens + layout, // inputLayout + numKeyValueHeads, // numKVHeads + sparseMode, innerPrecise, // sparseMode, innerPrecise + blockSize, antiquantMode, // blockSize, antiquantMode + softmaxLseFlag, // softmaxLseFlag + keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode + acl_dst_f16_tensor, // attentionOut + nullptr // softmaxLse + ); + + // Step 6: post-processing, permute and cast to f32 + + int64_t new_dim[] = {0, 2, 1, 3}; + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + + if(ggml_cann_type_mapping(dst->type) != faDataType){ + ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); + perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); + void* perm_out_f16_buffer = perm_out_f16_allocator.get(); + + int64_t* perm_out_f16_ne = dst->ne; + size_t perm_out_f16_nb[GGML_MAX_DIMS]; + perm_out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; + } + aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( + perm_out_f16_buffer, faDataType, faElemSize, + perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); + aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); + aclnn_cast(ctx, + acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); + ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); + }else{ + // only need to permute + aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); + } + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, + acl_src1_f16_tensor, + acl_src2_f16_tensor, + acl_dst_f16_tensor, + acl_dst_tensor); + if(src3 != nullptr){ + ggml_cann_release_resources(ctx, bcast_pse_tensor); + } + }else{ + GGML_ABORT("Function is not implemented."); + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h old mode 100644 new mode 100755 index 15993cce..80ce80ba --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Performs the Flash Attention extended operator using the CANN backend. + * + * @details This function implements the memory-efficient Flash Attention algorithm + * for computing scaled dot-product attention with hardware acceleration. + * The result is stored in the destination tensor `dst`. + * + * This operation is accelerated using the CANN backend to improve runtime performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`. + */ +void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /* * @brief A generic wrapper for ACL resources with custom deleter support. */ diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp old mode 100644 new mode 100755 index 605b6a73..c0ea2600 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -36,6 +36,7 @@ #include "ggml-backend-impl.h" #include "ggml-cann/aclnn_ops.h" #include "ggml-cann/common.h" +#include "ggml.h" #define GGML_COMMON_DECL_C @@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_COUNT_EQUAL: ggml_cann_count_equal(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cann_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -2177,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_FLASH_ATTN_EXT:{ + // derived from [ggml-cuda.cu] + if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ + return false; + } + if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){ + return false; + } + if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){ + return false; + } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } + if (op->src[0]->ne[0] == 192) { + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek MLA + return false; + } + if (op->src[0]->ne[3] != 1) { + return false; + } + float logitSoftcap = 0.0f; + memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float)); + if(logitSoftcap != 0.0f) { + return false; + } + return true; + } default: return false; }