import tempfile from contextlib import contextmanager from typing import Iterator, Optional, Union import blobfile as bf import numpy as np import torch from PIL import Image from shap_e.rendering.blender.render import render_mesh, render_model from shap_e.rendering.blender.view_data import BlenderViewData from shap_e.rendering.mesh import TriMesh from shap_e.rendering.point_cloud import PointCloud from shap_e.rendering.view_data import ViewData from shap_e.util.collections import AttrDict from shap_e.util.image_util import center_crop, get_alpha, remove_alpha, resize def load_or_create_multimodal_batch( device: torch.device, *, mesh_path: Optional[str] = None, model_path: Optional[str] = None, cache_dir: Optional[str] = None, point_count: int = 2**14, random_sample_count: int = 2**19, pc_num_views: int = 40, mv_light_mode: Optional[str] = None, mv_num_views: int = 20, mv_image_size: int = 512, mv_alpha_removal: str = "black", verbose: bool = False, ) -> AttrDict: if verbose: print("creating point cloud...") pc = load_or_create_pc( mesh_path=mesh_path, model_path=model_path, cache_dir=cache_dir, random_sample_count=random_sample_count, point_count=point_count, num_views=pc_num_views, verbose=verbose, ) raw_pc = np.concatenate([pc.coords, pc.select_channels(["R", "G", "B"])], axis=-1) encode_me = torch.from_numpy(raw_pc).float().to(device) batch = AttrDict(points=encode_me.t()[None]) if mv_light_mode: if verbose: print("creating multiview...") with load_or_create_multiview( mesh_path=mesh_path, model_path=model_path, cache_dir=cache_dir, num_views=mv_num_views, extract_material=False, light_mode=mv_light_mode, verbose=verbose, ) as mv: cameras, views, view_alphas, depths = [], [], [], [] for view_idx in range(mv.num_views): camera, view = mv.load_view( view_idx, ["R", "G", "B", "A"] if "A" in mv.channel_names else ["R", "G", "B"], ) depth = None if "D" in mv.channel_names: _, depth = mv.load_view(view_idx, ["D"]) depth = process_depth(depth, mv_image_size) view, alpha = process_image( np.round(view * 255.0).astype(np.uint8), mv_alpha_removal, mv_image_size ) camera = camera.center_crop().resize_image(mv_image_size, mv_image_size) cameras.append(camera) views.append(view) view_alphas.append(alpha) depths.append(depth) batch.depths = [depths] batch.views = [views] batch.view_alphas = [view_alphas] batch.cameras = [cameras] return normalize_input_batch(batch, pc_scale=2.0, color_scale=1.0 / 255.0) def load_or_create_pc( *, mesh_path: Optional[str], model_path: Optional[str], cache_dir: Optional[str], random_sample_count: int, point_count: int, num_views: int, verbose: bool = False, ) -> PointCloud: assert (model_path is not None) ^ ( mesh_path is not None ), "must specify exactly one of model_path or mesh_path" path = model_path if model_path is not None else mesh_path if cache_dir is not None: cache_path = bf.join( cache_dir, f"pc_{bf.basename(path)}_mat_{num_views}_{random_sample_count}_{point_count}.npz", ) if bf.exists(cache_path): return PointCloud.load(cache_path) else: cache_path = None with load_or_create_multiview( mesh_path=mesh_path, model_path=model_path, cache_dir=cache_dir, num_views=num_views, verbose=verbose, ) as mv: if verbose: print("extracting point cloud from multiview...") pc = mv_to_pc( multiview=mv, random_sample_count=random_sample_count, point_count=point_count ) if cache_path is not None: pc.save(cache_path) return pc @contextmanager def load_or_create_multiview( *, mesh_path: Optional[str], model_path: Optional[str], cache_dir: Optional[str], num_views: int = 20, extract_material: bool = True, light_mode: Optional[str] = None, verbose: bool = False, ) -> Iterator[BlenderViewData]: assert (model_path is not None) ^ ( mesh_path is not None ), "must specify exactly one of model_path or mesh_path" path = model_path if model_path is not None else mesh_path if extract_material: assert light_mode is None, "light_mode is ignored when extract_material=True" else: assert light_mode is not None, "must specify light_mode when extract_material=False" if cache_dir is not None: if extract_material: cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_mat_{num_views}.zip") else: cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_{light_mode}_{num_views}.zip") if bf.exists(cache_path): with bf.BlobFile(cache_path, "rb") as f: yield BlenderViewData(f) return else: cache_path = None common_kwargs = dict( fast_mode=True, extract_material=extract_material, camera_pose="random", light_mode=light_mode or "uniform", verbose=verbose, ) with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = bf.join(tmp_dir, "out.zip") if mesh_path is not None: mesh = TriMesh.load(mesh_path) render_mesh( mesh=mesh, output_path=tmp_path, num_images=num_views, backend="CYCLES", **common_kwargs, ) elif model_path is not None: render_model( model_path, output_path=tmp_path, num_images=num_views, backend="CYCLES", **common_kwargs, ) if cache_path is not None: bf.copy(tmp_path, cache_path) with bf.BlobFile(tmp_path, "rb") as f: yield BlenderViewData(f) def mv_to_pc(multiview: ViewData, random_sample_count: int, point_count: int) -> PointCloud: pc = PointCloud.from_rgbd(multiview) # Handle empty samples. if len(pc.coords) == 0: pc = PointCloud( coords=np.zeros([1, 3]), channels=dict(zip("RGB", np.zeros([3, 1]))), ) while len(pc.coords) < point_count: pc = pc.combine(pc) # Prevent duplicate points; some models may not like it. pc.coords += np.random.normal(size=pc.coords.shape) * 1e-4 pc = pc.random_sample(random_sample_count) pc = pc.farthest_point_sample(point_count, average_neighbors=True) return pc def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict: res = batch.copy() scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device) res.points = res.points * scale_vec[:, None] if "cameras" in res: res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras] if "depths" in res: res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths] return res def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray: depth_img = center_crop(depth_img) depth_img = resize(depth_img, width=image_size, height=image_size) return np.squeeze(depth_img) def process_image( img_or_img_arr: Union[Image.Image, np.ndarray], alpha_removal: str, image_size: int ): if isinstance(img_or_img_arr, np.ndarray): img = Image.fromarray(img_or_img_arr) img_arr = img_or_img_arr else: img = img_or_img_arr img_arr = np.array(img) if len(img_arr.shape) == 2: # Grayscale rgb = Image.new("RGB", img.size) rgb.paste(img) img = rgb img_arr = np.array(img) img = center_crop(img) alpha = get_alpha(img) img = remove_alpha(img, mode=alpha_removal) alpha = alpha.resize((image_size,) * 2, resample=Image.BILINEAR) img = img.resize((image_size,) * 2, resample=Image.BILINEAR) return img, alpha