Spaces:
Running
on
Zero
Running
on
Zero
from PIL import Image | |
import torch | |
import numpy as np | |
from pytorch3d.structures import Meshes | |
from pytorch3d.renderer import TexturesVertex | |
from scripts.utils import meshlab_mesh_to_py3dmesh, py3dmesh_to_meshlab_mesh | |
import pymeshlab | |
_MAX_THREAD = 8 | |
# rgb and depth to mesh | |
def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True, device="cuda"): | |
pixel_center = 0.5 if use_pixel_centers else 0 | |
i, j = np.meshgrid( | |
np.arange(W, dtype=np.float32) + pixel_center, | |
np.arange(H, dtype=np.float32) + pixel_center, | |
indexing='xy' | |
) | |
i, j = torch.from_numpy(i).to(device), torch.from_numpy(j).to(device) | |
origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2 * H / W, torch.zeros_like(i)], dim=-1) # W, H, 3 | |
directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3 | |
return origins, directions | |
def depth_and_color_to_mesh(rgb_BCHW, pred_HWC, valid_HWC=None, is_back=False): | |
if valid_HWC is None: | |
valid_HWC = torch.ones_like(pred_HWC).bool() | |
H, W = rgb_BCHW.shape[-2:] | |
rgb_BCHW = rgb_BCHW.flip(-2) | |
pred_HWC = pred_HWC.flip(0) | |
valid_HWC = valid_HWC.flip(0) | |
rays_o, rays_d = get_ortho_ray_directions_origins(W, H, device=rgb_BCHW.device) | |
verts = rays_o + rays_d * pred_HWC # [H, W, 3] | |
verts = verts.reshape(-1, 3) # [V, 3] | |
indexes = torch.arange(H * W).reshape(H, W).to(rgb_BCHW.device) | |
faces1 = torch.stack([indexes[:-1, :-1], indexes[:-1, 1:], indexes[1:, :-1]], dim=-1) | |
# faces1_valid = valid_HWC[:-1, :-1] | valid_HWC[:-1, 1:] | valid_HWC[1:, :-1] | |
faces1_valid = valid_HWC[:-1, :-1] & valid_HWC[:-1, 1:] & valid_HWC[1:, :-1] | |
faces2 = torch.stack([indexes[1:, 1:], indexes[1:, :-1], indexes[:-1, 1:]], dim=-1) | |
# faces2_valid = valid_HWC[1:, 1:] | valid_HWC[1:, :-1] | valid_HWC[:-1, 1:] | |
faces2_valid = valid_HWC[1:, 1:] & valid_HWC[1:, :-1] & valid_HWC[:-1, 1:] | |
faces = torch.cat([faces1[faces1_valid.expand_as(faces1)].reshape(-1, 3), faces2[faces2_valid.expand_as(faces2)].reshape(-1, 3)], dim=0) # (F, 3) | |
colors = (rgb_BCHW[0].permute((1,2,0)) / 2 + 0.5).reshape(-1, 3) # (V, 3) | |
if is_back: | |
verts = verts * torch.tensor([-1, 1, -1], dtype=verts.dtype, device=verts.device) | |
used_verts = faces.unique() | |
old_to_new_mapping = torch.zeros_like(verts[..., 0]).long() | |
old_to_new_mapping[used_verts] = torch.arange(used_verts.shape[0], device=verts.device) | |
new_faces = old_to_new_mapping[faces] | |
mesh = Meshes(verts=[verts[used_verts]], faces=[new_faces], textures=TexturesVertex(verts_features=[colors[used_verts]])) | |
return mesh | |
def normalmap_to_depthmap(normal_np): | |
from scripts.normal_to_height_map import estimate_height_map | |
height = estimate_height_map(normal_np, raw_values=True, thread_count=_MAX_THREAD, target_iteration_count=96) | |
return height | |
def transform_back_normal_to_front(normal_pil): | |
arr = np.array(normal_pil) # in [0, 255] | |
arr[..., 0] = 255-arr[..., 0] | |
arr[..., 2] = 255-arr[..., 2] | |
return Image.fromarray(arr.astype(np.uint8)) | |
def calc_w_over_h(normal_pil): | |
if isinstance(normal_pil, Image.Image): | |
arr = np.array(normal_pil) | |
else: | |
assert isinstance(normal_pil, np.ndarray) | |
arr = normal_pil | |
if arr.shape[-1] == 4: | |
alpha = arr[..., -1] / 255. | |
alpha[alpha >= 0.5] = 1 | |
alpha[alpha < 0.5] = 0 | |
else: | |
alpha = ~(arr.min(axis=-1) >= 250) | |
h_min, w_min = np.min(np.where(alpha), axis=1) | |
h_max, w_max = np.max(np.where(alpha), axis=1) | |
return (w_max - w_min) / (h_max - h_min) | |
def build_mesh(normal_pil, rgb_pil, is_back=False, clamp_min=-1, scale=0.3, init_type="std", offset=0): | |
if is_back: | |
normal_pil = transform_back_normal_to_front(normal_pil) | |
normal_img = np.array(normal_pil) | |
rgb_img = np.array(rgb_pil) | |
if normal_img.shape[-1] == 4: | |
valid_HWC = normal_img[..., [3]] / 255 | |
elif rgb_img.shape[-1] == 4: | |
valid_HWC = rgb_img[..., [3]] / 255 | |
else: | |
raise ValueError("invalid input, either normal or rgb should have alpha channel") | |
real_height_pix = np.max(np.where(valid_HWC>0.5)[0]) - np.min(np.where(valid_HWC>0.5)[0]) | |
heights = normalmap_to_depthmap(normal_img) | |
rgb_BCHW = torch.from_numpy(rgb_img[..., :3] / 255.).permute((2,0,1))[None] | |
valid_HWC[valid_HWC < 0.5] = 0 | |
valid_HWC[valid_HWC >= 0.5] = 1 | |
valid_HWC = torch.from_numpy(valid_HWC).bool() | |
if init_type == "std": | |
# accurate but not stable | |
pred_HWC = torch.from_numpy(heights / heights.max() * (real_height_pix / heights.shape[0]) * scale * 2).float()[..., None] | |
elif init_type == "thin": | |
heights = heights - heights.min() | |
heights = (heights / heights.max() * 0.2) | |
pred_HWC = torch.from_numpy(heights * scale).float()[..., None] | |
else: | |
# stable but not accurate | |
heights = heights - heights.min() | |
heights = (heights / heights.max() * (1-offset)) + offset # to [0.2, 1] | |
pred_HWC = torch.from_numpy(heights * scale).float()[..., None] | |
# set the boarder pixels to 0 height | |
import cv2 | |
# edge filter | |
edge = cv2.Canny((valid_HWC[..., 0] * 255).numpy().astype(np.uint8), 0, 255) | |
edge = torch.from_numpy(edge).bool()[..., None] | |
pred_HWC[edge] = 0 | |
valid_HWC[pred_HWC < clamp_min] = False | |
return depth_and_color_to_mesh(rgb_BCHW.cuda(), pred_HWC.cuda(), valid_HWC.cuda(), is_back) | |
def fix_border_with_pymeshlab_fast(meshes: Meshes, poissson_depth=6, simplification=0): | |
ms = pymeshlab.MeshSet() | |
ms.add_mesh(py3dmesh_to_meshlab_mesh(meshes), "cube_vcolor_mesh") | |
if simplification > 0: | |
ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True) | |
ms.apply_filter('generate_surface_reconstruction_screened_poisson', threads = 6, depth = poissson_depth, preclean = True) | |
if simplification > 0: | |
ms.apply_filter('meshing_decimation_quadric_edge_collapse', targetfacenum=simplification, preservetopology=True) | |
return meshlab_mesh_to_py3dmesh(ms.current_mesh()) | |