From 1e5544b39bf766e2d1af29f7eb8459e65848b32b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 2 Jan 2024 10:57:44 +0200 Subject: [PATCH] metal : enable shader debugging (cmake option) (llama/4705) * ggml : disable fast-math for Metal (cmake build only) ggml-ci * metal : fix Metal API debug warnings * cmake : add -fno-inline for Metal build (llama/4545) * metal : fix API debug warnings * metal : fix compile warnings * metal : use uint64_t for strides * cmake : rename option to LLAMA_METAL_SHADER_DEBUG * metal : fix mat-vec Q8_0 kernel for BS > 1 * metal : normalize mat-vec kernel signatures * cmake : respect LLAMA_QKK_64 option * metal : fix mat-vec Q4_K kernel for QK_K == 64 ggml-ci --- ggml-metal.m | 28 ++- ggml-metal.metal | 475 ++++++++++++++++++++++++++--------------------- 2 files changed, 284 insertions(+), 219 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 51a72ae3..cd9d0045 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -257,13 +257,14 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; #endif NSError * error = nil; - NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"]; + NSString * libPath = [bundle pathForResource:@"ggml" ofType:@"metallib"]; if (libPath != nil) { + // pre-compiled library found NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; } else { - GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + GGML_METAL_LOG_INFO("%s: ggml.metallib not found, loading from source\n", __func__); NSString * sourcePath; NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; @@ -291,6 +292,13 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { options = [MTLCompileOptions new]; options.preprocessorMacros = @{ @"QK_K" : @(64) }; #endif + // try to disable fast-math + // NOTE: this seems to have no effect whatsoever + // instead, in order to disable fast-math, we have to build ggml.metallib from the command line + // using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air + // and go through the "pre-compiled library found" path above + //[options setFastMathEnabled:false]; + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; } @@ -1230,7 +1238,7 @@ void ggml_metal_graph_compute( // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel - const int nth = MIN(1024, ne00); + const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00); [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1285,7 +1293,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - const int nth = MIN(1024, ne0); + const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00); [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -1785,8 +1793,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; // TODO: how to make this an array? read Metal docs - for (int j = 0; j < n_as; ++j) { - struct ggml_tensor * src_cur = dst->src[2 + j]; + for (int j = 0; j < 8; ++j) { + // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 + struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; size_t offs_src_cur = 0; id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); @@ -1909,8 +1918,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&r3 length:sizeof(r3) atIndex:21]; [encoder setBytes:&idx length:sizeof(idx) atIndex:22]; // TODO: how to make this an array? read Metal docs - for (int j = 0; j < n_as; ++j) { - struct ggml_tensor * src_cur = dst->src[2 + j]; + for (int j = 0; j < 8; ++j) { + // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8 + struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)]; size_t offs_src_cur = 0; id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); @@ -2229,7 +2239,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; [encoder setBytes:&sf length:sizeof(sf) atIndex:18]; - const int nth = MIN(1024, ne0); + const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0); [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index d5b54e11..1d5b8f6f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -59,26 +59,26 @@ kernel void kernel_add( constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, constant int64_t & offs, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], @@ -109,26 +109,26 @@ kernel void kernel_mul( constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -158,26 +158,26 @@ kernel void kernel_div( constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -205,7 +205,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(28)]], + constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -214,7 +214,7 @@ kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(28)]], + constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src1[tpig % nb]; } @@ -223,7 +223,7 @@ kernel void kernel_div_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb [[buffer(28)]], + constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] / src1[tpig % nb]; } @@ -307,26 +307,26 @@ kernel void kernel_sum_rows( constant int64_t & ne01, constant int64_t & ne02, constant int64_t & ne03, - constant int64_t & nb00, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & nb03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & nb13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant int64_t & ne2, constant int64_t & ne3, - constant int64_t & nb0, - constant int64_t & nb1, - constant int64_t & nb2, - constant int64_t & nb3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, uint3 tpig[[thread_position_in_grid]]) { int64_t i3 = tpig.z; int64_t i2 = tpig.y; @@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1071,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne10, + constant int64_t & ne11, constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1182,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); @@ -1209,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1346,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); @@ -1452,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); @@ -1478,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1543,7 +1578,8 @@ kernel void kernel_alibi_f32( const int64_t i3 = n / (ne2*ne1*ne0); const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + const int64_t k = i3*ne3 + i2; float m_k; @@ -2410,22 +2446,6 @@ typedef struct { } block_q6_K; // 210 bytes / block -static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { - uchar4 r; - if (j < 4) { - r[0] = q[j+0] & 63; - r[2] = q[j+1] & 63; - r[1] = q[j+4] & 63; - r[3] = q[j+5] & 63; - } else { - r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); - r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); - } - return r; -} - //====================================== dot products ========================= void kernel_mul_mv_q2_K_f32_impl( @@ -2584,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2841,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2984,8 +3018,8 @@ void kernel_mul_mv_q4_K_f32_impl( constant uint & r2, constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int ix = tiisg/4; // 0...7 const int it = tiisg%4; // 0...3 @@ -2994,7 +3028,7 @@ void kernel_mul_mv_q4_K_f32_impl( const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; const int ib_row = first_row * nb; const uint i12 = im%ne12; @@ -3060,7 +3094,7 @@ void kernel_mul_mv_q4_K_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -3072,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3271,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3398,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0 [[buffer(15)]], - constant int64_t & ne1 [[buffer(16)]], - constant uint & r2 [[buffer(17)]], - constant uint & r3 [[buffer(18)]], + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -3523,7 +3578,7 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg device const int8_t * qs = ((device const int8_t *)xb->qs); const half d = xb->d; - for (int i=0;i<16;i++) { + for (int i = 0; i < 16; i++) { reg[i/4][i%4] = (qs[i + 16*il] * d); } } @@ -3792,12 +3847,12 @@ void kernel_mul_mm_impl(device const uchar * src0, device float * dst, constant int64_t & ne00, constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3924,12 +3979,12 @@ kernel void kernel_mul_mm(device const uchar * src0, device float * dst, constant int64_t & ne00, constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3965,19 +4020,19 @@ kernel void kernel_mul_mm_id( device const uchar * ids, device const uchar * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4070,12 +4125,12 @@ typedef void (mat_mm_t)( device float * dst, constant int64_t & ne00, constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4104,19 +4159,19 @@ typedef void (mat_mm_id_t)( device const uchar * ids, device const uchar * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, + constant uint64_t & nb01, + constant uint64_t & nb02, constant int64_t & ne12, constant int64_t & ne13, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4153,7 +4208,7 @@ kernel void kernel_mul_mv_id_f32_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4169,7 +4224,7 @@ kernel void kernel_mul_mv_id_f32_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4222,7 +4277,7 @@ kernel void kernel_mul_mv_id_f16_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4238,7 +4293,7 @@ kernel void kernel_mul_mv_id_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4291,7 +4346,7 @@ kernel void kernel_mul_mv_id_q8_0_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4307,7 +4362,7 @@ kernel void kernel_mul_mv_id_q8_0_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4354,7 +4409,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4370,7 +4425,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4417,7 +4472,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4433,7 +4488,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4480,7 +4535,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4496,7 +4551,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4543,7 +4598,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4559,7 +4614,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4606,7 +4661,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4622,7 +4677,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4669,7 +4724,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4685,7 +4740,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4732,7 +4787,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4748,7 +4803,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4795,7 +4850,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4811,7 +4866,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -4858,7 +4913,7 @@ kernel void kernel_mul_mv_id_q6_K_f32( device const char * ids, device const char * src1, device uchar * dst, - constant int64_t & nbi1, + constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -4874,7 +4929,7 @@ kernel void kernel_mul_mv_id_q6_K_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant int64_t & nb1, + constant uint64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx,