V3D / sgm /data /objaverse.py
heheyas
init
cfb7702
raw
history blame contribute delete
No virus
29.8 kB
import numpy as np
from pathlib import Path
from PIL import Image
import json
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, default_collate
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from torchvision.transforms.functional import to_tensor
from pytorch_lightning import LightningDataModule
from einops import rearrange
def read_camera_matrix_single(json_file):
# for gobjaverse
with open(json_file, "r", encoding="utf8") as reader:
json_content = json.load(reader)
# negative sign for opencv to opengl
camera_matrix = torch.zeros(3, 4)
camera_matrix[:3, 0] = torch.tensor(json_content["x"])
camera_matrix[:3, 1] = -torch.tensor(json_content["y"])
camera_matrix[:3, 2] = -torch.tensor(json_content["z"])
camera_matrix[:3, 3] = torch.tensor(json_content["origin"])
"""
camera_matrix = np.eye(4)
camera_matrix[:3, 0] = np.array(json_content['x'])
camera_matrix[:3, 1] = np.array(json_content['y'])
camera_matrix[:3, 2] = np.array(json_content['z'])
camera_matrix[:3, 3] = np.array(json_content['origin'])
# print(camera_matrix)
"""
return camera_matrix
def read_camera_instrinsics_single(json_file, h: int, w: int, scale: float = 1.0):
with open(json_file, "r", encoding="utf8") as reader:
json_content = json.load(reader)
h = int(h * scale)
w = int(w * scale)
y_fov = json_content["y_fov"]
x_fov = json_content["x_fov"]
fy = h / 2 / np.tan(y_fov / 2)
fx = w / 2 / np.tan(x_fov / 2)
cx = w // 2
cy = h // 2
intrinsics = torch.tensor(
[
[fx, fy],
[cx, cy],
[w, h],
],
dtype=torch.float32,
)
return intrinsics
def compose_extrinsic_RT(RT: torch.Tensor):
"""
Compose the standard form extrinsic matrix from RT.
Batched I/O.
"""
return torch.cat(
[
RT,
torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(
RT.shape[0], 1, 1
),
],
dim=1,
)
def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
"""
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
Return batched fx, fy, cx, cy
"""
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
fx, fy = fx / width, fy / height
cx, cy = cx / width, cy / height
return fx, fy, cx, cy
def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
"""
RT: (N, 3, 4)
intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
"""
E = compose_extrinsic_RT(RT)
fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
I = torch.stack(
[
torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
torch.tensor([[0, 0, 1]], dtype=torch.float32).repeat(RT.shape[0], 1),
],
dim=1,
)
return torch.cat(
[
E.reshape(-1, 16),
I.reshape(-1, 9),
],
dim=-1,
)
def calc_elevation(c2w):
## works for single or batched c2w
## assume world up is (0, 0, 1)
pos = c2w[..., :3, 3]
return np.arcsin(pos[..., 2] / np.linalg.norm(pos, axis=-1, keepdims=False))
def read_camera_matrix_single(json_file):
with open(json_file, "r", encoding="utf8") as reader:
json_content = json.load(reader)
# negative sign for opencv to opengl
# camera_matrix = np.zeros([3, 4])
# camera_matrix[:3, 0] = np.array(json_content["x"])
# camera_matrix[:3, 1] = -np.array(json_content["y"])
# camera_matrix[:3, 2] = -np.array(json_content["z"])
# camera_matrix[:3, 3] = np.array(json_content["origin"])
camera_matrix = torch.zeros([3, 4])
camera_matrix[:3, 0] = torch.tensor(json_content["x"])
camera_matrix[:3, 1] = -torch.tensor(json_content["y"])
camera_matrix[:3, 2] = -torch.tensor(json_content["z"])
camera_matrix[:3, 3] = torch.tensor(json_content["origin"])
"""
camera_matrix = np.eye(4)
camera_matrix[:3, 0] = np.array(json_content['x'])
camera_matrix[:3, 1] = np.array(json_content['y'])
camera_matrix[:3, 2] = np.array(json_content['z'])
camera_matrix[:3, 3] = np.array(json_content['origin'])
# print(camera_matrix)
"""
return camera_matrix
def blend_white_bg(image):
new_image = Image.new("RGB", image.size, (255, 255, 255))
new_image.paste(image, mask=image.split()[3])
return new_image
def flatten_for_video(input):
return input.flatten()
FLATTEN_FIELDS = ["fps_id", "motion_bucket_id", "cond_aug", "elevation"]
def video_collate_fn(batch: list[dict], *args, **kwargs):
out = {}
for key in batch[0].keys():
if key in FLATTEN_FIELDS:
out[key] = default_collate([item[key] for item in batch])
out[key] = flatten_for_video(out[key])
elif key == "num_video_frames":
out[key] = batch[0][key]
elif key in ["frames", "latents", "rgb"]:
out[key] = default_collate([item[key] for item in batch])
out[key] = rearrange(out[key], "b t c h w -> (b t) c h w")
else:
out[key] = default_collate([item[key] for item in batch])
if "pixelnerf_input" in out:
out["pixelnerf_input"]["rgb"] = rearrange(
out["pixelnerf_input"]["rgb"], "b t c h w -> (b t) c h w"
)
return out
class GObjaverse(Dataset):
def __init__(
self,
root_dir,
split="train",
transform=None,
random_front=False,
max_item=None,
cond_aug_mean=-3.0,
cond_aug_std=0.5,
condition_on_elevation=False,
fps_id=0.0,
motion_bucket_id=300.0,
use_latents=False,
load_caps=False,
front_view_selection="random",
load_pixelnerf=False,
debug_base_idx=None,
scale_pose: bool = False,
max_n_cond: int = 1,
**unused_kwargs,
):
self.root_dir = Path(root_dir)
self.split = split
self.random_front = random_front
self.transform = transform
self.use_latents = use_latents
self.ids = json.load(open(self.root_dir / "valid_uids.json", "r"))
self.n_views = 24
self.load_caps = load_caps
if self.load_caps:
self.caps = json.load(open(self.root_dir / "text_captions_cap3d.json", "r"))
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
self.condition_on_elevation = condition_on_elevation
self.fps_id = fps_id
self.motion_bucket_id = motion_bucket_id
self.load_pixelnerf = load_pixelnerf
self.scale_pose = scale_pose
self.max_n_cond = max_n_cond
if self.use_latents:
self.latents_dir = self.root_dir / "latents256"
self.clip_dir = self.root_dir / "clip_emb256"
self.front_view_selection = front_view_selection
if self.front_view_selection == "random":
pass
elif self.front_view_selection == "fixed":
pass
elif self.front_view_selection.startswith("clip_score"):
self.clip_scores = torch.load(self.root_dir / "clip_score_per_view.pt")
self.ids = list(self.clip_scores.keys())
else:
raise ValueError(
f"Unknown front view selection method {self.front_view_selection}"
)
if max_item is not None:
self.ids = self.ids[:max_item]
## debug
self.ids = self.ids * 10000
if debug_base_idx is not None:
print(f"debug mode with base idx: {debug_base_idx}")
self.debug_base_idx = debug_base_idx
def __getitem__(self, idx: int):
if hasattr(self, "debug_base_idx"):
idx = (idx + self.debug_base_idx) % len(self.ids)
data = {}
idx_list = np.arange(self.n_views)
# if self.random_front:
# roll_idx = np.random.randint(self.n_views)
# idx_list = np.roll(idx_list, roll_idx)
if self.front_view_selection == "random":
roll_idx = np.random.randint(self.n_views)
idx_list = np.roll(idx_list, roll_idx)
elif self.front_view_selection == "fixed":
pass
elif self.front_view_selection == "clip_score_softmax":
this_clip_score = (
F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy()
)
roll_idx = np.random.choice(idx_list, p=this_clip_score)
idx_list = np.roll(idx_list, roll_idx)
elif self.front_view_selection == "clip_score_max":
this_clip_score = (
F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy()
)
roll_idx = np.argmax(this_clip_score)
idx_list = np.roll(idx_list, roll_idx)
frames = []
if not self.use_latents:
try:
for view_idx in idx_list:
frame = Image.open(
self.root_dir
/ "gobjaverse"
/ self.ids[idx]
/ f"{view_idx:05d}/{view_idx:05d}.png"
)
frames.append(self.transform(frame))
except:
idx = 0
frames = []
for view_idx in idx_list:
frame = Image.open(
self.root_dir
/ "gobjaverse"
/ self.ids[idx]
/ f"{view_idx:05d}/{view_idx:05d}.png"
)
frames.append(self.transform(frame))
# a workaround for some bugs in gobjaverse
# use idx=0 and the repeat will be resolved when gathering results, valid number of items can be checked by the len of results
frames = torch.stack(frames, dim=0)
cond = frames[0]
cond_aug = np.exp(
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
)
data.update(
{
"frames": frames,
"cond_frames_without_noise": cond,
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
"cond_frames": cond + cond_aug * torch.randn_like(cond),
"fps_id": torch.as_tensor([self.fps_id] * self.n_views),
"motion_bucket_id": torch.as_tensor(
[self.motion_bucket_id] * self.n_views
),
"num_video_frames": 24,
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
}
)
else:
latents = torch.load(self.latents_dir / f"{self.ids[idx]}.pt")[idx_list]
clip_emb = torch.load(self.clip_dir / f"{self.ids[idx]}.pt")[idx_list][0]
cond = latents[0]
cond_aug = np.exp(
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
)
data.update(
{
"latents": latents,
"cond_frames_without_noise": clip_emb,
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
"cond_frames": cond + cond_aug * torch.randn_like(cond),
"fps_id": torch.as_tensor([self.fps_id] * self.n_views),
"motion_bucket_id": torch.as_tensor(
[self.motion_bucket_id] * self.n_views
),
"num_video_frames": 24,
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
}
)
if self.condition_on_elevation:
sample_c2w = read_camera_matrix_single(
self.root_dir / self.ids[idx] / f"00000/00000.json"
)
elevation = calc_elevation(sample_c2w)
data["elevation"] = torch.as_tensor([elevation] * self.n_views)
if self.load_pixelnerf:
assert "frames" in data, f"pixelnerf cannot work with latents only mode"
data["pixelnerf_input"] = {}
RTs = []
intrinsics = []
for view_idx in idx_list:
meta = (
self.root_dir
/ "gobjaverse"
/ self.ids[idx]
/ f"{view_idx:05d}/{view_idx:05d}.json"
)
RTs.append(read_camera_matrix_single(meta)[:3])
intrinsics.append(read_camera_instrinsics_single(meta, 256, 256))
RTs = torch.stack(RTs, dim=0)
intrinsics = torch.stack(intrinsics, dim=0)
cameras = build_camera_standard(RTs, intrinsics)
data["pixelnerf_input"]["cameras"] = cameras
downsampled = []
for view_idx in idx_list:
frame = Image.open(
self.root_dir
/ "gobjaverse"
/ self.ids[idx]
/ f"{view_idx:05d}/{view_idx:05d}.png"
).resize((32, 32))
downsampled.append(to_tensor(blend_white_bg(frame)))
data["pixelnerf_input"]["rgb"] = torch.stack(downsampled, dim=0)
data["pixelnerf_input"]["frames"] = data["frames"]
if self.scale_pose:
c2ws = cameras[..., :16].reshape(-1, 4, 4)
center = c2ws[:, :3, 3].mean(0)
radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
scale = 1.5 / radius
c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
cameras[..., :16] = c2ws.reshape(-1, 16)
if self.load_caps:
data["caption"] = self.caps[self.ids[idx]]
data["ids"] = self.ids[idx]
return data
def __len__(self):
return len(self.ids)
def collate_fn(self, batch):
if self.max_n_cond > 1:
n_cond = np.random.randint(1, self.max_n_cond + 1)
if n_cond > 1:
for b in batch:
source_index = [0] + np.random.choice(
np.arange(1, self.n_views),
self.max_n_cond - 1,
replace=False,
).tolist()
b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
b["pixelnerf_input"]["n_cond"] = n_cond
b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
"cameras"
][source_index]
return video_collate_fn(batch)
class ObjaverseSpiral(Dataset):
def __init__(
self,
root_dir,
split="train",
transform=None,
random_front=False,
max_item=None,
cond_aug_mean=-3.0,
cond_aug_std=0.5,
condition_on_elevation=False,
**unused_kwargs,
):
self.root_dir = Path(root_dir)
self.split = split
self.random_front = random_front
self.transform = transform
self.ids = json.load(open(self.root_dir / f"{split}_ids.json", "r"))
self.n_views = 24
valid_ids = []
for idx in self.ids:
if (self.root_dir / idx).exists():
valid_ids.append(idx)
self.ids = valid_ids
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
self.condition_on_elevation = condition_on_elevation
if max_item is not None:
self.ids = self.ids[:max_item]
## debug
self.ids = self.ids * 10000
def __getitem__(self, idx: int):
frames = []
idx_list = np.arange(self.n_views)
if self.random_front:
roll_idx = np.random.randint(self.n_views)
idx_list = np.roll(idx_list, roll_idx)
for view_idx in idx_list:
frame = Image.open(
self.root_dir / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png"
)
frames.append(self.transform(frame))
# data = {"jpg": torch.stack(frames, dim=0)} # [T, C, H, W]
frames = torch.stack(frames, dim=0)
cond = frames[0]
cond_aug = np.exp(
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
)
data = {
"frames": frames,
"cond_frames_without_noise": cond,
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
"cond_frames": cond + cond_aug * torch.randn_like(cond),
"fps_id": torch.as_tensor([1.0] * self.n_views),
"motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
"num_video_frames": 24,
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
}
if self.condition_on_elevation:
sample_c2w = read_camera_matrix_single(
self.root_dir / self.ids[idx] / f"00000/00000.json"
)
elevation = calc_elevation(sample_c2w)
data["elevation"] = torch.as_tensor([elevation] * self.n_views)
return data
def __len__(self):
return len(self.ids)
class ObjaverseLVISSpiral(Dataset):
def __init__(
self,
root_dir,
split="train",
transform=None,
random_front=False,
max_item=None,
cond_aug_mean=-3.0,
cond_aug_std=0.5,
condition_on_elevation=False,
use_precomputed_latents=False,
**unused_kwargs,
):
print("Using LVIS subset")
self.root_dir = Path(root_dir)
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
self.split = split
self.random_front = random_front
self.transform = transform
self.use_precomputed_latents = use_precomputed_latents
self.ids = json.load(open("./assets/lvis_uids.json", "r"))
self.n_views = 18
valid_ids = []
for idx in self.ids:
if (self.root_dir / idx).exists():
valid_ids.append(idx)
self.ids = valid_ids
print("=" * 30)
print("Number of valid ids: ", len(self.ids))
print("=" * 30)
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
self.condition_on_elevation = condition_on_elevation
if max_item is not None:
self.ids = self.ids[:max_item]
## debug
self.ids = self.ids * 10000
def __getitem__(self, idx: int):
frames = []
idx_list = np.arange(self.n_views)
if self.random_front:
roll_idx = np.random.randint(self.n_views)
idx_list = np.roll(idx_list, roll_idx)
for view_idx in idx_list:
frame = Image.open(
self.root_dir
/ self.ids[idx]
/ "elevations_0"
/ f"colors_{view_idx * 2}.png"
)
frames.append(self.transform(frame))
frames = torch.stack(frames, dim=0)
cond = frames[0]
cond_aug = np.exp(
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
)
data = {
"frames": frames,
"cond_frames_without_noise": cond,
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
"cond_frames": cond + cond_aug * torch.randn_like(cond),
"fps_id": torch.as_tensor([0.0] * self.n_views),
"motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
"num_video_frames": self.n_views,
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
}
if self.use_precomputed_latents:
data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt")
if self.condition_on_elevation:
# sample_c2w = read_camera_matrix_single(
# self.root_dir / self.ids[idx] / f"00000/00000.json"
# )
# elevation = calc_elevation(sample_c2w)
# data["elevation"] = torch.as_tensor([elevation] * self.n_views)
assert False, "currently assumes elevation 0"
return data
def __len__(self):
return len(self.ids)
class ObjaverseALLSpiral(ObjaverseLVISSpiral):
def __init__(
self,
root_dir,
split="train",
transform=None,
random_front=False,
max_item=None,
cond_aug_mean=-3.0,
cond_aug_std=0.5,
condition_on_elevation=False,
use_precomputed_latents=False,
**unused_kwargs,
):
print("Using ALL objects in Objaverse")
self.root_dir = Path(root_dir)
self.split = split
self.random_front = random_front
self.transform = transform
self.use_precomputed_latents = use_precomputed_latents
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
self.ids = json.load(open("./assets/all_ids.json", "r"))
self.n_views = 18
valid_ids = []
for idx in self.ids:
if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir():
valid_ids.append(idx)
self.ids = valid_ids
print("=" * 30)
print("Number of valid ids: ", len(self.ids))
print("=" * 30)
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
self.condition_on_elevation = condition_on_elevation
if max_item is not None:
self.ids = self.ids[:max_item]
## debug
self.ids = self.ids * 10000
class ObjaverseWithPose(Dataset):
def __init__(
self,
root_dir,
split="train",
transform=None,
random_front=False,
max_item=None,
cond_aug_mean=-3.0,
cond_aug_std=0.5,
condition_on_elevation=False,
use_precomputed_latents=False,
**unused_kwargs,
):
print("Using Objaverse with poses")
self.root_dir = Path(root_dir)
self.split = split
self.random_front = random_front
self.transform = transform
self.use_precomputed_latents = use_precomputed_latents
self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
self.ids = json.load(open("./assets/all_ids.json", "r"))
self.n_views = 18
valid_ids = []
for idx in self.ids:
if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir():
valid_ids.append(idx)
self.ids = valid_ids
print("=" * 30)
print("Number of valid ids: ", len(self.ids))
print("=" * 30)
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
self.condition_on_elevation = condition_on_elevation
def __getitem__(self, idx: int):
frames = []
idx_list = np.arange(self.n_views)
if self.random_front:
roll_idx = np.random.randint(self.n_views)
idx_list = np.roll(idx_list, roll_idx)
for view_idx in idx_list:
frame = Image.open(
self.root_dir
/ self.ids[idx]
/ "elevations_0"
/ f"colors_{view_idx * 2}.png"
)
frames.append(self.transform(frame))
frames = torch.stack(frames, dim=0)
cond = frames[0]
cond_aug = np.exp(
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
)
data = {
"frames": frames,
"cond_frames_without_noise": cond,
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
"cond_frames": cond + cond_aug * torch.randn_like(cond),
"fps_id": torch.as_tensor([0.0] * self.n_views),
"motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
"num_video_frames": self.n_views,
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
}
if self.use_precomputed_latents:
data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt")
if self.condition_on_elevation:
assert False, "currently assumes elevation 0"
return data
class LatentObjaverse(Dataset):
def __init__(
self,
root_dir,
split="train",
random_front=False,
subset="lvis",
fps_id=1.0,
motion_bucket_id=300.0,
cond_aug_mean=-3.0,
cond_aug_std=0.5,
**unused_kwargs,
):
self.root_dir = Path(root_dir)
self.split = split
self.random_front = random_front
self.ids = json.load(open(Path("./assets") / f"{subset}_ids.json", "r"))
self.clip_emb_dir = self.root_dir / ".." / "clip_emb512"
self.n_views = 18
self.fps_id = fps_id
self.motion_bucket_id = motion_bucket_id
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
if self.random_front:
print("Using a random view as front view")
valid_ids = []
for idx in self.ids:
if (self.root_dir / f"{idx}.pt").exists() and (
self.clip_emb_dir / f"{idx}.pt"
).exists():
valid_ids.append(idx)
self.ids = valid_ids
print("=" * 30)
print("Number of valid ids: ", len(self.ids))
print("=" * 30)
def __getitem__(self, idx: int):
uid = self.ids[idx]
idx_list = torch.arange(self.n_views)
latents = torch.load(self.root_dir / f"{uid}.pt")
clip_emb = torch.load(self.clip_emb_dir / f"{uid}.pt")
if self.random_front:
idx_list = torch.roll(idx_list, np.random.randint(self.n_views))
latents = latents[idx_list]
clip_emb = clip_emb[idx_list][0]
cond_aug = np.exp(
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
)
cond = latents[0]
data = {
"latents": latents,
"cond_frames_without_noise": clip_emb,
"cond_frames": cond + cond_aug * torch.randn_like(cond),
"fps_id": torch.as_tensor([self.fps_id] * self.n_views),
"motion_bucket_id": torch.as_tensor([self.motion_bucket_id] * self.n_views),
"cond_aug": torch.as_tensor([cond_aug] * self.n_views),
"num_video_frames": self.n_views,
"image_only_indicator": torch.as_tensor([0.0] * self.n_views),
}
return data
def __len__(self):
return len(self.ids)
class ObjaverseSpiralDataset(LightningDataModule):
def __init__(
self,
root_dir,
random_front=False,
batch_size=2,
num_workers=10,
prefetch_factor=2,
shuffle=True,
max_item=None,
dataset_cls="richdreamer",
reso: int = 256,
**kwargs,
) -> None:
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.shuffle = shuffle
self.max_item = max_item
self.transform = Compose(
[
blend_white_bg,
Resize((reso, reso)),
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
data_cls = {
"richdreamer": ObjaverseSpiral,
"lvis": ObjaverseLVISSpiral,
"shengshu_all": ObjaverseALLSpiral,
"latent": LatentObjaverse,
"gobjaverse": GObjaverse,
}[dataset_cls]
self.train_dataset = data_cls(
root_dir=root_dir,
split="train",
random_front=random_front,
transform=self.transform,
max_item=self.max_item,
**kwargs,
)
self.test_dataset = data_cls(
root_dir=root_dir,
split="val",
random_front=random_front,
transform=self.transform,
max_item=self.max_item,
**kwargs,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=video_collate_fn
if not hasattr(self.train_dataset, "collate_fn")
else self.train_dataset.collate_fn,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=video_collate_fn
if not hasattr(self.test_dataset, "collate_fn")
else self.train_dataset.collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=video_collate_fn
if not hasattr(self.test_dataset, "collate_fn")
else self.train_dataset.collate_fn,
)