mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-01 07:25:49 +02:00
mnist: fix segmentation fault (ggml/1227)
This commit is contained in:
parent
9c3bfc1499
commit
405b9c77ad
@ -128,6 +128,8 @@ extern "C" {
|
|||||||
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
||||||
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
||||||
|
|
||||||
|
GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
|
||||||
|
|
||||||
// get underlying tensors that store data
|
// get underlying tensors that store data
|
||||||
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
||||||
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
||||||
|
@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
|
||||||
|
return opt_ctx->static_graphs;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
|
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
|
||||||
return opt_ctx->inputs;
|
return opt_ctx->inputs;
|
||||||
}
|
}
|
||||||
@ -842,6 +846,7 @@ void ggml_opt_epoch(
|
|||||||
int64_t idata_split,
|
int64_t idata_split,
|
||||||
ggml_opt_epoch_callback callback_train,
|
ggml_opt_epoch_callback callback_train,
|
||||||
ggml_opt_epoch_callback callback_eval) {
|
ggml_opt_epoch_callback callback_eval) {
|
||||||
|
GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
|
||||||
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
|
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
|
||||||
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
||||||
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
|
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user