|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from lib.renderer.mesh import compute_normal_batch |
|
from lib.dataset.mesh_util import feat_select, read_smpl_constants, surface_field_deformation |
|
from lib.net.NormalNet import NormalNet |
|
from lib.net.MLP import MLP, DeformationMLP, TransformerEncoderLayer, SDF2Density, SDF2Occ |
|
from lib.net.spatial import SpatialEncoder |
|
from lib.dataset.PointFeat import PointFeat |
|
from lib.dataset.mesh_util import SMPLX |
|
from lib.net.VE import VolumeEncoder |
|
from lib.net.ResBlkPIFuNet import ResnetFilter |
|
from lib.net.UNet import UNet |
|
from lib.net.HGFilters import * |
|
from lib.net.Transformer import ViTVQ |
|
from termcolor import colored |
|
from lib.net.BasePIFuNet import BasePIFuNet |
|
import torch.nn as nn |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import torch.nn.functional as F |
|
from lib.net.nerf_util import raw2outputs |
|
|
|
|
|
def normalize(tensor): |
|
min_val = tensor.min() |
|
max_val = tensor.max() |
|
normalized_tensor = (tensor - min_val) / (max_val - min_val) |
|
return normalized_tensor |
|
|
|
def visualize_feature_map(feature_map, title, filename): |
|
feature_map=feature_map.permute(0, 2, 3, 1) |
|
|
|
sample_index = 0 |
|
sample = feature_map[sample_index] |
|
|
|
|
|
channel_index = 0 |
|
channel = sample[:, :, channel_index] |
|
channel= normalize(channel) |
|
|
|
plt.imshow(channel.cpu().numpy(), cmap='hot') |
|
|
|
|
|
plt.axis('off') |
|
plt.savefig(filename, dpi=300,bbox_inches='tight', pad_inches=0) |
|
plt.close() |
|
|
|
|
|
class HGPIFuNet(BasePIFuNet): |
|
""" |
|
HG PIFu network uses Hourglass stacks as the image filter. |
|
It does the following: |
|
1. Compute image feature stacks and store it in self.im_feat_list |
|
self.im_feat_list[-1] is the last stack (output stack) |
|
2. Calculate calibration |
|
3. If training, it index on every intermediate stacks, |
|
If testing, it index on the last stack. |
|
4. Classification. |
|
5. During training, error is calculated on all stacks. |
|
""" |
|
|
|
def __init__(self, |
|
cfg, |
|
projection_mode="orthogonal", |
|
error_term=nn.MSELoss()): |
|
|
|
super(HGPIFuNet, self).__init__(projection_mode=projection_mode, |
|
error_term=error_term) |
|
|
|
self.l1_loss = nn.SmoothL1Loss() |
|
self.opt = cfg.net |
|
self.root = cfg.root |
|
self.overfit = cfg.overfit |
|
|
|
channels_IF = self.opt.mlp_dim |
|
|
|
self.use_filter = self.opt.use_filter |
|
self.prior_type = self.opt.prior_type |
|
self.smpl_feats = self.opt.smpl_feats |
|
|
|
self.smpl_dim = self.opt.smpl_dim |
|
self.voxel_dim = self.opt.voxel_dim |
|
self.hourglass_dim = self.opt.hourglass_dim |
|
|
|
self.in_geo = [item[0] for item in self.opt.in_geo] |
|
self.in_nml = [item[0] for item in self.opt.in_nml] |
|
|
|
self.in_geo_dim = sum([item[1] for item in self.opt.in_geo]) |
|
self.in_nml_dim = sum([item[1] for item in self.opt.in_nml]) |
|
|
|
self.in_total = self.in_geo + self.in_nml |
|
self.smpl_feat_dict = None |
|
self.smplx_data = SMPLX() |
|
|
|
image_lst = [0, 1, 2] |
|
normal_F_lst = [0, 1, 2] if "image" not in self.in_geo else [3, 4, 5] |
|
normal_B_lst = [3, 4, 5] if "image" not in self.in_geo else [6, 7, 8] |
|
|
|
|
|
|
|
if self.prior_type in ["icon", "keypoint"]: |
|
if "image" in self.in_geo: |
|
self.channels_filter = [ |
|
image_lst + normal_F_lst, |
|
image_lst + normal_B_lst, |
|
] |
|
else: |
|
self.channels_filter = [normal_F_lst, normal_B_lst] |
|
|
|
else: |
|
if "image" in self.in_geo: |
|
self.channels_filter = [ |
|
image_lst + normal_F_lst + normal_B_lst |
|
] |
|
else: |
|
self.channels_filter = [normal_F_lst + normal_B_lst] |
|
|
|
use_vis = (self.prior_type in ["icon", "keypoint" |
|
]) and ("vis" in self.smpl_feats) |
|
if self.prior_type in ["pamir", "pifu"]: |
|
use_vis = 1 |
|
|
|
if self.use_filter: |
|
channels_IF[0] = (self.hourglass_dim) * (2 - use_vis) |
|
else: |
|
channels_IF[0] = len(self.channels_filter[0]) * (2 - use_vis) |
|
|
|
if self.prior_type in ["icon", "keypoint"]: |
|
channels_IF[0] += self.smpl_dim |
|
|
|
elif self.prior_type == "pifu": |
|
channels_IF[0] += 1 |
|
else: |
|
print(f"don't support {self.prior_type}!") |
|
|
|
self.base_keys = ["smpl_verts", "smpl_faces"] |
|
|
|
self.icon_keys = self.base_keys + [ |
|
f"smpl_{feat_name}" for feat_name in self.smpl_feats |
|
] |
|
self.keypoint_keys = self.base_keys + [ |
|
f"smpl_{feat_name}" for feat_name in self.smpl_feats |
|
] |
|
|
|
self.pamir_keys = [ |
|
"voxel_verts", "voxel_faces", "pad_v_num", "pad_f_num" |
|
] |
|
self.pifu_keys = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.deform_dim=64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.image_filter=ViTVQ(image_size=512,channels=9) |
|
|
|
self.mlp=TransformerEncoderLayer(skips=4,multires=6,opt=self.opt) |
|
|
|
|
|
self.color_loss=nn.L1Loss() |
|
self.sp_encoder = SpatialEncoder() |
|
self.step=0 |
|
self.features_costume=None |
|
|
|
|
|
if self.use_filter: |
|
if self.opt.gtype == "HGPIFuNet": |
|
self.F_filter = HGFilter(self.opt, self.opt.num_stack, |
|
len(self.channels_filter[0])) |
|
|
|
|
|
|
|
else: |
|
print( |
|
colored(f"Backbone {self.opt.gtype} is unimplemented", |
|
"green")) |
|
|
|
summary_log = (f"{self.prior_type.upper()}:\n" + |
|
f"w/ Global Image Encoder: {self.use_filter}\n" + |
|
f"Image Features used by MLP: {self.in_geo}\n") |
|
|
|
if self.prior_type == "icon": |
|
summary_log += f"Geometry Features used by MLP: {self.smpl_feats}\n" |
|
summary_log += f"Dim of Image Features (local): {3 if (use_vis and not self.use_filter) else 6}\n" |
|
summary_log += f"Dim of Geometry Features (ICON): {self.smpl_dim}\n" |
|
elif self.prior_type == "keypoint": |
|
summary_log += f"Geometry Features used by MLP: {self.smpl_feats}\n" |
|
summary_log += f"Dim of Image Features (local): {3 if (use_vis and not self.use_filter) else 6}\n" |
|
summary_log += f"Dim of Geometry Features (Keypoint): {self.smpl_dim}\n" |
|
elif self.prior_type == "pamir": |
|
summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n" |
|
summary_log += f"Dim of Geometry Features (PaMIR): {self.voxel_dim}\n" |
|
else: |
|
summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n" |
|
summary_log += f"Dim of Geometry Features (PIFu): 1 (z-value)\n" |
|
|
|
summary_log += f"Dim of MLP's first layer: {channels_IF[0]}\n" |
|
|
|
print(colored(summary_log, "yellow")) |
|
|
|
self.normal_filter = NormalNet(cfg) |
|
|
|
init_net(self, init_type="normal") |
|
|
|
def get_normal(self, in_tensor_dict): |
|
|
|
|
|
if (not self.training) and (not self.overfit): |
|
|
|
with torch.no_grad(): |
|
feat_lst = [] |
|
if "image" in self.in_geo: |
|
feat_lst.append( |
|
in_tensor_dict["image"]) |
|
if "normal_F" in self.in_geo and "normal_B" in self.in_geo: |
|
if ("normal_F" not in in_tensor_dict.keys() |
|
or "normal_B" not in in_tensor_dict.keys()): |
|
(nmlF, nmlB) = self.normal_filter(in_tensor_dict) |
|
else: |
|
nmlF = in_tensor_dict["normal_F"] |
|
nmlB = in_tensor_dict["normal_B"] |
|
feat_lst.append(nmlF) |
|
feat_lst.append(nmlB) |
|
in_filter = torch.cat(feat_lst, dim=1) |
|
|
|
else: |
|
in_filter = torch.cat([in_tensor_dict[key] for key in self.in_geo], |
|
dim=1) |
|
|
|
return in_filter |
|
|
|
def get_mask(self, in_filter, size=128): |
|
|
|
mask = (F.interpolate( |
|
in_filter[:, self.channels_filter[0]], |
|
size=(size, size), |
|
mode="bilinear", |
|
align_corners=True, |
|
).abs().sum(dim=1, keepdim=True) != 0.0) |
|
|
|
return mask |
|
|
|
|
|
def filter(self, in_tensor_dict, return_inter=False): |
|
""" |
|
Filter the input images |
|
store all intermediate features. |
|
:param images: [B, C, H, W] input images |
|
""" |
|
|
|
in_filter = self.get_normal(in_tensor_dict) |
|
image= in_tensor_dict["image"] |
|
fuse_image=torch.cat([image,in_filter], dim=1) |
|
smpl_normals={ |
|
"T_normal_B":in_tensor_dict['normal_B'], |
|
"T_normal_R":in_tensor_dict['T_normal_R'], |
|
"T_normal_L":in_tensor_dict['T_normal_L'] |
|
} |
|
features_G = [] |
|
|
|
|
|
|
|
if self.prior_type in ["icon", "keypoint"]: |
|
if self.use_filter: |
|
triplane_features = self.image_filter(fuse_image,smpl_normals) |
|
|
|
features_F = self.F_filter(in_filter[:, |
|
self.channels_filter[0]] |
|
) |
|
features_B = self.F_filter(in_filter[:, |
|
self.channels_filter[1]] |
|
) |
|
else: |
|
assert 0 |
|
|
|
F_plane_feat,B_plane_feat,R_plane_feat,L_plane_feat=triplane_features |
|
|
|
refine_F_plane_feat=F_plane_feat |
|
features_G.append(refine_F_plane_feat) |
|
features_G.append(B_plane_feat) |
|
features_G.append(R_plane_feat) |
|
features_G.append(L_plane_feat) |
|
features_G.append(torch.cat([features_F[-1],features_B[-1]], dim=1)) |
|
|
|
else: |
|
assert 0 |
|
|
|
self.smpl_feat_dict = { |
|
k: in_tensor_dict[k] if k in in_tensor_dict.keys() else None |
|
for k in getattr(self, f"{self.prior_type}_keys") |
|
} |
|
if 'animated_smpl_verts' not in in_tensor_dict.keys(): |
|
self.point_feat_extractor = PointFeat(self.smpl_feat_dict["smpl_verts"], |
|
self.smpl_feat_dict["smpl_faces"]) |
|
else: |
|
assert 0 |
|
|
|
self.features_G = features_G |
|
|
|
|
|
if not self.training: |
|
features_out = features_G |
|
else: |
|
features_out = features_G |
|
|
|
if return_inter: |
|
return features_out, in_filter |
|
else: |
|
return features_out |
|
|
|
|
|
|
|
def query(self, features, points, calibs, transforms=None,type='shape'): |
|
|
|
xyz = self.projection(points, calibs, transforms) |
|
|
|
(xy, z) = xyz.split([2, 1], dim=1) |
|
|
|
|
|
zy=torch.cat([xyz[:,2:3],xyz[:,1:2]],dim=1) |
|
|
|
in_cube = (xyz > -1.0) & (xyz < 1.0) |
|
in_cube = in_cube.all(dim=1, keepdim=True).detach().float() |
|
|
|
preds_list = [] |
|
|
|
|
|
if self.prior_type in ["icon", "keypoint"]: |
|
|
|
|
|
|
|
densely_smpl=self.smpl_feat_dict['smpl_verts'].permute(0,2,1) |
|
|
|
smpl_vis=self.smpl_feat_dict['smpl_vis'].permute(0,2,1) |
|
|
|
|
|
|
|
|
|
(smpl_xy,smpl_z)=densely_smpl.split([2,1],dim=1) |
|
smpl_zy=torch.cat([densely_smpl[:,2:3],densely_smpl[:,1:2]],dim=1) |
|
|
|
point_feat_out = self.point_feat_extractor.query( |
|
xyz.permute(0, 2, 1).contiguous(), self.smpl_feat_dict) |
|
vis=point_feat_out['vis'].permute(0,2,1) |
|
|
|
feat_lst = [ |
|
point_feat_out[key] for key in self.smpl_feats |
|
if key in point_feat_out.keys() |
|
] |
|
smpl_feat = torch.cat(feat_lst, dim=2).permute(0, 2, 1) |
|
|
|
if len(features)==5: |
|
|
|
F_plane_feat1,F_plane_feat2=features[0].chunk(2,dim=1) |
|
B_plane_feat1,B_plane_feat2=features[1].chunk(2,dim=1) |
|
R_plane_feat1,R_plane_feat2=features[2].chunk(2,dim=1) |
|
L_plane_feat1,L_plane_feat2=features[3].chunk(2,dim=1) |
|
in_feat=features[4] |
|
|
|
|
|
F_feat=self.index(F_plane_feat1,xy) |
|
B_feat=self.index(B_plane_feat1,xy) |
|
R_feat=self.index(R_plane_feat1,zy) |
|
L_feat=self.index(L_plane_feat1,zy) |
|
normal_feat=feat_select(self.index(in_feat, xy),vis) |
|
three_plane_feat=(B_feat+R_feat+L_feat)/3 |
|
triplane_feat=torch.cat([F_feat,three_plane_feat],dim=1) |
|
|
|
|
|
smpl_F_feat=self.index(F_plane_feat2,smpl_xy) |
|
smpl_B_feat=self.index(B_plane_feat2,smpl_xy) |
|
smpl_R_feat=self.index(R_plane_feat2,smpl_zy) |
|
smpl_L_feat=self.index(L_plane_feat2,smpl_zy) |
|
|
|
|
|
|
|
smpl_three_plane_feat=(smpl_B_feat+smpl_R_feat+smpl_L_feat)/3 |
|
smpl_triplane_feat=torch.cat([smpl_F_feat,smpl_three_plane_feat],dim=1) |
|
bary_centric_feat=self.point_feat_extractor.query_barycentirc_feats(xyz.permute(0,2,1).contiguous() |
|
,smpl_triplane_feat.permute(0,2,1)) |
|
|
|
|
|
final_feat=torch.cat([triplane_feat,bary_centric_feat.permute(0,2,1),normal_feat],dim=1) |
|
|
|
if self.features_costume is not None: |
|
assert 0 |
|
if type=='shape': |
|
if 'animated_smpl_verts' in self.smpl_feat_dict.keys(): |
|
animated_smpl=self.smpl_feat_dict['animated_smpl_verts'] |
|
|
|
occ=self.mlp(xyz.permute(0,2,1).contiguous(),animated_smpl, |
|
final_feat,smpl_feat,training=self.training,type=type) |
|
else: |
|
|
|
occ=self.mlp(xyz.permute(0,2,1).contiguous(),densely_smpl.permute(0,2,1), |
|
final_feat,smpl_feat,training=self.training,type=type) |
|
occ=occ*in_cube |
|
preds_list.append(occ) |
|
|
|
elif type=='color': |
|
if 'animated_smpl_verts' in self.smpl_feat_dict.keys(): |
|
animated_smpl=self.smpl_feat_dict['animated_smpl_verts'] |
|
color_preds=self.mlp(xyz.permute(0,2,1).contiguous(),animated_smpl, |
|
final_feat,smpl_feat,training=self.training,type=type) |
|
|
|
|
|
else: |
|
color_preds=self.mlp(xyz.permute(0,2,1).contiguous(),densely_smpl.permute(0,2,1), |
|
final_feat,smpl_feat,training=self.training,type=type) |
|
preds_list.append(color_preds) |
|
|
|
return preds_list |
|
|
|
|
|
|
|
|
|
def get_error(self, preds_if_list, labels): |
|
"""calculate error |
|
|
|
Args: |
|
preds_list (list): list of torch.tensor(B, 3, N) |
|
labels (torch.tensor): (B, N_knn, N) |
|
|
|
Returns: |
|
torch.tensor: error |
|
""" |
|
error_if = 0 |
|
|
|
for pred_id in range(len(preds_if_list)): |
|
pred_if = preds_if_list[pred_id] |
|
error_if += F.binary_cross_entropy(pred_if, labels) |
|
|
|
error_if /= len(preds_if_list) |
|
|
|
return error_if |
|
|
|
|
|
def forward(self, in_tensor_dict): |
|
|
|
sample_tensor = in_tensor_dict["sample"] |
|
calib_tensor = in_tensor_dict["calib"] |
|
label_tensor = in_tensor_dict["label"] |
|
|
|
color_sample=in_tensor_dict["sample_color"] |
|
color_label=in_tensor_dict["color"] |
|
|
|
|
|
in_feat = self.filter(in_tensor_dict) |
|
|
|
|
|
|
|
preds_if_list = self.query(in_feat, |
|
sample_tensor, |
|
calib_tensor,type='shape') |
|
|
|
BCEloss = self.get_error(preds_if_list, label_tensor) |
|
|
|
color_preds=self.query(in_feat, |
|
color_sample, |
|
calib_tensor,type='color') |
|
color_loss=self.color_loss(color_preds[0],color_label) |
|
|
|
|
|
|
|
if self.training: |
|
|
|
self.color3d_loss= color_loss |
|
error=BCEloss+color_loss |
|
self.grad_loss=torch.tensor(0.).float().to(BCEloss.device) |
|
else: |
|
error=BCEloss |
|
|
|
return preds_if_list[-1].detach(), error |
|
|