Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de | |
from typing import Optional, Dict, Callable | |
import sys | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
from loguru import logger | |
from SMPLX.transfer_model.utils import get_vertices_per_edge | |
from SMPLX.transfer_model.optimizers import build_optimizer, minimize | |
from SMPLX.transfer_model.utils import ( | |
Tensor, batch_rodrigues, apply_deformation_transfer) | |
from SMPLX.transfer_model.losses import build_loss | |
def summary_closure(gt_vertices, var_dict, body_model, mask_ids=None): | |
param_dict = {} | |
for key, var in var_dict.items(): | |
# Decode the axis-angles | |
if 'pose' in key or 'orient' in key: | |
param_dict[key] = batch_rodrigues( | |
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
else: | |
# Simply pass the variable | |
param_dict[key] = var | |
body_model_output = body_model( | |
return_full_pose=True, get_skin=True, **param_dict) | |
est_vertices = body_model_output.vertices | |
if mask_ids is not None: | |
est_vertices = est_vertices[:, mask_ids] | |
gt_vertices = gt_vertices[:, mask_ids] | |
v2v = (est_vertices - gt_vertices).pow(2).sum(dim=-1).sqrt().mean() | |
return { | |
'Vertex-to-Vertex': v2v * 1000} | |
def build_model_forward_closure( | |
body_model: nn.Module, | |
var_dict: Dict[str, Tensor], | |
per_part: bool = True, | |
part_key: Optional[str] = None, | |
jidx: Optional[int] = None, | |
part: Optional[Tensor] = None | |
) -> Callable: | |
if per_part: | |
cond = part is not None and part_key is not None and jidx is not None | |
assert cond, ( | |
'When per-part is True, "part", "part_key", "jidx" must not be' | |
' None.' | |
) | |
def model_forward(): | |
param_dict = {} | |
for key, var in var_dict.items(): | |
if part_key == key: | |
param_dict[key] = batch_rodrigues( | |
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
param_dict[key][:, jidx] = batch_rodrigues( | |
part.reshape(-1, 3)).reshape(-1, 3, 3) | |
else: | |
# Decode the axis-angles | |
if 'pose' in key or 'orient' in key: | |
param_dict[key] = batch_rodrigues( | |
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
else: | |
# Simply pass the variable | |
param_dict[key] = var | |
return body_model( | |
return_full_pose=True, get_skin=True, **param_dict) | |
else: | |
def model_forward(): | |
param_dict = {} | |
for key, var in var_dict.items(): | |
# Decode the axis-angles | |
if 'pose' in key or 'orient' in key: | |
param_dict[key] = batch_rodrigues( | |
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
else: | |
# Simply pass the variable | |
param_dict[key] = var | |
return body_model(return_full_pose=True, get_skin=True, | |
**param_dict) | |
return model_forward | |
def build_edge_closure( | |
body_model: nn.Module, | |
var_dict: Dict[str, Tensor], | |
edge_loss: nn.Module, | |
optimizer_dict, | |
gt_vertices: Tensor, | |
per_part: bool = True, | |
part_key: Optional[str] = None, | |
jidx: Optional[int] = None, | |
part: Optional[Tensor] = None | |
) -> Callable: | |
''' Builds the closure for the edge objective | |
''' | |
optimizer = optimizer_dict['optimizer'] | |
create_graph = optimizer_dict['create_graph'] | |
if per_part: | |
params_to_opt = [part] | |
else: | |
params_to_opt = [p for key, p in var_dict.items() if 'pose' in key] | |
model_forward = build_model_forward_closure( | |
body_model, var_dict, per_part=per_part, part_key=part_key, | |
jidx=jidx, part=part) | |
def closure(backward=True): | |
if backward: | |
optimizer.zero_grad() | |
body_model_output = model_forward() | |
est_vertices = body_model_output.vertices | |
loss = edge_loss(est_vertices, gt_vertices) | |
if backward: | |
if create_graph: | |
# Use this instead of .backward to avoid GPU memory leaks | |
grads = torch.autograd.grad( | |
loss, params_to_opt, create_graph=True) | |
torch.autograd.backward( | |
params_to_opt, grads, create_graph=True) | |
else: | |
loss.backward() | |
return loss | |
return closure | |
def build_vertex_closure( | |
body_model: nn.Module, | |
var_dict: Dict[str, Tensor], | |
optimizer_dict, | |
gt_vertices: Tensor, | |
vertex_loss: nn.Module, | |
mask_ids=None, | |
per_part: bool = True, | |
part_key: Optional[str] = None, | |
jidx: Optional[int] = None, | |
part: Optional[Tensor] = None, | |
params_to_opt: Optional[Tensor] = None, | |
) -> Callable: | |
''' Builds the closure for the vertex objective | |
''' | |
optimizer = optimizer_dict['optimizer'] | |
create_graph = optimizer_dict['create_graph'] | |
model_forward = build_model_forward_closure( | |
body_model, var_dict, per_part=per_part, part_key=part_key, | |
jidx=jidx, part=part) | |
if params_to_opt is None: | |
params_to_opt = [p for key, p in var_dict.items()] | |
def closure(backward=True): | |
if backward: | |
optimizer.zero_grad() | |
body_model_output = model_forward() | |
est_vertices = body_model_output.vertices | |
loss = vertex_loss( | |
est_vertices[:, mask_ids] if mask_ids is not None else | |
est_vertices, | |
gt_vertices[:, mask_ids] if mask_ids is not None else gt_vertices) | |
if backward: | |
if create_graph: | |
# Use this instead of .backward to avoid GPU memory leaks | |
grads = torch.autograd.grad( | |
loss, params_to_opt, create_graph=True) | |
torch.autograd.backward( | |
params_to_opt, grads, create_graph=True) | |
else: | |
loss.backward() | |
return loss | |
return closure | |
def get_variables( | |
batch_size: int, | |
body_model: nn.Module, | |
dtype: torch.dtype = torch.float32 | |
) -> Dict[str, Tensor]: | |
var_dict = {} | |
device = next(body_model.buffers()).device | |
if (body_model.name() == 'SMPL' or body_model.name() == 'SMPL+H' or | |
body_model.name() == 'SMPL-X'): | |
var_dict.update({ | |
'transl': torch.zeros( | |
[batch_size, 3], device=device, dtype=dtype), | |
'global_orient': torch.zeros( | |
[batch_size, 1, 3], device=device, dtype=dtype), | |
'body_pose': torch.zeros( | |
[batch_size, body_model.NUM_BODY_JOINTS, 3], | |
device=device, dtype=dtype), | |
'betas': torch.zeros([batch_size, body_model.num_betas], | |
dtype=dtype, device=device), | |
}) | |
if body_model.name() == 'SMPL+H' or body_model.name() == 'SMPL-X': | |
var_dict.update( | |
left_hand_pose=torch.zeros( | |
[batch_size, body_model.NUM_HAND_JOINTS, 3], device=device, | |
dtype=dtype), | |
right_hand_pose=torch.zeros( | |
[batch_size, body_model.NUM_HAND_JOINTS, 3], device=device, | |
dtype=dtype), | |
) | |
if body_model.name() == 'SMPL-X': | |
var_dict.update( | |
jaw_pose=torch.zeros([batch_size, 1, 3], | |
device=device, dtype=dtype), | |
leye_pose=torch.zeros([batch_size, 1, 3], | |
device=device, dtype=dtype), | |
reye_pose=torch.zeros([batch_size, 1, 3], | |
device=device, dtype=dtype), | |
expression=torch.zeros( | |
[batch_size, body_model.num_expression_coeffs], | |
device=device, dtype=dtype), | |
) | |
# Toggle gradients to True | |
for key, val in var_dict.items(): | |
val.requires_grad_(True) | |
return var_dict | |
def run_fitting( | |
# exp_cfg, | |
batch: Dict[str, Tensor], | |
body_model: nn.Module, | |
def_matrix: Tensor, | |
mask_ids | |
) -> Dict[str, Tensor]: | |
''' Runs fitting | |
''' | |
vertices = batch['vertices'] | |
faces = batch['faces'] | |
batch_size = len(vertices) | |
dtype, device = vertices.dtype, vertices.device | |
# summary_steps = exp_cfg.get('summary_steps') | |
# interactive = exp_cfg.get('interactive') | |
summary_steps = 100 | |
interactive = True | |
# Get the parameters from the model | |
var_dict = get_variables(batch_size, body_model) | |
# Build the optimizer object for the current batch | |
# optim_cfg = exp_cfg.get('optim', {}) | |
optim_cfg = {'type': 'trust-ncg', 'lr': 1.0, 'gtol': 1e-06, 'ftol': -1.0, 'maxiters': 100, 'lbfgs': {'line_search_fn': 'strong_wolfe', 'max_iter': 50}, 'sgd': {'momentum': 0.9, 'nesterov': True}, 'adam': {'betas': [0.9, 0.999], 'eps': 1e-08, 'amsgrad': False}, 'trust_ncg': {'max_trust_radius': 1000.0, 'initial_trust_radius': 0.05, 'eta': 0.15, 'gtol': 1e-05}} | |
def_vertices = apply_deformation_transfer(def_matrix, vertices, faces) | |
if mask_ids is None: | |
f_sel = np.ones_like(body_model.faces[:, 0], dtype=np.bool_) | |
else: | |
f_per_v = [[] for _ in range(body_model.get_num_verts())] | |
[f_per_v[vv].append(iff) for iff, ff in enumerate(body_model.faces) | |
for vv in ff] | |
f_sel = list(set(tuple(sum([f_per_v[vv] for vv in mask_ids], [])))) | |
vpe = get_vertices_per_edge( | |
body_model.v_template.detach().cpu().numpy(), body_model.faces[f_sel]) | |
def log_closure(): | |
return summary_closure(def_vertices, var_dict, body_model, | |
mask_ids=mask_ids) | |
# edge_fitting_cfg = exp_cfg.get('edge_fitting', {}) | |
edge_fitting_cfg = {'per_part': False, 'reduction': 'mean'} | |
edge_loss = build_loss(type='vertex-edge', gt_edges=vpe, est_edges=vpe, | |
**edge_fitting_cfg) | |
edge_loss = edge_loss.to(device=device) | |
# vertex_fitting_cfg = exp_cfg.get('vertex_fitting', {}) | |
vertex_fitting_cfg = {} | |
vertex_loss = build_loss(**vertex_fitting_cfg) | |
vertex_loss = vertex_loss.to(device=device) | |
per_part = edge_fitting_cfg.get('per_part', True) | |
logger.info(f'Per-part: {per_part}') | |
# Optimize edge-based loss to initialize pose | |
if per_part: | |
for key, var in tqdm(var_dict.items(), desc='Parts'): | |
if 'pose' not in key: | |
continue | |
for jidx in tqdm(range(var.shape[1]), desc='Joints'): | |
part = torch.zeros( | |
[batch_size, 3], dtype=dtype, device=device, | |
requires_grad=True) | |
# Build the optimizer for the current part | |
optimizer_dict = build_optimizer([part], optim_cfg) | |
closure = build_edge_closure( | |
body_model, var_dict, edge_loss, optimizer_dict, | |
def_vertices, per_part=per_part, part_key=key, jidx=jidx, | |
part=part) | |
minimize(optimizer_dict['optimizer'], closure, | |
params=[part], | |
summary_closure=log_closure, | |
summary_steps=summary_steps, | |
interactive=interactive, | |
**optim_cfg) | |
with torch.no_grad(): | |
var[:, jidx] = part | |
else: | |
optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg) | |
closure = build_edge_closure( | |
body_model, var_dict, edge_loss, optimizer_dict, | |
def_vertices, per_part=per_part) | |
minimize(optimizer_dict['optimizer'], closure, | |
params=var_dict.values(), | |
summary_closure=log_closure, | |
summary_steps=summary_steps, | |
interactive=interactive, | |
**optim_cfg) | |
if 'translation' in var_dict: | |
optimizer_dict = build_optimizer([var_dict['translation']], optim_cfg) | |
closure = build_vertex_closure( | |
body_model, var_dict, | |
optimizer_dict, | |
def_vertices, | |
vertex_loss=vertex_loss, | |
mask_ids=mask_ids, | |
per_part=False, | |
params_to_opt=[var_dict['translation']], | |
) | |
# Optimize translation | |
minimize(optimizer_dict['optimizer'], | |
closure, | |
params=[var_dict['translation']], | |
summary_closure=log_closure, | |
summary_steps=summary_steps, | |
interactive=interactive, | |
**optim_cfg) | |
# Optimize all model parameters with vertex-based loss | |
optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg) | |
closure = build_vertex_closure( | |
body_model, var_dict, | |
optimizer_dict, | |
def_vertices, | |
vertex_loss=vertex_loss, | |
per_part=False, | |
mask_ids=mask_ids) | |
minimize(optimizer_dict['optimizer'], closure, | |
params=list(var_dict.values()), | |
summary_closure=log_closure, | |
summary_steps=summary_steps, | |
interactive=interactive, | |
**optim_cfg) | |
param_dict = {} | |
for key, var in var_dict.items(): | |
# Decode the axis-angles | |
if 'pose' in key or 'orient' in key: | |
param_dict[key] = batch_rodrigues( | |
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3) | |
else: | |
# Simply pass the variable | |
param_dict[key] = var | |
body_model_output = body_model( | |
return_full_pose=True, get_skin=True, **param_dict) | |
keys = ["vertices", "joints", "betas", "global_orient", "body_pose", "left_hand_pose", "right_hand_pose", "full_pose"] | |
for key in keys: | |
var_dict[key] = getattr(body_model_output, key) | |
var_dict['faces'] = body_model.faces | |
for key in var_dict.keys(): | |
try: | |
var_dict[key] = var_dict[key].detach().cpu().numpy() | |
except: | |
pass | |
return var_dict | |