L4GM-demo / core /models.py
fffiloni's picture
Migrated from GitHub
2cdb96e verified
raw
history blame
7.35 kB
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import kiui
from kiui.lpips import LPIPS
from core.unet import UNet
from core.options import Options
from core.gs import GaussianRenderer
class LGM(nn.Module):
def __init__(
self,
opt: Options,
):
super().__init__()
self.opt = opt
# unet
self.unet = UNet(
9, 14 * self.opt.gaussian_perpixel,
down_channels=self.opt.down_channels,
down_attention=self.opt.down_attention,
mid_attention=self.opt.mid_attention,
up_channels=self.opt.up_channels,
up_attention=self.opt.up_attention,
num_views=self.opt.num_input_views,
num_frames=self.opt.num_frames,
use_temp_attn=self.opt.use_temp_attn
)
# last conv
self.conv = nn.Conv2d(14 * self.opt.gaussian_perpixel, 14 * self.opt.gaussian_perpixel, kernel_size=1) # NOTE: maybe remove it if train again
# Gaussian Renderer
self.gs = GaussianRenderer(opt)
# activations...
self.pos_act = lambda x: x.clamp(-1, 1)
self.scale_act = lambda x: 0.1 * F.softplus(x)
self.opacity_act = lambda x: torch.sigmoid(x)
self.rot_act = lambda x: F.normalize(x, dim=-1)
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 # NOTE: may use sigmoid if train again
# LPIPS loss
if self.opt.lambda_lpips > 0:
self.lpips_loss = LPIPS(net='vgg')
self.lpips_loss.requires_grad_(False)
def state_dict(self, **kwargs):
# remove lpips_loss
state_dict = super().state_dict(**kwargs)
for k in list(state_dict.keys()):
if 'lpips_loss' in k:
del state_dict[k]
return state_dict
def prepare_default_rays(self, device, elevation=0):
from kiui.cam import orbit_camera
from core.utils import get_rays
cam_poses = np.stack([
orbit_camera(elevation, 0, radius=self.opt.cam_radius),
orbit_camera(elevation, 90, radius=self.opt.cam_radius),
orbit_camera(elevation, 180, radius=self.opt.cam_radius),
orbit_camera(elevation, 270, radius=self.opt.cam_radius),
], axis=0) # [4, 4, 4]
cam_poses = torch.from_numpy(cam_poses)
rays_embeddings = []
for i in range(cam_poses.shape[0]):
rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) # [h, w, 3]
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) # [h, w, 6]
rays_embeddings.append(rays_plucker)
## visualize rays for plotting figure
# kiui.vis.plot_image(rays_d * 0.5 + 0.5, save=True)
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) # [V, 6, h, w]
return rays_embeddings
def forward_gaussians(self, images):
# images: [B, T, 4, 9, H, W]
# return: Gaussians: [B, dim_t]
B, TV, C, H, W = images.shape
T = self.opt.num_frames
V = TV // T
images = images.view(B*T*V, C, H, W)
x = self.unet(images) # [B*4, 14, h, w]
x = self.conv(x) # [B*4, 14, h, w]
x = x.reshape(B*T, V, 14 * self.opt.gaussian_perpixel, self.opt.splat_size, self.opt.splat_size)
x = x.permute(0, 1, 3, 4, 2).reshape(B*T, -1, 14).contiguous()
pos = self.pos_act(x[..., 0:3]) # [B, N, 3]
opacity = self.opacity_act(x[..., 3:4])
scale = self.scale_act(x[..., 4:7])
rotation = self.rot_act(x[..., 7:11])
rgbs = self.rgb_act(x[..., 11:])
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) # [B, T, N, 14]
return gaussians
def forward(self, data, step_ratio=1):
# data: output of the dataloader
# return: loss
results = {}
loss = 0
images = data['input'] # [B, Tx4, 9, h, W], input features
B, TV, C, H, W = images.shape
T = self.opt.num_frames
# use the first view to predict gaussians
gaussians = self.forward_gaussians(images) # [B * T, N, 14]
results['gaussians'] = gaussians
# always use white bg
bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device)
# use the other views for rendering and supervision
data['cam_view'] = data['cam_view'].reshape(B*T, -1, *data['cam_view'].shape[2:])
data['cam_view_proj'] = data['cam_view_proj'].reshape(B*T, -1, *data['cam_view_proj'].shape[2:])
data['cam_pos'] = data['cam_pos'].reshape(B*T, -1, *data['cam_pos'].shape[2:])
results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color)
pred_images = results['image'] # [B*T, V, C, output_size, output_size]
pred_alphas = results['alpha'] # [B*T, V, 1, output_size, output_size]
results['images_pred'] = pred_images
results['alphas_pred'] = pred_alphas
data['images_output'] = data['images_output'].reshape(B*T, -1, *data['images_output'].shape[2:])
data['masks_output'] = data['masks_output'].reshape(B*T, -1, *data['masks_output'].shape[2:])
gt_images = data['images_output'] # [B*T, V, 3, output_size, output_size], ground-truth novel views
gt_masks = data['masks_output'] # [B*T, V, 1, output_size, output_size], ground-truth masks
gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks)
loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks)
loss = loss + loss_mse
if self.opt.lambda_lpips > 0:
loss_lpips = self.lpips_loss(
# downsampled to at most 256 to reduce memory cost
F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False),
).mean()
results['loss_lpips'] = loss_lpips
loss = loss + self.opt.lambda_lpips * loss_lpips
results['loss'] = loss
# metric
with torch.no_grad():
psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2))
results['psnr'] = psnr
return results