# Copyright © 2023 Apple Inc. | |
from typing import Callable | |
import mlx.core as mx | |
def value_and_grad(model: "mlx.nn.Module", fn: Callable): | |
"""Transform the passed function ``fn`` to a function that computes the | |
gradients of ``fn`` wrt the model's trainable parameters and also its | |
value. | |
Args: | |
model (mlx.nn.Module): The model whose trainable parameters to compute | |
gradients for | |
fn (Callable): The scalar function to compute gradients for | |
Returns: | |
A callable that returns the value of ``fn`` and the gradients wrt the | |
trainable parameters of ``model`` | |
""" | |
def inner_fn(params, *args, **kwargs): | |
model.update(params) | |
return fn(*args, **kwargs) | |
value_grad_fn = mx.value_and_grad(inner_fn) | |
def wrapped_value_grad_fn(*args, **kwargs): | |
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs) | |
return value, grad | |
return wrapped_value_grad_fn | |