| | from pathlib import Path |
| | from random import randrange |
| | from typing import Optional |
| |
|
| | import numpy as np |
| | import torch |
| | |
| | import swanlab as wandb |
| | from einops import rearrange, reduce, repeat |
| | from jaxtyping import Bool, Float |
| | from torch import Tensor |
| |
|
| | from ....dataset.types import BatchedViews |
| | from ....misc.heterogeneous_pairings import generate_heterogeneous_index |
| | from ....visualization.annotation import add_label |
| | from ....visualization.color_map import apply_color_map, apply_color_map_to_image |
| | from ....visualization.colors import get_distinct_color |
| | from ....visualization.drawing.lines import draw_lines |
| | from ....visualization.drawing.points import draw_points |
| | from ....visualization.layout import add_border, hcat, vcat |
| | |
| | from ..encoder_depthsplat import EncoderDepthSplat |
| | |
| | from .encoder_visualizer import EncoderVisualizer |
| | from .encoder_visualizer_depthsplat_cfg import EncoderVisualizerDepthSplatCfg |
| |
|
| |
|
| | def box( |
| | image: Float[Tensor, "3 height width"], |
| | ) -> Float[Tensor, "3 new_height new_width"]: |
| | return add_border(add_border(image), 1, 0) |
| |
|
| |
|
| | class EncoderVisualizerDepthSplat( |
| | EncoderVisualizer[EncoderVisualizerDepthSplatCfg, EncoderDepthSplat] |
| | ): |
| | def visualize( |
| | self, |
| | context: BatchedViews, |
| | global_step: int, |
| | ) -> dict[str, Float[Tensor, "3 _ _"]]: |
| | |
| | return {} |
| |
|
| | visualization_dump = {} |
| |
|
| | softmax_weights = [] |
| |
|
| | def hook(module, input, output): |
| | softmax_weights.append(output) |
| |
|
| | |
| | handles = [ |
| | layer[0].fn.attend.register_forward_hook(hook) |
| | for layer in self.encoder.epipolar_transformer.transformer.layers |
| | ] |
| |
|
| | result = self.encoder.forward( |
| | context, |
| | global_step, |
| | visualization_dump=visualization_dump, |
| | deterministic=True, |
| | ) |
| |
|
| | |
| | for handle in handles: |
| | handle.remove() |
| |
|
| | softmax_weights = torch.stack(softmax_weights) |
| |
|
| | |
| | context_images = context["image"] |
| | _, _, _, h, w = context_images.shape |
| | length = min(h, w) |
| | min_resolution = self.cfg.min_resolution |
| | scale_multiplier = (min_resolution + length - 1) // length |
| | if scale_multiplier > 1: |
| | context_images = repeat( |
| | context_images, |
| | "b v c h w -> b v c (h rh) (w rw)", |
| | rh=scale_multiplier, |
| | rw=scale_multiplier, |
| | ) |
| |
|
| | |
| | if self.cfg.export_ply and wandb.run is not None: |
| | name = wandb.run._name.split(" ")[0] |
| | ply_path = Path(f"outputs/gaussians/{name}/{global_step:0>6}.ply") |
| | export_ply( |
| | context["extrinsics"][0, 0], |
| | result.means[0], |
| | visualization_dump["scales"][0], |
| | visualization_dump["rotations"][0], |
| | result.harmonics[0], |
| | result.opacities[0], |
| | ply_path, |
| | ) |
| |
|
| | return { |
| | "attention": self.visualize_attention( |
| | context_images, |
| | visualization_dump["sampling"], |
| | softmax_weights, |
| | ), |
| | "epipolar_samples": self.visualize_epipolar_samples( |
| | context_images, |
| | visualization_dump["sampling"], |
| | ), |
| | "epipolar_color_samples": self.visualize_epipolar_color_samples( |
| | context_images, |
| | context, |
| | ), |
| | "gaussians": self.visualize_gaussians( |
| | context["image"], |
| | result.opacities, |
| | result.covariances, |
| | result.harmonics[..., 0], |
| | ), |
| | "overlaps": self.visualize_overlaps( |
| | context["image"], |
| | visualization_dump["sampling"], |
| | visualization_dump.get("is_monocular", None), |
| | ), |
| | "depth": self.visualize_depth( |
| | context, |
| | visualization_dump["depth"], |
| | ), |
| | } |
| |
|
| | def visualize_attention( |
| | self, |
| | context_images: Float[Tensor, "batch view 3 height width"], |
| | sampling: None, |
| | attention: Float[Tensor, "layer bvr head 1 sample"], |
| | ) -> Float[Tensor, "3 vis_height vis_width"]: |
| | device = context_images.device |
| |
|
| | |
| | b, v, ov, r, s, _ = sampling.xy_sample.shape |
| | rb = randrange(b) |
| | rv = randrange(v) |
| | rov = randrange(ov) |
| | num_samples = self.cfg.num_samples |
| | rr = np.random.choice(r, num_samples, replace=False) |
| | rr = torch.tensor(rr, dtype=torch.int64, device=device) |
| |
|
| | |
| | ray_view = draw_points( |
| | context_images[rb, rv], |
| | sampling.xy_ray[rb, rv, rr], |
| | 0, |
| | radius=4, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| | ray_view = draw_points( |
| | ray_view, |
| | sampling.xy_ray[rb, rv, rr], |
| | [get_distinct_color(i) for i, _ in enumerate(rr)], |
| | radius=3, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | |
| | attention = rearrange( |
| | attention, "l (b v r) hd () s -> l b v r hd s", b=b, v=v, r=r |
| | ) |
| | attention = attention[:, rb, rv, rr, :, :] |
| | num_layers, _, hd, _ = attention.shape |
| |
|
| | vis = [] |
| | for il in range(num_layers): |
| | vis_layer = [] |
| | for ihd in range(hd): |
| | |
| | color = [get_distinct_color(i) for i, _ in enumerate(rr)] |
| | color = torch.tensor(color, device=attention.device) |
| | color = rearrange(color, "r c -> r () c") |
| | attn = rearrange(attention[il, :, ihd], "r s -> r s ()") |
| | color = rearrange(attn * color, "r s c -> (r s ) c") |
| |
|
| | |
| | vis_layer_head = draw_lines( |
| | context_images[rb, self.encoder.sampler.index_v[rv, rov]], |
| | rearrange( |
| | sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy" |
| | ), |
| | rearrange( |
| | sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy" |
| | ), |
| | color, |
| | 3, |
| | cap="butt", |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| | vis_layer.append(vis_layer_head) |
| | vis.append(add_label(vcat(*vis_layer), f"Layer {il}")) |
| | vis = add_label(add_border(add_border(hcat(*vis)), 1, 0), "Keys & Values") |
| | vis = add_border(hcat(add_label(ray_view, "ray_view"), vis, align="top")) |
| | return vis |
| |
|
| | def visualize_depth( |
| | self, |
| | context: BatchedViews, |
| | multi_depth: Float[Tensor, "batch view height width surface spp"], |
| | ) -> Float[Tensor, "3 vis_width vis_height"]: |
| | multi_vis = [] |
| | *_, srf, _ = multi_depth.shape |
| | for i in range(srf): |
| | depth = multi_depth[..., i, :] |
| | depth = depth.mean(dim=-1) |
| |
|
| | |
| | near = rearrange(context["near"], "b v -> b v () ()") |
| | far = rearrange(context["far"], "b v -> b v () ()") |
| | relative_depth = (depth - near) / (far - near) |
| | relative_disparity = 1 - (1 / depth - 1 / far) / (1 / near - 1 / far) |
| |
|
| | relative_depth = apply_color_map_to_image(relative_depth, "turbo") |
| | relative_depth = vcat(*[hcat(*x) for x in relative_depth]) |
| | relative_depth = add_label(relative_depth, "Depth") |
| | relative_disparity = apply_color_map_to_image(relative_disparity, "turbo") |
| | relative_disparity = vcat(*[hcat(*x) for x in relative_disparity]) |
| | relative_disparity = add_label(relative_disparity, "Disparity") |
| | multi_vis.append(add_border(hcat(relative_depth, relative_disparity))) |
| |
|
| | return add_border(vcat(*multi_vis)) |
| |
|
| | def visualize_overlaps( |
| | self, |
| | context_images: Float[Tensor, "batch view 3 height width"], |
| | sampling: None, |
| | is_monocular: Optional[Bool[Tensor, "batch view height width"]] = None, |
| | ) -> Float[Tensor, "3 vis_width vis_height"]: |
| | device = context_images.device |
| | b, v, _, h, w = context_images.shape |
| | green = torch.tensor([0.235, 0.706, 0.294], device=device)[..., None, None] |
| | rb = randrange(b) |
| | valid = sampling.valid[rb].float() |
| | ds = self.encoder.cfg.epipolar_transformer.downscale |
| | valid = repeat( |
| | valid, |
| | "v ov (h w) -> v ov c (h rh) (w rw)", |
| | c=3, |
| | h=h // ds, |
| | w=w // ds, |
| | rh=ds, |
| | rw=ds, |
| | ) |
| |
|
| | if is_monocular is not None: |
| | is_monocular = is_monocular[rb].float() |
| | is_monocular = repeat(is_monocular, "v h w -> v c h w", c=3, h=h, w=w) |
| |
|
| | |
| | context_images = context_images[rb] |
| | index, _ = generate_heterogeneous_index(v) |
| | valid = valid * (green + context_images[index]) / 2 |
| |
|
| | vis = vcat(*(hcat(im, hcat(*v)) for im, v in zip(context_images, valid))) |
| | vis = add_label(vis, "Context Overlaps") |
| |
|
| | if is_monocular is not None: |
| | vis = hcat(vis, add_label(vcat(*is_monocular), "Monocular?")) |
| |
|
| | return add_border(vis) |
| |
|
| | def visualize_gaussians( |
| | self, |
| | context_images: Float[Tensor, "batch view 3 height width"], |
| | opacities: Float[Tensor, "batch vrspp"], |
| | covariances: Float[Tensor, "batch vrspp 3 3"], |
| | colors: Float[Tensor, "batch vrspp 3"], |
| | ) -> Float[Tensor, "3 vis_height vis_width"]: |
| | b, v, _, h, w = context_images.shape |
| | rb = randrange(b) |
| | context_images = context_images[rb] |
| | opacities = repeat( |
| | opacities[rb], "(v h w spp) -> spp v c h w", v=v, c=3, h=h, w=w |
| | ) |
| | colors = rearrange(colors[rb], "(v h w spp) c -> spp v c h w", v=v, h=h, w=w) |
| |
|
| | |
| | det = covariances[rb].det() |
| | det = apply_color_map(det / det.max(), "inferno") |
| | det = rearrange(det, "(v h w spp) c -> spp v c h w", v=v, h=h, w=w) |
| |
|
| | return add_border( |
| | hcat( |
| | add_label(box(hcat(*context_images)), "Context"), |
| | add_label(box(vcat(*[hcat(*x) for x in opacities])), "Opacities"), |
| | add_label( |
| | box(vcat(*[hcat(*x) for x in (colors * opacities)])), "Colors" |
| | ), |
| | add_label(box(vcat(*[hcat(*x) for x in colors])), "Colors (Raw)"), |
| | add_label(box(vcat(*[hcat(*x) for x in det])), "Determinant"), |
| | ) |
| | ) |
| |
|
| | def visualize_probabilities( |
| | self, |
| | context_images: Float[Tensor, "batch view 3 height width"], |
| | sampling: None, |
| | pdf: Float[Tensor, "batch view ray sample"], |
| | ) -> Float[Tensor, "3 vis_height vis_width"]: |
| | device = context_images.device |
| |
|
| | |
| | b, v, ov, r, _, _ = sampling.xy_sample.shape |
| | rb = randrange(b) |
| | rv = randrange(v) |
| | rov = randrange(ov) |
| | num_samples = self.cfg.num_samples |
| | rr = np.random.choice(r, num_samples, replace=False) |
| | rr = torch.tensor(rr, dtype=torch.int64, device=device) |
| | colors = [get_distinct_color(i) for i, _ in enumerate(rr)] |
| | colors = torch.tensor(colors, dtype=torch.float32, device=device) |
| |
|
| | |
| | ray_view = draw_points( |
| | context_images[rb, rv], |
| | sampling.xy_ray[rb, rv, rr], |
| | 0, |
| | radius=4, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| | ray_view = draw_points( |
| | ray_view, |
| | sampling.xy_ray[rb, rv, rr], |
| | colors, |
| | radius=3, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | |
| | pdf = pdf[rb, rv, rr] |
| | pdf = rearrange(pdf, "r s -> r s ()") |
| | colors = rearrange(colors, "r c -> r () c") |
| | sample_view = draw_lines( |
| | context_images[rb, self.encoder.sampler.index_v[rv, rov]], |
| | rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | rearrange(pdf * colors, "r s c -> (r s) c"), |
| | 6, |
| | cap="butt", |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | |
| | pdf_magnified = pdf / reduce(pdf, "r s () -> r () ()", "max") |
| | sample_view_magnified = draw_lines( |
| | context_images[rb, self.encoder.sampler.index_v[rv, rov]], |
| | rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | rearrange(pdf_magnified * colors, "r s c -> (r s) c"), |
| | 6, |
| | cap="butt", |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | return add_border( |
| | hcat( |
| | add_label(ray_view, "Rays"), |
| | add_label(sample_view, "Samples"), |
| | add_label(sample_view_magnified, "Samples (Magnified PDF)"), |
| | ) |
| | ) |
| |
|
| | def visualize_epipolar_samples( |
| | self, |
| | context_images: Float[Tensor, "batch view 3 height width"], |
| | sampling: None, |
| | ) -> Float[Tensor, "3 vis_height vis_width"]: |
| | device = context_images.device |
| |
|
| | |
| | b, v, ov, r, s, _ = sampling.xy_sample.shape |
| | rb = randrange(b) |
| | rv = randrange(v) |
| | rov = randrange(ov) |
| | num_samples = self.cfg.num_samples |
| | rr = np.random.choice(r, num_samples, replace=False) |
| | rr = torch.tensor(rr, dtype=torch.int64, device=device) |
| |
|
| | |
| | ray_view = draw_points( |
| | context_images[rb, rv], |
| | sampling.xy_ray[rb, rv, rr], |
| | 0, |
| | radius=4, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| | ray_view = draw_points( |
| | ray_view, |
| | sampling.xy_ray[rb, rv, rr], |
| | [get_distinct_color(i) for i, _ in enumerate(rr)], |
| | radius=3, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | |
| | |
| | sample_view = draw_lines( |
| | context_images[rb, self.encoder.sampler.index_v[rv, rov]], |
| | sampling.xy_sample_near[rb, rv, rov, rr, 0], |
| | sampling.xy_sample_far[rb, rv, rov, rr, -1], |
| | 0, |
| | 5, |
| | cap="butt", |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | |
| | color = repeat( |
| | torch.tensor([0, 1], device=device), |
| | "ab -> r (s ab) c", |
| | r=len(rr), |
| | s=(s + 1) // 2, |
| | c=3, |
| | ) |
| | color = rearrange(color[:, :s], "r s c -> (r s) c") |
| |
|
| | |
| | sample_view = draw_lines( |
| | sample_view, |
| | rearrange(sampling.xy_sample_near[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | rearrange(sampling.xy_sample_far[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | color, |
| | 3, |
| | cap="butt", |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | |
| | sample_view = draw_points( |
| | sample_view, |
| | rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | 0, |
| | radius=4, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| | sample_view = draw_points( |
| | sample_view, |
| | rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | [get_distinct_color(i // s) for i in range(s * len(rr))], |
| | radius=3, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | return add_border( |
| | hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View")) |
| | ) |
| |
|
| | def visualize_epipolar_color_samples( |
| | self, |
| | context_images: Float[Tensor, "batch view 3 height width"], |
| | context: BatchedViews, |
| | ) -> Float[Tensor, "3 vis_height vis_width"]: |
| | device = context_images.device |
| |
|
| | sampling = self.encoder.sampler( |
| | context["image"], |
| | context["extrinsics"], |
| | context["intrinsics"], |
| | context["near"], |
| | context["far"], |
| | ) |
| |
|
| | |
| | b, v, ov, r, s, _ = sampling.xy_sample.shape |
| | rb = randrange(b) |
| | rv = randrange(v) |
| | rov = randrange(ov) |
| | num_samples = self.cfg.num_samples |
| | rr = np.random.choice(r, num_samples, replace=False) |
| | rr = torch.tensor(rr, dtype=torch.int64, device=device) |
| |
|
| | |
| | ray_view = draw_points( |
| | context_images[rb, rv], |
| | sampling.xy_ray[rb, rv, rr], |
| | 0, |
| | radius=4, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| | ray_view = draw_points( |
| | ray_view, |
| | sampling.xy_ray[rb, rv, rr], |
| | [get_distinct_color(i) for i, _ in enumerate(rr)], |
| | radius=3, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | |
| | sample_view = draw_points( |
| | context_images[rb, self.encoder.sampler.index_v[rv, rov]], |
| | rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | [get_distinct_color(i // s) for i in range(s * len(rr))], |
| | radius=4, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| | sample_view = draw_points( |
| | sample_view, |
| | rearrange(sampling.xy_sample[rb, rv, rov, rr], "r s xy -> (r s) xy"), |
| | rearrange(sampling.features[rb, rv, rov, rr], "r s c -> (r s) c"), |
| | radius=3, |
| | x_range=(0, 1), |
| | y_range=(0, 1), |
| | ) |
| |
|
| | return add_border( |
| | hcat(add_label(ray_view, "Ray View"), add_label(sample_view, "Sample View")) |
| | ) |
| |
|