Spaces:
Running
on
T4
Running
on
T4
import gzip | |
import json | |
import os.path as osp | |
import random | |
import socket | |
import time | |
import torch | |
import warnings | |
import numpy as np | |
from PIL import Image, ImageFile | |
from tqdm import tqdm | |
from pytorch3d.renderer import PerspectiveCameras | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from scipy import ndimage as nd | |
from diffusionsfm.utils.distortion import distort_image | |
HOSTNAME = socket.gethostname() | |
CO3D_DIR = "../co3d_data" # update this | |
CO3D_ANNOTATION_DIR = osp.join(CO3D_DIR, "co3d_annotations") | |
CO3D_DIR = CO3D_DEPTH_DIR = osp.join(CO3D_DIR, "co3d") | |
order_path = osp.join( | |
CO3D_DIR, "co3d_v2_random_order_{sample_num}/{category}.json" | |
) | |
TRAINING_CATEGORIES = [ | |
"apple", | |
"backpack", | |
"banana", | |
"baseballbat", | |
"baseballglove", | |
"bench", | |
"bicycle", | |
"bottle", | |
"bowl", | |
"broccoli", | |
"cake", | |
"car", | |
"carrot", | |
"cellphone", | |
"chair", | |
"cup", | |
"donut", | |
"hairdryer", | |
"handbag", | |
"hydrant", | |
"keyboard", | |
"laptop", | |
"microwave", | |
"motorcycle", | |
"mouse", | |
"orange", | |
"parkingmeter", | |
"pizza", | |
"plant", | |
"stopsign", | |
"teddybear", | |
"toaster", | |
"toilet", | |
"toybus", | |
"toyplane", | |
"toytrain", | |
"toytruck", | |
"tv", | |
"umbrella", | |
"vase", | |
"wineglass", | |
] | |
TEST_CATEGORIES = [ | |
"ball", | |
"book", | |
"couch", | |
"frisbee", | |
"hotdog", | |
"kite", | |
"remote", | |
"sandwich", | |
"skateboard", | |
"suitcase", | |
] | |
assert len(TRAINING_CATEGORIES) + len(TEST_CATEGORIES) == 51 | |
Image.MAX_IMAGE_PIXELS = None | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
def fill_depths(data, invalid=None): | |
data_list = [] | |
for i in range(data.shape[0]): | |
data_item = data[i].numpy() | |
# Invalid must be 1 where stuff is invalid, 0 where valid | |
ind = nd.distance_transform_edt( | |
invalid[i], return_distances=False, return_indices=True | |
) | |
data_list.append(torch.tensor(data_item[tuple(ind)])) | |
return torch.stack(data_list, dim=0) | |
def full_scene_scale(batch): | |
cameras = PerspectiveCameras(R=batch["R"], T=batch["T"], device="cuda") | |
cc = cameras.get_camera_center() | |
centroid = torch.mean(cc, dim=0) | |
diffs = cc - centroid | |
norms = torch.linalg.norm(diffs, dim=1) | |
furthest_index = torch.argmax(norms).item() | |
scale = norms[furthest_index].item() | |
return scale | |
def square_bbox(bbox, padding=0.0, astype=None, tight=False): | |
""" | |
Computes a square bounding box, with optional padding parameters. | |
Args: | |
bbox: Bounding box in xyxy format (4,). | |
Returns: | |
square_bbox in xyxy format (4,). | |
""" | |
if astype is None: | |
astype = type(bbox[0]) | |
bbox = np.array(bbox) | |
center = (bbox[:2] + bbox[2:]) / 2 | |
extents = (bbox[2:] - bbox[:2]) / 2 | |
# No black bars if tight | |
if tight: | |
s = min(extents) * (1 + padding) | |
else: | |
s = max(extents) * (1 + padding) | |
square_bbox = np.array( | |
[center[0] - s, center[1] - s, center[0] + s, center[1] + s], | |
dtype=astype, | |
) | |
return square_bbox | |
def unnormalize_image(image, return_numpy=True, return_int=True): | |
if isinstance(image, torch.Tensor): | |
image = image.detach().cpu().numpy() | |
if image.ndim == 3: | |
if image.shape[0] == 3: | |
image = image[None, ...] | |
elif image.shape[2] == 3: | |
image = image.transpose(2, 0, 1)[None, ...] | |
else: | |
raise ValueError(f"Unexpected image shape: {image.shape}") | |
elif image.ndim == 4: | |
if image.shape[1] == 3: | |
pass | |
elif image.shape[3] == 3: | |
image = image.transpose(0, 3, 1, 2) | |
else: | |
raise ValueError(f"Unexpected batch image shape: {image.shape}") | |
else: | |
raise ValueError(f"Unsupported input shape: {image.shape}") | |
mean = np.array([0.485, 0.456, 0.406])[None, :, None, None] | |
std = np.array([0.229, 0.224, 0.225])[None, :, None, None] | |
image = image * std + mean | |
if return_int: | |
image = np.clip(image * 255.0, 0, 255).astype(np.uint8) | |
else: | |
image = np.clip(image, 0.0, 1.0) | |
if image.shape[0] == 1: | |
image = image[0] | |
if return_numpy: | |
return image | |
else: | |
return torch.from_numpy(image) | |
def unnormalize_image_for_vis(image): | |
assert len(image.shape) == 5 and image.shape[2] == 3 | |
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1).to(image.device) | |
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1).to(image.device) | |
image = image * std + mean | |
image = (image - 0.5) / 0.5 | |
return image | |
def _transform_intrinsic(image, bbox, principal_point, focal_length): | |
# Rescale intrinsics to match bbox | |
half_box = np.array([image.width, image.height]).astype(np.float32) / 2 | |
org_scale = min(half_box).astype(np.float32) | |
# Pixel coordinates | |
principal_point_px = half_box - (np.array(principal_point) * org_scale) | |
focal_length_px = np.array(focal_length) * org_scale | |
principal_point_px -= bbox[:2] | |
new_bbox = (bbox[2:] - bbox[:2]) / 2 | |
new_scale = min(new_bbox) | |
# NDC coordinates | |
new_principal_ndc = (new_bbox - principal_point_px) / new_scale | |
new_focal_ndc = focal_length_px / new_scale | |
principal_point = torch.tensor(new_principal_ndc.astype(np.float32)) | |
focal_length = torch.tensor(new_focal_ndc.astype(np.float32)) | |
return principal_point, focal_length | |
def construct_camera_from_batch(batch, device): | |
if isinstance(device, int): | |
device = f"cuda:{device}" | |
return PerspectiveCameras( | |
R=batch["R"].reshape(-1, 3, 3), | |
T=batch["T"].reshape(-1, 3), | |
focal_length=batch["focal_lengths"].reshape(-1, 2), | |
principal_point=batch["principal_points"].reshape(-1, 2), | |
image_size=batch["image_sizes"].reshape(-1, 2), | |
device=device, | |
) | |
def save_batch_images(images, fname): | |
cmap = plt.get_cmap("hsv") | |
num_frames = len(images) | |
num_rows = len(images) | |
num_cols = 4 | |
figsize = (num_cols * 2, num_rows * 2) | |
fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize) | |
axs = axs.flatten() | |
for i in range(num_rows): | |
for j in range(4): | |
if i < num_frames: | |
axs[i * 4 + j].imshow(unnormalize_image(images[i][j])) | |
for s in ["bottom", "top", "left", "right"]: | |
axs[i * 4 + j].spines[s].set_color(cmap(i / (num_frames))) | |
axs[i * 4 + j].spines[s].set_linewidth(5) | |
axs[i * 4 + j].set_xticks([]) | |
axs[i * 4 + j].set_yticks([]) | |
else: | |
axs[i * 4 + j].axis("off") | |
plt.tight_layout() | |
plt.savefig(fname) | |
def jitter_bbox( | |
square_bbox, | |
jitter_scale=(1.1, 1.2), | |
jitter_trans=(-0.07, 0.07), | |
direction_from_size=None, | |
): | |
square_bbox = np.array(square_bbox.astype(float)) | |
s = np.random.uniform(jitter_scale[0], jitter_scale[1]) | |
# Jitter only one dimension if center cropping | |
tx, ty = np.random.uniform(jitter_trans[0], jitter_trans[1], size=2) | |
if direction_from_size is not None: | |
if direction_from_size[0] > direction_from_size[1]: | |
tx = 0 | |
else: | |
ty = 0 | |
side_length = square_bbox[2] - square_bbox[0] | |
center = (square_bbox[:2] + square_bbox[2:]) / 2 + np.array([tx, ty]) * side_length | |
extent = side_length / 2 * s | |
ul = center - extent | |
lr = ul + 2 * extent | |
return np.concatenate((ul, lr)) | |
class Co3dDataset(Dataset): | |
def __init__( | |
self, | |
category=("all_train",), | |
split="train", | |
transform=None, | |
num_images=2, | |
img_size=224, | |
mask_images=False, | |
crop_images=True, | |
co3d_dir=None, | |
co3d_annotation_dir=None, | |
precropped_images=False, | |
apply_augmentation=True, | |
normalize_cameras=True, | |
no_images=False, | |
sample_num=None, | |
seed=0, | |
load_extra_cameras=False, | |
distort_image=False, | |
load_depths=False, | |
center_crop=False, | |
depth_size=256, | |
mask_holes=False, | |
object_mask=True, | |
): | |
""" | |
Args: | |
num_images: Number of images in each batch. | |
perspective_correction (str): | |
"none": No perspective correction. | |
"warp": Warp the image and label. | |
"label_only": Correct the label only. | |
""" | |
start_time = time.time() | |
self.category = category | |
self.split = split | |
self.transform = transform | |
self.num_images = num_images | |
self.img_size = img_size | |
self.mask_images = mask_images | |
self.crop_images = crop_images | |
self.precropped_images = precropped_images | |
self.apply_augmentation = apply_augmentation | |
self.normalize_cameras = normalize_cameras | |
self.no_images = no_images | |
self.sample_num = sample_num | |
self.load_extra_cameras = load_extra_cameras | |
self.distort = distort_image | |
self.load_depths = load_depths | |
self.center_crop = center_crop | |
self.depth_size = depth_size | |
self.mask_holes = mask_holes | |
self.object_mask = object_mask | |
if self.apply_augmentation: | |
if self.center_crop: | |
self.jitter_scale = (0.8, 1.1) | |
self.jitter_trans = (0.0, 0.0) | |
else: | |
self.jitter_scale = (1.1, 1.2) | |
self.jitter_trans = (-0.07, 0.07) | |
else: | |
# Note if trained with apply_augmentation, we should still use | |
# apply_augmentation at test time. | |
self.jitter_scale = (1, 1) | |
self.jitter_trans = (0.0, 0.0) | |
if self.distort: | |
self.k1_max = 1.0 | |
self.k2_max = 1.0 | |
if co3d_dir is not None: | |
self.co3d_dir = co3d_dir | |
self.co3d_annotation_dir = co3d_annotation_dir | |
else: | |
self.co3d_dir = CO3D_DIR | |
self.co3d_annotation_dir = CO3D_ANNOTATION_DIR | |
self.co3d_depth_dir = CO3D_DEPTH_DIR | |
if isinstance(self.category, str): | |
self.category = [self.category] | |
if "all_train" in self.category: | |
self.category = TRAINING_CATEGORIES | |
if "all_test" in self.category: | |
self.category = TEST_CATEGORIES | |
if "full" in self.category: | |
self.category = TRAINING_CATEGORIES + TEST_CATEGORIES | |
self.category = sorted(self.category) | |
self.is_single_category = len(self.category) == 1 | |
# Fixing seed | |
torch.manual_seed(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
print(f"Co3d ({split}):") | |
self.low_quality_translations = [ | |
"411_55952_107659", | |
"427_59915_115716", | |
"435_61970_121848", | |
"112_13265_22828", | |
"110_13069_25642", | |
"165_18080_34378", | |
"368_39891_78502", | |
"391_47029_93665", | |
"20_695_1450", | |
"135_15556_31096", | |
"417_57572_110680", | |
] # Initialized with sequences with poor depth masks | |
self.rotations = {} | |
self.category_map = {} | |
for c in tqdm(self.category): | |
annotation_file = osp.join( | |
self.co3d_annotation_dir, f"{c}_{self.split}.jgz" | |
) | |
with gzip.open(annotation_file, "r") as fin: | |
annotation = json.loads(fin.read()) | |
counter = 0 | |
for seq_name, seq_data in annotation.items(): | |
counter += 1 | |
if len(seq_data) < self.num_images: | |
continue | |
filtered_data = [] | |
self.category_map[seq_name] = c | |
bad_seq = False | |
for data in seq_data: | |
# Make sure translations are not ridiculous and rotations are valid | |
det = np.linalg.det(data["R"]) | |
if (np.abs(data["T"]) > 1e5).any() or det < 0.99 or det > 1.01: | |
bad_seq = True | |
self.low_quality_translations.append(seq_name) | |
break | |
# Ignore all unnecessary information. | |
filtered_data.append( | |
{ | |
"filepath": data["filepath"], | |
"bbox": data["bbox"], | |
"R": data["R"], | |
"T": data["T"], | |
"focal_length": data["focal_length"], | |
"principal_point": data["principal_point"], | |
}, | |
) | |
if not bad_seq: | |
self.rotations[seq_name] = filtered_data | |
self.sequence_list = list(self.rotations.keys()) | |
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
if self.transform is None: | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Resize(self.img_size, antialias=True), | |
transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), | |
] | |
) | |
self.transform_depth = transforms.Compose( | |
[ | |
transforms.Resize( | |
self.depth_size, | |
antialias=False, | |
interpolation=transforms.InterpolationMode.NEAREST_EXACT, | |
), | |
] | |
) | |
print( | |
f"Low quality translation sequences, not used: {self.low_quality_translations}" | |
) | |
print(f"Data size: {len(self)}") | |
print(f"Data loading took {(time.time()-start_time)} seconds.") | |
def __len__(self): | |
return len(self.sequence_list) | |
def __getitem__(self, index): | |
num_to_load = self.num_images if not self.load_extra_cameras else 8 | |
sequence_name = self.sequence_list[index % len(self.sequence_list)] | |
metadata = self.rotations[sequence_name] | |
if self.sample_num is not None: | |
with open( | |
order_path.format(sample_num=self.sample_num, category=self.category[0]) | |
) as f: | |
order = json.load(f) | |
ids = order[sequence_name][:num_to_load] | |
else: | |
replace = len(metadata) < 8 | |
ids = np.random.choice(len(metadata), num_to_load, replace=replace) | |
return self.get_data(index=index, ids=ids, num_valid_frames=num_to_load) | |
def _get_scene_scale(self, sequence_name): | |
n = len(self.rotations[sequence_name]) | |
R = torch.zeros(n, 3, 3) | |
T = torch.zeros(n, 3) | |
for i, ann in enumerate(self.rotations[sequence_name]): | |
R[i, ...] = torch.tensor(self.rotations[sequence_name][i]["R"]) | |
T[i, ...] = torch.tensor(self.rotations[sequence_name][i]["T"]) | |
cameras = PerspectiveCameras(R=R, T=T) | |
cc = cameras.get_camera_center() | |
centeroid = torch.mean(cc, dim=0) | |
diff = cc - centeroid | |
norm = torch.norm(diff, dim=1) | |
scale = torch.max(norm).item() | |
return scale | |
def _crop_image(self, image, bbox): | |
image_crop = transforms.functional.crop( | |
image, | |
top=bbox[1], | |
left=bbox[0], | |
height=bbox[3] - bbox[1], | |
width=bbox[2] - bbox[0], | |
) | |
return image_crop | |
def _transform_intrinsic(self, image, bbox, principal_point, focal_length): | |
half_box = np.array([image.width, image.height]).astype(np.float32) / 2 | |
org_scale = min(half_box).astype(np.float32) | |
# Pixel coordinates | |
principal_point_px = half_box - (np.array(principal_point) * org_scale) | |
focal_length_px = np.array(focal_length) * org_scale | |
principal_point_px -= bbox[:2] | |
new_bbox = (bbox[2:] - bbox[:2]) / 2 | |
new_scale = min(new_bbox) | |
# NDC coordinates | |
new_principal_ndc = (new_bbox - principal_point_px) / new_scale | |
new_focal_ndc = focal_length_px / new_scale | |
return new_principal_ndc.astype(np.float32), new_focal_ndc.astype(np.float32) | |
def get_data( | |
self, | |
index=None, | |
sequence_name=None, | |
ids=(0, 1), | |
no_images=False, | |
num_valid_frames=None, | |
load_using_order=None, | |
): | |
if load_using_order is not None: | |
with open( | |
order_path.format(sample_num=self.sample_num, category=self.category[0]) | |
) as f: | |
order = json.load(f) | |
ids = order[sequence_name][:load_using_order] | |
if sequence_name is None: | |
index = index % len(self.sequence_list) | |
sequence_name = self.sequence_list[index] | |
metadata = self.rotations[sequence_name] | |
category = self.category_map[sequence_name] | |
# Read image & camera information from annotations | |
annos = [metadata[i] for i in ids] | |
images = [] | |
image_sizes = [] | |
PP = [] | |
FL = [] | |
crop_parameters = [] | |
filenames = [] | |
distortion_parameters = [] | |
depths = [] | |
depth_masks = [] | |
object_masks = [] | |
dino_images = [] | |
for anno in annos: | |
filepath = anno["filepath"] | |
if not no_images: | |
image = Image.open(osp.join(self.co3d_dir, filepath)).convert("RGB") | |
image_size = image.size | |
# Optionally mask images with black background | |
if self.mask_images: | |
black_image = Image.new("RGB", image_size, (0, 0, 0)) | |
mask_name = osp.basename(filepath.replace(".jpg", ".png")) | |
mask_path = osp.join( | |
self.co3d_dir, category, sequence_name, "masks", mask_name | |
) | |
mask = Image.open(mask_path).convert("L") | |
if mask.size != image_size: | |
mask = mask.resize(image_size) | |
mask = Image.fromarray(np.array(mask) > 125) | |
image = Image.composite(image, black_image, mask) | |
if self.object_mask: | |
mask_name = osp.basename(filepath.replace(".jpg", ".png")) | |
mask_path = osp.join( | |
self.co3d_dir, category, sequence_name, "masks", mask_name | |
) | |
mask = Image.open(mask_path).convert("L") | |
if mask.size != image_size: | |
mask = mask.resize(image_size) | |
mask = torch.from_numpy(np.array(mask) > 125) | |
# Determine crop, Resnet wants square images | |
bbox = np.array(anno["bbox"]) | |
good_bbox = ((bbox[2:] - bbox[:2]) > 30).all() | |
bbox = ( | |
anno["bbox"] | |
if not self.center_crop and good_bbox | |
else [0, 0, image.width, image.height] | |
) | |
# Distort image and bbox if desired | |
if self.distort: | |
k1 = random.uniform(0, self.k1_max) | |
k2 = random.uniform(0, self.k2_max) | |
try: | |
image, bbox = distort_image( | |
image, np.array(bbox), k1, k2, modify_bbox=True | |
) | |
except: | |
print("INFO:") | |
print(sequence_name) | |
print(index) | |
print(ids) | |
print(k1) | |
print(k2) | |
distortion_parameters.append(torch.FloatTensor([k1, k2])) | |
bbox = square_bbox(np.array(bbox), tight=self.center_crop) | |
if self.apply_augmentation: | |
bbox = jitter_bbox( | |
bbox, | |
jitter_scale=self.jitter_scale, | |
jitter_trans=self.jitter_trans, | |
direction_from_size=image.size if self.center_crop else None, | |
) | |
bbox = np.around(bbox).astype(int) | |
# Crop parameters | |
crop_center = (bbox[:2] + bbox[2:]) / 2 | |
principal_point = torch.tensor(anno["principal_point"]) | |
focal_length = torch.tensor(anno["focal_length"]) | |
# convert crop center to correspond to a "square" image | |
width, height = image.size | |
length = max(width, height) | |
s = length / min(width, height) | |
crop_center = crop_center + (length - np.array([width, height])) / 2 | |
# convert to NDC | |
cc = s - 2 * s * crop_center / length | |
crop_width = 2 * s * (bbox[2] - bbox[0]) / length | |
crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s]) | |
# Crop and normalize image | |
if not self.precropped_images: | |
image = self._crop_image(image, bbox) | |
try: | |
image = self.transform(image) | |
except: | |
print("INFO:") | |
print(sequence_name) | |
print(index) | |
print(ids) | |
print(k1) | |
print(k2) | |
images.append(image[:, : self.img_size, : self.img_size]) | |
crop_parameters.append(crop_params) | |
if self.load_depths: | |
# Open depth map | |
depth_name = osp.basename( | |
filepath.replace(".jpg", ".jpg.geometric.png") | |
) | |
depth_path = osp.join( | |
self.co3d_depth_dir, | |
category, | |
sequence_name, | |
"depths", | |
depth_name, | |
) | |
depth_pil = Image.open(depth_path) | |
# 16 bit float type casting | |
depth = torch.tensor( | |
np.frombuffer( | |
np.array(depth_pil, dtype=np.uint16), dtype=np.float16 | |
) | |
.astype(np.float32) | |
.reshape((depth_pil.size[1], depth_pil.size[0])) | |
) | |
# Crop and resize as with images | |
if depth_pil.size != image_size: | |
# bbox may have the wrong scale | |
bbox = depth_pil.size[0] * bbox / image_size[0] | |
if self.object_mask: | |
assert mask.shape == depth.shape | |
bbox = np.around(bbox).astype(int) | |
depth = self._crop_image(depth, bbox) | |
# Resize | |
depth = self.transform_depth(depth.unsqueeze(0))[ | |
0, : self.depth_size, : self.depth_size | |
] | |
depths.append(depth) | |
if self.object_mask: | |
mask = self._crop_image(mask, bbox) | |
mask = self.transform_depth(mask.unsqueeze(0))[ | |
0, : self.depth_size, : self.depth_size | |
] | |
object_masks.append(mask) | |
PP.append(principal_point) | |
FL.append(focal_length) | |
image_sizes.append(torch.tensor([self.img_size, self.img_size])) | |
filenames.append(filepath) | |
if not no_images: | |
if self.load_depths: | |
depths = torch.stack(depths) | |
depth_masks = torch.logical_or(depths <= 0, depths.isinf()) | |
depth_masks = (~depth_masks).long() | |
if self.object_mask: | |
object_masks = torch.stack(object_masks, dim=0) | |
if self.mask_holes: | |
depths = fill_depths(depths, depth_masks == 0) | |
# Sometimes mask_holes misses stuff | |
new_masks = torch.logical_or(depths <= 0, depths.isinf()) | |
new_masks = (~new_masks).long() | |
depths[new_masks == 0] = -1 | |
assert torch.logical_or(depths > 0, depths == -1).all() | |
assert not (depths.isinf()).any() | |
assert not (depths.isnan()).any() | |
if self.load_extra_cameras: | |
# Remove the extra loaded image, for saving space | |
images = images[: self.num_images] | |
if self.distort: | |
distortion_parameters = torch.stack(distortion_parameters) | |
images = torch.stack(images) | |
crop_parameters = torch.stack(crop_parameters) | |
focal_lengths = torch.stack(FL) | |
principal_points = torch.stack(PP) | |
image_sizes = torch.stack(image_sizes) | |
else: | |
images = None | |
crop_parameters = None | |
distortion_parameters = None | |
focal_lengths = [] | |
principal_points = [] | |
image_sizes = [] | |
# Assemble batch info to send back | |
R = torch.stack([torch.tensor(anno["R"]) for anno in annos]) | |
T = torch.stack([torch.tensor(anno["T"]) for anno in annos]) | |
batch = { | |
"model_id": sequence_name, | |
"category": category, | |
"n": len(metadata), | |
"num_valid_frames": num_valid_frames, | |
"ind": torch.tensor(ids), | |
"image": images, | |
"depth": depths, | |
"depth_masks": depth_masks, | |
"object_masks": object_masks, | |
"R": R, | |
"T": T, | |
"focal_length": focal_lengths, | |
"principal_point": principal_points, | |
"image_size": image_sizes, | |
"crop_parameters": crop_parameters, | |
"distortion_parameters": torch.zeros(4), | |
"filename": filenames, | |
"category": category, | |
"dataset": "co3d", | |
} | |
return batch | |