From 5d8b068249c567665b6a24caa64666abdb7870c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 12 May 2025 14:44:49 +0200 Subject: [PATCH] llama/ggml: add LLM training support (llama/10544) * llama/ggml: add LLM training support more compact progress bar llama_save_model_to_file llama_opt_param_filter ggml_graph_dup force_grads refactor ggml_opt, fix test-opt * remove logits_all * refactor CUDA implementation for ACC * reset graph at beginning of opt period --- ggml/include/ggml-opt.h | 75 +++-- ggml/include/ggml.h | 13 +- ggml/src/ggml-backend.cpp | 2 +- ggml/src/ggml-cuda/acc.cu | 66 +++-- ggml/src/ggml-cuda/sum.cu | 2 +- ggml/src/ggml-opt.cpp | 570 +++++++++++++++++++++++++------------- ggml/src/ggml.c | 41 +-- 7 files changed, 492 insertions(+), 277 deletions(-) diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index eb5eab9d..da0c24b4 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -37,13 +37,16 @@ extern "C" { // ====== Dataset ====== GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( - int64_t ne_datapoint, // number of elements per datapoint - int64_t ne_label, // number of elements per label - int64_t ndata, // total number of datapoints/labels - int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + enum ggml_type type_data, // the type for the internal data tensor + enum ggml_type type_label, // the type for the internal labels tensor + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); // get underlying tensors that store the data + GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset); GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] @@ -56,13 +59,19 @@ extern "C" { struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] int64_t ibatch); + GGML_API void ggml_opt_dataset_get_batch_host( + ggml_opt_dataset_t dataset, + void * data_batch, + size_t nb_data_batch, + void * labels_batch, + int64_t ibatch); // ====== Model / Context ====== enum ggml_opt_build_type { - GGML_OPT_BUILD_TYPE_FORWARD, - GGML_OPT_BUILD_TYPE_GRAD, - GGML_OPT_BUILD_TYPE_OPT, + GGML_OPT_BUILD_TYPE_FORWARD = 10, + GGML_OPT_BUILD_TYPE_GRAD = 20, + GGML_OPT_BUILD_TYPE_OPT = 30, }; // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss @@ -81,20 +90,22 @@ extern "C" { // userdata can be used to pass arbitrary data typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); - // returns the default optimizer params (constant) + // returns the default optimizer params (constant, hard-coded values) // userdata is not used GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); + // casts userdata to ggml_opt_optimizer_params and returns it + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata); + // parameters for initializing a new optimization context struct ggml_opt_params { ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs - struct ggml_context * ctx_compute; // created in user code, holds non-static tensors - - // the forward graph is defined by inputs and outputs - // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts - struct ggml_tensor * inputs; - struct ggml_tensor * outputs; + // by default the forward graph needs to be reconstructed for each eval + // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically + struct ggml_context * ctx_compute; + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; enum ggml_opt_loss_type loss_type; enum ggml_opt_build_type build_type; @@ -107,12 +118,9 @@ extern "C" { // get parameters for an optimization context with defaults set where possible // parameters for which no sensible defaults exist are supplied as arguments to this function - GGML_API ggml_opt_params ggml_opt_default_params( - ggml_backend_sched_t backend_sched, - struct ggml_context * ctx_compute, - struct ggml_tensor * inputs, - struct ggml_tensor * outputs, - enum ggml_opt_loss_type loss_type); + GGML_API struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + enum ggml_opt_loss_type loss_type); GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); @@ -121,6 +129,7 @@ extern "C" { GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); // get underlying tensors that store data + // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against @@ -128,11 +137,12 @@ extern "C" { GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + // get the gradient accumulator for a node from the forward graph GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); // ====== Optimization Result ====== - GGML_API ggml_opt_result_t ggml_opt_result_init(); + GGML_API ggml_opt_result_t ggml_opt_result_init(void); GGML_API void ggml_opt_result_free(ggml_opt_result_t result); GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); @@ -144,11 +154,20 @@ extern "C" { // ====== Computation ====== - // do forward pass, increment result if not NULL - GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + // if not using static graphs, this function must be called prior to ggml_opt_alloc + GGML_API void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs); - // do forward pass, increment result if not NULL, do backward pass - GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + // allocate the next graph for evaluation, either forward or forward + backward + // must be called exactly once prior to calling ggml_opt_eval + GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward); + + // do forward pass, increment result if not NULL, do backward pass if allocated + GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); // ############################################################################ // ## The high-level functions start here. They do not depend on any private ## @@ -200,9 +219,9 @@ extern "C" { // fit model defined by inputs and outputs to dataset GGML_API void ggml_opt_fit( ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs - ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs - ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] - ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used ggml_opt_dataset_t dataset, // dataset with data and optionally also labels enum ggml_opt_loss_type loss_type, // loss to minimize ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c518366d..e91dedf1 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -768,7 +768,7 @@ extern "C" { // Tensor flags GGML_API void ggml_set_input(struct ggml_tensor * tensor); GGML_API void ggml_set_output(struct ggml_tensor * tensor); - GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API void ggml_set_param(struct ggml_tensor * tensor); GGML_API void ggml_set_loss(struct ggml_tensor * tensor); // @@ -938,7 +938,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride // concat a and b along dim // used in stable-diffusion @@ -2049,15 +2049,14 @@ extern "C" { GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( - struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation) - struct ggml_context * ctx_compute, // context for gradient computation - struct ggml_cgraph * cgraph, - bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static + struct ggml_context * ctx, // context for gradient computation + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs); // graph allocation in a context GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads); - GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads); GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 6f69d895..b30b4cb3 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1111,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg const int node_backend_id = tensor_backend_id(node); - assert(node_backend_id != -1); // all nodes should be assigned by now + assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback // check if we should start a new split based on the sources of the current node bool need_new_split = false; diff --git a/ggml/src/ggml-cuda/acc.cu b/ggml/src/ggml-cuda/acc.cu index 96bfe1c9..e084607c 100644 --- a/ggml/src/ggml-cuda/acc.cu +++ b/ggml/src/ggml-cuda/acc.cu @@ -1,47 +1,61 @@ #include "acc.cuh" -static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset) { - const int i = blockDim.x * blockIdx.x + threadIdx.x; +static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= ne) { return; } - int src1_idx = i - offset; - int oz = src1_idx / nb2; - int oy = (src1_idx - (oz * nb2)) / nb1; - int ox = src1_idx % nb1; - if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { - dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; - } else { - dst[i] = x[i]; + + int64_t src1_idx = i - offset; + + int64_t tmp = src1_idx; + const int64_t i13 = tmp / s13; + tmp -= i13 * s13; + const int64_t i12 = tmp / s12; + tmp -= i12 * s12; + const int64_t i11 = tmp / s11; + tmp -= i11 * s11; + const int64_t i10 = tmp; + + float val = x[i]; + if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { + val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; } + dst[i] = val; } -static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, const int offset, cudaStream_t stream) { - int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; - acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset); +static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) { + const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; + acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); } void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); + GGML_ASSERT(ggml_is_contiguously_allocated(dst)); - acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream); + const int64_t s1 = dst->op_params[0] / sizeof(float); + const int64_t s2 = dst->op_params[1] / sizeof(float); + const int64_t s3 = dst->op_params[2] / sizeof(float); + const int64_t offset = dst->op_params[3] / sizeof(float); + + acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream); } diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index f9589080..eb3d7cdb 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index 7c3e2410..58d77578 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -28,16 +28,19 @@ struct ggml_opt_dataset { }; struct ggml_opt_context { - ggml_backend_sched_t backend_sched = nullptr; - ggml_cgraph * allocated_graph = nullptr; - ggml_cgraph * allocated_graph_copy = nullptr; - struct ggml_context * ctx_static = nullptr; - struct ggml_context * ctx_static_cpu = nullptr; - struct ggml_context * ctx_compute = nullptr; - struct ggml_context * ctx_copy = nullptr; - ggml_backend_buffer_t buf_static = nullptr; - ggml_backend_buffer_t buf_static_cpu = nullptr; - std::mt19937 rng; + ggml_backend_sched_t backend_sched = nullptr; + ggml_cgraph * allocated_graph = nullptr; + ggml_cgraph * allocated_graph_copy = nullptr; + struct ggml_context * ctx_static = nullptr; + struct ggml_context * ctx_cpu = nullptr; + struct ggml_context * ctx_compute = nullptr; + struct ggml_context * ctx_copy = nullptr; + ggml_backend_buffer_t buf_static = nullptr; + ggml_backend_buffer_t buf_cpu = nullptr; + std::mt19937 rng; + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + enum ggml_opt_build_type build_type_alloc; struct ggml_tensor * inputs = nullptr; struct ggml_tensor * outputs = nullptr; @@ -50,6 +53,11 @@ struct ggml_opt_context { struct ggml_cgraph * gf = nullptr; struct ggml_cgraph * gb_grad = nullptr; struct ggml_cgraph * gb_opt = nullptr; + bool static_graphs = false; + bool eval_ready = false; + std::vector grad_accs; + std::vector grad_m; + std::vector grad_v; int64_t iter = 1; int32_t opt_period = 1; @@ -73,7 +81,13 @@ struct ggml_opt_result { // ====== Dataset ====== -ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) { +ggml_opt_dataset_t ggml_opt_dataset_init( + enum ggml_type type_data, + enum ggml_type type_label, + int64_t ne_datapoint, + int64_t ne_label, + int64_t ndata, + int64_t ndata_shard) { GGML_ASSERT(ne_datapoint > 0); GGML_ASSERT(ne_label >= 0); GGML_ASSERT(ndata > 0); @@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, result->ctx = ggml_init(params); } - result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata); + result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata); result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; if (ne_label > 0) { - result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata); + result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata); result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; } else { result->labels = nullptr; @@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { delete dataset; } +int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) { + return dataset->ndata; +} + struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) { return dataset->data; } @@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch)); GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch)); GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT( data_batch->type == dataset->data->type); + GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type); const size_t nb_data_batch = ggml_nbytes(data_batch); GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); @@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * } } +void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) { + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data; + char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data; + memcpy(ptr_data_batch, ptr_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels; + char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels; + memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels); + } +} + // ====== Model / Context ====== struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { @@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us return result; } +struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) { + return *((struct ggml_opt_optimizer_params *) userdata); +} + struct ggml_opt_params ggml_opt_default_params( ggml_backend_sched_t backend_sched, - struct ggml_context * ctx_compute, - struct ggml_tensor * inputs, - struct ggml_tensor * outputs, enum ggml_opt_loss_type loss_type) { return { /*backend_sched =*/ backend_sched, - /*ctx_compute =*/ ctx_compute, - /*inputs =*/ inputs, - /*logits =*/ outputs, + /*ctx_compute =*/ nullptr, + /*inputs =*/ nullptr, + /*logits =*/ nullptr, /*loss_type =*/ loss_type, /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, /*opt_period =*/ 1, @@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) { return dst; } -static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) { - GGML_ASSERT(graph); - if (opt_ctx->allocated_graph == graph) { - return; - } +static void ggml_opt_build(ggml_opt_context_t opt_ctx) { + GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc"); + GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically"); - ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD && + !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1); - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE, - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - ggml_free(opt_ctx->ctx_copy); - opt_ctx->ctx_copy = ggml_init(params); - } - - opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); - - ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); - opt_ctx->allocated_graph = graph; -} - -ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { - ggml_opt_context_t result = new struct ggml_opt_context; - result->backend_sched = params.backend_sched; - result->ctx_compute = params.ctx_compute; - result->inputs = params.inputs; - result->outputs = params.outputs; - result->opt_period = params.opt_period; - result->get_opt_pars = params.get_opt_pars; - result->get_opt_pars_ud = params.get_opt_pars_ud; - - GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); - GGML_ASSERT(result->opt_period >= 1); - - const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD || - (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1); - - ggml_set_input(result->inputs); - ggml_set_output(result->outputs); - - result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. - ggml_build_forward_expand(result->gf, result->outputs); + ggml_set_input(opt_ctx->inputs); + ggml_set_output(opt_ctx->outputs); int n_param = 0; - for (int i = 0; i < result->gf->n_nodes; ++i) { - if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { + for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) { + const struct ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { n_param++; } + GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented"); } - { + if (!opt_ctx->ctx_static) { // The static context is used for: - // - gradients (1 tensor per param if using gradient accumulation) + // - gradients (1 per loss, 1 tensor per param if using gradient accumulation) // - optimizer momenta (2 tensors per param) - // - labels - // - loss + its gradient (up to 5 tensors) - // - pred - // - ncorrect (2 tensors). - const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); - const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead(); + // - labels (if using static graphs) + // - loss (if using static graphs, up to 5 tensors) + // - pred (if using static graphs) + // - ncorrect (if using static graphs, 2 tensors). + constexpr size_t n_loss = 1; + const size_t tensors_per_param = (accumulate ? 1 : 0) + + (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0; + const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead(); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - result->ctx_static = ggml_init(params); + opt_ctx->ctx_static = ggml_init(params); } + GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc); + { - // The static cpu context is used for: - // - optimizer parameters (1 for the entire context) + // The cpu context is allocated statically if using static graphs, dynamically otherwise. + // It is used for: + // - optimizer parameters (1 shared for all optimizer invocations) const size_t size_meta = 1 * ggml_tensor_overhead(); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - result->ctx_static_cpu = ggml_init(params); + ggml_free(opt_ctx->ctx_cpu); + opt_ctx->ctx_cpu = ggml_init(params); + + ggml_backend_buffer_free(opt_ctx->buf_cpu); + opt_ctx->buf_cpu = nullptr; } + struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute; - switch (params.loss_type) { + switch (opt_ctx->loss_type) { case GGML_OPT_LOSS_TYPE_MEAN: { - result->loss = ggml_sum(result->ctx_static, result->outputs); - ggml_set_name(result->loss, "loss_sum"); - const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); - result->loss = ggml_scale(result->ctx_static, result->loss, scale); - ggml_set_name(result->loss, "loss_mean"); - result->loss_per_datapoint = true; + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean"); + opt_ctx->loss_per_datapoint = true; break; } case GGML_OPT_LOSS_TYPE_SUM: { - result->loss = ggml_sum(result->ctx_static, result->outputs); - ggml_set_name(result->loss, "loss_sum"); - result->loss_per_datapoint = false; + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + opt_ctx->loss_per_datapoint = false; break; } case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { - result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); - ggml_set_input(result->labels); - ggml_set_name(result->labels, "labels"); - result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels); - ggml_set_name(result->loss, "loss_cross_entropy"); - if (result->opt_period > 1) { - result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period); - ggml_set_name(result->loss, "loss_cross_entropy_scaled"); + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy"); + if (opt_ctx->opt_period > 1) { + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled"); } - result->loss_per_datapoint = true; + opt_ctx->loss_per_datapoint = true; break; } case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { - result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); - ggml_set_input(result->labels); - ggml_set_name(result->labels, "labels"); - result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels); - ggml_set_name(result->loss, "loss_error"); - result->loss = ggml_sqr(result->ctx_static, result->loss); - ggml_set_name(result->loss, "loss_squared_error"); - result->loss = ggml_sum(result->ctx_static, result->loss); - ggml_set_name(result->loss, "loss_sum_squared_error"); - const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); - result->loss = ggml_scale(result->ctx_static, result->loss, scale); - ggml_set_name(result->loss, "loss_mean_squared_error"); - result->loss_per_datapoint = true; + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_error"); + opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_squared_error"); + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean_squared_error"); + opt_ctx->loss_per_datapoint = true; break; } } - ggml_set_output(result->loss); - ggml_set_loss(result->loss); - ggml_build_forward_expand(result->gf, result->loss); + ggml_set_output(opt_ctx->loss); + ggml_set_loss(opt_ctx->loss); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss); - result->pred = ggml_argmax(result->ctx_static, result->outputs); - ggml_set_name(result->pred, "pred"); - ggml_set_output(result->pred); - ggml_build_forward_expand(result->gf, result->pred); + if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) { + opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->pred, "pred"); + ggml_set_output(opt_ctx->pred); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred); - if (result->labels) { - result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels)); - ggml_set_name(result->ncorrect, "ncorrect"); - ggml_set_output(result->ncorrect); - ggml_build_forward_expand(result->gf, result->ncorrect); - } else { - result->ncorrect = nullptr; + opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels)); + ggml_set_name(opt_ctx->ncorrect, "ncorrect"); + ggml_set_output(opt_ctx->ncorrect); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect); } - if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) { - result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); - return result; + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + return; } - // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. - result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf); - ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); + if (opt_ctx->grad_accs.empty()) { + GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD); - if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) { - result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); - ggml_graph_reset(result->gb_grad); - return result; - } + const int n_nodes = opt_ctx->gf->n_nodes; + opt_ctx->grad_accs.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { + opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_accs[i] = nullptr; + } + } - GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT); - - // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. - result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad); - - result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7); - ggml_set_input(result->adamw_params); - ggml_set_name(result->adamw_params, "adamw_params"); - - for (int i = result->gf->n_nodes-1; i >= 0; --i) { - struct ggml_tensor * node = result->gb_opt->nodes[i]; - struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node); - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node); - struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params); - ggml_build_forward_expand(result->gb_opt, opt_step); + if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) { + opt_ctx->grad_m.resize(n_nodes); + opt_ctx->grad_v.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_m[i] = nullptr; + opt_ctx->grad_v[i] = nullptr; + } + } } } - result->buf_static = ggml_backend_alloc_ctx_tensors( - result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. + opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true); + ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data()); - result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type()); + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_grad); + } - ggml_graph_reset(result->gb_opt); + GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT); + + // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. + opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true); + + opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7); + ggml_set_input(opt_ctx->adamw_params); + ggml_set_name(opt_ctx->adamw_params, "adamw_params"); + + for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) { + struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node); + + if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) { + struct ggml_tensor * m = opt_ctx->grad_m[i]; + struct ggml_tensor * v = opt_ctx->grad_v[i]; + struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params); + + ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str()); + ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str()); + ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str()); + + ggml_build_forward_expand(opt_ctx->gb_opt, opt_step); + } + } + + if (!opt_ctx->buf_static) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_opt); + } + + opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type()); +} + +ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { + ggml_opt_context_t result = new struct ggml_opt_context; + result->backend_sched = params.backend_sched; + result->ctx_compute = params.ctx_compute; + result->loss_type = params.loss_type; + result->build_type = params.build_type; + result->build_type_alloc = params.build_type; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->opt_period = params.opt_period; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + GGML_ASSERT(result->opt_period >= 1); + + result->static_graphs = result->ctx_compute; + + if (!result->static_graphs) { + GGML_ASSERT(!result->inputs); + GGML_ASSERT(!result->outputs); + return result; + } + + GGML_ASSERT(result->inputs); + GGML_ASSERT(result->outputs); + + result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + ggml_build_forward_expand(result->gf, result->outputs); + + ggml_opt_build(result); return result; } @@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { return; } ggml_backend_buffer_free(opt_ctx->buf_static); - ggml_backend_buffer_free(opt_ctx->buf_static_cpu); + ggml_backend_buffer_free(opt_ctx->buf_cpu); ggml_free(opt_ctx->ctx_static); - ggml_free(opt_ctx->ctx_static_cpu); + ggml_free(opt_ctx->ctx_cpu); delete opt_ctx; } @@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl // ====== Computation ====== -static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) { - if (graph != opt_ctx->gf) { +void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs) { + GGML_ASSERT(!opt_ctx->static_graphs); + opt_ctx->ctx_compute = ctx_compute; + opt_ctx->gf = gf; + opt_ctx->inputs = inputs; + opt_ctx->outputs = outputs; +} + +void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) { + GGML_ASSERT(!opt_ctx->eval_ready); + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) { + ggml_graph_reset(opt_ctx->gb_grad); + } + if (backward) { + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD; + } else { + opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD; + } + + if (!opt_ctx->static_graphs) { + ggml_opt_build(opt_ctx); + } + + struct ggml_cgraph * graph = nullptr; + switch (opt_ctx->build_type) { + case GGML_OPT_BUILD_TYPE_FORWARD: { + graph = opt_ctx->gf; + } break; + case GGML_OPT_BUILD_TYPE_GRAD: { + graph = opt_ctx->gb_grad; + } break; + case GGML_OPT_BUILD_TYPE_OPT: { + graph = opt_ctx->gb_opt; + } break; + } + GGML_ASSERT(graph); + + if (opt_ctx->allocated_graph == graph) { + opt_ctx->eval_ready = true; + return; + } + + ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + if (opt_ctx->static_graphs) { + ggml_init_params params = { + /*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = ggml_init(params); + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + } else { + opt_ctx->allocated_graph_copy = graph; + } + + ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; + + opt_ctx->eval_ready = true; +} + +void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) { + GGML_ASSERT(opt_ctx->eval_ready); + if (opt_ctx->allocated_graph == opt_ctx->gb_opt) { struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); @@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, adamw_par_data[6] = beta2h; } - ggml_opt_alloc_graph(opt_ctx, graph); ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + + if (!opt_ctx->static_graphs) { + opt_ctx->gf = nullptr; + opt_ctx->gb_grad = nullptr; + opt_ctx->gb_opt = nullptr; + opt_ctx->allocated_graph = nullptr; + opt_ctx->allocated_graph_copy = nullptr; + } + + opt_ctx->eval_ready = false; if (!result) { return; @@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss)); result->loss.push_back(loss); - GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); - std::vector pred(ndata); - ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); - result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + if (opt_ctx->pred) { + GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); + std::vector pred(ndata); + ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + } - if (!opt_ctx->labels || result->ncorrect < 0) { + if (!opt_ctx->ncorrect || result->ncorrect < 0) { result->ncorrect = -1; return; } @@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, result->ncorrect += ncorrect; } -void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result); -} - -void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { - if (opt_ctx->opt_period == 1) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); - return; - } - - const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; - if (opt_i_next == 0) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); - ggml_opt_reset(opt_ctx, /*optimizer =*/ false); - } else { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result); - } - opt_ctx->opt_i = opt_i_next; -} - // ====== High-Level Functions ====== void ggml_opt_epoch( @@ -700,16 +860,18 @@ void ggml_opt_epoch( int64_t ibatch = 0; int64_t t_loop_start = ggml_time_us(); for (; ibatch < ibatch_split; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ true); ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); - ggml_opt_forward_backward(opt_ctx, result_train); + ggml_opt_eval(opt_ctx, result_train); if (callback_train) { callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); } } t_loop_start = ggml_time_us(); for (; ibatch < nbatches; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ false); ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); - ggml_opt_forward(opt_ctx, result_eval); + ggml_opt_eval(opt_ctx, result_eval); if (callback_eval) { callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); } @@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar( int64_t t_start_us) { fprintf(stderr, "%s[", train ? "train: " : "val: "); - constexpr int64_t bar_length = 25; + // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels. + constexpr int64_t bar_length = 8; + const int64_t ibatch8 = 8 * ibatch; for (int64_t j = 0; j < bar_length; ++j) { - const int64_t ibatch_j = ibatch_max * j/bar_length; - if (ibatch_j < ibatch) { - fprintf(stderr, "="); - } else if (ibatch_max * (j - 1)/bar_length < ibatch) { - fprintf(stderr, ">"); + if (ibatch_max * (8*j + 8) / bar_length < ibatch8) { + fprintf(stderr, "\u2588"); // full block + } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) { + fprintf(stderr, "\u2589"); // 7/8 filled + } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) { + fprintf(stderr, "\u258A"); // 6/8 filled + } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) { + fprintf(stderr, "\u258B"); // 5/8 filled + } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) { + fprintf(stderr, "\u258C"); // 4/8 filled + } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) { + fprintf(stderr, "\u258D"); // 3/8 filled + } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) { + fprintf(stderr, "\u258E"); // 2/8 filled + } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) { + fprintf(stderr, "\u258F"); // 1/8 filled } else { fprintf(stderr, " "); } @@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar( const int64_t t_eta_m = t_eta_s / 60; t_eta_s -= t_eta_m * 60; - fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, " - "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r", + fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r", idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); if (ibatch == ibatch_max) { @@ -806,7 +981,10 @@ void ggml_opt_fit( int64_t epoch = 1; - ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type); + params.ctx_compute = ctx_compute; + params.inputs = inputs; + params.outputs = outputs; params.opt_period = opt_period; params.get_opt_pars = get_opt_pars; params.get_opt_pars_ud = &epoch; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ee4fe9f7..01f7e05b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5499,7 +5499,7 @@ static void ggml_compute_backward( // tensor = src0 * 1 + src1 * 0 if (src0_needs_grads) { // dsrc0 = dtensor * 1 - ggml_add_or_set(ctx, cgraph, isrc0, grad); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0)); } if (src1_needs_grads) { // dsrc1 = dtensor * 0 -> noop @@ -5780,10 +5780,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * } void ggml_build_backward_expand( - struct ggml_context * ctx_static, - struct ggml_context * ctx_compute, - struct ggml_cgraph * cgraph, - bool accumulate) { + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs) { GGML_ASSERT(cgraph->n_nodes > 0); GGML_ASSERT(cgraph->grads); GGML_ASSERT(cgraph->grad_accs); @@ -5856,21 +5855,24 @@ void ggml_build_backward_expand( GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); - const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); - GGML_ASSERT(igrad != GGML_HASHSET_FULL); - GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad)); - if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { - cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node); - cgraph->grads[igrad] = cgraph->grad_accs[igrad]; - ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name); + const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node); + GGML_ASSERT(ihash != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash)); + if (grad_accs && grad_accs[i]) { + cgraph->grad_accs[ihash] = grad_accs[i]; + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; + } else if (node->flags & GGML_TENSOR_FLAG_LOSS) { + // loss tensors always need a gradient accumulator + cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; } - grads_needed[igrad] = true; + grads_needed[ihash] = true; } for (int i = n_nodes_f - 1; i >= 0; --i) { // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations - ggml_compute_backward(ctx_compute, cgraph, i, grads_needed); + ggml_compute_backward(ctx, cgraph, i, grads_needed); } free(grads_needed); @@ -6016,8 +6018,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { } } -struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { - struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); +struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) { + struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads); ggml_graph_cpy(cgraph, result); return result; } @@ -6036,6 +6038,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { } void ggml_graph_reset(struct ggml_cgraph * cgraph) { + if (!cgraph) { + return; + } GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { @@ -6345,8 +6350,8 @@ void ggml_set_output(struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_OUTPUT; } -void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) { - GGML_UNUSED(ctx); // TODO: remove this parameter +void ggml_set_param(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_NONE); tensor->flags |= GGML_TENSOR_FLAG_PARAM; }