mnist: fix segmentation fault (ggml/1227)

This commit is contained in:
Johannes Gäßler 2025-05-19 09:33:35 +02:00 committed by Georgi Gerganov
parent 9c3bfc1499
commit 405b9c77ad
2 changed files with 7 additions and 0 deletions

View File

@ -128,6 +128,8 @@ extern "C" {
// set gradients to zero, initilize loss, and optionally reset the optimizer // set gradients to zero, initilize loss, and optionally reset the optimizer
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
// get underlying tensors that store data // get underlying tensors that store data
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor

View File

@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
} }
} }
bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
return opt_ctx->static_graphs;
}
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
return opt_ctx->inputs; return opt_ctx->inputs;
} }
@ -842,6 +846,7 @@ void ggml_opt_epoch(
int64_t idata_split, int64_t idata_split,
ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval) { ggml_opt_epoch_callback callback_eval) {
GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
struct ggml_tensor * data = ggml_opt_dataset_data(dataset); struct ggml_tensor * data = ggml_opt_dataset_data(dataset);