From ab36d02560dc9faba360fa0c278d5ffbbb2306c7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 25 Oct 2024 22:26:15 +0300 Subject: [PATCH] metal : support permuted matrix multiplicaions (llama/10033) * metal : support permuted matrix multiplicaions ggml-ci * cont : use nb01 directly for row steps ggml-ci * cont : add comments [no ci] * metal : minor refactor * metal : minor --- ggml/src/ggml-metal.m | 75 ++--- ggml/src/ggml-metal.metal | 578 +++++++++++++++++++++++++------------- 2 files changed, 423 insertions(+), 230 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index e9541441..80c08f15 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node( id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} +#if 0 + GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + if (src0) { + GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(src0), src0->name); + } + if (src1) { + GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(src1), src1->name); + } + if (dst) { + GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + dst->name); + } +#endif id device = ctx_dev->mtl_device; @@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node( [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node( [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { @@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node( GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel // ne20 = n_used_experts diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 71b58be1..defde624 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -777,10 +777,10 @@ kernel void kernel_ssm_conv_f32( const int64_t i3 = tgpig.z; const int64_t nc = ne10; - const int64_t ncs = ne00; - const int64_t nr = ne01; - const int64_t n_t = ne1; - const int64_t n_s = ne2; + //const int64_t ncs = ne00; + //const int64_t nr = ne01; + //const int64_t n_t = ne1; + //const int64_t n_s = ne2; device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); @@ -834,9 +834,9 @@ kernel void kernel_ssm_scan_f32( const int64_t i3 = tgpig.y; const int64_t nc = d_state; - const int64_t nr = d_inner; + //const int64_t nr = d_inner; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; + //const int64_t n_s = n_seqs; for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); @@ -1064,17 +1064,18 @@ kernel void kernel_group_norm( inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1]); + + return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); } // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre float d = qb_curr->d; float m = qb_curr->m; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); } - return d * (acc[0] + acc[1]) + sumy * m; + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); const uint32_t qh = *((device const uint32_t *)qb_curr->qh); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); } - return d * (sumy * -16.f + acc[0] + acc[1]); + + return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]); } // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre float d = qb_curr->d; float m = qb_curr->m; - float2 acc = 0.f; + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); const uint32_t qh = *((device const uint32_t *)qb_curr->qh); for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); } - return d * (acc[0] + acc[1]) + sumy * m; + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; } // putting them in the kernel cause a significant performance penalty @@ -1156,14 +1160,22 @@ void mul_vec_q_n_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, uint r3, threadgroup int8_t * shared_values, - uint3 tgpig, uint tiisg, uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; @@ -1175,10 +1187,19 @@ void mul_vec_q_n_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); + + // pointers to src0 rows + device const block_q_type * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; @@ -1190,19 +1211,22 @@ void mul_vec_q_n_f32_impl( // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; + float sumy[2] = { 0.f, 0.f }; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; +#pragma unroll + for (int i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; } +#pragma unroll for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } yb += QK4_0 * 16; @@ -1226,12 +1250,14 @@ kernel void kernel_mul_mv_q4_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1239,7 +1265,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1252,12 +1278,14 @@ kernel void kernel_mul_mv_q4_1_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1265,7 +1293,7 @@ kernel void kernel_mul_mv_q4_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1278,12 +1306,14 @@ kernel void kernel_mul_mv_q5_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1291,7 +1321,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1304,12 +1334,14 @@ kernel void kernel_mul_mv_q5_1_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1317,7 +1349,7 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } @@ -1330,8 +1362,14 @@ void kernel_mul_mv_q8_0_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -1354,10 +1392,19 @@ void kernel_mul_mv_q8_0_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); + + // pointers to src0 rows + device const block_q8_0 * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); + } float yl[NB_Q8_0]; float sumf[nr]={0.f}; @@ -1374,12 +1421,12 @@ void kernel_mul_mv_q8_0_f32_impl( } for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; + device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } - sumf[row] += sumq*x[ib+row*nb].d; + sumf[row] += sumq*ax[row][ib].d; } yb += NB_Q8_0 * nw; @@ -1404,12 +1451,14 @@ kernel void kernel_mul_mv_q8_0_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1417,7 +1466,7 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } #define N_MV_T_T 4 @@ -1433,12 +1482,14 @@ void kernel_mul_mv_impl( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -1452,7 +1503,7 @@ void kernel_mul_mv_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -1463,7 +1514,9 @@ void kernel_mul_mv_impl( break; } - device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < ne00; i += 32) { @@ -1483,7 +1536,9 @@ void kernel_mul_mv_impl( break; } - device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; @@ -1511,12 +1566,14 @@ kernel void kernel_mul_mv( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1533,12 +1590,14 @@ kernel void kernel_mul_mv( nb00, nb01, nb02, + nb03, ne10, ne11, ne12, nb10, nb11, nb12, + nb13, ne0, ne1, r2, @@ -1564,12 +1623,14 @@ kernel void kernel_mul_mv_1row( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1584,10 +1645,11 @@ kernel void kernel_mul_mv_1row( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; device const T * x = (device const T *) (src0 + offset0); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float * y = (device const float *) (src1 + offset1); float sumf = 0; if (ne00 < 128) { @@ -1631,12 +1693,14 @@ kernel void kernel_mul_mv_l4( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -1651,12 +1715,14 @@ kernel void kernel_mul_mv_l4( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < ne00/4; i += 32) { @@ -3416,8 +3482,14 @@ void kernel_mul_mv_q2_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3433,21 +3505,19 @@ void kernel_mul_mv_q2_K_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; - const int step = sizeof(block_q2_K) * nb; - const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 @@ -3492,9 +3562,9 @@ void kernel_mul_mv_q2_K_f32_impl( (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - qs += step/2; - sc += step; - dh += step/2; + qs += nb01/2; + sc += nb01; + dh += nb01/2; } y4 += 4 * QK_K; @@ -3519,12 +3589,14 @@ kernel void kernel_mul_mv_q2_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3533,7 +3605,7 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q3_K_f32_impl( @@ -3543,8 +3615,14 @@ void kernel_mul_mv_q3_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3565,10 +3643,11 @@ void kernel_mul_mv_q3_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float yl[32]; @@ -3608,8 +3687,6 @@ void kernel_mul_mv_q3_K_f32_impl( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - const int step = sizeof(block_q3_K) * nb / 2; - device const float * y1 = yy + ix*QK_K + y_offset; uint32_t scales32, aux32; @@ -3619,7 +3696,6 @@ void kernel_mul_mv_q3_K_f32_impl( float sumf1[2] = {0.f}; float sumf2[2] = {0.f}; for (int i = ix; i < nb; i += 4) { - for (int l = 0; l < 8; ++l) { yl[l+ 0] = y1[l+ 0]; yl[l+ 8] = y1[l+16]; @@ -3633,7 +3709,6 @@ void kernel_mul_mv_q3_K_f32_impl( device const half * dh = &x[i].d; for (int row = 0; row < 2; ++row) { - const float d_all = (float)dh[0]; scales16[0] = a[4]; @@ -3673,15 +3748,13 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] += d1 * (scales[1] - 32); sumf2[row] += d2 * (scales[3] - 32); - q += step; - h += step; - a += step; - dh += step; - + q += nb01/2; + h += nb01/2; + a += nb01/2; + dh += nb01/2; } y1 += 4 * QK_K; - } for (int row = 0; row < 2; ++row) { @@ -3706,12 +3779,14 @@ kernel void kernel_mul_mv_q3_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3720,7 +3795,7 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q4_K_f32_impl( @@ -3730,8 +3805,14 @@ void kernel_mul_mv_q4_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3756,29 +3837,26 @@ void kernel_mul_mv_q4_K_f32_impl( const int im = tgpig.z; //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[16]; float yh[16]; float sumf[N_DST]={0.f}, all_sum; - const int step = sizeof(block_q4_K) * nb / 2; - device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; uint16_t sc16[4]; thread const uint8_t * sc8 = (thread const uint8_t *)sc16; for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; ++i) { yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; @@ -3792,7 +3870,6 @@ void kernel_mul_mv_q4_K_f32_impl( device const half * dh = &x[ib].d; for (int row = 0; row < N_DST; row++) { - sc16[0] = sc[0] & kmask1; sc16[1] = sc[2] & kmask1; sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); @@ -3821,9 +3898,9 @@ void kernel_mul_mv_q4_K_f32_impl( (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += step; - sc += step; - dh += step; + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; } y4 += 4 * QK_K; @@ -3848,12 +3925,14 @@ kernel void kernel_mul_mv_q4_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -3862,7 +3941,7 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( @@ -3872,8 +3951,14 @@ void kernel_mul_mv_q5_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -3894,15 +3979,14 @@ void kernel_mul_mv_q5_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float sumf[2]={0.f}; - const int step = sizeof(block_q5_K) * nb; - float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; @@ -3930,7 +4014,6 @@ void kernel_mul_mv_q5_K_f32_impl( device const float * y1 = yy + ix*QK_K + y_offset; for (int i = ix; i < nb; i += 4) { - device const uint8_t * q1 = x[i].qs + q_offset; device const uint8_t * qh = x[i].qh + l0; device const half * dh = &x[i].d; @@ -3946,7 +4029,6 @@ void kernel_mul_mv_q5_K_f32_impl( } for (int row = 0; row < 2; ++row) { - device const uint8_t * q2 = q1 + 64; sc16[0] = a[0] & kmask1; @@ -3975,15 +4057,13 @@ void kernel_mul_mv_q5_K_f32_impl( sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += step; - qh += step; - dh += step/2; - a += step/2; - + q1 += nb01; + qh += nb01; + dh += nb01/2; + a += nb01/2; } y1 += 4 * QK_K; - } for (int row = 0; row < 2; ++row) { @@ -4005,12 +4085,14 @@ kernel void kernel_mul_mv_q5_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4019,7 +4101,7 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( @@ -4029,8 +4111,14 @@ void kernel_mul_mv_q6_K_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4056,10 +4144,11 @@ void kernel_mul_mv_q6_K_f32_impl( const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); + device const float * yy = (device const float *) ((device char *) src1 + offset1); float sumf = 0; @@ -4115,12 +4204,14 @@ kernel void kernel_mul_mv_q6_K_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4129,7 +4220,7 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit @@ -4141,8 +4232,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4158,15 +4255,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4219,8 +4316,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( } sumf[row] += d * sum; - dh += nb*sizeof(block_iq2_xxs)/2; - q2 += nb*sizeof(block_iq2_xxs)/2; + dh += nb01/2; + q2 += nb01/2; } y4 += 32 * 32; @@ -4245,12 +4342,14 @@ kernel void kernel_mul_mv_iq2_xxs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4260,7 +4359,7 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_xs_f32_impl( @@ -4270,8 +4369,14 @@ void kernel_mul_mv_iq2_xs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4287,15 +4392,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4357,9 +4462,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( } sumf[row] += d1 * sum1 + d2 * sum2; - dh += nb*sizeof(block_iq2_xs)/2; - q2 += nb*sizeof(block_iq2_xs)/2; - sc += nb*sizeof(block_iq2_xs); + dh += nb01/2; + q2 += nb01/2; + sc += nb01; } y4 += 32 * 32; @@ -4384,12 +4489,14 @@ kernel void kernel_mul_mv_iq2_xs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4399,7 +4506,7 @@ kernel void kernel_mul_mv_iq2_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_xxs_f32_impl( @@ -4409,8 +4516,14 @@ void kernel_mul_mv_iq3_xxs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4426,15 +4539,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4489,9 +4602,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb*sizeof(block_iq3_xxs)/2; - q3 += nb*sizeof(block_iq3_xxs); - gas += nb*sizeof(block_iq3_xxs)/2; + dh += nb01/2; + q3 += nb01; + gas += nb01/2; } y4 += 32 * 32; @@ -4516,12 +4629,14 @@ kernel void kernel_mul_mv_iq3_xxs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4531,7 +4646,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_s_f32_impl( @@ -4541,8 +4656,14 @@ void kernel_mul_mv_iq3_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4558,15 +4679,15 @@ void kernel_mul_mv_iq3_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4619,11 +4740,11 @@ void kernel_mul_mv_iq3_s_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb*sizeof(block_iq3_s)/2; - qs += nb*sizeof(block_iq3_s); - qh += nb*sizeof(block_iq3_s); - sc += nb*sizeof(block_iq3_s); - signs += nb*sizeof(block_iq3_s); + dh += nb01/2; + qs += nb01; + qh += nb01; + sc += nb01; + signs += nb01; } y4 += 32 * 32; @@ -4648,12 +4769,14 @@ kernel void kernel_mul_mv_iq3_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4663,7 +4786,7 @@ kernel void kernel_mul_mv_iq3_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_s_f32_impl( @@ -4673,8 +4796,14 @@ void kernel_mul_mv_iq2_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4690,15 +4819,15 @@ void kernel_mul_mv_iq2_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4752,11 +4881,11 @@ void kernel_mul_mv_iq2_s_f32_impl( } sumf[row] += d1 * sum[0] + d2 * sum[1]; - dh += nb*sizeof(block_iq2_s)/2; - qs += nb*sizeof(block_iq2_s); - qh += nb*sizeof(block_iq2_s); - sc += nb*sizeof(block_iq2_s); - signs += nb*sizeof(block_iq2_s); + dh += nb01/2; + qs += nb01; + qh += nb01; + sc += nb01; + signs += nb01; } y4 += 32 * 32; @@ -4781,12 +4910,14 @@ kernel void kernel_mul_mv_iq2_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -4796,7 +4927,7 @@ kernel void kernel_mul_mv_iq2_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq1_s_f32_impl( @@ -4806,8 +4937,14 @@ void kernel_mul_mv_iq1_s_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4823,14 +4960,15 @@ void kernel_mul_mv_iq1_s_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4873,9 +5011,9 @@ void kernel_mul_mv_iq1_s_f32_impl( } sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - dh += nb*sizeof(block_iq1_s)/2; - qs += nb*sizeof(block_iq1_s); - qh += nb*sizeof(block_iq1_s)/2; + dh += nb01/2; + qs += nb01; + qh += nb01/2; } y4 += 32 * 32; @@ -4896,8 +5034,14 @@ void kernel_mul_mv_iq1_m_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -4913,14 +5057,15 @@ void kernel_mul_mv_iq1_m_f32_impl( const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4972,9 +5117,9 @@ void kernel_mul_mv_iq1_m_f32_impl( sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); - sc += nb*sizeof(block_iq1_m)/2; - qs += nb*sizeof(block_iq1_m); - qh += nb*sizeof(block_iq1_m); + sc += nb01/2; + qs += nb01; + qh += nb01; } y4 += 32 * 32; @@ -4995,8 +5140,14 @@ void kernel_mul_mv_iq4_nl_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -5012,14 +5163,15 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 @@ -5089,8 +5241,14 @@ void kernel_mul_mv_iq4_xs_f32_impl( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -5106,14 +5264,15 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; const uint i12 = im%ne12; const uint i13 = im/ne12; - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); + device const float * y = (device const float *) ((device char *) src1 + offset1); const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 @@ -5188,12 +5347,14 @@ kernel void kernel_mul_mv_iq1_s_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5202,7 +5363,7 @@ kernel void kernel_mul_mv_iq1_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -5216,12 +5377,14 @@ kernel void kernel_mul_mv_iq1_m_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5230,7 +5393,7 @@ kernel void kernel_mul_mv_iq1_m_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -5244,12 +5407,14 @@ kernel void kernel_mul_mv_iq4_nl_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5259,7 +5424,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5273,12 +5438,14 @@ kernel void kernel_mul_mv_iq4_xs_f32( constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne10, constant int64_t & ne11, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5288,7 +5455,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); } //============================= templates and their specializations ============================= @@ -5833,10 +6000,12 @@ kernel void kernel_mul_mm(device const uchar * src0, constant int64_t & ne02, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant int64_t & ne12, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, + constant uint64_t & nb13, constant int64_t & ne0, constant int64_t & ne1, constant uint & r2, @@ -5873,12 +6042,13 @@ kernel void kernel_mul_mm(device const uchar * src0, const uint i12 = im%ne12; const uint i13 = im/ne12; - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; ushort offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 - + nb12 * im + + nb13 * i13 + + nb12 * i12 + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); @@ -6257,12 +6427,14 @@ typedef void (kernel_mul_mv_impl_t)( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -6277,8 +6449,14 @@ typedef void (kernel_mul_mv2_impl_t)( int64_t ne00, int64_t ne01, int64_t ne02, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne12, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint r2, @@ -6299,6 +6477,7 @@ void mmv_fn( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, @@ -6306,6 +6485,7 @@ void mmv_fn( uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint64_t nb1, @@ -6316,7 +6496,7 @@ void mmv_fn( uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); + impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg); } template @@ -6330,6 +6510,7 @@ void mmv_fn( uint64_t nb00, uint64_t nb01, uint64_t nb02, + uint64_t nb03, int64_t ne10, int64_t ne11, int64_t ne12, @@ -6337,6 +6518,7 @@ void mmv_fn( uint64_t nb10, uint64_t nb11, uint64_t nb12, + uint64_t nb13, int64_t ne0, int64_t ne1, uint64_t nb1, @@ -6347,7 +6529,7 @@ void mmv_fn( uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); + impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6396,8 +6578,8 @@ kernel void kernel_mul_mv_id( const int64_t i2 = i12; device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + device const char * src1_cur = src1 + i11*nb11 + i12*nb12; + device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; impl_fn( /* src0 */ src0_cur, @@ -6405,19 +6587,21 @@ kernel void kernel_mul_mv_id( /* dst */ dst_cur, /* ne00 */ ne00, /* ne01 */ ne01, - /* ne02 */ 1,//ne02, + /* ne02 */ 1, // ne02, /* nb00 */ nb00, /* nb01 */ nb01, /* nb02 */ nb02, + /* nb03 */ nb02, // ne02 == 1 /* ne10 */ ne10, - /* ne11 */ 1,//ne11, - /* ne12 */ 1,//ne12, - /* ne13 */ 1,//ne13, + /* ne11 */ 1, // ne11, + /* ne12 */ 1, // ne12, + /* ne13 */ 1, // ne13, /* nb10 */ nb10, /* nb11 */ nb11, /* nb12 */ nb12, + /* ne13 */ nb12, // ne12 == 1 /* ne0 */ ne0, - /* ne1 */ 1,//ne1, + /* ne1 */ 1, // ne1, /* nb1 */ nb1, /* r2 */ 1, /* r3 */ 1,