Spaces:
Runtime error
Runtime error
| import json | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| import torch | |
| from einops import rearrange | |
| from lightning.pytorch import LightningModule | |
| from tqdm import tqdm | |
| from ..geometry.epipolar_lines import project_rays | |
| from ..geometry.projection import get_world_rays, sample_image_grid | |
| from ..misc.image_io import save_image | |
| from ..visualization.annotation import add_label | |
| from ..visualization.layout import add_border, hcat | |
| class EvaluationIndexGeneratorCfg: | |
| num_target_views: int | |
| min_distance: int | |
| max_distance: int | |
| min_overlap: float | |
| max_overlap: float | |
| output_path: Path | |
| save_previews: bool | |
| seed: int | |
| class IndexEntry: | |
| context: tuple[int, ...] | |
| target: tuple[int, ...] | |
| overlap: Optional[str | float] = None # choose from ["small", "medium", "large"] or a float number indicates the overlap ratio | |
| class EvaluationIndexGenerator(LightningModule): | |
| generator: torch.Generator | |
| cfg: EvaluationIndexGeneratorCfg | |
| index: dict[str, IndexEntry | None] | |
| def __init__(self, cfg: EvaluationIndexGeneratorCfg) -> None: | |
| super().__init__() | |
| self.cfg = cfg | |
| self.generator = torch.Generator() | |
| self.generator.manual_seed(cfg.seed) | |
| self.index = {} | |
| def test_step(self, batch, batch_idx): | |
| b, v, _, h, w = batch["target"]["image"].shape | |
| assert b == 1 | |
| extrinsics = batch["target"]["extrinsics"][0] | |
| intrinsics = batch["target"]["intrinsics"][0] | |
| scene = batch["scene"][0] | |
| context_indices = torch.randperm(v, generator=self.generator) | |
| for context_index in tqdm(context_indices, "Finding context pair"): | |
| xy, _ = sample_image_grid((h, w), self.device) | |
| context_origins, context_directions = get_world_rays( | |
| rearrange(xy, "h w xy -> (h w) xy"), | |
| extrinsics[context_index], | |
| intrinsics[context_index], | |
| ) | |
| # Step away from context view until the minimum overlap threshold is met. | |
| valid_indices = [] | |
| for step in (1, -1): | |
| min_distance = self.cfg.min_distance | |
| max_distance = self.cfg.max_distance | |
| current_index = context_index + step * min_distance | |
| while 0 <= current_index.item() < v: | |
| # Compute overlap. | |
| current_origins, current_directions = get_world_rays( | |
| rearrange(xy, "h w xy -> (h w) xy"), | |
| extrinsics[current_index], | |
| intrinsics[current_index], | |
| ) | |
| projection_onto_current = project_rays( | |
| context_origins, | |
| context_directions, | |
| extrinsics[current_index], | |
| intrinsics[current_index], | |
| ) | |
| projection_onto_context = project_rays( | |
| current_origins, | |
| current_directions, | |
| extrinsics[context_index], | |
| intrinsics[context_index], | |
| ) | |
| overlap_a = projection_onto_context["overlaps_image"].float().mean() | |
| overlap_b = projection_onto_current["overlaps_image"].float().mean() | |
| overlap = min(overlap_a, overlap_b) | |
| delta = (current_index - context_index).abs() | |
| min_overlap = self.cfg.min_overlap | |
| max_overlap = self.cfg.max_overlap | |
| if min_overlap <= overlap <= max_overlap: | |
| valid_indices.append( | |
| (current_index.item(), overlap_a, overlap_b) | |
| ) | |
| # Stop once the camera has panned away too much. | |
| if overlap < min_overlap or delta > max_distance: | |
| break | |
| current_index += step | |
| if valid_indices: | |
| # Pick a random valid view. Index the resulting views. | |
| num_options = len(valid_indices) | |
| chosen = torch.randint( | |
| 0, num_options, size=tuple(), generator=self.generator | |
| ) | |
| chosen, overlap_a, overlap_b = valid_indices[chosen] | |
| context_left = min(chosen, context_index.item()) | |
| context_right = max(chosen, context_index.item()) | |
| delta = context_right - context_left | |
| # Pick non-repeated random target views. | |
| while True: | |
| target_views = torch.randint( | |
| context_left, | |
| context_right + 1, | |
| (self.cfg.num_target_views,), | |
| generator=self.generator, | |
| ) | |
| if (target_views.unique(return_counts=True)[1] == 1).all(): | |
| break | |
| target = tuple(sorted(target_views.tolist())) | |
| self.index[scene] = IndexEntry( | |
| context=(context_left, context_right), | |
| target=target, | |
| ) | |
| # Optionally, save a preview. | |
| if self.cfg.save_previews: | |
| preview_path = self.cfg.output_path / "previews" | |
| preview_path.mkdir(exist_ok=True, parents=True) | |
| a = batch["target"]["image"][0, chosen] | |
| a = add_label(a, f"Overlap: {overlap_a * 100:.1f}%") | |
| b = batch["target"]["image"][0, context_index] | |
| b = add_label(b, f"Overlap: {overlap_b * 100:.1f}%") | |
| vis = add_border(add_border(hcat(a, b)), 1, 0) | |
| vis = add_label(vis, f"Distance: {delta} frames") | |
| save_image(add_border(vis), preview_path / f"{scene}.png") | |
| break | |
| else: | |
| # This happens if no starting frame produces a valid evaluation example. | |
| self.index[scene] = None | |
| def save_index(self) -> None: | |
| self.cfg.output_path.mkdir(exist_ok=True, parents=True) | |
| with (self.cfg.output_path / "evaluation_index.json").open("w") as f: | |
| json.dump( | |
| {k: None if v is None else asdict(v) for k, v in self.index.items()}, f | |
| ) | |