# Copyright © 2023 Apple Inc. from typing import Callable import mlx.core as mx def value_and_grad(model: "mlx.nn.Module", fn: Callable): """Transform the passed function ``fn`` to a function that computes the gradients of ``fn`` wrt the model's trainable parameters and also its value. Args: model (mlx.nn.Module): The model whose trainable parameters to compute gradients for fn (Callable): The scalar function to compute gradients for Returns: A callable that returns the value of ``fn`` and the gradients wrt the trainable parameters of ``model`` """ def inner_fn(params, *args, **kwargs): model.update(params) return fn(*args, **kwargs) value_grad_fn = mx.value_and_grad(inner_fn) def wrapped_value_grad_fn(*args, **kwargs): value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs) return value, grad return wrapped_value_grad_fn