V3D / sgm /data /co3d.py
heheyas
init
cfb7702
"""
adopted from SparseFusion
Wrapper for the full CO3Dv2 dataset
#@ Modified from https://github.com/facebookresearch/pytorch3d
"""
import json
import logging
import math
import os
import random
import time
import warnings
from collections import defaultdict
from itertools import islice
from typing import (
Any,
ClassVar,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
)
from einops import rearrange, repeat
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from pytorch3d.utils import opencv_from_cameras_projection
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from sgm.data.json_index_dataset import (
FrameAnnotsEntry,
_bbox_xywh_to_xyxy,
_bbox_xyxy_to_xywh,
_clamp_box_to_image_bounds_and_round,
_crop_around_box,
_get_1d_bounds,
_get_bbox_from_mask,
_get_clamp_bbox,
_load_1bit_png_mask,
_load_16big_png_depth,
_load_depth,
_load_depth_mask,
_load_image,
_load_mask,
_load_pointcloud,
_rescale_bbox,
_safe_as_tensor,
_seq_name_to_seed,
)
from sgm.data.objaverse import video_collate_fn
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
get_available_subset_names,
)
from pytorch3d.renderer.cameras import PerspectiveCameras
logger = logging.getLogger(__name__)
from dataclasses import dataclass, field, fields
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
CO3D_ALL_CATEGORIES = list(
reversed(
[
"baseballbat",
"banana",
"bicycle",
"microwave",
"tv",
"cellphone",
"toilet",
"hairdryer",
"couch",
"kite",
"pizza",
"umbrella",
"wineglass",
"laptop",
"hotdog",
"stopsign",
"frisbee",
"baseballglove",
"cup",
"parkingmeter",
"backpack",
"toyplane",
"toybus",
"handbag",
"chair",
"keyboard",
"car",
"motorcycle",
"carrot",
"bottle",
"sandwich",
"remote",
"bowl",
"skateboard",
"toaster",
"mouse",
"toytrain",
"book",
"toytruck",
"orange",
"broccoli",
"plant",
"teddybear",
"suitcase",
"bench",
"ball",
"cake",
"vase",
"hydrant",
"apple",
"donut",
]
)
)
CO3D_ALL_TEN = [
"donut",
"apple",
"hydrant",
"vase",
"cake",
"ball",
"bench",
"suitcase",
"teddybear",
"plant",
]
# @ FROM https://github.com/facebookresearch/pytorch3d
@dataclass
class FrameData(Mapping[str, Any]):
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
frame_timestamp: The time elapsed since the start of a sequence in sec.
image_size_hw: The size of the image in pixels; (height, width) tensor
of shape (2,).
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box tightly enclosing the foreground object in the
format (x0, y0, width, height). The convention assumes that
`x0+width` and `y0+height` includes the boundary of the box.
I.e., to slice out the corresponding crop from an image tensor `I`
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
in the original image coordinates in the format (x0, y0, width, height).
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
from `bbox_xywh` due to padding (which can happen e.g. due to
setting `JsonIndexDataset.box_crop_context > 0`)
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = ""
depth_map: Optional[torch.Tensor] = torch.zeros(1)
depth_mask: Optional[torch.Tensor] = torch.zeros(1)
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
crop_bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = ""
sequence_point_cloud: Optional[Pointclouds] = torch.zeros(1)
sequence_point_cloud_idx: Optional[torch.Tensor] = torch.zeros(1)
frame_type: Union[str, List[str], None] = "" # known | unseen
meta: dict = field(default_factory=lambda: {})
valid_region: Optional[torch.Tensor] = None
category_one_hot: Optional[torch.Tensor] = None
def to(self, *args, **kwargs):
new_params = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[f.name] = value.to(*args, **kwargs)
else:
new_params[f.name] = value
return type(self)(**new_params)
def cpu(self):
return self.to(device=torch.device("cpu"))
def cuda(self):
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def __iter__(self):
for f in fields(self):
yield f.name
def __getitem__(self, key):
return getattr(self, key)
def __len__(self):
return len(fields(self))
@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)
# @ MODIFIED FROM https://github.com/facebookresearch/pytorch3d
class CO3Dv2Wrapper(torch.utils.data.Dataset):
def __init__(
self,
root_dir="/drive/datasets/co3d/",
category="hydrant",
subset="fewview_train",
stage="train",
sample_batch_size=20,
image_size=256,
masked=False,
deprecated_val_region=False,
return_frame_data_list=False,
reso: int = 256,
mask_type: str = "random",
cond_aug_mean=-3.0,
cond_aug_std=0.5,
condition_on_elevation=False,
fps_id=0.0,
motion_bucket_id=300.0,
num_frames: int = 20,
use_mask: bool = True,
load_pixelnerf: bool = True,
scale_pose: bool = True,
max_n_cond: int = 5,
min_n_cond: int = 2,
cond_on_multi: bool = False,
):
root = root_dir
from typing import List
from co3d.dataset.data_types import (
FrameAnnotation,
SequenceAnnotation,
load_dataclass_jgzip,
)
self.dataset_root = root
self.path_manager = None
self.subset = subset
self.stage = stage
self.subset_lists_file: List[str] = [
f"{self.dataset_root}/{category}/set_lists/set_lists_{subset}.json"
]
self.subsets: Optional[List[str]] = [subset]
self.sample_batch_size = sample_batch_size
self.limit_to: int = 0
self.limit_sequences_to: int = 0
self.pick_sequence: Tuple[str, ...] = ()
self.exclude_sequence: Tuple[str, ...] = ()
self.limit_category_to: Tuple[int, ...] = ()
self.load_images: bool = True
self.load_depths: bool = False
self.load_depth_masks: bool = False
self.load_masks: bool = True
self.load_point_clouds: bool = False
self.max_points: int = 0
self.mask_images: bool = False
self.mask_depths: bool = False
self.image_height: Optional[int] = image_size
self.image_width: Optional[int] = image_size
self.box_crop: bool = True
self.box_crop_mask_thr: float = 0.4
self.box_crop_context: float = 0.3
self.remove_empty_masks: bool = True
self.n_frames_per_sequence: int = -1
self.seed: int = 0
self.sort_frames: bool = False
self.eval_batches: Any = None
self.img_h = self.image_height
self.img_w = self.image_width
self.masked = masked
self.deprecated_val_region = deprecated_val_region
self.return_frame_data_list = return_frame_data_list
self.reso = reso
self.num_frames = num_frames
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
self.condition_on_elevation = condition_on_elevation
self.fps_id = fps_id
self.motion_bucket_id = motion_bucket_id
self.mask_type = mask_type
self.use_mask = use_mask
self.load_pixelnerf = load_pixelnerf
self.scale_pose = scale_pose
self.max_n_cond = max_n_cond
self.min_n_cond = min_n_cond
self.cond_on_multi = cond_on_multi
if self.cond_on_multi:
assert self.min_n_cond == self.max_n_cond
start_time = time.time()
if "all_" in category or category == "all":
self.category_frame_annotations = []
self.category_sequence_annotations = []
self.subset_lists_file = []
if category == "all":
cats = CO3D_ALL_CATEGORIES
elif category == "all_four":
cats = ["hydrant", "teddybear", "motorcycle", "bench"]
elif category == "all_ten":
cats = [
"donut",
"apple",
"hydrant",
"vase",
"cake",
"ball",
"bench",
"suitcase",
"teddybear",
"plant",
]
elif category == "all_15":
cats = [
"hydrant",
"teddybear",
"motorcycle",
"bench",
"hotdog",
"remote",
"suitcase",
"donut",
"plant",
"toaster",
"keyboard",
"handbag",
"toyplane",
"tv",
"orange",
]
else:
print("UNSPECIFIED CATEGORY SUBSET")
cats = ["hydrant", "teddybear"]
print("loading", cats)
for cat in cats:
self.category_frame_annotations.extend(
load_dataclass_jgzip(
f"{self.dataset_root}/{cat}/frame_annotations.jgz",
List[FrameAnnotation],
)
)
self.category_sequence_annotations.extend(
load_dataclass_jgzip(
f"{self.dataset_root}/{cat}/sequence_annotations.jgz",
List[SequenceAnnotation],
)
)
self.subset_lists_file.append(
f"{self.dataset_root}/{cat}/set_lists/set_lists_{subset}.json"
)
else:
self.category_frame_annotations = load_dataclass_jgzip(
f"{self.dataset_root}/{category}/frame_annotations.jgz",
List[FrameAnnotation],
)
self.category_sequence_annotations = load_dataclass_jgzip(
f"{self.dataset_root}/{category}/sequence_annotations.jgz",
List[SequenceAnnotation],
)
self.subset_to_image_path = None
self._load_frames()
self._load_sequences()
self._sort_frames()
self._load_subset_lists()
self._filter_db() # also computes sequence indices
# self._extract_and_set_eval_batches()
# print(self.eval_batches)
logger.info(str(self))
self.seq_to_frames = {}
for fi, item in enumerate(self.frame_annots):
if item["frame_annotation"].sequence_name in self.seq_to_frames:
self.seq_to_frames[item["frame_annotation"].sequence_name].append(fi)
else:
self.seq_to_frames[item["frame_annotation"].sequence_name] = [fi]
if self.stage != "test" or self.subset != "fewview_test":
count = 0
new_seq_to_frames = {}
for item in self.seq_to_frames:
if len(self.seq_to_frames[item]) > 10:
count += 1
new_seq_to_frames[item] = self.seq_to_frames[item]
self.seq_to_frames = new_seq_to_frames
self.seq_list = list(self.seq_to_frames.keys())
# @ REMOVE A FEW TRAINING SEQ THAT CAUSES BUG
remove_list = ["411_55952_107659", "376_42884_85882"]
for remove_idx in remove_list:
if remove_idx in self.seq_to_frames:
self.seq_list.remove(remove_idx)
print("removing", remove_idx)
print("total training seq", len(self.seq_to_frames))
print("data loading took", time.time() - start_time, "seconds")
self.all_category_list = list(CO3D_ALL_CATEGORIES)
self.all_category_list.sort()
self.cat_to_idx = {}
for ci, cname in enumerate(self.all_category_list):
self.cat_to_idx[cname] = ci
def __len__(self):
return len(self.seq_list)
def __getitem__(self, index):
seq_index = self.seq_list[index]
if self.subset == "fewview_test" and self.stage == "test":
batch_idx = torch.arange(len(self.seq_to_frames[seq_index]))
elif self.stage == "test":
batch_idx = (
torch.linspace(
0, len(self.seq_to_frames[seq_index]) - 1, self.sample_batch_size
)
.long()
.tolist()
)
else:
rand = torch.randperm(len(self.seq_to_frames[seq_index]))
batch_idx = rand[: min(len(rand), self.sample_batch_size)]
frame_data_list = []
idx_list = []
timestamp_list = []
for idx in batch_idx:
idx_list.append(self.seq_to_frames[seq_index][idx])
timestamp_list.append(
self.frame_annots[self.seq_to_frames[seq_index][idx]][
"frame_annotation"
].frame_timestamp
)
frame_data_list.append(
self._get_frame(int(self.seq_to_frames[seq_index][idx]))
)
time_order = torch.argsort(torch.tensor(timestamp_list))
frame_data_list = [frame_data_list[i] for i in time_order]
frame_data = FrameData.collate(frame_data_list)
image_size = torch.Tensor([self.image_height]).repeat(
frame_data.camera.R.shape[0], 2
)
frame_dict = {
"R": frame_data.camera.R,
"T": frame_data.camera.T,
"f": frame_data.camera.focal_length,
"c": frame_data.camera.principal_point,
"images": frame_data.image_rgb * frame_data.fg_probability
+ (1 - frame_data.fg_probability),
"valid_region": frame_data.mask_crop,
"bbox": frame_data.valid_region,
"image_size": image_size,
"frame_type": frame_data.frame_type,
"idx": seq_index,
"category": frame_data.category_one_hot,
}
if not self.masked:
frame_dict["images_full"] = frame_data.image_rgb
frame_dict["masks"] = frame_data.fg_probability
frame_dict["mask_crop"] = frame_data.mask_crop
cond_aug = np.exp(
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
)
def _pad(input):
return torch.cat([input, torch.flip(input, dims=[0])], dim=0)[
: self.num_frames
]
if len(frame_dict["images"]) < self.num_frames:
for k in frame_dict:
if isinstance(frame_dict[k], torch.Tensor):
frame_dict[k] = _pad(frame_dict[k])
data = dict()
if "images_full" in frame_dict:
frames = frame_dict["images_full"] * 2 - 1
else:
frames = frame_dict["images"] * 2 - 1
data["frames"] = frames
cond = frames[0]
data["cond_frames_without_noise"] = cond
data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames)
data["cond_frames"] = cond + cond_aug * torch.randn_like(cond)
data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames)
data["motion_bucket_id"] = torch.as_tensor(
[self.motion_bucket_id] * self.num_frames
)
data["num_video_frames"] = self.num_frames
data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames)
if self.load_pixelnerf:
data["pixelnerf_input"] = dict()
# Rs = frame_dict["R"].transpose(-1, -2)
# Ts = frame_dict["T"]
# Rs[:, :, 2] *= -1
# Rs[:, :, 0] *= -1
# Ts[:, 2] *= -1
# Ts[:, 0] *= -1
# c2ws = torch.zeros(Rs.shape[0], 4, 4)
# c2ws[:, :3, :3] = Rs
# c2ws[:, :3, 3] = Ts
# c2ws[:, 3, 3] = 1
# c2ws = c2ws.inverse()
# # c2ws[..., 0] *= -1
# # c2ws[..., 2] *= -1
# cx = frame_dict["c"][:, 0]
# cy = frame_dict["c"][:, 1]
# fx = frame_dict["f"][:, 0]
# fy = frame_dict["f"][:, 1]
# intrinsics = torch.zeros(cx.shape[0], 3, 3)
# intrinsics[:, 2, 2] = 1
# intrinsics[:, 0, 0] = fx
# intrinsics[:, 1, 1] = fy
# intrinsics[:, 0, 2] = cx
# intrinsics[:, 1, 2] = cy
scene_cameras = PerspectiveCameras(
R=frame_dict["R"],
T=frame_dict["T"],
focal_length=frame_dict["f"],
principal_point=frame_dict["c"],
image_size=frame_dict["image_size"],
)
R, T, intrinsics = opencv_from_cameras_projection(
scene_cameras, frame_dict["image_size"]
)
c2ws = torch.zeros(R.shape[0], 4, 4)
c2ws[:, :3, :3] = R
c2ws[:, :3, 3] = T
c2ws[:, 3, 3] = 1.0
c2ws = c2ws.inverse()
c2ws[..., 1:3] *= -1
intrinsics[:, :2] /= 256
cameras = torch.zeros(c2ws.shape[0], 25)
cameras[..., :16] = c2ws.reshape(-1, 16)
cameras[..., 16:] = intrinsics.reshape(-1, 9)
if self.scale_pose:
c2ws = cameras[..., :16].reshape(-1, 4, 4)
center = c2ws[:, :3, 3].mean(0)
radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
scale = 1.5 / radius
c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
cameras[..., :16] = c2ws.reshape(-1, 16)
data["pixelnerf_input"]["frames"] = frames
data["pixelnerf_input"]["cameras"] = cameras
data["pixelnerf_input"]["rgb"] = (
F.interpolate(
frames,
(self.image_width // 8, self.image_height // 8),
mode="bilinear",
align_corners=False,
)
+ 1
) * 0.5
return data
# if self.return_frame_data_list:
# return (frame_dict, frame_data_list)
# return frame_dict
def collate_fn(self, batch):
# a hack to add source index and keep consistent within a batch
if self.max_n_cond > 1:
# TODO implement this
n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1)
# debug
# source_index = [0]
if n_cond > 1:
for b in batch:
source_index = [0] + np.random.choice(
np.arange(1, self.num_frames),
self.max_n_cond - 1,
replace=False,
).tolist()
b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
b["pixelnerf_input"]["n_cond"] = n_cond
b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
"cameras"
][source_index]
if self.cond_on_multi:
b["cond_frames_without_noise"] = b["frames"][source_index]
ret = video_collate_fn(batch)
if self.cond_on_multi:
ret["cond_frames_without_noise"] = rearrange(
ret["cond_frames_without_noise"], "b t ... -> (b t) ..."
)
return ret
def _get_frame(self, index):
# if index >= len(self.frame_annots):
# raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
entry = self.frame_annots[index]["frame_annotation"]
# pyre-ignore[16]
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
frame_data = FrameData(
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
sequence_name=entry.sequence_name,
sequence_category=self.seq_annots[entry.sequence_name].category,
camera_quality_score=_safe_as_tensor(
self.seq_annots[entry.sequence_name].viewpoint_quality_score,
torch.float,
),
point_cloud_quality_score=_safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
)
# The rest of the fields are optional
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
(
frame_data.fg_probability,
frame_data.mask_path,
frame_data.bbox_xywh,
clamp_bbox_xyxy,
frame_data.crop_bbox_xywh,
) = self._load_crop_fg_probability(entry)
scale = 1.0
if self.load_images and entry.image is not None:
# original image size
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
(
frame_data.image_rgb,
frame_data.image_path,
frame_data.mask_crop,
scale,
) = self._load_crop_images(
entry, frame_data.fg_probability, clamp_bbox_xyxy
)
# print(frame_data.fg_probability.sum())
# print('scale', scale)
#! INSERT
if self.deprecated_val_region:
# print(frame_data.crop_bbox_xywh)
valid_bbox = _bbox_xywh_to_xyxy(frame_data.crop_bbox_xywh).float()
# print(valid_bbox, frame_data.image_size_hw)
valid_bbox[0] = torch.clip(
(
valid_bbox[0]
- torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor")
)
/ torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"),
-1.0,
1.0,
)
valid_bbox[1] = torch.clip(
(
valid_bbox[1]
- torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor")
)
/ torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"),
-1.0,
1.0,
)
valid_bbox[2] = torch.clip(
(
valid_bbox[2]
- torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor")
)
/ torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"),
-1.0,
1.0,
)
valid_bbox[3] = torch.clip(
(
valid_bbox[3]
- torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor")
)
/ torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"),
-1.0,
1.0,
)
# print(valid_bbox)
frame_data.valid_region = valid_bbox
else:
#! UPDATED VALID BBOX
if self.stage == "train":
assert self.image_height == 256 and self.image_width == 256
valid = torch.nonzero(frame_data.mask_crop[0])
min_y = valid[:, 0].min()
min_x = valid[:, 1].min()
max_y = valid[:, 0].max()
max_x = valid[:, 1].max()
valid_bbox = torch.tensor(
[min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device
).unsqueeze(0)
valid_bbox = torch.clip(
(valid_bbox - (256 // 2)) / (256 // 2), -1.0, 1.0
)
frame_data.valid_region = valid_bbox[0]
else:
valid = torch.nonzero(frame_data.mask_crop[0])
min_y = valid[:, 0].min()
min_x = valid[:, 1].min()
max_y = valid[:, 0].max()
max_x = valid[:, 1].max()
valid_bbox = torch.tensor(
[min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device
).unsqueeze(0)
valid_bbox = torch.clip(
(valid_bbox - (self.image_height // 2)) / (self.image_height // 2),
-1.0,
1.0,
)
frame_data.valid_region = valid_bbox[0]
#! SET CLASS ONEHOT
frame_data.category_one_hot = torch.zeros(
(len(self.all_category_list)), device=frame_data.image_rgb.device
)
frame_data.category_one_hot[self.cat_to_idx[frame_data.sequence_category]] = 1
if self.load_depths and entry.depth is not None:
(
frame_data.depth_map,
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
if entry.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(
entry,
scale,
clamp_bbox_xyxy,
)
if self.load_point_clouds and point_cloud is not None:
frame_data.sequence_point_cloud_path = pcl_path = os.path.join(
self.dataset_root, point_cloud.path
)
frame_data.sequence_point_cloud = _load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points
)
# for key in frame_data:
# if frame_data[key] == None:
# print(key)
return frame_data
def _extract_and_set_eval_batches(self):
"""
Sets eval_batches based on input eval_batch_index.
"""
if self.eval_batch_index is not None:
if self.eval_batches is not None:
raise ValueError(
"Cannot define both eval_batch_index and eval_batches."
)
self.eval_batches = self.seq_frame_index_to_dataset_index(
self.eval_batch_index
)
def _load_crop_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[
Optional[torch.Tensor],
Optional[str],
Optional[torch.Tensor],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
fg_probability = None
full_path = None
bbox_xywh = None
clamp_bbox_xyxy = None
crop_box_xywh = None
if (self.load_masks or self.box_crop) and entry.mask is not None:
full_path = os.path.join(self.dataset_root, entry.mask.path)
mask = _load_mask(self._local_path(full_path))
if mask.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
)
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
if self.box_crop:
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
_get_clamp_bbox(
bbox_xywh,
image_path=entry.image.path,
box_crop_context=self.box_crop_context,
),
image_size_hw=tuple(mask.shape[-2:]),
)
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
def _load_crop_images(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
assert self.dataset_root is not None and entry.image is not None
path = os.path.join(self.dataset_root, entry.image.path)
image_rgb = _load_image(self._local_path(path))
if image_rgb.shape[-2:] != entry.image.size:
raise ValueError(
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
)
if self.box_crop:
assert clamp_bbox_xyxy is not None
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
image_rgb, scale, mask_crop = self._resize_image(image_rgb)
if self.mask_images:
assert fg_probability is not None
image_rgb *= fg_probability
return image_rgb, path, mask_crop, scale
def _load_mask_depth(
self,
entry: types.FrameAnnotation,
clamp_bbox_xyxy: Optional[torch.Tensor],
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.box_crop:
assert clamp_bbox_xyxy is not None
depth_bbox_xyxy = _rescale_bbox(
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
)
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
if self.mask_depths:
assert fg_probability is not None
depth_map *= fg_probability
if self.load_depth_masks:
assert entry_depth.mask_path is not None
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = _load_depth_mask(self._local_path(mask_path))
if self.box_crop:
assert clamp_bbox_xyxy is not None
depth_mask_bbox_xyxy = _rescale_bbox(
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
)
depth_mask = _crop_around_box(
depth_mask, depth_mask_bbox_xyxy, mask_path
)
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
else:
depth_mask = torch.ones_like(depth_map)
return depth_map, path, depth_mask
def _get_pytorch3d_camera(
self,
entry: types.FrameAnnotation,
scale: float,
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None
# principal point and focal length
principal_point = torch.tensor(
entry_viewpoint.principal_point, dtype=torch.float
)
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
half_image_size_wh_orig = (
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
)
# first, we convert from the dataset's NDC convention to pixels
format = entry_viewpoint.intrinsics_format
if format.lower() == "ndc_norm_image_bounds":
# this is e.g. currently used in CO3D for storing intrinsics
rescale = half_image_size_wh_orig
elif format.lower() == "ndc_isotropic":
rescale = half_image_size_wh_orig.min()
else:
raise ValueError(f"Unknown intrinsics format: {format}")
# principal point and focal length in pixels
principal_point_px = half_image_size_wh_orig - principal_point * rescale
focal_length_px = focal_length * rescale
if self.box_crop:
assert clamp_bbox_xyxy is not None
principal_point_px -= clamp_bbox_xyxy[:2]
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
if self.image_height is None or self.image_width is None:
out_size = list(reversed(entry.image.size))
else:
out_size = [self.image_width, self.image_height]
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
half_min_image_size_output = half_image_size_output.min()
# rescaled principal point and focal length in ndc
principal_point = (
half_image_size_output - principal_point_px * scale
) / half_min_image_size_output
focal_length = focal_length_px * scale / half_min_image_size_output
return PerspectiveCameras(
focal_length=focal_length[None],
principal_point=principal_point[None],
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)
def _load_frames(self) -> None:
self.frame_annots = [
FrameAnnotsEntry(frame_annotation=a, subset=None)
for a in self.category_frame_annotations
]
def _load_sequences(self) -> None:
self.seq_annots = {
entry.sequence_name: entry for entry in self.category_sequence_annotations
}
def _load_subset_lists(self) -> None:
logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
if not self.subset_lists_file:
return
frame_path_to_subset = {}
for subset_list_file in self.subset_lists_file:
with open(self._local_path(subset_list_file), "r") as f:
subset_to_seq_frame = json.load(f)
#! PRINT SUBSET_LIST STATS
# if len(self.subset_lists_file) == 1:
# print('train frames', len(subset_to_seq_frame['train']))
# print('val frames', len(subset_to_seq_frame['val']))
# print('test frames', len(subset_to_seq_frame['test']))
for set_ in subset_to_seq_frame:
for _, _, path in subset_to_seq_frame[set_]:
if path in frame_path_to_subset:
frame_path_to_subset[path].add(set_)
else:
frame_path_to_subset[path] = {set_}
# pyre-ignore[16]
for frame in self.frame_annots:
frame["subset"] = frame_path_to_subset.get(
frame["frame_annotation"].image.path, None
)
if frame["subset"] is None:
continue
warnings.warn(
"Subset lists are given but don't include "
+ frame["frame_annotation"].image.path
)
def _sort_frames(self) -> None:
# Sort frames to have them grouped by sequence, ordered by timestamp
# pyre-ignore[16]
self.frame_annots = sorted(
self.frame_annots,
key=lambda f: (
f["frame_annotation"].sequence_name,
f["frame_annotation"].frame_timestamp or 0,
),
)
def _filter_db(self) -> None:
if self.remove_empty_masks:
logger.info("Removing images with empty masks.")
# pyre-ignore[16]
old_len = len(self.frame_annots)
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
mask = frame_annot.mask
if mask is None:
return False
if mask.mass is None:
raise ValueError(msg)
return mask.mass > 1
self.frame_annots = [
frame
for frame in self.frame_annots
if positive_mass(frame["frame_annotation"])
]
logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
# this has to be called after joining with categories!!
subsets = self.subsets
if subsets:
if not self.subset_lists_file:
raise ValueError(
"Subset filter is on but subset_lists_file was not given"
)
logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
# truncate the list of subsets to the valid one
self.frame_annots = [
entry
for entry in self.frame_annots
if (entry["subset"] is not None and self.stage in entry["subset"])
]
if len(self.frame_annots) == 0:
raise ValueError(f"There are no frames in the '{subsets}' subsets!")
self._invalidate_indexes(filter_seq_annots=True)
if len(self.limit_category_to) > 0:
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
# pyre-ignore[16]
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
if entry.category in self.limit_category_to
}
# sequence filters
for prefix in ("pick", "exclude"):
orig_len = len(self.seq_annots)
attr = f"{prefix}_sequence"
arr = getattr(self, attr)
if len(arr) > 0:
logger.info(f"{attr}: {str(arr)}")
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
if (name in arr) == (prefix == "pick")
}
logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
if self.limit_sequences_to > 0:
self.seq_annots = dict(
islice(self.seq_annots.items(), self.limit_sequences_to)
)
# retain only frames from retained sequences
self.frame_annots = [
f
for f in self.frame_annots
if f["frame_annotation"].sequence_name in self.seq_annots
]
self._invalidate_indexes()
if self.n_frames_per_sequence > 0:
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = []
# pyre-ignore[16]
for seq, seq_indices in self._seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences
seed = _seq_name_to_seed(seq) + self.seed
seq_idx_shuffled = random.Random(seed).sample(
sorted(seq_indices), len(seq_indices)
)
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
logger.info(
"... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
)
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
self._invalidate_indexes(filter_seq_annots=False)
# sequences are not decimated, so self.seq_annots is valid
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
logger.info(
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
)
self.frame_annots = self.frame_annots[: self.limit_to]
self._invalidate_indexes(filter_seq_annots=True)
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
# update _seq_to_idx and filter seq_meta according to frame_annots change
# if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
self._invalidate_seq_to_idx()
if filter_seq_annots:
# pyre-ignore[16]
self.seq_annots = {
k: v
for k, v in self.seq_annots.items()
# pyre-ignore[16]
if k in self._seq_to_idx
}
def _invalidate_seq_to_idx(self) -> None:
seq_to_idx = defaultdict(list)
# pyre-ignore[16]
for idx, entry in enumerate(self.frame_annots):
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
# pyre-ignore[16]
self._seq_to_idx = seq_to_idx
def _resize_image(
self, image, mode="bilinear"
) -> Tuple[torch.Tensor, float, torch.Tensor]:
image_height, image_width = self.image_height, self.image_width
if image_height is None or image_width is None:
# skip the resizing
imre_ = torch.from_numpy(image)
return imre_, 1.0, torch.ones_like(imre_[:1])
# takes numpy array, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
torch.from_numpy(image)[None],
scale_factor=minscale,
mode=mode,
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
# pyre-fixme[19]: Expected 1 positional argument.
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
mask = torch.zeros(1, self.image_height, self.image_width)
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
return imre_, minscale, mask
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
return self.path_manager.get_local_path(path)
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int]
) -> List[Tuple[int, float]]:
out: List[Tuple[int, float]] = []
for idx in idxs:
# pyre-ignore[16]
frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
)
return out
def get_eval_batches(self) -> Optional[List[List[int]]]:
return self.eval_batches
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
return entry["frame_annotation"].meta["frame_type"]
class CO3DDataset(LightningDataModule):
def __init__(
self,
root_dir,
batch_size=2,
shuffle=True,
num_workers=10,
prefetch_factor=2,
category="hydrant",
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.shuffle = shuffle
self.train_dataset = CO3Dv2Wrapper(
root_dir=root_dir,
stage="train",
category=category,
**kwargs,
)
self.test_dataset = CO3Dv2Wrapper(
root_dir=root_dir,
stage="test",
subset="fewview_dev",
category=category,
**kwargs,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=self.train_dataset.collate_fn,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=self.test_dataset.collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=video_collate_fn,
)