ECON / lib /pixielib /pixie.py
Yuliang's picture
Support TEXTure
487ee6d
# -*- coding: utf-8 -*-
#
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# Using this computer program means that you agree to the terms
# in the LICENSE file included with this software distribution.
# Any use not explicitly granted by the LICENSE is prohibited.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# For comments or questions, please email us at pixie@tue.mpg.de
# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from skimage.io import imread
from .models.encoders import MLP, HRNEncoder, ResnetEncoder
from .models.moderators import TempSoftmaxFusion
from .models.SMPLX import SMPLX
from .utils import rotation_converter as converter
from .utils import tensor_cropper, util
from .utils.config import cfg
class PIXIE(object):
def __init__(self, config=None, device="cuda:0"):
if config is None:
self.cfg = cfg
else:
self.cfg = config
self.device = device
# parameters setting
self.param_list_dict = {}
for lst in self.cfg.params.keys():
param_list = cfg.params.get(lst)
self.param_list_dict[lst] = {i: cfg.model.get("n_" + i) for i in param_list}
# Build the models
self._create_model()
# Set up the cropping modules used to generate face/hand crops from the body predictions
self._setup_cropper()
def forward(self, data):
# encode + decode
param_dict = self.encode(
{"body": {"image": data}},
threthold=True,
keep_local=True,
copy_and_paste=False,
)
opdict = self.decode(param_dict["body"], param_type="body")
return opdict
def _setup_cropper(self):
self.Cropper = {}
for crop_part in ["head", "hand"]:
data_cfg = self.cfg.dataset[crop_part]
scale_size = (data_cfg.scale_min + data_cfg.scale_max) * 0.5
self.Cropper[crop_part] = tensor_cropper.Cropper(
crop_size=data_cfg.image_size,
scale=[scale_size, scale_size],
trans_scale=0,
)
def _create_model(self):
self.model_dict = {}
# Build all image encoders
# Hand encoder only works for right hand, for left hand, flip inputs and flip the results back
self.Encoder = {}
for key in self.cfg.network.encoder.keys():
if self.cfg.network.encoder.get(key).type == "resnet50":
self.Encoder[key] = ResnetEncoder().to(self.device)
elif self.cfg.network.encoder.get(key).type == "hrnet":
self.Encoder[key] = HRNEncoder().to(self.device)
self.model_dict[f"Encoder_{key}"] = self.Encoder[key].state_dict()
# Build the parameter regressors
self.Regressor = {}
for key in self.cfg.network.regressor.keys():
n_output = sum(self.param_list_dict[f"{key}_list"].values())
channels = ([2048] + self.cfg.network.regressor.get(key).channels + [n_output])
if self.cfg.network.regressor.get(key).type == "mlp":
self.Regressor[key] = MLP(channels=channels).to(self.device)
self.model_dict[f"Regressor_{key}"] = self.Regressor[key].state_dict()
# Build the extractors
# to extract separate head/left hand/right hand feature from body feature
self.Extractor = {}
for key in self.cfg.network.extractor.keys():
channels = [2048] + self.cfg.network.extractor.get(key).channels + [2048]
if self.cfg.network.extractor.get(key).type == "mlp":
self.Extractor[key] = MLP(channels=channels).to(self.device)
self.model_dict[f"Extractor_{key}"] = self.Extractor[key].state_dict()
# Build the moderators
self.Moderator = {}
for key in self.cfg.network.moderator.keys():
share_part = key.split("_")[0]
detach_inputs = self.cfg.network.moderator.get(key).detach_inputs
detach_feature = self.cfg.network.moderator.get(key).detach_feature
channels = [2048 * 2] + self.cfg.network.moderator.get(key).channels + [2]
self.Moderator[key] = TempSoftmaxFusion(
detach_inputs=detach_inputs,
detach_feature=detach_feature,
channels=channels,
).to(self.device)
self.model_dict[f"Moderator_{key}"] = self.Moderator[key].state_dict()
# Build the SMPL-X body model, which we also use to represent faces and
# hands, using the relevant parts only
self.smplx = SMPLX(self.cfg.model).to(self.device)
self.part_indices = self.smplx.part_indices
# -- resume model
model_path = self.cfg.pretrained_modelpath
if os.path.exists(model_path):
checkpoint = torch.load(model_path)
for key in self.model_dict.keys():
util.copy_state_dict(self.model_dict[key], checkpoint[key])
else:
print(f"pixie trained model path: {model_path} does not exist!")
exit()
# eval mode
for module in [self.Encoder, self.Regressor, self.Moderator, self.Extractor]:
for net in module.values():
net.eval()
def decompose_code(self, code, num_dict):
"""Convert a flattened parameter vector to a dictionary of parameters"""
code_dict = {}
start = 0
for key in num_dict:
end = start + int(num_dict[key])
code_dict[key] = code[:, start:end]
start = end
return code_dict
def part_from_body(self, image, part_key, points_dict, crop_joints=None):
"""crop part(head/left_hand/right_hand) out from body data, joints also change accordingly"""
assert part_key in ["head", "left_hand", "right_hand"]
assert "smplx_kpt" in points_dict.keys()
if part_key == "head":
# use face 68 kpts for cropping head image
indices_key = "face"
elif part_key == "left_hand":
indices_key = "left_hand"
elif part_key == "right_hand":
indices_key = "right_hand"
# get points for cropping
part_indices = self.part_indices[indices_key]
if crop_joints is not None:
points_for_crop = crop_joints[:, part_indices]
else:
points_for_crop = points_dict["smplx_kpt"][:, part_indices]
# crop
cropper_key = "hand" if "hand" in part_key else part_key
points_scale = image.shape[-2:]
cropped_image, tform = self.Cropper[cropper_key].crop(image, points_for_crop, points_scale)
# transform points(must be normalized to [-1.1]) accordingly
cropped_points_dict = {}
for points_key in points_dict.keys():
points = points_dict[points_key]
cropped_points = self.Cropper[cropper_key].transform_points(
points, tform, points_scale, normalize=True
)
cropped_points_dict[points_key] = cropped_points
return cropped_image, cropped_points_dict
@torch.no_grad()
def encode(
self,
data,
threthold=True,
keep_local=True,
copy_and_paste=False,
body_only=False,
):
"""Encode images to smplx parameters
Args:
data: dict
key: image_type (body/head/hand)
value:
image: [bz, 3, 224, 224], range [0,1]
image_hd(needed if key==body): a high res version of image, only for cropping parts from body image
head_image: optinal, well-cropped head from body image
left_hand_image: optinal, well-cropped left hand from body image
right_hand_image: optinal, well-cropped right hand from body image
Returns:
param_dict: dict
key: image_type (body/head/hand)
value: param_dict
"""
for key in data.keys():
assert key in ["body", "head", "hand"]
feature = {}
param_dict = {}
# Encode features
for key in data.keys():
part = key
# encode feature
feature[key] = {}
feature[key][part] = self.Encoder[part](data[key]["image"])
# for head/hand image
if key == "head" or key == "hand":
# predict head/hand-only parameters from part feature
part_dict = self.decompose_code(
self.Regressor[part](feature[key][part]),
self.param_list_dict[f"{part}_list"],
)
# if input is part data, skip feature fusion: share feature is the same as part feature
# then predict share parameters
feature[key][f"{key}_share"] = feature[key][key]
share_dict = self.decompose_code(
self.Regressor[f"{part}_share"](feature[key][f"{part}_share"]),
self.param_list_dict[f"{part}_share_list"],
)
# compose parameters
param_dict[key] = {**share_dict, **part_dict}
# for body image
if key == "body":
fusion_weight = {}
f_body = feature["body"]["body"]
# extract part feature
for part_name in ["head", "left_hand", "right_hand"]:
feature["body"][f"{part_name}_share"] = self.Extractor[f"{part_name}_share"](
f_body
)
# -- check if part crops are given, if not, crop parts by coarse body estimation
if (
"head_image" not in data[key].keys() or
"left_hand_image" not in data[key].keys() or
"right_hand_image" not in data[key].keys()
):
# - run without fusion to get coarse estimation, for cropping parts
# body only
body_dict = self.decompose_code(
self.Regressor[part](feature[key][part]),
self.param_list_dict[part + "_list"],
)
# head share
head_share_dict = self.decompose_code(
self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
self.param_list_dict["head" + "_share_list"],
)
# right hand share
right_hand_share_dict = self.decompose_code(
self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
self.param_list_dict["hand" + "_share_list"],
)
# left hand share
left_hand_share_dict = self.decompose_code(
self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
self.param_list_dict["hand" + "_share_list"],
)
# change the dict name from right to left
left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop(
"right_hand_pose"
)
left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
"right_wrist_pose"
)
param_dict[key] = {
**body_dict,
**head_share_dict,
**left_hand_share_dict,
**right_hand_share_dict,
}
if body_only:
param_dict["moderator_weight"] = None
return param_dict
prediction_body_only = self.decode(param_dict[key], param_type="body")
# crop
for part_name in ["head", "left_hand", "right_hand"]:
part = part_name.split("_")[-1]
points_dict = {
"smplx_kpt": prediction_body_only["smplx_kpt"],
"trans_verts": prediction_body_only["transformed_vertices"],
}
image_hd = torchvision.transforms.Resize(1024)(data["body"]["image"])
cropped_image, cropped_joints_dict = self.part_from_body(
image_hd, part_name, points_dict
)
data[key][part_name + "_image"] = cropped_image
# -- encode features from part crops, then fuse feature using the weight from moderator
for part_name in ["head", "left_hand", "right_hand"]:
part = part_name.split("_")[-1]
cropped_image = data[key][part_name + "_image"]
# if left hand, flip it as if it is right hand
if part_name == "left_hand":
cropped_image = torch.flip(cropped_image, dims=(-1, ))
# run part regressor
f_part = self.Encoder[part](cropped_image)
part_dict = self.decompose_code(
self.Regressor[part](f_part),
self.param_list_dict[f"{part}_list"],
)
part_share_dict = self.decompose_code(
self.Regressor[f"{part}_share"](f_part),
self.param_list_dict[f"{part}_share_list"],
)
param_dict["body_" + part_name] = {**part_dict, **part_share_dict}
# moderator to assign weight, then integrate features
f_body_out, f_part_out, f_weight = self.Moderator[f"{part}_share"](
feature["body"][f"{part_name}_share"], f_part, work=True
)
if copy_and_paste:
# copy and paste strategy always trusts the results from part
feature["body"][f"{part_name}_share"] = f_part
elif threthold and part == "hand":
# for hand, if part weight > 0.7 (very confident, then fully trust part)
part_w = f_weight[:, [1]]
part_w[part_w > 0.7] = 1.0
f_body_out = (
feature["body"][f"{part_name}_share"] * (1.0 - part_w) + f_part * part_w
)
feature["body"][f"{part_name}_share"] = f_body_out
else:
feature["body"][f"{part_name}_share"] = f_body_out
fusion_weight[part_name] = f_weight
# save weights from moderator, that can be further used for optimization/running specific tasks on parts
param_dict["moderator_weight"] = fusion_weight
# -- predict parameters from fused body feature
# head share
head_share_dict = self.decompose_code(
self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
self.param_list_dict["head" + "_share_list"],
)
# right hand share
right_hand_share_dict = self.decompose_code(
self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
self.param_list_dict["hand" + "_share_list"],
)
# left hand share
left_hand_share_dict = self.decompose_code(
self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
self.param_list_dict["hand" + "_share_list"],
)
# change the dict name from right to left
left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop("right_hand_pose")
left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
"right_wrist_pose"
)
param_dict["body"] = {
**body_dict,
**head_share_dict,
**left_hand_share_dict,
**right_hand_share_dict,
}
# copy tex param from head param dict to body param dict
param_dict["body"]["tex"] = param_dict["body_head"]["tex"]
param_dict["body"]["light"] = param_dict["body_head"]["light"]
if keep_local:
# for local change that will not affect whole body and produce unnatral pose, trust part
param_dict[key]["exp"] = param_dict["body_head"]["exp"]
param_dict[key]["right_hand_pose"] = param_dict["body_right_hand"][
"right_hand_pose"]
param_dict[key]["left_hand_pose"] = param_dict["body_left_hand"][
"right_hand_pose"]
return param_dict
def convert_pose(self, param_dict, param_type):
"""Convert pose parameters to rotation matrix
Args:
param_dict: smplx parameters
param_type: should be one of body/head/hand
Returns:
param_dict: smplx parameters
"""
assert param_type in ["body", "head", "hand"]
# convert pose representations: the output from network are continous repre or axis angle,
# while the input pose for smplx need to be rotation matrix
for key in param_dict:
if "pose" in key and "jaw" not in key:
param_dict[key] = converter.batch_cont2matrix(param_dict[key])
if param_type == "body" or param_type == "head":
param_dict["jaw_pose"] = converter.batch_euler2matrix(param_dict["jaw_pose"]
)[:, None, :, :]
# complement params if it's not in given param dict
if param_type == "head":
batch_size = param_dict["shape"].shape[0]
param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
param_dict["global_pose"] = param_dict["head_pose"]
param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
param_dict["right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
elif param_type == "hand":
batch_size = param_dict["right_hand_pose"].shape[0]
param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
dtype = param_dict["right_hand_pose"].dtype
device = param_dict["right_hand_pose"].device
x_180_pose = (torch.eye(3, dtype=dtype, device=device).unsqueeze(0).repeat(1, 1, 1))
x_180_pose[0, 2, 2] = -1.0
x_180_pose[0, 1, 1] = -1.0
param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
param_dict["shape"] = self.smplx.shape_params.expand(batch_size, -1)
param_dict["exp"] = self.smplx.expression_params.expand(batch_size, -1)
param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)[:, :self.param_list_dict["body_list"]["partbody_pose"]]
param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
batch_size, -1, -1, -1
)
elif param_type == "body":
# the predcition from the head and hand share regressor is always absolute pose
batch_size = param_dict["shape"].shape[0]
param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
param_dict["abs_left_wrist_pose"] = param_dict["left_wrist_pose"].clone()
# the body-hand share regressor is working for right hand
# so we assume body network get the flipped feature for the left hand. then get the parameters
# then we need to flip it back to left, which matches the input left hand
param_dict["left_wrist_pose"] = util.flip_pose(param_dict["left_wrist_pose"])
param_dict["left_hand_pose"] = util.flip_pose(param_dict["left_hand_pose"])
else:
exit()
return param_dict
def decode(self, param_dict, param_type):
"""Decode model parameters to smplx vertices & joints & texture
Args:
param_dict: smplx parameters
param_type: should be one of body/head/hand
Returns:
predictions: smplx predictions
"""
if "jaw_pose" in param_dict.keys() and len(param_dict["jaw_pose"].shape) == 2:
self.convert_pose(param_dict, param_type)
elif param_dict["right_wrist_pose"].shape[-1] == 6:
self.convert_pose(param_dict, param_type)
# concatenate body pose
partbody_pose = param_dict["partbody_pose"]
param_dict["body_pose"] = torch.cat(
[
partbody_pose[:, :11],
param_dict["neck_pose"],
partbody_pose[:, 11:11 + 2],
param_dict["head_pose"],
partbody_pose[:, 13:13 + 4],
param_dict["left_wrist_pose"],
param_dict["right_wrist_pose"],
],
dim=1,
)
# change absolute head&hand pose to relative pose according to rest body pose
if param_type == "head" or param_type == "body":
param_dict["body_pose"] = self.smplx.pose_abs2rel(
param_dict["global_pose"], param_dict["body_pose"], abs_joint="head"
)
if param_type == "hand" or param_type == "body":
param_dict["body_pose"] = self.smplx.pose_abs2rel(
param_dict["global_pose"],
param_dict["body_pose"],
abs_joint="left_wrist",
)
param_dict["body_pose"] = self.smplx.pose_abs2rel(
param_dict["global_pose"],
param_dict["body_pose"],
abs_joint="right_wrist",
)
if self.cfg.model.check_pose:
# check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose)
# xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left)
for pose_ind in [14]: # head [15-1, 20-1, 21-1]:
curr_pose = param_dict["body_pose"][:, pose_ind]
euler_pose = converter._compute_euler_from_matrix(curr_pose)
for i, max_angle in enumerate([20, 70, 10]):
euler_pose_curr = euler_pose[:, i]
euler_pose_curr[euler_pose_curr != torch.clamp(
euler_pose_curr,
min=-max_angle * np.pi / 180,
max=max_angle * np.pi / 180,
)] = 0.0
param_dict["body_pose"][:, pose_ind] = converter.batch_euler2matrix(euler_pose)
# SMPLX
verts, landmarks, joints = self.smplx(
shape_params=param_dict["shape"],
expression_params=param_dict["exp"],
global_pose=param_dict["global_pose"],
body_pose=param_dict["body_pose"],
jaw_pose=param_dict["jaw_pose"],
left_hand_pose=param_dict["left_hand_pose"],
right_hand_pose=param_dict["right_hand_pose"],
)
smplx_kpt3d = joints.clone()
# projection
cam = param_dict[param_type + "_cam"]
trans_verts = util.batch_orth_proj(verts, cam)
predicted_landmarks = util.batch_orth_proj(landmarks, cam)[:, :, :2]
predicted_joints = util.batch_orth_proj(joints, cam)[:, :, :2]
prediction = {
"vertices": verts,
"transformed_vertices": trans_verts,
"face_kpt": predicted_landmarks,
"smplx_kpt": predicted_joints,
"smplx_kpt3d": smplx_kpt3d,
"joints": joints,
"cam": param_dict[param_type + "_cam"],
}
# change the order of face keypoints, to be the same as "standard" 68 keypoints
prediction["face_kpt"] = torch.cat([
prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]
],
dim=1)
prediction.update(param_dict)
return prediction
def decode_Tpose(self, param_dict):
"""return body mesh in T pose, support body and head param dict only"""
verts, _, _ = self.smplx(
shape_params=param_dict["shape"],
expression_params=param_dict["exp"],
jaw_pose=param_dict["jaw_pose"],
)
return verts