heheyas
init
cfb7702
raw
history blame
14.6 kB
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, default_collate
from pathlib import Path
from PIL import Image
from scipy.spatial.transform import Rotation
import rembg
from rembg import remove, new_session
from einops import rearrange
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from torchvision.transforms.functional import to_tensor
from pytorch_lightning import LightningDataModule
from sgm.data.colmap import read_cameras_binary, read_images_binary
from sgm.data.objaverse import video_collate_fn, FLATTEN_FIELDS, flatten_for_video
def qvec2rotmat(qvec):
return np.array(
[
[
1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
],
[
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
],
[
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
],
]
)
def qt2c2w(q, t):
# NOTE: remember to convert to opengl coordinate system
# rot = Rotation.from_quat(q).as_matrix()
rot = qvec2rotmat(q)
c2w = np.eye(4)
c2w[:3, :3] = np.transpose(rot)
c2w[:3, 3] = -np.transpose(rot) @ t
c2w[..., 1:3] *= -1
return c2w
def random_crop():
pass
class MVImageNet(Dataset):
def __init__(
self,
root_dir,
split,
transform,
reso: int = 256,
mask_type: str = "random",
cond_aug_mean=-3.0,
cond_aug_std=0.5,
condition_on_elevation=False,
fps_id=0.0,
motion_bucket_id=300.0,
num_frames: int = 24,
use_mask: bool = True,
load_pixelnerf: bool = False,
scale_pose: bool = False,
max_n_cond: int = 1,
min_n_cond: int = 1,
cond_on_multi: bool = False,
) -> None:
super().__init__()
self.root_dir = Path(root_dir)
self.split = split
avails = self.root_dir.glob("*/*")
self.ids = list(
map(
lambda x: str(x.relative_to(self.root_dir)),
filter(lambda x: x.is_dir(), avails),
)
)
self.transform = transform
self.reso = reso
self.num_frames = num_frames
self.cond_aug_mean = cond_aug_mean
self.cond_aug_std = cond_aug_std
self.condition_on_elevation = condition_on_elevation
self.fps_id = fps_id
self.motion_bucket_id = motion_bucket_id
self.mask_type = mask_type
self.use_mask = use_mask
self.load_pixelnerf = load_pixelnerf
self.scale_pose = scale_pose
self.max_n_cond = max_n_cond
self.min_n_cond = min_n_cond
self.cond_on_multi = cond_on_multi
if self.cond_on_multi:
assert self.min_n_cond == self.max_n_cond
self.session = new_session()
def __getitem__(self, index: int):
# mvimgnet starts with idx==1
idx_list = np.arange(0, self.num_frames)
this_image_dir = self.root_dir / self.ids[index] / "images"
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
# while not this_camera_dir.exists():
# index = (index + 1) % len(self.ids)
# this_image_dir = self.root_dir / self.ids[index] / "images"
# this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
if not this_camera_dir.exists():
index = 0
this_image_dir = self.root_dir / self.ids[index] / "images"
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
this_images = read_images_binary(this_camera_dir / "images.bin")
# filenames = list(map(lambda x: f"{x:03d}", this_images.keys()))
filenames = list(this_images.keys())
if len(filenames) == 0:
index = 0
this_image_dir = self.root_dir / self.ids[index] / "images"
this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
this_images = read_images_binary(this_camera_dir / "images.bin")
# filenames = list(map(lambda x: f"{x:03d}", this_images.keys()))
filenames = list(this_images.keys())
filenames = list(
filter(lambda x: (this_image_dir / this_images[x].name).exists(), filenames)
)
filenames = sorted(filenames, key=lambda x: this_images[x].name)
# # debug
# names = []
# for v in filenames:
# names.append(this_images[v].name)
# breakpoint()
while len(filenames) < self.num_frames:
num_surpass = self.num_frames - len(filenames)
filenames += list(reversed(filenames[-num_surpass:]))
if len(filenames) < self.num_frames:
print(f"\n\n{self.ids[index]}\n\n")
frames = []
cameras = []
downsampled_rgb = []
for view_idx in idx_list:
this_id = filenames[view_idx]
frame = Image.open(this_image_dir / this_images[this_id].name)
w, h = frame.size
if self.mask_type == "random":
image_size = min(h, w)
left = np.random.randint(0, w - image_size + 1)
right = left + image_size
top = np.random.randint(0, h - image_size + 1)
bottom = top + image_size
## need to assign left, right, top, bottom, image_size
elif self.mask_type == "object":
pass
elif self.mask_type == "rembg":
image_size = min(h, w)
if (
cached := this_image_dir
/ f"{this_images[this_id].name[:-4]}_rembg.png"
).exists():
try:
mask = np.asarray(Image.open(cached, formats=["png"]))[..., 3]
except:
mask = remove(frame, session=self.session)
mask.save(cached)
mask = np.asarray(mask)[..., 3]
else:
mask = remove(frame, session=self.session)
mask.save(cached)
mask = np.asarray(mask)[..., 3]
# in h,w order
y, x = np.array(mask.nonzero())
bbox_cx = x.mean()
bbox_cy = y.mean()
if bbox_cy - image_size / 2 < 0:
top = 0
elif bbox_cy + image_size / 2 > h:
top = h - image_size
else:
top = int(bbox_cy - image_size / 2)
if bbox_cx - image_size / 2 < 0:
left = 0
elif bbox_cx + image_size / 2 > w:
left = w - image_size
else:
left = int(bbox_cx - image_size / 2)
# top = max(int(bbox_cy - image_size / 2), 0)
# left = max(int(bbox_cx - image_size / 2), 0)
bottom = top + image_size
right = left + image_size
else:
raise ValueError(f"Unknown mask type: {self.mask_type}")
frame = frame.crop((left, top, right, bottom))
frame = frame.resize((self.reso, self.reso))
frames.append(self.transform(frame))
if self.load_pixelnerf:
# extrinsics
extrinsics = this_images[this_id]
c2w = qt2c2w(extrinsics.qvec, extrinsics.tvec)
# intrinsics
intrinsics = read_cameras_binary(this_camera_dir / "cameras.bin")
assert len(intrinsics) == 1
intrinsics = intrinsics[1]
f, cx, cy, _ = intrinsics.params
f *= 1 / image_size
cx -= left
cy -= top
cx *= 1 / image_size
cy *= 1 / image_size # all are relative values
intrinsics = np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]])
this_camera = np.zeros(25)
this_camera[:16] = c2w.reshape(-1)
this_camera[16:] = intrinsics.reshape(-1)
cameras.append(this_camera)
downsampled = frame.resize((self.reso // 8, self.reso // 8))
downsampled_rgb.append((self.transform(downsampled) + 1.0) * 0.5)
data = dict()
cond_aug = np.exp(
np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
)
frames = torch.stack(frames)
cond = frames[0]
# setting all things in data
data["frames"] = frames
data["cond_frames_without_noise"] = cond
data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames)
data["cond_frames"] = cond + cond_aug * torch.randn_like(cond)
data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames)
data["motion_bucket_id"] = torch.as_tensor(
[self.motion_bucket_id] * self.num_frames
)
data["num_video_frames"] = self.num_frames
data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames)
if self.load_pixelnerf:
# TODO: normalize camera poses
data["pixelnerf_input"] = dict()
data["pixelnerf_input"]["frames"] = frames
data["pixelnerf_input"]["rgb"] = torch.stack(downsampled_rgb)
cameras = torch.from_numpy(np.stack(cameras)).float()
if self.scale_pose:
c2ws = cameras[..., :16].reshape(-1, 4, 4)
center = c2ws[:, :3, 3].mean(0)
radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
scale = 1.5 / radius
c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
cameras[..., :16] = c2ws.reshape(-1, 16)
# if self.max_n_cond > 1:
# # TODO implement this
# n_cond = np.random.randint(1, self.max_n_cond + 1)
# # debug
# source_index = [0]
# if n_cond > 1:
# source_index += np.random.choice(
# np.arange(1, self.num_frames),
# self.max_n_cond - 1,
# replace=False,
# ).tolist()
# data["pixelnerf_input"]["source_index"] = torch.as_tensor(
# source_index
# )
# data["pixelnerf_input"]["n_cond"] = n_cond
# data["pixelnerf_input"]["source_images"] = frames[source_index]
# data["pixelnerf_input"]["source_cameras"] = cameras[source_index]
data["pixelnerf_input"]["cameras"] = cameras
return data
def __len__(self):
return len(self.ids)
def collate_fn(self, batch):
# a hack to add source index and keep consistent within a batch
if self.max_n_cond > 1:
# TODO implement this
n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1)
# debug
# source_index = [0]
if n_cond > 1:
for b in batch:
source_index = [0] + np.random.choice(
np.arange(1, self.num_frames),
self.max_n_cond - 1,
replace=False,
).tolist()
b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
b["pixelnerf_input"]["n_cond"] = n_cond
b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
"cameras"
][source_index]
if self.cond_on_multi:
b["cond_frames_without_noise"] = b["frames"][source_index]
ret = video_collate_fn(batch)
if self.cond_on_multi:
ret["cond_frames_without_noise"] = rearrange(ret["cond_frames_without_noise"], "b t ... -> (b t) ...")
return ret
class MVImageNetFixedCond(MVImageNet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class MVImageNetDataset(LightningDataModule):
def __init__(
self,
root_dir,
batch_size=2,
shuffle=True,
num_workers=10,
prefetch_factor=2,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.shuffle = shuffle
self.transform = Compose(
[
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
self.train_dataset = MVImageNet(
root_dir=root_dir,
split="train",
transform=self.transform,
**kwargs,
)
self.test_dataset = MVImageNet(
root_dir=root_dir,
split="test",
transform=self.transform,
**kwargs,
)
def train_dataloader(self):
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0])
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=self.train_dataset.collate_fn,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=self.test_dataset.collate_fn,
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
num_workers=self.num_workers,
prefetch_factor=self.prefetch_factor,
collate_fn=video_collate_fn,
)