From d1286cf32b1b18219d33b692ee52383c25428e7f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 12 Jul 2025 14:33:49 +0300 Subject: [PATCH] ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (llama/14435) --- ggml/CMakeLists.txt | 7 - ggml/include/ggml.h | 26 +- ggml/src/ggml-cann/ggml-cann.cpp | 7 +- ggml/src/ggml-cpu/ops.cpp | 115 ++++----- ggml/src/ggml-cuda/ggml-cuda.cu | 11 +- ggml/src/ggml-metal/ggml-metal-impl.h | 20 +- ggml/src/ggml-metal/ggml-metal.m | 23 +- ggml/src/ggml-metal/ggml-metal.metal | 72 +++--- ggml/src/ggml-sycl/ggml-sycl.cpp | 10 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 222 +++++++++--------- .../ggml-vulkan/vulkan-shaders/upscale.comp | 74 +----- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 - ggml/src/ggml.c | 20 +- 13 files changed, 304 insertions(+), 305 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 675f3bf7..215eb234 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -360,13 +360,6 @@ write_basic_package_version_file( VERSION ${GGML_INSTALL_VERSION} COMPATIBILITY SameMajorVersion) -target_compile_definitions(ggml-base PRIVATE - GGML_VERSION="${GGML_INSTALL_VERSION}" - GGML_COMMIT="${GGML_BUILD_COMMIT}" -) -message(STATUS "ggml version: ${GGML_INSTALL_VERSION}") -message(STATUS "ggml commit: ${GGML_BUILD_COMMIT}") - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5ceccb8f..a92495db 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -646,9 +646,6 @@ extern "C" { // misc - GGML_API const char * ggml_version(void); - GGML_API const char * ggml_commit(void); - GGML_API void ggml_time_init(void); // call this once at the beginning of the program GGML_API int64_t ggml_time_ms(void); GGML_API int64_t ggml_time_us(void); @@ -1513,8 +1510,14 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // a [ne0, ne01, ne02, ne03] + // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional + // + // broadcast: + // ne02 % ne12 == 0 + // ne03 % ne13 == 0 + // // fused soft_max(a*scale + mask*(ALiBi slope)) - // mask is optional // max_bias = 0.0f for no ALiBi GGML_API struct ggml_tensor * ggml_soft_max_ext( struct ggml_context * ctx, @@ -1977,11 +1980,16 @@ extern "C" { #define GGML_KQ_MASK_PAD 64 - // q: [n_embd_k, n_batch, n_head, 1] - // k: [n_embd_k, n_kv, n_head_kv, 1] - // v: [n_embd_v, n_kv, n_head_kv, 1] !! not transposed !! - // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd_v, n_head, n_batch, 1] !! permuted !! + // q: [n_embd_k, n_batch, n_head, ne3] + // k: [n_embd_k, n_kv, n_head_kv, ne3] + // v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !! + // mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd_v, n_head, n_batch, ne3] !! permuted !! + // + // broadcast: + // n_head % n_head_kv == 0 + // ne3 % ne32 == 0 + // GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index d1a0ad37..8a3d2026 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2187,7 +2187,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_SQRT: case GGML_OP_CLAMP: case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: @@ -2205,6 +2204,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_SOFT_MAX: + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_FLASH_ATTN_EXT:{ // derived from [ggml-cuda.cu] if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ @@ -2227,6 +2230,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, // DeepSeek MLA return false; } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index dd83efde..2ae0721e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5232,14 +5232,17 @@ static void ggml_compute_forward_soft_max_f32( memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - // TODO: handle transposed/permuted matrices - const int ith = params->ith; const int nth = params->nth; GGML_TENSOR_UNARY_OP_LOCALS - //const int64_t ne11 = src1 ? src1->ne[1] : 1; + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -5249,68 +5252,66 @@ static void ggml_compute_forward_soft_max_f32( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - for (int i1 = ir0; i1 < ir1; i1++) { - // ALiBi - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const int64_t i11 = i01; + const int64_t i12 = i02%ne12; + const int64_t i13 = i03%ne13; - float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + // ALiBi + const uint32_t h = i02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL; + + ggml_vec_cpy_f32 (ne00, wp, sp); + ggml_vec_scale_f32(ne00, wp, scale); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < ne00; ++i) { + wp[i] += slope*mp_f32[i]; + } + } } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; + +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(ne00, &max, wp); + + ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max); + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(ne00, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < ne00; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif } } - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } -#endif - - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); - - ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(nc, dp, sum); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif } } @@ -7766,7 +7767,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; + ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; @@ -7798,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(VKQ32, 0, DV*sizeof(float)); } - const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL; // k indices const int ik3 = iq3 / rk3; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 086f9a56..186492be 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3327,8 +3327,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: return true; + case GGML_OP_SOFT_MAX: + // TODO: support batching + if (op->src[0]->ne[3] != 1) { + return false; + } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_SOFT_MAX_BACK: { float max_bias = 0.0f; memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); @@ -3375,6 +3382,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (op->src[0]->ne[0] == 192) { return false; } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7a9aab31..8c2a9719 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -229,7 +229,9 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne32; uint64_t nb31; + uint64_t nb32; int32_t ne1; int32_t ne2; float scale; @@ -461,9 +463,21 @@ typedef struct { } ggml_metal_kargs_sum_rows; typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float scale; float max_bias; float m0; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 12a36695..1a0968ed 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1725,7 +1725,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); @@ -2644,10 +2644,7 @@ static bool ggml_metal_encode_node( memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -2707,6 +2704,18 @@ static bool ggml_metal_encode_node( /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, /*.scale =*/ scale, /*.max_bias =*/ max_bias, /*.m0 =*/ m0, @@ -2726,7 +2735,7 @@ static bool ggml_metal_encode_node( [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_DIAG_MASK_INF: { @@ -4979,7 +4988,9 @@ static bool ggml_metal_encode_node( /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne32 =*/ ne32, /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, /*.scale =*/ scale, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index dac45c7a..dfff6697 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1320,24 +1320,28 @@ kernel void kernel_soft_max( device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (args.ne02*args.ne01); - const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; - const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + uint3 tptg[[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x; - device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); + const int32_t i13 = i03%args.ne13; + const int32_t i12 = i02%args.ne12; + const int32_t i11 = i01; + + device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; // ALiBi if (args.max_bias > 0.0f) { - const int64_t h = i02; + const int32_t h = i02; const float base = h < args.n_head_log2 ? args.m0 : args.m1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; @@ -1348,13 +1352,13 @@ kernel void kernel_soft_max( // parallel max float lmax = -INFINITY; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); } // find the max value in the block float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = -INFINITY; } @@ -1373,7 +1377,7 @@ kernel void kernel_soft_max( // parallel sum float lsum = 0.0f; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); lsum += exp_psrc0; pdst[i00] = exp_psrc0; @@ -1385,7 +1389,7 @@ kernel void kernel_soft_max( float sum = simd_sum(lsum); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; } @@ -1404,7 +1408,7 @@ kernel void kernel_soft_max( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { pdst[i00] *= inv_sum; } } @@ -1416,23 +1420,27 @@ kernel void kernel_soft_max_4( device char * dst, constant ggml_metal_kargs_soft_max & args, threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (args.ne02*args.ne01); - const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01; - const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); + uint3 tptg[[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x; - device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; + const int32_t i13 = i03%args.ne13; + const int32_t i12 = i02%args.ne12; + const int32_t i11 = i01; + + device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; + device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); float slope = 1.0f; if (args.max_bias > 0.0f) { - const int64_t h = i02; + const int32_t h = i02; const float base = h < args.n_head_log2 ? args.m0 : args.m1; const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; @@ -1443,14 +1451,14 @@ kernel void kernel_soft_max_4( // parallel max float4 lmax4 = -INFINITY; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = -INFINITY; } @@ -1469,7 +1477,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; @@ -1483,7 +1491,7 @@ kernel void kernel_soft_max_4( float sum = simd_sum(lsum); - if (ntg > N_SIMDWIDTH) { + if (tptg.x > N_SIMDWIDTH) { if (sgitg == 0) { buf[tiisg] = 0.0f; } @@ -1502,7 +1510,7 @@ kernel void kernel_soft_max_4( const float inv_sum = 1.0f/sum; - for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { pdst4[i00] *= inv_sum; } } @@ -3776,7 +3784,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32); const float m = pm[ic + tiisg]; @@ -4262,7 +4270,7 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*args.nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32); float slope = 1.0f; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ae5e0625..1d41d7a4 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4395,9 +4395,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; - case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: - return true; + // TODO: support batching + if (op->src[0]->ne[3] != 1) { + return false; + } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); + case GGML_OP_DIAG_MASK_INF: case GGML_OP_ROPE: case GGML_OP_IM2COL: return true; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 09157bf5..b8e25ba2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -410,14 +410,13 @@ struct vk_device_struct { vk_pipeline pipeline_div_norepeat[2][2][2]; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; - vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; + vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_sin_f32; vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; - vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; @@ -689,37 +688,6 @@ struct vk_op_unary_push_constants { }; static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); -static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) { - GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst))); - ne = ne != 0 ? ne : ggml_nelements(dst); - GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); - - vk_op_unary_push_constants p{}; - p.ne = (uint32_t)ne; - - size_t src0_tsize = ggml_type_size(src0->type); - p.ne00 = (uint32_t)src0->ne[0]; - p.ne01 = (uint32_t)src0->ne[1]; - p.ne02 = (uint32_t)src0->ne[2]; - p.ne03 = (uint32_t)src0->ne[3]; - p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); - p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); - p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); - p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); - - size_t dst_tsize = ggml_type_size(dst->type); - p.ne10 = (uint32_t)dst->ne[0]; - p.ne11 = (uint32_t)dst->ne[1]; - p.ne12 = (uint32_t)dst->ne[2]; - p.ne13 = (uint32_t)dst->ne[3]; - p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); - p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); - p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); - p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); - - return p; // fastdiv values and offsets are initialized later in ggml_vk_op -} - // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. // Precompute mp (m' in the paper) and L such that division // can be computed using a multiply (high 32b of 64b result) @@ -881,7 +849,6 @@ struct vk_op_conv2d_dw_push_constants { struct vk_op_upscale_push_constants { uint32_t ne; uint32_t a_offset; uint32_t d_offset; - uint32_t ne00; uint32_t ne01; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; float sf0; float sf1; float sf2; float sf3; @@ -2775,9 +2742,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); - ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); - ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -2789,8 +2754,6 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -6453,16 +6416,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UPSCALE: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - int mode = ggml_get_op_params_i32(dst, 0); - switch (mode) { - case GGML_SCALE_MODE_NEAREST: - return ctx->device->pipeline_upscale_nearest_f32; - case GGML_SCALE_MODE_BILINEAR: - return ctx->device->pipeline_upscale_bilinear_f32; - case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS: - return ctx->device->pipeline_upscale_bilinear_ac_f32; - } + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) { + return ctx->device->pipeline_upscale_f32; } return nullptr; case GGML_OP_SCALE: @@ -6495,11 +6450,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_pad_f32; } return nullptr; - case GGML_OP_ROLL: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_roll_f32; - } - return nullptr; case GGML_OP_REPEAT: if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { return ctx->device->pipeline_repeat_f32; @@ -7042,7 +6992,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: - case GGML_OP_ROLL: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_CPY: @@ -7479,21 +7428,14 @@ static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, co static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const uint32_t src0_type_size = ggml_type_size(src0->type); - const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0); - float sf0 = (float)dst->ne[0] / src0->ne[0]; - float sf1 = (float)dst->ne[1] / src0->ne[1]; - float sf2 = (float)dst->ne[2] / src0->ne[2]; - float sf3 = (float)dst->ne[3] / src0->ne[3]; - - if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); - sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); - } + const float sf0 = (float)dst->ne[0] / src0->ne[0]; + const float sf1 = (float)dst->ne[1] / src0->ne[1]; + const float sf2 = (float)dst->ne[2] / src0->ne[2]; + const float sf3 = (float)dst->ne[3] / src0->ne[3]; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { (uint32_t)ggml_nelements(dst), 0, 0, - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], sf0, sf1, sf2, sf3, @@ -7501,60 +7443,117 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c } static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); - p.param1 = ggml_get_op_params_f32(dst, 0); + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); - p.param1 = ggml_get_op_params_f32(dst, 0); - p.param2 = ggml_get_op_params_f32(dst, 1); + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], op_params[1], + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); -} + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); -static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - const int32_t s0 = ggml_get_op_params_i32(dst, 0); - const int32_t s1 = ggml_get_op_params_i32(dst, 1); - const int32_t s2 = ggml_get_op_params_i32(dst, 2); - const int32_t s3 = ggml_get_op_params_i32(dst, 3); - const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000); - const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000); - - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); - memcpy(&p.param1, &s01_packed, sizeof(float)); - memcpy(&p.param2, &s23_packed, sizeof(float)); - - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { @@ -7572,8 +7571,14 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const } } - vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { + ne, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -8885,7 +8890,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: - case GGML_OP_ROLL: case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: @@ -9055,10 +9059,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_PAD: ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); - break; - case GGML_OP_ROLL: - ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun); - break; case GGML_OP_CPY: case GGML_OP_CONT: @@ -9276,7 +9276,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_COS: case GGML_OP_CLAMP: case GGML_OP_PAD: - case GGML_OP_ROLL: case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: @@ -10249,6 +10248,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + if (op->src[0]->ne[3] != 1) { + return false; + } // It's straightforward to support different K/V dequant, but would // significantly increase the number of pipelines if (op->src[1]->type != op->src[2]->type) { @@ -10401,13 +10405,21 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_CLAMP: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: + return op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_ACC: case GGML_OP_CONCAT: case GGML_OP_SCALE: case GGML_OP_PAD: - case GGML_OP_ROLL: case GGML_OP_DIAG_MASK_INF: + return true; case GGML_OP_SOFT_MAX: + // TODO: support batching + if (op->src[0]->ne[3] != 1) { + return false; + } + // TODO: support broadcast + // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1); case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ARGSORT: case GGML_OP_SUM: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index d78c6190..6f607380 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -3,7 +3,6 @@ layout (push_constant) uniform parameter { uint ne; uint a_offset; uint d_offset; - uint ne00; uint ne01; uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; float sf0; float sf1; float sf2; float sf3; @@ -16,61 +15,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; -// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag -#define NEAREST 0 -#define BILINEAR 1 -#define ALIGN_CORNERS (1 << 8) - -layout (constant_id = 0) const uint scale_mode = 0; - -float fetch_nearest(uint i10, uint i11, uint i12, uint i13) { - const uint i00 = uint(i10 / p.sf0); - const uint i01 = uint(i11 / p.sf1); - const uint i02 = uint(i12 / p.sf2); - const uint i03 = uint(i13 / p.sf3); - - return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]; -} - -float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) { - const uint i02 = uint(i12 / p.sf2); - const uint i03 = uint(i13 / p.sf3); - const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02; - - const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00]; - const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00]; - const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00]; - const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00]; - - return - v00 * (1.0-d.x) * (1.0-d.y) + - v01 * d.x * (1.0-d.y) + - v10 * (1.0-d.x) * d.y + - v11 * d.x * d.y; -} - -float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { - const ivec2 ne0 = ivec2(p.ne00, p.ne01); - - const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5; - const vec2 c0f = floor(c); - const vec2 d = c - c0f; - const ivec2 c0 = max(ivec2(c0f), 0); - const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1); - - return fetch_bilinear(c0, c1, d, i12, i13); -} - -float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) { - const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1); - const vec2 c0f = floor(c); - const vec2 d = c - c0f; - const ivec2 c0 = ivec2(c0f); - const ivec2 c1 = c0 + 1; - - return fetch_bilinear(c0, c1, d, i12, i13); -} - void main() { const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; @@ -83,18 +27,10 @@ void main() { const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; - float result; - switch (scale_mode) { - case NEAREST: - result = fetch_nearest(i10, i11, i12, i13); - break; - case BILINEAR: - result = interpolate_bilinear(i10, i11, i12, i13); - break; - case BILINEAR | ALIGN_CORNERS: - result = interpolate_bilinear_align_corners(i10, i11, i12, i13); - break; - } + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); - data_d[p.d_offset + idx] = D_TYPE(result); + data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4658ad0c..297a2a77 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -644,8 +644,6 @@ void process_shaders() { string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); - string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - for (auto &c : compiles) { c.wait(); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 814f01b7..f1245c50 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -473,14 +473,6 @@ bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b) { return memcmp(guid_a, guid_b, sizeof(ggml_guid)) == 0; } -const char * ggml_version(void) { - return GGML_VERSION; -} - -const char * ggml_commit(void) { - return GGML_COMMIT; -} - // // timing // @@ -3674,9 +3666,11 @@ static struct ggml_tensor * ggml_soft_max_impl( if (mask) { GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(ggml_is_matrix(mask)); + GGML_ASSERT(ggml_is_3d(mask)); GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); + GGML_ASSERT(a->ne[2]%mask->ne[2] == 0); + GGML_ASSERT(a->ne[3]%mask->ne[3] == 0); } if (max_bias > 0.0f) { @@ -4697,13 +4691,17 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) + GGML_ASSERT(q->ne[3] == k->ne[3]); + GGML_ASSERT(q->ne[3] == v->ne[3]); + if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[2] == 1); - GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[2] == q->ne[3]); GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + + GGML_ASSERT(q->ne[3] % mask->ne[2] == 0); } if (max_bias > 0.0f) {