VOODOO3D-unofficial / models /voodoo3d_model.py
ameerazam08's picture
Upload folder using huggingface_hub
03da825 verified
raw
history blame
7.18 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict
from additional_modules.deeplabv3.deeplabv3 import DeepLabV3
from additional_modules.eg3d.camera_utils import IntrinsicsSampler, LookAtPoseSampler
from additional_modules.segformer.backbone import Block, OverlapPatchEmbed
from models.utils.face_augmentor import FaceAugmentor
from models.lp3d_model import PositionalEncoder, Lp3D
from utils.registry import MODEL_REGISTRY
class ExpEncoder(PositionalEncoder):
def __init__(self, img_size=512, img_channels=3, use_aug=True):
super().__init__(img_size)
self.use_aug = use_aug
self.source_feat_extractor = DeepLabV3(input_channels=img_channels + 2)
self.driver_feat_extractor = DeepLabV3(input_channels=img_channels + 2)
self.triplane_descriptor = nn.Sequential(
nn.Conv2d(96, 96, 3, 2, 1, bias=True),
nn.ReLU(),
nn.Conv2d(96, 96, 3, 1, 1, bias=True),
nn.ReLU(),
nn.Conv2d(96, 128, 3, 2, 1, bias=True),
nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, 1, bias=True),
nn.ReLU(),
nn.Conv2d(128, 128, 3, 1, 1, bias=True),
)
self.patch_embed = OverlapPatchEmbed(
img_size=img_size // 8, patch_size=3, stride=2, in_chans=256 * 2 + 128, embed_dim=1024
)
self.block1 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1)
self.block2 = Block(dim=1024, num_heads=4, mlp_ratio=2, sr_ratio=1)
self.up1 = nn.PixelShuffle(upscale_factor=2)
self.up2 = nn.Upsample(scale_factor=2, mode='bilinear')
self.up3 = nn.Upsample(scale_factor=2, mode='bilinear')
self.conv1 = nn.Conv2d(256, 128, 3, 1, 1, bias=True)
self.act1 = nn.ReLU()
self.conv2 = nn.Conv2d(128, 128, 3, 1, 1, bias=True)
self.act2 = nn.ReLU()
self.conv3 = nn.Conv2d(128, 96, 3, 1, 1, bias=True)
self.driver_aug = FaceAugmentor()
def _calculate_delta(self, xs_img, xd_img, xs_triplane):
if self.use_aug:
xd_img = self.driver_aug(
xd_img,
target_size=(512, 512),
apply_color_aug=self.training,
apply_rnd_mask=self.training,
apply_rnd_zoom=self.training,
)
else:
xd_img = F.interpolate(xd_img, xs_img.shape[-2:])
xs_img = self._add_positional_encoding(xs_img)
xd_img = self._add_positional_encoding(xd_img)
xs_feat = self.source_feat_extractor(xs_img)
xd_feat = self.driver_feat_extractor(xd_img)
xs_triplane = xs_triplane.reshape(-1, 96, 256, 256)
xs_triplane_feat = self.triplane_descriptor(xs_triplane)
x = torch.cat((xs_feat, xd_feat, xs_triplane_feat), dim=1)
x, H, W = self.patch_embed(x)
x = self.block1(x, H, W)
x = self.block2(x, H, W)
x = x.reshape(xs_img.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous()
x = self.up1(x)
x = self.up2(x)
x = self.conv1(x)
x = self.act1(x)
x = self.up3(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = x.reshape(-1, 3, 32, 256, 256)
return x
def forward(self, xs_img, xd_img, xs_triplane):
delta = self._calculate_delta(xs_img, xd_img, xs_triplane)
xs_triplane = xs_triplane + delta
return xs_triplane
@MODEL_REGISTRY.register()
class Voodoo3D(Lp3D):
def __init__(
self,
neural_rendering_resolution: int, # Render at this resolution and use superres to upsample to 512x512
triplane_nd: int, # Triplane's number of channels
triplane_h: int, # Triplane height
triplane_w: int, # Triplane width
use_aug: bool,
rendering_kwargs,
superresolution_kwargs,
):
self.use_aug = use_aug
super().__init__(
neural_rendering_resolution,
triplane_nd,
triplane_h,
triplane_w,
rendering_kwargs,
superresolution_kwargs,
)
lookat_point = torch.tensor(rendering_kwargs['lookat_point']).unsqueeze(0).float()
canonical_cam2world = LookAtPoseSampler.sample(
np.pi / 2, np.pi / 2, rendering_kwargs['camera_radius'],
lookat_point,
np.pi / 2, np.pi / 2, 0.0,
batch_size=1
)
canonical_intrinsics = IntrinsicsSampler.sample(
18.837, 0.5,
0, 0,
batch_size=1
)
self.register_buffer('lookat_point', lookat_point)
self.register_buffer('canonical_cam2world', canonical_cam2world)
self.register_buffer('canonical_intrinsics', canonical_intrinsics)
def _setup_modules(self):
super()._setup_modules()
self.exp_transfer = ExpEncoder(use_aug=self.use_aug)
self.triplane_encoder.requires_grad_(False)
self.superresolution.requires_grad_(False)
self.decoder.requires_grad_(False)
def frontalize(self, inp, neural_upsample):
"""
Frontalize the input image/triplane.
Parameters:
- inp (Tensor): Can be either image (Bx3xHxW) or triplane (Bx3x32x256x256)
- neural_upsample (bool): If True, use superresolution to upsample the output. Otherwise, just
upsample it using nearest interpolation.
"""
if len(inp.shape) == 4: # Case 1: Input is RGB image
inp = self.canonicalize(inp)
canonical_cam2world = self.canonical_cam2world.repeat(inp.shape[0], 1, 1)
canonical_intrinsics = self.canonical_intrinsics.repeat(inp.shape[0], 1, 1)
frontalized_data = self.render(
inp, canonical_cam2world, canonical_intrinsics, upsample=neural_upsample
)
return frontalized_data
def forward(
self,
xs_data: Dict[str, torch.Tensor],
xd_data: Dict[str, torch.Tensor]
):
"""
Reenact the source image using driver(s).
Parameters:
- xs_data: The source's data. Must have 'image' key in it
- xd_data: The driver' data. Must have 'image, 'cam2world', and 'intrinsics'
"""
xs_triplane = self.canonicalize(xs_data['image'])
# Frontalize xs
with torch.no_grad():
xs_face_frontal_data = self.frontalize(xs_data['image'], neural_upsample=True)
xs_face_frontal_hr = (xs_face_frontal_data['image'] + 1) / 2 # Legacy issue
with torch.no_grad():
xd_triplane = self.canonicalize(xd_data['image'])
xd_face_frontal_data = self.frontalize(xd_triplane, neural_upsample=False)
xd_face_frontal_lr = F.interpolate(xd_face_frontal_data['image_raw'], (512, 512))
xd_face_frontal_lr = (xd_face_frontal_lr + 1) / 2 # Legacy issue
xs_triplane_newExp = self.exp_transfer(xs_face_frontal_hr, xd_face_frontal_lr, xs_triplane)
driver_out = self.render(xs_triplane_newExp, xd_data['cam2world'], xd_data['intrinsics'])
return driver_out