From 5e9d6baa48e263dd2f9e216022012fcfee4c581f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 30 Sep 2024 09:55:23 +0200 Subject: [PATCH] test: fix OPT_STEP_ADAMW for test-backend-ops (ggml/974) --- ggml/include/ggml.h | 1 + ggml/src/ggml.c | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a8a74bee..ce3d92cb 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2052,6 +2052,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * grad, float alpha, float beta1, float beta2, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aac4e3a7..bcbc32d9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, struct ggml_tensor * a, + struct ggml_tensor * grad, float alpha, float beta1, float beta2, float eps, float wd) { GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); + GGML_ASSERT(ggml_are_same_shape(a, grad)); GGML_ASSERT(alpha > 0.0f); GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); @@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw( result->op = GGML_OP_OPT_STEP_ADAMW; result->src[0] = a; - result->src[1] = a->grad; - result->src[2] = ggml_dup_tensor(ctx, a); - result->src[3] = ggml_dup_tensor(ctx, a); + result->src[1] = grad; + result->src[2] = ggml_dup_tensor(ctx, grad); + result->src[3] = ggml_dup_tensor(ctx, grad); return result; } @@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw( if (node->flags & GGML_TENSOR_FLAG_PARAM) { GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd); + struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); ggml_build_forward_expand(gb, opt_step); } }