| from dataclasses import dataclass |
| from pathlib import Path |
| import gc |
| import random |
| from typing import Literal, Optional, Protocol, runtime_checkable, Any |
|
|
| import moviepy.editor as mpy |
| import torch |
| import torchvision |
| import wandb |
| from einops import pack, rearrange, repeat |
| from jaxtyping import Float |
| from lightning.pytorch import LightningModule |
| from lightning.pytorch.loggers.wandb import WandbLogger |
| from lightning.pytorch.utilities import rank_zero_only |
| from tabulate import tabulate |
| from torch import Tensor, nn, optim |
| import torch.nn.functional as F |
|
|
| from loss.loss_lpips import LossLpips |
| from loss.loss_mse import LossMse |
| from model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
|
|
| from ..loss.loss_distill import DistillLoss |
| from src.utils.render import generate_path |
| from src.utils.point import get_normal_map |
|
|
| from ..loss.loss_huber import HuberLoss, extri_intri_to_pose_encoding |
|
|
| |
|
|
| from ..dataset.data_module import get_data_shim |
| from ..dataset.types import BatchedExample |
| from ..evaluation.metrics import compute_lpips, compute_psnr, compute_ssim, abs_relative_difference, delta1_acc |
| from ..global_cfg import get_cfg |
| from ..loss import Loss |
| from ..loss.loss_point import Regr3D |
| from ..loss.loss_ssim import ssim |
| from ..misc.benchmarker import Benchmarker |
| from ..misc.cam_utils import update_pose, get_pnp_pose, rotation_6d_to_matrix |
| from ..misc.image_io import prep_image, save_image, save_video |
| from ..misc.LocalLogger import LOG_PATH, LocalLogger |
| from ..misc.nn_module_tools import convert_to_buffer |
| from ..misc.step_tracker import StepTracker |
| from ..misc.utils import inverse_normalize, vis_depth_map, confidence_map, get_overlap_tag |
| from ..visualization.annotation import add_label |
| from ..visualization.camera_trajectory.interpolation import ( |
| interpolate_extrinsics, |
| interpolate_intrinsics, |
| ) |
| from ..visualization.camera_trajectory.wobble import ( |
| generate_wobble, |
| generate_wobble_transformation, |
| ) |
| from ..visualization.color_map import apply_color_map_to_image |
| from ..visualization.layout import add_border, hcat, vcat |
| |
| from .decoder.decoder import Decoder, DepthRenderingMode |
| from .encoder import Encoder |
| from .encoder.visualization.encoder_visualizer import EncoderVisualizer |
| from .ply_export import export_ply |
|
|
| @dataclass |
| class OptimizerCfg: |
| lr: float |
| warm_up_steps: int |
| backbone_lr_multiplier: float |
|
|
|
|
| @dataclass |
| class TestCfg: |
| output_path: Path |
| align_pose: bool |
| pose_align_steps: int |
| rot_opt_lr: float |
| trans_opt_lr: float |
| compute_scores: bool |
| save_image: bool |
| save_video: bool |
| save_compare: bool |
| generate_video: bool |
| mode: Literal["inference", "evaluation"] |
| image_folder: str |
|
|
|
|
| @dataclass |
| class TrainCfg: |
| output_path: Path |
| depth_mode: DepthRenderingMode | None |
| extended_visualization: bool |
| print_log_every_n_steps: int |
| distiller: str |
| distill_max_steps: int |
| pose_loss_alpha: float = 1.0 |
| pose_loss_delta: float = 1.0 |
| cxt_depth_weight: float = 0.01 |
| weight_pose: float = 1.0 |
| weight_depth: float = 1.0 |
| weight_normal: float = 1.0 |
| render_ba: bool = False |
| render_ba_after_step: int = 0 |
|
|
|
|
| @runtime_checkable |
| class TrajectoryFn(Protocol): |
| def __call__( |
| self, |
| t: Float[Tensor, " t"], |
| ) -> tuple[ |
| Float[Tensor, "batch view 4 4"], |
| Float[Tensor, "batch view 3 3"], |
| ]: |
| pass |
|
|
|
|
| class ModelWrapper(LightningModule): |
| logger: Optional[WandbLogger] |
| model: nn.Module |
| losses: nn.ModuleList |
| optimizer_cfg: OptimizerCfg |
| test_cfg: TestCfg |
| train_cfg: TrainCfg |
| step_tracker: StepTracker | None |
|
|
| def __init__( |
| self, |
| optimizer_cfg: OptimizerCfg, |
| test_cfg: TestCfg, |
| train_cfg: TrainCfg, |
| model: nn.Module, |
| losses: list[Loss], |
| step_tracker: StepTracker | None |
| ) -> None: |
| super().__init__() |
| self.optimizer_cfg = optimizer_cfg |
| self.test_cfg = test_cfg |
| self.train_cfg = train_cfg |
| self.step_tracker = step_tracker |
| |
| |
| self.encoder_visualizer = None |
| self.model = model |
| self.data_shim = get_data_shim(self.model.encoder) |
| self.losses = nn.ModuleList(losses) |
| |
| if self.model.encoder.pred_pose: |
| self.loss_pose = HuberLoss(alpha=self.train_cfg.pose_loss_alpha, delta=self.train_cfg.pose_loss_delta) |
| |
| if self.model.encoder.distill: |
| self.loss_distill = DistillLoss( |
| delta=self.train_cfg.pose_loss_delta, |
| weight_pose=self.train_cfg.weight_pose, |
| weight_depth=self.train_cfg.weight_depth, |
| weight_normal=self.train_cfg.weight_normal |
| ) |
|
|
| |
| self.benchmarker = Benchmarker() |
| |
| def on_train_epoch_start(self) -> None: |
| |
| if hasattr(self.trainer.datamodule.train_loader.dataset, "set_epoch"): |
| self.trainer.datamodule.train_loader.dataset.set_epoch(self.current_epoch) |
| if hasattr(self.trainer.datamodule.train_loader.sampler, "set_epoch"): |
| self.trainer.datamodule.train_loader.sampler.set_epoch(self.current_epoch) |
|
|
| def on_validation_epoch_start(self) -> None: |
| print(f"Validation epoch start on rank {self.trainer.global_rank}") |
| |
| if hasattr(self.trainer.datamodule.val_loader.dataset, "set_epoch"): |
| self.trainer.datamodule.val_loader.dataset.set_epoch(self.current_epoch) |
| if hasattr(self.trainer.datamodule.val_loader.sampler, "set_epoch"): |
| self.trainer.datamodule.val_loader.sampler.set_epoch(self.current_epoch) |
| |
| def training_step(self, batch, batch_idx): |
| |
| |
| if isinstance(batch, list): |
| batch_combined = None |
| for batch_per_dl in batch: |
| if batch_combined is None: |
| batch_combined = batch_per_dl |
| else: |
| for k in batch_combined.keys(): |
| if isinstance(batch_combined[k], list): |
| batch_combined[k] += batch_per_dl[k] |
| elif isinstance(batch_combined[k], dict): |
| for kk in batch_combined[k].keys(): |
| batch_combined[k][kk] = torch.cat([batch_combined[k][kk], batch_per_dl[k][kk]], dim=0) |
| else: |
| raise NotImplementedError |
| batch = batch_combined |
| |
| batch: BatchedExample = self.data_shim(batch) |
| b, v, c, h, w = batch["context"]["image"].shape |
| context_image = (batch["context"]["image"] + 1) / 2 |
| |
| |
| visualization_dump = None |
|
|
| encoder_output, output = self.model(context_image, self.global_step, visualization_dump=visualization_dump) |
| gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict |
| pred_context_pose = encoder_output.pred_context_pose |
| infos = encoder_output.infos |
| distill_infos = encoder_output.distill_infos |
| |
| num_context_views = pred_context_pose['extrinsic'].shape[1] |
|
|
| using_index = torch.arange(num_context_views, device=gaussians.means.device) |
| batch["using_index"] = using_index |
| |
| target_gt = (batch["context"]["image"] + 1) / 2 |
| scene_scale = infos["scene_scale"] |
| self.log("train/scene_scale", infos["scene_scale"]) |
| self.log("train/voxelize_ratio", infos["voxelize_ratio"]) |
|
|
| |
| psnr_probabilistic = compute_psnr( |
| rearrange(target_gt, "b v c h w -> (b v) c h w"), |
| rearrange(output.color, "b v c h w -> (b v) c h w"), |
| ) |
| self.log("train/psnr_probabilistic", psnr_probabilistic.mean()) |
|
|
| consis_absrel = abs_relative_difference( |
| rearrange(output.depth, "b v h w -> (b v) h w"), |
| rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), |
| rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"), |
| ) |
| self.log("train/consis_absrel", consis_absrel.mean()) |
|
|
| consis_delta1 = delta1_acc( |
| rearrange(output.depth, "b v h w -> (b v) h w"), |
| rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), |
| rearrange(distill_infos['conf_mask'], "b v h w -> (b v) h w"), |
| ) |
| self.log("train/consis_delta1", consis_delta1.mean()) |
| |
| |
| total_loss = 0 |
|
|
| depth_dict['distill_infos'] = distill_infos |
| with torch.amp.autocast('cuda', enabled=False): |
| for loss_fn in self.losses: |
| loss = loss_fn.forward(output, batch, gaussians, depth_dict, self.global_step) |
| self.log(f"loss/{loss_fn.name}", loss) |
| total_loss = total_loss + loss |
|
|
| if depth_dict is not None and "depth" in get_cfg()["loss"].keys() and self.train_cfg.cxt_depth_weight > 0: |
| depth_loss_idx = list(get_cfg()["loss"].keys()).index("depth") |
| depth_loss_fn = self.losses[depth_loss_idx].ctx_depth_loss |
| loss_depth = depth_loss_fn(depth_dict["depth_map"], depth_dict["depth_conf"], batch, cxt_depth_weight=self.train_cfg.cxt_depth_weight) |
| self.log("loss/ctx_depth", loss_depth) |
| total_loss = total_loss + loss_depth |
|
|
| if distill_infos is not None: |
| |
| loss_distill_list = self.loss_distill(distill_infos, pred_pose_enc_list, output, batch) |
| self.log("loss/distill", loss_distill_list['loss_distill']) |
| self.log("loss/distill_pose", loss_distill_list['loss_pose']) |
| self.log("loss/distill_depth", loss_distill_list['loss_depth']) |
| self.log("loss/distill_normal", loss_distill_list['loss_normal']) |
| total_loss = total_loss + loss_distill_list['loss_distill'] |
| |
| self.log("loss/total", total_loss) |
| print(f"total_loss: {total_loss}") |
|
|
| |
| SKIP_AFTER_STEP = 1000 |
| LOSS_THRESHOLD = 0.2 |
| if self.global_step > SKIP_AFTER_STEP and total_loss > LOSS_THRESHOLD: |
| print(f"Skipping batch with high loss ({total_loss:.6f}) at step {self.global_step} on Rank {self.global_rank}") |
| |
| return total_loss * 1e-10 |
|
|
| if ( |
| self.global_rank == 0 |
| and self.global_step % self.train_cfg.print_log_every_n_steps == 0 |
| ): |
| print( |
| f"train step {self.global_step}; " |
| f"scene = {[x[:20] for x in batch['scene']]}; " |
| f"context = {batch['context']['index'].tolist()}; " |
| f"loss = {total_loss:.6f}; " |
| ) |
| |
| self.log("info/global_step", self.global_step) |
| |
| |
| if self.step_tracker is not None: |
| self.step_tracker.set_step(self.global_step) |
| |
| del batch |
| if self.global_step % 50 == 0: |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| return total_loss |
| |
| def on_after_backward(self): |
| total_norm = 0.0 |
| counter = 0 |
| for p in self.parameters(): |
| if p.grad is not None: |
| param_norm = p.grad.detach().data.norm(2) |
| total_norm += param_norm.item() ** 2 |
| counter += 1 |
| total_norm = (total_norm / counter) ** 0.5 |
| self.log("loss/grad_norm", total_norm) |
| |
| def test_step(self, batch, batch_idx): |
| batch: BatchedExample = self.data_shim(batch) |
| b, v, _, h, w = batch["target"]["image"].shape |
| assert b == 1 |
| if batch_idx % 100 == 0: |
| print(f"Test step {batch_idx:0>6}.") |
| |
| |
| with self.benchmarker.time("encoder"): |
| gaussians = self.model.encoder( |
| (batch["context"]["image"]+1)/2, |
| self.global_step, |
| )[0] |
| |
| |
| if self.test_cfg.align_pose: |
| output = self.test_step_align(batch, gaussians) |
| else: |
| with self.benchmarker.time("decoder", num_calls=v): |
| output = self.model.decoder.forward( |
| gaussians, |
| batch["target"]["extrinsics"], |
| batch["target"]["intrinsics"], |
| batch["target"]["near"], |
| batch["target"]["far"], |
| (h, w), |
| ) |
| |
| |
| if self.test_cfg.compute_scores: |
| overlap = batch["context"]["overlap"][0] |
| overlap_tag = get_overlap_tag(overlap) |
|
|
| rgb_pred = output.color[0] |
| rgb_gt = batch["target"]["image"][0] |
| all_metrics = { |
| f"lpips_ours": compute_lpips(rgb_gt, rgb_pred).mean(), |
| f"ssim_ours": compute_ssim(rgb_gt, rgb_pred).mean(), |
| f"psnr_ours": compute_psnr(rgb_gt, rgb_pred).mean(), |
| } |
| methods = ['ours'] |
|
|
| self.log_dict(all_metrics) |
| self.print_preview_metrics(all_metrics, methods, overlap_tag=overlap_tag) |
| |
| |
| (scene,) = batch["scene"] |
| name = get_cfg()["wandb"]["name"] |
| path = self.test_cfg.output_path / name |
| if self.test_cfg.save_image: |
| for index, color in zip(batch["target"]["index"][0], output.color[0]): |
| save_image(color, path / scene / f"color/{index:0>6}.png") |
|
|
| if self.test_cfg.save_video: |
| frame_str = "_".join([str(x.item()) for x in batch["context"]["index"][0]]) |
| save_video( |
| [a for a in output.color[0]], |
| path / "video" / f"{scene}_frame_{frame_str}.mp4", |
| ) |
|
|
| if self.test_cfg.save_compare: |
| |
| context_img = inverse_normalize(batch["context"]["image"][0]) |
| comparison = hcat( |
| add_label(vcat(*context_img), "Context"), |
| add_label(vcat(*rgb_gt), "Target (Ground Truth)"), |
| add_label(vcat(*rgb_pred), "Target (Prediction)"), |
| ) |
| save_image(comparison, path / f"{scene}.png") |
| |
| def test_step_align(self, batch, gaussians): |
| self.model.encoder.eval() |
| |
| for param in self.model.encoder.parameters(): |
| param.requires_grad = False |
|
|
| b, v, _, h, w = batch["target"]["image"].shape |
| output_c2ws = batch["target"]["extrinsics"] |
| with torch.set_grad_enabled(True): |
| cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=output_c2ws.device)) |
| cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=output_c2ws.device)) |
| opt_params = [] |
| self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]).to(output_c2ws)) |
| opt_params.append( |
| { |
| "params": [cam_rot_delta], |
| "lr": 0.005, |
| } |
| ) |
| opt_params.append( |
| { |
| "params": [cam_trans_delta], |
| "lr": 0.005, |
| } |
| ) |
| pose_optimizer = torch.optim.Adam(opt_params) |
| extrinsics = output_c2ws.clone() |
| with self.benchmarker.time("optimize"): |
| for i in range(self.test_cfg.pose_align_steps): |
| pose_optimizer.zero_grad() |
| dx, drot = cam_trans_delta, cam_rot_delta |
| rot = rotation_6d_to_matrix( |
| drot + self.identity.expand(b, v, -1) |
| ) |
|
|
| transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1)) |
| transform[..., :3, :3] = rot |
| transform[..., :3, 3] = dx |
|
|
| new_extrinsics = torch.matmul(extrinsics, transform) |
| output = self.model.decoder.forward( |
| gaussians, |
| new_extrinsics, |
| batch["target"]["intrinsics"], |
| batch["target"]["near"], |
| batch["target"]["far"], |
| (h, w), |
| |
| |
| ) |
|
|
| |
| total_loss = 0 |
| for loss_fn in self.losses: |
| loss = loss_fn.forward(output, batch, gaussians, self.global_step) |
| total_loss = total_loss + loss |
|
|
| total_loss.backward() |
| pose_optimizer.step() |
| |
| |
| output = self.model.decoder.forward( |
| gaussians, |
| new_extrinsics, |
| batch["target"]["intrinsics"], |
| batch["target"]["near"], |
| batch["target"]["far"], |
| (h, w), |
| ) |
|
|
| return output |
|
|
| def on_test_end(self) -> None: |
| name = get_cfg()["wandb"]["name"] |
| self.benchmarker.dump(self.test_cfg.output_path / name / "benchmark.json") |
| self.benchmarker.dump_memory( |
| self.test_cfg.output_path / name / "peak_memory.json" |
| ) |
| self.benchmarker.summarize() |
|
|
| @rank_zero_only |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): |
| batch: BatchedExample = self.data_shim(batch) |
|
|
| if self.global_rank == 0: |
| print( |
| f"validation step {self.global_step}; " |
| f"scene = {batch['scene']}; " |
| f"context = {batch['context']['index'].tolist()}" |
| ) |
|
|
| |
| b, v, _, h, w = batch["context"]["image"].shape |
| assert b == 1 |
| visualization_dump = {} |
|
|
| encoder_output, output = self.model(batch["context"]["image"], self.global_step, visualization_dump=visualization_dump) |
| gaussians, pred_pose_enc_list, depth_dict = encoder_output.gaussians, encoder_output.pred_pose_enc_list, encoder_output.depth_dict |
| pred_context_pose, distill_infos = encoder_output.pred_context_pose, encoder_output.distill_infos |
| infos = encoder_output.infos |
|
|
| GS_num = infos['voxelize_ratio'] * (h*w*v) |
| self.log("val/GS_num", GS_num) |
| |
| num_context_views = pred_context_pose['extrinsic'].shape[1] |
| num_target_views = batch["target"]["extrinsics"].shape[1] |
| rgb_pred = output.color[0].float() |
| depth_pred = vis_depth_map(output.depth[0]) |
|
|
| |
| gaussian_means = visualization_dump["depth"][0].squeeze() |
| if gaussian_means.shape[-1] == 3: |
| gaussian_means = gaussian_means.mean(dim=-1) |
|
|
| |
| rgb_gt = (batch["context"]["image"][0].float() + 1) / 2 |
| psnr = compute_psnr(rgb_gt, rgb_pred).mean() |
| self.log(f"val/psnr", psnr) |
| lpips = compute_lpips(rgb_gt, rgb_pred).mean() |
| self.log(f"val/lpips", lpips) |
| ssim = compute_ssim(rgb_gt, rgb_pred).mean() |
| self.log(f"val/ssim", ssim) |
|
|
| |
| consis_absrel = abs_relative_difference( |
| rearrange(output.depth, "b v h w -> (b v) h w"), |
| rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), |
| ) |
| self.log("val/consis_absrel", consis_absrel.mean()) |
| |
| consis_delta1 = delta1_acc( |
| rearrange(output.depth, "b v h w -> (b v) h w"), |
| rearrange(depth_dict['depth'].squeeze(-1), "b v h w -> (b v) h w"), |
| valid_mask=rearrange(torch.ones_like(output.depth, device=output.depth.device, dtype=torch.bool), "b v h w -> (b v) h w"), |
| ) |
| self.log("val/consis_delta1", consis_delta1.mean()) |
|
|
| diff_map = torch.abs(output.depth - depth_dict['depth'].squeeze(-1)) |
| self.log("val/consis_mse", diff_map[distill_infos['conf_mask']].mean()) |
|
|
| |
| context_img = inverse_normalize(batch["context"]["image"][0]) |
| |
| context = [] |
| for i in range(context_img.shape[0]): |
| context.append(context_img[i]) |
| |
| |
| colored_diff_map = vis_depth_map(diff_map[0], near=torch.tensor(1e-4, device=diff_map.device), far=torch.tensor(1.0, device=diff_map.device)) |
| model_depth_pred = depth_dict["depth"].squeeze(-1)[0] |
| model_depth_pred = vis_depth_map(model_depth_pred) |
| |
| render_normal = (get_normal_map(output.depth.flatten(0, 1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2. |
| pred_normal = (get_normal_map(depth_dict['depth'].flatten(0, 1).squeeze(-1), batch["context"]["intrinsics"].flatten(0, 1)).permute(0, 3, 1, 2) + 1) / 2. |
|
|
| comparison = hcat( |
| add_label(vcat(*context), "Context"), |
| add_label(vcat(*rgb_gt), "Target (Ground Truth)"), |
| add_label(vcat(*rgb_pred), "Target (Prediction)"), |
| add_label(vcat(*depth_pred), "Depth (Prediction)"), |
| add_label(vcat(*model_depth_pred), "Depth (VGGT Prediction)"), |
| add_label(vcat(*render_normal), "Normal (Prediction)"), |
| add_label(vcat(*pred_normal), "Normal (VGGT Prediction)"), |
| add_label(vcat(*colored_diff_map), "Diff Map"), |
| ) |
|
|
| comparison = torch.nn.functional.interpolate( |
| comparison.unsqueeze(0), |
| scale_factor=0.5, |
| mode='bicubic', |
| align_corners=False |
| ).squeeze(0) |
| |
| self.logger.log_image( |
| "comparison", |
| [prep_image(add_border(comparison))], |
| step=self.global_step, |
| caption=batch["scene"], |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| if self.encoder_visualizer is not None: |
| for k, image in self.encoder_visualizer.visualize( |
| batch["context"], self.global_step |
| ).items(): |
| self.logger.log_image(k, [prep_image(image)], step=self.global_step) |
| |
| |
| self.render_video_interpolation(batch) |
| self.render_video_wobble(batch) |
| if self.train_cfg.extended_visualization: |
| self.render_video_interpolation_exaggerated(batch) |
|
|
| @rank_zero_only |
| def render_video_wobble(self, batch: BatchedExample) -> None: |
| |
| _, v, _, _ = batch["context"]["extrinsics"].shape |
| if v != 2: |
| return |
|
|
| def trajectory_fn(t): |
| origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
| origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] |
| delta = (origin_a - origin_b).norm(dim=-1) |
| extrinsics = generate_wobble( |
| batch["context"]["extrinsics"][:, 0], |
| delta * 0.25, |
| t, |
| ) |
| intrinsics = repeat( |
| batch["context"]["intrinsics"][:, 0], |
| "b i j -> b v i j", |
| v=t.shape[0], |
| ) |
| return extrinsics, intrinsics |
|
|
| return self.render_video_generic(batch, trajectory_fn, "wobble", num_frames=60) |
|
|
| @rank_zero_only |
| def render_video_interpolation(self, batch: BatchedExample) -> None: |
| _, v, _, _ = batch["context"]["extrinsics"].shape |
|
|
| def trajectory_fn(t): |
| extrinsics = interpolate_extrinsics( |
| batch["context"]["extrinsics"][0, 0], |
| ( |
| batch["context"]["extrinsics"][0, 1] |
| if v == 2 |
| else batch["target"]["extrinsics"][0, 0] |
| ), |
| t, |
| ) |
| intrinsics = interpolate_intrinsics( |
| batch["context"]["intrinsics"][0, 0], |
| ( |
| batch["context"]["intrinsics"][0, 1] |
| if v == 2 |
| else batch["target"]["intrinsics"][0, 0] |
| ), |
| t, |
| ) |
| return extrinsics[None], intrinsics[None] |
|
|
| return self.render_video_generic(batch, trajectory_fn, "rgb") |
|
|
| @rank_zero_only |
| def render_video_interpolation_exaggerated(self, batch: BatchedExample) -> None: |
| |
| _, v, _, _ = batch["context"]["extrinsics"].shape |
| if v != 2: |
| return |
|
|
| def trajectory_fn(t): |
| origin_a = batch["context"]["extrinsics"][:, 0, :3, 3] |
| origin_b = batch["context"]["extrinsics"][:, 1, :3, 3] |
| delta = (origin_a - origin_b).norm(dim=-1) |
| tf = generate_wobble_transformation( |
| delta * 0.5, |
| t, |
| 5, |
| scale_radius_with_t=False, |
| ) |
| extrinsics = interpolate_extrinsics( |
| batch["context"]["extrinsics"][0, 0], |
| ( |
| batch["context"]["extrinsics"][0, 1] |
| if v == 2 |
| else batch["target"]["extrinsics"][0, 0] |
| ), |
| t * 5 - 2, |
| ) |
| intrinsics = interpolate_intrinsics( |
| batch["context"]["intrinsics"][0, 0], |
| ( |
| batch["context"]["intrinsics"][0, 1] |
| if v == 2 |
| else batch["target"]["intrinsics"][0, 0] |
| ), |
| t * 5 - 2, |
| ) |
| return extrinsics @ tf, intrinsics[None] |
|
|
| return self.render_video_generic( |
| batch, |
| trajectory_fn, |
| "interpolation_exagerrated", |
| num_frames=300, |
| smooth=False, |
| loop_reverse=False, |
| ) |
|
|
| @rank_zero_only |
| def render_video_generic( |
| self, |
| batch: BatchedExample, |
| trajectory_fn: TrajectoryFn, |
| name: str, |
| num_frames: int = 30, |
| smooth: bool = True, |
| loop_reverse: bool = True, |
| ) -> None: |
| |
| encoder_output = self.model.encoder((batch["context"]["image"]+1)/2, self.global_step) |
| gaussians, pred_pose_enc_list = encoder_output.gaussians, encoder_output.pred_pose_enc_list |
|
|
| t = torch.linspace(0, 1, num_frames, dtype=torch.float32, device=self.device) |
| if smooth: |
| t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 |
|
|
| extrinsics, intrinsics = trajectory_fn(t) |
|
|
| _, _, _, h, w = batch["context"]["image"].shape |
|
|
| |
| near = repeat(batch["context"]["near"][:, 0], "b -> b v", v=num_frames) |
| far = repeat(batch["context"]["far"][:, 0], "b -> b v", v=num_frames) |
| output = self.model.decoder.forward( |
| gaussians, extrinsics, intrinsics, near, far, (h, w), "depth" |
| ) |
| images = [ |
| vcat(rgb, depth) |
| for rgb, depth in zip(output.color[0], vis_depth_map(output.depth[0])) |
| ] |
|
|
| video = torch.stack(images) |
| video = (video.clip(min=0, max=1) * 255).type(torch.uint8).cpu().numpy() |
| if loop_reverse: |
| video = pack([video, video[::-1][1:-1]], "* c h w")[0] |
| visualizations = { |
| f"video/{name}": wandb.Video(video[None], fps=30, format="mp4") |
| } |
| |
| |
| try: |
| wandb.log(visualizations) |
| except Exception: |
| assert isinstance(self.logger, LocalLogger) |
| for key, value in visualizations.items(): |
| tensor = value._prepare_video(value.data) |
| clip = mpy.ImageSequenceClip(list(tensor), fps=30) |
| dir = LOG_PATH / key |
| dir.mkdir(exist_ok=True, parents=True) |
| clip.write_videofile( |
| str(dir / f"{self.global_step:0>6}.mp4"), logger=None |
| ) |
|
|
| def print_preview_metrics(self, metrics: dict[str, float | Tensor], methods: list[str] | None = None, overlap_tag: str | None = None) -> None: |
| if getattr(self, "running_metrics", None) is None: |
| self.running_metrics = metrics |
| self.running_metric_steps = 1 |
| else: |
| s = self.running_metric_steps |
| self.running_metrics = { |
| k: ((s * v) + metrics[k]) / (s + 1) |
| for k, v in self.running_metrics.items() |
| } |
| self.running_metric_steps += 1 |
|
|
| if overlap_tag is not None: |
| if getattr(self, "running_metrics_sub", None) is None: |
| self.running_metrics_sub = {overlap_tag: metrics} |
| self.running_metric_steps_sub = {overlap_tag: 1} |
| elif overlap_tag not in self.running_metrics_sub: |
| self.running_metrics_sub[overlap_tag] = metrics |
| self.running_metric_steps_sub[overlap_tag] = 1 |
| else: |
| s = self.running_metric_steps_sub[overlap_tag] |
| self.running_metrics_sub[overlap_tag] = {k: ((s * v) + metrics[k]) / (s + 1) |
| for k, v in self.running_metrics_sub[overlap_tag].items()} |
| self.running_metric_steps_sub[overlap_tag] += 1 |
|
|
| metric_list = ["psnr", "lpips", "ssim"] |
|
|
| def print_metrics(runing_metric, methods=None): |
| table = [] |
| if methods is None: |
| methods = ['ours'] |
|
|
| for method in methods: |
| row = [ |
| f"{runing_metric[f'{metric}_{method}']:.3f}" |
| for metric in metric_list |
| ] |
| table.append((method, *row)) |
|
|
| headers = ["Method"] + metric_list |
| table = tabulate(table, headers) |
| print(table) |
|
|
| print("All Pairs:") |
| print_metrics(self.running_metrics, methods) |
| if overlap_tag is not None: |
| for k, v in self.running_metrics_sub.items(): |
| print(f"Overlap: {k}") |
| print_metrics(v, methods) |
|
|
| def configure_optimizers(self): |
| new_params, new_param_names = [], [] |
| pretrained_params, pretrained_param_names = [], [] |
| for name, param in self.named_parameters(): |
| if not param.requires_grad: |
| continue |
| |
| if "gaussian_param_head" in name or "interm" in name: |
| new_params.append(param) |
| new_param_names.append(name) |
| else: |
| pretrained_params.append(param) |
| pretrained_param_names.append(name) |
| |
| param_dicts = [ |
| { |
| "params": new_params, |
| "lr": self.optimizer_cfg.lr, |
| }, |
| { |
| "params": pretrained_params, |
| "lr": self.optimizer_cfg.lr * self.optimizer_cfg.backbone_lr_multiplier, |
| }, |
| ] |
| optimizer = torch.optim.AdamW(param_dicts, lr=self.optimizer_cfg.lr, weight_decay=0.05, betas=(0.9, 0.95)) |
| warm_up_steps = self.optimizer_cfg.warm_up_steps |
| warm_up = torch.optim.lr_scheduler.LinearLR( |
| optimizer, |
| 1 / warm_up_steps, |
| 1, |
| total_iters=warm_up_steps, |
| ) |
| |
| lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=get_cfg()["trainer"]["max_steps"], eta_min=self.optimizer_cfg.lr * 0.1) |
| lr_scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warm_up, lr_scheduler], milestones=[warm_up_steps]) |
|
|
| return { |
| "optimizer": optimizer, |
| "lr_scheduler": { |
| "scheduler": lr_scheduler, |
| "interval": "step", |
| "frequency": 1, |
| }, |
| } |
|
|