mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-05-17 14:41:10 +02:00
metal : optimize MoE for large batches (llama/13388)
This commit is contained in:
parent
029c8837f8
commit
41ed62bdbc
@ -299,21 +299,42 @@ typedef struct {
|
|||||||
} ggml_metal_kargs_mul_mv_ext;
|
} ggml_metal_kargs_mul_mv_ext;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t nei0;
|
int32_t ne10;
|
||||||
int32_t nei1;
|
int32_t ne11; // n_expert_used (bcast)
|
||||||
uint64_t nbi1;
|
uint64_t nb11;
|
||||||
|
uint64_t nb12;
|
||||||
|
int32_t neh11; // n_tokens
|
||||||
|
uint64_t nbh11;
|
||||||
|
int32_t ne20; // n_expert_used
|
||||||
|
uint64_t nb21;
|
||||||
|
} ggml_metal_kargs_mul_mm_id_map0;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne20; // n_expert_used
|
||||||
|
int32_t neh0;
|
||||||
|
int32_t neh1;
|
||||||
|
uint64_t nbh1;
|
||||||
|
uint64_t nbh2;
|
||||||
|
int32_t ne0;
|
||||||
|
uint64_t nb1;
|
||||||
|
uint64_t nb2;
|
||||||
|
} ggml_metal_kargs_mul_mm_id_map1;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
int32_t ne00;
|
int32_t ne00;
|
||||||
int32_t ne02;
|
int32_t ne02;
|
||||||
uint64_t nb01;
|
uint64_t nb01;
|
||||||
uint64_t nb02;
|
uint64_t nb02;
|
||||||
int32_t ne11;
|
uint64_t nb03;
|
||||||
int32_t ne12;
|
int32_t neh12;
|
||||||
int32_t ne13;
|
uint64_t nbh10;
|
||||||
uint64_t nb10;
|
uint64_t nbh11;
|
||||||
uint64_t nb11;
|
uint64_t nbh12;
|
||||||
uint64_t nb12;
|
uint64_t nbh13;
|
||||||
int32_t ne0;
|
int32_t neh0;
|
||||||
int32_t ne1;
|
int32_t neh1;
|
||||||
|
int16_t r2;
|
||||||
|
int16_t r3;
|
||||||
} ggml_metal_kargs_mul_mm_id;
|
} ggml_metal_kargs_mul_mm_id;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -306,28 +306,30 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
||||||
@ -490,7 +492,264 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_COUNT
|
GGML_METAL_KERNEL_TYPE_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//
|
||||||
|
// ggml_metal_heap
|
||||||
|
//
|
||||||
|
|
||||||
|
struct ggml_metal_heap {
|
||||||
|
// number of times the heap was unused
|
||||||
|
int n_unused;
|
||||||
|
|
||||||
|
// total number of buffer allocations in this heap across all computes
|
||||||
|
int64_t n_alloc;
|
||||||
|
|
||||||
|
// current offset in the heap - we reset this after each node in order to reuse the memory
|
||||||
|
size_t offs;
|
||||||
|
|
||||||
|
// the currently allocated MTLBuffer objects in this heap
|
||||||
|
id<MTLHeap> obj;
|
||||||
|
|
||||||
|
NSMutableArray * bufs;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
|
||||||
|
struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
|
||||||
|
|
||||||
|
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
|
||||||
|
desc.storageMode = MTLStorageModePrivate;
|
||||||
|
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
||||||
|
desc.type = MTLHeapTypePlacement;
|
||||||
|
desc.size = size;
|
||||||
|
|
||||||
|
heap->n_unused = 0;
|
||||||
|
heap->n_alloc = 0;
|
||||||
|
|
||||||
|
heap->obj = [device newHeapWithDescriptor:desc];
|
||||||
|
if (!heap->obj) {
|
||||||
|
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
|
||||||
|
|
||||||
|
free(heap);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
[desc release];
|
||||||
|
|
||||||
|
heap->bufs = [[NSMutableArray alloc] init];
|
||||||
|
|
||||||
|
return heap;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
|
||||||
|
heap->offs = 0;
|
||||||
|
|
||||||
|
// count how many graph computes the heap ended up being unused
|
||||||
|
if ([heap->bufs count] > 0) {
|
||||||
|
heap->n_unused = 0;
|
||||||
|
} else {
|
||||||
|
heap->n_unused++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (id<MTLBuffer> buf in heap->bufs) {
|
||||||
|
[buf release];
|
||||||
|
}
|
||||||
|
[heap->bufs removeAllObjects];
|
||||||
|
|
||||||
|
// tell the OS that it can reuse this memory if needed
|
||||||
|
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
||||||
|
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
|
||||||
|
if (heap == nil) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_heap_reset(heap);
|
||||||
|
|
||||||
|
[heap->obj release];
|
||||||
|
[heap->bufs release];
|
||||||
|
|
||||||
|
free(heap);
|
||||||
|
}
|
||||||
|
|
||||||
|
@interface ggml_metal_heap_ptr : NSObject
|
||||||
|
|
||||||
|
@property (nonatomic, assign) struct ggml_metal_heap * data;
|
||||||
|
|
||||||
|
@end
|
||||||
|
|
||||||
|
@implementation ggml_metal_heap_ptr
|
||||||
|
@end
|
||||||
|
|
||||||
|
//
|
||||||
|
// ggml_metal_mem_pool
|
||||||
|
//
|
||||||
|
|
||||||
|
struct ggml_metal_mem_pool {
|
||||||
|
id<MTLDevice> device;
|
||||||
|
|
||||||
|
int n_heaps; // total number of heaps ever created (including those that were removed)
|
||||||
|
|
||||||
|
NSMutableArray * heaps;
|
||||||
|
NSMutableArray * heaps_to_remove;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
|
||||||
|
struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
|
||||||
|
|
||||||
|
mem_pool->n_heaps = 0;
|
||||||
|
|
||||||
|
mem_pool->heaps = [[NSMutableArray alloc] init];
|
||||||
|
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
|
||||||
|
|
||||||
|
return mem_pool;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
|
||||||
|
GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
|
||||||
|
|
||||||
|
size_t size_all = 0;
|
||||||
|
size_t size_cur = 0;
|
||||||
|
|
||||||
|
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||||
|
GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
|
||||||
|
GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
|
||||||
|
GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
|
||||||
|
GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
|
||||||
|
GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
|
||||||
|
|
||||||
|
if ([ptr.data->bufs count] > 0) {
|
||||||
|
size_cur += [ptr.data->obj size];
|
||||||
|
}
|
||||||
|
size_all += [ptr.data->obj size];
|
||||||
|
|
||||||
|
ggml_metal_heap_free(ptr.data);
|
||||||
|
[ptr release];
|
||||||
|
}
|
||||||
|
[mem_pool->heaps release];
|
||||||
|
[mem_pool->heaps_to_remove release];
|
||||||
|
|
||||||
|
if (size_all > 0) {
|
||||||
|
GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
|
||||||
|
GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(mem_pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
|
||||||
|
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
|
||||||
|
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
|
||||||
|
|
||||||
|
struct ggml_metal_heap * heap = ptr.data;
|
||||||
|
ggml_metal_heap_reset(heap);
|
||||||
|
|
||||||
|
// if the heap hasn't been used for a while, remove it
|
||||||
|
if (heap->n_unused >= 128) {
|
||||||
|
[mem_pool->heaps_to_remove addObject:@(i)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mem_pool->heaps_to_remove.count > 0) {
|
||||||
|
// remove in reverse order
|
||||||
|
for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
|
||||||
|
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
|
||||||
|
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
|
||||||
|
|
||||||
|
struct ggml_metal_heap * heap = ptr.data;
|
||||||
|
ggml_metal_heap_free(heap);
|
||||||
|
|
||||||
|
[mem_pool->heaps removeObjectAtIndex:index];
|
||||||
|
[ptr release];
|
||||||
|
|
||||||
|
if (i == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[mem_pool->heaps_to_remove removeAllObjects];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
|
||||||
|
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||||
|
ptr.data->offs = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
|
||||||
|
const size_t alignment = 256;
|
||||||
|
|
||||||
|
const size_t size_aligned = GGML_PAD(size, alignment);
|
||||||
|
|
||||||
|
// try one of the existing heaps
|
||||||
|
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
|
||||||
|
struct ggml_metal_heap * heap = ptr.data;
|
||||||
|
if (heap->offs + size_aligned <= [heap->obj size]) {
|
||||||
|
// if this is the first buffer in the heap for the current command buffer, tell the OS that
|
||||||
|
// it cannot free the memory used by the heap
|
||||||
|
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
|
||||||
|
if ([heap->bufs count] == 0) {
|
||||||
|
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
||||||
|
if (buf == nil) {
|
||||||
|
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
||||||
|
return nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
heap->n_alloc++;
|
||||||
|
heap->offs += size_aligned;
|
||||||
|
|
||||||
|
[heap->bufs addObject:buf];
|
||||||
|
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a new heap that can fit this buffer
|
||||||
|
ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
|
||||||
|
|
||||||
|
struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
|
||||||
|
if (heap == NULL) {
|
||||||
|
GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
//GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
|
||||||
|
|
||||||
|
heap_ptr.data = heap;
|
||||||
|
ggml_metal_heap_reset(heap);
|
||||||
|
|
||||||
|
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
|
||||||
|
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
|
||||||
|
if (buf == nil) {
|
||||||
|
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
heap->n_alloc++;
|
||||||
|
heap->offs += size_aligned;
|
||||||
|
|
||||||
|
[heap->bufs addObject:buf];
|
||||||
|
|
||||||
|
[mem_pool->heaps addObject:heap_ptr];
|
||||||
|
mem_pool->n_heaps++;
|
||||||
|
|
||||||
|
return buf;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_metal_command_buffer {
|
||||||
|
id<MTLCommandBuffer> obj;
|
||||||
|
|
||||||
|
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
||||||
|
struct ggml_metal_mem_pool * mem_pool;
|
||||||
|
};
|
||||||
|
|
||||||
struct ggml_backend_metal_context {
|
struct ggml_backend_metal_context {
|
||||||
|
id<MTLDevice> device;
|
||||||
id<MTLCommandQueue> queue;
|
id<MTLCommandQueue> queue;
|
||||||
|
|
||||||
dispatch_queue_t d_queue;
|
dispatch_queue_t d_queue;
|
||||||
@ -515,7 +774,7 @@ struct ggml_backend_metal_context {
|
|||||||
void (^encode_async)(size_t ith);
|
void (^encode_async)(size_t ith);
|
||||||
|
|
||||||
// n_cb command buffers + 1 used by the main thread
|
// n_cb command buffers + 1 used by the main thread
|
||||||
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
||||||
|
|
||||||
// abort ggml_metal_graph_compute if callback returns true
|
// abort ggml_metal_graph_compute if callback returns true
|
||||||
ggml_abort_callback abort_callback;
|
ggml_abort_callback abort_callback;
|
||||||
@ -705,8 +964,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
struct ggml_backend_metal_device_context * ctx_dev = dev->context;
|
||||||
|
|
||||||
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
|
||||||
|
|
||||||
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
|
||||||
|
|
||||||
|
ctx->device = device;
|
||||||
ctx->queue = [device newCommandQueue];
|
ctx->queue = [device newCommandQueue];
|
||||||
if (ctx->queue == nil) {
|
if (ctx->queue == nil) {
|
||||||
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
||||||
@ -768,7 +1029,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
ctx->gf = nil;
|
ctx->gf = nil;
|
||||||
ctx->encode_async = nil;
|
ctx->encode_async = nil;
|
||||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||||
ctx->command_buffers[i] = nil;
|
ctx->cmd_bufs[i].obj = nil;
|
||||||
|
|
||||||
|
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
|
||||||
|
ctx->cmd_bufs[i].mem_pool->device = device;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
||||||
@ -985,28 +1249,30 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
||||||
@ -1181,6 +1447,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
|||||||
|
|
||||||
[ctx->queue release];
|
[ctx->queue release];
|
||||||
|
|
||||||
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||||
|
// ctx->cmd_bufs[i].obj is auto released
|
||||||
|
|
||||||
|
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
||||||
|
}
|
||||||
|
|
||||||
dispatch_release(ctx->d_queue);
|
dispatch_release(ctx->d_queue);
|
||||||
|
|
||||||
free(ctx);
|
free(ctx);
|
||||||
@ -1486,10 +1758,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_metal_encode_node(
|
static bool ggml_metal_encode_node(
|
||||||
ggml_backend_t backend,
|
ggml_backend_t backend,
|
||||||
int idx,
|
int idx,
|
||||||
id<MTLComputeCommandEncoder> encoder) {
|
id<MTLComputeCommandEncoder> encoder,
|
||||||
|
struct ggml_metal_mem_pool * mem_pool) {
|
||||||
struct ggml_backend_metal_context * ctx = backend->context;
|
struct ggml_backend_metal_context * ctx = backend->context;
|
||||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||||
|
|
||||||
@ -1505,7 +1778,7 @@ static void ggml_metal_encode_node(
|
|||||||
struct ggml_tensor * dst = node;
|
struct ggml_tensor * dst = node;
|
||||||
|
|
||||||
if (ggml_is_empty(dst)) {
|
if (ggml_is_empty(dst)) {
|
||||||
return;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
@ -1516,7 +1789,7 @@ static void ggml_metal_encode_node(
|
|||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
{
|
{
|
||||||
// noop -> next node
|
// noop -> next node
|
||||||
} return;
|
} return true;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
} break;
|
} break;
|
||||||
@ -1527,6 +1800,8 @@ static void ggml_metal_encode_node(
|
|||||||
GGML_ABORT("unsupported op");
|
GGML_ABORT("unsupported op");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_mem_pool_clear(mem_pool);
|
||||||
|
|
||||||
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
||||||
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
||||||
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
||||||
@ -2173,6 +2448,56 @@ static void ggml_metal_encode_node(
|
|||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
// use this branch to test the ggml_metal_mem_pool functionality
|
||||||
|
#if 0
|
||||||
|
// cpy to tmp buffer in MTLHeap
|
||||||
|
|
||||||
|
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
|
||||||
|
if (!h_src0) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
offs_src0 = 0;
|
||||||
|
|
||||||
|
ggml_metal_kargs_cpy args_cpy = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.ne03 =*/ ne03,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne0 =*/ ne00,
|
||||||
|
/*.ne1 =*/ ne01,
|
||||||
|
/*.ne2 =*/ ne02,
|
||||||
|
/*.ne3 =*/ ne03,
|
||||||
|
/*.nb0 =*/ nb00,
|
||||||
|
/*.nb1 =*/ nb01,
|
||||||
|
/*.nb2 =*/ nb02,
|
||||||
|
/*.nb3 =*/ nb03,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
|
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
|
||||||
|
} else {
|
||||||
|
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
|
||||||
|
}
|
||||||
|
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:h_src0 offset:0 atIndex:2];
|
||||||
|
|
||||||
|
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||||
|
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
|
||||||
|
|
||||||
|
#else
|
||||||
|
id<MTLBuffer> h_src0 = id_src0;
|
||||||
|
#endif
|
||||||
|
// softmax
|
||||||
|
|
||||||
ggml_metal_kargs_soft_max args = {
|
ggml_metal_kargs_soft_max args = {
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
/*.ne01 =*/ ne01,
|
/*.ne01 =*/ ne01,
|
||||||
@ -2185,11 +2510,11 @@ static void ggml_metal_encode_node(
|
|||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
|
||||||
if (id_src1) {
|
if (id_src1) {
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||||
@ -2683,7 +3008,7 @@ static void ggml_metal_encode_node(
|
|||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
@ -2903,8 +3228,6 @@ static void ggml_metal_encode_node(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
const int n_as = src0->ne[2];
|
|
||||||
|
|
||||||
// src2 = ids
|
// src2 = ids
|
||||||
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
||||||
|
|
||||||
@ -2918,24 +3241,21 @@ static void ggml_metal_encode_node(
|
|||||||
GGML_ASSERT(ne03 == 1);
|
GGML_ASSERT(ne03 == 1);
|
||||||
GGML_ASSERT(ne13 == 1);
|
GGML_ASSERT(ne13 == 1);
|
||||||
|
|
||||||
|
const uint32_t r2 = 1;
|
||||||
|
const uint32_t r3 = 1;
|
||||||
|
|
||||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
// ne20 = n_used_experts
|
// ne20 = n_used_experts
|
||||||
// ne21 = n_rows
|
// ne21 = n_rows (batch size)
|
||||||
const int dst_rows = ne20*ne21;
|
const int ne21_mm_id_min = 32;
|
||||||
const int dst_rows_min = n_as;
|
|
||||||
const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4;
|
|
||||||
|
|
||||||
// max size of the rowids array in the kernel shared buffer
|
|
||||||
//GGML_ASSERT(dst_rows <= dst_rows_max);
|
|
||||||
|
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||||
//ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments
|
(ne21 >= ne21_mm_id_min)) {
|
||||||
dst_rows > dst_rows_min &&
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
dst_rows <= dst_rows_max) {
|
|
||||||
|
|
||||||
// some Metal matrix data types require aligned pointers
|
// some Metal matrix data types require aligned pointers
|
||||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||||
@ -2946,62 +3266,169 @@ static void ggml_metal_encode_node(
|
|||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64_t neh10 = ne10; // n_embd
|
||||||
|
const int64_t neh11 = ne21; // n_tokens
|
||||||
|
const int64_t neh12 = ne02; // n_expert
|
||||||
|
|
||||||
|
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
|
||||||
|
const uint64_t nbh11 = nbh10*neh10;
|
||||||
|
const uint64_t nbh12 = nbh11*neh11;
|
||||||
|
const uint64_t nbh13 = nbh12*neh12;
|
||||||
|
|
||||||
|
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
|
||||||
|
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
||||||
|
if (!h_src1) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t neh0 = ne0;
|
||||||
|
const int64_t neh1 = ne21;
|
||||||
|
const int64_t neh2 = ne02;
|
||||||
|
|
||||||
|
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
|
||||||
|
const uint64_t nbh1 = nbh0*neh0;
|
||||||
|
const uint64_t nbh2 = nbh1*neh1;
|
||||||
|
//const uint64_t nbh3 = nbh2*neh2;
|
||||||
|
|
||||||
|
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
|
||||||
|
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
||||||
|
if (!h_dst) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokens per expert
|
||||||
|
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
|
||||||
|
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
||||||
|
if (!h_tpe) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// id map
|
||||||
|
// [n_expert_used, n_tokens]
|
||||||
|
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
|
||||||
|
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
||||||
|
if (!h_ids) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const int nth = MIN(1024, ne10/4);
|
||||||
|
|
||||||
|
ggml_metal_kargs_mul_mm_id_map0 args = {
|
||||||
|
ne10,
|
||||||
|
ne11, // n_expert_used (bcast)
|
||||||
|
nb11,
|
||||||
|
nb12,
|
||||||
|
neh11, // n_tokens
|
||||||
|
nbh11,
|
||||||
|
ne20, // n_expert_used
|
||||||
|
nb21,
|
||||||
|
};
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
|
[encoder setBuffer: h_src1 offset:0 atIndex:3];
|
||||||
|
[encoder setBuffer: h_tpe offset:0 atIndex:4];
|
||||||
|
[encoder setBuffer: h_ids offset:0 atIndex:5];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break;
|
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break;
|
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break;
|
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break;
|
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break;
|
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break;
|
case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break;
|
case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
|
case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break;
|
||||||
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
|
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break;
|
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break;
|
||||||
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break;
|
case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break;
|
||||||
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break;
|
||||||
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
default: GGML_ABORT("MUL_MAT_ID not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_kargs_mul_mm_id args = {
|
ggml_metal_kargs_mul_mm_id args = {
|
||||||
/*.nei0 =*/ ne20,
|
|
||||||
/*.nei1 =*/ ne21,
|
|
||||||
/*.nbi1 =*/ nb21,
|
|
||||||
/*.ne00 =*/ ne00,
|
/*.ne00 =*/ ne00,
|
||||||
/*.ne02 =*/ ne02,
|
/*.ne02 =*/ ne02,
|
||||||
/*.nb01 =*/ nb01,
|
/*.nb01 =*/ nb01,
|
||||||
/*.nb02 =*/ nb02,
|
/*.nb02 =*/ nb02,
|
||||||
/*.ne11 =*/ ne11,
|
/*.nb03 =*/ nb03,
|
||||||
/*.ne12 =*/ ne12,
|
/*.neh12 =*/ neh12,
|
||||||
/*.ne13 =*/ ne13,
|
/*.nbh10 =*/ nbh10,
|
||||||
/*.nb10 =*/ nb10,
|
/*.nbh11 =*/ nbh11,
|
||||||
/*.nb11 =*/ nb11,
|
/*.nbh12 =*/ nbh12,
|
||||||
/*.nb12 =*/ nb12,
|
/*.nbh13 =*/ nbh13,
|
||||||
/*.ne0 =*/ ne0,
|
/*.neh0 =*/ neh0,
|
||||||
/*.ne1 =*/ ne1,
|
/*.neh1 =*/ neh1,
|
||||||
|
/*.r2 =*/ r2,
|
||||||
|
/*.r3 =*/ r3,
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
[encoder setBuffer: h_src1 offset:0 atIndex:2];
|
||||||
|
[encoder setBuffer: h_tpe offset:0 atIndex:3];
|
||||||
|
[encoder setBuffer: h_dst offset:0 atIndex:4];
|
||||||
|
|
||||||
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne0 % 4 == 0);
|
||||||
|
|
||||||
|
const int nth = MIN(1024, ne0/4);
|
||||||
|
|
||||||
|
ggml_metal_kargs_mul_mm_id_map1 args = {
|
||||||
|
ne20, // n_expert_used
|
||||||
|
neh0,
|
||||||
|
neh1,
|
||||||
|
nbh1,
|
||||||
|
nbh2,
|
||||||
|
ne0,
|
||||||
|
nb1,
|
||||||
|
nb2,
|
||||||
|
};
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
|
[encoder setBuffer: h_dst offset:0 atIndex:1];
|
||||||
|
[encoder setBuffer: h_ids offset:0 atIndex:2];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
|
||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
}
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
|
||||||
} else {
|
} else {
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
@ -3195,7 +3622,7 @@ static void ggml_metal_encode_node(
|
|||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||||
|
|
||||||
const int64_t _ne1 = 1;
|
const int64_t _ne1 = 1;
|
||||||
const int64_t ne123 = dst_rows;
|
const int64_t ne123 = ne20*ne21;
|
||||||
|
|
||||||
if (smem > 0) {
|
if (smem > 0) {
|
||||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||||
@ -4601,6 +5028,8 @@ static void ggml_metal_encode_node(
|
|||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_status ggml_metal_graph_compute(
|
static enum ggml_status ggml_metal_graph_compute(
|
||||||
@ -4654,25 +5083,25 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// the main thread commits the first few commands immediately
|
// the main thread commits the first few commands immediately
|
||||||
// command_buffer[n_cb]
|
// cmd_buf[n_cb]
|
||||||
{
|
{
|
||||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||||
ctx->command_buffers[n_cb] = command_buffer;
|
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
||||||
|
|
||||||
[command_buffer enqueue];
|
[cmd_buf enqueue];
|
||||||
ctx->encode_async(n_cb);
|
ctx->encode_async(n_cb);
|
||||||
}
|
}
|
||||||
|
|
||||||
// prepare the rest of the command buffers asynchronously
|
// prepare the rest of the command buffers asynchronously
|
||||||
// command_buffer[0.. n_cb)
|
// cmd_buf[0.. n_cb)
|
||||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||||
ctx->command_buffers[cb_idx] = command_buffer;
|
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
||||||
|
|
||||||
// always enqueue the first two command buffers
|
// always enqueue the first two command buffers
|
||||||
// enqueue all of the command buffers if we don't need to abort
|
// enqueue all of the command buffers if we don't need to abort
|
||||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
[command_buffer enqueue];
|
[cmd_buf enqueue];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4681,14 +5110,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
// wait for completion and check status of each command buffer
|
// wait for completion and check status of each command buffer
|
||||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||||
{
|
{
|
||||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
||||||
[command_buffer waitUntilCompleted];
|
[cmd_buf waitUntilCompleted];
|
||||||
|
|
||||||
MTLCommandBufferStatus status = [command_buffer status];
|
MTLCommandBufferStatus status = [cmd_buf status];
|
||||||
if (status != MTLCommandBufferStatusCompleted) {
|
if (status != MTLCommandBufferStatusCompleted) {
|
||||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
||||||
if (status == MTLCommandBufferStatusError) {
|
if (status == MTLCommandBufferStatusError) {
|
||||||
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return GGML_STATUS_FAILED;
|
return GGML_STATUS_FAILED;
|
||||||
@ -4696,20 +5125,20 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_cb; ++i) {
|
for (int i = 0; i < n_cb; ++i) {
|
||||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
||||||
[command_buffer waitUntilCompleted];
|
[cmd_buf waitUntilCompleted];
|
||||||
|
|
||||||
MTLCommandBufferStatus status = [command_buffer status];
|
MTLCommandBufferStatus status = [cmd_buf status];
|
||||||
if (status != MTLCommandBufferStatusCompleted) {
|
if (status != MTLCommandBufferStatusCompleted) {
|
||||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||||
if (status == MTLCommandBufferStatusError) {
|
if (status == MTLCommandBufferStatusError) {
|
||||||
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return GGML_STATUS_FAILED;
|
return GGML_STATUS_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
|
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
||||||
if (!next_buffer) {
|
if (!next_buffer) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -5092,8 +5521,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|||||||
|
|
||||||
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
||||||
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
|
||||||
|
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
||||||
|
|
||||||
int node_start = 0;
|
int node_start = 0;
|
||||||
int node_end = n_nodes_0;
|
int node_end = n_nodes_0;
|
||||||
@ -5105,22 +5535,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|||||||
|
|
||||||
const bool should_capture = ctx->capture_next_compute;
|
const bool should_capture = ctx->capture_next_compute;
|
||||||
|
|
||||||
|
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
||||||
|
ggml_metal_mem_pool_reset(mem_pool);
|
||||||
|
|
||||||
for (int idx = node_start; idx < node_end; ++idx) {
|
for (int idx = node_start; idx < node_end; ++idx) {
|
||||||
if (should_capture) {
|
if (should_capture) {
|
||||||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_metal_encode_node(backend, idx, encoder);
|
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
||||||
|
|
||||||
if (should_capture) {
|
if (should_capture) {
|
||||||
[encoder popDebugGroup];
|
[encoder popDebugGroup];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!res) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[encoder endEncoding];
|
[encoder endEncoding];
|
||||||
|
|
||||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
[command_buffer commit];
|
[cmd_buf commit];
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -6336,127 +6336,219 @@ kernel void kernel_mul_mm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
template<typename T4>
|
||||||
// TODO: this kernel needs to be reimplemented from scratch for better performance
|
kernel void kernel_mul_mm_id_map0(
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
constant ggml_metal_kargs_mul_mm_id_map0 & args,
|
||||||
void kernel_mul_mm_id_impl(
|
device const char * src1,
|
||||||
int32_t ne00,
|
device const char * src2,
|
||||||
int32_t ne02,
|
device char * hsrc1,
|
||||||
uint64_t nb01,
|
device char * htpe,
|
||||||
uint64_t nb02,
|
device char * hids,
|
||||||
int32_t ne11,
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
int32_t ne12,
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint64_t nb10,
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
uint64_t nb11,
|
const int ide = tgpig[0]; // expert id
|
||||||
uint64_t nb12,
|
|
||||||
int32_t ne0,
|
int n_all = 0;
|
||||||
int32_t ne1,
|
|
||||||
int64_t ne0ne1,
|
device int32_t * ids_i32 = (device int32_t *) (hids);
|
||||||
|
|
||||||
|
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
|
||||||
|
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
|
||||||
|
|
||||||
|
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
|
||||||
|
if (src2_i32[i20] != ide) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
|
||||||
|
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
|
||||||
|
|
||||||
|
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
|
||||||
|
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tpitg.x == 0) {
|
||||||
|
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
|
||||||
|
}
|
||||||
|
|
||||||
|
++n_all;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tpitg.x == 0) {
|
||||||
|
device int32_t * tpe_i32 = (device int32_t *) (htpe);
|
||||||
|
tpe_i32[ide] = n_all;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
kernel void kernel_mul_mm_id_map1(
|
||||||
|
constant ggml_metal_kargs_mul_mm_id_map1 & args,
|
||||||
|
device const char * hdst,
|
||||||
|
device const char * hids,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int i20 = tgpig[0]; // used expert
|
||||||
|
const int i21 = tgpig[1]; // token
|
||||||
|
|
||||||
|
device const int32_t * ids_i32 = (device const int32_t *) (hids);
|
||||||
|
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
|
||||||
|
|
||||||
|
const int id = ids_i32[i21*args.ne20 + i20];
|
||||||
|
|
||||||
|
const int ide = id / args.neh1;
|
||||||
|
const int idt = id % args.neh1;
|
||||||
|
|
||||||
|
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
|
||||||
|
|
||||||
|
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
|
||||||
|
dst_f32x4[i0] = hdst_f32x4[i0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
|
||||||
|
|
||||||
|
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||||
|
kernel void kernel_mul_mm_id(
|
||||||
|
constant ggml_metal_kargs_mul_mm_id & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
threadgroup ushort2 * rowids,
|
device const char * tpe,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup char * shmem,
|
threadgroup char * shmem [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
threadgroup half * sa = (threadgroup half *)(shmem);
|
threadgroup T * sa = (threadgroup T *)(shmem);
|
||||||
threadgroup float * sb = (threadgroup float *)(shmem + 4096);
|
threadgroup half * sb = (threadgroup half *)(shmem + 4096);
|
||||||
|
|
||||||
const int r0 = tgpig.y;
|
const int r0 = tgpig.y;
|
||||||
const int r1 = tgpig.x;
|
const int r1 = tgpig.x;
|
||||||
|
const int im = tgpig.z;
|
||||||
|
|
||||||
if (r1*BLOCK_SIZE_N >= ne1) return;
|
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
|
||||||
|
|
||||||
|
const int neh1 = tpe_i32[im];
|
||||||
|
|
||||||
|
if (r1*BLOCK_SIZE_N >= neh1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// if this block is of 64x32 shape or smaller
|
// if this block is of 64x32 shape or smaller
|
||||||
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
|
||||||
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
|
||||||
|
|
||||||
// a thread shouldn't load data outside of the matrix
|
// a thread shouldn't load data outside of the matrix
|
||||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||||
|
|
||||||
simdgroup_half8x8 ma[4];
|
simdgroup_T8x8 ma[4];
|
||||||
simdgroup_float8x8 mb[2];
|
simdgroup_half8x8 mb[2];
|
||||||
simdgroup_float8x8 mc[8];
|
simdgroup_float8x8 mc[8];
|
||||||
for (int i = 0; i < 8; i++){
|
|
||||||
|
for (short i = 0; i < 8; i++){
|
||||||
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
||||||
}
|
}
|
||||||
|
|
||||||
short il = (tiitg % THREAD_PER_ROW);
|
short il = (tiitg % THREAD_PER_ROW);
|
||||||
|
|
||||||
ushort offset1 = il/nl;
|
const int i12 = im%args.neh12;
|
||||||
|
const int i13 = im/args.neh12;
|
||||||
|
|
||||||
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
|
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
const short offset1 = il/nl;
|
||||||
|
|
||||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
|
device const block_q * x = (device const block_q *)(src0
|
||||||
device const float * y = (device const float *)(src1
|
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
|
||||||
+ nb12 * id[1]
|
|
||||||
+ nb11 * (id[0] % ne11)
|
|
||||||
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
|
||||||
|
|
||||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
device const half * y = (device const half *)(src1
|
||||||
|
+ args.nbh13*i13
|
||||||
|
+ args.nbh12*i12
|
||||||
|
+ args.nbh11*(r1*BLOCK_SIZE_N + thread_col)
|
||||||
|
+ args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||||
|
|
||||||
|
for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
|
||||||
// load data and store to threadgroup memory
|
// load data and store to threadgroup memory
|
||||||
half4x4 temp_a;
|
T4x4 temp_a;
|
||||||
dequantize_func(x, il, temp_a);
|
dequantize_func(x, il, temp_a);
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
for (int i = 0; i < 16; i++) {
|
#pragma unroll(16)
|
||||||
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
|
for (short i = 0; i < 16; i++) {
|
||||||
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
|
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
|
||||||
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
|
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
|
||||||
|
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
|
||||||
}
|
}
|
||||||
|
|
||||||
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
|
*(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y);
|
||||||
|
|
||||||
il = (il + 2 < nl) ? il + 2 : il % 2;
|
il = (il + 2 < nl) ? il + 2 : il % 2;
|
||||||
x = (il < 2) ? x + (2+nl-1)/nl : x;
|
x = (il < 2) ? x + (2 + nl - 1)/nl : x;
|
||||||
y += BLOCK_SIZE_K;
|
y += BLOCK_SIZE_K;
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// load matrices from threadgroup memory and conduct outer products
|
// load matrices from threadgroup memory and conduct outer products
|
||||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
|
||||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
|
||||||
|
|
||||||
#pragma unroll(BLOCK_SIZE_K/8)
|
|
||||||
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
for (int i = 0; i < 4; i++) {
|
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (short i = 0; i < 4; i++) {
|
||||||
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
simdgroup_barrier(mem_flags::mem_none);
|
||||||
|
|
||||||
#pragma unroll(2)
|
#pragma unroll(2)
|
||||||
for (int i = 0; i < 2; i++) {
|
for (short i = 0; i < 2; i++) {
|
||||||
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
||||||
}
|
}
|
||||||
|
|
||||||
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
||||||
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
|
|
||||||
|
|
||||||
#pragma unroll(8)
|
#pragma unroll(8)
|
||||||
for (int i = 0; i < 8; i++){
|
for (short i = 0; i < 8; i++){
|
||||||
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
|
||||||
|
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
|
||||||
|
device float * C = (device float *) dst +
|
||||||
|
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
|
||||||
|
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
|
||||||
|
|
||||||
|
for (short i = 0; i < 8; i++) {
|
||||||
|
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// block is smaller than 64x32, we should avoid writing data outside of the matrix
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
threadgroup float * temp_str = ((threadgroup float *) shmem) \
|
||||||
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
|
+ 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M;
|
||||||
for (int i = 0; i < 8; i++) {
|
for (short i = 0; i < 8; i++) {
|
||||||
simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
|
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
||||||
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
|
||||||
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
|
|
||||||
|
|
||||||
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
|
|
||||||
device float4 * D4 = (device float4 *) D;
|
device float4 * D4 = (device float4 *) D;
|
||||||
|
|
||||||
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
|
||||||
@ -6476,66 +6568,6 @@ void kernel_mul_mm_id_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
||||||
kernel void kernel_mul_mm_id(
|
|
||||||
constant ggml_metal_kargs_mul_mm_id & args,
|
|
||||||
device const char * src0s,
|
|
||||||
device const char * src1,
|
|
||||||
device char * dst,
|
|
||||||
device const char * ids,
|
|
||||||
threadgroup char * shmem [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
ushort tiitg[[thread_index_in_threadgroup]],
|
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
||||||
|
|
||||||
const int32_t i02 = tgpig.z;
|
|
||||||
|
|
||||||
tgpig.z = 0;
|
|
||||||
|
|
||||||
device const char * src0 = src0s + i02*args.nb02;
|
|
||||||
|
|
||||||
// row indices
|
|
||||||
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
|
|
||||||
|
|
||||||
// TODO: parallelize this loop
|
|
||||||
int32_t _ne1 = 0;
|
|
||||||
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
|
|
||||||
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
|
|
||||||
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
|
|
||||||
if (id == i02) {
|
|
||||||
if (tiitg == 0) {
|
|
||||||
rowids[_ne1] = ushort2(ii0, ii1);
|
|
||||||
}
|
|
||||||
_ne1++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
|
||||||
args.ne00,
|
|
||||||
args.ne02,
|
|
||||||
args.nb01,
|
|
||||||
args.nb02,
|
|
||||||
args.ne11,
|
|
||||||
args.ne12,
|
|
||||||
args.nb10,
|
|
||||||
args.nb11,
|
|
||||||
args.nb12,
|
|
||||||
args.ne0,
|
|
||||||
_ne1,
|
|
||||||
(int64_t)args.ne0*args.ne1,
|
|
||||||
src0,
|
|
||||||
src1,
|
|
||||||
rowids,
|
|
||||||
dst,
|
|
||||||
shmem,
|
|
||||||
tgpig,
|
|
||||||
tiitg,
|
|
||||||
sgitg);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define QK_NL 16
|
#define QK_NL 16
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -6576,63 +6608,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
|
|||||||
// matrix-matrix multiplication
|
// matrix-matrix multiplication
|
||||||
//
|
//
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||||
#if defined(GGML_METAL_USE_BF16)
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||||
#endif
|
#endif
|
||||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// indirect matrix-matrix multiplication
|
// indirect matrix-matrix multiplication
|
||||||
//
|
//
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mul_mm_id;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||||
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||||
#if defined(GGML_METAL_USE_BF16)
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
|
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
|
||||||
#endif
|
#endif
|
||||||
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||||
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
|
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
|
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||||
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// matrix-vector multiplication
|
// matrix-vector multiplication
|
||||||
|
@ -2732,11 +2732,11 @@ void ggml_mul_mat_set_prec(
|
|||||||
c = ggml_mul_mat_id(ctx, as, b, ids);
|
c = ggml_mul_mat_id(ctx, as, b, ids);
|
||||||
|
|
||||||
as -> [cols, rows, n_expert]
|
as -> [cols, rows, n_expert]
|
||||||
ids -> [n_experts_used, n_tokens] (i32)
|
|
||||||
b -> [cols, n_expert_used, n_tokens]
|
b -> [cols, n_expert_used, n_tokens]
|
||||||
|
ids -> [n_expert_used, n_tokens] (i32)
|
||||||
c -> [rows, n_expert_used, n_tokens]
|
c -> [rows, n_expert_used, n_tokens]
|
||||||
|
|
||||||
in b, n_experts_used can be broadcasted to match the n_expert_used of ids
|
in b, n_expert_used can be broadcasted to match the n_expert_used of ids
|
||||||
|
|
||||||
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
|
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
|
||||||
*/
|
*/
|
||||||
|
Loading…
Reference in New Issue
Block a user