heheyas
init
cfb7702
raw
history blame
No virus
10.4 kB
import os
import json
import math
import numpy as np
from PIL import Image
import cv2
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
import datasets
from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions
from utils.misc import get_rank
from glob import glob
import PIL.Image
def camNormal2worldNormal(rot_c2w, camNormal):
H,W,_ = camNormal.shape
normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
return normal_img
def worldNormal2camNormal(rot_w2c, worldNormal):
H,W,_ = worldNormal.shape
normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
return normal_img
def trans_normal(normal, RT_w2c, RT_w2c_target):
normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
return normal_target_cam
def img2normal(img):
return (img/255.)*2-1
def normal2img(normal):
return np.uint8((normal*0.5+0.5)*255)
def norm_normalize(normal, dim=-1):
normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
return normal
def RT_opengl2opencv(RT):
# Build the coordinate transform matrix from world to computer vision camera
# R_world2cv = R_bcam2cv@R_world2bcam
# T_world2cv = R_bcam2cv@T_world2bcam
R = RT[:3, :3]
t = RT[:3, 3]
R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
R_world2cv = R_bcam2cv @ R
t_world2cv = R_bcam2cv @ t
RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
return RT
def normal_opengl2opencv(normal):
H,W,C = np.shape(normal)
# normal_img = np.reshape(normal, (H*W,C))
R_bcam2cv = np.array([1, -1, -1], np.float32)
normal_cv = normal * R_bcam2cv[None, None, :]
print(np.shape(normal_cv))
return normal_cv
def inv_RT(RT):
RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0)
RT_inv = np.linalg.inv(RT_h)
return RT_inv[:3, :]
def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None,
normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None):
all_images = []
all_normals = []
all_normals_world = []
all_masks = []
all_color_masks = []
all_poses = []
all_w2cs = []
directions = []
ray_origins = []
RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0]) # world2cam matrix
RT_front_cv = RT_opengl2opencv(RT_front) # convert normal from opengl to opencv
for idx, view in enumerate(view_types):
print(os.path.join(root_dir,test_object))
normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view))
# Load key frame
if load_color: # use bgr
image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3]
normal = np.array(PIL.Image.open(normal_filepath).resize(imSize))
mask = normal[:, :, 3]
normal = normal[:, :, :3]
color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3]
invalid_color_mask = color_mask < 255*0.5
threshold = np.ones_like(image[:, :, 0]) * 250
invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold)
invalid_color_mask_final = invalid_color_mask & invalid_white_mask
color_mask = (1 - invalid_color_mask_final) > 0
# if erode_mask:
# kernel = np.ones((3, 3), np.uint8)
# mask = cv2.erode(mask, kernel, iterations=1)
RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view))) # world2cam matrix
normal = img2normal(normal)
normal[mask==0] = [0,0,0]
mask = mask> (0.5*255)
if load_color:
all_images.append(image)
all_masks.append(mask)
all_color_masks.append(color_mask)
RT_cv = RT_opengl2opencv(RT) # convert normal from opengl to opencv
all_poses.append(inv_RT(RT_cv)) # cam2world
all_w2cs.append(RT_cv)
# whether to
normal_cam_cv = normal_opengl2opencv(normal)
if normal_system == 'front':
print("the loaded normals are defined in the system of front view")
normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv)
elif normal_system == 'self':
print("the loaded normals are in their independent camera systems")
normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv)
all_normals.append(normal_cam_cv)
all_normals_world.append(normal_world)
if camera_type == 'ortho':
origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1])
elif camera_type == 'pinhole':
dirs = get_ray_directions(W=imSize[0], H=imSize[1],
fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3])
origins = dirs # occupy a position
else:
raise Exception("not support camera type")
ray_origins.append(origins)
directions.append(dirs)
if not load_color:
all_images = [normal2img(x) for x in all_normals_world]
return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \
np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks)
class OrthoDatasetBase():
def setup(self, config, split):
self.config = config
self.split = split
self.rank = get_rank()
self.data_dir = self.config.root_dir
self.object_name = self.config.scene
self.scene = self.config.scene
self.imSize = self.config.imSize
self.load_color = True
self.img_wh = [self.imSize[0], self.imSize[1]]
self.w = self.img_wh[0]
self.h = self.img_wh[1]
self.camera_type = self.config.camera_type
self.camera_params = self.config.camera_params # [fx, fy, cx, cy]
self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1)
self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w)
if self.config.cam_pose_dir is None:
self.cam_pose_dir = "./datasets/fixed_poses"
else:
self.cam_pose_dir = self.config.cam_pose_dir
self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \
self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction(
self.data_dir, self.object_name, self.imSize, self.view_types,
self.load_color, self.cam_pose_dir, normal_system='front',
camera_type=self.camera_type, cam_params=self.camera_params)
self.has_mask = True
self.apply_mask = self.config.apply_mask
self.all_c2w = torch.from_numpy(self.pose_all_np)
self.all_images = torch.from_numpy(self.images_np) / 255.
self.all_fg_masks = torch.from_numpy(self.masks_np)
self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np)
self.all_normals_world = torch.from_numpy(self.normals_world_np)
self.origins = torch.from_numpy(self.origins_np)
self.directions = torch.from_numpy(self.directions_np)
self.directions = self.directions.float().to(self.rank)
self.origins = self.origins.float().to(self.rank)
self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank)
self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \
self.all_c2w.float().to(self.rank), \
self.all_images.float().to(self.rank), \
self.all_fg_masks.float().to(self.rank), \
self.all_normals_world.float().to(self.rank)
class OrthoDataset(Dataset, OrthoDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __len__(self):
return len(self.all_images)
def __getitem__(self, index):
return {
'index': index
}
class OrthoIterableDataset(IterableDataset, OrthoDatasetBase):
def __init__(self, config, split):
self.setup(config, split)
def __iter__(self):
while True:
yield {}
@datasets.register('ortho')
class OrthoDataModule(pl.LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
def setup(self, stage=None):
if stage in [None, 'fit']:
self.train_dataset = OrthoIterableDataset(self.config, 'train')
if stage in [None, 'fit', 'validate']:
self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train'))
if stage in [None, 'test']:
self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test'))
if stage in [None, 'predict']:
self.predict_dataset = OrthoDataset(self.config, 'train')
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size):
sampler = None
return DataLoader(
dataset,
num_workers=os.cpu_count(),
batch_size=batch_size,
pin_memory=True,
sampler=sampler
)
def train_dataloader(self):
return self.general_loader(self.train_dataset, batch_size=1)
def val_dataloader(self):
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self):
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self):
return self.general_loader(self.predict_dataset, batch_size=1)