Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de | |
from typing import List, Union, Callable, Optional, Dict | |
import torch | |
from loguru import logger | |
from tqdm import tqdm | |
from SMPLX.transfer_model.utils import ( | |
from_torch, Tensor, Array, rel_change) | |
def minimize( | |
optimizer: torch.optim, | |
closure, | |
params: List[Tensor], | |
summary_closure: Optional[Callable[[], Dict[str, float]]] = None, | |
maxiters=100, | |
ftol=-1.0, | |
gtol=1e-9, | |
interactive=True, | |
summary_steps=10, | |
**kwargs | |
): | |
''' Helper function for running an optimization process | |
Args: | |
- optimizer: The PyTorch optimizer object | |
- closure: The function used to calculate the gradients | |
- params: a list containing the parameters that will be optimized | |
Keyword arguments: | |
- maxiters (100): The maximum number of iterations for the | |
optimizer | |
- ftol: The tolerance for the relative change in the loss | |
function. | |
If it is lower than this value, then the process stops | |
- gtol: The tolerance for the maximum change in the gradient. | |
If the maximum absolute values of the all gradient tensors | |
are less than this, then the process will stop. | |
''' | |
prev_loss = None | |
for n in tqdm(range(maxiters), desc='Fitting iterations'): | |
loss = optimizer.step(closure) | |
if n > 0 and prev_loss is not None and ftol > 0: | |
loss_rel_change = rel_change(prev_loss, loss.item()) | |
if loss_rel_change <= ftol: | |
prev_loss = loss.item() | |
break | |
if (all([var.grad.view(-1).abs().max().item() < gtol | |
for var in params if var.grad is not None]) and gtol > 0): | |
prev_loss = loss.item() | |
break | |
if interactive and n % summary_steps == 0: | |
logger.info(f'[{n:05d}] Loss: {loss.item():.4f}') | |
if summary_closure is not None: | |
summaries = summary_closure() | |
for key, val in summaries.items(): | |
logger.info(f'[{n:05d}] {key}: {val:.4f}') | |
prev_loss = loss.item() | |
# Save the final step | |
if interactive: | |
logger.info(f'[{n + 1:05d}] Loss: {loss.item():.4f}') | |
if summary_closure is not None: | |
summaries = summary_closure() | |
for key, val in summaries.items(): | |
logger.info(f'[{n + 1:05d}] {key}: {val:.4f}') | |
return prev_loss | |