From af947cb72e0268f970a8a45407a92d8c34ce4c0c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 Sep 2023 15:16:22 +0300 Subject: [PATCH] whisper : add ggml_mul_mat_pad --- examples/bench/bench.cpp | 1 - extra/bench-all.sh | 6 ++-- ggml-metal.m | 12 ++++--- ggml-metal.metal | 75 +++++++++++++++++++++++++--------------- whisper.cpp | 62 +++++++++++++++++++++------------ 5 files changed, 98 insertions(+), 58 deletions(-) diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index f728348e..1a2f0393 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -69,7 +69,6 @@ int whisper_bench_full(const whisper_params & params) { fprintf(stderr, "error: failed to set mel: %d\n", ret); return 3; } - // heat encoder if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { fprintf(stderr, "error: failed to encode model: %d\n", ret); diff --git a/extra/bench-all.sh b/extra/bench-all.sh index 98c8cfd6..352a2235 100755 --- a/extra/bench-all.sh +++ b/extra/bench-all.sh @@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then printf "\n" fi -printf "| CPU | OS | Config | Model | Th | Enc. | Dec. | PP | Commit |\n" -printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ---- | ------ |\n" +printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit" +printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---" for model in "${models[@]}"; do # actual run @@ -86,6 +86,6 @@ for model in "${models[@]}"; do commit=$(git rev-parse --short HEAD) if [ $ret -eq 0 ]; then - printf "| | | $config | $model | $n_threads | $encode_time | $decode_time | $prompt_time | $commit |\n" + printf "| | | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit" fi done diff --git a/ggml-metal.m b/ggml-metal.m index 7f587aab..059da6ee 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -878,7 +878,8 @@ void ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (ggml_is_contiguous(src1) && + if (!ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && [ctx->device supportsFamily:MTLGPUFamilyApple7] && ne00%32 == 0 && @@ -903,9 +904,12 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { diff --git a/ggml-metal.metal b/ggml-metal.metal index b9ea8f60..0db037c1 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2113,22 +2113,25 @@ kernel void kernel_get_rows( // each block_q contains 16*nl weights template kernel void kernel_mul_mm(device const uchar * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = ((threadgroup half *)shared_memory); + threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); const uint r0 = tgpig.y; @@ -2141,7 +2144,7 @@ kernel void kernel_mul_mm(device const uchar * src0, short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; + simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; simdgroup_float8x8 c_res[8]; for (int i = 0; i < 8; i++){ @@ -2149,10 +2152,15 @@ kernel void kernel_mul_mm(device const uchar * src0, } short il = (tiitg % THREAD_PER_ROW); - uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ - + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + uint offset0 = im/gqa*nb02; + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { //load data and store to threadgroup memory @@ -2243,14 +2251,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ - constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ - constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); +typedef void (mat_mm_t)( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar *, uint3, uint, uint); -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/whisper.cpp b/whisper.cpp index b3f1dac0..9e699046 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -136,6 +136,22 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * ggml_graph_compute(graph, &plan); } +static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) { + if (x->ne[0] % pad == 0 || x->ne[0] / pad < 2) { + return ggml_mul_mat(ctx, x, y); + } + + struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0); + struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]); + + struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0); + struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]); + + return ggml_add(ctx, + ggml_mul_mat(ctx, x_0, y_0), + ggml_mul_mat(ctx, x_1, y_1)); +} + // available whisper models enum e_model { MODEL_UNKNOWN, @@ -1626,7 +1642,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0, layer.attn_q_w, cur); @@ -1635,13 +1651,13 @@ static struct ggml_cgraph * whisper_build_graph_encoder( //Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0, layer.attn_k_w, cur); //Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0, layer.attn_v_w, cur); @@ -1690,7 +1706,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( 0, 2, 1, 3); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, ggml_cont(ctx0, Q)); + struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q); struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale); @@ -1706,7 +1722,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head) ); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max); #endif struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -1717,7 +1733,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // projection { - cur = ggml_mul_mat(ctx0, + cur = ggml_mul_mat_pad(ctx0, layer.attn_ln_1_w, cur); @@ -1747,7 +1763,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); #else // fully connected - cur = ggml_mul_mat(ctx0, + cur = ggml_mul_mat_pad(ctx0, layer.mlp_0_w, cur); @@ -1757,7 +1773,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( cur = ggml_gelu(ctx0, cur); // projection - cur = ggml_mul_mat(ctx0, + cur = ggml_mul_mat_pad(ctx0, layer.mlp_1_w, cur); @@ -1835,13 +1851,13 @@ static struct ggml_cgraph * whisper_build_graph_cross( for (int il = 0; il < model.hparams.n_text_layer; ++il) { auto & layer = model.layers_decoder[il]; - struct ggml_tensor* Kcross = ggml_mul_mat(ctx0, + struct ggml_tensor* Kcross = ggml_mul_mat_pad(ctx0, layer.cross_attn_k_w, cur); Kcross = ggml_scale(ctx0, Kcross, Kscale); - struct ggml_tensor* Vcross = ggml_mul_mat(ctx0, + struct ggml_tensor* Vcross = ggml_mul_mat_pad(ctx0, layer.cross_attn_v_w, cur); @@ -2038,7 +2054,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0, layer.attn_q_w, cur); @@ -2049,7 +2065,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( Qcur = ggml_scale(ctx0, Qcur, KQscale); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0, layer.attn_k_w, cur); @@ -2057,7 +2073,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // store key and value to memory { - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0, layer.attn_v_w, cur); @@ -2091,7 +2107,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_element_size(kv_self.k)*n_state*n_ctx*il); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, ggml_cont(ctx0, Q)); + struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q); //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); @@ -2106,7 +2122,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, il*n_ctx*ggml_element_size(kv_self.v)*n_state); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -2117,7 +2133,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // projection { - cur = ggml_mul_mat(ctx0, + cur = ggml_mul_mat_pad(ctx0, layer.attn_ln_1_w, cur); @@ -2143,7 +2159,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // cross-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0, layer.cross_attn_q_w, cur); @@ -2186,7 +2202,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( 0, 2, 1, 3); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, ggml_cont(ctx0, Q)); + struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, Kcross, Q); //struct ggml_tensor * KQ_scaled = // ggml_scale(ctx0, @@ -2199,7 +2215,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -2211,7 +2227,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // projection { - cur = ggml_mul_mat(ctx0, + cur = ggml_mul_mat_pad(ctx0, layer.cross_attn_ln_1_w, cur); @@ -2240,7 +2256,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( } // fully connected - cur = ggml_mul_mat(ctx0, + cur = ggml_mul_mat_pad(ctx0, layer.mlp_0_w, cur); @@ -2252,7 +2268,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur = ggml_gelu(ctx0, cur); // projection - cur = ggml_mul_mat(ctx0, + cur = ggml_mul_mat_pad(ctx0, layer.mlp_1_w, cur); @@ -2282,7 +2298,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // might be useful in the future cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); - struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + struct ggml_tensor * logits = ggml_mul_mat_pad(ctx0, model.d_te, cur); ggml_build_forward_expand(gf, logits);