reach-vb's picture
reach-vb HF staff
ce304fafe19161978ad512b385c65426bad519e5a0b8fb3f0659eace3d2ea3cc
f14e74e
raw
history blame
999 Bytes
# Copyright © 2023 Apple Inc.
from typing import Callable
import mlx.core as mx
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
"""Transform the passed function ``fn`` to a function that computes the
gradients of ``fn`` wrt the model's trainable parameters and also its
value.
Args:
model (mlx.nn.Module): The model whose trainable parameters to compute
gradients for
fn (Callable): The scalar function to compute gradients for
Returns:
A callable that returns the value of ``fn`` and the gradients wrt the
trainable parameters of ``model``
"""
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(*args, **kwargs)
value_grad_fn = mx.value_and_grad(inner_fn)
def wrapped_value_grad_fn(*args, **kwargs):
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
return value, grad
return wrapped_value_grad_fn