diff --git a/ggml-metal.m b/ggml-metal.m index 58d62911..b21564db 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -368,6 +368,7 @@ static id ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; + //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { *offs = (size_t) ioffs; @@ -701,6 +702,7 @@ void ggml_metal_graph_compute( case GGML_OP_ADD: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -708,6 +710,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_add_row]; } else { [encoder setComputePipelineState:ctx->pipeline_add]; @@ -724,6 +727,7 @@ void ggml_metal_graph_compute( case GGML_OP_MUL: { GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); // utilize float4 GGML_ASSERT(ne00 % 4 == 0); @@ -731,6 +735,7 @@ void ggml_metal_graph_compute( if (ggml_nelements(src1) == ne10) { // src1 is a row + GGML_ASSERT(ne11 == 1); [encoder setComputePipelineState:ctx->pipeline_mul_row]; } else { [encoder setComputePipelineState:ctx->pipeline_mul]; @@ -746,6 +751,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SCALE: { + GGML_ASSERT(ggml_is_contiguous(src0)); + const float scale = *(const float *) src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; @@ -805,7 +812,6 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -1022,9 +1028,9 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5]; const int64_t n = ggml_nelements(src1); diff --git a/ggml-metal.metal b/ggml-metal.metal index 8cf59f4e..484cecb9 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -107,7 +107,6 @@ kernel void kernel_soft_max( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, - threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -119,58 +118,23 @@ kernel void kernel_soft_max( device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - buf[tpitg[0]] = -INFINITY; - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]); + float lmax = psrc0[tpitg[0]]; + for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { + lmax = MAX(lmax, psrc0[i00]); } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg[0]/2; i > 0; i /= 2) { - if (tpitg[0] < i) { - buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of - // the loop, and when that is done, buf[0] has the correct (synchronized) value - //if (tpitg[0] == 0) { - // buf[0] = buf[0]; - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - - const float max = buf[0]; + const float max = simd_max(lmax); // parallel sum - buf[tpitg[0]] = 0.0f; + float lsum = 0.0f; for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { const float exp_psrc0 = exp(psrc0[i00] - max); - buf[tpitg[0]] += exp_psrc0; + lsum += exp_psrc0; // Remember the result of exp here. exp is expensive, so we really do not // whish to compute it twice. pdst[i00] = exp_psrc0; } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg[0]/2; i > 0; i /= 2) { - if (tpitg[0] < i) { - buf[tpitg[0]] += buf[tpitg[0] + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // broadcast - not needed, see above - //// broadcast - //if (tpitg[0] == 0) { - // buf[0] = buf[0]; - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - - const float sum = buf[0]; + const float sum = simd_sum(lsum); for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { pdst[i00] /= sum; @@ -195,22 +159,6 @@ kernel void kernel_diag_mask_inf( } } -kernel void kernel_get_rows_f32( - device const float * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - for (int j = 0; j < ne00; j++) { - dst[i*nb1 + j] = ((device float *) ((device char *) src0 + r*nb01))[j]; - } -} - kernel void kernel_norm( device const void * src0, device float * dst, @@ -1763,6 +1711,15 @@ kernel void kernel_mul_mat_q6_K_f32( //============================= templates and their specializations ============================= +// NOTE: this is not dequantizeing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { half4x4 temp = *(((device half4x4 *)src)); @@ -2111,6 +2068,7 @@ kernel void kernel_mul_mm(device const uchar * src0, typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ constant uint64_t &, constant uint64_t &, uint, uint, uint); +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; diff --git a/whisper.cpp b/whisper.cpp index a608de0b..d8e7f370 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2031,17 +2031,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), 0, 2, 1, 3); struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), - n_state/n_head, n_head, n_past + N), - 0, 2, 1, 3); + ggml_view_3d(ctx0, kv_self.k, + n_state/n_head, n_past + N, n_head, + ggml_element_size(kv_self.k)*n_state, + ggml_element_size(kv_self.k)*n_state/n_head, + ggml_element_size(kv_self.k)*n_state*n_ctx*il); // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); @@ -2108,9 +2106,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // Kcross is already scaled struct ggml_tensor * Kcross = - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state), - n_state/n_head, n_head, M); + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state/n_head, M, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state/n_head, + ggml_element_size(wstate.kv_cross.k)*n_state*M*il); //struct ggml_tensor * Vcross = // ggml_reshape_3d(ctx0, @@ -2133,15 +2133,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * Q = ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)), + ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N), 0, 2, 1, 3); - struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3); - // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, @@ -2288,8 +2284,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; #ifdef GGML_USE_METAL - if (wstate.ctx_metal && n_tokens == 1) { - // TODO: fix for non-1 tokens + if (wstate.ctx_metal) { ggml_metal_set_n_cb (wstate.ctx_metal, n_threads); ggml_metal_graph_compute(wstate.ctx_metal, gf); ggml_metal_get_tensor (wstate.ctx_metal, logits); @@ -2302,8 +2297,8 @@ static bool whisper_decode_internal( } // extract logits for all N tokens - //logits_out.resize(N*n_vocab); - //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); + //logits_out.resize(n_tokens*n_vocab); + //memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab); // extract logits only for the last token logits_out.resize(n_vocab); @@ -2794,9 +2789,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - log("checkpoint 0\n"); alloc = ggml_allocr_new_measure(tensor_alignment); - log("checkpoint 1\n"); ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0); @@ -2818,15 +2811,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - log("checkpoint 2\n"); alloc = ggml_allocr_new_measure(tensor_alignment); - log("checkpoint 3\n"); ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state); - log("checkpoint 4\n"); const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; - log("checkpoint 5\n"); ggml_allocr_free(alloc);