kleinhe
init
c3d0293
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de
from typing import List, Union, Callable, Optional, Dict
import torch
from loguru import logger
from tqdm import tqdm
from SMPLX.transfer_model.utils import (
from_torch, Tensor, Array, rel_change)
def minimize(
optimizer: torch.optim,
closure,
params: List[Tensor],
summary_closure: Optional[Callable[[], Dict[str, float]]] = None,
maxiters=100,
ftol=-1.0,
gtol=1e-9,
interactive=True,
summary_steps=10,
**kwargs
):
''' Helper function for running an optimization process
Args:
- optimizer: The PyTorch optimizer object
- closure: The function used to calculate the gradients
- params: a list containing the parameters that will be optimized
Keyword arguments:
- maxiters (100): The maximum number of iterations for the
optimizer
- ftol: The tolerance for the relative change in the loss
function.
If it is lower than this value, then the process stops
- gtol: The tolerance for the maximum change in the gradient.
If the maximum absolute values of the all gradient tensors
are less than this, then the process will stop.
'''
prev_loss = None
for n in tqdm(range(maxiters), desc='Fitting iterations'):
loss = optimizer.step(closure)
if n > 0 and prev_loss is not None and ftol > 0:
loss_rel_change = rel_change(prev_loss, loss.item())
if loss_rel_change <= ftol:
prev_loss = loss.item()
break
if (all([var.grad.view(-1).abs().max().item() < gtol
for var in params if var.grad is not None]) and gtol > 0):
prev_loss = loss.item()
break
if interactive and n % summary_steps == 0:
logger.info(f'[{n:05d}] Loss: {loss.item():.4f}')
if summary_closure is not None:
summaries = summary_closure()
for key, val in summaries.items():
logger.info(f'[{n:05d}] {key}: {val:.4f}')
prev_loss = loss.item()
# Save the final step
if interactive:
logger.info(f'[{n + 1:05d}] Loss: {loss.item():.4f}')
if summary_closure is not None:
summaries = summary_closure()
for key, val in summaries.items():
logger.info(f'[{n + 1:05d}] {key}: {val:.4f}')
return prev_loss