diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 0b409ce8..259a2928 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -753,69 +753,55 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; - aclTensor* acl_src = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); if (ggml_are_same_shape(src0, dst)) { + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); if (dst->type == src0->type) { cann_copy(ctx, acl_src, acl_dst); } else { aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type)); } + ggml_cann_release_resources(ctx, acl_src, acl_dst); } else { - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { - if (dst->type == src0->type) { - size_t cpy_size = ggml_nbytes(dst); - ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE); - return; - } else { - ggml_cann_pool_alloc src_buffer_allocator( - ctx.pool(), - ggml_nelements(dst) * ggml_type_size(dst->type)); - void* src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = ggml_type_size(dst->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; - } - aclTensor* src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), src0->ne, src_trans_nb, - GGML_MAX_DIMS); - - aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type)); - size_t cpy_size = ggml_nbytes(dst); - ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE); - ggml_cann_release_resources(ctx, src_trans_tensor); - return; - } - } else if (ggml_is_contiguous(dst)) { - ggml_cann_pool_alloc src_buffer_allocator( - ctx.pool(), ggml_nelements(dst) * ggml_type_size(dst->type)); - void* src_trans_buffer = src_buffer_allocator.get(); + void* src_trans_buffer = src0->data; + ggml_cann_pool_alloc src_buffer_allocator; + if (!ggml_is_contiguous(src0)) { + aclTensor* acl_src = ggml_cann_create_tensor(src0); + src_buffer_allocator.alloc(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(src0->type)); + src_trans_buffer = src_buffer_allocator.get(); size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = ggml_type_size(dst->type); + src_trans_nb[0] = ggml_type_size(src0->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; } aclTensor* src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), src0->ne, src_trans_nb, + src_trans_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - - aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type)); - - size_t cpy_size = ggml_nbytes(dst); - ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE); - ggml_cann_release_resources(ctx, src_trans_tensor); - return; - } else { - GGML_ABORT("Unsupport dst is not contiguous."); + cann_copy(ctx, acl_src, src_trans_tensor); + ggml_cann_release_resources(ctx, acl_src, src_trans_tensor); } + + size_t src_reshape_nb[GGML_MAX_DIMS]; + src_reshape_nb[0] = ggml_type_size(src0->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_reshape_nb[i] = src_reshape_nb[i - 1] * dst->ne[i - 1]; + } + + aclTensor* trans_acl_src = ggml_cann_create_tensor(src_trans_buffer, + ggml_cann_type_mapping(src0->type),ggml_type_size(src0->type), + dst->ne, src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + + if (dst->type == src0->type) { + cann_copy(ctx, trans_acl_src, acl_dst); + } else { + aclnn_cast(ctx, trans_acl_src, acl_dst, ggml_cann_type_mapping(dst->type)); + } + ggml_cann_release_resources(ctx, trans_acl_src, acl_dst); } - ggml_cann_release_resources(ctx, acl_src, acl_dst); + return; } /** diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3d3520f1..cb8af42e 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2391,7 +2391,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // only support F32 and F16. return false; } - return ggml_is_contiguous(op); + return true; } break; case GGML_OP_CONT: { // TODO: support GGML_TYPE_BF16 @@ -2457,7 +2457,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2)); } case GGML_OP_DUP: - return ggml_is_contiguous(op); case GGML_OP_SUM: case GGML_OP_IM2COL: case GGML_OP_CONCAT: