seawolf2357's picture
Upload folder using huggingface_hub
684943d verified
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import copy
import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from field_construction.utils.general_utils import PILtoTorch
from field_construction.utils.graphics_utils import (
fov2focal, getProjectionMatrix, getProjectionMatrixCenterShift,
getWorld2View2)
def dilate(bin_img, ksize=6):
pad = (ksize - 1) // 2
bin_img = F.pad(bin_img, pad=[pad, pad, pad, pad], mode='reflect')
out = F.max_pool2d(bin_img, kernel_size=ksize, stride=1, padding=0)
return out
def erode(bin_img, ksize=12):
out = 1 - dilate(1 - bin_img, ksize)
return out
def process_image(image_path, resolution, ncc_scale):
image = Image.open(image_path)
if len(image.split()) > 3:
resized_image_rgb = torch.cat([PILtoTorch(im, resolution) for im in image.split()[:3]], dim=0)
loaded_mask = PILtoTorch(image.split()[3], resolution)
gt_image = resized_image_rgb
if ncc_scale != 1.0:
ncc_resolution = (int(resolution[0]/ncc_scale), int(resolution[1]/ncc_scale))
resized_image_rgb = torch.cat([PILtoTorch(im, ncc_resolution) for im in image.split()[:3]], dim=0)
else:
resized_image_rgb = PILtoTorch(image, resolution)
loaded_mask = None
gt_image = resized_image_rgb
if ncc_scale != 1.0:
ncc_resolution = (int(resolution[0]/ncc_scale), int(resolution[1]/ncc_scale))
resized_image_rgb = PILtoTorch(image, ncc_resolution)
gray_image = (0.299 * resized_image_rgb[0] + 0.587 * resized_image_rgb[1] + 0.114 * resized_image_rgb[2])[None]
return gt_image, gray_image, loaded_mask
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy,
image_width, image_height,
image_path, image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0,
ncc_scale=1.0,
preload_img=True, data_device = "cuda"
):
super(Camera, self).__init__()
self.uid = uid
self.nearest_id = []
self.nearest_names = []
self.colmap_id = colmap_id
self.R = R
self.T = T
self.FoVx = FoVx
self.FoVy = FoVy
self.image_name = image_name
self.image_path = image_path
self.image_width = image_width
self.image_height = image_height
self.resolution = (image_width, image_height)
self.Fx = fov2focal(FoVx, self.image_width)
self.Fy = fov2focal(FoVy, self.image_height)
self.Cx = 0.5 * self.image_width
self.Cy = 0.5 * self.image_height
base_image_path = "/".join(self.image_path.split("/")[:-2])
self.normal_path = os.path.join(base_image_path, "normal", self.image_path.split("/")[-1])
try:
self.data_device = torch.device(data_device)
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.original_image, self.image_gray, self.mask = None, None, None
self.preload_img = preload_img
self.ncc_scale = ncc_scale
if self.preload_img:
gt_image, gray_image, loaded_mask = process_image(self.image_path, self.resolution, ncc_scale)
self.original_image = gt_image.to(self.data_device)
self.original_image_gray = gray_image.to(self.data_device)
self.mask = loaded_mask
self.zfar = 100.0
self.znear = 0.01
self.trans = trans
self.scale = scale
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
self.plane_mask, self.non_plane_mask = None, None
def get_image(self):
if self.preload_img:
return self.original_image.cuda(), self.original_image_gray.cuda()
else:
gt_image, gray_image, _ = process_image(self.image_path, self.resolution, self.ncc_scale)
return gt_image.cuda(), gray_image.cuda()
def get_normal(self):
_normal = Image.open(self.normal_path)
resized_normal = PILtoTorch(_normal, self.resolution)
resized_normal = resized_normal[:3]
_normal = - (resized_normal * 2 - 1).cuda()
# normalize normal
_normal = _normal.permute(1, 2, 0) @ (torch.linalg.inv(torch.as_tensor(self.R).float()).cuda())
normal_gt = _normal.permute(2, 0, 1)
normal_norm = torch.norm(normal_gt, dim=0, keepdim=True)
normal_mask = ~((normal_norm > 1.1) | (normal_norm < 0.9))
normal_gt /= normal_norm
return normal_gt, normal_mask
def get_language_feature(self, language_feature_dir):
language_feature_name = os.path.join(language_feature_dir, self.image_name)
feature_map = torch.from_numpy(np.load(language_feature_name + '_f.npy')).to(self.data_device)
if len(feature_map.shape) < 4:
feature_map = feature_map[None]
point_feature = F.interpolate(feature_map, (self.image_height, self.image_width), mode="bilinear", align_corners=False)
seg_map = torch.from_numpy(np.load(language_feature_name + "_s.npy")).to(self.data_device) # (h, w)
seg_map = seg_map.long()
mask = seg_map != -1
# perform mask_pooling:
point_feature = point_feature.squeeze(0) # (feat_dim, h, w)
# for color_id in range(seg_map.max() + 1):
# point_feature[:, seg_map == color_id] = point_feature[:, seg_map == color_id].mean(dim=-1, keepdim=True)
return point_feature, mask, seg_map
def get_calib_matrix_nerf(self, scale=1.0):
intrinsic_matrix = torch.tensor([[self.Fx/scale, 0, self.Cx/scale], [0, self.Fy/scale, self.Cy/scale], [0, 0, 1]]).float()
extrinsic_matrix = self.world_view_transform.transpose(0,1).contiguous() # cam2world
return intrinsic_matrix, extrinsic_matrix
def get_rays(self, scale=1.0):
W, H = int(self.image_width/scale), int(self.image_height/scale)
ix, iy = torch.meshgrid(
torch.arange(W), torch.arange(H), indexing='xy')
rays_d = torch.stack(
[(ix-self.Cx/scale) / self.Fx * scale,
(iy-self.Cy/scale) / self.Fy * scale,
torch.ones_like(ix)], -1).float().cuda()
return rays_d
def get_k(self, scale=1.0):
K = torch.tensor([[self.Fx / scale, 0, self.Cx / scale],
[0, self.Fy / scale, self.Cy / scale],
[0, 0, 1]]).cuda()
return K
def get_inv_k(self, scale=1.0):
K_T = torch.tensor([[scale/self.Fx, 0, -self.Cx/self.Fx],
[0, scale/self.Fy, -self.Cy/self.Fy],
[0, 0, 1]]).cuda()
return K_T
class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width
self.image_height = height
self.FoVy = fovy
self.FoVx = fovx
self.znear = znear
self.zfar = zfar
self.world_view_transform = world_view_transform
self.full_proj_transform = full_proj_transform
view_inv = torch.inverse(self.world_view_transform)
self.camera_center = view_inv[3][:3]
def sample_cam(cam_l: Camera, cam_r: Camera):
cam = copy.copy(cam_l)
Rt = np.zeros((4, 4))
Rt[:3, :3] = cam_l.R.transpose()
Rt[:3, 3] = cam_l.T
Rt[3, 3] = 1.0
Rt2 = np.zeros((4, 4))
Rt2[:3, :3] = cam_r.R.transpose()
Rt2[:3, 3] = cam_r.T
Rt2[3, 3] = 1.0
C2W = np.linalg.inv(Rt)
C2W2 = np.linalg.inv(Rt2)
w = np.random.rand()
pose_c2w_at_unseen = w * C2W + (1 - w) * C2W2
Rt = np.linalg.inv(pose_c2w_at_unseen)
cam.R = Rt[:3, :3]
cam.T = Rt[:3, 3]
cam.world_view_transform = torch.tensor(getWorld2View2(cam.R, cam.T, cam.trans, cam.scale)).transpose(0, 1).cuda()
cam.projection_matrix = getProjectionMatrix(znear=cam.znear, zfar=cam.zfar, fovX=cam.FoVx, fovY=cam.FoVy).transpose(0,1).cuda()
cam.full_proj_transform = (cam.world_view_transform.unsqueeze(0).bmm(cam.projection_matrix.unsqueeze(0))).squeeze(0)
cam.camera_center = cam.world_view_transform.inverse()[3, :3]
return cam