|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
import kiui |
|
from kiui.lpips import LPIPS |
|
|
|
from core.unet import UNet |
|
from core.options import Options |
|
from core.gs import GaussianRenderer |
|
|
|
|
|
class LGM(nn.Module): |
|
def __init__( |
|
self, |
|
opt: Options, |
|
): |
|
super().__init__() |
|
|
|
self.opt = opt |
|
|
|
|
|
self.unet = UNet( |
|
9, 14, |
|
down_channels=self.opt.down_channels, |
|
down_attention=self.opt.down_attention, |
|
mid_attention=self.opt.mid_attention, |
|
up_channels=self.opt.up_channels, |
|
up_attention=self.opt.up_attention, |
|
) |
|
|
|
|
|
self.conv = nn.Conv2d(14, 14, kernel_size=1) |
|
|
|
|
|
self.gs = GaussianRenderer(opt) |
|
|
|
|
|
self.pos_act = lambda x: x.clamp(-1, 1) |
|
self.scale_act = lambda x: 0.1 * F.softplus(x) |
|
self.opacity_act = lambda x: torch.sigmoid(x) |
|
self.rot_act = F.normalize |
|
self.rgb_act = lambda x: 0.5 * torch.tanh(x) + 0.5 |
|
|
|
|
|
if self.opt.lambda_lpips > 0: |
|
self.lpips_loss = LPIPS(net='vgg') |
|
self.lpips_loss.requires_grad_(False) |
|
|
|
|
|
def state_dict(self, **kwargs): |
|
|
|
state_dict = super().state_dict(**kwargs) |
|
for k in list(state_dict.keys()): |
|
if 'lpips_loss' in k: |
|
del state_dict[k] |
|
return state_dict |
|
|
|
|
|
def prepare_default_rays(self, device, elevation=0): |
|
|
|
from kiui.cam import orbit_camera |
|
from core.utils import get_rays |
|
|
|
cam_poses = np.stack([ |
|
orbit_camera(elevation, 0, radius=self.opt.cam_radius), |
|
orbit_camera(elevation, 90, radius=self.opt.cam_radius), |
|
orbit_camera(elevation, 180, radius=self.opt.cam_radius), |
|
orbit_camera(elevation, 270, radius=self.opt.cam_radius), |
|
], axis=0) |
|
cam_poses = torch.from_numpy(cam_poses) |
|
|
|
rays_embeddings = [] |
|
for i in range(cam_poses.shape[0]): |
|
rays_o, rays_d = get_rays(cam_poses[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) |
|
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) |
|
rays_embeddings.append(rays_plucker) |
|
|
|
|
|
|
|
|
|
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous().to(device) |
|
|
|
return rays_embeddings |
|
|
|
|
|
def forward_gaussians(self, images): |
|
|
|
|
|
|
|
B, V, C, H, W = images.shape |
|
images = images.view(B*V, C, H, W) |
|
|
|
x = self.unet(images) |
|
x = self.conv(x) |
|
|
|
x = x.reshape(B, 4, 14, self.opt.splat_size, self.opt.splat_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(0, 1, 3, 4, 2).reshape(B, -1, 14) |
|
|
|
pos = self.pos_act(x[..., 0:3]) |
|
opacity = self.opacity_act(x[..., 3:4]) |
|
scale = self.scale_act(x[..., 4:7]) |
|
rotation = self.rot_act(x[..., 7:11]) |
|
rgbs = self.rgb_act(x[..., 11:]) |
|
|
|
rot_matrix = torch.tensor([[1.0, 0.0, 0.0, 0.0], |
|
[0.0, -1.0, 0.0, 0.0], |
|
[0.0, 0.0, -1.0, 0.0], |
|
[0.0, 0.0, 0.0, 1.0]], dtype=torch.float32, device=images.device) |
|
|
|
pos_4d = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1) |
|
pos = torch.matmul(pos_4d, rot_matrix) |
|
pos = pos[..., :3] |
|
|
|
rotation = torch.matmul(rotation, rot_matrix) |
|
|
|
gaussians = torch.cat([pos, opacity, scale, rotation, rgbs], dim=-1) |
|
|
|
return gaussians |
|
|
|
|
|
def forward(self, data, step_ratio=1): |
|
|
|
|
|
|
|
results = {} |
|
loss = 0 |
|
|
|
images = data['input'] |
|
|
|
|
|
gaussians = self.forward_gaussians(images) |
|
|
|
results['gaussians'] = gaussians |
|
|
|
|
|
if self.training: |
|
bg_color = torch.rand(3, dtype=torch.float32, device=gaussians.device) |
|
else: |
|
bg_color = torch.ones(3, dtype=torch.float32, device=gaussians.device) |
|
|
|
|
|
results = self.gs.render(gaussians, data['cam_view'], data['cam_view_proj'], data['cam_pos'], bg_color=bg_color) |
|
pred_images = results['image'] |
|
pred_alphas = results['alpha'] |
|
|
|
results['images_pred'] = pred_images |
|
results['alphas_pred'] = pred_alphas |
|
|
|
gt_images = data['images_output'] |
|
gt_masks = data['masks_output'] |
|
|
|
gt_images = gt_images * gt_masks + bg_color.view(1, 1, 3, 1, 1) * (1 - gt_masks) |
|
|
|
loss_mse = F.mse_loss(pred_images, gt_images) + F.mse_loss(pred_alphas, gt_masks) |
|
loss = loss + loss_mse |
|
|
|
if self.opt.lambda_lpips > 0: |
|
loss_lpips = self.lpips_loss( |
|
|
|
|
|
|
|
F.interpolate(gt_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), |
|
F.interpolate(pred_images.view(-1, 3, self.opt.output_size, self.opt.output_size) * 2 - 1, (256, 256), mode='bilinear', align_corners=False), |
|
).mean() |
|
results['loss_lpips'] = loss_lpips |
|
loss = loss + self.opt.lambda_lpips * loss_lpips |
|
|
|
results['loss'] = loss |
|
|
|
|
|
with torch.no_grad(): |
|
psnr = -10 * torch.log10(torch.mean((pred_images.detach() - gt_images) ** 2)) |
|
results['psnr'] = psnr |
|
|
|
return results |