SemanticBoost / SMPLX /transfer_model /transfer_model.py
kleinhe
init
c3d0293
raw
history blame
14.8 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: Vassilis Choutas, vassilis.choutas@tuebingen.mpg.de
from typing import Optional, Dict, Callable
import sys
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from loguru import logger
from SMPLX.transfer_model.utils import get_vertices_per_edge
from SMPLX.transfer_model.optimizers import build_optimizer, minimize
from SMPLX.transfer_model.utils import (
Tensor, batch_rodrigues, apply_deformation_transfer)
from SMPLX.transfer_model.losses import build_loss
def summary_closure(gt_vertices, var_dict, body_model, mask_ids=None):
param_dict = {}
for key, var in var_dict.items():
# Decode the axis-angles
if 'pose' in key or 'orient' in key:
param_dict[key] = batch_rodrigues(
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3)
else:
# Simply pass the variable
param_dict[key] = var
body_model_output = body_model(
return_full_pose=True, get_skin=True, **param_dict)
est_vertices = body_model_output.vertices
if mask_ids is not None:
est_vertices = est_vertices[:, mask_ids]
gt_vertices = gt_vertices[:, mask_ids]
v2v = (est_vertices - gt_vertices).pow(2).sum(dim=-1).sqrt().mean()
return {
'Vertex-to-Vertex': v2v * 1000}
def build_model_forward_closure(
body_model: nn.Module,
var_dict: Dict[str, Tensor],
per_part: bool = True,
part_key: Optional[str] = None,
jidx: Optional[int] = None,
part: Optional[Tensor] = None
) -> Callable:
if per_part:
cond = part is not None and part_key is not None and jidx is not None
assert cond, (
'When per-part is True, "part", "part_key", "jidx" must not be'
' None.'
)
def model_forward():
param_dict = {}
for key, var in var_dict.items():
if part_key == key:
param_dict[key] = batch_rodrigues(
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3)
param_dict[key][:, jidx] = batch_rodrigues(
part.reshape(-1, 3)).reshape(-1, 3, 3)
else:
# Decode the axis-angles
if 'pose' in key or 'orient' in key:
param_dict[key] = batch_rodrigues(
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3)
else:
# Simply pass the variable
param_dict[key] = var
return body_model(
return_full_pose=True, get_skin=True, **param_dict)
else:
def model_forward():
param_dict = {}
for key, var in var_dict.items():
# Decode the axis-angles
if 'pose' in key or 'orient' in key:
param_dict[key] = batch_rodrigues(
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3)
else:
# Simply pass the variable
param_dict[key] = var
return body_model(return_full_pose=True, get_skin=True,
**param_dict)
return model_forward
def build_edge_closure(
body_model: nn.Module,
var_dict: Dict[str, Tensor],
edge_loss: nn.Module,
optimizer_dict,
gt_vertices: Tensor,
per_part: bool = True,
part_key: Optional[str] = None,
jidx: Optional[int] = None,
part: Optional[Tensor] = None
) -> Callable:
''' Builds the closure for the edge objective
'''
optimizer = optimizer_dict['optimizer']
create_graph = optimizer_dict['create_graph']
if per_part:
params_to_opt = [part]
else:
params_to_opt = [p for key, p in var_dict.items() if 'pose' in key]
model_forward = build_model_forward_closure(
body_model, var_dict, per_part=per_part, part_key=part_key,
jidx=jidx, part=part)
def closure(backward=True):
if backward:
optimizer.zero_grad()
body_model_output = model_forward()
est_vertices = body_model_output.vertices
loss = edge_loss(est_vertices, gt_vertices)
if backward:
if create_graph:
# Use this instead of .backward to avoid GPU memory leaks
grads = torch.autograd.grad(
loss, params_to_opt, create_graph=True)
torch.autograd.backward(
params_to_opt, grads, create_graph=True)
else:
loss.backward()
return loss
return closure
def build_vertex_closure(
body_model: nn.Module,
var_dict: Dict[str, Tensor],
optimizer_dict,
gt_vertices: Tensor,
vertex_loss: nn.Module,
mask_ids=None,
per_part: bool = True,
part_key: Optional[str] = None,
jidx: Optional[int] = None,
part: Optional[Tensor] = None,
params_to_opt: Optional[Tensor] = None,
) -> Callable:
''' Builds the closure for the vertex objective
'''
optimizer = optimizer_dict['optimizer']
create_graph = optimizer_dict['create_graph']
model_forward = build_model_forward_closure(
body_model, var_dict, per_part=per_part, part_key=part_key,
jidx=jidx, part=part)
if params_to_opt is None:
params_to_opt = [p for key, p in var_dict.items()]
def closure(backward=True):
if backward:
optimizer.zero_grad()
body_model_output = model_forward()
est_vertices = body_model_output.vertices
loss = vertex_loss(
est_vertices[:, mask_ids] if mask_ids is not None else
est_vertices,
gt_vertices[:, mask_ids] if mask_ids is not None else gt_vertices)
if backward:
if create_graph:
# Use this instead of .backward to avoid GPU memory leaks
grads = torch.autograd.grad(
loss, params_to_opt, create_graph=True)
torch.autograd.backward(
params_to_opt, grads, create_graph=True)
else:
loss.backward()
return loss
return closure
def get_variables(
batch_size: int,
body_model: nn.Module,
dtype: torch.dtype = torch.float32
) -> Dict[str, Tensor]:
var_dict = {}
device = next(body_model.buffers()).device
if (body_model.name() == 'SMPL' or body_model.name() == 'SMPL+H' or
body_model.name() == 'SMPL-X'):
var_dict.update({
'transl': torch.zeros(
[batch_size, 3], device=device, dtype=dtype),
'global_orient': torch.zeros(
[batch_size, 1, 3], device=device, dtype=dtype),
'body_pose': torch.zeros(
[batch_size, body_model.NUM_BODY_JOINTS, 3],
device=device, dtype=dtype),
'betas': torch.zeros([batch_size, body_model.num_betas],
dtype=dtype, device=device),
})
if body_model.name() == 'SMPL+H' or body_model.name() == 'SMPL-X':
var_dict.update(
left_hand_pose=torch.zeros(
[batch_size, body_model.NUM_HAND_JOINTS, 3], device=device,
dtype=dtype),
right_hand_pose=torch.zeros(
[batch_size, body_model.NUM_HAND_JOINTS, 3], device=device,
dtype=dtype),
)
if body_model.name() == 'SMPL-X':
var_dict.update(
jaw_pose=torch.zeros([batch_size, 1, 3],
device=device, dtype=dtype),
leye_pose=torch.zeros([batch_size, 1, 3],
device=device, dtype=dtype),
reye_pose=torch.zeros([batch_size, 1, 3],
device=device, dtype=dtype),
expression=torch.zeros(
[batch_size, body_model.num_expression_coeffs],
device=device, dtype=dtype),
)
# Toggle gradients to True
for key, val in var_dict.items():
val.requires_grad_(True)
return var_dict
def run_fitting(
# exp_cfg,
batch: Dict[str, Tensor],
body_model: nn.Module,
def_matrix: Tensor,
mask_ids
) -> Dict[str, Tensor]:
''' Runs fitting
'''
vertices = batch['vertices']
faces = batch['faces']
batch_size = len(vertices)
dtype, device = vertices.dtype, vertices.device
# summary_steps = exp_cfg.get('summary_steps')
# interactive = exp_cfg.get('interactive')
summary_steps = 100
interactive = True
# Get the parameters from the model
var_dict = get_variables(batch_size, body_model)
# Build the optimizer object for the current batch
# optim_cfg = exp_cfg.get('optim', {})
optim_cfg = {'type': 'trust-ncg', 'lr': 1.0, 'gtol': 1e-06, 'ftol': -1.0, 'maxiters': 100, 'lbfgs': {'line_search_fn': 'strong_wolfe', 'max_iter': 50}, 'sgd': {'momentum': 0.9, 'nesterov': True}, 'adam': {'betas': [0.9, 0.999], 'eps': 1e-08, 'amsgrad': False}, 'trust_ncg': {'max_trust_radius': 1000.0, 'initial_trust_radius': 0.05, 'eta': 0.15, 'gtol': 1e-05}}
def_vertices = apply_deformation_transfer(def_matrix, vertices, faces)
if mask_ids is None:
f_sel = np.ones_like(body_model.faces[:, 0], dtype=np.bool_)
else:
f_per_v = [[] for _ in range(body_model.get_num_verts())]
[f_per_v[vv].append(iff) for iff, ff in enumerate(body_model.faces)
for vv in ff]
f_sel = list(set(tuple(sum([f_per_v[vv] for vv in mask_ids], []))))
vpe = get_vertices_per_edge(
body_model.v_template.detach().cpu().numpy(), body_model.faces[f_sel])
def log_closure():
return summary_closure(def_vertices, var_dict, body_model,
mask_ids=mask_ids)
# edge_fitting_cfg = exp_cfg.get('edge_fitting', {})
edge_fitting_cfg = {'per_part': False, 'reduction': 'mean'}
edge_loss = build_loss(type='vertex-edge', gt_edges=vpe, est_edges=vpe,
**edge_fitting_cfg)
edge_loss = edge_loss.to(device=device)
# vertex_fitting_cfg = exp_cfg.get('vertex_fitting', {})
vertex_fitting_cfg = {}
vertex_loss = build_loss(**vertex_fitting_cfg)
vertex_loss = vertex_loss.to(device=device)
per_part = edge_fitting_cfg.get('per_part', True)
logger.info(f'Per-part: {per_part}')
# Optimize edge-based loss to initialize pose
if per_part:
for key, var in tqdm(var_dict.items(), desc='Parts'):
if 'pose' not in key:
continue
for jidx in tqdm(range(var.shape[1]), desc='Joints'):
part = torch.zeros(
[batch_size, 3], dtype=dtype, device=device,
requires_grad=True)
# Build the optimizer for the current part
optimizer_dict = build_optimizer([part], optim_cfg)
closure = build_edge_closure(
body_model, var_dict, edge_loss, optimizer_dict,
def_vertices, per_part=per_part, part_key=key, jidx=jidx,
part=part)
minimize(optimizer_dict['optimizer'], closure,
params=[part],
summary_closure=log_closure,
summary_steps=summary_steps,
interactive=interactive,
**optim_cfg)
with torch.no_grad():
var[:, jidx] = part
else:
optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg)
closure = build_edge_closure(
body_model, var_dict, edge_loss, optimizer_dict,
def_vertices, per_part=per_part)
minimize(optimizer_dict['optimizer'], closure,
params=var_dict.values(),
summary_closure=log_closure,
summary_steps=summary_steps,
interactive=interactive,
**optim_cfg)
if 'translation' in var_dict:
optimizer_dict = build_optimizer([var_dict['translation']], optim_cfg)
closure = build_vertex_closure(
body_model, var_dict,
optimizer_dict,
def_vertices,
vertex_loss=vertex_loss,
mask_ids=mask_ids,
per_part=False,
params_to_opt=[var_dict['translation']],
)
# Optimize translation
minimize(optimizer_dict['optimizer'],
closure,
params=[var_dict['translation']],
summary_closure=log_closure,
summary_steps=summary_steps,
interactive=interactive,
**optim_cfg)
# Optimize all model parameters with vertex-based loss
optimizer_dict = build_optimizer(list(var_dict.values()), optim_cfg)
closure = build_vertex_closure(
body_model, var_dict,
optimizer_dict,
def_vertices,
vertex_loss=vertex_loss,
per_part=False,
mask_ids=mask_ids)
minimize(optimizer_dict['optimizer'], closure,
params=list(var_dict.values()),
summary_closure=log_closure,
summary_steps=summary_steps,
interactive=interactive,
**optim_cfg)
param_dict = {}
for key, var in var_dict.items():
# Decode the axis-angles
if 'pose' in key or 'orient' in key:
param_dict[key] = batch_rodrigues(
var.reshape(-1, 3)).reshape(len(var), -1, 3, 3)
else:
# Simply pass the variable
param_dict[key] = var
body_model_output = body_model(
return_full_pose=True, get_skin=True, **param_dict)
keys = ["vertices", "joints", "betas", "global_orient", "body_pose", "left_hand_pose", "right_hand_pose", "full_pose"]
for key in keys:
var_dict[key] = getattr(body_model_output, key)
var_dict['faces'] = body_model.faces
for key in var_dict.keys():
try:
var_dict[key] = var_dict[key].detach().cpu().numpy()
except:
pass
return var_dict