metal : reduce command encoding overhead (llama/9698)

This commit is contained in:
Georgi Gerganov 2024-10-02 15:12:16 +03:00
parent ff2cb0811f
commit 162a455402
2 changed files with 1999 additions and 1893 deletions

View File

@ -25,9 +25,6 @@
#include <stddef.h>
#include <stdbool.h>
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 64
struct ggml_tensor;
struct ggml_cgraph;
@ -48,8 +45,6 @@ GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);

View File

@ -12,6 +12,12 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 64
// max number of MTLCommandBuffer used to submit a graph for processing
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
#ifdef GGML_METAL_NDEBUG
#define GGML_METAL_LOG(...)
#define GGML_METAL_LOG_INFO(...)
@ -221,11 +227,11 @@ enum ggml_metal_kernel_type {
};
struct ggml_backend_metal_context {
int n_cb;
id<MTLDevice> device;
id<MTLCommandQueue> queue;
MTLComputePassDescriptor * edesc;
dispatch_queue_t d_queue;
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
@ -233,7 +239,27 @@ struct ggml_backend_metal_context {
bool support_simdgroup_reduction;
bool support_simdgroup_mm;
bool should_capture_next_compute;
// capture state
bool capture_next_compute;
bool capture_started;
id<MTLCaptureScope> capture_scope;
// command buffer state
int n_cb; // number of extra threads used to submit the command buffers
int n_nodes_0; // number of nodes submitted by the main thread
int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
int n_nodes_per_cb;
struct ggml_cgraph * gf;
// the callback given to the thread pool
// TODO: ideally, this should be created once, utilizing the command buffer state above
// for some reason, doing it like this leads to a crash
void (^encode_async)(size_t ith);
// n_cb command buffers + 1 used by the main thread
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
@ -303,7 +329,7 @@ static void * ggml_metal_host_malloc(size_t n) {
return data;
}
static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
static struct ggml_backend_metal_context * ggml_metal_init(void) {
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@ -322,8 +348,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
// Configure context
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
ctx->device = device;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
ctx->queue = [ctx->device newCommandQueue];
ctx->edesc = MTLComputePassDescriptor.computePassDescriptor;
ctx->edesc.dispatchType = MTLDispatchTypeSerial;
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
id<MTLLibrary> metal_library;
@ -455,7 +482,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
ctx->should_capture_next_compute = false;
ctx->capture_next_compute = false;
ctx->capture_started = false;
ctx->capture_scope = nil;
ctx->gf = nil;
ctx->encode_async = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
ctx->command_buffers[i] = nil;
}
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
@ -686,6 +721,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
}
[metal_library release];
return ctx;
}
@ -874,78 +910,23 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
}
}
static enum ggml_status ggml_metal_graph_compute(
static void ggml_metal_encode_node(
struct ggml_backend_metal_context * ctx,
struct ggml_cgraph * gf) {
int idx,
id<MTLComputeCommandEncoder> encoder) {
struct ggml_cgraph * gf = ctx->gf;
@autoreleasepool {
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
edesc.dispatchType = MTLDispatchTypeSerial;
struct ggml_tensor * node = ggml_graph_node(gf, idx);
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
const int n_nodes = gf->n_nodes;
const int n_cb = ctx->n_cb;
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
const bool should_capture = ctx->should_capture_next_compute;
if (should_capture) {
ctx->should_capture_next_compute = false;
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->queue;
NSError * error = nil;
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
GGML_ABORT("capture failed");
}
}
id<MTLCommandBuffer> command_buffer_builder[n_cb];
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
command_buffer_builder[cb_idx] = command_buffer;
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer enqueue];
}
}
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
const int cb_idx = iter;
size_t offs_src0 = 0;
size_t offs_src1 = 0;
size_t offs_src2 = 0;
size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
for (int i = node_start; i < node_end; ++i) {
if (i == -1) {
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
}
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
struct ggml_tensor * dst = gf->nodes[i];
struct ggml_tensor * src0 = node->src[0];
struct ggml_tensor * src1 = node->src[1];
struct ggml_tensor * src2 = node->src[2];
struct ggml_tensor * dst = node;
if (ggml_is_empty(dst)) {
continue;
return;
}
switch (dst->op) {
@ -956,7 +937,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_PERMUTE:
{
// noop -> next node
} continue;
} return;
default:
{
} break;
@ -967,10 +948,6 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ABORT("unsupported op");
}
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
}
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0;
@ -1015,6 +992,11 @@ static enum ggml_status ggml_metal_graph_compute(
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
size_t offs_src0 = 0;
size_t offs_src1 = 0;
size_t offs_src2 = 0;
size_t offs_dst = 0;
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
@ -1039,7 +1021,7 @@ static enum ggml_status ggml_metal_graph_compute(
{
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
const int32_t dim = ((int32_t *) dst->op_params)[0];
const int32_t dim = ((const int32_t *) dst->op_params)[0];
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1203,12 +1185,12 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
const size_t offs = ((int32_t *) dst->op_params)[3];
const size_t pnb1 = ((const int32_t *) dst->op_params)[0];
const size_t pnb2 = ((const int32_t *) dst->op_params)[1];
const size_t pnb3 = ((const int32_t *) dst->op_params)[2];
const size_t offs = ((const int32_t *) dst->op_params)[3];
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
const bool inplace = (bool) ((const int32_t *) dst->op_params)[4];
if (!inplace) {
// run a separete kernel to cpy src->dst
@ -1309,8 +1291,8 @@ static enum ggml_status ggml_metal_graph_compute(
float min;
float max;
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1323,7 +1305,7 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
switch (ggml_get_unary_op(node)) {
// we are not taking into account the strides, so for now require contiguous tensors
GGML_ASSERT(ggml_is_contiguous(src0));
@ -1422,7 +1404,7 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
default:
{
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
GGML_ABORT("fatal error");
}
} break;
@ -1551,8 +1533,8 @@ static enum ggml_status ggml_metal_graph_compute(
float scale;
float max_bias;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
@ -1585,7 +1567,7 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_DIAG_MASK_INF:
{
const int n_past = ((int32_t *)(dst->op_params))[0];
const int n_past = ((const int32_t *)(dst->op_params))[0];
id<MTLComputePipelineState> pipeline = nil;
@ -1644,9 +1626,9 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_SSM_SCAN:
{
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
struct ggml_tensor * src4 = gf->nodes[i]->src[4];
struct ggml_tensor * src5 = gf->nodes[i]->src[5];
struct ggml_tensor * src3 = node->src[3];
struct ggml_tensor * src4 = node->src[4];
struct ggml_tensor * src5 = node->src[5];
GGML_ASSERT(src3);
GGML_ASSERT(src4);
@ -2425,7 +2407,7 @@ static enum ggml_status ggml_metal_graph_compute(
float eps;
memcpy(&eps, dst->op_params + 1, sizeof(float));
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
const int32_t n_groups = ((const int32_t *) dst->op_params)[0];
int nth = 32; // SIMD width
@ -2479,11 +2461,11 @@ static enum ggml_status ggml_metal_graph_compute(
const int nth = MIN(1024, ne00);
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_past = ((const int32_t *) dst->op_params)[0];
const int n_dims = ((const int32_t *) dst->op_params)[1];
const int mode = ((const int32_t *) dst->op_params)[2];
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
float freq_base;
float freq_scale;
@ -2492,12 +2474,12 @@ static enum ggml_status ggml_metal_graph_compute(
float beta_fast;
float beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@ -2686,8 +2668,8 @@ static enum ggml_status ggml_metal_graph_compute(
float start;
float step;
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
@ -2786,7 +2768,7 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ggml_are_same_shape (src1, src2));
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
struct ggml_tensor * src3 = node->src[3];
size_t offs_src3 = 0;
@ -2811,9 +2793,9 @@ static enum ggml_status ggml_metal_graph_compute(
float scale;
float max_bias;
float logit_softcap;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
@ -3014,10 +2996,86 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
default:
{
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
GGML_ABORT("fatal error");
}
}
}
static enum ggml_status ggml_metal_graph_compute(
struct ggml_backend_metal_context * ctx,
struct ggml_cgraph * gf) {
// number of nodes encoded by the main thread (empirically determined)
const int n_main = 128;
// number of threads in addition to the main thread
const int n_cb = ctx->n_cb;
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
// each thread creates it's own command buffer and enqueues the ops in parallel
//
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
@autoreleasepool {
ctx->gf = gf;
ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
const bool should_capture = ctx->capture_next_compute;
if (should_capture) {
ctx->capture_next_compute = false;
if (!ctx->capture_started) {
// create capture scope
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
descriptor.captureObject = ctx->capture_scope;
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
NSError * error = nil;
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
} else {
[ctx->capture_scope beginScope];
ctx->capture_started = true;
}
}
}
// TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
ctx->encode_async = ^(size_t iter) {
const int cb_idx = iter;
const int n_cb_l = ctx->n_cb;
const int n_nodes_0 = ctx->n_nodes_0;
const int n_nodes_1 = ctx->n_nodes_1;
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: ctx->edesc];
int node_start = 0;
int node_end = n_nodes_0;
if (cb_idx < n_cb_l) {
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
}
for (int idx = node_start; idx < node_end; ++idx) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
}
ggml_metal_encode_node(ctx, idx, encoder);
if (should_capture) {
[encoder popDebugGroup];
@ -3029,13 +3087,52 @@ static enum ggml_status ggml_metal_graph_compute(
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer commit];
}
});
};
// Wait for completion and check status of each command buffer
// the main thread commits the first few commands immediately
// command_buffer[n_cb]
{
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
ctx->command_buffers[n_cb] = command_buffer;
[command_buffer enqueue];
ctx->encode_async(n_cb);
}
// prepare the rest of the command buffers asynchronously
// command_buffer[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
ctx->command_buffers[cb_idx] = command_buffer;
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer enqueue];
}
}
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
[command_buffer waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
if (status == MTLCommandBufferStatusError) {
GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
}
return GGML_STATUS_FAILED;
}
}
for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> command_buffer = command_buffers[i];
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
[command_buffer waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status];
@ -3048,12 +3145,12 @@ static enum ggml_status ggml_metal_graph_compute(
return GGML_STATUS_FAILED;
}
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
if (!next_buffer) {
continue;
}
bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
if (next_queued) {
continue;
}
@ -3066,11 +3163,12 @@ static enum ggml_status ggml_metal_graph_compute(
[next_buffer commit];
}
if (should_capture) {
if (!should_capture && ctx->capture_started) {
[ctx->capture_scope endScope];
[[MTLCaptureManager sharedCaptureManager] stopCapture];
}
}
return GGML_STATUS_SUCCESS;
}
@ -3405,6 +3503,25 @@ GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, g
UNUSED(backend);
}
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
GGML_ASSERT(ggml_backend_is_metal(backend));
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
if (ctx->n_cb != n_cb) {
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
if (ctx->n_cb > 2) {
GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
}
}
// TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
//ctx->encode_async = ^(size_t iter) {
// ...
//};
}
static struct ggml_backend_i ggml_backend_metal_i = {
/* .get_name = */ ggml_backend_metal_name,
/* .free = */ ggml_backend_metal_free,
@ -3439,35 +3556,29 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
}
ggml_backend_t ggml_backend_metal_init(void) {
struct ggml_backend_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
struct ggml_backend_metal_context * ctx = ggml_metal_init();
if (ctx == NULL) {
GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
return NULL;
}
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
*metal_backend = (struct ggml_backend) {
*backend = (struct ggml_backend) {
/* .guid = */ ggml_backend_metal_guid(),
/* .interface = */ ggml_backend_metal_i,
/* .context = */ ctx,
};
return metal_backend;
ggml_backend_metal_set_n_cb(backend, 1);
return backend;
}
bool ggml_backend_is_metal(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
}
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
GGML_ASSERT(ggml_backend_is_metal(backend));
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
}
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
GGML_ASSERT(ggml_backend_is_metal(backend));
@ -3489,7 +3600,7 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
GGML_ASSERT(ggml_backend_is_metal(backend));
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
ctx->should_capture_next_compute = true;
ctx->capture_next_compute = true;
}
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning