mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
ggml : sync latest llama.cpp (view_src + alloc improvements) (#1247)
* ggml : sync latest llama.cpp (view_src + alloc improvements) * ggml : fix build
This commit is contained in:
parent
ba3c333611
commit
c3f319d7c2
107
ggml-cuda.cu
107
ggml-cuda.cu
@ -81,12 +81,29 @@
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
#endif
|
||||
|
||||
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
||||
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
||||
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
||||
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
||||
#if __has_builtin(__builtin_elementwise_sub_sat)
|
||||
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
||||
return reinterpret_cast<const int&>(c);
|
||||
#else
|
||||
int8x4_t c;
|
||||
int16_t tmp;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tmp = va[i] - vb[i];
|
||||
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
||||
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
||||
c[i] = tmp;
|
||||
}
|
||||
return reinterpret_cast<int&>(c);
|
||||
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
||||
@ -447,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
|
||||
dst[i] = x[i] / (1.0f + expf(-x[i]));
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const float eps = 1e-5f;
|
||||
|
||||
float mean = 0.0f;
|
||||
float var = 0.0f;
|
||||
float2 mean_var = make_float2(0.f, 0.f);
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
mean += xi;
|
||||
var += xi * xi;
|
||||
mean_var.x += xi;
|
||||
mean_var.y += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
|
||||
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float2 s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = mean_var;
|
||||
}
|
||||
__syncthreads();
|
||||
mean_var = s_sum[lane_id];
|
||||
mean_var = warp_reduce_sum(mean_var);
|
||||
}
|
||||
|
||||
mean /= ncols;
|
||||
var = var / ncols - mean * mean;
|
||||
const float inv_var = rsqrtf(var + eps);
|
||||
const float mean = mean_var.x / ncols;
|
||||
const float var = mean_var.y / ncols - mean * mean;
|
||||
const float inv_std = rsqrtf(var + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
template <int block_size>
|
||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
float tmp = 0.0f; // partial sum for thread in warp
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
const float xi = x[row*ncols + col];
|
||||
tmp += xi * xi;
|
||||
}
|
||||
|
||||
// sum up partial sums
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__shared__ float s_sum[32];
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
s_sum[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
const float mean = tmp / ncols;
|
||||
const float scale = rsqrtf(mean + eps);
|
||||
|
||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||
for (int col = tid; col < ncols; col += block_size) {
|
||||
dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||
}
|
||||
}
|
||||
@ -4186,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
|
||||
|
||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
if (ncols < 1024) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
}
|
||||
}
|
||||
|
||||
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
|
||||
@ -5721,7 +5781,6 @@ inline void ggml_cuda_op_alibi(
|
||||
(void) src1;
|
||||
(void) src0_ddq_i;
|
||||
(void) src1_ddf_i;
|
||||
(void) i02;
|
||||
(void) i1;
|
||||
}
|
||||
|
||||
|
133
ggml-metal.m
133
ggml-metal.m
@ -11,6 +11,7 @@
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
// TODO: temporary - reuse llama.cpp logging
|
||||
#ifdef GGML_METAL_NDEBUG
|
||||
#define metal_printf(...)
|
||||
#else
|
||||
@ -75,6 +76,7 @@ struct ggml_metal_context {
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
|
||||
@ -113,12 +115,26 @@ static NSString * const msl_library_source = @"see metal.metal";
|
||||
@end
|
||||
|
||||
struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
fprintf(stderr, "%s: allocating\n", __func__);
|
||||
metal_printf("%s: allocating\n", __func__);
|
||||
|
||||
// Show all the Metal device instances in the system
|
||||
NSArray * devices = MTLCopyAllDevices();
|
||||
id <MTLDevice> device;
|
||||
NSString * s;
|
||||
for (device in devices) {
|
||||
s = [device name];
|
||||
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
|
||||
}
|
||||
|
||||
// Pick and show default Metal device
|
||||
device = MTLCreateSystemDefaultDevice();
|
||||
s = [device name];
|
||||
metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
|
||||
|
||||
// Configure context
|
||||
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
|
||||
|
||||
ctx->device = device;
|
||||
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
||||
ctx->device = MTLCreateSystemDefaultDevice();
|
||||
ctx->queue = [ctx->device newCommandQueue];
|
||||
ctx->n_buffers = 0;
|
||||
ctx->concur_list_len = 0;
|
||||
@ -132,7 +148,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
|
||||
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
|
||||
if (error) {
|
||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
@ -146,11 +162,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
//NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
|
||||
NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
|
||||
NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
|
||||
fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
|
||||
metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
|
||||
|
||||
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
||||
if (error) {
|
||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@ -162,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
||||
#endif
|
||||
if (error) {
|
||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
@ -174,11 +190,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
#define GGML_METAL_ADD_KERNEL(name) \
|
||||
ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
|
||||
ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
|
||||
fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
||||
metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
|
||||
(int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
|
||||
(int) ctx->pipeline_##name.threadExecutionWidth); \
|
||||
if (error) { \
|
||||
fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
||||
metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
||||
return NULL; \
|
||||
}
|
||||
|
||||
@ -204,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
|
||||
@ -230,19 +247,19 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||
#undef GGML_METAL_ADD_KERNEL
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
fprintf(stderr, "%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||
if (ctx->device.maxTransferRate != 0) {
|
||||
fprintf(stderr, "%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
||||
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
|
||||
} else {
|
||||
fprintf(stderr, "%s: maxTransferRate = built-in GPU\n", __func__);
|
||||
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
fprintf(stderr, "%s: deallocating\n", __func__);
|
||||
metal_printf("%s: deallocating\n", __func__);
|
||||
#define GGML_METAL_DEL_KERNEL(name) \
|
||||
[ctx->function_##name release]; \
|
||||
[ctx->pipeline_##name release];
|
||||
@ -269,6 +286,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
|
||||
@ -311,7 +329,7 @@ void * ggml_metal_host_malloc(size_t n) {
|
||||
void * data = NULL;
|
||||
const int result = posix_memalign((void **) &data, getpagesize(), n);
|
||||
if (result != 0) {
|
||||
fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
|
||||
metal_printf("%s: error: posix_memalign failed\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@ -339,7 +357,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
|
||||
// Metal buffer based on the host memory pointer
|
||||
//
|
||||
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
||||
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
||||
//metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
|
||||
|
||||
const int64_t tsize = ggml_nbytes(t);
|
||||
|
||||
@ -350,13 +368,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
|
||||
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
|
||||
*offs = (size_t) ioffs;
|
||||
|
||||
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
||||
//metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
|
||||
|
||||
return ctx->buffers[i].metal;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: error: buffer is nil\n", __func__);
|
||||
metal_printf("%s: error: buffer is nil\n", __func__);
|
||||
|
||||
return nil;
|
||||
}
|
||||
@ -368,7 +386,7 @@ bool ggml_metal_add_buffer(
|
||||
size_t size,
|
||||
size_t max_size) {
|
||||
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
|
||||
fprintf(stderr, "%s: too many buffers\n", __func__);
|
||||
metal_printf("%s: too many buffers\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -378,7 +396,7 @@ bool ggml_metal_add_buffer(
|
||||
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
|
||||
|
||||
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
|
||||
fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
||||
metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -399,11 +417,11 @@ bool ggml_metal_add_buffer(
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
|
||||
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
||||
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
|
||||
|
||||
++ctx->n_buffers;
|
||||
} else {
|
||||
@ -423,27 +441,27 @@ bool ggml_metal_add_buffer(
|
||||
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
|
||||
|
||||
if (ctx->buffers[ctx->n_buffers].metal == nil) {
|
||||
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
||||
metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
|
||||
return false;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
||||
metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
|
||||
if (i + size_step < size) {
|
||||
fprintf(stderr, "\n");
|
||||
metal_printf("\n");
|
||||
}
|
||||
|
||||
++ctx->n_buffers;
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, ", (%8.2f / %8.2f)",
|
||||
metal_printf(", (%8.2f / %8.2f)",
|
||||
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
|
||||
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||
|
||||
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
|
||||
fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
|
||||
metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
|
||||
} else {
|
||||
fprintf(stderr, "\n");
|
||||
metal_printf("\n");
|
||||
}
|
||||
}
|
||||
|
||||
@ -453,8 +471,6 @@ bool ggml_metal_add_buffer(
|
||||
void ggml_metal_set_tensor(
|
||||
struct ggml_metal_context * ctx,
|
||||
struct ggml_tensor * t) {
|
||||
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
|
||||
|
||||
size_t offs;
|
||||
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
|
||||
|
||||
@ -464,8 +480,6 @@ void ggml_metal_set_tensor(
|
||||
void ggml_metal_get_tensor(
|
||||
struct ggml_metal_context * ctx,
|
||||
struct ggml_tensor * t) {
|
||||
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
|
||||
|
||||
size_t offs;
|
||||
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
|
||||
|
||||
@ -560,15 +574,13 @@ void ggml_metal_graph_find_concurrency(
|
||||
}
|
||||
|
||||
if (ctx->concur_list_len > GGML_MAX_CONCUR) {
|
||||
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
|
||||
metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_metal_graph_compute(
|
||||
struct ggml_metal_context * ctx,
|
||||
struct ggml_cgraph * gf) {
|
||||
metal_printf("%s: evaluating graph\n", __func__);
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
// if there is ctx->concur_list, dispatch concurrently
|
||||
@ -616,7 +628,7 @@ void ggml_metal_graph_compute(
|
||||
continue;
|
||||
}
|
||||
|
||||
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
|
||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
||||
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
||||
@ -685,6 +697,12 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
// utilize float4
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
const int64_t nb = ne00/4;
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
[encoder setComputePipelineState:ctx->pipeline_add_row];
|
||||
@ -694,14 +712,20 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
// utilize float4
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
const int64_t nb = ne00/4;
|
||||
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||
@ -711,9 +735,9 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&nb length:sizeof(nb) atIndex:3];
|
||||
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
const int64_t n = ggml_nelements(dst)/4;
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
@ -764,7 +788,7 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
} break;
|
||||
@ -845,9 +869,13 @@ void ggml_metal_graph_compute(
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
nth0 = 64;
|
||||
nth0 = 32;
|
||||
nth1 = 1;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||
if (ne11 * ne12 < 4) {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||
} else {
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
@ -899,8 +927,8 @@ void ggml_metal_graph_compute(
|
||||
GGML_ASSERT(ne02 == 1);
|
||||
GGML_ASSERT(ne12 == 1);
|
||||
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
nth0 = 4; //1;
|
||||
nth1 = 8; //32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
@ -923,7 +951,7 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
|
||||
metal_printf("Asserting on type %d\n",(int)src0t);
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
};
|
||||
@ -948,9 +976,12 @@ void ggml_metal_graph_compute(
|
||||
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
||||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q3_K) {
|
||||
#ifdef GGML_QKK_64
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
@ -964,8 +995,8 @@ void ggml_metal_graph_compute(
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
int64_t ny = (ne11 + 3)/4;
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
}
|
||||
} break;
|
||||
@ -1161,7 +1192,7 @@ void ggml_metal_graph_compute(
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
@ -1186,7 +1217,7 @@ void ggml_metal_graph_compute(
|
||||
|
||||
MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||
metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
238
ggml-metal.metal
238
ggml-metal.metal
@ -25,9 +25,9 @@ typedef struct {
|
||||
} block_q8_0;
|
||||
|
||||
kernel void kernel_add(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig];
|
||||
}
|
||||
@ -35,18 +35,18 @@ kernel void kernel_add(
|
||||
// assumption: src1 is a row
|
||||
// broadcast src1 into src0
|
||||
kernel void kernel_add_row(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant int64_t & nb,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig % ne00];
|
||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_mul(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src1[tpig];
|
||||
}
|
||||
@ -54,12 +54,12 @@ kernel void kernel_mul(
|
||||
// assumption: src1 is a row
|
||||
// broadcast src1 into src0
|
||||
kernel void kernel_mul_row(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant int64_t & nb,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] * src1[tpig % ne00];
|
||||
dst[tpig] = src0[tpig] * src1[tpig % nb];
|
||||
}
|
||||
|
||||
kernel void kernel_scale(
|
||||
@ -133,19 +133,24 @@ kernel void kernel_soft_max(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// broadcast
|
||||
if (tpitg[0] == 0) {
|
||||
buf[0] = buf[0];
|
||||
}
|
||||
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
|
||||
// the loop, and when that is done, buf[0] has the correct (synchronized) value
|
||||
//if (tpitg[0] == 0) {
|
||||
// buf[0] = buf[0];
|
||||
//}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float max = buf[0];
|
||||
|
||||
// parallel sum
|
||||
buf[tpitg[0]] = 0.0f;
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
buf[tpitg[0]] += exp(psrc0[i00] - max);
|
||||
const float exp_psrc0 = exp(psrc0[i00] - max);
|
||||
buf[tpitg[0]] += exp_psrc0;
|
||||
// Remember the result of exp here. exp is expensive, so we really do not
|
||||
// whish to compute it twice.
|
||||
pdst[i00] = exp_psrc0;
|
||||
}
|
||||
|
||||
// reduce
|
||||
@ -157,17 +162,18 @@ kernel void kernel_soft_max(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
// broadcast
|
||||
if (tpitg[0] == 0) {
|
||||
buf[0] = buf[0];
|
||||
}
|
||||
// broadcast - not needed, see above
|
||||
//// broadcast
|
||||
//if (tpitg[0] == 0) {
|
||||
// buf[0] = buf[0];
|
||||
//}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
const float sum = buf[0];
|
||||
|
||||
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||
pdst[i00] = exp(psrc0[i00] - max) / sum;
|
||||
pdst[i00] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
@ -214,25 +220,27 @@ kernel void kernel_norm(
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
// broadcast
|
||||
if (tpitg == 0) {
|
||||
sum[0] /= ne00;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//// broadcast
|
||||
//if (tpitg == 0) {
|
||||
// sum[0] /= ne00;
|
||||
//}
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
const float mean = sum[0];
|
||||
|
||||
// recenter
|
||||
// recenter and VARIANCE
|
||||
device float * y = dst + tgpig*ne00;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
y[i00] = x[i00] - mean;
|
||||
}
|
||||
|
||||
// VARIANCE
|
||||
// parallel sum
|
||||
sum[tpitg] = 0.0f;
|
||||
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
y[i00] = x[i00] - mean;
|
||||
sum[tpitg] += y[i00] * y[i00];
|
||||
}
|
||||
|
||||
//// VARIANCE
|
||||
//// parallel sum
|
||||
//sum[tpitg] = 0.0f;
|
||||
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||
// sum[tpitg] += y[i00] * y[i00];
|
||||
//}
|
||||
// reduce
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = ntg/2; i > 0; i /= 2) {
|
||||
@ -241,11 +249,11 @@ kernel void kernel_norm(
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
// broadcast
|
||||
if (tpitg == 0) {
|
||||
sum[0] /= ne00;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//// broadcast
|
||||
//if (tpitg == 0) {
|
||||
// sum[0] /= ne00;
|
||||
//}
|
||||
//threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
const float variance = sum[0];
|
||||
|
||||
const float scale = 1.0f/sqrt(variance + eps);
|
||||
@ -435,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
||||
}
|
||||
|
||||
#define NB_Q8_0 8
|
||||
|
||||
kernel void kernel_mul_mat_q8_0_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
@ -463,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32(
|
||||
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[16];
|
||||
float yl[NB_Q8_0];
|
||||
float sumf[nr]={0.f};
|
||||
|
||||
const int ix = tiisg/2;
|
||||
const int il = tiisg%2;
|
||||
const int ix = tiisg/4;
|
||||
const int il = tiisg%4;
|
||||
|
||||
device const float * yb = y + ix * QK8_0 + 16*il;
|
||||
device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
|
||||
|
||||
// each thread in a SIMD group deals with half a block.
|
||||
for (int ib = ix; ib < nb; ib += nw/2) {
|
||||
for (int i = 0; i < 16; ++i) {
|
||||
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
|
||||
for (int ib = ix; ib < nb; ib += nw/4) {
|
||||
for (int i = 0; i < NB_Q8_0; ++i) {
|
||||
yl[i] = yb[i];
|
||||
}
|
||||
|
||||
for (int row = 0; row < nr; row++) {
|
||||
device const int8_t * qs = x[ib+row*nb].qs + 16*il;
|
||||
device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
|
||||
float sumq = 0.f;
|
||||
for (int iq = 0; iq < 16; ++iq) {
|
||||
for (int iq = 0; iq < NB_Q8_0; ++iq) {
|
||||
sumq += qs[iq] * yl[iq];
|
||||
}
|
||||
sumf[row] += sumq*x[ib+row*nb].d;
|
||||
}
|
||||
|
||||
yb += QK8_0 * 16;
|
||||
yb += NB_Q8_0 * nw;
|
||||
}
|
||||
|
||||
for (int row = 0; row < nr; ++row) {
|
||||
@ -497,6 +507,60 @@ kernel void kernel_mul_mat_q8_0_f32(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32_1row(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
if (ne00 < 128) {
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
sumf += (float) x[i] * (float) y[i];
|
||||
}
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
} else {
|
||||
device const half4 * x4 = (device const half4 *) x;
|
||||
device const float4 * y4 = (device const float4 *) y;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
||||
}
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define N_F16_F32 4
|
||||
|
||||
kernel void kernel_mul_mat_f16_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@ -515,37 +579,58 @@ kernel void kernel_mul_mat_f16_f32(
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
threadgroup float * sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpig[[thread_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 tptg[[threads_per_threadgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
const int64_t rb = tgpig.y*N_F16_F32;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||
|
||||
sum[tpitg.x] = 0.0f;
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F16_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
for (int i = tpitg.x; i < ne00; i += tptg.x) {
|
||||
sum[tpitg.x] += (float) x[i] * (float) y[i];
|
||||
}
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
// accumulate the sum from all threads in the threadgroup
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (uint i = tptg.x/2; i > 0; i /= 2) {
|
||||
if (tpitg.x < i) {
|
||||
sum[tpitg.x] += sum[tpitg.x + i];
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
sumf += (float) x[i] * (float) y[i];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
device const half4 * x4 = (device const half4 *)x;
|
||||
for (int row = 0; row < N_F16_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
device const float4 * y4 = (device const float4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_alibi_f32(
|
||||
@ -1244,7 +1329,8 @@ kernel void kernel_mul_mat_q4_K_f32(
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int r2 = tgpig.z;
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const int first_row = r0 * N_DST;
|
||||
const int ib_row = first_row * nb;
|
||||
const uint offset0 = r2/gqa*(nb*ne0);
|
||||
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
|
||||
|
@ -1334,7 +1334,7 @@ void ggml_cl_free_data(const struct ggml_tensor* tensor) {
|
||||
return;
|
||||
}
|
||||
|
||||
cl_mem mem = (cl_mem)tensor->data;
|
||||
cl_mem mem = (cl_mem)tensor->extra;
|
||||
clReleaseMemObject(mem);
|
||||
}
|
||||
|
||||
@ -1393,7 +1393,7 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
size_t d_size;
|
||||
|
||||
cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
|
||||
cl_mem d_Y = (cl_mem) src1->data; // src1 is already on device, broadcasted.
|
||||
cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
|
||||
cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
|
||||
|
||||
|
||||
@ -1491,9 +1491,9 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
||||
size_t d_size;
|
||||
cl_mem d_X;
|
||||
if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
|
||||
d_X = (cl_mem) src0->data;
|
||||
d_X = (cl_mem) src0->extra;
|
||||
} else {
|
||||
d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
|
||||
d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||
}
|
||||
cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||
cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||
@ -1567,7 +1567,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
||||
size_t d_size;
|
||||
cl_mem d_X;
|
||||
if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
|
||||
d_X = (cl_mem) src0->data;
|
||||
d_X = (cl_mem) src0->extra;
|
||||
} else {
|
||||
d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
|
||||
}
|
||||
@ -1697,7 +1697,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
||||
events.emplace_back();
|
||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
|
||||
} else if (src0->backend == GGML_BACKEND_GPU) {
|
||||
d_Q = (cl_mem) src0->data;
|
||||
d_Q = (cl_mem) src0->extra;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
@ -1860,6 +1860,6 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
|
||||
|
||||
CL_CHECK(clFinish(queue));
|
||||
|
||||
tensor->data = dst;
|
||||
tensor->extra = dst;
|
||||
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
|
||||
}
|
||||
|
34
ggml.h
34
ggml.h
@ -479,6 +479,9 @@ extern "C" {
|
||||
int64_t perf_cycles;
|
||||
int64_t perf_time_us;
|
||||
|
||||
struct ggml_tensor * view_src;
|
||||
size_t view_offs;
|
||||
|
||||
void * data;
|
||||
|
||||
char name[GGML_MAX_NAME];
|
||||
@ -661,7 +664,7 @@ extern "C" {
|
||||
GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
|
||||
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
|
||||
GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
|
||||
|
||||
@ -952,11 +955,11 @@ extern "C" {
|
||||
|
||||
// a - x
|
||||
// b - dy
|
||||
// TODO: update with configurable eps
|
||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
struct ggml_tensor * b,
|
||||
float eps);
|
||||
|
||||
// A: n columns, m rows
|
||||
// B: n columns, p rows (i.e. we transpose it internally)
|
||||
@ -1612,7 +1615,8 @@ extern "C" {
|
||||
struct ggml_tensor * tensor);
|
||||
|
||||
|
||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
|
||||
|
||||
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
||||
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
|
||||
@ -1677,6 +1681,8 @@ extern "C" {
|
||||
GGML_LINESEARCH_INVALID_PARAMETERS,
|
||||
};
|
||||
|
||||
typedef void (*ggml_opt_callback)(void * data, float * sched);
|
||||
|
||||
// optimization parameters
|
||||
//
|
||||
// see ggml.c (ggml_opt_default_params) for default values
|
||||
@ -1712,12 +1718,14 @@ extern "C" {
|
||||
|
||||
float sched; // schedule multiplier (fixed, decay or warmup)
|
||||
float decay; // weight decay for AdamW, use 0.0f to disable
|
||||
int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
|
||||
float alpha; // learning rate
|
||||
float beta1;
|
||||
float beta2;
|
||||
float eps; // epsilon for numerical stability
|
||||
float eps_f; // epsilon for convergence test
|
||||
float eps_g; // epsilon for convergence test
|
||||
float gclip; // gradient clipping
|
||||
} adam;
|
||||
|
||||
// LBFGS parameters
|
||||
@ -1745,14 +1753,12 @@ extern "C" {
|
||||
|
||||
bool just_initialized;
|
||||
|
||||
float loss_before;
|
||||
float loss_after;
|
||||
|
||||
struct {
|
||||
struct ggml_tensor * x; // view of the parameters
|
||||
struct ggml_tensor * g1; // gradient
|
||||
struct ggml_tensor * g2; // gradient squared
|
||||
struct ggml_tensor * m; // first moment
|
||||
struct ggml_tensor * v; // second moment
|
||||
struct ggml_tensor * mh; // first moment hat
|
||||
struct ggml_tensor * vh; // second moment hat
|
||||
struct ggml_tensor * pf; // past function values
|
||||
float fx_best;
|
||||
float fx_prev;
|
||||
@ -1789,10 +1795,10 @@ extern "C" {
|
||||
|
||||
// initialize optimizer context
|
||||
GGML_API void ggml_opt_init(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_opt_context * opt,
|
||||
struct ggml_opt_params params,
|
||||
int64_t nx);
|
||||
struct ggml_opt_params params,
|
||||
int64_t nx);
|
||||
|
||||
// continue optimizing the function defined by the tensor f
|
||||
GGML_API enum ggml_opt_result ggml_opt_resume(
|
||||
@ -1806,7 +1812,9 @@ extern "C" {
|
||||
struct ggml_opt_context * opt,
|
||||
struct ggml_tensor * f,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb);
|
||||
struct ggml_cgraph * gb,
|
||||
ggml_opt_callback callback,
|
||||
void * callback_data);
|
||||
|
||||
//
|
||||
// quantization
|
||||
|
Loading…
Reference in New Issue
Block a user