|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch_efficient_distloss import flatten_eff_distloss |
|
|
|
|
|
import pytorch_lightning as pl |
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug |
|
|
|
|
|
import models |
|
|
from models.utils import cleanup |
|
|
from models.ray_utils import get_rays |
|
|
import systems |
|
|
from systems.base import BaseSystem |
|
|
from systems.criterions import PSNR, binary_cross_entropy |
|
|
|
|
|
import pdb |
|
|
|
|
|
def ranking_loss(error, penalize_ratio=0.7, extra_weights=None , type='mean'): |
|
|
error, indices = torch.sort(error) |
|
|
|
|
|
s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) |
|
|
if extra_weights is not None: |
|
|
weights = torch.index_select(extra_weights, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) |
|
|
s_error = s_error * weights |
|
|
|
|
|
if type == 'mean': |
|
|
return torch.mean(s_error) |
|
|
elif type == 'sum': |
|
|
return torch.sum(s_error) |
|
|
|
|
|
@systems.register('pinhole-neus-system') |
|
|
class PinholeNeuSSystem(BaseSystem): |
|
|
""" |
|
|
Two ways to print to console: |
|
|
1. self.print: correctly handle progress bar |
|
|
2. rank_zero_info: use the logging module |
|
|
""" |
|
|
def prepare(self): |
|
|
self.criterions = { |
|
|
'psnr': PSNR() |
|
|
} |
|
|
self.train_num_samples = self.config.model.train_num_rays * (self.config.model.num_samples_per_ray + self.config.model.get('num_samples_per_ray_bg', 0)) |
|
|
self.train_num_rays = self.config.model.train_num_rays |
|
|
self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) |
|
|
|
|
|
def forward(self, batch): |
|
|
return self.model(batch['rays']) |
|
|
|
|
|
def preprocess_data(self, batch, stage): |
|
|
if 'index' in batch: |
|
|
index = batch['index'] |
|
|
else: |
|
|
if self.config.model.batch_image_sampling: |
|
|
index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) |
|
|
else: |
|
|
index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) |
|
|
if stage in ['train']: |
|
|
c2w = self.dataset.all_c2w[index] |
|
|
x = torch.randint( |
|
|
0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device |
|
|
) |
|
|
y = torch.randint( |
|
|
0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device |
|
|
) |
|
|
if self.dataset.directions.ndim == 3: |
|
|
directions = self.dataset.directions[y, x] |
|
|
|
|
|
elif self.dataset.directions.ndim == 4: |
|
|
directions = self.dataset.directions[index, y, x] |
|
|
|
|
|
rays_o, rays_d = get_rays(directions, c2w) |
|
|
rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) |
|
|
normal = self.dataset.all_normals_world[index, y, x].view(-1, self.dataset.all_normals_world.shape[-1]).to(self.rank) |
|
|
fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) |
|
|
rgb_mask = self.dataset.all_rgb_masks[index, y, x].view(-1).to(self.rank) |
|
|
view_weights = self.dataset.view_weights[index, y, x].view(-1).to(self.rank) |
|
|
else: |
|
|
c2w = self.dataset.all_c2w[index][0] |
|
|
if self.dataset.directions.ndim == 3: |
|
|
directions = self.dataset.directions |
|
|
|
|
|
elif self.dataset.directions.ndim == 4: |
|
|
directions = self.dataset.directions[index][0] |
|
|
|
|
|
rays_o, rays_d = get_rays(directions, c2w) |
|
|
rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) |
|
|
normal = self.dataset.all_normals_world[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) |
|
|
fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) |
|
|
rgb_mask = self.dataset.all_rgb_masks[index].view(-1).to(self.rank) |
|
|
view_weights = None |
|
|
|
|
|
cosines = self.cos(rays_d, normal) |
|
|
rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) |
|
|
|
|
|
if stage in ['train']: |
|
|
if self.config.model.background_color == 'white': |
|
|
self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) |
|
|
elif self.config.model.background_color == 'black': |
|
|
self.model.background_color = torch.zeros((3,), dtype=torch.float32, device=self.rank) |
|
|
elif self.config.model.background_color == 'random': |
|
|
self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) |
|
|
else: |
|
|
raise NotImplementedError |
|
|
else: |
|
|
self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) |
|
|
|
|
|
if self.dataset.apply_mask: |
|
|
rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) |
|
|
|
|
|
batch.update({ |
|
|
'rays': rays, |
|
|
'rgb': rgb, |
|
|
'normal': normal, |
|
|
'fg_mask': fg_mask, |
|
|
'rgb_mask': rgb_mask, |
|
|
'cosines': cosines, |
|
|
'view_weights': view_weights |
|
|
}) |
|
|
|
|
|
def training_step(self, batch, batch_idx): |
|
|
out = self(batch) |
|
|
|
|
|
cosines = batch['cosines'] |
|
|
fg_mask = batch['fg_mask'] |
|
|
rgb_mask = batch['rgb_mask'] |
|
|
view_weights = batch['view_weights'] |
|
|
|
|
|
cosines[cosines > -0.1] = 0 |
|
|
mask = ((fg_mask > 0) & (cosines < -0.1)) |
|
|
rgb_mask = out['rays_valid_full'][...,0] & (rgb_mask > 0) |
|
|
|
|
|
grad_cosines = self.cos(batch['rays'][...,3:], out['comp_normal']).detach() |
|
|
|
|
|
|
|
|
loss = 0. |
|
|
|
|
|
|
|
|
if self.config.model.dynamic_ray_sampling: |
|
|
train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples_full'].sum().item())) |
|
|
self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) |
|
|
|
|
|
erros_rgb_mse = F.mse_loss(out['comp_rgb_full'][rgb_mask], batch['rgb'][rgb_mask], reduction='none') |
|
|
|
|
|
|
|
|
loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='mean') |
|
|
self.log('train/loss_rgb_mse', loss_rgb_mse, prog_bar=True, rank_zero_only=True) |
|
|
loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) |
|
|
|
|
|
loss_rgb_l1 = F.l1_loss(out['comp_rgb_full'][rgb_mask], batch['rgb'][rgb_mask], reduction='none') |
|
|
loss_rgb_l1 = ranking_loss(loss_rgb_l1.sum(dim=1), |
|
|
extra_weights=view_weights[rgb_mask], |
|
|
penalize_ratio=0.8) |
|
|
self.log('train/loss_rgb', loss_rgb_l1) |
|
|
loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) |
|
|
|
|
|
normal_errors = 1 - F.cosine_similarity(out['comp_normal'], batch['normal'], dim=1) |
|
|
|
|
|
normal_errors = normal_errors * torch.exp(cosines.abs()) / torch.exp(cosines.abs()).sum() |
|
|
loss_normal = ranking_loss(normal_errors[mask], penalize_ratio=0.8, |
|
|
|
|
|
type='sum') |
|
|
self.log('train/loss_normal', loss_normal, prog_bar=True, rank_zero_only=True) |
|
|
loss += loss_normal * self.C(self.config.system.loss.lambda_normal) |
|
|
|
|
|
loss_eikonal = ((torch.linalg.norm(out['sdf_grad_samples'], ord=2, dim=-1) - 1.)**2).mean() |
|
|
self.log('train/loss_eikonal', loss_eikonal, prog_bar=True, rank_zero_only=True) |
|
|
loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) |
|
|
|
|
|
opacity = torch.clamp(out['opacity'].squeeze(-1), 1.e-3, 1.-1.e-3) |
|
|
loss_mask = binary_cross_entropy(opacity, batch['fg_mask'].float(), reduction='none') |
|
|
loss_mask = ranking_loss(loss_mask, penalize_ratio=0.9, extra_weights=view_weights) |
|
|
self.log('train/loss_mask', loss_mask, prog_bar=True, rank_zero_only=True) |
|
|
loss += loss_mask * (self.C(self.config.system.loss.lambda_mask) if self.dataset.has_mask else 0.0) |
|
|
|
|
|
loss_opaque = binary_cross_entropy(opacity, opacity) |
|
|
self.log('train/loss_opaque', loss_opaque) |
|
|
loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) |
|
|
|
|
|
loss_sparsity = torch.exp(-self.config.system.loss.sparsity_scale * out['random_sdf'].abs()).mean() |
|
|
self.log('train/loss_sparsity', loss_sparsity, prog_bar=True, rank_zero_only=True) |
|
|
loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) |
|
|
|
|
|
if self.C(self.config.system.loss.lambda_curvature) > 0: |
|
|
assert 'sdf_laplace_samples' in out, "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" |
|
|
loss_curvature = out['sdf_laplace_samples'].abs().mean() |
|
|
self.log('train/loss_curvature', loss_curvature) |
|
|
loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) |
|
|
|
|
|
|
|
|
|
|
|
if self.C(self.config.system.loss.lambda_distortion) > 0: |
|
|
loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) |
|
|
self.log('train/loss_distortion', loss_distortion) |
|
|
loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) |
|
|
|
|
|
if self.config.model.learned_background and self.C(self.config.system.loss.lambda_distortion_bg) > 0: |
|
|
loss_distortion_bg = flatten_eff_distloss(out['weights_bg'], out['points_bg'], out['intervals_bg'], out['ray_indices_bg']) |
|
|
self.log('train/loss_distortion_bg', loss_distortion_bg) |
|
|
loss += loss_distortion_bg * self.C(self.config.system.loss.lambda_distortion_bg) |
|
|
|
|
|
if self.C(self.config.system.loss.lambda_3d_normal_smooth) > 0: |
|
|
if "random_sdf_grad" not in out: |
|
|
raise ValueError( |
|
|
"random_sdf_grad is required for normal smooth loss, no normal is found in the output." |
|
|
) |
|
|
if "normal_perturb" not in out: |
|
|
raise ValueError( |
|
|
"normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." |
|
|
) |
|
|
normals_3d = out["random_sdf_grad"] |
|
|
normals_perturb_3d = out["normal_perturb"] |
|
|
loss_3d_normal_smooth = (normals_3d - normals_perturb_3d).abs().mean() |
|
|
self.log('train/loss_3d_normal_smooth', loss_3d_normal_smooth, prog_bar=True ) |
|
|
|
|
|
loss += loss_3d_normal_smooth * self.C(self.config.system.loss.lambda_3d_normal_smooth) |
|
|
|
|
|
losses_model_reg = self.model.regularizations(out) |
|
|
for name, value in losses_model_reg.items(): |
|
|
self.log(f'train/loss_{name}', value) |
|
|
loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) |
|
|
loss += loss_ |
|
|
|
|
|
self.log('train/inv_s', out['inv_s'], prog_bar=True) |
|
|
|
|
|
for name, value in self.config.system.loss.items(): |
|
|
if name.startswith('lambda'): |
|
|
self.log(f'train_params/{name}', self.C(value)) |
|
|
|
|
|
self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) |
|
|
|
|
|
return { |
|
|
'loss': loss |
|
|
} |
|
|
|
|
|
""" |
|
|
# aggregate outputs from different devices (DP) |
|
|
def training_step_end(self, out): |
|
|
pass |
|
|
""" |
|
|
|
|
|
""" |
|
|
# aggregate outputs from different iterations |
|
|
def training_epoch_end(self, out): |
|
|
pass |
|
|
""" |
|
|
|
|
|
def validation_step(self, batch, batch_idx): |
|
|
out = self(batch) |
|
|
psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) |
|
|
W, H = self.dataset.img_wh |
|
|
self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ |
|
|
{'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, |
|
|
{'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} |
|
|
] + ([ |
|
|
{'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, |
|
|
{'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, |
|
|
] if self.config.model.learned_background else []) + [ |
|
|
{'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, |
|
|
{'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} |
|
|
]) |
|
|
return { |
|
|
'psnr': psnr, |
|
|
'index': batch['index'] |
|
|
} |
|
|
|
|
|
|
|
|
""" |
|
|
# aggregate outputs from different devices when using DP |
|
|
def validation_step_end(self, out): |
|
|
pass |
|
|
""" |
|
|
|
|
|
def validation_epoch_end(self, out): |
|
|
out = self.all_gather(out) |
|
|
if self.trainer.is_global_zero: |
|
|
out_set = {} |
|
|
for step_out in out: |
|
|
|
|
|
if step_out['index'].ndim == 1: |
|
|
out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} |
|
|
|
|
|
else: |
|
|
for oi, index in enumerate(step_out['index']): |
|
|
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} |
|
|
psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) |
|
|
self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) |
|
|
self.export() |
|
|
|
|
|
def test_step(self, batch, batch_idx): |
|
|
out = self(batch) |
|
|
psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) |
|
|
W, H = self.dataset.img_wh |
|
|
self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ |
|
|
{'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, |
|
|
{'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} |
|
|
] + ([ |
|
|
{'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, |
|
|
{'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, |
|
|
] if self.config.model.learned_background else []) + [ |
|
|
{'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, |
|
|
{'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} |
|
|
]) |
|
|
return { |
|
|
'psnr': psnr, |
|
|
'index': batch['index'] |
|
|
} |
|
|
|
|
|
def test_epoch_end(self, out): |
|
|
""" |
|
|
Synchronize devices. |
|
|
Generate image sequence using test outputs. |
|
|
""" |
|
|
out = self.all_gather(out) |
|
|
if self.trainer.is_global_zero: |
|
|
out_set = {} |
|
|
for step_out in out: |
|
|
|
|
|
if step_out['index'].ndim == 1: |
|
|
out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} |
|
|
|
|
|
else: |
|
|
for oi, index in enumerate(step_out['index']): |
|
|
out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} |
|
|
psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) |
|
|
self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) |
|
|
|
|
|
self.save_img_sequence( |
|
|
f"it{self.global_step}-test", |
|
|
f"it{self.global_step}-test", |
|
|
'(\d+)\.png', |
|
|
save_format='mp4', |
|
|
fps=30 |
|
|
) |
|
|
|
|
|
self.export() |
|
|
|
|
|
def export(self): |
|
|
mesh = self.model.export(self.config.export) |
|
|
self.save_mesh( |
|
|
f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", |
|
|
ortho_scale=self.config.export.ortho_scale, |
|
|
**mesh |
|
|
) |
|
|
|