Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import os | |
import random | |
from dataclasses import dataclass, field | |
import cv2 | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from torch.utils.data import DataLoader, Dataset | |
from ..utils.config import parse_structured | |
from ..utils.geometry import ( | |
get_plucker_embeds_from_cameras, | |
get_plucker_embeds_from_cameras_ortho, | |
get_position_map_from_depth, | |
get_position_map_from_depth_ortho, | |
) | |
from ..utils.typing import * | |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
def _parse_scene_list_single(scene_list_path: str, root_data_dir: str): | |
all_scenes = [] | |
if scene_list_path.endswith(".json"): | |
with open(scene_list_path) as f: | |
for p in json.loads(f.read()): | |
if "/" in p: | |
all_scenes.append(os.path.join(root_data_dir, p)) | |
else: | |
all_scenes.append(os.path.join(root_data_dir, p[:2], p)) | |
elif scene_list_path.endswith(".txt"): | |
with open(scene_list_path) as f: | |
for p in f.readlines(): | |
p = p.strip() | |
if "/" in p: | |
all_scenes.append(os.path.join(root_data_dir, p)) | |
else: | |
all_scenes.append(os.path.join(root_data_dir, p[:2], p)) | |
else: | |
raise NotImplementedError | |
return all_scenes | |
def _parse_scene_list( | |
scene_list_path: Union[str, List[str]], root_data_dir: Union[str, List[str]] | |
): | |
all_scenes = [] | |
if isinstance(scene_list_path, str): | |
scene_list_path = [scene_list_path] | |
if isinstance(root_data_dir, str): | |
root_data_dir = [root_data_dir] | |
for scene_list_path_, root_data_dir_ in zip(scene_list_path, root_data_dir): | |
all_scenes += _parse_scene_list_single(scene_list_path_, root_data_dir_) | |
return all_scenes | |
def _parse_reference_scene_list(reference_scenes: List[str], all_scenes: List[str]): | |
all_ids = set(scene.split("/")[-1] for scene in all_scenes) | |
ref_ids = set(scene.split("/")[-1] for scene in reference_scenes) | |
common_ids = ref_ids.intersection(all_ids) | |
all_scenes = [scene for scene in all_scenes if scene.split("/")[-1] in common_ids] | |
all_ids = {scene.split("/")[-1]: idx for idx, scene in enumerate(all_scenes)} | |
ref_scenes = [ | |
scene for scene in reference_scenes if scene.split("/")[-1] in all_ids | |
] | |
sorted_ref_scenes = sorted(ref_scenes, key=lambda x: all_ids[x.split("/")[-1]]) | |
scene2ref = { | |
scene: ref_scene for scene, ref_scene in zip(all_scenes, sorted_ref_scenes) | |
} | |
return all_scenes, scene2ref | |
class MultiviewDataModuleConfig: | |
root_dir: Any = "" | |
scene_list: Any = "" | |
image_suffix: str = "webp" | |
background_color: Union[str, float] = "gray" | |
image_names: List[str] = field(default_factory=lambda: []) | |
image_modality: str = "render" | |
num_views: int = 1 | |
random_view_list: Optional[List[List[int]]] = None | |
prompt_db_path: Optional[str] = None | |
return_prompt: bool = False | |
use_empty_prompt: bool = False | |
prompt_prefix: Optional[Any] = None | |
return_one_prompt: bool = True | |
projection_type: str = "ORTHO" | |
# source conditions | |
source_image_modality: Any = "position" | |
use_camera_space_normal: bool = False | |
position_offset: float = 0.5 | |
position_scale: float = 1.0 | |
plucker_offset: float = 1.0 | |
plucker_scale: float = 2.0 | |
# reference image | |
reference_root_dir: Optional[Any] = None | |
reference_scene_list: Optional[Any] = None | |
reference_image_modality: str = "render" | |
reference_image_names: List[str] = field(default_factory=lambda: []) | |
reference_augment_resolutions: Optional[List[int]] = None | |
reference_mask_aug: bool = False | |
repeat: int = 1 # for debugging purpose | |
train_indices: Optional[Tuple[Any, Any]] = None | |
val_indices: Optional[Tuple[Any, Any]] = None | |
test_indices: Optional[Tuple[Any, Any]] = None | |
height: int = 768 | |
width: int = 768 | |
batch_size: int = 1 | |
eval_batch_size: int = 1 | |
num_workers: int = 16 | |
class MultiviewDataset(Dataset): | |
def __init__(self, cfg: Any, split: str = "train") -> None: | |
super().__init__() | |
assert split in ["train", "val", "test"] | |
self.cfg: MultiviewDataModuleConfig = cfg | |
self.all_scenes = _parse_scene_list(self.cfg.scene_list, self.cfg.root_dir) | |
if ( | |
self.cfg.reference_root_dir is not None | |
and self.cfg.reference_scene_list is not None | |
): | |
reference_scenes = _parse_scene_list( | |
self.cfg.reference_scene_list, self.cfg.reference_root_dir | |
) | |
self.all_scenes, self.reference_scenes = _parse_reference_scene_list( | |
reference_scenes, self.all_scenes | |
) | |
else: | |
self.reference_scenes = None | |
self.split = split | |
if self.split == "train" and self.cfg.train_indices is not None: | |
self.all_scenes = self.all_scenes[ | |
self.cfg.train_indices[0] : self.cfg.train_indices[1] | |
] | |
self.all_scenes = self.all_scenes * self.cfg.repeat | |
elif self.split == "val" and self.cfg.val_indices is not None: | |
self.all_scenes = self.all_scenes[ | |
self.cfg.val_indices[0] : self.cfg.val_indices[1] | |
] | |
elif self.split == "test" and self.cfg.test_indices is not None: | |
self.all_scenes = self.all_scenes[ | |
self.cfg.test_indices[0] : self.cfg.test_indices[1] | |
] | |
if self.cfg.prompt_db_path is not None: | |
self.prompt_db = json.load(open(self.cfg.prompt_db_path)) | |
else: | |
self.prompt_db = None | |
def __len__(self): | |
return len(self.all_scenes) | |
def get_bg_color(self, bg_color): | |
if bg_color == "white": | |
bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32) | |
elif bg_color == "black": | |
bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32) | |
elif bg_color == "gray": | |
bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32) | |
elif bg_color == "random": | |
bg_color = np.random.rand(3) | |
elif bg_color == "random_gray": | |
bg_color = random.uniform(0.3, 0.7) | |
bg_color = np.array([bg_color] * 3, dtype=np.float32) | |
elif isinstance(bg_color, float): | |
bg_color = np.array([bg_color] * 3, dtype=np.float32) | |
elif isinstance(bg_color, list) or isinstance(bg_color, tuple): | |
bg_color = np.array(bg_color, dtype=np.float32) | |
else: | |
raise NotImplementedError | |
return bg_color | |
def load_image( | |
self, | |
image: Union[str, Image.Image], | |
height: int, | |
width: int, | |
background_color: torch.Tensor, | |
rescale: bool = False, | |
mask_aug: bool = False, | |
): | |
if isinstance(image, str): | |
image = Image.open(image) | |
image = image.resize((width, height)) | |
image = torch.from_numpy(np.array(image)).float() / 255.0 | |
if mask_aug: | |
alpha = image[:, :, 3] # Extract alpha channel | |
h, w = alpha.shape | |
y_indices, x_indices = torch.where(alpha > 0.5) | |
if len(y_indices) > 0 and len(x_indices) > 0: | |
idx = torch.randint(len(y_indices), (1,)).item() | |
y_center = y_indices[idx].item() | |
x_center = x_indices[idx].item() | |
mask_h = random.randint(h // 8, h // 4) | |
mask_w = random.randint(w // 8, w // 4) | |
y1 = max(0, y_center - mask_h // 2) | |
y2 = min(h, y_center + mask_h // 2) | |
x1 = max(0, x_center - mask_w // 2) | |
x2 = min(w, x_center + mask_w // 2) | |
alpha[y1:y2, x1:x2] = 0.0 | |
image[:, :, 3] = alpha | |
image = image[:, :, :3] * image[:, :, 3:4] + background_color * ( | |
1 - image[:, :, 3:4] | |
) | |
if rescale: | |
image = image * 2.0 - 1.0 | |
return image | |
def load_normal_image( | |
self, | |
path, | |
height, | |
width, | |
background_color, | |
camera_space: bool = False, | |
c2w: Optional[torch.FloatTensor] = None, | |
): | |
image = Image.open(path).resize((width, height), resample=Image.NEAREST) | |
image = torch.from_numpy(np.array(image)).float() / 255.0 | |
alpha = image[:, :, 3:4] | |
image = image[:, :, :3] | |
if camera_space: | |
w2c = torch.linalg.inv(c2w)[:3, :3] | |
image = ( | |
F.normalize(((image * 2 - 1)[:, :, None, :] * w2c).sum(-1), dim=-1) | |
* 0.5 | |
+ 0.5 | |
) | |
image = image * alpha + background_color * (1 - alpha) | |
return image | |
def load_depth(self, path, height, width): | |
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED) | |
depth = cv2.resize(depth, (width, height), interpolation=cv2.INTER_NEAREST) | |
depth = torch.from_numpy(depth[..., 0:1]).float() | |
mask = torch.ones_like(depth) | |
mask[depth > 1000.0] = 0.0 # depth = 65535 is the invalid value | |
depth[~(mask > 0.5)] = 0.0 | |
return depth, mask | |
def retrieve_prompt(self, scene_dir): | |
assert self.prompt_db is not None | |
source_id = os.path.basename(scene_dir) | |
return self.prompt_db.get(source_id, "") | |
def __getitem__(self, index): | |
background_color = torch.as_tensor(self.get_bg_color(self.cfg.background_color)) | |
scene_dir = self.all_scenes[index] | |
with open(os.path.join(scene_dir, "meta.json")) as f: | |
meta = json.load(f) | |
name2loc = {loc["index"]: loc for loc in meta["locations"]} | |
# target multi-view images | |
image_paths = [ | |
os.path.join( | |
scene_dir, f"{self.cfg.image_modality}_{f}.{self.cfg.image_suffix}" | |
) | |
for f in self.cfg.image_names | |
] | |
images = [ | |
self.load_image( | |
p, | |
height=self.cfg.height, | |
width=self.cfg.width, | |
background_color=background_color, | |
) | |
for p in image_paths | |
] | |
images = torch.stack(images, dim=0).permute(0, 3, 1, 2) | |
# camera | |
c2w = [ | |
torch.as_tensor(name2loc[name]["transform_matrix"]) | |
for name in self.cfg.image_names | |
] | |
c2w = torch.stack(c2w, dim=0) | |
if self.cfg.projection_type == "PERSP": | |
camera_angle_x = ( | |
meta.get("camera_angle_x", None) | |
or meta["locations"][0]["camera_angle_x"] | |
) | |
focal_length = 0.5 * self.cfg.width / np.tan(0.5 * camera_angle_x) | |
intrinsics = ( | |
torch.as_tensor( | |
[ | |
[focal_length, 0.0, 0.5 * self.cfg.width], | |
[0.0, focal_length, 0.5 * self.cfg.height], | |
[0.0, 0.0, 1.0], | |
] | |
) | |
.unsqueeze(0) | |
.float() | |
.repeat(len(self.cfg.image_names), 1, 1) | |
) | |
elif self.cfg.projection_type == "ORTHO": | |
ortho_scale = ( | |
meta.get("ortho_scale", None) or meta["locations"][0]["ortho_scale"] | |
) | |
# source conditions | |
source_image_modality = self.cfg.source_image_modality | |
if isinstance(source_image_modality, str): | |
source_image_modality = [source_image_modality] | |
source_images = [] | |
for modality in source_image_modality: | |
if modality == "position": | |
depth_masks = [ | |
self.load_depth( | |
os.path.join(scene_dir, f"depth_{f}.exr"), | |
self.cfg.height, | |
self.cfg.width, | |
) | |
for f in self.cfg.image_names | |
] | |
depths = torch.stack([d for d, _ in depth_masks]) | |
masks = torch.stack([m for _, m in depth_masks]) | |
c2w_ = c2w.clone() | |
c2w_[:, :, 1:3] *= -1 | |
if self.cfg.projection_type == "PERSP": | |
position_maps = get_position_map_from_depth( | |
depths, | |
masks, | |
intrinsics, | |
c2w_, | |
image_wh=(self.cfg.width, self.cfg.height), | |
) | |
elif self.cfg.projection_type == "ORTHO": | |
position_maps = get_position_map_from_depth_ortho( | |
depths, | |
masks, | |
c2w_, | |
ortho_scale, | |
image_wh=(self.cfg.width, self.cfg.height), | |
) | |
position_maps = ( | |
(position_maps + self.cfg.position_offset) / self.cfg.position_scale | |
).clamp(0.0, 1.0) | |
source_images.append(position_maps) | |
elif modality == "normal": | |
normal_maps = [ | |
self.load_normal_image( | |
os.path.join( | |
scene_dir, f"{modality}_{f}.{self.cfg.image_suffix}" | |
), | |
height=self.cfg.height, | |
width=self.cfg.width, | |
background_color=background_color, | |
camera_space=self.cfg.use_camera_space_normal, | |
c2w=c, | |
) | |
for c, f in zip(c2w, self.cfg.image_names) | |
] | |
source_images.append(torch.stack(normal_maps, dim=0)) | |
elif modality == "plucker": | |
if self.cfg.projection_type == "ORTHO": | |
plucker_embed = get_plucker_embeds_from_cameras_ortho( | |
c2w, [ortho_scale] * len(c2w), self.cfg.width | |
) | |
elif self.cfg.projection_type == "PERSP": | |
plucker_embed = get_plucker_embeds_from_cameras( | |
c2w, [camera_angle_x] * len(c2w), self.cfg.width | |
) | |
else: | |
raise NotImplementedError | |
plucker_embed = plucker_embed.permute(0, 2, 3, 1) | |
plucker_embed = ( | |
(plucker_embed + self.cfg.plucker_offset) / self.cfg.plucker_scale | |
).clamp(0.0, 1.0) | |
source_images.append(plucker_embed) | |
else: | |
raise NotImplementedError | |
source_images = torch.cat(source_images, dim=-1).permute(0, 3, 1, 2) | |
rv = {"rgb": images, "c2w": c2w, "source_rgb": source_images} | |
num_images = len(self.cfg.image_names) | |
# prompt | |
if self.cfg.return_prompt: | |
if self.cfg.use_empty_prompt: | |
prompt = "" | |
else: | |
prompt = self.retrieve_prompt(scene_dir) | |
prompts = [prompt] * num_images | |
if self.cfg.prompt_prefix is not None: | |
prompt_prefix = self.cfg.prompt_prefix | |
if isinstance(prompt_prefix, str): | |
prompt_prefix = [prompt_prefix] * num_images | |
for i, prompt in enumerate(prompts): | |
prompts[i] = f"{prompt_prefix[i]} {prompt}" | |
if self.cfg.return_one_prompt: | |
rv.update({"prompts": prompts[0]}) | |
else: | |
rv.update({"prompts": prompts}) | |
# reference image | |
if self.reference_scenes is not None: | |
reference_scene_dir = self.reference_scenes[scene_dir] | |
reference_image_paths = [ | |
os.path.join( | |
reference_scene_dir, | |
f"{self.cfg.reference_image_modality}_{f}.{self.cfg.image_suffix}", | |
) | |
for f in self.cfg.reference_image_names | |
] | |
reference_image_path = random.choice(reference_image_paths) | |
if self.cfg.reference_augment_resolutions is None: | |
reference_image = self.load_image( | |
reference_image_path, | |
height=self.cfg.height, | |
width=self.cfg.width, | |
background_color=background_color, | |
mask_aug=self.cfg.reference_mask_aug, | |
).permute(2, 0, 1) | |
rv.update({"reference_rgb": reference_image}) | |
else: | |
random_resolution = random.choice( | |
self.cfg.reference_augment_resolutions | |
) | |
reference_image_ = Image.open(reference_image_path).resize( | |
(random_resolution, random_resolution) | |
) | |
reference_image = self.load_image( | |
reference_image_, | |
height=self.cfg.height, | |
width=self.cfg.width, | |
background_color=background_color, | |
mask_aug=self.cfg.reference_mask_aug, | |
).permute(2, 0, 1) | |
rv.update({"reference_rgb": reference_image}) | |
return rv | |
def collate(self, batch): | |
batch = torch.utils.data.default_collate(batch) | |
pack = lambda t: t.view(-1, *t.shape[2:]) | |
if self.cfg.random_view_list is not None: | |
indices = random.choice(self.cfg.random_view_list) | |
else: | |
indices = list(range(self.cfg.num_views)) | |
num_views = len(indices) | |
for k in batch.keys(): | |
if k in ["rgb", "source_rgb", "c2w"]: | |
batch[k] = batch[k][:, indices] | |
batch[k] = pack(batch[k]) | |
for k in ["prompts"]: | |
if not self.cfg.return_one_prompt: | |
batch[k] = [item for pair in zip(*batch[k]) for item in pair] | |
batch.update( | |
{ | |
"num_views": num_views, | |
# For SDXL | |
"original_size": (self.cfg.height, self.cfg.width), | |
"target_size": (self.cfg.height, self.cfg.width), | |
"crops_coords_top_left": (0, 0), | |
} | |
) | |
return batch | |
class MultiviewDataModule(pl.LightningDataModule): | |
cfg: MultiviewDataModuleConfig | |
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: | |
super().__init__() | |
self.cfg = parse_structured(MultiviewDataModuleConfig, cfg) | |
def setup(self, stage=None) -> None: | |
if stage in [None, "fit"]: | |
self.train_dataset = MultiviewDataset(self.cfg, "train") | |
if stage in [None, "fit", "validate"]: | |
self.val_dataset = MultiviewDataset(self.cfg, "val") | |
if stage in [None, "test", "predict"]: | |
self.test_dataset = MultiviewDataset(self.cfg, "test") | |
def prepare_data(self): | |
pass | |
def train_dataloader(self) -> DataLoader: | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.cfg.batch_size, | |
num_workers=self.cfg.num_workers, | |
shuffle=True, | |
collate_fn=self.train_dataset.collate, | |
) | |
def val_dataloader(self) -> DataLoader: | |
return DataLoader( | |
self.val_dataset, | |
batch_size=self.cfg.eval_batch_size, | |
num_workers=self.cfg.num_workers, | |
shuffle=False, | |
collate_fn=self.val_dataset.collate, | |
) | |
def test_dataloader(self) -> DataLoader: | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.cfg.eval_batch_size, | |
num_workers=self.cfg.num_workers, | |
shuffle=False, | |
collate_fn=self.test_dataset.collate, | |
) | |
def predict_dataloader(self) -> DataLoader: | |
return self.test_dataloader() | |