|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from lib.net.FBNet import define_G |
|
from lib.net.net_util import init_net, VGGLoss |
|
from lib.net.HGFilters import * |
|
from lib.net.BasePIFuNet import BasePIFuNet |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class Hallucinator(BasePIFuNet): |
|
''' |
|
HG PIFu network uses Hourglass stacks as the image filter. |
|
It does the following: |
|
1. Compute image feature stacks and store it in self.im_feat_list |
|
self.im_feat_list[-1] is the last stack (output stack) |
|
2. Calculate calibration |
|
3. If training, it index on every intermediate stacks, |
|
If testing, it index on the last stack. |
|
4. Classification. |
|
5. During training, error is calculated on all stacks. |
|
''' |
|
|
|
def __init__(self, cfg, error_term=nn.SmoothL1Loss()): |
|
|
|
super(Hallucinator, self).__init__(error_term=error_term) |
|
|
|
self.l1_loss = nn.SmoothL1Loss() |
|
|
|
self.opt = cfg.net |
|
|
|
if self.training: |
|
self.vgg_loss = [VGGLoss()] |
|
|
|
self.in_nmlB = [ |
|
item[0] for item in self.opt.in_nml |
|
if '_B' in item[0] or item[0] == 'image' |
|
] |
|
self.in_nmlL = [ |
|
item[0] for item in self.opt.in_nml |
|
if '_L' in item[0] or item[0] == 'image' |
|
] |
|
self.in_nmlB_dim = sum([ |
|
item[1] for item in self.opt.in_nml |
|
if '_B' in item[0] or item[0] == 'image' |
|
]) |
|
self.in_nmlL_dim = sum([ |
|
item[1] for item in self.opt.in_nml |
|
if '_L' in item[0] or item[0] == 'image' |
|
]) |
|
|
|
self.netB = define_G(self.in_nmlB_dim, 3, 64, "global", 4, 9, 1, 3, |
|
"instance") |
|
self.netL = define_G(self.in_nmlL_dim, 3, 64, "global", 4, 9, 1, 3, |
|
"instance") |
|
|
|
init_net(self) |
|
|
|
def forward(self, in_tensor): |
|
|
|
inB_list = [] |
|
inL_list = [] |
|
|
|
for name in self.in_nmlB: |
|
inB_list.append(in_tensor[name]) |
|
for name in self.in_nmlL: |
|
inL_list.append(in_tensor[name]) |
|
|
|
nmlB = self.netB(torch.cat(inB_list, dim=1)) |
|
nmlL = self.netL(torch.cat(inL_list, dim=1)) |
|
|
|
|
|
nmlB = nmlB / torch.norm(nmlB, dim=1, keepdim=True) |
|
nmlL = nmlL / torch.norm(nmlL, dim=1, keepdim=True) |
|
|
|
|
|
|
|
mask = (in_tensor['image'].abs().sum(dim=1, keepdim=True) != |
|
0.0).detach().float() |
|
|
|
nmlB = nmlB * mask |
|
|
|
|
|
return nmlB, nmlL |
|
|
|
def get_norm_error(self, prd_B, prd_L, tgt): |
|
"""calculate normal loss |
|
|
|
Args: |
|
pred (torch.tensor): [B, 6, 512, 512] |
|
tagt (torch.tensor): [B, 6, 512, 512] |
|
""" |
|
|
|
tgt_B, tgt_L = tgt['render_B'], tgt['render_L'] |
|
|
|
l1_B_loss = self.l1_loss(prd_B, tgt_B) |
|
l1_L_loss = self.l1_loss(prd_L, tgt_L) |
|
|
|
with torch.no_grad(): |
|
vgg_B_loss = self.vgg_loss[0](prd_B, tgt_B) |
|
vgg_L_loss = self.vgg_loss[0](prd_L, tgt_L) |
|
|
|
total_loss = [ |
|
5.0 * l1_B_loss + vgg_B_loss, 5.0 * l1_L_loss + vgg_L_loss |
|
] |
|
|
|
return total_loss |
|
|