|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import torch |
|
import math |
|
from modules.real3d.segformer import SegFormerImg2PlaneBackbone, SegFormerSECC2PlaneBackbone |
|
from modules.real3d.img2plane_baseline import OSAvatar_Img2plane |
|
from modules.img2plane.img2plane_model import Img2PlaneModel |
|
from utils.commons.hparams import hparams |
|
|
|
|
|
class OSAvatarSECC_Img2plane(OSAvatar_Img2plane): |
|
def __init__(self, hp=None): |
|
super().__init__(hp=hp) |
|
hparams = self.hparams |
|
|
|
self.cano_img2plane_backbone = self.img2plane_backbone |
|
del self.img2plane_backbone |
|
self.secc_img2plane_backbone = SegFormerSECC2PlaneBackbone(mode=hparams['secc_segformer_scale'], out_channels=3*self.triplane_hid_dim*self.triplane_depth, pncc_cond_mode=hparams['pncc_cond_mode']) |
|
self.lambda_pertube_blink_secc = torch.nn.Parameter(torch.tensor([0.001]), requires_grad=False) |
|
self.lambda_pertube_secc = torch.nn.Parameter(torch.tensor([0.001]), requires_grad=False) |
|
|
|
def on_train_full_model(self): |
|
self.requires_grad_(True) |
|
|
|
def on_train_nerf(self): |
|
self.cano_img2plane_backbone.requires_grad_(True) |
|
self.secc_img2plane_backbone.requires_grad_(True) |
|
self.decoder.requires_grad_(True) |
|
self.superresolution.requires_grad_(False) |
|
|
|
def on_train_superresolution(self): |
|
self.cano_img2plane_backbone.requires_grad_(False) |
|
self.secc_img2plane_backbone.requires_grad_(False) |
|
self.decoder.requires_grad_(False) |
|
self.superresolution.requires_grad_(True) |
|
|
|
def cal_cano_plane(self, img, cond=None, **kwargs): |
|
hparams = self.hparams |
|
planes = cano_planes = self.cano_img2plane_backbone(img, cond, **kwargs) |
|
if hparams.get("triplane_feature_type", "triplane") in ['triplane', 'trigrid']: |
|
planes = planes.view(len(planes), 3, self.triplane_hid_dim*self.triplane_depth, planes.shape[-2], planes.shape[-1]) |
|
elif hparams.get("triplane_feature_type", "triplane") in ['trigrid_v2']: |
|
b, k, cd, h, w = planes.shape |
|
planes = planes.reshape([b, k*cd, h, w]) |
|
planes = self.plane2grid_module(planes) |
|
planes = planes.reshape([b, k, cd, h, w]) |
|
else: |
|
raise NotImplementedError() |
|
return planes |
|
|
|
def cal_secc_plane(self, cond): |
|
cano_pncc, src_pncc, tgt_pncc = cond['cond_cano'], cond['cond_src'], cond['cond_tgt'] |
|
if self.hparams.get("pncc_cond_mode", "cano_tgt") == 'cano_src_tgt': |
|
inp_pncc = torch.cat([cano_pncc, src_pncc, tgt_pncc], dim=1) |
|
else: |
|
inp_pncc = torch.cat([cano_pncc, tgt_pncc], dim=1) |
|
secc_planes = self.secc_img2plane_backbone(inp_pncc) |
|
return secc_planes |
|
|
|
def cal_plane_given_cano(self, cano_planes, cond=None): |
|
|
|
secc_planes = self.cal_secc_plane(cond) |
|
if self.hparams.get("phase1_plane_fusion_mode", "add") == 'add': |
|
planes = cano_planes + secc_planes |
|
elif self.hparams.get("phase1_plane_fusion_mode", "add") == 'mul': |
|
planes = cano_planes * secc_planes |
|
else: raise NotImplementedError() |
|
return planes |
|
|
|
def cal_plane(self, img, cond, ret=None, **kwargs): |
|
cano_planes = self.cal_cano_plane(img, cond, **kwargs) |
|
planes = self.cal_plane_given_cano(cano_planes, cond) |
|
return planes, cano_planes |
|
|
|
def sample(self, coordinates, directions, img, cond=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, ref_camera=None, **synthesis_kwargs): |
|
|
|
planes, _ = self.cal_plane(img, cond, ret={}, ref_camera=ref_camera) |
|
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) |
|
|
|
def synthesis(self, img, camera, cond=None, ret=None, update_emas=False, cache_backbone=True, use_cached_backbone=False, **synthesis_kwargs): |
|
if ret is None: ret = {} |
|
cam2world_matrix = camera[:, :16].view(-1, 4, 4) |
|
intrinsics = camera[:, 16:25].view(-1, 3, 3) |
|
|
|
neural_rendering_resolution = self.neural_rendering_resolution |
|
|
|
|
|
ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) |
|
|
|
|
|
N, M, _ = ray_origins.shape |
|
if use_cached_backbone: |
|
|
|
cano_planes = self._last_cano_planes |
|
planes = self.cal_plane_given_cano(cano_planes, cond) |
|
else: |
|
planes, cano_planes = self.cal_plane(img, cond, ret, **synthesis_kwargs) |
|
if cache_backbone: |
|
self._last_cano_planes = cano_planes |
|
|
|
|
|
feature_samples, depth_samples, weights_samples, is_ray_valid = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) |
|
|
|
|
|
H = W = self.neural_rendering_resolution |
|
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() |
|
weights_image = weights_samples.permute(0, 2, 1).reshape(N,1,H,W).contiguous() |
|
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
|
if self.hparams.get("mask_invalid_rays", False): |
|
is_ray_valid_mask = is_ray_valid.reshape([feature_samples.shape[0], 1,self.neural_rendering_resolution,self.neural_rendering_resolution]) |
|
feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] = -1 |
|
|
|
|
|
depth_image[~is_ray_valid_mask] = depth_image[is_ray_valid_mask].min().item() |
|
|
|
|
|
rgb_image = feature_image[:, :3] |
|
ret['weights_img'] = weights_image |
|
sr_image = self._forward_sr(rgb_image, feature_image, cond, ret, **synthesis_kwargs) |
|
rgb_image = rgb_image.clamp(-1,1) |
|
sr_image = sr_image.clamp(-1,1) |
|
ret.update({'image_raw': rgb_image, 'image_depth': depth_image, 'image': sr_image, 'image_feature': feature_image[:, 3:], 'plane': planes}) |
|
return ret |
|
|