whisper : add ggml_mul_mat_pad

This commit is contained in:
Georgi Gerganov
2023-09-14 15:16:22 +03:00
parent e81c67a125
commit af947cb72e
5 changed files with 98 additions and 58 deletions

View File

@ -69,7 +69,6 @@ int whisper_bench_full(const whisper_params & params) {
fprintf(stderr, "error: failed to set mel: %d\n", ret);
return 3;
}
// heat encoder
if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) {
fprintf(stderr, "error: failed to encode model: %d\n", ret);

View File

@ -44,8 +44,8 @@ if [ "$encoder_only" -eq 0 ]; then
printf "\n"
fi
printf "| CPU | OS | Config | Model | Th | Enc. | Dec. | PP | Commit |\n"
printf "| --- | -- | ------ | ----- | -- | ---- | ---- | ---- | ------ |\n"
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "CPU" "OS" "Config" "Model" "Th" "Enc." "Dec." "PP" "Commit"
printf "| %6s | %6s | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "---" "---" "---" "---" "---" "---" "---" "---" "---"
for model in "${models[@]}"; do
# actual run
@ -86,6 +86,6 @@ for model in "${models[@]}"; do
commit=$(git rev-parse --short HEAD)
if [ $ret -eq 0 ]; then
printf "| <todo> | <todo> | $config | $model | $n_threads | $encode_time | $decode_time | $prompt_time | $commit |\n"
printf "| <todo> | <todo> | %12s | %9s | %3s | %7s | %7s | %7s | %7s |\n" "$config" "$model" "$n_threads" "$encode_time" "$decode_time" "$prompt_time" "$commit"
fi
done

View File

@ -878,7 +878,8 @@ void ggml_metal_graph_compute(
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src1) &&
if (!ggml_is_transposed(src0) &&
!ggml_is_transposed(src1) &&
src1t == GGML_TYPE_F32 &&
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00%32 == 0 &&
@ -903,9 +904,12 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {

View File

@ -2113,22 +2113,25 @@ kernel void kernel_get_rows(
// each block_q contains 16*nl weights
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm(device const uchar * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup half * sa = ((threadgroup half *)shared_memory);
threadgroup half * sa = (threadgroup half *)(shared_memory);
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
const uint r0 = tgpig.y;
@ -2141,7 +2144,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_half8x8 ma[4];
simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
simdgroup_float8x8 c_res[8];
for (int i = 0; i < 8; i++){
@ -2149,10 +2152,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
}
short il = (tiitg % THREAD_PER_ROW);
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
uint offset0 = im/gqa*nb02;
ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
+ nb12 * im
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
//load data and store to threadgroup memory
@ -2243,14 +2251,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
typedef void (mat_mm_t)(
device const uchar * src0,
device const uchar * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar *, uint3, uint, uint);
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;

View File

@ -136,6 +136,22 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
ggml_graph_compute(graph, &plan);
}
static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) {
if (x->ne[0] % pad == 0 || x->ne[0] / pad < 2) {
return ggml_mul_mat(ctx, x, y);
}
struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0);
struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]);
struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0);
struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]);
return ggml_add(ctx,
ggml_mul_mat(ctx, x_0, y_0),
ggml_mul_mat(ctx, x_1, y_1));
}
// available whisper models
enum e_model {
MODEL_UNKNOWN,
@ -1626,7 +1642,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// self-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
layer.attn_q_w,
cur);
@ -1635,13 +1651,13 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
//Qcur = ggml_scale(ctx0, Qcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0,
layer.attn_k_w,
cur);
//Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0,
layer.attn_v_w,
cur);
@ -1690,7 +1706,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, ggml_cont(ctx0, Q));
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q);
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale);
@ -1706,7 +1722,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)
);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
#endif
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@ -1717,7 +1733,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// projection
{
cur = ggml_mul_mat(ctx0,
cur = ggml_mul_mat_pad(ctx0,
layer.attn_ln_1_w,
cur);
@ -1747,7 +1763,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
#else
// fully connected
cur = ggml_mul_mat(ctx0,
cur = ggml_mul_mat_pad(ctx0,
layer.mlp_0_w,
cur);
@ -1757,7 +1773,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
cur = ggml_gelu(ctx0, cur);
// projection
cur = ggml_mul_mat(ctx0,
cur = ggml_mul_mat_pad(ctx0,
layer.mlp_1_w,
cur);
@ -1835,13 +1851,13 @@ static struct ggml_cgraph * whisper_build_graph_cross(
for (int il = 0; il < model.hparams.n_text_layer; ++il) {
auto & layer = model.layers_decoder[il];
struct ggml_tensor* Kcross = ggml_mul_mat(ctx0,
struct ggml_tensor* Kcross = ggml_mul_mat_pad(ctx0,
layer.cross_attn_k_w,
cur);
Kcross = ggml_scale(ctx0, Kcross, Kscale);
struct ggml_tensor* Vcross = ggml_mul_mat(ctx0,
struct ggml_tensor* Vcross = ggml_mul_mat_pad(ctx0,
layer.cross_attn_v_w,
cur);
@ -2038,7 +2054,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// self-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
layer.attn_q_w,
cur);
@ -2049,7 +2065,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
Qcur = ggml_scale(ctx0, Qcur, KQscale);
// note: no bias for Key
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0,
struct ggml_tensor * Kcur = ggml_mul_mat_pad(ctx0,
layer.attn_k_w,
cur);
@ -2057,7 +2073,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// store key and value to memory
{
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0,
struct ggml_tensor * Vcur = ggml_mul_mat_pad(ctx0,
layer.attn_v_w,
cur);
@ -2091,7 +2107,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, ggml_cont(ctx0, Q));
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, K, Q);
//struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
@ -2106,7 +2122,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
n_ctx*ggml_element_size(kv_self.v)*n_state/n_head,
il*n_ctx*ggml_element_size(kv_self.v)*n_state);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@ -2117,7 +2133,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// projection
{
cur = ggml_mul_mat(ctx0,
cur = ggml_mul_mat_pad(ctx0,
layer.attn_ln_1_w,
cur);
@ -2143,7 +2159,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// cross-attention
{
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
struct ggml_tensor * Qcur = ggml_mul_mat_pad(ctx0,
layer.cross_attn_q_w,
cur);
@ -2186,7 +2202,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
0, 2, 1, 3);
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, ggml_cont(ctx0, Q));
struct ggml_tensor * KQ = ggml_mul_mat_pad(ctx0, Kcross, Q);
//struct ggml_tensor * KQ_scaled =
// ggml_scale(ctx0,
@ -2199,7 +2215,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV = ggml_mul_mat_pad(ctx0, V, KQ_soft_max);
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@ -2211,7 +2227,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// projection
{
cur = ggml_mul_mat(ctx0,
cur = ggml_mul_mat_pad(ctx0,
layer.cross_attn_ln_1_w,
cur);
@ -2240,7 +2256,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
}
// fully connected
cur = ggml_mul_mat(ctx0,
cur = ggml_mul_mat_pad(ctx0,
layer.mlp_0_w,
cur);
@ -2252,7 +2268,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
cur = ggml_gelu(ctx0, cur);
// projection
cur = ggml_mul_mat(ctx0,
cur = ggml_mul_mat_pad(ctx0,
layer.mlp_1_w,
cur);
@ -2282,7 +2298,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// might be useful in the future
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);
struct ggml_tensor * logits = ggml_mul_mat_pad(ctx0, model.d_te, cur);
ggml_build_forward_expand(gf, logits);