mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-11-07 08:34:37 +01:00
test: fix OPT_STEP_ADAMW for test-backend-ops (ggml/974)
This commit is contained in:
parent
845f8d663e
commit
5e9d6baa48
@ -2052,6 +2052,7 @@ extern "C" {
|
|||||||
GGML_API struct ggml_tensor * ggml_opt_step_adamw(
|
GGML_API struct ggml_tensor * ggml_opt_step_adamw(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * grad,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta1,
|
float beta1,
|
||||||
float beta2,
|
float beta2,
|
||||||
|
@ -7818,12 +7818,14 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
|
|||||||
struct ggml_tensor * ggml_opt_step_adamw(
|
struct ggml_tensor * ggml_opt_step_adamw(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * grad,
|
||||||
float alpha,
|
float alpha,
|
||||||
float beta1,
|
float beta1,
|
||||||
float beta2,
|
float beta2,
|
||||||
float eps,
|
float eps,
|
||||||
float wd) {
|
float wd) {
|
||||||
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
|
GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(a, grad));
|
||||||
GGML_ASSERT(alpha > 0.0f);
|
GGML_ASSERT(alpha > 0.0f);
|
||||||
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
|
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
|
||||||
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
|
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
|
||||||
@ -7842,9 +7844,9 @@ struct ggml_tensor * ggml_opt_step_adamw(
|
|||||||
|
|
||||||
result->op = GGML_OP_OPT_STEP_ADAMW;
|
result->op = GGML_OP_OPT_STEP_ADAMW;
|
||||||
result->src[0] = a;
|
result->src[0] = a;
|
||||||
result->src[1] = a->grad;
|
result->src[1] = grad;
|
||||||
result->src[2] = ggml_dup_tensor(ctx, a);
|
result->src[2] = ggml_dup_tensor(ctx, grad);
|
||||||
result->src[3] = ggml_dup_tensor(ctx, a);
|
result->src[3] = ggml_dup_tensor(ctx, grad);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -18769,7 +18771,7 @@ void ggml_build_opt_adamw(
|
|||||||
|
|
||||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||||
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
||||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
|
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
|
||||||
ggml_build_forward_expand(gb, opt_step);
|
ggml_build_forward_expand(gb, opt_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user