mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-08-25 03:25:55 +02:00
finetune: SGD optimizer, more CLI args (llama/13873)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. support finetune.cpp arg -opt SGD (or sgd). (default adamw as before) llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (wikipedia 100 lines finetune) ( using the same GPU memory, adamw can only do before OOM 512 batch/context, reaching: train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00 val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00 SGD is superior, though it converges slower, with max before OOM 1728 batch/context (esp see the better validation perf): train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00 val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00 ) note: when finetuning long enough (or w/ enough -lr), validation accuracy *eventually* drops ('catastrophic forgetting') -lr-half (halflife) option useful for SGD to avoid oscillation or super slow underdamped learning (makes setting -lr more forgiving). terminal -lr for now is set by lr-halvings i.e. if you want at most 1/8 the inital -lr you set -lr-halvings 3. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit, disabled (still done for new SGD though) since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no cmdline arg to set such a policy yet) test-opt checks adamw as before and now sgd (except for a few disabled tests for sgd only; probably just needs logging values and adding alternate reference values); tolerance on the 'regression' test is broader for sgd (so we don't need many more epochs) * Vulkan: Implement GGML_OP_OPT_STEP_SGD * tests: Fix OPT_STEP_SGD test-backend-ops * SGD op param store weight-decay and not 1-alpha*wd * minor + cosmetic changes * fix vulkan sgd * try CI fix --------- Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
committed by
Georgi Gerganov
parent
cbaec6c4ac
commit
c76ec72d59
@@ -74,16 +74,26 @@ extern "C" {
|
||||
GGML_OPT_BUILD_TYPE_OPT = 30,
|
||||
};
|
||||
|
||||
enum ggml_opt_optimizer_type {
|
||||
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||
GGML_OPT_OPTIMIZER_TYPE_SGD,
|
||||
|
||||
GGML_OPT_OPTIMIZER_TYPE_COUNT
|
||||
};
|
||||
|
||||
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
||||
struct ggml_opt_optimizer_params {
|
||||
// AdamW optimizer parameters
|
||||
struct {
|
||||
float alpha; // learning rate
|
||||
float beta1;
|
||||
float beta2;
|
||||
float beta1; // first AdamW momentum
|
||||
float beta2; // second AdamW momentum
|
||||
float eps; // epsilon for numerical stability
|
||||
float wd; // weight decay for AdamW, use 0.0f to disable
|
||||
float wd; // weight decay - 0.0f to disable
|
||||
} adamw;
|
||||
struct {
|
||||
float alpha; // learning rate
|
||||
float wd; // weight decay
|
||||
} sgd;
|
||||
};
|
||||
|
||||
// callback to calculate optimizer parameters prior to a backward pass
|
||||
@@ -112,8 +122,11 @@ extern "C" {
|
||||
|
||||
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
|
||||
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
||||
enum ggml_opt_optimizer_type optimizer;
|
||||
};
|
||||
|
||||
// get parameters for an optimization context with defaults set where possible
|
||||
@@ -142,6 +155,10 @@ extern "C" {
|
||||
// get the gradient accumulator for a node from the forward graph
|
||||
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
||||
|
||||
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
|
||||
|
||||
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
|
||||
|
||||
// ====== Optimization Result ======
|
||||
|
||||
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
||||
@@ -226,12 +243,14 @@ extern "C" {
|
||||
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
||||
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
||||
enum ggml_opt_loss_type loss_type, // loss to minimize
|
||||
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
|
||||
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
||||
int64_t nepoch, // how many times the dataset should be iterated over
|
||||
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
||||
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
||||
bool silent); // whether or not info prints to stderr should be suppressed
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@@ -542,6 +542,7 @@ extern "C" {
|
||||
GGML_OP_CROSS_ENTROPY_LOSS,
|
||||
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
||||
GGML_OP_OPT_STEP_ADAMW,
|
||||
GGML_OP_OPT_STEP_SGD,
|
||||
|
||||
GGML_OP_GLU,
|
||||
|
||||
@@ -2311,7 +2312,14 @@ extern "C" {
|
||||
struct ggml_tensor * grad,
|
||||
struct ggml_tensor * m,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * adamw_params); // parameters such a the learning rate
|
||||
struct ggml_tensor * adamw_params); // parameters such as the learning rate
|
||||
|
||||
// stochastic gradient descent step (with weight decay)
|
||||
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * grad,
|
||||
struct ggml_tensor * sgd_params); // alpha, weight decay
|
||||
|
||||
//
|
||||
// automatic differentiation
|
||||
|
@@ -2022,6 +2022,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
ggml_compute_forward_opt_step_adamw(params, tensor);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
{
|
||||
ggml_compute_forward_opt_step_sgd(params, tensor);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
{
|
||||
// nop
|
||||
@@ -2325,6 +2330,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
@@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
||||
|
||||
const float alpha = adamw_params_ptr[0];
|
||||
const float beta1 = adamw_params_ptr[1];
|
||||
const float beta2 = adamw_params_ptr[2];
|
||||
@@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||
const float wd = adamw_params_ptr[4];
|
||||
const float beta1h = adamw_params_ptr[5];
|
||||
const float beta2h = adamw_params_ptr[6];
|
||||
|
||||
const float keep = 1.f - alpha * wd;
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
@@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||
// The weight decay is applied independently of the Adam momenta m and v.
|
||||
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
||||
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
||||
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
|
||||
w[i00] = w[i00] * keep - alpha * mh / vh;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src0_grad = dst->src[1];
|
||||
const ggml_tensor * sgd_params = dst->src[2];
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
||||
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1) / nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr * ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
||||
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
|
||||
const float alpha = sgd_params_ptr[0];
|
||||
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir / (ne02 * ne01);
|
||||
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
||||
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
||||
|
||||
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
||||
|
||||
float * w = (float *) ((char *) src0->data + offset); // weight
|
||||
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
||||
|
||||
for (int i00 = 0; i00 < ne00; ++i00) {
|
||||
w[i00] = w[i00] * keep - alpha * g[i00];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error - sgd is F32 only");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
|
||||
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
@@ -28,6 +28,7 @@
|
||||
#include "ggml-cuda/mmvq.cuh"
|
||||
#include "ggml-cuda/norm.cuh"
|
||||
#include "ggml-cuda/opt-step-adamw.cuh"
|
||||
#include "ggml-cuda/opt-step-sgd.cuh"
|
||||
#include "ggml-cuda/out-prod.cuh"
|
||||
#include "ggml-cuda/pad.cuh"
|
||||
#include "ggml-cuda/pool2d.cuh"
|
||||
@@ -2479,6 +2480,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_cuda_opt_step_adamw(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
ggml_cuda_opt_step_sgd(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -3536,6 +3540,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
49
ggml/src/ggml-cuda/opt-step-sgd.cu
Normal file
49
ggml/src/ggml-cuda/opt-step-sgd.cu
Normal file
@@ -0,0 +1,49 @@
|
||||
#include "ggml-impl.h"
|
||||
#include "opt-step-sgd.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void opt_step_sgd_f32(
|
||||
float * __restrict__ x, const float * __restrict__ g,
|
||||
const float * __restrict__ pars, const int64_t k) {
|
||||
|
||||
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
|
||||
}
|
||||
|
||||
static void opt_step_sgd_f32_cuda(
|
||||
float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
|
||||
|
||||
const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
|
||||
const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
|
||||
opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
|
||||
}
|
||||
|
||||
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src0_grad = dst->src[1];
|
||||
const ggml_tensor * params = dst->src[2];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(params->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src0_grad));
|
||||
GGML_ASSERT(ggml_is_contiguous(params));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
||||
GGML_ASSERT(ggml_nelements(params) == 2);
|
||||
|
||||
float * src0_d = (float *) src0->data;
|
||||
const float * src0_grad_d = (const float *) src0_grad->data;
|
||||
const float * params_d = (const float *) params->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
|
||||
opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
|
||||
}
|
5
ggml/src/ggml-cuda/opt-step-sgd.cuh
Normal file
5
ggml/src/ggml-cuda/opt-step-sgd.cuh
Normal file
@@ -0,0 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
@@ -64,9 +64,11 @@ struct ggml_opt_context {
|
||||
int32_t opt_i = 0;
|
||||
bool loss_per_datapoint = false;
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
||||
void * get_opt_pars_ud = nullptr;
|
||||
struct ggml_tensor * adamw_params = nullptr;
|
||||
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
||||
void * get_opt_pars_ud = nullptr;
|
||||
struct ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
|
||||
|
||||
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
};
|
||||
|
||||
struct ggml_opt_result {
|
||||
@@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
|
||||
result.adamw.eps = 1e-8f;
|
||||
result.adamw.wd = 0.0f;
|
||||
|
||||
result.sgd.alpha = 1e-3f;
|
||||
result.sgd.wd = 0.0f;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
||||
return *((struct ggml_opt_optimizer_params *) userdata);
|
||||
}
|
||||
@@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
|
||||
/*opt_period =*/ 1,
|
||||
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
|
||||
/*get_opt_pars_ud =*/ nullptr,
|
||||
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
||||
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
||||
|
||||
const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
|
||||
|
||||
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
||||
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
||||
|
||||
const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
|
||||
opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
|
||||
ggml_set_input(opt_ctx->inputs);
|
||||
ggml_set_output(opt_ctx->outputs);
|
||||
|
||||
@@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
// - pred (if using static graphs)
|
||||
// - ncorrect (if using static graphs, 2 tensors).
|
||||
constexpr size_t n_loss = 1;
|
||||
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
||||
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
||||
const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
|
||||
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
||||
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
||||
struct ggml_init_params params = {
|
||||
@@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
||||
if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
||||
opt_ctx->grad_m.resize(n_nodes);
|
||||
opt_ctx->grad_v.resize(n_nodes);
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
@@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
||||
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
||||
|
||||
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
|
||||
ggml_set_input(opt_ctx->adamw_params);
|
||||
ggml_set_name(opt_ctx->adamw_params, "adamw_params");
|
||||
|
||||
opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
|
||||
ggml_tensor * adamw_params = opt_ctx->opt_step_params;
|
||||
ggml_set_input(adamw_params);
|
||||
const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
|
||||
ggml_format_name(adamw_params, "%s_params", optimizer_name);
|
||||
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
||||
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
||||
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
||||
|
||||
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
||||
struct ggml_tensor * m = opt_ctx->grad_m[i];
|
||||
struct ggml_tensor * v = opt_ctx->grad_v[i];
|
||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
|
||||
|
||||
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
|
||||
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
|
||||
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
|
||||
|
||||
struct ggml_tensor * m = nullptr;
|
||||
struct ggml_tensor * v = nullptr;
|
||||
if (need_momenta) {
|
||||
m = opt_ctx->grad_m[i];
|
||||
v = opt_ctx->grad_v[i];
|
||||
ggml_format_name(m, "AdamW m for %s", node->name);
|
||||
ggml_format_name(v, "AdamW v for %s", node->name);
|
||||
}
|
||||
struct ggml_tensor * opt_step;
|
||||
switch (optimizer) {
|
||||
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
||||
opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
|
||||
break;
|
||||
case GGML_OPT_OPTIMIZER_TYPE_SGD:
|
||||
opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
|
||||
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
||||
}
|
||||
}
|
||||
@@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
||||
result->opt_period = params.opt_period;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
result->optimizer = params.optimizer;
|
||||
|
||||
GGML_ASSERT(result->opt_period >= 1);
|
||||
|
||||
@@ -756,29 +781,43 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
|
||||
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
||||
GGML_ASSERT(opt_ctx->eval_ready);
|
||||
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
||||
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
||||
const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
||||
|
||||
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
|
||||
switch (opt_ctx->optimizer) {
|
||||
case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
|
||||
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
|
||||
|
||||
// beta1, beta2 after applying warmup
|
||||
const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
||||
const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
||||
// beta1, beta2 after applying warmup
|
||||
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
||||
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
||||
|
||||
float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
|
||||
adamw_par_data[0] = opt_pars.adamw.alpha;
|
||||
adamw_par_data[1] = opt_pars.adamw.beta1;
|
||||
adamw_par_data[2] = opt_pars.adamw.beta2;
|
||||
adamw_par_data[3] = opt_pars.adamw.eps;
|
||||
adamw_par_data[4] = opt_pars.adamw.wd;
|
||||
adamw_par_data[5] = beta1h;
|
||||
adamw_par_data[6] = beta2h;
|
||||
float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
|
||||
adamw_par_data[0] = opt_pars.adamw.alpha;
|
||||
adamw_par_data[1] = opt_pars.adamw.beta1;
|
||||
adamw_par_data[2] = opt_pars.adamw.beta2;
|
||||
adamw_par_data[3] = opt_pars.adamw.eps;
|
||||
adamw_par_data[4] = opt_pars.adamw.wd;
|
||||
adamw_par_data[5] = beta1h;
|
||||
adamw_par_data[6] = beta2h;
|
||||
} break;
|
||||
case GGML_OPT_OPTIMIZER_TYPE_SGD: {
|
||||
GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
|
||||
GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
|
||||
float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
|
||||
sgd[0] = opt_pars.sgd.alpha;
|
||||
sgd[1] = opt_pars.sgd.wd;
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||
@@ -963,6 +1002,7 @@ void ggml_opt_fit(
|
||||
ggml_tensor * outputs,
|
||||
ggml_opt_dataset_t dataset,
|
||||
enum ggml_opt_loss_type loss_type,
|
||||
enum ggml_opt_optimizer_type optimizer,
|
||||
ggml_opt_get_optimizer_params get_opt_pars,
|
||||
int64_t nepoch,
|
||||
int64_t nbatch_logical,
|
||||
@@ -993,6 +1033,7 @@ void ggml_opt_fit(
|
||||
params.opt_period = opt_period;
|
||||
params.get_opt_pars = get_opt_pars;
|
||||
params.get_opt_pars_ud = &epoch;
|
||||
params.optimizer = optimizer;
|
||||
ggml_opt_context_t opt_ctx = ggml_opt_init(params);
|
||||
|
||||
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
||||
@@ -1035,3 +1076,18 @@ void ggml_opt_fit(
|
||||
ggml_opt_result_free(result_train);
|
||||
ggml_opt_result_free(result_val);
|
||||
}
|
||||
|
||||
enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
|
||||
return c->optimizer;
|
||||
}
|
||||
|
||||
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
|
||||
switch (o) {
|
||||
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
||||
return "adamw";
|
||||
case GGML_OPT_OPTIMIZER_TYPE_SGD:
|
||||
return "sgd";
|
||||
default:
|
||||
return "undefined";
|
||||
};
|
||||
}
|
||||
|
@@ -510,6 +510,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
||||
@@ -3123,6 +3124,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
// conv2d
|
||||
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
||||
uint32_t conv2d_WG_SIZE = 256;
|
||||
@@ -7193,6 +7196,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_opt_step_sgd_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_leaky_relu_f32;
|
||||
@@ -7692,6 +7700,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_OPT_STEP_SGD) {
|
||||
// OPT_STEP_SGD works on src0, it does not need dst
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
|
||||
} else if (use_src2) {
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
@@ -8045,6 +8057,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t n = ggml_nelements(dst->src[0]);
|
||||
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
int * op_params = (int *)dst->op_params;
|
||||
|
||||
@@ -9598,6 +9616,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
break;
|
||||
default:
|
||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
||||
@@ -9662,6 +9681,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
{
|
||||
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
||||
// do the only thing needed for the dryrun.
|
||||
@@ -9911,6 +9931,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@@ -10014,8 +10039,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
buf = tensor->buffer;
|
||||
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
@@ -11154,6 +11179,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_ACC:
|
||||
@@ -11175,8 +11203,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
@@ -11774,6 +11800,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
src_clone[0]->flags = src0->flags;
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2], src_clone[3], src_clone[4]);
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
|
||||
src_clone[0]->flags = src0->flags;
|
||||
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2]);
|
||||
}
|
||||
else {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
|
22
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
Normal file
22
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
Normal file
@@ -0,0 +1,22 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X {A_TYPE data_x[];};
|
||||
layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
|
||||
layout (binding = 2) readonly buffer P {float data_params[2];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = data_params[0];
|
||||
const float keep = 1.f - alpha * data_params[1];
|
||||
|
||||
data_x[i] = data_x[i] * keep - alpha * data_grad[i];
|
||||
}
|
@@ -657,6 +657,7 @@ void process_shaders() {
|
||||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
||||
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
||||
|
@@ -1012,11 +1012,12 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS",
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
"OPT_STEP_ADAMW",
|
||||
"OPT_STEP_SGD",
|
||||
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1113,15 +1114,15 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss(x,y)",
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
"adamw(x)",
|
||||
"sgd(x)",
|
||||
|
||||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||
static_assert(GGML_OP_COUNT == 88, "GGML_OP_COUNT != 88");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
||||
static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||
"ABS",
|
||||
"SGN",
|
||||
@@ -5606,6 +5607,28 @@ struct ggml_tensor * ggml_opt_step_adamw(
|
||||
return result;
|
||||
}
|
||||
|
||||
// opt_step_sgd
|
||||
|
||||
struct ggml_tensor * ggml_opt_step_sgd(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * grad,
|
||||
struct ggml_tensor * params) {
|
||||
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
|
||||
GGML_ASSERT(ggml_are_same_shape(a, grad));
|
||||
GGML_ASSERT(params->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_nelements(params) == 2);
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_OPT_STEP_SGD;
|
||||
result->src[0] = a;
|
||||
result->src[1] = grad;
|
||||
result->src[2] = params;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ggml_hash_set ggml_hash_set_new(size_t size) {
|
||||
|
Reference in New Issue
Block a user