metal : single allocation of encode_async block (llama/9747)

* Single allocation of encode_async block with non-ARC capture in ggml-metal.m

* Moving Block_release to the deallocation code

* Release encode block when re-setting encoding buffer count if needed

* Update ggml/src/ggml-metal.m

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Paul Tsochantaris 2024-10-07 13:26:31 +01:00 committed by Georgi Gerganov
parent 8f9bdca4c4
commit 80753d4da8

View File

@ -239,8 +239,6 @@ struct ggml_backend_metal_context {
struct ggml_cgraph * gf;
// the callback given to the thread pool
// TODO: ideally, this should be created once, utilizing the command buffer state above
// for some reason, doing it like this leads to a crash
void (^encode_async)(size_t ith);
// n_cb command buffers + 1 used by the main thread
@ -683,6 +681,8 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
[ctx->kernels[i].pipeline release];
}
Block_release(ctx->encode_async);
[ctx->queue release];
[ctx->device release];
@ -3000,46 +3000,6 @@ static enum ggml_status ggml_metal_graph_compute(
}
}
// TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
ctx->encode_async = ^(size_t iter) {
const int cb_idx = iter;
const int n_cb_l = ctx->n_cb;
const int n_nodes_0 = ctx->n_nodes_0;
const int n_nodes_1 = ctx->n_nodes_1;
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
int node_start = 0;
int node_end = n_nodes_0;
if (cb_idx < n_cb_l) {
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
}
for (int idx = node_start; idx < node_end; ++idx) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
}
ggml_metal_encode_node(ctx, idx, encoder);
if (should_capture) {
[encoder popDebugGroup];
}
}
[encoder endEncoding];
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer commit];
}
};
// the main thread commits the first few commands immediately
// command_buffer[n_cb]
{
@ -3129,7 +3089,7 @@ static enum ggml_status ggml_metal_graph_compute(
// default buffer
static id<MTLDevice> g_backend_device = nil;
static int g_backend_device_ref_count = 0; // TODO: make thread-safe
static int g_backend_device_ref_count = 0;
static id<MTLDevice> ggml_backend_metal_get_device(void) {
if (g_backend_device == nil) {
@ -3468,10 +3428,50 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
}
}
// TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
//ctx->encode_async = ^(size_t iter) {
// ...
//};
if (ctx->encode_async) {
Block_release(ctx->encode_async);
}
ctx->encode_async = Block_copy(^(size_t iter) {
const int cb_idx = iter;
const int n_cb_l = ctx->n_cb;
const int n_nodes_0 = ctx->n_nodes_0;
const int n_nodes_1 = ctx->n_nodes_1;
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
int node_start = 0;
int node_end = n_nodes_0;
if (cb_idx < n_cb_l) {
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
}
const bool should_capture = ctx->capture_next_compute;
for (int idx = node_start; idx < node_end; ++idx) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
ggml_metal_encode_node(ctx, idx, encoder);
if (should_capture) {
[encoder popDebugGroup];
}
}
[encoder endEncoding];
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer commit];
}
});
}
static struct ggml_backend_i ggml_backend_metal_i = {