|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from typing import Union |
|
import random |
|
import numpy as np |
|
import torch |
|
from megfile import smart_path_join, smart_open |
|
|
|
from .cam_utils import build_camera_standard, build_camera_principle, camera_normalization_objaverse |
|
from ..utils.proxy import no_proxy |
|
from .objaverse import ObjaverseDataset |
|
from .back_transform.back_transform import transform_back_image |
|
|
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
__all__ = ['GobjaverseDataset'] |
|
|
|
def opposite_view(i): |
|
if 0 <= i <= 24: |
|
return (i + 12) % 24 |
|
elif 27 <= i <= 39: |
|
return ((i - 27) + 6) % 12 + 27 |
|
else: |
|
raise ValueError("Input number must be between 0-24 or 27-39.") |
|
|
|
def get_random_views(rgba_dir, num_views=4): |
|
all_files = [f for f in os.listdir(rgba_dir) if f.endswith('.png')] |
|
view_numbers = [int(os.path.splitext(f)[0]) for f in all_files] |
|
selected_views = random.sample(view_numbers, num_views) |
|
return np.array(selected_views) |
|
|
|
class GobjaverseDataset(ObjaverseDataset): |
|
|
|
def __init__(self, root_dirs: list[str], meta_path: str, |
|
sample_side_views: int, |
|
render_image_res_low: int, render_image_res_high: int, render_region_size: int, |
|
source_image_res: int, normalize_camera: bool, |
|
normed_dist_to_center: Union[float, str] = None, num_all_views: int = 32): |
|
super().__init__( |
|
root_dirs, meta_path, |
|
sample_side_views, |
|
render_image_res_low, |
|
render_image_res_high, |
|
render_region_size, |
|
source_image_res, |
|
normalize_camera, |
|
normed_dist_to_center, |
|
num_all_views, |
|
) |
|
|
|
self.back_transforms = transform_back_image() |
|
|
|
|
|
@staticmethod |
|
def _load_pose_txt(file_path): |
|
with open(file_path, 'r') as file: |
|
lines = file.readlines() |
|
pose_data = np.array([list(map(float, line.split())) for line in lines], dtype=np.float32) |
|
pose = torch.from_numpy(pose_data).reshape(4, 4) |
|
opengl2opencv = np.array([ |
|
[1, 0, 0, 0], |
|
[0, -1, 0, 0], |
|
[0, 0, -1, 0], |
|
[0, 0, 0, 1] |
|
], dtype=np.float32) |
|
|
|
pose = np.matmul(pose, opengl2opencv) |
|
return pose[:3, :] |
|
|
|
@staticmethod |
|
def _load_rgba_image_transform(file_path, bg_color: float = 1.0, extra_transforms=None): |
|
''' Load and blend RGBA image to RGB with certain background, 0-1 scaled ''' |
|
rgba = np.array(Image.open(smart_open(file_path, 'rb')) ) |
|
rgba = torch.from_numpy(rgba).float() / 255.0 |
|
rgba = rgba.permute(2, 0, 1).unsqueeze(0) |
|
rgb = rgba[:, :3, :, :] * rgba[:, 3:4, :, :] + bg_color * (1 - rgba[:, 3:, :, :]) |
|
if extra_transforms is not None: |
|
rgb = extra_transforms( |
|
transforms.ToPILImage()(rgb.squeeze()) |
|
).unsqueeze(0) |
|
return rgb |
|
|
|
@no_proxy |
|
def inner_get_item(self, idx): |
|
""" |
|
Loaded contents: |
|
rgbs: [M, 3, H, W] |
|
poses: [M, 3, 4], [R|t] |
|
intrinsics: [3, 2], [[fx, fy], [cx, cy], [weight, height]] |
|
""" |
|
uid = self.uids[idx] |
|
root_dir = self._locate_datadir(self.root_dirs, uid, locator="pose") |
|
|
|
pose_dir = os.path.join(root_dir, uid, 'pose') |
|
rgba_dir = os.path.join(root_dir, uid, 'rgb') |
|
|
|
|
|
intrinsics = torch.tensor([[384, 384], [256, 256], [512, 512]], dtype=torch.float) |
|
|
|
|
|
sample_views = get_random_views(rgba_dir, num_views=self.sample_side_views) |
|
source_image_view_back = opposite_view(sample_views[0]) |
|
sample_views = np.insert(sample_views, 1, source_image_view_back) |
|
|
|
poses, rgbs, bg_colors = [], [], [] |
|
source_image = None |
|
for view in sample_views: |
|
pose_path = smart_path_join(pose_dir, f'{view:03d}.txt') |
|
rgba_path = smart_path_join(rgba_dir, f'{view:03d}.png') |
|
pose = self._load_pose_txt(pose_path) |
|
bg_color = random.choice([0.0, 0.5, 1.0]) |
|
rgb = self._load_rgba_image(rgba_path, bg_color=bg_color) |
|
poses.append(pose) |
|
rgbs.append(rgb) |
|
bg_colors.append(bg_color) |
|
if source_image is None: |
|
source_image = self._load_rgba_image(rgba_path, bg_color=1.0) |
|
assert source_image is not None, "Really bad luck!" |
|
poses = torch.stack(poses, dim=0) |
|
rgbs = torch.cat(rgbs, dim=0) |
|
|
|
|
|
source_image_back = self._load_rgba_image_transform(smart_path_join(rgba_dir, f'{sample_views[1]:03d}.png'), bg_color=bg_color) |
|
|
|
if self.normalize_camera: |
|
poses = camera_normalization_objaverse(self.normed_dist_to_center, poses) |
|
|
|
|
|
source_camera = build_camera_principle(poses[:1], intrinsics.unsqueeze(0)).squeeze(0) |
|
render_camera = build_camera_standard(poses, intrinsics.repeat(poses.shape[0], 1, 1)) |
|
|
|
|
|
source_image = torch.nn.functional.interpolate( |
|
source_image, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0) |
|
source_image = torch.clamp(source_image, 0, 1) |
|
|
|
|
|
source_image_back = torch.nn.functional.interpolate( |
|
source_image_back, size=(self.source_image_res, self.source_image_res), mode='bicubic', align_corners=True).squeeze(0) |
|
source_image_back = torch.clamp(source_image_back, 0, 1) |
|
|
|
|
|
render_image_res = np.random.randint(self.render_image_res_low, self.render_image_res_high + 1) |
|
render_image = torch.nn.functional.interpolate( |
|
rgbs, size=(render_image_res, render_image_res), mode='bicubic', align_corners=True) |
|
render_image = torch.clamp(render_image, 0, 1) |
|
anchors = torch.randint( |
|
0, render_image_res - self.render_region_size + 1, size=(self.sample_side_views + 1, 2)) |
|
crop_indices = torch.arange(0, self.render_region_size, device=render_image.device) |
|
index_i = (anchors[:, 0].unsqueeze(1) + crop_indices).view(-1, self.render_region_size, 1) |
|
index_j = (anchors[:, 1].unsqueeze(1) + crop_indices).view(-1, 1, self.render_region_size) |
|
batch_indices = torch.arange(self.sample_side_views + 1, device=render_image.device).view(-1, 1, 1) |
|
cropped_render_image = render_image[batch_indices, :, index_i, index_j].permute(0, 3, 1, 2) |
|
|
|
return { |
|
'uid': uid, |
|
'source_camera': source_camera, |
|
'render_camera': render_camera, |
|
'source_image': source_image, |
|
'render_image': cropped_render_image, |
|
'source_image_back': source_image_back, |
|
'render_anchors': anchors, |
|
'render_full_resolutions': torch.tensor([[render_image_res]], dtype=torch.float32).repeat(self.sample_side_views + 1, 1), |
|
'render_bg_colors': torch.tensor(bg_colors, dtype=torch.float32).unsqueeze(-1), |
|
} |
|
|