mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-10 01:09:17 +02:00
metal : decoder works on GPU!
This commit is contained in:
14
ggml-metal.m
14
ggml-metal.m
@ -368,6 +368,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
||||
for (int i = 0; i < ctx->n_buffers; ++i) {
|
||||
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
|
||||
|
||||
//metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
|
||||
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
||||
*offs = (size_t) ioffs;
|
||||
|
||||
@ -701,6 +702,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
// utilize float4
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
@ -708,6 +710,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||
@ -724,6 +727,7 @@ void ggml_metal_graph_compute(
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
// utilize float4
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
@ -731,6 +735,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul];
|
||||
@ -746,6 +751,8 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const float scale = *(const float *) src1->data;
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_scale];
|
||||
@ -805,7 +812,6 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
@ -1022,9 +1028,9 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
|
||||
|
||||
const int64_t n = ggml_nelements(src1);
|
||||
|
||||
|
@ -107,7 +107,6 @@ kernel void kernel_soft_max(
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
@ -119,58 +118,23 @@ kernel void kernel_soft_max(
|
||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
// parallel max
|
||||
buf[tpitg[0]] = -INFINITY;
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
|
||||
float lmax = psrc0[tpitg[0]];
|
||||
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
lmax = MAX(lmax, psrc0[i00]);
|
||||
}
|
||||
|
||||
// reduce
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
||||
if (tpitg[0] < i) {
|
||||
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
||||
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
||||
//if (tpitg[0] == 0) {
|
||||
// buf[0] = buf[0];
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float max = buf[0];
|
||||
const float max = simd_max(lmax);
|
||||
|
||||
// parallel sum
|
||||
buf[tpitg[0]] = 0.0f;
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||
buf[tpitg[0]] += exp_psrc0;
|
||||
lsum += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// whish to compute it twice.
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
// reduce
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = ntg[0]/2; i > 0; i /= 2) {
|
||||
if (tpitg[0] < i) {
|
||||
buf[tpitg[0]] += buf[tpitg[0] + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// broadcast - not needed, see above
|
||||
//// broadcast
|
||||
//if (tpitg[0] == 0) {
|
||||
// buf[0] = buf[0];
|
||||
//}
|
||||
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float sum = buf[0];
|
||||
const float sum = simd_sum(lsum);
|
||||
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
pdst[i00] /= sum;
|
||||
@ -195,22 +159,6 @@ kernel void kernel_diag_mask_inf(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f32(
|
||||
device const float * src0,
|
||||
device const int * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb1,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
const int i = tpig;
|
||||
const int r = ((device int32_t *) src1)[i];
|
||||
|
||||
for (int j = 0; j < ne00; j++) {
|
||||
dst[i*nb1 + j] = ((device float *) ((device char *) src0 + r*nb01))[j];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_norm(
|
||||
device const void * src0,
|
||||
device float * dst,
|
||||
@ -1763,6 +1711,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||
|
||||
//============================= templates and their specializations =============================
|
||||
|
||||
// NOTE: this is not dequantizeing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
float4x4 temp = *(((device float4x4 *)src));
|
||||
for (int i = 0; i < 16; i++){
|
||||
reg[i/4][i%4] = temp[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||
half4x4 temp = *(((device half4x4 *)src));
|
||||
@ -2111,6 +2068,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
||||
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
||||
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
||||
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||
|
43
whisper.cpp
43
whisper.cpp
@ -2031,17 +2031,15 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
||||
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state),
|
||||
n_state/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_state/n_head, n_past + N, n_head,
|
||||
ggml_element_size(kv_self.k)*n_state,
|
||||
ggml_element_size(kv_self.k)*n_state/n_head,
|
||||
ggml_element_size(kv_self.k)*n_state*n_ctx*il);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
@ -2108,9 +2106,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
// Kcross is already scaled
|
||||
struct ggml_tensor * Kcross =
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state),
|
||||
n_state/n_head, n_head, M);
|
||||
ggml_view_3d(ctx0, wstate.kv_cross.k,
|
||||
n_state/n_head, M, n_head,
|
||||
ggml_element_size(wstate.kv_cross.k)*n_state,
|
||||
ggml_element_size(wstate.kv_cross.k)*n_state/n_head,
|
||||
ggml_element_size(wstate.kv_cross.k)*n_state*M*il);
|
||||
|
||||
//struct ggml_tensor * Vcross =
|
||||
// ggml_reshape_3d(ctx0,
|
||||
@ -2133,15 +2133,11 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_cpy(ctx0,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state/n_head, n_head, N)),
|
||||
ggml_reshape_3d(ctx0, Qcur, n_state/n_head, n_head, N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q);
|
||||
|
||||
//struct ggml_tensor * KQ_scaled =
|
||||
// ggml_scale(ctx0,
|
||||
@ -2288,8 +2284,7 @@ static bool whisper_decode_internal(
|
||||
logits = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
if (wstate.ctx_metal && n_tokens == 1) {
|
||||
// TODO: fix for non-1 tokens
|
||||
if (wstate.ctx_metal) {
|
||||
ggml_metal_set_n_cb (wstate.ctx_metal, n_threads);
|
||||
ggml_metal_graph_compute(wstate.ctx_metal, gf);
|
||||
ggml_metal_get_tensor (wstate.ctx_metal, logits);
|
||||
@ -2302,8 +2297,8 @@ static bool whisper_decode_internal(
|
||||
}
|
||||
|
||||
// extract logits for all N tokens
|
||||
//logits_out.resize(N*n_vocab);
|
||||
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab);
|
||||
//logits_out.resize(n_tokens*n_vocab);
|
||||
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
|
||||
|
||||
// extract logits only for the last token
|
||||
logits_out.resize(n_vocab);
|
||||
@ -2794,9 +2789,7 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
log("checkpoint 0\n");
|
||||
alloc = ggml_allocr_new_measure(tensor_alignment);
|
||||
log("checkpoint 1\n");
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0);
|
||||
|
||||
@ -2818,15 +2811,11 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
log("checkpoint 2\n");
|
||||
alloc = ggml_allocr_new_measure(tensor_alignment);
|
||||
log("checkpoint 3\n");
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state);
|
||||
log("checkpoint 4\n");
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
|
||||
log("checkpoint 5\n");
|
||||
|
||||
ggml_allocr_free(alloc);
|
||||
|
||||
|
Reference in New Issue
Block a user