diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4decf3c997cd7bda3fdf18d94f32911f4c3a1b94 --- /dev/null +++ b/app.py @@ -0,0 +1,375 @@ +import gradio as gr +import torch +import numpy as np +from functools import partial +from typing import Optional +from shap_e.diffusion.gaussian_diffusion import diffusion_from_config +from shap_e.diffusion.sample import sample_latents +from shap_e.models.download import load_model, load_config +from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh +import trimesh +import torch.nn as nn +import os +import random +import warnings +from huggingface_hub import hf_hub_download +import hashlib + +import sys + +sys.tracebacklimit = 0 +def set_seed(seed=1024): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + +def freeze_params(params): + for param in params: + param.requires_grad = False + +class Blocks(gr.Blocks): + + def __init__( + self, + theme: str = "default", + analytics_enabled: Optional[bool] = None, + mode: str = "blocks", + title: str = "Gradio", + css: Optional[str] = None, + **kwargs, + ): + self.extra_configs = { + 'thumbnail': kwargs.pop('thumbnail', ''), + 'url': kwargs.pop('url', 'https://gradio.app/'), + 'creator': kwargs.pop('creator', '@teamGradio'), + } + + super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs) + warnings.filterwarnings("ignore") + + def get_config_file(self): + config = super(Blocks, self).get_config_file() + + for k, v in self.extra_configs.items(): + config[k] = v + + return config +def optimize_all(xm, models, initial_noise, noise_start_t, diffusion, latent_model, device, prompt, instruction, rand_seed): + state = {} + out_gen_1, out_gen_2, out_gen_3, out_gen_4, state = generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state) + edited_1, edited_2, edited_3, edited_4, state = _3d_editing(xm, models, diffusion, initial_noise, noise_start_t, device, instruction, rand_seed, state) + print(state) + return out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4 +def generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state): + set_seed(rand_seed) + batch_size = 4 + guidance_scale = 15.0 + xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device) + xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device) + xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max]) + + print("prompt: ", prompt, "rand_seed: ", rand_seed, "state:", state) + latents = sample_latents( + batch_size=batch_size, + model=latent_model, + diffusion=diffusion, + guidance_scale=guidance_scale, + model_kwargs=dict(texts=[prompt] * batch_size), + progress=True, + clip_denoised=True, + use_fp16=True, + use_karras=True, + karras_steps=64, + sigma_min=1e-3, + sigma_max=160, + s_churn=0, + ) + prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed)).encode('utf-8')).hexdigest()) + mesh_path = [] + output_path = './logs' + os.makedirs(os.path.join(output_path, 'source'), exist_ok=True) + state['latent'] = [] + state['prompt'] = prompt + state['rand_seed_1'] = rand_seed + for i, latent in enumerate(latents): + + output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i)) + t_obj = decode_latent_mesh(xm, latent).tri_mesh() + with open(output_path_tmp, 'w') as f: + t_obj.write_obj(f) + + mesh = trimesh.load_mesh(output_path_tmp) + angle = np.radians(180) + axis = [0, 1, 0] + rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis) + mesh.apply_transform(rotation_matrix) + angle = np.radians(90) + axis = [1, 0, 0] + rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis) + mesh.apply_transform(rotation_matrix) + output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i)) + mesh.export(output_path_tmp) + state['latent'].append(latent.clone().detach()) + mesh_path.append(output_path_tmp) + + return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state + +def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instruction, rand_seed, state): + set_seed(rand_seed) + mesh_path = [] + prompt = state['prompt'] + rand_seed_1 = state['rand_seed_1'] + print("prompt: ", prompt, "rand_seed: ", rand_seed, "instruction:", instruction, "state:", state) + prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed_1) + '_' + instruction + '_' + str(rand_seed)).encode('utf-8')).hexdigest()) + if 'santa' in instruction: + e_type = 'santa_hat' + elif 'rainbow' in instruction: + e_type = 'rainbow' + elif 'gold' in instruction: + e_type = 'golden' + elif 'lego' in instruction: + e_type = 'lego' + elif 'wooden' in instruction: + e_type = 'wooden' + elif 'cyber' in instruction: + e_type = 'cyber' + + # import pdb; pdb.set_trace() + model = models[e_type].to(device) + noise_initial = initial_noise[e_type].to(device) + noise_start_t = start_t[e_type] + general_save_path = './logs/edited' + os.makedirs(general_save_path, exist_ok=True) + for i, latent in enumerate(state['latent']): + latent = latent.to(device) + text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction])) + print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction) + ref_latent = latent.clone().unsqueeze(0) + t_1 = torch.randint(noise_start_t, noise_start_t + 1, (1,), device=device).long() + + noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial) + out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True, + model_kwargs=text_embeddings_clip, + condition_latents=ref_latent) + + updated_latents = out_1['pred_xstart'] + + if 'santa' in instruction: + xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.25]).to(device) + xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1]).to(device) + xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max]) + + else: + xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device) + xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device) + xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max]) + + for latent_idx, updated_latent in enumerate(updated_latents): + output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i)) + + t = decode_latent_mesh(xm, updated_latent).tri_mesh() + with open(output_path, 'w') as f: + t.write_obj(f) + mesh = trimesh.load_mesh(output_path) + + angle = np.radians(180) + axis = [0, 1, 0] + + rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis) + mesh.apply_transform(rotation_matrix) + angle = np.radians(90) + axis = [1, 0, 0] + + rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis) + mesh.apply_transform(rotation_matrix) + + output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i)) + mesh.export(output_path) + mesh_path.append(output_path) + return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state +def main(): + + css = """ + #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img + { + height: var(--height) !important; + max-height: var(--height) !important; + min-height: var(--height) !important; + } + #paper-info a { + color:#008AD7; + text-decoration: none; + } + #paper-info a:hover { + cursor: pointer; + text-decoration: none; + } + + .tooltip { + color: #555; + position: relative; + display: inline-block; + cursor: pointer; + } + + .tooltip .tooltiptext { + visibility: hidden; + width: 400px; + background-color: #555; + color: #fff; + text-align: center; + padding: 5px; + border-radius: 5px; + position: absolute; + z-index: 1; /* Set z-index to 1 */ + left: 10px; + top: 100%; + opacity: 0; + transition: opacity 0.3s; + } + + .tooltip:hover .tooltiptext { + visibility: visible; + opacity: 1; + z-index: 9999; /* Set a high z-index value when hovering */ + } + + + """ + + rescale_js = """ + function(x) { + const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app'); + let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0; + const image_width = root.querySelector('#img2img_image').clientWidth; + const target_height = parseInt(image_width * image_scale); + document.body.style.setProperty('--height', `${target_height}px`); + root.querySelectorAll('button.justify-center.rounded')[0].style.display='none'; + root.querySelectorAll('button.justify-center.rounded')[1].style.display='none'; + return x; + } + """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + latent_model = load_model('text300M', device=device) + xm = load_model('transmitter', device=device) + diffusion = diffusion_from_config(load_config('diffusion')) + freeze_params(xm.parameters()) + models = dict() + initial_noise = dict() + noise_start_t = dict() + editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber'] + + for editing_type in editing_types: + tmp_model = load_model('text300M', device=device) + with torch.no_grad(): + new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=tmp_model.wrapped.input_proj.weight.dtype) + new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight)) + new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) # + new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias)) + new_proj.bias[:1024].copy_(tmp_model.wrapped.input_proj.bias) + tmp_model.wrapped.input_proj = new_proj + + ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu') + tmp_model.load_state_dict(ckp['model']) + noise_initial = ckp['initial_noise']['noise'].to(device) + initial_noise[editing_type] = noise_initial + noise_start_t[editing_type] = ckp['t_start'] + models[editing_type] = tmp_model + + with Blocks( + css=css, + analytics_enabled=False, + title="SHAPE-EDITOR demo", + ) as demo: + description = """

+ SHAP-EDITOR: Instruction-guided
Latent 3D Editing in Seconds
+
+ + [Project Page] + [Paper] + [GitHub] + +

+ """ + state = gr.State({}) + gr.HTML(description) + with gr.Column(): + with gr.Column(): + gr.HTML('Step 1: generate original 3D object using Shap-E.') + prompt = gr.Textbox( + label="Text prompt for initial 3D generation", lines=1 + ) + gen_btn = gr.Button(value='Generate', scale=1) + + + with gr.Column(): + gr.HTML('Generated 3D objects') + with gr.Row(): + out_gen_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 1 (step 1)") + out_gen_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 2 (step 1)") + out_gen_3 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 3 (step 1)") + out_gen_4 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 4 (step 1)") + + with gr.Column(scale=1): + gr.HTML('Step 2: apply 3D editing with SHAP-EDITOR.') + + editing_choice = gr.Dropdown( + ["Add a santa hat to it", "Make it look like made of gold", "Make the color of it look like rainbow", "Make it in cyberpunk style", "Make it wooden", "Make it look like make of lego"], value='Add a santa hat to it', multiselect=False, label="Editing effects", info="Select specific editing you want to apply!" + ), + apply_btn = gr.Button(value='Editing', scale=1) + + with gr.Column(scale=3): + gr.HTML('Edited 3D objects') + with gr.Row(): + edited_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 1 (step 2)") + edited_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 2 (step 2)") + edited_3 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 3 (step 2)") + edited_4 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 4 (step 2)") + + + with gr.Accordion("Advanced Options", open=False): + rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random seed") + + gen_btn.click( + fn=partial(generate_3d_with_shap_e, xm, diffusion, latent_model, device), + inputs=[prompt, rand_seed, state], + outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state], + queue=False) + + apply_btn.click( + fn=partial(_3d_editing, xm, models, diffusion, initial_noise, noise_start_t, device), + inputs=[ + editing_choice[0], rand_seed, state + ], + outputs=[edited_1, edited_2, edited_3, edited_4, state], + queue=True + ) + print("Generate examples...") + with gr.Column(): + gr.Examples( + examples=[ + [ "a corgi", + "Make the color of it look like rainbow", + 456, + ], + ["a penguin", + "Make it look like make of lego", + 214, + ], + ], + inputs=[prompt, editing_choice[0], rand_seed], + outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4], + fn=partial(optimize_all, xm, models, initial_noise, noise_start_t, diffusion, latent_model, device), + cache_examples=True, + ) + + + demo.queue(max_size=10, api_open=False) + demo.launch(share=True, show_api=False, show_error=True) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b959541524ea2aabc82240978f726239056c35c4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +filelock +pillow +torch +fire +humanize +requests +tqdm +matplot +scikit-image +scipy +numpy +blobfile +clip @ git+https://github.com/openai/CLIP.git +trimesh + +# gradio demo +gradio diff --git a/shap_e/.DS_Store b/shap_e/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7f150b41acaceb8d1e153f2d14f8965d37f2d077 Binary files /dev/null and b/shap_e/.DS_Store differ diff --git a/shap_e/__init__.py b/shap_e/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/__pycache__/__init__.cpython-39.pyc b/shap_e/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71e1a4c7d4fc27a67e6078ce2fb26204c1b52d48 Binary files /dev/null and b/shap_e/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/diffusion/__init__.py b/shap_e/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/diffusion/__pycache__/__init__.cpython-39.pyc b/shap_e/diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e8fbfdf428533a439362f0ebb379fff7516911b Binary files /dev/null and b/shap_e/diffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc b/shap_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e847808e7a01ea45d2e973b4270e45a26afefbd Binary files /dev/null and b/shap_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc differ diff --git a/shap_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc b/shap_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ec801778849cb2c96cf4e1dea36f10232cad07b Binary files /dev/null and b/shap_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc differ diff --git a/shap_e/diffusion/__pycache__/sample.cpython-39.pyc b/shap_e/diffusion/__pycache__/sample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec1c0c84a36f6d5d46f0fc2f5290f35986acff4a Binary files /dev/null and b/shap_e/diffusion/__pycache__/sample.cpython-39.pyc differ diff --git a/shap_e/diffusion/gaussian_diffusion.py b/shap_e/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..d151eee92fbabd28027f0be5b2b86d3694dbe842 --- /dev/null +++ b/shap_e/diffusion/gaussian_diffusion.py @@ -0,0 +1,1143 @@ +""" +Based on https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +""" + +import math +from typing import Any, Dict, Iterable, Optional, Sequence, Union + +import blobfile as bf +import numpy as np +import torch as th +import yaml + + +def diffusion_from_config(config: Union[str, Dict[str, Any]]) -> "GaussianDiffusion": + if isinstance(config, str): + with bf.BlobFile(config, "rb") as f: + obj = yaml.load(f, Loader=yaml.SafeLoader) + return diffusion_from_config(obj) + + schedule = config["schedule"] + steps = config["timesteps"] + respace = config.get("respacing", None) + mean_type = config.get("mean_type", "epsilon") + betas = get_named_beta_schedule(schedule, steps, **config.get("schedule_args", {})) + channel_scales = config.get("channel_scales", None) + channel_biases = config.get("channel_biases", None) + if channel_scales is not None: + channel_scales = np.array(channel_scales) + if channel_biases is not None: + channel_biases = np.array(channel_biases) + kwargs = dict( + betas=betas, + model_mean_type=mean_type, + model_var_type="learned_range", + loss_type="mse", + channel_scales=channel_scales, + channel_biases=channel_biases, + ) + if respace is None: + return GaussianDiffusion(**kwargs) + else: + return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs) + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **extra_args: float): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + elif schedule_name == "inv_parabola": + exponent = extra_args.get("power", 2.0) + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: 1 - t**exponent, + ) + elif schedule_name == "translated_parabola": + exponent = extra_args.get("power", 2.0) + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: (1 - t) ** exponent, + ) + elif schedule_name == "exp": + coefficient = extra_args.get("coefficient", -12.0) + return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.exp(t * coefficient)) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + elif section_counts.startswith("exact"): + res = set(int(x) for x in section_counts[len("exact") :].split(",")) + for x in res: + if x < 0 or x >= num_timesteps: + raise ValueError(f"timestep out of bounds: {x}") + return res + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D array of betas for each diffusion timestep from T to 1. + :param model_mean_type: a string determining what the model outputs. + :param model_var_type: a string determining how variance is output. + :param loss_type: a string determining the loss function to use. + :param discretized_t0: if True, use discrete gaussian loss for t=0. Only + makes sense for images. + :param channel_scales: a multiplier to apply to x_start in training_losses + and sampling functions. + """ + + def __init__( + self, + *, + betas: Sequence[float], + model_mean_type: str, + model_var_type: str, + loss_type: str, + discretized_t0: bool = False, + channel_scales: Optional[np.ndarray] = None, + channel_biases: Optional[np.ndarray] = None, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.discretized_t0 = discretized_t0 + self.channel_scales = channel_scales + self.channel_biases = channel_biases + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def get_sigmas(self, t): + return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None, condition_latents=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) if condition_latents is None else model(x, t, condition_latents, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in ["learned", "learned_range"]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == "learned": + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + "fixed_large": ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + "fixed_small": ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == "x_prev": + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in ["x_start", "epsilon"]: + if self.model_mean_type == "x_start": + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **(model_kwargs or {})) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **(model_kwargs or {})) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=False, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=False, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + temp=1.0, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + temp=temp, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=False, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + temp=1.0, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) * temp + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield self.unscale_out_dict(out) + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=False, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=False, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=False, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + temp=1.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + temp=temp, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=False, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + temp=1.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) * temp + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield self.unscale_out_dict(out) + img = out["sample"] + + def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + if not self.discretized_t0: + decoder_nll = th.zeros_like(decoder_nll) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return { + "output": output, + "pred_xstart": out["pred_xstart"], + "extra": out["extra"], + } + + def training_losses( + self, model, x_start, t, model_kwargs=None, noise=None + ) -> Dict[str, th.Tensor]: + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + x_start = self.scale_channels(x_start) + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == "kl" or self.loss_type == "rescaled_kl": + vb_terms = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + ) + terms["loss"] = vb_terms["output"] + if self.loss_type == "rescaled_kl": + terms["loss"] *= self.num_timesteps + extra = vb_terms["extra"] + elif self.loss_type == "mse" or self.loss_type == "rescaled_mse": + model_output = model(x_t, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = {} + + if self.model_var_type in [ + "learned", + "learned_range", + ]: + B, C = x_t.shape[:2] + assert model_output.shape == ( + B, + C * 2, + *x_t.shape[2:], + ), f"{model_output.shape} != {(B, C * 2, *x_t.shape[2:])}" + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == "rescaled_mse": + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + "x_prev": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0], + "x_start": x_start, + "epsilon": noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + if "losses" in extra: + terms.update({k: loss for k, (loss, _scale) in extra["losses"].items()}) + for loss, scale in extra["losses"].values(): + terms["loss"] = terms["loss"] + loss * scale + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + def scale_channels(self, x: th.Tensor) -> th.Tensor: + if self.channel_scales is not None: + x = x * th.from_numpy(self.channel_scales).to(x).reshape( + [1, -1, *([1] * (len(x.shape) - 2))] + ) + if self.channel_biases is not None: + x = x + th.from_numpy(self.channel_biases).to(x).reshape( + [1, -1, *([1] * (len(x.shape) - 2))] + ) + return x + + def unscale_channels(self, x: th.Tensor) -> th.Tensor: + if self.channel_biases is not None: + x = x - th.from_numpy(self.channel_biases).to(x).reshape( + [1, -1, *([1] * (len(x.shape) - 2))] + ) + if self.channel_scales is not None: + x = x / th.from_numpy(self.channel_scales).to(x).reshape( + [1, -1, *([1] * (len(x.shape) - 2))] + ) + return x + + def unscale_out_dict( + self, out: Dict[str, Union[th.Tensor, Any]] + ) -> Dict[str, Union[th.Tensor, Any]]: + return { + k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items() + } + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: (unordered) timesteps from the original diffusion + process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps: Iterable[int], **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses(self, model, *args, **kwargs): + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + return self.model(x, new_ts, **kwargs) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.flatten(1).mean(1) diff --git a/shap_e/diffusion/k_diffusion.py b/shap_e/diffusion/k_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..353d976f0210ac2972c5165f6dd8f41bb1706175 --- /dev/null +++ b/shap_e/diffusion/k_diffusion.py @@ -0,0 +1,426 @@ +""" +Based on: https://github.com/crowsonkb/k-diffusion + +Copyright (c) 2022 Katherine Crowson + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion, mean_flat + + +class KarrasDenoiser: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def get_snr(self, sigmas): + return sigmas**-2 + + def get_sigmas(self, sigmas): + return sigmas + + def get_scalings(self, sigma): + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out, c_in + + def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None): + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + + terms = {} + + dims = x_start.ndim + x_t = x_start + noise * append_dims(sigmas, dims) + c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)] + model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs) + target = (x_start - c_skip * x_t) / c_out + + terms["mse"] = mean_flat((model_output - target) ** 2) + terms["xs_mse"] = mean_flat((denoised - x_start) ** 2) + + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + + return terms + + def denoise(self, model, x_t, sigmas, **model_kwargs): + c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)] + rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) + model_output = model(c_in * x_t, rescaled_t, **model_kwargs) + denoised = c_out * model_output + c_skip * x_t + return model_output, denoised + + +class GaussianToKarrasDenoiser: + def __init__(self, model, diffusion): + from scipy import interpolate + + self.model = model + self.diffusion = diffusion + self.alpha_cumprod_to_t = interpolate.interp1d( + diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps) + ) + + def sigma_to_t(self, sigma): + alpha_cumprod = 1.0 / (sigma**2 + 1) + if alpha_cumprod > self.diffusion.alphas_cumprod[0]: + return 0 + elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]: + return self.diffusion.num_timesteps - 1 + else: + return float(self.alpha_cumprod_to_t(alpha_cumprod)) + + def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None, condition_latents=None): + t = th.tensor( + [self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()], + dtype=th.long, + device=sigmas.device, + ) + c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim) + out = self.diffusion.p_mean_variance( + self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latents + ) + return None, out["pred_xstart"] + + +def karras_sample(*args, **kwargs): + last = None + x_sequence = [] + # print("kraras_sample_model_kwargs", kwargs["model_kwargs"]['embeddings'].shape) + for x in karras_sample_progressive(*args, **kwargs): + last = x["x"] + x_sequence.append(last) + return last, x_sequence + + + +def karras_sample_progressive( + diffusion, + model, + shape, + steps, + clip_denoised=True, + progress=False, + model_kwargs=None, + device=None, + sigma_min=0.002, + sigma_max=80, # higher for highres? + rho=7.0, + sampler="heun", + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, + guidance_scale=0.0, + condition_latent=None, + initial_noise=None, +): + sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) + # print("sigmas", sigmas.shape, sigmas) + if initial_noise is None: + x_T = th.randn(*shape, device=device) * sigma_max + else: + x_T = initial_noise.clone() * sigma_max + sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[ + sampler + ] + if sampler != "ancestral": + sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise) + else: + sampler_args = {} + + if isinstance(diffusion, KarrasDenoiser): + def denoiser(x_t, sigma): + _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs) + if clip_denoised: + denoised = denoised.clamp(-1, 1) + return denoised + + elif isinstance(diffusion, GaussianDiffusion): + model = GaussianToKarrasDenoiser(model, diffusion) + + def denoiser(x_t, sigma): + _, denoised = model.denoise( + x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latent + ) + return denoised + + else: + raise NotImplementedError + + if guidance_scale != 0 and guidance_scale != 1: + + def guided_denoiser(x_t, sigma): + x_t = th.cat([x_t, x_t], dim=0) + sigma = th.cat([sigma, sigma], dim=0) + x_0 = denoiser(x_t, sigma) + cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0) + x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0) + return x_0 + + else: + guided_denoiser = denoiser + + for obj in sample_fn( + guided_denoiser, + x_T, + sigmas, + progress=progress, + condition_latent=condition_latent, + **sampler_args, + ): + if isinstance(diffusion, GaussianDiffusion): + # print("is gaussian diffusion", obj) + yield diffusion.unscale_out_dict(obj) + else: + yield obj + + +def karras_sample_progressive_condition( + diffusion, + model, + shape, + steps, + clip_denoised=True, + progress=False, + model_kwargs=None, + device=None, + sigma_min=0.002, + sigma_max=80, # higher for highres? + rho=7.0, + sampler="heun", + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, + text_guidance_scale=0.0, + image_guidance_scale=0.0, + condition_latent=None, +): + sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) + x_T = th.randn(*shape, device=device) * sigma_max + sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[ + sampler + ] + if sampler != "ancestral": + sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise) + else: + sampler_args = {} + + if isinstance(diffusion, KarrasDenoiser): + def denoiser(x_t, sigma): + _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs) + if clip_denoised: + denoised = denoised.clamp(-1, 1) + return denoised + + elif isinstance(diffusion, GaussianDiffusion): + model = GaussianToKarrasDenoiser(model, diffusion) + + def denoiser(x_t, sigma): + _, denoised = model.denoise( + x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latent + ) + return denoised + + else: + raise NotImplementedError + + if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0): + def guided_denoiser(x_t, sigma): + x_t = th.cat([x_t, x_t, x_t], dim=0) + sigma = th.cat([sigma, sigma, sigma], dim=0) + x_0 = denoiser(x_t, sigma) + # import pdb; pdb.set_trace() + cond_x_0_text, cond_x_0_image, uncond_x_0 = th.chunk(x_0, 3, dim=0) + x_0 = uncond_x_0 + text_guidance_scale * (cond_x_0_text - cond_x_0_image) + image_guidance_scale * (cond_x_0_image - uncond_x_0) + return x_0 + + else: + guided_denoiser = denoiser + + for obj in sample_fn( + guided_denoiser, + x_T, + sigmas, + progress=progress, + condition_latent=condition_latent, + **sampler_args, + ): + if isinstance(diffusion, GaussianDiffusion): + yield diffusion.unscale_out_dict(obj) + else: + yield obj +def karras_sample_addition_condition(*args, **kwargs): + last = None + x_sequence = [] + for x in karras_sample_progressive_condition(*args, **kwargs): + last = x["x"] + x_sequence.append(x["pred_xstart"]) + return last, x_sequence + +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = th.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + return sigma_down, sigma_up + + +@th.no_grad() +def sample_euler_ancestral(model, x, sigmas, progress=False): + """Ancestral sampling with Euler method steps.""" + s_in = x.new_ones([x.shape[0]]) + indices = range(len(sigmas) - 1) + if progress: + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + denoised = model(x, sigmas[i] * s_in) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) + yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised} + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + x = x + th.randn_like(x) * sigma_up + yield {"x": x, "pred_xstart": x} + + +@th.no_grad() +def sample_heun( + denoiser, + x, + sigmas, + progress=False, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, + condition_latent=None, +): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + s_in = x.new_ones([x.shape[0]]) + indices = range(len(sigmas) - 1) + if progress: + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 + ) + eps = th.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 + denoised = denoiser(x, sigma_hat * s_in) + d = to_d(x, sigma_hat, denoised) + yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised} + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + yield {"x": x, "pred_xstart": denoised} + + +@th.no_grad() +def sample_dpm( + denoiser, + x, + sigmas, + progress=False, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + s_in = x.new_ones([x.shape[0]]) + indices = range(len(sigmas) - 1) + if progress: + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 + ) + eps = th.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 + denoised = denoiser(x, sigma_hat * s_in) + d = to_d(x, sigma_hat, denoised) + yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised} + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + denoised_2 = denoiser(x_2, sigma_mid * s_in) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + yield {"x": x, "pred_xstart": denoised} + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def append_zero(x): + return th.cat([x, x.new_zeros([1])]) diff --git a/shap_e/diffusion/sample.py b/shap_e/diffusion/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f6dc1d3db37277b1ff99f0c5be9695f843eabfa9 --- /dev/null +++ b/shap_e/diffusion/sample.py @@ -0,0 +1,160 @@ +from typing import Any, Callable, Dict, Optional, List + +import torch +import torch.nn as nn + +from .gaussian_diffusion import GaussianDiffusion +from .k_diffusion import karras_sample, karras_sample_addition_condition + +DEFAULT_KARRAS_STEPS = 64 +DEFAULT_KARRAS_SIGMA_MIN = 1e-3 +DEFAULT_KARRAS_SIGMA_MAX = 160 +DEFAULT_KARRAS_S_CHURN = 0.0 + + +def uncond_guide_model( + model: Callable[..., torch.Tensor], scale: float +) -> Callable[..., torch.Tensor]: + + def model_fn(x_t, ts, **kwargs): + half = x_t[: len(x_t) // 2] + combined = torch.cat([half, half], dim=0) + model_out = model(combined, ts, **kwargs) + cond_out, uncond_out = torch.chunk(model_out, 2, dim=0) + cond_out = uncond_out + scale * (cond_out - uncond_out) + return torch.cat([cond_out, cond_out], dim=0) + + return model_fn + + +def sample_latents( + *, + batch_size: int, + model: nn.Module, + diffusion: GaussianDiffusion, + model_kwargs: Dict[str, Any], + guidance_scale: float, + clip_denoised: bool, + use_fp16: bool, + use_karras: bool, + karras_steps: int, + sigma_min: float, + sigma_max: float, + s_churn: float, + device: Optional[torch.device] = None, + progress: bool = False, + initial_noise: Optional[torch.Tensor] = None, +) -> (torch.Tensor, List[torch.Tensor]): + sample_shape = (batch_size, model.d_latent) + + if device is None: + device = next(model.parameters()).device + + if hasattr(model, "cached_model_kwargs"): + model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs) + if guidance_scale != 1.0 and guidance_scale != 0.0: + for k, v in model_kwargs.copy().items(): + # print(k, v.shape) + model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0) + + sample_shape = (batch_size, model.d_latent) + with torch.autocast(device_type=device.type, enabled=use_fp16): + if use_karras: + samples, sample_sequence = karras_sample( + diffusion=diffusion, + model=model, + shape=sample_shape, + steps=karras_steps, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + device=device, + sigma_min=sigma_min, + sigma_max=sigma_max, + s_churn=s_churn, + guidance_scale=guidance_scale, + progress=progress, + initial_noise=initial_noise, + ) + else: + internal_batch_size = batch_size + if guidance_scale != 1.0: + model = uncond_guide_model(model, guidance_scale) + internal_batch_size *= 2 + samples = diffusion.p_sample_loop( + model, + shape=(internal_batch_size, *sample_shape[1:]), + model_kwargs=model_kwargs, + device=device, + clip_denoised=clip_denoised, + progress=progress, + ) + + return samples + + +def sample_latents_with_additional_latent( + *, + batch_size: int, + model: nn.Module, + diffusion: GaussianDiffusion, + model_kwargs: Dict[str, Any], + text_guidance_scale: float, + image_guidance_scale: float, + clip_denoised: bool, + use_fp16: bool, + use_karras: bool, + karras_steps: int, + sigma_min: float, + sigma_max: float, + s_churn: float, + device: Optional[torch.device] = None, + progress: bool = False, + condition_latent: Optional[torch.Tensor] = None, +) -> (torch.Tensor, List[torch.Tensor]): + + if device is None: + device = next(model.parameters()).device + + if hasattr(model, "cached_model_kwargs"): + model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs) + if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0): + for k, v in model_kwargs.copy().items(): + # print(k, v.shape) + model_kwargs[k] = torch.cat([v, torch.zeros_like(v), torch.zeros_like(v)], dim=0) + condition_latent = torch.cat([condition_latent, condition_latent, torch.zeros_like(condition_latent)], dim=0) + + sample_shape = (batch_size, model.d_latent) + # print("sample_shape", sample_shape) + with torch.autocast(device_type=device.type, enabled=use_fp16): + if use_karras: + samples, samples_squence = karras_sample_addition_condition( + diffusion=diffusion, + model=model, + shape=sample_shape, + steps=karras_steps, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + device=device, + sigma_min=sigma_min, + sigma_max=sigma_max, + s_churn=s_churn, + text_guidance_scale=text_guidance_scale, + image_guidance_scale=image_guidance_scale, + progress=progress, + condition_latent=condition_latent, + ) + else: + internal_batch_size = batch_size + if text_guidance_scale != 1.0: + model = uncond_guide_model(model, text_guidance_scale) + internal_batch_size *= 2 + samples = diffusion.p_sample_loop( + model, + shape=(internal_batch_size, *sample_shape[1:]), + model_kwargs=model_kwargs, + device=device, + clip_denoised=clip_denoised, + progress=progress, + ) + + return samples \ No newline at end of file diff --git a/shap_e/examples/encode_model.ipynb b/shap_e/examples/encode_model.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..42270bcb5deabae98c635be7eca2c69061859c7e --- /dev/null +++ b/shap_e/examples/encode_model.ipynb @@ -0,0 +1,93 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from shap_e.models.download import load_model\n", + "from shap_e.util.data_util import load_or_create_multimodal_batch\n", + "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "xm = load_model('transmitter', device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = \"example_data/cactus/object.obj\"\n", + "\n", + "# This may take a few minutes, since it requires rendering the model twice\n", + "# in two different modes.\n", + "batch = load_or_create_multimodal_batch(\n", + " device,\n", + " model_path=model_path,\n", + " mv_light_mode=\"basic\",\n", + " mv_image_size=256,\n", + " cache_dir=\"example_data/cactus/cached\",\n", + " verbose=True, # this will show Blender output during renders\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " latent = xm.encoder.encode_to_bottleneck(batch)\n", + "\n", + " render_mode = 'stf' # you can change this to 'nerf'\n", + " size = 128 # recommended that you lower resolution when using nerf\n", + "\n", + " cameras = create_pan_cameras(size, device)\n", + " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", + " display(gif_widget(images))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/shap_e/examples/sample_image_to_3d.ipynb b/shap_e/examples/sample_image_to_3d.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c224bb728961ef3265e7dfb3ab4226583c7c1692 --- /dev/null +++ b/shap_e/examples/sample_image_to_3d.ipynb @@ -0,0 +1,125 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "964ccced", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from shap_e.diffusion.sample import sample_latents\n", + "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n", + "from shap_e.models.download import load_model, load_config\n", + "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\n", + "from shap_e.util.image_util import load_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8eed3a76", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d922637", + "metadata": {}, + "outputs": [], + "source": [ + "xm = load_model('transmitter', device=device)\n", + "model = load_model('image300M', device=device)\n", + "diffusion = diffusion_from_config(load_config('diffusion'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53d329d0", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 4\n", + "guidance_scale = 3.0\n", + "\n", + "image = load_image(\"example_data/corgi.png\")\n", + "\n", + "latents = sample_latents(\n", + " batch_size=batch_size,\n", + " model=model,\n", + " diffusion=diffusion,\n", + " guidance_scale=guidance_scale,\n", + " model_kwargs=dict(images=[image] * batch_size),\n", + " progress=True,\n", + " clip_denoised=True,\n", + " use_fp16=True,\n", + " use_karras=True,\n", + " karras_steps=64,\n", + " sigma_min=1e-3,\n", + " sigma_max=160,\n", + " s_churn=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "633da2ec", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "render_mode = 'nerf' # you can change this to 'stf' for mesh rendering\n", + "size = 64 # this is the size of the renders; higher values take longer to render.\n", + "\n", + "cameras = create_pan_cameras(size, device)\n", + "for i, latent in enumerate(latents):\n", + " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", + " display(gif_widget(images))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/shap_e/examples/sample_text_to_3d.ipynb b/shap_e/examples/sample_text_to_3d.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ef39d44171c70d5f3fe1390942b67d5fc5b1b2ae --- /dev/null +++ b/shap_e/examples/sample_text_to_3d.ipynb @@ -0,0 +1,124 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "964ccced", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from shap_e.diffusion.sample import sample_latents\n", + "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n", + "from shap_e.models.download import load_model, load_config\n", + "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8eed3a76", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d922637", + "metadata": {}, + "outputs": [], + "source": [ + "xm = load_model('transmitter', device=device)\n", + "model = load_model('text300M', device=device)\n", + "diffusion = diffusion_from_config(load_config('diffusion'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53d329d0", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 4\n", + "guidance_scale = 15.0\n", + "prompt = \"a shark\"\n", + "\n", + "latents = sample_latents(\n", + " batch_size=batch_size,\n", + " model=model,\n", + " diffusion=diffusion,\n", + " guidance_scale=guidance_scale,\n", + " model_kwargs=dict(texts=[prompt] * batch_size),\n", + " progress=True,\n", + " clip_denoised=True,\n", + " use_fp16=True,\n", + " use_karras=True,\n", + " karras_steps=64,\n", + " sigma_min=1e-3,\n", + " sigma_max=160,\n", + " s_churn=0,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "633da2ec", + "metadata": {}, + "outputs": [], + "source": [ + "render_mode = 'nerf' # you can change this to 'stf'\n", + "size = 64 # this is the size of the renders; higher values take longer to render.\n", + "\n", + "cameras = create_pan_cameras(size, device)\n", + "for i, latent in enumerate(latents):\n", + " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", + " display(gif_widget(images))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85a4dce4", + "metadata": {}, + "outputs": [], + "source": [ + "# Example of saving the latents as meshes.\n", + "from shap_e.util.notebooks import decode_latent_mesh\n", + "\n", + "for i, latent in enumerate(latents):\n", + " t = decode_latent_mesh(xm, latent).tri_mesh()\n", + " with open(f'example_mesh_{i}.ply', 'wb') as f:\n", + " t.write_ply(f)\n", + " with open(f'example_mesh_{i}.obj', 'w') as f:\n", + " t.write_obj(f)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/shap_e/models/__init__.py b/shap_e/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/models/__pycache__/__init__.cpython-39.pyc b/shap_e/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ccca183a04a82647087385c13f99b4c9efdd0c7 Binary files /dev/null and b/shap_e/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/models/__pycache__/configs.cpython-39.pyc b/shap_e/models/__pycache__/configs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24a89b6548b38fa4a7e63c0a418a47fd4e46ee8e Binary files /dev/null and b/shap_e/models/__pycache__/configs.cpython-39.pyc differ diff --git a/shap_e/models/__pycache__/download.cpython-39.pyc b/shap_e/models/__pycache__/download.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0095b90a8d255a1a3b78a2a4b4189803d567b238 Binary files /dev/null and b/shap_e/models/__pycache__/download.cpython-39.pyc differ diff --git a/shap_e/models/__pycache__/query.cpython-39.pyc b/shap_e/models/__pycache__/query.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c265b87fa336f67ebe26a203b5757244a8a35c1 Binary files /dev/null and b/shap_e/models/__pycache__/query.cpython-39.pyc differ diff --git a/shap_e/models/__pycache__/renderer.cpython-39.pyc b/shap_e/models/__pycache__/renderer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdc91f471c3df5b5b7dd3f851891de537ad73289 Binary files /dev/null and b/shap_e/models/__pycache__/renderer.cpython-39.pyc differ diff --git a/shap_e/models/__pycache__/volume.cpython-39.pyc b/shap_e/models/__pycache__/volume.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b192d8cf66d1b501e453b4a265a28a52b4a5daa7 Binary files /dev/null and b/shap_e/models/__pycache__/volume.cpython-39.pyc differ diff --git a/shap_e/models/configs.py b/shap_e/models/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..6d115b1d168b39da33a1fa10308fbb4b9a903155 --- /dev/null +++ b/shap_e/models/configs.py @@ -0,0 +1,166 @@ +from typing import Any, Dict, Union + +import blobfile as bf +import torch +import torch.nn as nn +import yaml + +from shap_e.models.generation.latent_diffusion import SplitVectorDiffusion +from shap_e.models.generation.perceiver import PointDiffusionPerceiver +from shap_e.models.generation.pooled_mlp import PooledMLP +from shap_e.models.generation.transformer import ( + CLIPImageGridPointDiffusionTransformer, + CLIPImageGridUpsamplePointDiffusionTransformer, + CLIPImagePointDiffusionTransformer, + PointDiffusionTransformer, + UpsamplePointDiffusionTransformer, +) +from shap_e.models.nerf.model import MLPNeRFModel, VoidNeRFModel +from shap_e.models.nerf.renderer import OneStepNeRFRenderer, TwoStepNeRFRenderer +from shap_e.models.nerstf.mlp import MLPDensitySDFModel, MLPNeRSTFModel +from shap_e.models.nerstf.renderer import NeRSTFRenderer +from shap_e.models.nn.meta import batch_meta_state_dict +from shap_e.models.stf.mlp import MLPSDFModel, MLPTextureFieldModel +from shap_e.models.stf.renderer import STFRenderer +from shap_e.models.transmitter.base import ChannelsDecoder, Transmitter, VectorDecoder +from shap_e.models.transmitter.channels_encoder import ( + PointCloudPerceiverChannelsEncoder, + PointCloudTransformerChannelsEncoder, +) +from shap_e.models.transmitter.multiview_encoder import MultiviewTransformerEncoder +from shap_e.models.transmitter.pc_encoder import ( + PointCloudPerceiverEncoder, + PointCloudTransformerEncoder, +) +from shap_e.models.volume import BoundingBoxVolume, SphericalVolume, UnboundedVolume + + +def model_from_config(config: Union[str, Dict[str, Any]], device: torch.device) -> nn.Module: + print(config) + if isinstance(config, str): + print("config", config) + with bf.BlobFile(config, "rb") as f: + obj = yaml.load(f, Loader=yaml.SafeLoader) + return model_from_config(obj, device=device) + + config = config.copy() + name = config.pop("name") + + if name == "PointCloudTransformerEncoder": + return PointCloudTransformerEncoder(device=device, dtype=torch.float32, **config) + elif name == "PointCloudPerceiverEncoder": + return PointCloudPerceiverEncoder(device=device, dtype=torch.float32, **config) + elif name == "PointCloudTransformerChannelsEncoder": + return PointCloudTransformerChannelsEncoder(device=device, dtype=torch.float32, **config) + elif name == "PointCloudPerceiverChannelsEncoder": + return PointCloudPerceiverChannelsEncoder(device=device, dtype=torch.float32, **config) + elif name == "MultiviewTransformerEncoder": + return MultiviewTransformerEncoder(device=device, dtype=torch.float32, **config) + elif name == "Transmitter": + renderer = model_from_config(config.pop("renderer"), device=device) + param_shapes = { + k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() + } + encoder_config = config.pop("encoder").copy() + encoder_config["param_shapes"] = param_shapes + encoder = model_from_config(encoder_config, device=device) + return Transmitter(encoder=encoder, renderer=renderer, **config) + elif name == "VectorDecoder": + renderer = model_from_config(config.pop("renderer"), device=device) + param_shapes = { + k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() + } + return VectorDecoder(param_shapes=param_shapes, renderer=renderer, device=device, **config) + elif name == "ChannelsDecoder": + renderer = model_from_config(config.pop("renderer"), device=device) + param_shapes = { + k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items() + } + return ChannelsDecoder( + param_shapes=param_shapes, renderer=renderer, device=device, **config + ) + elif name == "OneStepNeRFRenderer": + config = config.copy() + for field in [ + # Required + "void_model", + "foreground_model", + "volume", + # Optional to use NeRF++ + "background_model", + "outer_volume", + ]: + if field in config: + config[field] = model_from_config(config.pop(field).copy(), device) + return OneStepNeRFRenderer(device=device, **config) + elif name == "TwoStepNeRFRenderer": + config = config.copy() + for field in [ + # Required + "void_model", + "coarse_model", + "fine_model", + "volume", + # Optional to use NeRF++ + "coarse_background_model", + "fine_background_model", + "outer_volume", + ]: + if field in config: + config[field] = model_from_config(config.pop(field).copy(), device) + return TwoStepNeRFRenderer(device=device, **config) + elif name == "PooledMLP": + return PooledMLP(device, **config) + elif name == "PointDiffusionTransformer": + return PointDiffusionTransformer(device=device, dtype=torch.float32, **config) + elif name == "PointDiffusionPerceiver": + return PointDiffusionPerceiver(device=device, dtype=torch.float32, **config) + elif name == "CLIPImagePointDiffusionTransformer": + return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config) + elif name == "CLIPImageGridPointDiffusionTransformer": + return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config) + elif name == "UpsamplePointDiffusionTransformer": + return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config) + elif name == "CLIPImageGridUpsamplePointDiffusionTransformer": + return CLIPImageGridUpsamplePointDiffusionTransformer( + device=device, dtype=torch.float32, **config + ) + elif name == "SplitVectorDiffusion": + inner_config = config.pop("inner") + d_latent = config.pop("d_latent") + latent_ctx = config.pop("latent_ctx", 1) + inner_config["input_channels"] = d_latent // latent_ctx + inner_config["n_ctx"] = latent_ctx + inner_config["output_channels"] = d_latent // latent_ctx * 2 + inner_model = model_from_config(inner_config, device) + return SplitVectorDiffusion( + device=device, wrapped=inner_model, n_ctx=latent_ctx, d_latent=d_latent + ) + elif name == "STFRenderer": + config = config.copy() + for field in ["sdf", "tf", "volume"]: + config[field] = model_from_config(config.pop(field), device) + return STFRenderer(device=device, **config) + elif name == "NeRSTFRenderer": + config = config.copy() + for field in ["sdf", "tf", "nerstf", "void", "volume"]: + if field not in config: + continue + config[field] = model_from_config(config.pop(field), device) + config.setdefault("sdf", None) + config.setdefault("tf", None) + config.setdefault("nerstf", None) + return NeRSTFRenderer(device=device, **config) + + model_cls = { + "MLPSDFModel": MLPSDFModel, + "MLPTextureFieldModel": MLPTextureFieldModel, + "MLPNeRFModel": MLPNeRFModel, + "MLPDensitySDFModel": MLPDensitySDFModel, + "MLPNeRSTFModel": MLPNeRSTFModel, + "VoidNeRFModel": VoidNeRFModel, + "BoundingBoxVolume": BoundingBoxVolume, + "SphericalVolume": SphericalVolume, + "UnboundedVolume": UnboundedVolume, + }[name] + return model_cls(device=device, **config) diff --git a/shap_e/models/download.py b/shap_e/models/download.py new file mode 100644 index 0000000000000000000000000000000000000000..00fe9fb6955a1f407e3ff9a7ed155856604ad03e --- /dev/null +++ b/shap_e/models/download.py @@ -0,0 +1,152 @@ +""" +Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py +""" + +import hashlib +import os +from functools import lru_cache +from typing import Dict, Optional + +import requests +import torch +import yaml +from filelock import FileLock +from tqdm.auto import tqdm + +MODEL_PATHS = { + "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt", + "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt", + "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt", + "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt", +} + +CONFIG_PATHS = { + "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml", + "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml", + "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml", + "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml", + "diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml", +} + +URL_HASHES = { + "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b", + "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98", + "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4", + "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa", + "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e", + "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c", + "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1", + "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0", + "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57", +} + + +@lru_cache() +def default_cache_dir() -> str: + return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache") + + +def fetch_file_cached( + url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096 +) -> str: + """ + Download the file at the given URL into a local file and return the path. + If cache_dir is specified, it will be used to download the files. + Otherwise, default_cache_dir() is used. + """ + expected_hash = URL_HASHES[url] + + if cache_dir is None: + cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) + local_path = os.path.join(cache_dir, url.split("/")[-1]) + if os.path.exists(local_path): + check_hash(local_path, expected_hash) + return local_path + + response = requests.get(url, stream=True) + size = int(response.headers.get("content-length", "0")) + with FileLock(local_path + ".lock"): + if progress: + pbar = tqdm(total=size, unit="iB", unit_scale=True) + tmp_path = local_path + ".tmp" + with open(tmp_path, "wb") as f: + for chunk in response.iter_content(chunk_size): + if progress: + pbar.update(len(chunk)) + f.write(chunk) + os.rename(tmp_path, local_path) + if progress: + pbar.close() + check_hash(local_path, expected_hash) + return local_path + + +def check_hash(path: str, expected_hash: str): + actual_hash = hash_file(path) + if actual_hash != expected_hash: + raise RuntimeError( + f"The file {path} should have hash {expected_hash} but has {actual_hash}. " + "Try deleting it and running this call again." + ) + + +def hash_file(path: str) -> str: + sha256_hash = hashlib.sha256() + with open(path, "rb") as file: + while True: + data = file.read(4096) + if not len(data): + break + sha256_hash.update(data) + return sha256_hash.hexdigest() + + +def load_config( + config_name: str, + progress: bool = False, + cache_dir: Optional[str] = None, + chunk_size: int = 4096, +): + if config_name not in CONFIG_PATHS: + raise ValueError( + f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}." + ) + path = fetch_file_cached( + CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size + ) + with open(path, "r") as f: + return yaml.safe_load(f) + + +def load_checkpoint( + checkpoint_name: str, + device: torch.device, + progress: bool = True, + cache_dir: Optional[str] = None, + chunk_size: int = 4096, +) -> Dict[str, torch.Tensor]: + if checkpoint_name not in MODEL_PATHS: + raise ValueError( + f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}." + ) + print(checkpoint_name) + path = fetch_file_cached( + MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size + ) + return torch.load(path, map_location=device) + + +def load_model( + model_name: str, + device: torch.device, + **kwargs, +) -> Dict[str, torch.Tensor]: + from .configs import model_from_config + + model = model_from_config(load_config(model_name, **kwargs), device=device) + # print(model_name, kwargs) + # print(model) + model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs)) + model.eval() + return model diff --git a/shap_e/models/generation/__init__.py b/shap_e/models/generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/models/generation/__pycache__/__init__.cpython-39.pyc b/shap_e/models/generation/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25e74e8838df4c594b90f0aee9d04fb4ab56806b Binary files /dev/null and b/shap_e/models/generation/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/models/generation/__pycache__/latent_diffusion.cpython-39.pyc b/shap_e/models/generation/__pycache__/latent_diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..118b293f5dc03369957afd876244f03ba9c2e8ca Binary files /dev/null and b/shap_e/models/generation/__pycache__/latent_diffusion.cpython-39.pyc differ diff --git a/shap_e/models/generation/__pycache__/perceiver.cpython-39.pyc b/shap_e/models/generation/__pycache__/perceiver.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baa9c218e045d42ceec1cf97138c948e38abc51c Binary files /dev/null and b/shap_e/models/generation/__pycache__/perceiver.cpython-39.pyc differ diff --git a/shap_e/models/generation/__pycache__/pooled_mlp.cpython-39.pyc b/shap_e/models/generation/__pycache__/pooled_mlp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e12d835f0b7a44aab59f06e8af70d2b95eeaeaab Binary files /dev/null and b/shap_e/models/generation/__pycache__/pooled_mlp.cpython-39.pyc differ diff --git a/shap_e/models/generation/__pycache__/pretrained_clip.cpython-39.pyc b/shap_e/models/generation/__pycache__/pretrained_clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7719a021c48f25788eac267565b5c32de90bbcf Binary files /dev/null and b/shap_e/models/generation/__pycache__/pretrained_clip.cpython-39.pyc differ diff --git a/shap_e/models/generation/__pycache__/transformer.cpython-39.pyc b/shap_e/models/generation/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62faefdfbdad9c4380eff1c76a4e114e2d28407e Binary files /dev/null and b/shap_e/models/generation/__pycache__/transformer.cpython-39.pyc differ diff --git a/shap_e/models/generation/__pycache__/util.cpython-39.pyc b/shap_e/models/generation/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0e2a1aaf4fdac9ec707b567a15427282ad73481 Binary files /dev/null and b/shap_e/models/generation/__pycache__/util.cpython-39.pyc differ diff --git a/shap_e/models/generation/latent_diffusion.py b/shap_e/models/generation/latent_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca472892a54123d91fd75623ebba336b8af153b --- /dev/null +++ b/shap_e/models/generation/latent_diffusion.py @@ -0,0 +1,32 @@ +from typing import Any, Dict + +import torch +import torch.nn as nn +from typing import Any, Callable, Dict, Optional + + +class SplitVectorDiffusion(nn.Module): + def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int): + super().__init__() + self.device = device + self.n_ctx = n_ctx + self.d_latent = d_latent + self.wrapped = wrapped + + if hasattr(self.wrapped, "cached_model_kwargs"): + self.cached_model_kwargs = self.wrapped.cached_model_kwargs + + def forward(self, x: torch.Tensor, t: torch.Tensor, conditional_latent: Optional[torch.Tensor] = None, **kwargs): + h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1) + if conditional_latent is not None: + conditional_latent = conditional_latent.reshape(conditional_latent.shape[0], self.n_ctx, -1) + h = torch.cat([h.permute(0, 2, 1) , conditional_latent], dim=-1).permute(0, 2, 1) # (batch_size, n_ctx, channel) -> (batch_size, d_latent, n_ctx) + h = self.wrapped(h, t, **kwargs) + eps, var = torch.chunk(h, 2, dim=1) + return torch.cat( + [ + eps.permute(0, 2, 1).flatten(1), + var.permute(0, 2, 1).flatten(1), + ], + dim=1, + ) diff --git a/shap_e/models/generation/perceiver.py b/shap_e/models/generation/perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b579c0ebe7c6b2138b643489b04d8d85e35a15 --- /dev/null +++ b/shap_e/models/generation/perceiver.py @@ -0,0 +1,244 @@ +import math +from typing import Optional + +import torch +import torch.nn as nn + +from shap_e.models.nn.checkpoint import checkpoint + +from .transformer import MLP, Transformer, init_linear +from .util import timestep_embedding + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + n_data: int, + width: int, + heads: int, + init_scale: float, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_ctx = n_ctx + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, n_data=n_data + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__( + self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, n_data: int + ): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + self.n_data = n_data + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + n_data: int, + width: int, + heads: int, + data_width: Optional[int] = None, + init_scale: float = 1.0, + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class SimplePerceiver(nn.Module): + """ + Only does cross attention + """ + + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + n_data: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + init_scale = init_scale * math.sqrt(1.0 / width) + self.resblocks = nn.ModuleList( + [ + ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + n_data=n_data, + width=width, + heads=heads, + init_scale=init_scale, + data_width=data_width, + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + for block in self.resblocks: + x = block(x, data) + return x + + +class PointDiffusionPerceiver(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + input_channels: int = 3, + output_channels: int = 3, + n_ctx: int = 1024, + n_latent: int = 128, + width: int = 512, + encoder_layers: int = 12, + latent_layers: int = 12, + decoder_layers: int = 12, + heads: int = 8, + init_scale: float = 0.25, + ): + super().__init__() + self.time_embed = MLP( + device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) + ) + self.latent_embed = MLP( + device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) + ) + self.n_latent = n_latent + + self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) + self.encoder = SimplePerceiver( + device=device, + dtype=dtype, + n_ctx=n_latent, + n_data=n_ctx, + width=width, + layers=encoder_layers, + heads=heads, + init_scale=init_scale, + ) + self.processor = Transformer( + device=device, + dtype=dtype, + n_ctx=n_latent, + width=width, + layers=latent_layers, + heads=heads, + init_scale=init_scale, + ) + self.decoder = SimplePerceiver( + device=device, + dtype=dtype, + n_ctx=n_ctx, + n_data=n_latent, + width=width, + layers=decoder_layers, + heads=heads, + init_scale=init_scale, + ) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) + with torch.no_grad(): + self.output_proj.weight.zero_() + self.output_proj.bias.zero_() + + def forward(self, x: torch.Tensor, t: torch.Tensor): + """ + :param x: an [N x C x T] tensor. + :param t: an [N] tensor. + :return: an [N x C' x T] tensor. + """ + assert x.shape[-1] == self.decoder.n_ctx + t_embed = self.time_embed(timestep_embedding(t, self.encoder.width)) + data = self.input_proj(x.permute(0, 2, 1)) + t_embed[:, None] + data = self.ln_pre(data) + + l = torch.arange(self.n_latent).to(x.device) + h = self.latent_embed(timestep_embedding(l, self.decoder.width)) + h = h.unsqueeze(0).repeat(x.shape[0], 1, 1) + + h = self.encoder(h, data) + h = self.processor(h) + h = self.decoder(data, h) + h = self.ln_post(h) + h = self.output_proj(h) + return h.permute(0, 2, 1) diff --git a/shap_e/models/generation/pooled_mlp.py b/shap_e/models/generation/pooled_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..b37ff84826b0a1199121b8a543b3aa42eff2357c --- /dev/null +++ b/shap_e/models/generation/pooled_mlp.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + +from .util import timestep_embedding + + +class PooledMLP(nn.Module): + def __init__( + self, + device: torch.device, + *, + input_channels: int = 3, + output_channels: int = 6, + hidden_size: int = 256, + resblocks: int = 4, + pool_op: str = "max", + ): + super().__init__() + self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device) + self.time_embed = nn.Linear(hidden_size, hidden_size, device=device) + + blocks = [] + for _ in range(resblocks): + blocks.append(ResBlock(hidden_size, pool_op, device=device)) + self.sequence = nn.Sequential(*blocks) + + self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device) + with torch.no_grad(): + self.out.bias.zero_() + self.out.weight.zero_() + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + in_embed = self.input_embed(x) + t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1])) + h = in_embed + t_embed[..., None] + h = self.sequence(h) + h = self.out(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, hidden_size: int, pool_op: str, device: torch.device): + super().__init__() + assert pool_op in ["mean", "max"] + self.pool_op = pool_op + self.body = nn.Sequential( + nn.SiLU(), + nn.LayerNorm((hidden_size,), device=device), + nn.Linear(hidden_size, hidden_size, device=device), + nn.SiLU(), + nn.LayerNorm((hidden_size,), device=device), + nn.Linear(hidden_size, hidden_size, device=device), + ) + self.gate = nn.Sequential( + nn.Linear(hidden_size, hidden_size, device=device), + nn.Tanh(), + ) + + def forward(self, x: torch.Tensor): + N, C, T = x.shape + out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1) + pooled = pool(self.pool_op, x) + gate = self.gate(pooled) + return x + out * gate[..., None] + + +def pool(op_name: str, x: torch.Tensor) -> torch.Tensor: + if op_name == "max": + pooled, _ = torch.max(x, dim=-1) + elif op_name == "mean": + pooled, _ = torch.mean(x, dim=-1) + else: + raise ValueError(f"unknown pool op: {op_name}") + return pooled diff --git a/shap_e/models/generation/pretrained_clip.py b/shap_e/models/generation/pretrained_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..29bf40b20104cf9f57a1d50766a4d8fae8f9bcb9 --- /dev/null +++ b/shap_e/models/generation/pretrained_clip.py @@ -0,0 +1,270 @@ +from typing import Iterable, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from shap_e.models.download import default_cache_dir + +ImageType = Union[np.ndarray, torch.Tensor, Image.Image] + + +class ImageCLIP(nn.Module): + """ + A wrapper around a pre-trained CLIP model that automatically handles + batches of texts, images, and embeddings. + """ + + def __init__( + self, + device: torch.device, + dtype: Optional[torch.dtype] = torch.float32, + ensure_used_params: bool = True, + clip_name: str = "ViT-L/14", + cache_dir: Optional[str] = None, + ): + super().__init__() + + assert clip_name in ["ViT-L/14", "ViT-B/32"] + + self.device = device + self.ensure_used_params = ensure_used_params + + # Lazy import because of torchvision. + import clip + + self.clip_model, self.preprocess = clip.load( + clip_name, device=device, download_root=cache_dir or default_cache_dir() + ) + self.clip_name = clip_name + + if dtype is not None: + self.clip_model.to(dtype) + self._tokenize = clip.tokenize + + @property + def feature_dim(self) -> int: + if self.clip_name == "ViT-L/14": + return 768 + else: + return 512 + + @property + def grid_size(self) -> int: + if self.clip_name == "ViT-L/14": + return 16 + else: + return 7 + + @property + def grid_feature_dim(self) -> int: + if self.clip_name == "ViT-L/14": + return 1024 + else: + return 768 + + def forward( + self, + batch_size: int, + images: Optional[Iterable[Optional[ImageType]]] = None, + texts: Optional[Iterable[Optional[str]]] = None, + embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + """ + Generate a batch of embeddings from a mixture of images, texts, + precomputed embeddings, and possibly empty values. + + For each batch element, at most one of images, texts, and embeddings + should have a non-None value. Embeddings from multiple modalities + cannot be mixed for a single batch element. If no modality is provided, + a zero embedding will be used for the batch element. + """ + image_seq = [None] * batch_size if images is None else list(images) + text_seq = [None] * batch_size if texts is None else list(texts) + embedding_seq = [None] * batch_size if embeddings is None else list(embeddings) + assert len(image_seq) == batch_size, "number of images should match batch size" + assert len(text_seq) == batch_size, "number of texts should match batch size" + assert len(embedding_seq) == batch_size, "number of embeddings should match batch size" + + if self.ensure_used_params: + return self._static_multimodal_embed( + images=image_seq, texts=text_seq, embeddings=embedding_seq + ) + + result = torch.zeros((batch_size, self.feature_dim), device=self.device) + index_images = [] + index_texts = [] + for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)): + assert ( + sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2 + ), "only one modality may be non-None per batch element" + if image is not None: + index_images.append((i, image)) + elif text is not None: + index_texts.append((i, text)) + elif emb is not None: + result[i] = emb.to(result) + + if len(index_images): + embs = self.embed_images((img for _, img in index_images)) + for (i, _), emb in zip(index_images, embs): + result[i] = emb.to(result) + if len(index_texts): + embs = self.embed_text((text for _, text in index_texts)) + for (i, _), emb in zip(index_texts, embs): + result[i] = emb.to(result) + + return result + + def _static_multimodal_embed( + self, + images: List[Optional[ImageType]] = None, + texts: List[Optional[str]] = None, + embeddings: List[Optional[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Like forward(), but always runs all encoders to ensure that + the forward graph looks the same on every rank. + """ + image_emb = self.embed_images(images) + text_emb = self.embed_text(t if t else "" for t in texts) + joined_embs = torch.stack( + [ + emb.to(device=self.device, dtype=torch.float32) + if emb is not None + else torch.zeros(self.feature_dim, device=self.device) + for emb in embeddings + ], + dim=0, + ) + + image_flag = torch.tensor([x is not None for x in images], device=self.device)[ + :, None + ].expand_as(image_emb) + text_flag = torch.tensor([x is not None for x in texts], device=self.device)[ + :, None + ].expand_as(image_emb) + emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[ + :, None + ].expand_as(image_emb) + + return ( + image_flag.float() * image_emb + + text_flag.float() * text_emb + + emb_flag.float() * joined_embs + + self.clip_model.logit_scale * 0 # avoid unused parameters + ) + + def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: + """ + :param xs: N images, stored as numpy arrays, tensors, or PIL images. + :return: an [N x D] tensor of features. + """ + clip_inputs = self.images_to_tensor(xs) + results = self.clip_model.encode_image(clip_inputs).float() + return results / torch.linalg.norm(results, dim=-1, keepdim=True) + + def embed_text(self, prompts: Iterable[str]) -> torch.Tensor: + """ + Embed text prompts as an [N x D] tensor. + """ + enc = self.clip_model.encode_text( + self._tokenize(list(prompts), truncate=True).to(self.device) + ).float() + return enc / torch.linalg.norm(enc, dim=-1, keepdim=True) + + def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: + """ + Embed images into latent grids. + + :param xs: an iterable of images to embed. + :return: a tensor of shape [N x C x L], where L = self.grid_size**2. + """ + if self.ensure_used_params: + extra_value = 0.0 + for p in self.parameters(): + extra_value = extra_value + p.mean() * 0.0 + else: + extra_value = 0.0 + + x = self.images_to_tensor(xs).to(self.clip_model.dtype) + + # https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225 + vt = self.clip_model.visual + x = vt.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + vt.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + vt.positional_embedding.to(x.dtype) + x = vt.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = vt.transformer(x) + x = x.permute(1, 2, 0) # LND -> NDL + + return x[..., 1:].contiguous().float() + extra_value + + def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: + return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device) + + +class FrozenImageCLIP: + def __init__(self, device: torch.device, **kwargs): + self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs) + for parameter in self.model.parameters(): + parameter.requires_grad_(False) + + @property + def feature_dim(self) -> int: + return self.model.feature_dim + + @property + def grid_size(self) -> int: + return self.model.grid_size + + @property + def grid_feature_dim(self) -> int: + return self.model.grid_feature_dim + + def __call__( + self, + batch_size: int, + images: Optional[Iterable[Optional[ImageType]]] = None, + texts: Optional[Iterable[Optional[str]]] = None, + embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, + ) -> torch.Tensor: + # We don't do a no_grad() here so that gradients could still + # flow to the input embeddings argument. + # This behavior is currently not used, but it could be. + return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings) + + def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: + with torch.no_grad(): + return self.model.embed_images(xs) + + def embed_text(self, prompts: Iterable[str]) -> torch.Tensor: + with torch.no_grad(): + return self.model.embed_text(prompts) + + def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: + with torch.no_grad(): + return self.model.embed_images_grid(xs) + + +def _image_to_pil(obj: Optional[ImageType]) -> Image.Image: + if obj is None: + return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8)) + if isinstance(obj, np.ndarray): + return Image.fromarray(obj.astype(np.uint8)) + elif isinstance(obj, torch.Tensor): + return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8)) + else: + return obj diff --git a/shap_e/models/generation/transformer.py b/shap_e/models/generation/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..889caf09f8cc2a05e7083c765ac5c5cbd9624134 --- /dev/null +++ b/shap_e/models/generation/transformer.py @@ -0,0 +1,494 @@ +import math +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from shap_e.models.nn.checkpoint import checkpoint + +from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType +from .util import timestep_embedding + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), True) + x = self.c_proj(x) + return x + + +class MLP(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float = 1.0, + ): + super().__init__() + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + init_scale = init_scale * math.sqrt(1.0 / width) + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + +class PointDiffusionTransformer(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + input_channels: int = 3, + output_channels: int = 3, + n_ctx: int = 1024, + width: int = 512, + layers: int = 12, + heads: int = 8, + init_scale: float = 0.25, + time_token_cond: bool = False, + use_pos_emb: bool = False, + pos_emb_init_scale: float = 1.0, + pos_emb_n_ctx: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + self.n_ctx = n_ctx + self.time_token_cond = time_token_cond + self.use_pos_emb = use_pos_emb + self.time_embed = MLP( + device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width) + ) + self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) + self.backbone = Transformer( + device=device, + dtype=dtype, + n_ctx=n_ctx + int(time_token_cond), + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + ) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) + # with torch.no_grad(): + # self.output_proj.weight.zero_() + # self.output_proj.bias.zero_() + if self.use_pos_emb: + self.register_parameter( + "pos_emb", + nn.Parameter( + pos_emb_init_scale + * torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype) + ), + ) + + def forward(self, x: torch.Tensor, t: torch.Tensor): + """ + :param x: an [N x C x T] tensor. + :param t: an [N] tensor. + :return: an [N x C' x T] tensor. + """ + assert x.shape[-1] == self.n_ctx + t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) + return self._forward_with_cond(x, [(t_embed, self.time_token_cond)]) + + def _forward_with_cond( + self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]] + ) -> torch.Tensor: + h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC + for emb, as_token in cond_as_token: + if not as_token: + h = h + emb[:, None] + if self.use_pos_emb: + h = h + self.pos_emb + extra_tokens = [ + (emb[:, None] if len(emb.shape) == 2 else emb) + for emb, as_token in cond_as_token + if as_token + ] + if len(extra_tokens): + h = torch.cat(extra_tokens + [h], dim=1) + h = self.ln_pre(h) + h = self.backbone(h) + h = self.ln_post(h) + if len(extra_tokens): + h = h[:, sum(h.shape[1] for h in extra_tokens):] + h = self.output_proj(h) + return h.permute(0, 2, 1) # NCL -> NLC + + + + +class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int = 1024, + token_cond: bool = False, + cond_drop_prob: float = 0.0, + frozen_clip: bool = True, + **kwargs, + ): + super().__init__( + device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs + ) + # print("!!!!!", "deivce:", device, "dtype:", dtype, "n_ctx:", n_ctx, "token_cond:", token_cond, "cond_drop_prob:", cond_drop_prob, "frozen_clip:", frozen_clip, "kwargs:", kwargs) + self.n_ctx = n_ctx + self.token_cond = token_cond + self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) + self.clip_embed = nn.Linear( + self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype + ) + self.cond_drop_prob = cond_drop_prob + + def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: + with torch.no_grad(): + return dict(embeddings=self.clip(batch_size, **model_kwargs)) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + images: Optional[Iterable[Optional[ImageType]]] = None, + texts: Optional[Iterable[Optional[str]]] = None, + embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, + ): + """ + :param x: an [N x C x T] tensor. + :param t: an [N] tensor. + :param images: a batch of images to condition on. + :param texts: a batch of texts to condition on. + :param embeddings: a batch of CLIP embeddings to condition on. + :return: an [N x C' x T] tensor. + """ + # print("x.shape", x.shape, "t.shape", t.shape, "images", images, "texts", texts, "embeddings", embeddings) + assert x.shape[-1] == self.n_ctx # self.n_ctx = 1024 + + t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) + clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings) + assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0] + + if self.training: + mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob + clip_out = clip_out * mask[:, None].to(clip_out) + + # Rescale the features to have unit variance + clip_out = math.sqrt(clip_out.shape[1]) * clip_out + + clip_embed = self.clip_embed(clip_out) + + cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)] + return self._forward_with_cond(x, cond) + + +class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int = 1024, + cond_drop_prob: float = 0.0, + frozen_clip: bool = True, + **kwargs, + ): + clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) + super().__init__( + device=device, + dtype=dtype, + n_ctx=n_ctx + clip.grid_size**2, + pos_emb_n_ctx=n_ctx, + **kwargs, + ) + self.n_ctx = n_ctx + self.clip = clip + self.clip_embed = nn.Sequential( + nn.LayerNorm( + normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype + ), + nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), + ) + self.cond_drop_prob = cond_drop_prob + + def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: + _ = batch_size + with torch.no_grad(): + return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"])) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + images: Optional[Iterable[ImageType]] = None, + embeddings: Optional[Iterable[torch.Tensor]] = None, + ): + """ + :param x: an [N x C x T] tensor. + :param t: an [N] tensor. + :param images: a batch of images to condition on. + :param embeddings: a batch of CLIP latent grids to condition on. + :return: an [N x C' x T] tensor. + """ + assert images is not None or embeddings is not None, "must specify images or embeddings" + assert images is None or embeddings is None, "cannot specify both images and embeddings" + assert x.shape[-1] == self.n_ctx + + t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) + + if images is not None: + clip_out = self.clip.embed_images_grid(images) + else: + clip_out = embeddings + + if self.training: + mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob + clip_out = clip_out * mask[:, None, None].to(clip_out) + + clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC + clip_embed = self.clip_embed(clip_out) + + cond = [(t_embed, self.time_token_cond), (clip_embed, True)] + return self._forward_with_cond(x, cond) + + +class UpsamplePointDiffusionTransformer(PointDiffusionTransformer): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + cond_input_channels: Optional[int] = None, + cond_ctx: int = 1024, + n_ctx: int = 4096 - 1024, + channel_scales: Optional[Sequence[float]] = None, + channel_biases: Optional[Sequence[float]] = None, + **kwargs, + ): + super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs) + self.n_ctx = n_ctx + self.cond_input_channels = cond_input_channels or self.input_channels + self.cond_point_proj = nn.Linear( + self.cond_input_channels, self.backbone.width, device=device, dtype=dtype + ) + + self.register_buffer( + "channel_scales", + torch.tensor(channel_scales, dtype=dtype, device=device) + if channel_scales is not None + else None, + ) + self.register_buffer( + "channel_biases", + torch.tensor(channel_biases, dtype=dtype, device=device) + if channel_biases is not None + else None, + ) + + def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor): + """ + :param x: an [N x C1 x T] tensor. + :param t: an [N] tensor. + :param low_res: an [N x C2 x T'] tensor of conditioning points. + :return: an [N x C3 x T] tensor. + """ + assert x.shape[-1] == self.n_ctx + t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) + low_res_embed = self._embed_low_res(low_res) + cond = [(t_embed, self.time_token_cond), (low_res_embed, True)] + return self._forward_with_cond(x, cond) + + def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor: + if self.channel_scales is not None: + x = x * self.channel_scales[None, :, None] + if self.channel_biases is not None: + x = x + self.channel_biases[None, :, None] + return self.cond_point_proj(x.permute(0, 2, 1)) + + +class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int = 4096 - 1024, + cond_drop_prob: float = 0.0, + frozen_clip: bool = True, + **kwargs, + ): + clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device) + super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs) + self.n_ctx = n_ctx + + self.clip = clip + self.clip_embed = nn.Sequential( + nn.LayerNorm( + normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype + ), + nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype), + ) + self.cond_drop_prob = cond_drop_prob + + def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]: + _ = batch_size + with torch.no_grad(): + return dict( + embeddings=self.clip.embed_images_grid(model_kwargs["images"]), + low_res=model_kwargs["low_res"], + ) + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + *, + low_res: torch.Tensor, + images: Optional[Iterable[ImageType]] = None, + embeddings: Optional[Iterable[torch.Tensor]] = None, + ): + """ + :param x: an [N x C1 x T] tensor. + :param t: an [N] tensor. + :param low_res: an [N x C2 x T'] tensor of conditioning points. + :param images: a batch of images to condition on. + :param embeddings: a batch of CLIP latent grids to condition on. + :return: an [N x C3 x T] tensor. + """ + assert x.shape[-1] == self.n_ctx + t_embed = self.time_embed(timestep_embedding(t, self.backbone.width)) + low_res_embed = self._embed_low_res(low_res) + + if images is not None: + clip_out = self.clip.embed_images_grid(images) + else: + clip_out = embeddings + + if self.training: + mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob + clip_out = clip_out * mask[:, None, None].to(clip_out) + + clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC + clip_embed = self.clip_embed(clip_out) + + cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)] + return self._forward_with_cond(x, cond) + diff --git a/shap_e/models/generation/util.py b/shap_e/models/generation/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac30033a5c75b2c917160f661d48c5edad14871 --- /dev/null +++ b/shap_e/models/generation/util.py @@ -0,0 +1,23 @@ +import math + +import torch + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].to(timesteps.dtype) * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/shap_e/models/nerf/__init__.py b/shap_e/models/nerf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/models/nerf/__pycache__/__init__.cpython-39.pyc b/shap_e/models/nerf/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19981c42c1140c5d72116fa2d559ecdff4caae65 Binary files /dev/null and b/shap_e/models/nerf/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/models/nerf/__pycache__/model.cpython-39.pyc b/shap_e/models/nerf/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b95c7e1d5ae435ddbda6a16fa76e7ea58f93c9ae Binary files /dev/null and b/shap_e/models/nerf/__pycache__/model.cpython-39.pyc differ diff --git a/shap_e/models/nerf/__pycache__/ray.cpython-39.pyc b/shap_e/models/nerf/__pycache__/ray.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0a7bae86307c1e034c43a7fe9c0c75c6217357b Binary files /dev/null and b/shap_e/models/nerf/__pycache__/ray.cpython-39.pyc differ diff --git a/shap_e/models/nerf/__pycache__/renderer.cpython-39.pyc b/shap_e/models/nerf/__pycache__/renderer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..647f774344c5f1ab7f09f482171e52acf8da7aaf Binary files /dev/null and b/shap_e/models/nerf/__pycache__/renderer.cpython-39.pyc differ diff --git a/shap_e/models/nerf/model.py b/shap_e/models/nerf/model.py new file mode 100644 index 0000000000000000000000000000000000000000..372ee76da1a26119c91373e379ad1f6dc3c3cc67 --- /dev/null +++ b/shap_e/models/nerf/model.py @@ -0,0 +1,255 @@ +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn + +from shap_e.models.nn.checkpoint import checkpoint +from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis +from shap_e.models.nn.meta import MetaModule, subdict +from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init +from shap_e.models.nn.utils import ArrayType +from shap_e.models.query import Query +from shap_e.util.collections import AttrDict + + +class NeRFModel(ABC): + """ + Parametric scene representation whose outputs are integrated by NeRFRenderer + """ + + @abstractmethod + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict: + """ + :param query: the points in the field to query. + :param params: Meta parameters + :param options: Optional hyperparameters + :return: An AttrDict containing at least + - density: [batch_size x ... x 1] + - channels: [batch_size x ... x n_channels] + - aux_losses: [batch_size x ... x 1] + """ + + +class VoidNeRFModel(MetaModule, NeRFModel): + """ + Implements the default empty space model where all queries are rendered as + background. + """ + + def __init__( + self, + background: ArrayType, + trainable: bool = False, + channel_scale: float = 255.0, + device: torch.device = torch.device("cuda"), + ): + super().__init__() + background = nn.Parameter( + torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device) + / channel_scale + ) + if trainable: + self.register_parameter("background", background) + else: + self.register_buffer("background", background) + + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict: + _ = params + default_bg = self.background[None] + background = options.get("background", default_bg) if options is not None else default_bg + + shape = query.position.shape[:-1] + ones = [1] * (len(shape) - 1) + n_channels = background.shape[-1] + background = torch.broadcast_to( + background.view(background.shape[0], *ones, n_channels), [*shape, n_channels] + ) + return background + + +class MLPNeRFModel(MetaModule, NeRFModel): + def __init__( + self, + # Positional encoding parameters + n_levels: int = 10, + # MLP parameters + d_hidden: int = 256, + n_density_layers: int = 4, + n_channel_layers: int = 1, + n_channels: int = 3, + sh_degree: int = 4, + activation: str = "relu", + density_activation: str = "exp", + init: Optional[str] = None, + init_scale: float = 1.0, + output_activation: str = "sigmoid", + meta_parameters: bool = False, + trainable_meta: bool = False, + zero_out: bool = True, + register_freqs: bool = True, + posenc_version: str = "v1", + device: torch.device = torch.device("cuda"), + ): + super().__init__() + + # Positional encoding + if register_freqs: + # not used anymore + self.register_buffer( + "freqs", + 2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels), + ) + + self.posenc_version = posenc_version + dummy = torch.eye(1, 3) + d_input = encode_position(posenc_version, position=dummy).shape[-1] + + self.n_levels = n_levels + + self.sh_degree = sh_degree + d_sh_coeffs = sh_degree**2 + + self.meta_parameters = meta_parameters + + mlp_cls = ( + partial( + MetaMLP, + meta_scale=False, + meta_shift=False, + meta_proj=True, + meta_bias=True, + trainable_meta=trainable_meta, + ) + if meta_parameters + else MLP + ) + + self.density_mlp = mlp_cls( + d_input=d_input, + d_hidden=[d_hidden] * (n_density_layers - 1), + d_output=d_hidden, + act_name=activation, + init_scale=init_scale, + ) + + self.channel_mlp = mlp_cls( + d_input=d_hidden + d_sh_coeffs, + d_hidden=[d_hidden] * n_channel_layers, + d_output=n_channels, + act_name=activation, + init_scale=init_scale, + ) + + self.act = get_act(output_activation) + self.density_act = get_act(density_activation) + + mlp_init( + list(self.density_mlp.affines) + list(self.channel_mlp.affines), + init=init, + init_scale=init_scale, + ) + + if zero_out: + zero_init(self.channel_mlp.affines[-1]) + + self.to(device) + + def encode_position(self, query: Query): + h = encode_position(self.posenc_version, position=query.position) + return h + + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict: + params = self.update(params) + + options = AttrDict() if options is None else AttrDict(options) + + query = query.copy() + + h_position = self.encode_position(query) + + if self.meta_parameters: + density_params = subdict(params, "density_mlp") + density_mlp = partial( + self.density_mlp, params=density_params, options=options, log_prefix="density_" + ) + density_mlp_parameters = list(density_params.values()) + else: + density_mlp = partial(self.density_mlp, options=options, log_prefix="density_") + density_mlp_parameters = self.density_mlp.parameters() + h_density = checkpoint( + density_mlp, + (h_position,), + density_mlp_parameters, + options.checkpoint_nerf_mlp, + ) + h_direction = maybe_get_spherical_harmonics_basis( + sh_degree=self.sh_degree, + coords_shape=query.position.shape, + coords=query.direction, + device=query.position.device, + ) + + if self.meta_parameters: + channel_params = subdict(params, "channel_mlp") + channel_mlp = partial( + self.channel_mlp, params=channel_params, options=options, log_prefix="channel_" + ) + channel_mlp_parameters = list(channel_params.values()) + else: + channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_") + channel_mlp_parameters = self.channel_mlp.parameters() + h_channel = checkpoint( + channel_mlp, + (torch.cat([h_density, h_direction], dim=-1),), + channel_mlp_parameters, + options.checkpoint_nerf_mlp, + ) + + density_logit = h_density[..., :1] + + res = AttrDict( + density_logit=density_logit, + density=self.density_act(density_logit), + channels=self.act(h_channel), + aux_losses=AttrDict(), + no_weight_grad_aux_losses=AttrDict(), + ) + if options.return_h_density: + res.h_density = h_density + + return res + + +def maybe_get_spherical_harmonics_basis( + sh_degree: int, + coords_shape: Tuple[int], + coords: Optional[torch.Tensor] = None, + device: torch.device = torch.device("cuda"), +) -> torch.Tensor: + """ + :param sh_degree: Spherical harmonics degree + :param coords_shape: [*shape, 3] + :param coords: optional coordinate tensor of coords_shape + """ + if coords is None: + return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device) + + return spherical_harmonics_basis(coords, sh_degree) diff --git a/shap_e/models/nerf/ray.py b/shap_e/models/nerf/ray.py new file mode 100644 index 0000000000000000000000000000000000000000..7346adb3948e81f6ff95f0dfb184ea87a4aee407 --- /dev/null +++ b/shap_e/models/nerf/ray.py @@ -0,0 +1,512 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from shap_e.models.nn.utils import sample_pmf +from shap_e.models.volume import Volume, VolumeRange +from shap_e.util.collections import AttrDict + +from .model import NeRFModel, Query + + +def render_rays( + rays: torch.Tensor, + parts: List["RayVolumeIntegral"], + void_model: NeRFModel, + shared: bool = False, + prev_raw_outputs: Optional[List[AttrDict]] = None, + render_with_direction: bool = True, + importance_sampling_options: Optional[Dict[str, Any]] = None, +) -> Tuple["RayVolumeIntegralResults", List["RaySampler"], List[AttrDict]]: + """ + Perform volumetric rendering over a partition of possible t's in the union + of rendering volumes (written below with some abuse of notations) + + C(r) := sum( + transmittance(t[i]) * + integrate( + lambda t: density(t) * channels(t) * transmittance(t), + [t[i], t[i + 1]], + ) + for i in range(len(parts)) + ) + transmittance(t[-1]) * void_model(t[-1]).channels + + where + + 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the + probability of light passing through the volume specified by [t[0], s]. + (transmittance of 1 means light can pass freely) + 2) density and channels are obtained by evaluating the appropriate + part.model at time t. + 3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects + (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface + of the shell (if bounded). If the ray does not intersect, the integral over + this segment is evaluated as 0 and transmittance(t[i + 1]) := + transmittance(t[i]). + 4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that + is evaluated by the void_model (i.e. we consider this space to be empty). + + :param rays: [batch_size x ... x 2 x 3] origin and direction. + :param parts: disjoint volume integrals. + :param void_model: use this model to integrate over the empty space + :param shared: All RayVolumeIntegrals are calculated with the same model. + :param prev_raw_outputs: Raw outputs from the previous rendering step + + :return: A tuple of + - AttrDict containing the rendered `channels`, `distances`, and the `aux_losses` + - A list of importance samplers for additional fine-grained rendering + - A list of raw output for each interval + """ + if importance_sampling_options is None: + importance_sampling_options = {} + + origin, direc = rays[..., 0, :], rays[..., 1, :] + + if prev_raw_outputs is None: + prev_raw_outputs = [None] * len(parts) + + samplers = [] + raw_outputs = [] + t0 = None + results = None + # import pdb; pdb.set_trace() + for part_i, prev_raw_i in zip(parts, prev_raw_outputs): + + # Integrate over [t[i], t[i + 1]] + results_i = part_i.render_rays( + origin, + direc, + t0=t0, + prev_raw=prev_raw_i, + shared=shared, + render_with_direction=render_with_direction, + ) + + # Create an importance sampler for (optional) fine rendering + samplers.append( + ImportanceRaySampler( + results_i.volume_range, results_i.raw, **importance_sampling_options + ) + ) + raw_outputs.append(results_i.raw) + + # Pass t[i + 1] as the start of integration for the next interval. + t0 = results_i.volume_range.next_t0() + + # Combine the results from [t[0], t[i]] and [t[i], t[i+1]] + results = results_i if results is None else results.combine(results_i) + + # While integrating out [t[-1], math.inf] is the correct thing to do, this + # erases a lot of useful information. Also, void_model is meant to predict + # the channels at t=math.inf. + + # # Add the void background over [t[-1], math.inf] to complete integration. + # results = results.combine( + # RayVolumeIntegralResults( + # output=AttrDict( + # channels=void_model(origin, direc), + # distances=torch.zeros_like(t0), + # aux_losses=AttrDict(), + # ), + # volume_range=VolumeRange( + # t0=t0, + # t1=torch.full_like(t0, math.inf), + # intersected=torch.full_like(results.volume_range.intersected, True), + # ), + # # Void space extends to infinity. It is assumed that no light + # # passes beyond the void. + # transmittance=torch.zeros_like(results_i.transmittance), + # ) + # ) + results.output.channels = results.output.channels + results.transmittance * void_model( + Query(origin, direc) + ) + + return results, samplers, raw_outputs + + +@dataclass +class RayVolumeIntegralResults: + """ + Stores the relevant state and results of + + integrate( + lambda t: density(t) * channels(t) * transmittance(t), + [t0, t1], + ) + """ + + # Rendered output and auxiliary losses + # output.channels has shape [batch_size, *inner_shape, n_channels] + output: AttrDict + + """ + Optional values + """ + + # Raw values contain the sampled `ts`, `density`, `channels`, etc. + raw: Optional[AttrDict] = None + + # Integration + volume_range: Optional[VolumeRange] = None + + # If a ray intersects, the transmittance from t0 to t1 (e.g. the + # probability that the ray passes through this volume). + # has shape [batch_size, *inner_shape, 1] + transmittance: Optional[torch.Tensor] = None + + def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegralResults": + """ + Combines the integration results of `self` over [t0, t1] and + `cur` over [t1, t2] to produce a new set of results over [t0, t2] by + using a similar equation to (4) in NeRF++: + + integrate( + lambda t: density(t) * channels(t) * transmittance(t), + [t0, t2] + ) + + = integrate( + lambda t: density(t) * channels(t) * transmittance(t), + [t0, t1] + ) + transmittance(t1) * integrate( + lambda t: density(t) * channels(t) * transmittance(t), + [t1, t2] + ) + """ + assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0) + + def _combine_fn( + prev_val: Optional[torch.Tensor], + cur_val: Optional[torch.Tensor], + *, + prev_transmittance: torch.Tensor, + ): + assert prev_val is not None + if cur_val is None: + # cur_output.aux_losses are empty for the void_model. + return prev_val + return prev_val + prev_transmittance * cur_val + + output = self.output.combine( + cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance) + ) + + combined = RayVolumeIntegralResults( + output=output, + volume_range=self.volume_range.extend(cur.volume_range), + transmittance=self.transmittance * cur.transmittance, + ) + return combined + + +@dataclass +class RayVolumeIntegral: + model: NeRFModel + volume: Volume + sampler: "RaySampler" + n_samples: int + + def render_rays( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0: Optional[torch.Tensor] = None, + prev_raw: Optional[AttrDict] = None, + shared: bool = False, + render_with_direction: bool = True, + ) -> "RayVolumeIntegralResults": + """ + Perform volumetric rendering over the given volume. + + :param position: [batch_size, *shape, 3] + :param direction: [batch_size, *shape, 3] + :param t0: Optional [batch_size, *shape, 1] + :param prev_raw: the raw outputs when using multiple levels with this model. + :param shared: means the same model is used for all RayVolumeIntegral's + :param render_with_direction: use the incoming ray direction when querying the model. + + :return: RayVolumeIntegralResults + """ + # 1. Intersect the rays with the current volume and sample ts to + # integrate along. + vrange = self.volume.intersect(origin, direction, t0_lower=t0) + ts = self.sampler.sample(vrange.t0, vrange.t1, self.n_samples) + + if prev_raw is not None and not shared: + # Append the previous ts now before fprop because previous + # rendering used a different model and we can't reuse the output. + ts = torch.sort(torch.cat([ts, prev_raw.ts], dim=-2), dim=-2).values + + # Shape sanity checks + batch_size, *_shape, _t0_dim = vrange.t0.shape + _, *ts_shape, _ts_dim = ts.shape + + # 2. Get the points along the ray and query the model + directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3]) + positions = origin.unsqueeze(-2) + ts * directions + + optional_directions = directions if render_with_direction else None + mids = (ts[..., 1:, :] + ts[..., :-1, :]) / 2 + raw = self.model( + Query( + position=positions, + direction=optional_directions, + t_min=torch.cat([vrange.t0[..., None, :], mids], dim=-2), + t_max=torch.cat([mids, vrange.t1[..., None, :]], dim=-2), + ) + ) + raw.ts = ts + + if prev_raw is not None and shared: + # We can append the additional queries to previous raw outputs + # before integration + copy = prev_raw.copy() + result = torch.sort(torch.cat([raw.pop("ts"), copy.pop("ts")], dim=-2), dim=-2) + merge_results = partial(self._merge_results, dim=-2, indices=result.indices) + raw = raw.combine(copy, merge_results) + raw.ts = result.values + + # 3. Integrate the raw results + output, transmittance = self.integrate_samples(vrange, raw) + + # 4. Clean up results that do not intersect with the volume. + transmittance = torch.where( + vrange.intersected, transmittance, torch.ones_like(transmittance) + ) + + def _mask_fn(_key: str, tensor: torch.Tensor): + return torch.where(vrange.intersected, tensor, torch.zeros_like(tensor)) + + def _is_tensor(_key: str, value: Any): + return isinstance(value, torch.Tensor) + + output = output.map(map_fn=_mask_fn, should_map=_is_tensor) + + return RayVolumeIntegralResults( + output=output, + raw=raw, + volume_range=vrange, + transmittance=transmittance, + ) + + def integrate_samples( + self, + volume_range: VolumeRange, + raw: AttrDict, + ) -> Tuple[AttrDict, torch.Tensor]: + """ + Integrate the raw.channels along with other aux_losses and values to + produce the final output dictionary containing rendered `channels`, + estimated `distances` and `aux_losses`. + + :param volume_range: Specifies the integral range [t0, t1] + :param raw: Contains a dict of function evaluations at ts. Should have + + density: torch.Tensor [batch_size, *shape, n_samples, 1] + channels: torch.Tensor [batch_size, *shape, n_samples, n_channels] + aux_losses: {key: torch.Tensor [batch_size, *shape, n_samples, 1] for each key} + no_weight_grad_aux_losses: an optional set of losses for which the weights + should be detached before integration. + + after the call, integrate_samples populates some intermediate calculations + for later use like + + weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density * + transmittance)[i] weight for each rgb output at [..., i, :]. + :returns: a tuple of ( + a dictionary of rendered outputs and aux_losses, + transmittance of this volume, + ) + """ + + # 1. Calculate the weights + _, _, dt = volume_range.partition(raw.ts) + ddensity = raw.density * dt + + mass = torch.cumsum(ddensity, dim=-2) + transmittance = torch.exp(-mass[..., -1, :]) + + alphas = 1.0 - torch.exp(-ddensity) + Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2)) + # This is the probability of light hitting and reflecting off of + # something at depth [..., i, :]. + weights = alphas * Ts + + # 2. Integrate all results + def _integrate(key: str, samples: torch.Tensor, weights: torch.Tensor): + if key == "density": + # Omit integrating the density, because we don't need it + return None + return torch.sum(samples * weights, dim=-2) + + def _is_tensor(_key: str, value: Any): + return isinstance(value, torch.Tensor) + + if raw.no_weight_grad_aux_losses: + extra_aux_losses = raw.no_weight_grad_aux_losses.map( + partial(_integrate, weights=weights.detach()), should_map=_is_tensor + ) + else: + extra_aux_losses = {} + output = raw.map(partial(_integrate, weights=weights), should_map=_is_tensor) + if "no_weight_grad_aux_losses" in output: + del output["no_weight_grad_aux_losses"] + output.aux_losses.update(extra_aux_losses) + + # Integrating the ts yields the distance away from the origin; rename the variable. + output.distances = output.ts + del output["ts"] + del output["density"] + + assert output.distances.shape == (*output.channels.shape[:-1], 1) + assert output.channels.shape[:-1] == raw.channels.shape[:-2] + assert output.channels.shape[-1] == raw.channels.shape[-1] + + # 3. Reduce loss + def _reduce_loss(_key: str, loss: torch.Tensor): + return loss.view(loss.shape[0], -1).sum(dim=-1) + + # 4. Store other useful calculations + raw.weights = weights + + output.aux_losses = output.aux_losses.map(_reduce_loss) + + return output, transmittance + + def _merge_results( + self, a: Optional[torch.Tensor], b: torch.Tensor, dim: int, indices: torch.Tensor + ): + """ + :param a: [..., n_a, ...]. The other dictionary containing the b's may + contain extra tensors from earlier calculations, so a can be None. + :param b: [..., n_b, ...] + :param dim: dimension to merge + :param indices: how the merged results should be sorted at the end + :return: a concatted and sorted tensor of size [..., n_a + n_b, ...] + """ + if a is None: + return None + + merged = torch.cat([a, b], dim=dim) + return torch.gather(merged, dim=dim, index=torch.broadcast_to(indices, merged.shape)) + + +class RaySampler(ABC): + @abstractmethod + def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + :param t0: start time has shape [batch_size, *shape, 1] + :param t1: finish time has shape [batch_size, *shape, 1] + :param n_samples: number of ts to sample + :return: sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + + +class StratifiedRaySampler(RaySampler): + """ + Instead of fixed intervals, a sample is drawn uniformly at random from each + interval. + """ + + def __init__(self, depth_mode: str = "linear"): + """ + :param depth_mode: linear samples ts linearly in depth. harmonic ensures + closer points are sampled more densely. + """ + self.depth_mode = depth_mode + assert self.depth_mode in ("linear", "geometric", "harmonic") + + def sample( + self, + t0: torch.Tensor, + t1: torch.Tensor, + n_samples: int, + epsilon: float = 1e-3, + ) -> torch.Tensor: + """ + :param t0: start time has shape [batch_size, *shape, 1] + :param t1: finish time has shape [batch_size, *shape, 1] + :param n_samples: number of ts to sample + :return: sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + ones = [1] * (len(t0.shape) - 1) + ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device) + + if self.depth_mode == "linear": + ts = t0 * (1.0 - ts) + t1 * ts + elif self.depth_mode == "geometric": + ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp() + elif self.depth_mode == "harmonic": + # The original NeRF recommends this interpolation scheme for + # spherical scenes, but there could be some weird edge cases when + # the observer crosses from the inner to outer volume. + ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts) + + mids = 0.5 * (ts[..., 1:] + ts[..., :-1]) + upper = torch.cat([mids, t1], dim=-1) + lower = torch.cat([t0, mids], dim=-1) + t_rand = torch.rand_like(ts) + + ts = lower + (upper - lower) * t_rand + return ts.unsqueeze(-1) + + +class ImportanceRaySampler(RaySampler): + """ + Given the initial estimate of densities, this samples more from + regions/bins expected to have objects. + """ + + def __init__( + self, volume_range: VolumeRange, raw: AttrDict, blur_pool: bool = False, alpha: float = 1e-5 + ): + """ + :param volume_range: the range in which a ray intersects the given volume. + :param raw: dictionary of raw outputs from the NeRF models of shape + [batch_size, *shape, n_coarse_samples, 1]. Should at least contain + + :param ts: earlier samples from the coarse rendering step + :param weights: discretized version of density * transmittance + :param blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF. + :param alpha: small value to add to weights. + """ + self.volume_range = volume_range + self.ts = raw.ts.clone().detach() + self.weights = raw.weights.clone().detach() + self.blur_pool = blur_pool + self.alpha = alpha + + @torch.no_grad() + def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + :param t0: start time has shape [batch_size, *shape, 1] + :param t1: finish time has shape [batch_size, *shape, 1] + :param n_samples: number of ts to sample + :return: sampled ts of shape [batch_size, *shape, n_samples, 1] + """ + lower, upper, _ = self.volume_range.partition(self.ts) + + batch_size, *shape, n_coarse_samples, _ = self.ts.shape + + weights = self.weights + if self.blur_pool: + padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2) + maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :]) + weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :]) + weights = weights + self.alpha + pmf = weights / weights.sum(dim=-2, keepdim=True) + inds = sample_pmf(pmf, n_samples) + assert inds.shape == (batch_size, *shape, n_samples, 1) + assert (inds >= 0).all() and (inds < n_coarse_samples).all() + + t_rand = torch.rand(inds.shape, device=inds.device) + lower_ = torch.gather(lower, -2, inds) + upper_ = torch.gather(upper, -2, inds) + + ts = lower_ + (upper_ - lower_) * t_rand + ts = torch.sort(ts, dim=-2).values + return ts diff --git a/shap_e/models/nerf/renderer.py b/shap_e/models/nerf/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..c356512bc43ebe0ba0a8a78a8d32a989d7a13056 --- /dev/null +++ b/shap_e/models/nerf/renderer.py @@ -0,0 +1,301 @@ +from functools import partial +from typing import Any, Dict, Optional + +import torch + +from shap_e.models.nn.meta import subdict +from shap_e.models.renderer import RayRenderer +from shap_e.models.volume import Volume +from shap_e.util.collections import AttrDict + +from .model import NeRFModel +from .ray import RayVolumeIntegral, StratifiedRaySampler, render_rays + + +class TwoStepNeRFRenderer(RayRenderer): + """ + Coarse and fine-grained rendering as proposed by NeRF. This class + additionally supports background rendering like NeRF++. + """ + + def __init__( + self, + n_coarse_samples: int, + n_fine_samples: int, + void_model: NeRFModel, + fine_model: NeRFModel, + volume: Volume, + coarse_model: Optional[NeRFModel] = None, + coarse_background_model: Optional[NeRFModel] = None, + fine_background_model: Optional[NeRFModel] = None, + outer_volume: Optional[Volume] = None, + foreground_stratified_depth_sampling_mode: str = "linear", + background_stratified_depth_sampling_mode: str = "linear", + importance_sampling_options: Optional[Dict[str, Any]] = None, + channel_scale: float = 255, + device: torch.device = torch.device("cuda"), + **kwargs, + ): + """ + :param outer_volume: is where distant objects are encoded. + """ + super().__init__(**kwargs) + + if coarse_model is None: + assert ( + fine_background_model is None or coarse_background_model is None + ), "models should be shared for both fg and bg" + + self.n_coarse_samples = n_coarse_samples + self.n_fine_samples = n_fine_samples + self.void_model = void_model + self.coarse_model = coarse_model + self.fine_model = fine_model + self.volume = volume + self.coarse_background_model = coarse_background_model + self.fine_background_model = fine_background_model + self.outer_volume = outer_volume + self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode + self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode + self.importance_sampling_options = AttrDict(importance_sampling_options or {}) + self.channel_scale = channel_scale + self.device = device + self.to(device) + + if self.coarse_background_model is not None: + assert self.fine_background_model is not None + assert self.outer_volume is not None + + def render_rays( + self, + batch: Dict, + params: Optional[Dict] = None, + options: Optional[Dict] = None, + ) -> AttrDict: + params = self.update(params) + + batch = AttrDict(batch) + if options is None: + options = AttrDict() + options.setdefault("render_background", True) + options.setdefault("render_with_direction", True) + options.setdefault("n_coarse_samples", self.n_coarse_samples) + options.setdefault("n_fine_samples", self.n_fine_samples) + options.setdefault( + "foreground_stratified_depth_sampling_mode", + self.foreground_stratified_depth_sampling_mode, + ) + options.setdefault( + "background_stratified_depth_sampling_mode", + self.background_stratified_depth_sampling_mode, + ) + + shared = self.coarse_model is None + + # First, render rays using the coarse models with stratified ray samples. + coarse_model, coarse_key = ( + (self.fine_model, "fine_model") if shared else (self.coarse_model, "coarse_model") + ) + coarse_model = partial( + coarse_model, + params=subdict(params, coarse_key), + options=options, + ) + parts = [ + RayVolumeIntegral( + model=coarse_model, + volume=self.volume, + sampler=StratifiedRaySampler( + depth_mode=options.foreground_stratified_depth_sampling_mode, + ), + n_samples=options.n_coarse_samples, + ), + ] + if options.render_background and self.outer_volume is not None: + coarse_background_model, coarse_background_key = ( + (self.fine_background_model, "fine_background_model") + if shared + else (self.coarse_background_model, "coarse_background_model") + ) + coarse_background_model = partial( + coarse_background_model, + params=subdict(params, coarse_background_key), + options=options, + ) + parts.append( + RayVolumeIntegral( + model=coarse_background_model, + volume=self.outer_volume, + sampler=StratifiedRaySampler( + depth_mode=options.background_stratified_depth_sampling_mode, + ), + n_samples=options.n_coarse_samples, + ) + ) + coarse_results, samplers, coarse_raw_outputs = render_rays( + batch.rays, + parts, + partial(self.void_model, options=options), + shared=shared, + render_with_direction=options.render_with_direction, + importance_sampling_options=AttrDict(self.importance_sampling_options), + ) + + # Then, render rays using the fine models with importance-weighted ray samples. + fine_model = partial( + self.fine_model, + params=subdict(params, "fine_model"), + options=options, + ) + parts = [ + RayVolumeIntegral( + model=fine_model, + volume=self.volume, + sampler=samplers[0], + n_samples=options.n_fine_samples, + ), + ] + if options.render_background and self.outer_volume is not None: + fine_background_model = partial( + self.fine_background_model, + params=subdict(params, "fine_background_model"), + options=options, + ) + parts.append( + RayVolumeIntegral( + model=fine_background_model, + volume=self.outer_volume, + sampler=samplers[1], + n_samples=options.n_fine_samples, + ) + ) + fine_results, *_ = render_rays( + batch.rays, + parts, + partial(self.void_model, options=options), + shared=shared, + prev_raw_outputs=coarse_raw_outputs, + render_with_direction=options.render_with_direction, + ) + + # Combine results + aux_losses = fine_results.output.aux_losses.copy() + for key, val in coarse_results.output.aux_losses.items(): + aux_losses[key + "_coarse"] = val + + return AttrDict( + channels=fine_results.output.channels * self.channel_scale, + channels_coarse=coarse_results.output.channels * self.channel_scale, + distances=fine_results.output.distances, + transmittance=fine_results.transmittance, + transmittance_coarse=coarse_results.transmittance, + t0=fine_results.volume_range.t0, + t1=fine_results.volume_range.t1, + intersected=fine_results.volume_range.intersected, + aux_losses=aux_losses, + ) + + +class OneStepNeRFRenderer(RayRenderer): + """ + Renders rays using stratified sampling only unlike vanilla NeRF. + The same setup as NeRF++. + """ + + def __init__( + self, + n_samples: int, + void_model: NeRFModel, + foreground_model: NeRFModel, + volume: Volume, + background_model: Optional[NeRFModel] = None, + outer_volume: Optional[Volume] = None, + foreground_stratified_depth_sampling_mode: str = "linear", + background_stratified_depth_sampling_mode: str = "linear", + channel_scale: float = 255, + device: torch.device = torch.device("cuda"), + **kwargs, + ): + super().__init__(**kwargs) + self.n_samples = n_samples + self.void_model = void_model + self.foreground_model = foreground_model + self.volume = volume + self.background_model = background_model + self.outer_volume = outer_volume + self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode + self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode + self.channel_scale = channel_scale + self.device = device + self.to(device) + + def render_rays( + self, + batch: Dict, + params: Optional[Dict] = None, + options: Optional[Dict] = None, + ) -> AttrDict: + params = self.update(params) + + batch = AttrDict(batch) + if options is None: + options = AttrDict() + options.setdefault("render_background", True) + options.setdefault("render_with_direction", True) + options.setdefault("n_samples", self.n_samples) + options.setdefault( + "foreground_stratified_depth_sampling_mode", + self.foreground_stratified_depth_sampling_mode, + ) + options.setdefault( + "background_stratified_depth_sampling_mode", + self.background_stratified_depth_sampling_mode, + ) + + foreground_model = partial( + self.foreground_model, + params=subdict(params, "foreground_model"), + options=options, + ) + parts = [ + RayVolumeIntegral( + model=foreground_model, + volume=self.volume, + sampler=StratifiedRaySampler( + depth_mode=options.foreground_stratified_depth_sampling_mode + ), + n_samples=options.n_samples, + ), + ] + if options.render_background and self.outer_volume is not None: + background_model = partial( + self.background_model, + params=subdict(params, "background_model"), + options=options, + ) + parts.append( + RayVolumeIntegral( + model=background_model, + volume=self.outer_volume, + sampler=StratifiedRaySampler( + depth_mode=options.background_stratified_depth_sampling_mode + ), + n_samples=options.n_samples, + ) + ) + results, *_ = render_rays( + batch.rays, + parts, + self.void_model, + render_with_direction=options.render_with_direction, + ) + + return AttrDict( + channels=results.output.channels * self.channel_scale, + distances=results.output.distances, + transmittance=results.transmittance, + t0=results.volume_range.t0, + t1=results.volume_range.t1, + intersected=results.volume_range.intersected, + aux_losses=results.output.aux_losses, + ) diff --git a/shap_e/models/nerstf/__pycache__/mlp.cpython-39.pyc b/shap_e/models/nerstf/__pycache__/mlp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77b74986339d797be9661b816ead87bd90b28d27 Binary files /dev/null and b/shap_e/models/nerstf/__pycache__/mlp.cpython-39.pyc differ diff --git a/shap_e/models/nerstf/__pycache__/renderer.cpython-39.pyc b/shap_e/models/nerstf/__pycache__/renderer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..076edc7c2eb7cd389b0b24996c0ac345052d9680 Binary files /dev/null and b/shap_e/models/nerstf/__pycache__/renderer.cpython-39.pyc differ diff --git a/shap_e/models/nerstf/mlp.py b/shap_e/models/nerstf/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..47d9e4fb99c8aade08763555431677b191fb3528 --- /dev/null +++ b/shap_e/models/nerstf/mlp.py @@ -0,0 +1,174 @@ +from typing import Any, Dict, Optional, Tuple + +import torch + +from shap_e.models.nn.ops import get_act +from shap_e.models.query import Query +from shap_e.models.stf.mlp import MLPModel +from shap_e.util.collections import AttrDict + + +class MLPDensitySDFModel(MLPModel): + def __init__( + self, + initial_bias: float = -0.1, + sdf_activation="tanh", + density_activation="exp", + **kwargs, + ): + super().__init__( + n_output=2, + output_activation="identity", + **kwargs, + ) + self.mlp[-1].bias[0].data.fill_(initial_bias) + self.sdf_activation = get_act(sdf_activation) + self.density_activation = get_act(density_activation) + + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict[str, Any]: + # query.direction is None typically for SDF models and training + h, _h_directionless = self._mlp( + query.position, query.direction, params=params, options=options + ) + h_sdf, h_density = h.split(1, dim=-1) + return AttrDict( + density=self.density_activation(h_density), + signed_distance=self.sdf_activation(h_sdf), + ) + + +class MLPNeRSTFModel(MLPModel): + def __init__( + self, + sdf_activation="tanh", + density_activation="exp", + channel_activation="sigmoid", + direction_dependent_shape: bool = True, # To be able to load old models. Set this to be False in future models. + separate_nerf_channels: bool = False, + separate_coarse_channels: bool = False, + initial_density_bias: float = 0.0, + initial_sdf_bias: float = -0.1, + **kwargs, + ): + h_map, h_directionless_map = indices_for_output_mode( + direction_dependent_shape=direction_dependent_shape, + separate_nerf_channels=separate_nerf_channels, + separate_coarse_channels=separate_coarse_channels, + ) + n_output = index_mapping_max(h_map) + super().__init__( + n_output=n_output, + output_activation="identity", + **kwargs, + ) + self.direction_dependent_shape = direction_dependent_shape + self.separate_nerf_channels = separate_nerf_channels + self.separate_coarse_channels = separate_coarse_channels + self.sdf_activation = get_act(sdf_activation) + self.density_activation = get_act(density_activation) + self.channel_activation = get_act(channel_activation) + self.h_map = h_map + self.h_directionless_map = h_directionless_map + self.mlp[-1].bias.data.zero_() + layer = -1 if self.direction_dependent_shape else self.insert_direction_at + self.mlp[layer].bias[0].data.fill_(initial_sdf_bias) + self.mlp[layer].bias[1].data.fill_(initial_density_bias) + + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict[str, Any]: + + options = AttrDict() if options is None else AttrDict(options) + h, h_directionless = self._mlp( + query.position, query.direction, params=params, options=options + ) + activations = map_indices_to_keys(self.h_map, h) + activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless)) + + if options.nerf_level == "coarse": + h_density = activations.density_coarse + else: + h_density = activations.density_fine + + if options.get("rendering_mode", "stf") == "nerf": + if options.nerf_level == "coarse": + h_channels = activations.nerf_coarse + else: + h_channels = activations.nerf_fine + else: + h_channels = activations.stf + return AttrDict( + density=self.density_activation(h_density), + signed_distance=self.sdf_activation(activations.sdf), + channels=self.channel_activation(h_channels), + ) + + +IndexMapping = AttrDict[str, Tuple[int, int]] + + +def indices_for_output_mode( + direction_dependent_shape: bool, + separate_nerf_channels: bool, + separate_coarse_channels: bool, +) -> Tuple[IndexMapping, IndexMapping]: + """ + Get output mappings for (h, h_directionless). + """ + h_map = AttrDict() + h_directionless_map = AttrDict() + if direction_dependent_shape: + h_map.sdf = (0, 1) + if separate_coarse_channels: + assert separate_nerf_channels + h_map.density_coarse = (1, 2) + h_map.density_fine = (2, 3) + h_map.stf = (3, 6) + h_map.nerf_coarse = (6, 9) + h_map.nerf_fine = (9, 12) + else: + h_map.density_coarse = (1, 2) + h_map.density_fine = (1, 2) + if separate_nerf_channels: + h_map.stf = (2, 5) + h_map.nerf_coarse = (5, 8) + h_map.nerf_fine = (5, 8) + else: + h_map.stf = (2, 5) + h_map.nerf_coarse = (2, 5) + h_map.nerf_fine = (2, 5) + else: + h_directionless_map.sdf = (0, 1) + h_directionless_map.density_coarse = (1, 2) + if separate_coarse_channels: + h_directionless_map.density_fine = (2, 3) + else: + h_directionless_map.density_fine = h_directionless_map.density_coarse + h_map.stf = (0, 3) + if separate_coarse_channels: + assert separate_nerf_channels + h_map.nerf_coarse = (3, 6) + h_map.nerf_fine = (6, 9) + else: + if separate_nerf_channels: + h_map.nerf_coarse = (3, 6) + else: + h_map.nerf_coarse = (0, 3) + h_map.nerf_fine = h_map.nerf_coarse + return h_map, h_directionless_map + + +def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]: + return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()}) + + +def index_mapping_max(mapping: IndexMapping) -> int: + return max(end for _, (_, end) in mapping.items()) diff --git a/shap_e/models/nerstf/renderer.py b/shap_e/models/nerstf/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..e6742c94ad7b0ec1c528491a7b60bc381e023a72 --- /dev/null +++ b/shap_e/models/nerstf/renderer.py @@ -0,0 +1,293 @@ +from functools import partial +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch + +from shap_e.models.nerf.model import NeRFModel +from shap_e.models.nerf.ray import RayVolumeIntegral, StratifiedRaySampler, render_rays +from shap_e.models.nn.meta import subdict +from shap_e.models.nn.utils import to_torch +from shap_e.models.query import Query +from shap_e.models.renderer import RayRenderer, render_views_from_rays +from shap_e.models.stf.base import Model +from shap_e.models.stf.renderer import STFRendererBase, render_views_from_stf +from shap_e.models.volume import BoundingBoxVolume, Volume +from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR +from shap_e.util.collections import AttrDict + + +class NeRSTFRenderer(RayRenderer, STFRendererBase): + def __init__( + self, + sdf: Optional[Model], + tf: Optional[Model], + nerstf: Optional[Model], + void: NeRFModel, + volume: Volume, + grid_size: int, + n_coarse_samples: int, + n_fine_samples: int, + importance_sampling_options: Optional[Dict[str, Any]] = None, + separate_shared_samples: bool = False, + texture_channels: Sequence[str] = ("R", "G", "B"), + channel_scale: Sequence[float] = (255.0, 255.0, 255.0), + ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, + diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, + specular_color: Union[float, Tuple[float]] = 0.0, + output_srgb: bool = True, + device: torch.device = torch.device("cuda"), + **kwargs, + ): + super().__init__(**kwargs) + assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume" + assert (nerstf is not None) ^ (sdf is not None and tf is not None) + self.sdf = sdf + self.tf = tf + self.nerstf = nerstf + self.void = void + self.volume = volume + self.grid_size = grid_size + self.n_coarse_samples = n_coarse_samples + self.n_fine_samples = n_fine_samples + self.importance_sampling_options = AttrDict(importance_sampling_options or {}) + self.separate_shared_samples = separate_shared_samples + self.texture_channels = texture_channels + self.channel_scale = to_torch(channel_scale).to(device) + self.ambient_color = ambient_color + self.diffuse_color = diffuse_color + self.specular_color = specular_color + self.output_srgb = output_srgb + self.device = device + self.patch_size=128 + self.use_patch=False + self.to(device) + + def _query( + self, + query: Query, + params: AttrDict[str, torch.Tensor], + options: AttrDict[str, Any], + ) -> AttrDict: + no_dir_query = query.copy() + no_dir_query.direction = None + + if options.get("rendering_mode", "stf") == "stf": + assert query.direction is None + + if self.nerstf is not None: + sdf = tf = self.nerstf( + query, + params=subdict(params, "nerstf"), + options=options, + ) + else: + sdf = self.sdf(no_dir_query, params=subdict(params, "sdf"), options=options) + tf = self.tf(query, params=subdict(params, "tf"), options=options) + + return AttrDict( + density=sdf.density, + signed_distance=sdf.signed_distance, + channels=tf.channels, + aux_losses=dict(), + ) + + def render_rays( + self, + batch: AttrDict, + params: Optional[Dict] = None, + options: Optional[AttrDict] = None, + ) -> AttrDict: + """ + :param batch: has + + - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. + :param options: Optional[Dict] + """ + params = self.update(params) + options = AttrDict() if options is None else AttrDict(options) + + # Necessary to tell the TF to use specific NeRF channels. + options.rendering_mode = "nerf" + + model = partial(self._query, params=params, options=options) + + # First, render rays with coarse, stratified samples. + options.nerf_level = "coarse" + parts = [ + RayVolumeIntegral( + model=model, + volume=self.volume, + sampler=StratifiedRaySampler(), + n_samples=self.n_coarse_samples, + ), + ] + coarse_results, samplers, coarse_raw_outputs = render_rays( + batch.rays, + parts, + self.void, + shared=not self.separate_shared_samples, + render_with_direction=options.render_with_direction, + importance_sampling_options=self.importance_sampling_options, + ) + + # Then, render with additional importance-weighted ray samples. + options.nerf_level = "fine" + parts = [ + RayVolumeIntegral( + model=model, + volume=self.volume, + sampler=samplers[0], + n_samples=self.n_fine_samples, + ), + ] + fine_results, _, raw_outputs = render_rays( + batch.rays, + parts, + self.void, + shared=not self.separate_shared_samples, + prev_raw_outputs=coarse_raw_outputs, + render_with_direction=options.render_with_direction, + ) + raw = raw_outputs[0] + + aux_losses = fine_results.output.aux_losses.copy() + if self.separate_shared_samples: + for key, val in coarse_results.output.aux_losses.items(): + aux_losses[key + "_coarse"] = val + + channels = fine_results.output.channels + shape = [1] * (channels.ndim - 1) + [len(self.texture_channels)] + channels = channels * self.channel_scale.view(*shape) + + res = AttrDict( + channels=channels, + transmittance=fine_results.transmittance, + raw_signed_distance=raw.signed_distance, + raw_density=raw.density, + distances=fine_results.output.distances, + t0=fine_results.volume_range.t0, + t1=fine_results.volume_range.t1, + intersected=fine_results.volume_range.intersected, + aux_losses=aux_losses, + ) + + if self.separate_shared_samples: + res.update( + dict( + channels_coarse=( + coarse_results.output.channels * self.channel_scale.view(*shape) + ), + distances_coarse=coarse_results.output.distances, + transmittance_coarse=coarse_results.transmittance, + ) + ) + + return res + + def render_views( + self, + batch: AttrDict, + params: Optional[Dict] = None, + options: Optional[AttrDict] = None, + ) -> AttrDict: + """ + Returns a backproppable rendering of a view + + :param batch: contains either ["poses", "camera"], or ["cameras"]. Can + optionally contain any of ["height", "width", "query_batch_size"] + + :param params: Meta parameters + contains rendering_mode in ["stf", "nerf"] + :param options: controls checkpointing, caching, and rendering. + Can provide a `rendering_mode` in ["stf", "nerf"] + """ + params = self.update(params) + options = AttrDict() if options is None else AttrDict(options) + + if options.cache is None: + created_cache = True + options.cache = AttrDict() + else: + created_cache = False + + rendering_mode = options.get("rendering_mode", "stf") + # import pdb; pdb.set_trace() + if rendering_mode == "nerf": + + output = render_views_from_rays( + self.render_rays, + batch, + params=params, + options=options, + device=self.device, + patch_size=self.patch_size, + use_patch=self.use_patch + ) + + elif rendering_mode == "stf": + + sdf_fn = tf_fn = nerstf_fn = None + if self.nerstf is not None: + nerstf_fn = partial( + self.nerstf.forward_batched, + params=subdict(params, "nerstf"), + options=options, + ) + else: + sdf_fn = partial( + self.sdf.forward_batched, + params=subdict(params, "sdf"), + options=options, + ) + tf_fn = partial( + self.tf.forward_batched, + params=subdict(params, "tf"), + options=options, + ) + output = render_views_from_stf( + batch, + options, + sdf_fn=sdf_fn, + tf_fn=tf_fn, + nerstf_fn=nerstf_fn, + volume=self.volume, + grid_size=self.grid_size, + channel_scale=self.channel_scale, + texture_channels=self.texture_channels, + ambient_color=self.ambient_color, + diffuse_color=self.diffuse_color, + specular_color=self.specular_color, + output_srgb=self.output_srgb, + device=self.device, + ) + + else: + + raise NotImplementedError + + if created_cache: + del options["cache"] + + return output + + def get_signed_distance( + self, + query: Query, + params: Dict[str, torch.Tensor], + options: AttrDict[str, Any], + ) -> torch.Tensor: + if self.sdf is not None: + return self.sdf(query, params=subdict(params, "sdf"), options=options).signed_distance + assert self.nerstf is not None + return self.nerstf(query, params=subdict(params, "nerstf"), options=options).signed_distance + + def get_texture( + self, + query: Query, + params: Dict[str, torch.Tensor], + options: AttrDict[str, Any], + ) -> torch.Tensor: + if self.tf is not None: + return self.tf(query, params=subdict(params, "tf"), options=options).channels + assert self.nerstf is not None + return self.nerstf(query, params=subdict(params, "nerstf"), options=options).channels diff --git a/shap_e/models/nn/__init__.py b/shap_e/models/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba38290084daf41297fa2876ad9ffe17d47a3a24 --- /dev/null +++ b/shap_e/models/nn/__init__.py @@ -0,0 +1,2 @@ +from .meta import * +from .ops import * diff --git a/shap_e/models/nn/__pycache__/__init__.cpython-39.pyc b/shap_e/models/nn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2fa13bcc2dddba79a2ee4cdd5892cfcf80ccb8d Binary files /dev/null and b/shap_e/models/nn/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/models/nn/__pycache__/camera.cpython-39.pyc b/shap_e/models/nn/__pycache__/camera.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d86ac4d56a39db5678851e7683c4a8463e417bc2 Binary files /dev/null and b/shap_e/models/nn/__pycache__/camera.cpython-39.pyc differ diff --git a/shap_e/models/nn/__pycache__/checkpoint.cpython-39.pyc b/shap_e/models/nn/__pycache__/checkpoint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c376028f997a2167319b92c189a8f83518d9f850 Binary files /dev/null and b/shap_e/models/nn/__pycache__/checkpoint.cpython-39.pyc differ diff --git a/shap_e/models/nn/__pycache__/encoding.cpython-39.pyc b/shap_e/models/nn/__pycache__/encoding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34316e112aa4c93ec7e850df147612ec5fbfc996 Binary files /dev/null and b/shap_e/models/nn/__pycache__/encoding.cpython-39.pyc differ diff --git a/shap_e/models/nn/__pycache__/meta.cpython-39.pyc b/shap_e/models/nn/__pycache__/meta.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9011e4413690208e1109eb82ee1aef80e384ecab Binary files /dev/null and b/shap_e/models/nn/__pycache__/meta.cpython-39.pyc differ diff --git a/shap_e/models/nn/__pycache__/ops.cpython-39.pyc b/shap_e/models/nn/__pycache__/ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc05ef613902b32ba211fdd1db7389a9c21ec98f Binary files /dev/null and b/shap_e/models/nn/__pycache__/ops.cpython-39.pyc differ diff --git a/shap_e/models/nn/__pycache__/pointnet2_utils.cpython-39.pyc b/shap_e/models/nn/__pycache__/pointnet2_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb8c5e9eee21b8c0beff6fc5dba1c9e3f62a8a5a Binary files /dev/null and b/shap_e/models/nn/__pycache__/pointnet2_utils.cpython-39.pyc differ diff --git a/shap_e/models/nn/__pycache__/utils.cpython-39.pyc b/shap_e/models/nn/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..846efd4319cc6736d8c5476cb4dee13c8013b5b5 Binary files /dev/null and b/shap_e/models/nn/__pycache__/utils.cpython-39.pyc differ diff --git a/shap_e/models/nn/camera.py b/shap_e/models/nn/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..90f871821346bcbfa0aa852b409bf073c786f7fc --- /dev/null +++ b/shap_e/models/nn/camera.py @@ -0,0 +1,208 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from shap_e.rendering.view_data import ProjectiveCamera + + +@dataclass +class DifferentiableCamera(ABC): + """ + An object describing how a camera corresponds to pixels in an image. + """ + + @abstractmethod + def camera_rays(self, coords: torch.Tensor) -> torch.Tensor: + """ + For every (x, y) coordinate in a rendered image, compute the ray of the + corresponding pixel. + + :param coords: an [N x ... x 2] integer array of 2D image coordinates. + :return: an [N x ... x 2 x 3] array of [2 x 3] (origin, direction) tuples. + The direction should always be unit length. + """ + + @abstractmethod + def resize_image(self, width: int, height: int) -> "DifferentiableCamera": + """ + Creates a new camera with the same intrinsics and direction as this one, + but with resized image dimensions. + """ + + +@dataclass +class DifferentiableProjectiveCamera(DifferentiableCamera): + """ + Implements a batch, differentiable, standard pinhole camera + """ + + origin: torch.Tensor # [batch_size x 3] + x: torch.Tensor # [batch_size x 3] + y: torch.Tensor # [batch_size x 3] + z: torch.Tensor # [batch_size x 3] + width: int + height: int + x_fov: float + y_fov: float + + def __post_init__(self): + assert self.x.shape[0] == self.y.shape[0] == self.z.shape[0] == self.origin.shape[0] + assert self.x.shape[1] == self.y.shape[1] == self.z.shape[1] == self.origin.shape[1] == 3 + assert ( + len(self.x.shape) + == len(self.y.shape) + == len(self.z.shape) + == len(self.origin.shape) + == 2 + ) + + def resolution(self): + return torch.from_numpy(np.array([self.width, self.height], dtype=np.float32)) + + def fov(self): + return torch.from_numpy(np.array([self.x_fov, self.y_fov], dtype=np.float32)) + + def image_coords(self) -> torch.Tensor: + """ + :return: coords of shape (width * height, 2) + """ + pixel_indices = torch.arange(self.height * self.width) + coords = torch.stack( + [ + pixel_indices % self.width, + torch.div(pixel_indices, self.width, rounding_mode="trunc"), + ], + axis=1, + ) + return coords + + def camera_rays(self, coords: torch.Tensor) -> torch.Tensor: + # import pdb; pdb.set_trace() + batch_size, *shape, n_coords = coords.shape + assert n_coords == 2 + assert batch_size == self.origin.shape[0] + flat = coords.view(batch_size, -1, 2) + + res = self.resolution().to(flat.device) + fov = self.fov().to(flat.device) + + fracs = (flat.float() / (res - 1)) * 2 - 1 + fracs = fracs * torch.tan(fov / 2) + + fracs = fracs.view(batch_size, -1, 2) + directions = ( + self.z.view(batch_size, 1, 3) + + self.x.view(batch_size, 1, 3) * fracs[:, :, :1] + + self.y.view(batch_size, 1, 3) * fracs[:, :, 1:] + ) + directions = directions / directions.norm(dim=-1, keepdim=True) + rays = torch.stack( + [ + torch.broadcast_to( + self.origin.view(batch_size, 1, 3), [batch_size, directions.shape[1], 3] + ), + directions, + ], + dim=2, + ) + return rays.view(batch_size, *shape, 2, 3) + + def resize_image(self, width: int, height: int) -> "DifferentiableProjectiveCamera": + """ + Creates a new camera for the resized view assuming the aspect ratio does not change. + """ + assert width * self.height == height * self.width, "The aspect ratio should not change." + return DifferentiableProjectiveCamera( + origin=self.origin, + x=self.x, + y=self.y, + z=self.z, + width=width, + height=height, + x_fov=self.x_fov, + y_fov=self.y_fov, + ) + + +@dataclass +class DifferentiableCameraBatch(ABC): + """ + Annotate a differentiable camera with a multi-dimensional batch shape. + """ + + shape: Tuple[int] + flat_camera: DifferentiableCamera + + +def normalize(vec: torch.Tensor) -> torch.Tensor: + return vec / vec.norm(dim=-1, keepdim=True) + + +def project_out(vec1: torch.Tensor, vec2: torch.Tensor) -> torch.Tensor: + """ + Removes the vec2 component from vec1 + """ + vec2 = normalize(vec2) + proj = (vec1 * vec2).sum(dim=-1, keepdim=True) + return vec1 - proj * vec2 + + +def camera_orientation(toward: torch.Tensor, up: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + :param toward: [batch_size x 3] unit vector from camera position to the object + :param up: Optional [batch_size x 3] specifying the physical up direction in the world frame. + :return: [batch_size x 3 x 3] + """ + + if up is None: + up = torch.zeros_like(toward) + up[:, 2] = 1 + + assert len(toward.shape) == 2 + assert toward.shape[1] == 3 + + assert len(up.shape) == 2 + assert up.shape[1] == 3 + + z = toward / toward.norm(dim=-1, keepdim=True) + y = -normalize(project_out(up, toward)) + x = torch.cross(y, z, dim=1) + return torch.stack([x, y, z], dim=1) + + +def projective_camera_frame( + origin: torch.Tensor, + toward: torch.Tensor, + camera_params: Union[ProjectiveCamera, DifferentiableProjectiveCamera], +) -> DifferentiableProjectiveCamera: + """ + Given the origin and the direction of a view, return a differentiable + projective camera with the given parameters. + + TODO: We need to support the rotation of the camera frame about the + `toward` vector to fully implement 6 degrees of freedom. + """ + rot = camera_orientation(toward) + camera = DifferentiableProjectiveCamera( + origin=origin, + x=rot[:, 0], + y=rot[:, 1], + z=rot[:, 2], + width=camera_params.width, + height=camera_params.height, + x_fov=camera_params.x_fov, + y_fov=camera_params.y_fov, + ) + return camera + + +@torch.no_grad() +def get_image_coords(width, height) -> torch.Tensor: + pixel_indices = torch.arange(height * width) + # torch throws warnings for pixel_indices // width + pixel_indices_div = torch.div(pixel_indices, width, rounding_mode="trunc") + coords = torch.stack([pixel_indices % width, pixel_indices_div], dim=1) + return coords diff --git a/shap_e/models/nn/checkpoint.py b/shap_e/models/nn/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ec4b59e8e7e17d4fe897b533ba0d182e4b20b23f --- /dev/null +++ b/shap_e/models/nn/checkpoint.py @@ -0,0 +1,116 @@ +from typing import Callable, Iterable, Sequence, Union + +import torch +from torch.cuda.amp import custom_bwd, custom_fwd + + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, +): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.length = length + input_tensors = list(args[:length]) + input_params = list(args[length:]) + ctx.save_for_backward(*input_tensors, *input_params) + with torch.no_grad(): + output_tensors = ctx.run_function(*input_tensors) + return output_tensors + + @staticmethod + @custom_bwd + def backward(ctx, *output_grads): + inputs = ctx.saved_tensors + input_tensors = inputs[: ctx.length] + input_params = inputs[ctx.length :] + res = CheckpointFunctionGradFunction.apply( + ctx.run_function, + len(input_tensors), + len(input_params), + *input_tensors, + *input_params, + *output_grads + ) + return (None, None) + res + + +class CheckpointFunctionGradFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, run_function, length_1, length_2, *args): + ctx.run_function = run_function + ctx.length_1 = length_1 + ctx.length_2 = length_2 + input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]] + input_params = list(args[length_1 : length_1 + length_2]) + output_grads = list(args[length_1 + length_2 :]) + ctx.save_for_backward(*input_tensors, *input_params, *output_grads) + + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + input_tensors + input_params, + output_grads, + allow_unused=True, + ) + return input_grads + + @staticmethod + @custom_bwd + def backward(ctx, *all_output_grads): + args = ctx.saved_tensors + input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]] + input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2]) + output_grads = [ + x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :] + ] + + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + input_tensors + input_params, + output_grads, + allow_unused=True, + create_graph=True, + retain_graph=True, + ) + input_grads_grads = torch.autograd.grad( + input_grads, + input_tensors + input_params + output_grads, + all_output_grads, + allow_unused=True, + ) + del input_grads + return (None, None, None) + input_grads_grads diff --git a/shap_e/models/nn/encoding.py b/shap_e/models/nn/encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5200b99c0ece8b5319148f16dea940628ca2bf --- /dev/null +++ b/shap_e/models/nn/encoding.py @@ -0,0 +1,458 @@ +import math +from functools import lru_cache +from typing import Optional + +import torch +import torch.nn as nn + + +def encode_position(version: str, *, position: torch.Tensor): + if version == "v1": + freqs = get_scales(0, 10, position.dtype, position.device).view(1, -1) + freqs = position.reshape(-1, 1) * freqs + return torch.cat([freqs.cos(), freqs.sin()], dim=1).reshape(*position.shape[:-1], -1) + elif version == "nerf": + return posenc_nerf(position, min_deg=0, max_deg=15) + else: + raise ValueError(version) + + +def encode_channels(version: str, *, channels: torch.Tensor): + if version == "v1": + freqs = get_scales(0, 10, channels.dtype, channels.device).view(1, -1) + freqs = channels.reshape(-1, 1) * freqs + return torch.cat([freqs.cos(), freqs.sin()], dim=1).reshape(*channels.shape[:-1], -1) + elif version == "nerf": + return posenc_nerf(channels, min_deg=0, max_deg=15) + else: + raise ValueError(version) + + +def position_encoding_channels(version: Optional[str] = None) -> int: + if version is None: + return 1 + return encode_position(version, position=torch.zeros(1, 1)).shape[-1] + + +def channel_encoding_channels(version: Optional[str] = None) -> int: + if version is None: + return 1 + return encode_channels(version, channels=torch.zeros(1, 1)).shape[-1] + + +class PosEmbLinear(nn.Linear): + def __init__( + self, posemb_version: Optional[str], in_features: int, out_features: int, **kwargs + ): + super().__init__( + in_features * position_encoding_channels(posemb_version), + out_features, + **kwargs, + ) + self.posemb_version = posemb_version + + def forward(self, x: torch.Tensor): + if self.posemb_version is not None: + x = encode_position(self.posemb_version, position=x) + return super().forward(x) + + +class MultiviewPoseEmbedding(nn.Conv2d): + def __init__( + self, + posemb_version: Optional[str], + n_channels: int, + out_features: int, + stride: int = 1, + **kwargs, + ): + in_features = ( + n_channels * channel_encoding_channels(version=posemb_version) + + 3 * position_encoding_channels(version=posemb_version) + + 3 * position_encoding_channels(version=posemb_version) + ) + super().__init__( + in_features, + out_features, + kernel_size=3, + stride=stride, + padding=1, + **kwargs, + ) + self.posemb_version = posemb_version + + def forward( + self, channels: torch.Tensor, position: torch.Tensor, direction: torch.Tensor + ) -> torch.Tensor: + """ + :param channels: [batch_shape, inner_batch_shape, n_channels, height, width] + :param position: [batch_shape, inner_batch_shape, 3, height, width] + :param direction: [batch_shape, inner_batch_shape, 3, height, width] + :return: [*batch_shape, out_features, height, width] + """ + + if self.posemb_version is not None: + channels = channels.permute(0, 1, 3, 4, 2) + position = position.permute(0, 1, 3, 4, 2) + direction = direction.permute(0, 1, 3, 4, 2) + channels = encode_channels(self.posemb_version, channels=channels).permute( + 0, 1, 4, 2, 3 + ) + direction = maybe_encode_direction( + self.posemb_version, position=position, direction=direction + ).permute(0, 1, 4, 2, 3) + position = encode_position(self.posemb_version, position=position).permute( + 0, 1, 4, 2, 3 + ) + x = torch.cat([channels, position, direction], dim=-3) + *batch_shape, in_features, height, width = x.shape + return ( + super() + .forward(x.view(-1, in_features, height, width)) + .view(*batch_shape, -1, height, width) + ) + + +class MultiviewPointCloudEmbedding(nn.Conv2d): + def __init__( + self, + posemb_version: Optional[str], + n_channels: int, + out_features: int, + stride: int = 1, + **kwargs, + ): + in_features = ( + n_channels * channel_encoding_channels(version=posemb_version) + + 3 * position_encoding_channels(version=posemb_version) + + 3 * position_encoding_channels(version=posemb_version) + ) + super().__init__( + in_features, + out_features, + kernel_size=3, + stride=stride, + padding=1, + **kwargs, + ) + self.posemb_version = posemb_version + self.register_parameter( + "unk_token", nn.Parameter(torch.randn(in_features, **kwargs) * 0.01) + ) + self.unk_token: torch.Tensor + + def forward( + self, + channels: torch.Tensor, + origin: torch.Tensor, + position: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + """ + :param channels: [batch_shape, inner_batch_shape, n_channels, height, width] + :param origin: [batch_shape, inner_batch_shape, 3, height, width] + :param position: [batch_shape, inner_batch_shape, 3, height, width] + :return: [*batch_shape, out_features, height, width] + """ + + if self.posemb_version is not None: + channels = channels.permute(0, 1, 3, 4, 2) + origin = origin.permute(0, 1, 3, 4, 2) + position = position.permute(0, 1, 3, 4, 2) + channels = encode_channels(self.posemb_version, channels=channels).permute( + 0, 1, 4, 2, 3 + ) + origin = encode_position(self.posemb_version, position=origin).permute(0, 1, 4, 2, 3) + position = encode_position(self.posemb_version, position=position).permute( + 0, 1, 4, 2, 3 + ) + x = torch.cat([channels, origin, position], dim=-3) + unk_token = torch.broadcast_to(self.unk_token.view(1, 1, -1, 1, 1), x.shape) + x = torch.where(mask, x, unk_token) + *batch_shape, in_features, height, width = x.shape + return ( + super() + .forward(x.view(-1, in_features, height, width)) + .view(*batch_shape, -1, height, width) + ) + + +def maybe_encode_direction( + version: str, + *, + position: torch.Tensor, + direction: Optional[torch.Tensor] = None, +): + + if version == "v1": + sh_degree = 4 + if direction is None: + return torch.zeros(*position.shape[:-1], sh_degree**2).to(position) + return spherical_harmonics_basis(direction, sh_degree=sh_degree) + elif version == "nerf": + if direction is None: + return torch.zeros_like(posenc_nerf(position, min_deg=0, max_deg=8)) + return posenc_nerf(direction, min_deg=0, max_deg=8) + else: + raise ValueError(version) + + +def posenc_nerf(x: torch.Tensor, min_deg: int = 0, max_deg: int = 15) -> torch.Tensor: + """ + Concatenate x and its positional encodings, following NeRF. + + Reference: https://arxiv.org/pdf/2210.04628.pdf + """ + if min_deg == max_deg: + return x + scales = get_scales(min_deg, max_deg, x.dtype, x.device) + *shape, dim = x.shape + xb = (x.reshape(-1, 1, dim) * scales.view(1, -1, 1)).reshape(*shape, -1) + assert xb.shape[-1] == dim * (max_deg - min_deg) + emb = torch.cat([xb, xb + math.pi / 2.0], axis=-1).sin() + return torch.cat([x, emb], dim=-1) + + +@lru_cache +def get_scales( + min_deg: int, + max_deg: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + return 2.0 ** torch.arange(min_deg, max_deg, device=device, dtype=dtype) + + +def spherical_harmonics_basis( + coords: torch.Tensor, + sh_degree: int, +) -> torch.Tensor: + """ + Calculate the spherical harmonics basis + + :param coords: [batch_size, *shape, 3] of unit norm + :param sh_degree: Spherical harmonics degree + :return: [batch_size, *shape, sh_degree**2] + """ + if sh_degree > 8: + raise NotImplementedError + + batch_size, *shape, _ = coords.shape + x, y, z = coords.reshape(-1, 3).split(1, dim=-1) + x = x.squeeze(dim=-1) + y = y.squeeze(dim=-1) + z = z.squeeze(dim=-1) + + xy, xz, yz = x * y, x * z, y * z + x2, y2, z2 = x * x, y * y, z * z + x4, y4, z4 = x2 * x2, y2 * y2, z2 * z2 + x6, y6, z6 = x4 * x2, y4 * y2, z4 * z2 + xyz = xy * z + + # https://github.com/NVlabs/tiny-cuda-nn/blob/8575542682cb67cddfc748cc3d3cfc12593799aa/include/tiny-cuda-nn/encodings/spherical_harmonics.h#L76 + + out = torch.zeros(x.shape[0], sh_degree**2, dtype=x.dtype, device=x.device) + + def _sh(): + out[:, 0] = 0.28209479177387814 # 1/(2*sqrt(pi)) + if sh_degree <= 1: + return + out[:, 1] = -0.48860251190291987 * y # -sqrt(3)*y/(2*sqrt(pi)) + out[:, 2] = 0.48860251190291987 * z # sqrt(3)*z/(2*sqrt(pi)) + out[:, 3] = -0.48860251190291987 * x # -sqrt(3)*x/(2*sqrt(pi)) + if sh_degree <= 2: + return + out[:, 4] = 1.0925484305920792 * xy # sqrt(15)*xy/(2*sqrt(pi)) + out[:, 5] = -1.0925484305920792 * yz # -sqrt(15)*yz/(2*sqrt(pi)) + out[:, 6] = ( + 0.94617469575755997 * z2 - 0.31539156525251999 + ) # sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) + out[:, 7] = -1.0925484305920792 * xz # -sqrt(15)*xz/(2*sqrt(pi)) + out[:, 8] = ( + 0.54627421529603959 * x2 - 0.54627421529603959 * y2 + ) # sqrt(15)*(x2 - y2)/(4*sqrt(pi)) + if sh_degree <= 3: + return + out[:, 9] = ( + 0.59004358992664352 * y * (-3.0 * x2 + y2) + ) # sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + out[:, 10] = 2.8906114426405538 * xy * z # sqrt(105)*xy*z/(2*sqrt(pi)) + out[:, 11] = ( + 0.45704579946446572 * y * (1.0 - 5.0 * z2) + ) # sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) + out[:, 12] = 0.3731763325901154 * z * (5.0 * z2 - 3.0) # sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) + out[:, 13] = ( + 0.45704579946446572 * x * (1.0 - 5.0 * z2) + ) # sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) + out[:, 14] = 1.4453057213202769 * z * (x2 - y2) # sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) + out[:, 15] = ( + 0.59004358992664352 * x * (-x2 + 3.0 * y2) + ) # sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + if sh_degree <= 4: + return + out[:, 16] = 2.5033429417967046 * xy * (x2 - y2) # 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) + out[:, 17] = ( + 1.7701307697799304 * yz * (-3.0 * x2 + y2) + ) # 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) + out[:, 18] = ( + 0.94617469575756008 * xy * (7.0 * z2 - 1.0) + ) # 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) + out[:, 19] = ( + 0.66904654355728921 * yz * (3.0 - 7.0 * z2) + ) # 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) + out[:, 20] = ( + -3.1735664074561294 * z2 + 3.7024941420321507 * z4 + 0.31735664074561293 + ) # 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) + out[:, 21] = ( + 0.66904654355728921 * xz * (3.0 - 7.0 * z2) + ) # 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) + out[:, 22] = ( + 0.47308734787878004 * (x2 - y2) * (7.0 * z2 - 1.0) + ) # 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) + out[:, 23] = ( + 1.7701307697799304 * xz * (-x2 + 3.0 * y2) + ) # 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) + out[:, 24] = ( + -3.7550144126950569 * x2 * y2 + 0.62583573544917614 * x4 + 0.62583573544917614 * y4 + ) # 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + if sh_degree <= 5: + return + out[:, 25] = ( + 0.65638205684017015 * y * (10.0 * x2 * y2 - 5.0 * x4 - y4) + ) # 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + out[:, 26] = ( + 8.3026492595241645 * xy * z * (x2 - y2) + ) # 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) + out[:, 27] = ( + -0.48923829943525038 * y * (3.0 * x2 - y2) * (9.0 * z2 - 1.0) + ) # -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + out[:, 28] = ( + 4.7935367849733241 * xy * z * (3.0 * z2 - 1.0) + ) # sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) + out[:, 29] = ( + 0.45294665119569694 * y * (14.0 * z2 - 21.0 * z4 - 1.0) + ) # sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + out[:, 30] = ( + 0.1169503224534236 * z * (-70.0 * z2 + 63.0 * z4 + 15.0) + ) # sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) + out[:, 31] = ( + 0.45294665119569694 * x * (14.0 * z2 - 21.0 * z4 - 1.0) + ) # sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + out[:, 32] = ( + 2.3967683924866621 * z * (x2 - y2) * (3.0 * z2 - 1.0) + ) # sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) + out[:, 33] = ( + -0.48923829943525038 * x * (x2 - 3.0 * y2) * (9.0 * z2 - 1.0) + ) # -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) + out[:, 34] = ( + 2.0756623148810411 * z * (-6.0 * x2 * y2 + x4 + y4) + ) # 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + out[:, 35] = ( + 0.65638205684017015 * x * (10.0 * x2 * y2 - x4 - 5.0 * y4) + ) # 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + if sh_degree <= 6: + return + out[:, 36] = ( + 1.3663682103838286 * xy * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4) + ) # sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + out[:, 37] = ( + 2.3666191622317521 * yz * (10.0 * x2 * y2 - 5.0 * x4 - y4) + ) # 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + out[:, 38] = ( + 2.0182596029148963 * xy * (x2 - y2) * (11.0 * z2 - 1.0) + ) # 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + out[:, 39] = ( + -0.92120525951492349 * yz * (3.0 * x2 - y2) * (11.0 * z2 - 3.0) + ) # -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + out[:, 40] = ( + 0.92120525951492349 * xy * (-18.0 * z2 + 33.0 * z4 + 1.0) + ) # sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + out[:, 41] = ( + 0.58262136251873131 * yz * (30.0 * z2 - 33.0 * z4 - 5.0) + ) # sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + out[:, 42] = ( + 6.6747662381009842 * z2 + - 20.024298714302954 * z4 + + 14.684485723822165 * z6 + - 0.31784601133814211 + ) # sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) + out[:, 43] = ( + 0.58262136251873131 * xz * (30.0 * z2 - 33.0 * z4 - 5.0) + ) # sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + out[:, 44] = ( + 0.46060262975746175 * (x2 - y2) * (11.0 * z2 * (3.0 * z2 - 1.0) - 7.0 * z2 + 1.0) + ) # sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) + out[:, 45] = ( + -0.92120525951492349 * xz * (x2 - 3.0 * y2) * (11.0 * z2 - 3.0) + ) # -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) + out[:, 46] = ( + 0.50456490072872406 * (11.0 * z2 - 1.0) * (-6.0 * x2 * y2 + x4 + y4) + ) # 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + out[:, 47] = ( + 2.3666191622317521 * xz * (10.0 * x2 * y2 - x4 - 5.0 * y4) + ) # 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + out[:, 48] = ( + 10.247761577878714 * x2 * y4 + - 10.247761577878714 * x4 * y2 + + 0.6831841051919143 * x6 + - 0.6831841051919143 * y6 + ) # sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + if sh_degree <= 7: + return + out[:, 49] = ( + 0.70716273252459627 * y * (-21.0 * x2 * y4 + 35.0 * x4 * y2 - 7.0 * x6 + y6) + ) # 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) + out[:, 50] = ( + 5.2919213236038001 * xy * z * (-10.0 * x2 * y2 + 3.0 * x4 + 3.0 * y4) + ) # 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + out[:, 51] = ( + -0.51891557872026028 * y * (13.0 * z2 - 1.0) * (-10.0 * x2 * y2 + 5.0 * x4 + y4) + ) # -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) + out[:, 52] = ( + 4.1513246297620823 * xy * z * (x2 - y2) * (13.0 * z2 - 3.0) + ) # 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + out[:, 53] = ( + -0.15645893386229404 + * y + * (3.0 * x2 - y2) + * (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0) + ) # -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + out[:, 54] = ( + 0.44253269244498261 * xy * z * (-110.0 * z2 + 143.0 * z4 + 15.0) + ) # 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + out[:, 55] = ( + 0.090331607582517306 * y * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0) + ) # sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + out[:, 56] = ( + 0.068284276912004949 * z * (315.0 * z2 - 693.0 * z4 + 429.0 * z6 - 35.0) + ) # sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) + out[:, 57] = ( + 0.090331607582517306 * x * (-135.0 * z2 + 495.0 * z4 - 429.0 * z6 + 5.0) + ) # sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + out[:, 58] = ( + 0.07375544874083044 + * z + * (x2 - y2) + * (143.0 * z2 * (3.0 * z2 - 1.0) - 187.0 * z2 + 45.0) + ) # sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) + out[:, 59] = ( + -0.15645893386229404 + * x + * (x2 - 3.0 * y2) + * (13.0 * z2 * (11.0 * z2 - 3.0) - 27.0 * z2 + 3.0) + ) # -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + out[:, 60] = ( + 1.0378311574405206 * z * (13.0 * z2 - 3.0) * (-6.0 * x2 * y2 + x4 + y4) + ) # 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + out[:, 61] = ( + -0.51891557872026028 * x * (13.0 * z2 - 1.0) * (-10.0 * x2 * y2 + x4 + 5.0 * y4) + ) # -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) + out[:, 62] = ( + 2.6459606618019 * z * (15.0 * x2 * y4 - 15.0 * x4 * y2 + x6 - y6) + ) # 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + out[:, 63] = ( + 0.70716273252459627 * x * (-35.0 * x2 * y4 + 21.0 * x4 * y2 - x6 + 7.0 * y6) + ) # 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) + + _sh() + return out.view(batch_size, *shape, sh_degree**2) diff --git a/shap_e/models/nn/meta.py b/shap_e/models/nn/meta.py new file mode 100644 index 0000000000000000000000000000000000000000..4f36bd13768af913a1b827a51c2ad5abf73c5b05 --- /dev/null +++ b/shap_e/models/nn/meta.py @@ -0,0 +1,234 @@ +""" +Meta-learning modules based on: https://github.com/tristandeleu/pytorch-meta + +MIT License + +Copyright (c) 2019-2020 Tristan Deleu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import itertools +import re +from collections import OrderedDict + +import torch.nn as nn + +from shap_e.util.collections import AttrDict + +__all__ = [ + "MetaModule", + "subdict", + "superdict", + "leveldict", + "leveliter", + "batch_meta_parameters", + "batch_meta_state_dict", +] + + +def subdict(dictionary, key=None): + if dictionary is None: + return None + if (key is None) or (key == ""): + return dictionary + key_re = re.compile(r"^{0}\.(.+)".format(re.escape(key))) + return AttrDict( + OrderedDict( + (key_re.sub(r"\1", k), value) + for (k, value) in dictionary.items() + if key_re.match(k) is not None + ) + ) + + +def superdict(dictionary, key=None): + if dictionary is None: + return None + if (key is None) or (key == ""): + return dictionary + return AttrDict(OrderedDict((key + "." + k, value) for (k, value) in dictionary.items())) + + +def leveldict(dictionary, depth=0): + return AttrDict(leveliter(dictionary, depth=depth)) + + +def leveliter(dictionary, depth=0): + """ + depth == 0 is root + """ + for key, value in dictionary.items(): + if key.count(".") == depth: + yield key, value + + +class MetaModule(nn.Module): + """ + Base class for PyTorch meta-learning modules. These modules accept an + additional argument `params` in their `forward` method. + + Notes + ----- + Objects inherited from `MetaModule` are fully compatible with PyTorch + modules from `torch.nn.Module`. The argument `params` is a dictionary of + tensors, with full support of the computation graph (for differentiation). + + Based on SIREN's torchmeta with some additional features/changes. + + All meta weights must not have the batch dimension, as they are later tiled + to the given batch size after unsqueezing the first dimension (e.g. a + weight of dimension [d_out x d_in] is tiled to have the dimension [batch x + d_out x d_in]). Requiring all meta weights to have a batch dimension of 1 + (e.g. [1 x d_out x d_in] from the earlier example) could be a more natural + choice, but this results in silent failures. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._meta_state_dict = set() + self._meta_params = set() + + def register_meta_buffer(self, name: str, param: nn.Parameter): + """ + Registers a trainable or nontrainable parameter as a meta buffer. This + can be later retrieved by meta_state_dict + """ + self.register_buffer(name, param) + self._meta_state_dict.add(name) + + def register_meta_parameter(self, name: str, parameter: nn.Parameter): + """ + Registers a meta parameter so it is included in named_meta_parameters + and meta_state_dict. + """ + self.register_parameter(name, parameter) + self._meta_params.add(name) + self._meta_state_dict.add(name) + + def register_meta(self, name: str, parameter: nn.Parameter, trainable: bool = True): + if trainable: + self.register_meta_parameter(name, parameter) + else: + self.register_meta_buffer(name, parameter) + + def register(self, name: str, parameter: nn.Parameter, meta: bool, trainable: bool = True): + if meta: + if trainable: + self.register_meta_parameter(name, parameter) + else: + self.register_meta_buffer(name, parameter) + else: + if trainable: + self.register_parameter(name, parameter) + else: + self.register_buffer(name, parameter) + + def named_meta_parameters(self, prefix="", recurse=True): + """ + Returns an iterator over all the names and meta parameters + """ + + def meta_iterator(module): + meta = module._meta_params if isinstance(module, MetaModule) else set() + for name, param in module._parameters.items(): + if name in meta: + yield name, param + + gen = self._named_members( + meta_iterator, + prefix=prefix, + recurse=recurse, + ) + for name, param in gen: + yield name, param + + def named_nonmeta_parameters(self, prefix="", recurse=True): + def _iterator(module): + meta = module._meta_params if isinstance(module, MetaModule) else set() + for name, param in module._parameters.items(): + if name not in meta: + yield name, param + + gen = self._named_members( + _iterator, + prefix=prefix, + recurse=recurse, + ) + for name, param in gen: + yield name, param + + def nonmeta_parameters(self, prefix="", recurse=True): + for _, param in self.named_nonmeta_parameters(prefix=prefix, recurse=recurse): + yield param + + def meta_state_dict(self, prefix="", recurse=True): + """ + Returns an iterator over all the names and meta parameters/buffers. + + One difference between module.state_dict() is that this preserves + requires_grad, because we may want to compute the gradient w.r.t. meta + buffers, but don't necessarily update them automatically. + """ + + def meta_iterator(module): + meta = module._meta_state_dict if isinstance(module, MetaModule) else set() + for name, param in itertools.chain(module._buffers.items(), module._parameters.items()): + if name in meta: + yield name, param + + gen = self._named_members( + meta_iterator, + prefix=prefix, + recurse=recurse, + ) + return dict(gen) + + def update(self, params=None): + """ + Updates the parameter list before the forward prop so that if `params` + is None or doesn't have a certain key, the module uses the default + parameter/buffer registered in the module. + """ + # import pdb; pdb.set_trace() + if params is None: + params = AttrDict() + params = AttrDict(params) + named_params = set([name for name, _ in self.named_parameters()]) + for name, param in self.named_parameters(): + params.setdefault(name, param) + for name, param in self.state_dict().items(): + if name not in named_params: + params.setdefault(name, param) + return params + + +def batch_meta_parameters(net, batch_size): + params = AttrDict() + for name, param in net.named_meta_parameters(): + params[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape)) + return params + + +def batch_meta_state_dict(net, batch_size): + state_dict = AttrDict() + meta_parameters = set([name for name, _ in net.named_meta_parameters()]) + for name, param in net.meta_state_dict().items(): + state_dict[name] = param.clone().unsqueeze(0).repeat(batch_size, *[1] * len(param.shape)) + return state_dict diff --git a/shap_e/models/nn/ops.py b/shap_e/models/nn/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3fac80f644f44001fa447d017a8f5ab802a623d5 --- /dev/null +++ b/shap_e/models/nn/ops.py @@ -0,0 +1,410 @@ +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from shap_e.util.collections import AttrDict + +from .meta import MetaModule, subdict +from .pointnet2_utils import sample_and_group, sample_and_group_all + + +def gelu(x): + return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def swish(x): + return x * torch.sigmoid(x) + + +def quick_gelu(x): + return x * torch.sigmoid(1.702 * x) + + +def torch_gelu(x): + return torch.nn.functional.gelu(x) + + +def geglu(x): + v, gates = x.chunk(2, dim=-1) + return v * gelu(gates) + + +class SirenSin: + def __init__(self, w0=30.0): + self.w0 = w0 + + def __call__(self, x): + return torch.sin(self.w0 * x) + + +def get_act(name): + return { + "relu": torch.nn.functional.relu, + "leaky_relu": torch.nn.functional.leaky_relu, + "swish": swish, + "tanh": torch.tanh, + "gelu": gelu, + "quick_gelu": quick_gelu, + "torch_gelu": torch_gelu, + "gelu2": quick_gelu, + "geglu": geglu, + "sigmoid": torch.sigmoid, + "sin": torch.sin, + "sin30": SirenSin(w0=30.0), + "softplus": F.softplus, + "exp": torch.exp, + "identity": lambda x: x, + }[name] + + +def zero_init(affine): + nn.init.constant_(affine.weight, 0.0) + if affine.bias is not None: + nn.init.constant_(affine.bias, 0.0) + + +def siren_init_first_layer(affine, init_scale: float = 1.0): + n_input = affine.weight.shape[1] + u = init_scale / n_input + nn.init.uniform_(affine.weight, -u, u) + if affine.bias is not None: + nn.init.constant_(affine.bias, 0.0) + + +def siren_init(affine, coeff=1.0, init_scale: float = 1.0): + n_input = affine.weight.shape[1] + u = init_scale * np.sqrt(6.0 / n_input) / coeff + nn.init.uniform_(affine.weight, -u, u) + if affine.bias is not None: + nn.init.constant_(affine.bias, 0.0) + + +def siren_init_30(affine, init_scale: float = 1.0): + siren_init(affine, coeff=30.0, init_scale=init_scale) + + +def std_init(affine, init_scale: float = 1.0): + n_in = affine.weight.shape[1] + stddev = init_scale / math.sqrt(n_in) + nn.init.normal_(affine.weight, std=stddev) + if affine.bias is not None: + nn.init.constant_(affine.bias, 0.0) + + +def mlp_init(affines, init: Optional[str] = None, init_scale: float = 1.0): + if init == "siren30": + for idx, affine in enumerate(affines): + init = siren_init_first_layer if idx == 0 else siren_init_30 + init(affine, init_scale=init_scale) + elif init == "siren": + for idx, affine in enumerate(affines): + init = siren_init_first_layer if idx == 0 else siren_init + init(affine, init_scale=init_scale) + elif init is None: + for affine in affines: + std_init(affine, init_scale=init_scale) + else: + raise NotImplementedError(init) + + +class MetaLinear(MetaModule): + def __init__( + self, + n_in, + n_out, + bias: bool = True, + meta_scale: bool = True, + meta_shift: bool = True, + meta_proj: bool = False, + meta_bias: bool = False, + trainable_meta: bool = False, + **kwargs, + ): + super().__init__() + # n_in, n_out, bias=bias) + register_meta_fn = ( + self.register_meta_parameter if trainable_meta else self.register_meta_buffer + ) + if meta_scale: + register_meta_fn("scale", nn.Parameter(torch.ones(n_out, **kwargs))) + if meta_shift: + register_meta_fn("shift", nn.Parameter(torch.zeros(n_out, **kwargs))) + + register_proj_fn = self.register_parameter if not meta_proj else register_meta_fn + register_proj_fn("weight", nn.Parameter(torch.empty((n_out, n_in), **kwargs))) + + if not bias: + self.register_parameter("bias", None) + else: + register_bias_fn = self.register_parameter if not meta_bias else register_meta_fn + register_bias_fn("bias", nn.Parameter(torch.empty(n_out, **kwargs))) + + self.reset_parameters() + + def reset_parameters(self) -> None: + + # from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear + + # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with + # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see + # https://github.com/pytorch/pytorch/issues/57109 + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + def _bcast(self, op, left, right): + if right.ndim == 2: + # Has dimension [batch x d_output] + right = right.unsqueeze(1) + return op(left, right) + + def forward(self, x, params=None): + params = self.update(params) + + batch_size, *shape, d_in = x.shape + x = x.view(batch_size, -1, d_in) + + if params.weight.ndim == 2: + h = torch.einsum("bni,oi->bno", x, params.weight) + elif params.weight.ndim == 3: + h = torch.einsum("bni,boi->bno", x, params.weight) + + if params.bias is not None: + h = self._bcast(torch.add, h, params.bias) + + if params.scale is not None: + h = self._bcast(torch.mul, h, params.scale) + + if params.shift is not None: + h = self._bcast(torch.add, h, params.shift) + + h = h.view(batch_size, *shape, -1) + return h + + +def Conv(n_dim, d_in, d_out, kernel, stride=1, padding=0, dilation=1, **kwargs): + cls = { + 1: nn.Conv1d, + 2: nn.Conv2d, + 3: nn.Conv3d, + }[n_dim] + return cls(d_in, d_out, kernel, stride=stride, padding=padding, dilation=dilation, **kwargs) + + +def flatten(x): + batch_size, *shape, n_channels = x.shape + n_ctx = np.prod(shape) + return x.view(batch_size, n_ctx, n_channels), AttrDict( + shape=shape, n_ctx=n_ctx, n_channels=n_channels + ) + + +def unflatten(x, info): + batch_size = x.shape[0] + return x.view(batch_size, *info.shape, info.n_channels) + + +def torchify(x): + extent = list(range(1, x.ndim - 1)) + return x.permute([0, x.ndim - 1, *extent]) + + +def untorchify(x): + extent = list(range(2, x.ndim)) + return x.permute([0, *extent, 1]) + + +class MLP(nn.Module): + def __init__( + self, + d_input: int, + d_hidden: List[int], + d_output: int, + act_name: str = "quick_gelu", + bias: bool = True, + init: Optional[str] = None, + init_scale: float = 1.0, + zero_out: bool = False, + ): + """ + Required: d_input, d_hidden, d_output + Optional: act_name, bias + """ + super().__init__() + + ds = [d_input] + d_hidden + [d_output] + affines = [nn.Linear(d_in, d_out, bias=bias) for d_in, d_out in zip(ds[:-1], ds[1:])] + self.d = ds + self.affines = nn.ModuleList(affines) + self.act = get_act(act_name) + + mlp_init(self.affines, init=init, init_scale=init_scale) + if zero_out: + zero_init(affines[-1]) + + def forward(self, h, options: Optional[AttrDict] = None, log_prefix: str = ""): + options = AttrDict() if options is None else AttrDict(options) + *hid, out = self.affines + for i, f in enumerate(hid): + h = self.act(f(h)) + h = out(h) + return h + + +class MetaMLP(MetaModule): + def __init__( + self, + d_input: int, + d_hidden: List[int], + d_output: int, + act_name: str = "quick_gelu", + bias: bool = True, + meta_scale: bool = True, + meta_shift: bool = True, + meta_proj: bool = False, + meta_bias: bool = False, + trainable_meta: bool = False, + init: Optional[str] = None, + init_scale: float = 1.0, + zero_out: bool = False, + ): + super().__init__() + ds = [d_input] + d_hidden + [d_output] + affines = [ + MetaLinear( + d_in, + d_out, + bias=bias, + meta_scale=meta_scale, + meta_shift=meta_shift, + meta_proj=meta_proj, + meta_bias=meta_bias, + trainable_meta=trainable_meta, + ) + for d_in, d_out in zip(ds[:-1], ds[1:]) + ] + self.d = ds + self.affines = nn.ModuleList(affines) + self.act = get_act(act_name) + + mlp_init(affines, init=init, init_scale=init_scale) + if zero_out: + zero_init(affines[-1]) + + def forward(self, h, params=None, options: Optional[AttrDict] = None, log_prefix: str = ""): + options = AttrDict() if options is None else AttrDict(options) + params = self.update(params) + *hid, out = self.affines + for i, layer in enumerate(hid): + h = self.act(layer(h, params=subdict(params, f"{log_prefix}affines.{i}"))) + last = len(self.affines) - 1 + h = out(h, params=subdict(params, f"{log_prefix}affines.{last}")) + return h + + +class LayerNorm(nn.LayerNorm): + def __init__( + self, norm_shape: Union[int, Tuple[int]], eps: float = 1e-5, elementwise_affine: bool = True + ): + super().__init__(norm_shape, eps=eps, elementwise_affine=elementwise_affine) + self.width = np.prod(norm_shape) + self.max_numel = 65535 * self.width + + def forward(self, input): + if input.numel() > self.max_numel: + return F.layer_norm( + input.float(), self.normalized_shape, self.weight, self.bias, self.eps + ).type_as(input) + else: + return super(LayerNorm, self).forward(input.float()).type_as(input) + + +class PointSetEmbedding(nn.Module): + def __init__( + self, + *, + radius: float, + n_point: int, + n_sample: int, + d_input: int, + d_hidden: List[int], + patch_size: int = 1, + stride: int = 1, + activation: str = "swish", + group_all: bool = False, + padding_mode: str = "zeros", + fps_method: str = "fps", + **kwargs, + ): + super().__init__() + self.n_point = n_point + self.radius = radius + self.n_sample = n_sample + self.mlp_convs = nn.ModuleList() + self.act = get_act(activation) + self.patch_size = patch_size + self.stride = stride + last_channel = d_input + 3 + for out_channel in d_hidden: + self.mlp_convs.append( + nn.Conv2d( + last_channel, + out_channel, + kernel_size=(patch_size, 1), + stride=(stride, 1), + padding=(patch_size // 2, 0), + padding_mode=padding_mode, + **kwargs, + ) + ) + last_channel = out_channel + self.group_all = group_all + self.fps_method = fps_method + + def forward(self, xyz, points): + """ + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + Return: + new_points: sample points feature data, [B, d_hidden[-1], n_point] + """ + xyz = xyz.permute(0, 2, 1) + if points is not None: + points = points.permute(0, 2, 1) + + if self.group_all: + new_xyz, new_points = sample_and_group_all(xyz, points) + else: + new_xyz, new_points = sample_and_group( + self.n_point, + self.radius, + self.n_sample, + xyz, + points, + deterministic=not self.training, + fps_method=self.fps_method, + ) + # new_xyz: sampled points position data, [B, n_point, C] + # new_points: sampled points data, [B, n_point, n_sample, C+D] + new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, n_sample, n_point] + for i, conv in enumerate(self.mlp_convs): + new_points = self.act(self.apply_conv(new_points, conv)) + + new_points = new_points.mean(dim=2) + return new_points + + def apply_conv(self, points: torch.Tensor, conv: nn.Module): + batch, channels, n_samples, _ = points.shape + # Shuffle the representations + if self.patch_size > 1: + # TODO shuffle deterministically when not self.training + _, indices = torch.rand(batch, channels, n_samples, 1, device=points.device).sort(dim=2) + points = torch.gather(points, 2, torch.broadcast_to(indices, points.shape)) + return conv(points) diff --git a/shap_e/models/nn/pointnet2_utils.py b/shap_e/models/nn/pointnet2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73c63cfb18b2cc6543f9805db74ca8e26d90e4e1 --- /dev/null +++ b/shap_e/models/nn/pointnet2_utils.py @@ -0,0 +1,370 @@ +""" +Based on https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet2_utils.py + +MIT License + +Copyright (c) 2019 benny + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from time import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def timeit(tag, t): + print("{}: {}s".format(tag, time() - t)) + return time() + + +def pc_normalize(pc): + l = pc.shape[0] + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc**2, axis=1))) + pc = pc / m + return pc + + +def square_distance(src, dst): + """ + Calculate Euclid distance between each two points. + + src^T * dst = xn * xm + yn * ym + zn * zm; + sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; + sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; + dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 + = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst + + Input: + src: source points, [B, N, C] + dst: target points, [B, M, C] + Output: + dist: per-point square distance, [B, N, M] + """ + B, N, _ = src.shape + _, M, _ = dst.shape + dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) + dist += torch.sum(src**2, -1).view(B, N, 1) + dist += torch.sum(dst**2, -1).view(B, 1, M) + return dist + + +def index_points(points, idx): + """ + + Input: + points: input points data, [B, N, C] + idx: sample index data, [B, S] + Return: + new_points:, indexed points data, [B, S, C] + """ + device = points.device + B = points.shape[0] + view_shape = list(idx.shape) + view_shape[1:] = [1] * (len(view_shape) - 1) + repeat_shape = list(idx.shape) + repeat_shape[0] = 1 + batch_indices = ( + torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) + ) + new_points = points[batch_indices, idx, :] + return new_points + + +def farthest_point_sample(xyz, npoint, deterministic=False): + """ + Input: + xyz: pointcloud data, [B, N, 3] + npoint: number of samples + Return: + centroids: sampled pointcloud index, [B, npoint] + """ + device = xyz.device + B, N, C = xyz.shape + centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) + distance = torch.ones(B, N).to(device) * 1e10 + if deterministic: + farthest = torch.arange(0, B, dtype=torch.long).to(device) + else: + farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) + batch_indices = torch.arange(B, dtype=torch.long).to(device) + for i in range(npoint): + centroids[:, i] = farthest + centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) + dist = torch.sum((xyz - centroid) ** 2, -1) + mask = dist < distance + distance[mask] = dist[mask] + farthest = torch.max(distance, -1)[1] + return centroids + + +def query_ball_point(radius, nsample, xyz, new_xyz): + """ + Input: + radius: local region radius + nsample: max sample number in local region + xyz: all points, [B, N, 3] + new_xyz: query points, [B, S, 3] + Return: + group_idx: grouped points index, [B, S, nsample] + """ + device = xyz.device + B, N, C = xyz.shape + _, S, _ = new_xyz.shape + group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) + sqrdists = square_distance(new_xyz, xyz) + group_idx[sqrdists > radius**2] = N + group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] + group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) + mask = group_idx == N + group_idx[mask] = group_first[mask] + return group_idx + + +def sample_and_group( + npoint, + radius, + nsample, + xyz, + points, + returnfps=False, + deterministic=False, + fps_method: str = "fps", +): + """ + Input: + npoint: + radius: + nsample: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + Return: + new_xyz: sampled points position data, [B, npoint, nsample, 3] + new_points: sampled points data, [B, npoint, nsample, 3+D] + """ + B, N, C = xyz.shape + S = npoint + if fps_method == "fps": + fps_idx = farthest_point_sample(xyz, npoint, deterministic=deterministic) # [B, npoint, C] + elif fps_method == "first": + fps_idx = torch.arange(npoint)[None].repeat(B, 1) + else: + raise ValueError(f"Unknown FPS method: {fps_method}") + new_xyz = index_points(xyz, fps_idx) + idx = query_ball_point(radius, nsample, xyz, new_xyz) + grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] + grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) + + if points is not None: + grouped_points = index_points(points, idx) + new_points = torch.cat( + [grouped_xyz_norm, grouped_points], dim=-1 + ) # [B, npoint, nsample, C+D] + else: + new_points = grouped_xyz_norm + if returnfps: + return new_xyz, new_points, grouped_xyz, fps_idx + else: + return new_xyz, new_points + + +def sample_and_group_all(xyz, points): + """ + Input: + xyz: input points position data, [B, N, 3] + points: input points data, [B, N, D] + Return: + new_xyz: sampled points position data, [B, 1, 3] + new_points: sampled points data, [B, 1, N, 3+D] + """ + device = xyz.device + B, N, C = xyz.shape + new_xyz = torch.zeros(B, 1, C).to(device) + grouped_xyz = xyz.view(B, 1, N, C) + if points is not None: + new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) + else: + new_points = grouped_xyz + return new_xyz, new_points + + +class PointNetSetAbstraction(nn.Module): + def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): + super(PointNetSetAbstraction, self).__init__() + self.npoint = npoint + self.radius = radius + self.nsample = nsample + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) + self.mlp_bns.append(nn.BatchNorm2d(out_channel)) + last_channel = out_channel + self.group_all = group_all + + def forward(self, xyz, points): + """ + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + Return: + new_xyz: sampled points position data, [B, C, S] + new_points_concat: sample points feature data, [B, D', S] + """ + xyz = xyz.permute(0, 2, 1) + if points is not None: + points = points.permute(0, 2, 1) + + if self.group_all: + new_xyz, new_points = sample_and_group_all(xyz, points) + else: + new_xyz, new_points = sample_and_group( + self.npoint, self.radius, self.nsample, xyz, points, deterministic=not self.training + ) + # new_xyz: sampled points position data, [B, npoint, C] + # new_points: sampled points data, [B, npoint, nsample, C+D] + new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] + new_points = F.relu(bn(conv(new_points))) + + new_points = torch.max(new_points, 2)[0] + new_xyz = new_xyz.permute(0, 2, 1) + return new_xyz, new_points + + +class PointNetSetAbstractionMsg(nn.Module): + def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): + super(PointNetSetAbstractionMsg, self).__init__() + self.npoint = npoint + self.radius_list = radius_list + self.nsample_list = nsample_list + self.conv_blocks = nn.ModuleList() + self.bn_blocks = nn.ModuleList() + for i in range(len(mlp_list)): + convs = nn.ModuleList() + bns = nn.ModuleList() + last_channel = in_channel + 3 + for out_channel in mlp_list[i]: + convs.append(nn.Conv2d(last_channel, out_channel, 1)) + bns.append(nn.BatchNorm2d(out_channel)) + last_channel = out_channel + self.conv_blocks.append(convs) + self.bn_blocks.append(bns) + + def forward(self, xyz, points): + """ + Input: + xyz: input points position data, [B, C, N] + points: input points data, [B, D, N] + Return: + new_xyz: sampled points position data, [B, C, S] + new_points_concat: sample points feature data, [B, D', S] + """ + xyz = xyz.permute(0, 2, 1) + if points is not None: + points = points.permute(0, 2, 1) + + B, N, C = xyz.shape + S = self.npoint + new_xyz = index_points(xyz, farthest_point_sample(xyz, S, deterministic=not self.training)) + new_points_list = [] + for i, radius in enumerate(self.radius_list): + K = self.nsample_list[i] + group_idx = query_ball_point(radius, K, xyz, new_xyz) + grouped_xyz = index_points(xyz, group_idx) + grouped_xyz -= new_xyz.view(B, S, 1, C) + if points is not None: + grouped_points = index_points(points, group_idx) + grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) + else: + grouped_points = grouped_xyz + + grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] + for j in range(len(self.conv_blocks[i])): + conv = self.conv_blocks[i][j] + bn = self.bn_blocks[i][j] + grouped_points = F.relu(bn(conv(grouped_points))) + new_points = torch.max(grouped_points, 2)[0] # [B, D', S] + new_points_list.append(new_points) + + new_xyz = new_xyz.permute(0, 2, 1) + new_points_concat = torch.cat(new_points_list, dim=1) + return new_xyz, new_points_concat + + +class PointNetFeaturePropagation(nn.Module): + def __init__(self, in_channel, mlp): + super(PointNetFeaturePropagation, self).__init__() + self.mlp_convs = nn.ModuleList() + self.mlp_bns = nn.ModuleList() + last_channel = in_channel + for out_channel in mlp: + self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) + self.mlp_bns.append(nn.BatchNorm1d(out_channel)) + last_channel = out_channel + + def forward(self, xyz1, xyz2, points1, points2): + """ + Input: + xyz1: input points position data, [B, C, N] + xyz2: sampled input points position data, [B, C, S] + points1: input points data, [B, D, N] + points2: input points data, [B, D, S] + Return: + new_points: upsampled points data, [B, D', N] + """ + xyz1 = xyz1.permute(0, 2, 1) + xyz2 = xyz2.permute(0, 2, 1) + + points2 = points2.permute(0, 2, 1) + B, N, C = xyz1.shape + _, S, _ = xyz2.shape + + if S == 1: + interpolated_points = points2.repeat(1, N, 1) + else: + dists = square_distance(xyz1, xyz2) + dists, idx = dists.sort(dim=-1) + dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] + + dist_recip = 1.0 / (dists + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + interpolated_points = torch.sum( + index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2 + ) + + if points1 is not None: + points1 = points1.permute(0, 2, 1) + new_points = torch.cat([points1, interpolated_points], dim=-1) + else: + new_points = interpolated_points + + new_points = new_points.permute(0, 2, 1) + for i, conv in enumerate(self.mlp_convs): + bn = self.mlp_bns[i] + new_points = F.relu(bn(conv(new_points))) + return new_points diff --git a/shap_e/models/nn/utils.py b/shap_e/models/nn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76998a3311b1c2a5e1ff4c0fc8681af1d21ce9de --- /dev/null +++ b/shap_e/models/nn/utils.py @@ -0,0 +1,37 @@ +from typing import Iterable, Union + +import numpy as np +import torch + +ArrayType = Union[np.ndarray, Iterable[int], torch.Tensor] + + +def to_torch(arr: ArrayType, dtype=torch.float): + if isinstance(arr, torch.Tensor): + return arr + return torch.from_numpy(np.array(arr)).to(dtype) + + +def sample_pmf(pmf: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + Sample from the given discrete probability distribution with replacement. + + The i-th bin is assumed to have mass pmf[i]. + + :param pmf: [batch_size, *shape, n_samples, 1] where (pmf.sum(dim=-2) == 1).all() + :param n_samples: number of samples + + :return: indices sampled with replacement + """ + + *shape, support_size, last_dim = pmf.shape + assert last_dim == 1 + + cdf = torch.cumsum(pmf.view(-1, support_size), dim=1) + inds = torch.searchsorted(cdf, torch.rand(cdf.shape[0], n_samples, device=cdf.device)) + + return inds.view(*shape, n_samples, 1).clamp(0, support_size - 1) + + +def safe_divide(a, b, epsilon=1e-6): + return a / torch.where(b < 0, b - epsilon, b + epsilon) diff --git a/shap_e/models/query.py b/shap_e/models/query.py new file mode 100644 index 0000000000000000000000000000000000000000..a95fcbb2c698cec2fcb3a5b6d79eb763bec39b32 --- /dev/null +++ b/shap_e/models/query.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Callable, Optional + +import torch + + +@dataclass +class Query: + # Both of these are of shape [batch_size x ... x 3] + position: torch.Tensor + direction: Optional[torch.Tensor] = None + + t_min: Optional[torch.Tensor] = None + t_max: Optional[torch.Tensor] = None + + def copy(self) -> "Query": + return Query( + position=self.position, + direction=self.direction, + t_min=self.t_min, + t_max=self.t_max, + ) + + def map_tensors(self, f: Callable[[torch.Tensor], torch.Tensor]) -> "Query": + return Query( + position=f(self.position), + direction=f(self.direction) if self.direction is not None else None, + t_min=f(self.t_min) if self.t_min is not None else None, + t_max=f(self.t_max) if self.t_max is not None else None, + ) diff --git a/shap_e/models/renderer.py b/shap_e/models/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..45b72fb9dfc4feb40917857091b5e255cabaadc7 --- /dev/null +++ b/shap_e/models/renderer.py @@ -0,0 +1,387 @@ +from abc import abstractmethod +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch + +from shap_e.models.nn.camera import ( + DifferentiableCamera, + DifferentiableProjectiveCamera, + get_image_coords, + projective_camera_frame, +) +from shap_e.models.nn.meta import MetaModule +from shap_e.util.collections import AttrDict + + +class Renderer(MetaModule): + """ + A rendering abstraction that can render rays and views by calling the + appropriate models. The models are instantiated outside but registered in + this module. + """ + + @abstractmethod + def render_views( + self, + batch: AttrDict, + params: Optional[Dict] = None, + options: Optional[Dict] = None, + ) -> AttrDict: + """ + Returns a backproppable rendering of a view + + :param batch: contains + - height: Optional[int] + - width: Optional[int] + - inner_batch_size or ray_batch_size: Optional[int] defaults to 4096 rays + + And additionally, to specify poses with a default up direction: + - poses: [batch_size x *shape x 2 x 3] where poses[:, ..., 0, :] are the camera + positions, and poses[:, ..., 1, :] are the z-axis (toward the object) of + the camera frame. + - camera: DifferentiableCamera. Assumes the same camera position + across batch for simplicity. Could eventually support + batched cameras. + + or to specify a batch of arbitrary poses: + - cameras: DifferentiableCameraBatch of shape [batch_size x *shape]. + + :param params: Meta parameters + :param options: Optional[Dict] + """ + + +class RayRenderer(Renderer): + """ + A rendering abstraction that can render rays and views by calling the + appropriate models. The models are instantiated outside but registered in + this module. + """ + + @abstractmethod + def render_rays( + self, + batch: AttrDict, + params: Optional[Dict] = None, + options: Optional[Dict] = None, + ) -> AttrDict: + """ + :param batch: has + - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. + - radii (optional): [batch_size x ... x 1] the "thickness" of each ray. + :param options: Optional[Dict] + """ + + def render_views( + self, + batch: AttrDict, + params: Optional[Dict] = None, + options: Optional[Dict] = None, + device: torch.device = torch.device("cuda"), + ) -> AttrDict: + output = render_views_from_rays( + self.render_rays, + batch, + params=params, + options=options, + device=self.device, + ) + return output + + def forward( + self, + batch: AttrDict, + params: Optional[Dict] = None, + options: Optional[Dict] = None, + ) -> AttrDict: + """ + :param batch: must contain either + + - rays: [batch_size x ... x 2 x 3] specify the origin and direction of each ray. + + or + + - poses: [batch_size x 2 x 3] where poses[:, 0] are the camera + positions, and poses[:, 1] are the z-axis (toward the object) of + the camera frame. + - camera: an instance of Camera that implements camera_rays + + or + + - cameras: DifferentiableCameraBatch of shape [batch_size x *shape]. + + For both of the above two options, these may be specified. + - height: Optional[int] + - width: Optional[int] + - ray_batch_size or inner_batch_size: Optional[int] defaults to 4096 rays + + :param params: a dictionary of optional meta parameters. + :param options: A Dict of other hyperparameters that could be + related to rendering or debugging + + :return: a dictionary containing + + - channels: [batch_size, *shape, n_channels] + - distances: [batch_size, *shape, 1] + - transmittance: [batch_size, *shape, 1] + - aux_losses: Dict[str, torch.Tensor] + """ + + if "rays" in batch: + for key in ["poses", "camera", "height", "width"]: + assert key not in batch + return self.render_rays(batch, params=params, options=options) + elif "poses" in batch or "cameras" in batch: + assert "rays" not in batch + if "poses" in batch: + assert "camera" in batch + else: + assert "camera" not in batch + return self.render_views(batch, params=params, options=options) + + raise NotImplementedError + + +def get_camera_from_batch(batch: AttrDict) -> Tuple[DifferentiableCamera, int, Tuple[int]]: + if "poses" in batch: + assert not "cameras" in batch + batch_size, *inner_shape, n_vecs, spatial_dim = batch.poses.shape + assert n_vecs == 2 and spatial_dim == 3 + inner_batch_size = int(np.prod(inner_shape)) + poses = batch.poses.view(batch_size * inner_batch_size, 2, 3) + position, direction = poses[:, 0], poses[:, 1] + camera = projective_camera_frame(position, direction, batch.camera) + elif "cameras" in batch: + assert not "camera" in batch + batch_size, *inner_shape = batch.cameras.shape + camera = batch.cameras.flat_camera + else: + raise ValueError(f'neither "poses" nor "cameras" found in keys: {batch.keys()}') + if "height" in batch and "width" in batch: + camera = camera.resize_image(batch.width, batch.height) + return camera, batch_size, inner_shape + + +def append_tensor(val_list: Optional[List[torch.Tensor]], output: Optional[torch.Tensor]): + if val_list is None: + return [output] + return val_list + [output] + + +def render_views_from_rays( + render_rays: Callable[[AttrDict, AttrDict, AttrDict], AttrDict], + batch: AttrDict, + params: Optional[Dict] = None, + options: Optional[Dict] = None, + device: torch.device = torch.device("cuda"), + patch_size: Optional[int] = 128, + use_patch: bool = False, +) -> AttrDict: + # import pdb; pdb.set_trace() + camera, batch_size, inner_shape = get_camera_from_batch(batch) + inner_batch_size = int(np.prod(inner_shape)) + + coords = get_image_coords(camera.width, camera.height).to(device) + coords = torch.broadcast_to(coords.unsqueeze(0), [batch_size * inner_batch_size, *coords.shape]) + rays = camera.camera_rays(coords) + + # mip-NeRF radii calculation from: https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/datasets.py#L193-L200 + directions = rays.view(batch_size, inner_batch_size, camera.height, camera.width, 2, 3)[ + ..., 1, : + ] + neighbor_dists = torch.linalg.norm(directions[:, :, :, 1:] - directions[:, :, :, :-1], dim=-1) + neighbor_dists = torch.cat([neighbor_dists, neighbor_dists[:, :, :, -2:-1]], dim=3) + radii = (neighbor_dists * 2 / np.sqrt(12)).view(batch_size, -1, 1) + # do the patching + if use_patch: + print("use_patch") + assert patch_size < camera.height + H, W = camera.height, camera.width + # import pdb; pdb.set_trace() + down_scale_factor = min(H // patch_size, 4) + rays = rays.view(batch_size*inner_batch_size, camera.height, camera.width, 2, 3) + rays_o = rays[..., 0, :] + rays_d = rays[..., 1, :] + global_rays_o = torch.nn.functional.interpolate(rays_o.permute(0, 3, 1, 2), + (H // down_scale_factor, W // down_scale_factor), + mode="bilinear").permute(0, 2, 3, 1) + global_rays_d = torch.nn.functional.interpolate(rays_d.permute(0, 3, 1, 2), + (H // down_scale_factor, W // down_scale_factor), + mode="bilinear").permute(0, 2, 3, 1) + global_rays = torch.stack([global_rays_o, global_rays_d], dim=-2) + + global_rays = global_rays.view(batch_size, inner_batch_size * camera.height * camera.width // (down_scale_factor **2) , 2, 3) + + + + # rays = rays.view(batch_size, inner_batch_size * camera.height * camera.width, 2, 3) + + if isinstance(camera, DifferentiableProjectiveCamera): + # Compute the camera z direction corresponding to every ray's pixel. + # Used for depth computations below. + z_directions = ( + (camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True)) + .reshape([batch_size, inner_batch_size, 1, 3]) + .repeat(1, 1, camera.width * camera.height // down_scale_factor **2 , 1) + .reshape(1, inner_batch_size * camera.height * camera.width // down_scale_factor ** 2, 3) + ) + + ray_batch_size = batch.get("ray_batch_size", batch.get("inner_batch_size", 1024)) + + assert global_rays.shape[1] % ray_batch_size == 0 + n_batches = global_rays.shape[1] // ray_batch_size + + output_list_global = AttrDict(aux_losses=dict()) + for idx in range(n_batches): + rays_batch = AttrDict( + rays=global_rays[:, idx * ray_batch_size: (idx + 1) * ray_batch_size], + radii=global_rays[:, idx * ray_batch_size: (idx + 1) * ray_batch_size], + ) + output_global = render_rays(rays_batch, params=params, options=options) + + # output.channels.register_hook(lambda grad: print("render_rays", grad)) + + if isinstance(camera, DifferentiableProjectiveCamera): + z_batch = z_directions[:, idx * ray_batch_size: (idx + 1) * ray_batch_size] + ray_directions = rays_batch.rays[:, :, 1] + z_dots = (ray_directions * z_batch).sum(-1, keepdim=True) + output_global.depth = output_global.distances * z_dots + + output_list_global = output_list_global.combine(output_global, append_tensor) + + PS = patch_size + patch_x = torch.randint(0, W - PS, (1,)).item() + patch_y = torch.randint(0, H - PS, (1,)).item() + + patch_rays = rays[..., patch_y: patch_y + PS, patch_x: patch_x + PS, :, :] + patch_rays = patch_rays.reshape(batch_size, inner_batch_size * PS * PS , 2, 3) + + if isinstance(camera, DifferentiableProjectiveCamera): + # Compute the camera z direction corresponding to every ray's pixel. + # Used for depth computations below. + z_directions_patch = ( + (camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True)) + .reshape([batch_size, inner_batch_size, 1, 3]) + .repeat(1, 1, PS * PS , 1) + .reshape(1, inner_batch_size * PS * PS, 3) + ) + + # ray_batch_size = batch.get("ray_batch_size", batch.get("inner_batch_size", 4096)) + print(ray_batch_size, patch_rays.shape[1]) + assert patch_rays.shape[1] % ray_batch_size == 0 + n_batches = patch_rays.shape[1] // ray_batch_size + + output_list = AttrDict(aux_losses=dict()) + for idx in range(n_batches): + rays_batch = AttrDict( + rays=patch_rays[:, idx * ray_batch_size: (idx + 1) * ray_batch_size], + radii=patch_rays[:, idx * ray_batch_size: (idx + 1) * ray_batch_size], + ) + output_patch = render_rays(rays_batch, params=params, options=options) + + # output.channels.register_hook(lambda grad: print("render_rays", grad)) + + if isinstance(camera, DifferentiableProjectiveCamera): + z_batch = z_directions_patch[:, idx * ray_batch_size: (idx + 1) * ray_batch_size] + ray_directions = rays_batch.rays[:, :, 1] + z_dots = (ray_directions * z_batch).sum(-1, keepdim=True) + output_patch.depth = output_patch.distances * z_dots + + output_list = output_list.combine(output_patch, append_tensor) + def _resize(val_list: List[torch.Tensor], H, W): + val = torch.cat(val_list, dim=1) + assert val.shape[1] == inner_batch_size * H * W + return val.view(batch_size, *inner_shape, H, W, -1) + + def _avg(_key: str, loss_list: List[torch.Tensor]): + return sum(loss_list) / n_batches + + output_global = AttrDict( + {name: _resize(val_list, camera.width // down_scale_factor, camera.height // down_scale_factor) for name, val_list in output_list_global.items() if name != "aux_losses"} + ) + output_global.aux_losses = output_list_global.aux_losses.map(_avg) + + output = AttrDict( + {name: _resize(val_list, PS, PS) for name, val_list in output_list.items() if name != "aux_losses"} + ) + output.aux_losses = output_list.aux_losses.map(_avg) + + + + valid_patch_key = [] + for key in output: + if torch.is_tensor(output[key]): + print(key, output[key].shape, output["channels"].shape) + if len(output[key].shape) == len(output["channels"].shape): + if output[key][..., 0].shape == output["channels"][..., 0].shape: + valid_patch_key.append(key) + # import pdb; pdb.set_trace() + for key in valid_patch_key: + if output_global[key].dtype != torch.bool: + output_global[key] = torch.nn.functional.interpolate( + output_global[key].view(inner_batch_size*batch_size, camera.width // down_scale_factor, camera.height //down_scale_factor, -1).permute(0, 3, 1, 2), (H, W), mode="bilinear" + ).permute(0, 2, 3, 1).view(batch_size, inner_batch_size, H, W, -1) + else: + output_global[key] = torch.nn.functional.interpolate( + output_global[key].view(inner_batch_size*batch_size, camera.width // down_scale_factor, camera.height //down_scale_factor, -1).permute(0, 3, 1, 2).to(torch.float32), (H, W), mode="nearest" + ).permute(0, 2, 3, 1).view(batch_size, inner_batch_size, H, W, -1).to(torch.bool) + output_global[key] = output_global[key].detach() + output_global[key][ + ..., patch_y: patch_y + PS, patch_x: patch_x + PS, : + ] = output[key] + output = output_global + + + return output + + else: + rays = rays.view(batch_size, inner_batch_size * camera.height * camera.width, 2, 3) + + if isinstance(camera, DifferentiableProjectiveCamera): + # Compute the camera z direction corresponding to every ray's pixel. + # Used for depth computations below. + z_directions = ( + (camera.z / torch.linalg.norm(camera.z, dim=-1, keepdim=True)) + .reshape([batch_size, inner_batch_size, 1, 3]) + .repeat(1, 1, camera.width * camera.height, 1) + .reshape(1, inner_batch_size * camera.height * camera.width, 3) + ) + + ray_batch_size = batch.get("ray_batch_size", batch.get("inner_batch_size", 4096)) + assert rays.shape[1] % ray_batch_size == 0 + n_batches = rays.shape[1] // ray_batch_size + + output_list = AttrDict(aux_losses=dict()) + for idx in range(n_batches): + rays_batch = AttrDict( + rays=rays[:, idx * ray_batch_size : (idx + 1) * ray_batch_size], + radii=radii[:, idx * ray_batch_size : (idx + 1) * ray_batch_size], + ) + output = render_rays(rays_batch, params=params, options=options) + + # output.channels.register_hook(lambda grad: print("render_rays", grad)) + + if isinstance(camera, DifferentiableProjectiveCamera): + z_batch = z_directions[:, idx * ray_batch_size : (idx + 1) * ray_batch_size] + ray_directions = rays_batch.rays[:, :, 1] + z_dots = (ray_directions * z_batch).sum(-1, keepdim=True) + output.depth = output.distances * z_dots + + output_list = output_list.combine(output, append_tensor) + # for key in params: + # if params[key].requires_grad: + # params[key].register_hook(lambda grad: print("params", key, grad)) + def _resize(val_list: List[torch.Tensor]): + val = torch.cat(val_list, dim=1) + assert val.shape[1] == inner_batch_size * camera.height * camera.width + return val.view(batch_size, *inner_shape, camera.height, camera.width, -1) + + def _avg(_key: str, loss_list: List[torch.Tensor]): + return sum(loss_list) / n_batches + + output = AttrDict( + {name: _resize(val_list) for name, val_list in output_list.items() if name != "aux_losses"} + ) + output.aux_losses = output_list.aux_losses.map(_avg) + return output diff --git a/shap_e/models/stf/__init__.py b/shap_e/models/stf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/models/stf/__pycache__/__init__.cpython-39.pyc b/shap_e/models/stf/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62cd0f26590da3777e083c960bfaba449eba8b82 Binary files /dev/null and b/shap_e/models/stf/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/models/stf/__pycache__/base.cpython-39.pyc b/shap_e/models/stf/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73dbfbab5af39520dcd68f1ee607a9dc111bb4b8 Binary files /dev/null and b/shap_e/models/stf/__pycache__/base.cpython-39.pyc differ diff --git a/shap_e/models/stf/__pycache__/mlp.cpython-39.pyc b/shap_e/models/stf/__pycache__/mlp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..852f579980f4cb7bdcbe3d7a7fbec6500a81ecec Binary files /dev/null and b/shap_e/models/stf/__pycache__/mlp.cpython-39.pyc differ diff --git a/shap_e/models/stf/__pycache__/renderer.cpython-39.pyc b/shap_e/models/stf/__pycache__/renderer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f64b48652a6b0afe3edda467b1b564508e6d0d6 Binary files /dev/null and b/shap_e/models/stf/__pycache__/renderer.cpython-39.pyc differ diff --git a/shap_e/models/stf/base.py b/shap_e/models/stf/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a222dbab9b3f74c0e8695665756e67e61317ca1a --- /dev/null +++ b/shap_e/models/stf/base.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +import torch + +from shap_e.models.query import Query +from shap_e.models.renderer import append_tensor +from shap_e.util.collections import AttrDict + + +class Model(ABC): + @abstractmethod + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict[str, Any]: + """ + Predict an attribute given position + """ + + def forward_batched( + self, + query: Query, + query_batch_size: int = 4096, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict[str, Any]: + if not query.position.numel(): + # Avoid torch.cat() of zero tensors. + return self(query, params=params, options=options) + + if options.cache is None: + created_cache = True + options.cache = AttrDict() + else: + created_cache = False + + results_list = AttrDict() + for i in range(0, query.position.shape[1], query_batch_size): + out = self( + query=query.map_tensors(lambda x, i=i: x[:, i : i + query_batch_size]), + params=params, + options=options, + ) + results_list = results_list.combine(out, append_tensor) + + if created_cache: + del options["cache"] + + return results_list.map(lambda key, tensor_list: torch.cat(tensor_list, dim=1)) diff --git a/shap_e/models/stf/mlp.py b/shap_e/models/stf/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..d13d21c154e99a181502f6dd49c51f6e8312f1be --- /dev/null +++ b/shap_e/models/stf/mlp.py @@ -0,0 +1,213 @@ +from functools import partial +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +from shap_e.models.nn.checkpoint import checkpoint +from shap_e.models.nn.encoding import encode_position, maybe_encode_direction +from shap_e.models.nn.meta import MetaModule, subdict +from shap_e.models.nn.ops import MetaLinear, get_act, mlp_init +from shap_e.models.query import Query +from shap_e.util.collections import AttrDict + +from .base import Model + + +class MLPModel(MetaModule, Model): + def __init__( + self, + n_output: int, + output_activation: str, + # Positional encoding parameters + posenc_version: str = "v1", + # Direction related channel prediction + insert_direction_at: Optional[int] = None, + # MLP parameters + d_hidden: int = 256, + n_hidden_layers: int = 4, + activation: str = "relu", + init: Optional[str] = None, + init_scale: float = 1.0, + meta_parameters: bool = False, + trainable_meta: bool = False, + meta_proj: bool = True, + meta_bias: bool = True, + meta_start: int = 0, + meta_stop: Optional[int] = None, + n_meta_layers: Optional[int] = None, + register_freqs: bool = False, + device: torch.device = torch.device("cuda"), + ): + super().__init__() + + if register_freqs: + self.register_buffer("freqs", 2.0 ** torch.arange(10, device=device).view(1, 10)) + + # Positional encoding + self.posenc_version = posenc_version + dummy = torch.eye(1, 3) + d_posenc_pos = encode_position(posenc_version, position=dummy).shape[-1] + d_posenc_dir = maybe_encode_direction(posenc_version, position=dummy).shape[-1] + + # Instantiate the MLP + mlp_widths = [d_hidden] * n_hidden_layers + input_widths = [d_posenc_pos, *mlp_widths] + output_widths = mlp_widths + [n_output] + + self.meta_parameters = meta_parameters + + # When this model is used jointly to express NeRF, it may have to + # process directions as well in which case we simply concatenate + # the direction representation at the specified layer. + self.insert_direction_at = insert_direction_at + if insert_direction_at is not None: + input_widths[self.insert_direction_at] += d_posenc_dir + + linear_cls = lambda meta: ( + partial( + MetaLinear, + meta_scale=False, + meta_shift=False, + meta_proj=meta_proj, + meta_bias=meta_bias, + trainable_meta=trainable_meta, + ) + if meta + else nn.Linear + ) + + if meta_stop is None: + if n_meta_layers is not None: + assert n_meta_layers > 0 + meta_stop = meta_start + n_meta_layers - 1 + else: + meta_stop = n_hidden_layers + + if meta_parameters: + metas = [meta_start <= layer <= meta_stop for layer in range(n_hidden_layers + 1)] + else: + metas = [False] * (n_hidden_layers + 1) + + self.mlp = nn.ModuleList( + [ + linear_cls(meta)(d_in, d_out, device=device) + for meta, d_in, d_out in zip(metas, input_widths, output_widths) + ] + ) + + mlp_init(self.mlp, init=init, init_scale=init_scale) + + self.activation = get_act(activation) + self.output_activation = get_act(output_activation) + + self.device = device + self.to(device) + + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict: + """ + :param position: [batch_size x ... x 3] + :param params: Meta parameters + :param options: Optional hyperparameters + """ + + # query.direction is None typically for SDF models and training + h_final, _h_directionless = self._mlp( + query.position, query.direction, params=params, options=options + ) + return self.output_activation(h_final) + + def _run_mlp( + self, position: torch.Tensor, direction: torch.Tensor, params: AttrDict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :return: the final and directionless activations at the given query + """ + h_preact = h = encode_position(self.posenc_version, position=position) + h_directionless = None + for i, layer in enumerate(self.mlp): + if i == self.insert_direction_at: + h_directionless = h_preact + h_direction = maybe_encode_direction( + self.posenc_version, position=position, direction=direction + ) + h = torch.cat([h, h_direction], dim=-1) + if isinstance(layer, MetaLinear): + h = layer(h, params=subdict(params, f"mlp.{i}")) + else: + h = layer(h) + h_preact = h + if i < len(self.mlp) - 1: + h = self.activation(h) + h_final = h + if h_directionless is None: + h_directionless = h_preact + return h_final, h_directionless + + def _mlp( + self, + position: torch.Tensor, + direction: Optional[torch.Tensor] = None, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param position: [batch_size x ... x 3] + :param params: Meta parameters + :param options: Optional hyperparameters + :return: the final and directionless activations at the given query + """ + params = self.update(params) + options = AttrDict() if options is None else AttrDict(options) + + mlp = partial(self._run_mlp, direction=direction, params=params) + parameters = [] + for i, layer in enumerate(self.mlp): + if isinstance(layer, MetaLinear): + parameters.extend(list(subdict(params, f"mlp.{i}").values())) + else: + parameters.extend(layer.parameters()) + + h_final, h_directionless = checkpoint( + mlp, (position,), parameters, options.checkpoint_stf_model + ) + + return h_final, h_directionless + + +class MLPSDFModel(MLPModel): + def __init__(self, initial_bias: float = -0.1, **kwargs): + super().__init__(n_output=1, output_activation="identity", **kwargs) + self.mlp[-1].bias.data.fill_(initial_bias) + + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict[str, Any]: + signed_distance = super().forward(query=query, params=params, options=options) + return AttrDict(signed_distance=signed_distance) + + +class MLPTextureFieldModel(MLPModel): + def __init__( + self, + n_channels: int = 3, + **kwargs, + ): + super().__init__(n_output=n_channels, output_activation="sigmoid", **kwargs) + + def forward( + self, + query: Query, + params: Optional[Dict[str, torch.Tensor]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AttrDict[str, Any]: + channels = super().forward(query=query, params=params, options=options) + return AttrDict(channels=channels) diff --git a/shap_e/models/stf/renderer.py b/shap_e/models/stf/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..099de74b21492c28aaadc3dd220b64b01fc6647f --- /dev/null +++ b/shap_e/models/stf/renderer.py @@ -0,0 +1,507 @@ +import warnings +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +from shap_e.models.nn.camera import DifferentiableCamera, DifferentiableProjectiveCamera +from shap_e.models.nn.meta import subdict +from shap_e.models.nn.utils import to_torch +from shap_e.models.query import Query +from shap_e.models.renderer import Renderer, get_camera_from_batch +from shap_e.models.volume import BoundingBoxVolume, Volume +from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR +from shap_e.rendering.mc import marching_cubes +from shap_e.rendering.torch_mesh import TorchMesh +from shap_e.rendering.view_data import ProjectiveCamera +from shap_e.util.collections import AttrDict + +from .base import Model + + +class STFRendererBase(ABC): + @abstractmethod + def get_signed_distance( + self, + position: torch.Tensor, + params: Dict[str, torch.Tensor], + options: AttrDict[str, Any], + ) -> torch.Tensor: + pass + + @abstractmethod + def get_texture( + self, + position: torch.Tensor, + params: Dict[str, torch.Tensor], + options: AttrDict[str, Any], + ) -> torch.Tensor: + pass + + +class STFRenderer(Renderer, STFRendererBase): + def __init__( + self, + sdf: Model, + tf: Model, + volume: Volume, + grid_size: int, + texture_channels: Sequence[str] = ("R", "G", "B"), + channel_scale: Sequence[float] = (255.0, 255.0, 255.0), + ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, + diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, + specular_color: Union[float, Tuple[float]] = 0.0, + output_srgb: bool = True, + device: torch.device = torch.device("cuda"), + **kwargs, + ): + super().__init__(**kwargs) + assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume" + self.sdf = sdf + self.tf = tf + self.volume = volume + self.grid_size = grid_size + self.texture_channels = texture_channels + self.channel_scale = to_torch(channel_scale).to(device) + self.ambient_color = ambient_color + self.diffuse_color = diffuse_color + self.specular_color = specular_color + self.output_srgb = output_srgb + self.device = device + self.to(device) + + def render_views( + self, + batch: Dict, + params: Optional[Dict] = None, + options: Optional[Dict] = None, + ) -> AttrDict: + params = self.update(params) + options = AttrDict() if not options else AttrDict(options) + + sdf_fn = partial(self.sdf.forward_batched, params=subdict(params, "sdf")) + tf_fn = partial(self.tf.forward_batched, params=subdict(params, "tf")) + nerstf_fn = None + + return render_views_from_stf( + batch, + options, + sdf_fn=sdf_fn, + tf_fn=tf_fn, + nerstf_fn=nerstf_fn, + volume=self.volume, + grid_size=self.grid_size, + channel_scale=self.channel_scale, + texture_channels=self.texture_channels, + ambient_color=self.ambient_color, + diffuse_color=self.diffuse_color, + specular_color=self.specular_color, + output_srgb=self.output_srgb, + device=self.device, + ) + + def get_signed_distance( + self, + query: Query, + params: Dict[str, torch.Tensor], + options: AttrDict[str, Any], + ) -> torch.Tensor: + return self.sdf( + query, + params=subdict(params, "sdf"), + options=options, + ).signed_distance + + def get_texture( + self, + query: Query, + params: Dict[str, torch.Tensor], + options: AttrDict[str, Any], + ) -> torch.Tensor: + return self.tf( + query, + params=subdict(params, "tf"), + options=options, + ).channels + + +def render_views_from_stf( + batch: Dict, + options: AttrDict[str, Any], + *, + sdf_fn: Optional[Callable], + tf_fn: Optional[Callable], + nerstf_fn: Optional[Callable], + volume: BoundingBoxVolume, + grid_size: int, + channel_scale: torch.Tensor, + texture_channels: Sequence[str] = ("R", "G", "B"), + ambient_color: Union[float, Tuple[float]] = 0.0, + diffuse_color: Union[float, Tuple[float]] = 1.0, + specular_color: Union[float, Tuple[float]] = 0.2, + output_srgb: bool = False, + device: torch.device = torch.device("cuda"), +) -> AttrDict: + """ + :param batch: contains either ["poses", "camera"], or ["cameras"]. Can + optionally contain any of ["height", "width", "query_batch_size"] + :param options: controls checkpointing, caching, and rendering + :param sdf_fn: returns [batch_size, query_batch_size, n_output] where + n_output >= 1. + :param tf_fn: returns [batch_size, query_batch_size, n_channels] + :param volume: AABB volume + :param grid_size: SDF sampling resolution + :param texture_channels: what texture to predict + :param channel_scale: how each channel is scaled + :return: at least + channels: [batch_size, len(cameras), height, width, 3] + transmittance: [batch_size, len(cameras), height, width, 1] + aux_losses: AttrDict[str, torch.Tensor] + """ + camera, batch_size, inner_shape = get_camera_from_batch(batch) + inner_batch_size = int(np.prod(inner_shape)) + assert camera.width == camera.height, "only square views are supported" + assert camera.x_fov == camera.y_fov, "only square views are supported" + assert isinstance(camera, DifferentiableProjectiveCamera) + + device = camera.origin.device + device_type = device.type + + TO_CACHE = ["fields", "raw_meshes", "raw_signed_distance", "raw_density", "mesh_mask", "meshes"] + if options.cache is not None and all(key in options.cache for key in TO_CACHE): + fields = options.cache.fields + raw_meshes = options.cache.raw_meshes + raw_signed_distance = options.cache.raw_signed_distance + raw_density = options.cache.raw_density + mesh_mask = options.cache.mesh_mask + else: + query_batch_size = batch.get("query_batch_size", batch.get("ray_batch_size", 4096)) + query_points = volume_query_points(volume, grid_size) + fn = nerstf_fn if sdf_fn is None else sdf_fn + sdf_out = fn( + query=Query(position=query_points[None].repeat(batch_size, 1, 1)), + query_batch_size=query_batch_size, + options=options, + ) + raw_signed_distance = sdf_out.signed_distance + raw_density = None + if "density" in sdf_out: + raw_density = sdf_out.density + with torch.autocast(device_type, enabled=False): + fields = sdf_out.signed_distance.float() + raw_signed_distance = sdf_out.signed_distance + assert ( + len(fields.shape) == 3 and fields.shape[-1] == 1 + ), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" + fields = fields.reshape(batch_size, *([grid_size] * 3)) + + # Force a negative border around the SDFs to close off all the models. + full_grid = torch.zeros( + batch_size, + grid_size + 2, + grid_size + 2, + grid_size + 2, + device=fields.device, + dtype=fields.dtype, + ) + full_grid.fill_(-1.0) + full_grid[:, 1:-1, 1:-1, 1:-1] = fields + fields = full_grid + + raw_meshes = [] + mesh_mask = [] + for field in fields: + raw_mesh = marching_cubes(field, volume.bbox_min, volume.bbox_max - volume.bbox_min) + if len(raw_mesh.faces) == 0: + # DDP deadlocks when there are unused parameters on some ranks + # and not others, so we make sure the field is a dependency in + # the graph regardless of empty meshes. + vertex_dependency = field.mean() + raw_mesh = TorchMesh( + verts=torch.zeros(3, 3, device=device) + vertex_dependency, + faces=torch.tensor([[0, 1, 2]], dtype=torch.long, device=device), + ) + # Make sure we only feed back zero gradients to the field + # by masking out the final renderings of this mesh. + mesh_mask.append(False) + else: + mesh_mask.append(True) + raw_meshes.append(raw_mesh) + mesh_mask = torch.tensor(mesh_mask, device=device) + + max_vertices = max(len(m.verts) for m in raw_meshes) + + fn = nerstf_fn if tf_fn is None else tf_fn + tf_out = fn( + query=Query( + position=torch.stack( + [m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], + dim=0, + ) + ), + query_batch_size=query_batch_size, + options=options, + ) + + if "cache" in options: + options.cache.fields = fields + options.cache.raw_meshes = raw_meshes + options.cache.raw_signed_distance = raw_signed_distance + options.cache.raw_density = raw_density + options.cache.mesh_mask = mesh_mask + + if output_srgb: + tf_out.channels = _convert_srgb_to_linear(tf_out.channels) + + # Make sure the raw meshes have colors. + with torch.autocast(device_type, enabled=False): + textures = tf_out.channels.float() + assert len(textures.shape) == 3 and textures.shape[-1] == len( + texture_channels + ), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" + for m, texture in zip(raw_meshes, textures): + texture = texture[: len(m.verts)] + m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))} + + args = dict( + options=options, + texture_channels=texture_channels, + ambient_color=ambient_color, + diffuse_color=diffuse_color, + specular_color=specular_color, + camera=camera, + batch_size=batch_size, + inner_batch_size=inner_batch_size, + inner_shape=inner_shape, + raw_meshes=raw_meshes, + tf_out=tf_out, + ) + + try: + out = _render_with_pytorch3d(**args) + except ModuleNotFoundError as exc: + warnings.warn(f"exception rendering with PyTorch3D: {exc}") + warnings.warn( + "falling back on native PyTorch renderer, which does not support full gradients" + ) + out = _render_with_raycast(**args) + + # Apply mask to prevent gradients for empty meshes. + reshaped_mask = mesh_mask.view([-1] + [1] * (len(out.channels.shape) - 1)) + out.channels = torch.where(reshaped_mask, out.channels, torch.zeros_like(out.channels)) + out.transmittance = torch.where( + reshaped_mask, out.transmittance, torch.ones_like(out.transmittance) + ) + + if output_srgb: + out.channels = _convert_linear_to_srgb(out.channels) + out.channels = out.channels * (1 - out.transmittance) * channel_scale.view(-1) + + # This might be useful information to have downstream + out.raw_meshes = raw_meshes + out.fields = fields + out.mesh_mask = mesh_mask + out.raw_signed_distance = raw_signed_distance + out.aux_losses = AttrDict(cross_entropy=cross_entropy_sdf_loss(fields)) + if raw_density is not None: + out.raw_density = raw_density + + return out + + +def _render_with_pytorch3d( + options: AttrDict, + texture_channels: Sequence[str], + ambient_color: Union[float, Tuple[float]], + diffuse_color: Union[float, Tuple[float]], + specular_color: Union[float, Tuple[float]], + camera: DifferentiableCamera, + batch_size: int, + inner_shape: Sequence[int], + inner_batch_size: int, + raw_meshes: List[TorchMesh], + tf_out: AttrDict, +): + _ = tf_out + + # Lazy import because pytorch3d is installed lazily. + from shap_e.rendering.pytorch3d_util import ( + blender_uniform_lights, + convert_cameras_torch, + convert_meshes, + render_images, + ) + + n_channels = len(texture_channels) + device = camera.origin.device + device_type = device.type + + with torch.autocast(device_type, enabled=False): + meshes = convert_meshes(raw_meshes) + + lights = blender_uniform_lights( + batch_size, + device, + ambient_color=ambient_color, + diffuse_color=diffuse_color, + specular_color=specular_color, + ) + + # Separate camera intrinsics for each view, so that we can + # create a new camera for each batch of views. + cam_shape = [batch_size, inner_batch_size, -1] + position = camera.origin.reshape(cam_shape) + x = camera.x.reshape(cam_shape) + y = camera.y.reshape(cam_shape) + z = camera.z.reshape(cam_shape) + + results = [] + for i in range(inner_batch_size): + sub_cams = convert_cameras_torch( + position[:, i], x[:, i], y[:, i], z[:, i], fov=camera.x_fov + ) + imgs = render_images( + camera.width, + meshes, + sub_cams, + lights, + use_checkpoint=options.checkpoint_render, + **options.get("render_options", {}), + ) + results.append(imgs) + views = torch.stack(results, dim=1) + views = views.view(batch_size, *inner_shape, camera.height, camera.width, n_channels + 1) + + out = AttrDict( + channels=views[..., :-1], # [batch_size, *inner_shape, height, width, n_channels] + transmittance=1 - views[..., -1:], # [batch_size, *inner_shape, height, width, 1] + meshes=meshes, + ) + + return out + + +def _render_with_raycast( + options: AttrDict, + texture_channels: Sequence[str], + ambient_color: Union[float, Tuple[float]], + diffuse_color: Union[float, Tuple[float]], + specular_color: Union[float, Tuple[float]], + camera: DifferentiableCamera, + batch_size: int, + inner_shape: Sequence[int], + inner_batch_size: int, + raw_meshes: List[TorchMesh], + tf_out: AttrDict, +): + assert np.mean(np.array(specular_color)) == 0 + + from shap_e.rendering.raycast.render import render_diffuse_mesh + from shap_e.rendering.raycast.types import TriMesh as TorchTriMesh + + device = camera.origin.device + device_type = device.type + + cam_shape = [batch_size, inner_batch_size, -1] + origin = camera.origin.reshape(cam_shape) + x = camera.x.reshape(cam_shape) + y = camera.y.reshape(cam_shape) + z = camera.z.reshape(cam_shape) + + with torch.autocast(device_type, enabled=False): + all_meshes = [] + for i, mesh in enumerate(raw_meshes): + all_meshes.append( + TorchTriMesh( + faces=mesh.faces.long(), + vertices=mesh.verts.float(), + vertex_colors=tf_out.channels[i, : len(mesh.verts)].float(), + ) + ) + all_images = [] + for i, mesh in enumerate(all_meshes): + for j in range(inner_batch_size): + all_images.append( + render_diffuse_mesh( + camera=ProjectiveCamera( + origin=origin[i, j].detach().cpu().numpy(), + x=x[i, j].detach().cpu().numpy(), + y=y[i, j].detach().cpu().numpy(), + z=z[i, j].detach().cpu().numpy(), + width=camera.width, + height=camera.height, + x_fov=camera.x_fov, + y_fov=camera.y_fov, + ), + mesh=mesh, + diffuse=float(np.array(diffuse_color).mean()), + ambient=float(np.array(ambient_color).mean()), + ray_batch_size=16, # low memory usage + checkpoint=options.checkpoint_render, + ) + ) + + n_channels = len(texture_channels) + views = torch.stack(all_images).view( + batch_size, *inner_shape, camera.height, camera.width, n_channels + 1 + ) + return AttrDict( + channels=views[..., :-1], # [batch_size, *inner_shape, height, width, n_channels] + transmittance=1 - views[..., -1:], # [batch_size, *inner_shape, height, width, 1] + meshes=all_meshes, + ) + + +def _convert_srgb_to_linear(u: torch.Tensor) -> torch.Tensor: + return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) + + +def _convert_linear_to_srgb(u: torch.Tensor) -> torch.Tensor: + return torch.where(u <= 0.0031308, 12.92 * u, 1.055 * (u ** (1 / 2.4)) - 0.055) + + +def cross_entropy_sdf_loss(fields: torch.Tensor): + logits = F.logsigmoid(fields) + signs = (fields > 0).float() + + losses = [] + for dim in range(1, 4): + n = logits.shape[dim] + for (t_start, t_end, p_start, p_end) in [(0, -1, 1, n), (1, n, 0, -1)]: + targets = slice_fields(signs, dim, t_start, t_end) + preds = slice_fields(logits, dim, p_start, p_end) + losses.append( + F.binary_cross_entropy_with_logits(preds, targets, reduction="none") + .flatten(1) + .mean() + ) + return torch.stack(losses, dim=-1).sum() + + +def slice_fields(fields: torch.Tensor, dim: int, start: int, end: int): + if dim == 1: + return fields[:, start:end] + elif dim == 2: + return fields[:, :, start:end] + elif dim == 3: + return fields[:, :, :, start:end] + else: + raise ValueError(f"cannot slice dimension {dim}") + + +def volume_query_points( + volume: Volume, + grid_size: int, +): + assert isinstance(volume, BoundingBoxVolume) + indices = torch.arange(grid_size**3, device=volume.bbox_min.device) + zs = indices % grid_size + ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size + xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size + combined = torch.stack([xs, ys, zs], dim=1) + return (combined.float() / (grid_size - 1)) * ( + volume.bbox_max - volume.bbox_min + ) + volume.bbox_min diff --git a/shap_e/models/transmitter/__init__.py b/shap_e/models/transmitter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/models/transmitter/__pycache__/__init__.cpython-39.pyc b/shap_e/models/transmitter/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b775d8572db7bb0eac47f870ba993e4a5ac08fd5 Binary files /dev/null and b/shap_e/models/transmitter/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/models/transmitter/__pycache__/base.cpython-39.pyc b/shap_e/models/transmitter/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c450825c9b6b704c60464693ba199532ab92101 Binary files /dev/null and b/shap_e/models/transmitter/__pycache__/base.cpython-39.pyc differ diff --git a/shap_e/models/transmitter/__pycache__/bottleneck.cpython-39.pyc b/shap_e/models/transmitter/__pycache__/bottleneck.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ca53d935abc9477840ca3229132590392782ff6 Binary files /dev/null and b/shap_e/models/transmitter/__pycache__/bottleneck.cpython-39.pyc differ diff --git a/shap_e/models/transmitter/__pycache__/channels_encoder.cpython-39.pyc b/shap_e/models/transmitter/__pycache__/channels_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3a913c81be7a35d0890e2bef80fb0c39d160b58 Binary files /dev/null and b/shap_e/models/transmitter/__pycache__/channels_encoder.cpython-39.pyc differ diff --git a/shap_e/models/transmitter/__pycache__/multiview_encoder.cpython-39.pyc b/shap_e/models/transmitter/__pycache__/multiview_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8098d0aa158dd85e29cf17e54cda7dd5b21bd15a Binary files /dev/null and b/shap_e/models/transmitter/__pycache__/multiview_encoder.cpython-39.pyc differ diff --git a/shap_e/models/transmitter/__pycache__/params_proj.cpython-39.pyc b/shap_e/models/transmitter/__pycache__/params_proj.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd6fff93f7e9b6bd7b4286493fb349d6812f9fbe Binary files /dev/null and b/shap_e/models/transmitter/__pycache__/params_proj.cpython-39.pyc differ diff --git a/shap_e/models/transmitter/__pycache__/pc_encoder.cpython-39.pyc b/shap_e/models/transmitter/__pycache__/pc_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e6fb75e8cf68af1a44d83bd482a7fb08a72109e Binary files /dev/null and b/shap_e/models/transmitter/__pycache__/pc_encoder.cpython-39.pyc differ diff --git a/shap_e/models/transmitter/base.py b/shap_e/models/transmitter/base.py new file mode 100644 index 0000000000000000000000000000000000000000..0490f9bc43b6fca469a657b5112433e45eeed473 --- /dev/null +++ b/shap_e/models/transmitter/base.py @@ -0,0 +1,200 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Tuple + +import torch.nn as nn +from torch import torch + +from shap_e.models.renderer import Renderer +from shap_e.util.collections import AttrDict + +from .bottleneck import latent_bottleneck_from_config, latent_warp_from_config +from .params_proj import flatten_param_shapes, params_proj_from_config + + +class Encoder(nn.Module, ABC): + def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]]): + """ + Instantiate the encoder with information about the renderer's input + parameters. This information can be used to create output layers to + generate the necessary latents. + """ + super().__init__() + self.param_shapes = param_shapes + self.device = device + + @abstractmethod + def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: + """ + Encode a batch of data into a batch of latent information. + """ + + +class VectorEncoder(Encoder): + def __init__( + self, + *, + device: torch.device, + param_shapes: Dict[str, Tuple[int]], + params_proj: Dict[str, Any], + d_latent: int, + latent_bottleneck: Optional[Dict[str, Any]] = None, + latent_warp: Optional[Dict[str, Any]] = None, + ): + super().__init__(device=device, param_shapes=param_shapes) + if latent_bottleneck is None: + latent_bottleneck = dict(name="identity") + if latent_warp is None: + latent_warp = dict(name="identity") + self.d_latent = d_latent + self.params_proj = params_proj_from_config( + params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent + ) + self.latent_bottleneck = latent_bottleneck_from_config( + latent_bottleneck, device=device, d_latent=d_latent + ) + self.latent_warp = latent_warp_from_config(latent_warp, device=device) + + def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: + h = self.encode_to_bottleneck(batch, options=options) + return self.bottleneck_to_params(h, options=options) + + def encode_to_bottleneck( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> torch.Tensor: + return self.latent_warp.warp( + self.latent_bottleneck(self.encode_to_vector(batch, options=options), options=options), + options=options, + ) + + @abstractmethod + def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: + """ + Encode the batch into a single latent vector. + """ + + def bottleneck_to_params( + self, vector: torch.Tensor, options: Optional[AttrDict] = None + ) -> AttrDict: + _ = options + return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options) + + +class ChannelsEncoder(VectorEncoder): + def __init__( + self, + *, + device: torch.device, + param_shapes: Dict[str, Tuple[int]], + params_proj: Dict[str, Any], + d_latent: int, + latent_bottleneck: Optional[Dict[str, Any]] = None, + latent_warp: Optional[Dict[str, Any]] = None, + ): + super().__init__( + device=device, + param_shapes=param_shapes, + params_proj=params_proj, + d_latent=d_latent, + latent_bottleneck=latent_bottleneck, + latent_warp=latent_warp, + ) + self.flat_shapes = flatten_param_shapes(param_shapes) + self.latent_ctx = sum(flat[0] for flat in self.flat_shapes.values()) + + @abstractmethod + def encode_to_channels( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> torch.Tensor: + """ + Encode the batch into a per-data-point set of latents. + :return: [batch_size, latent_ctx, latent_width] + """ + + def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: + return self.encode_to_channels(batch, options=options).flatten(1) + + def bottleneck_to_channels( + self, vector: torch.Tensor, options: Optional[AttrDict] = None + ) -> torch.Tensor: + _ = options + return vector.view(vector.shape[0], self.latent_ctx, -1) + + def bottleneck_to_params( + self, vector: torch.Tensor, options: Optional[AttrDict] = None + ) -> AttrDict: + _ = options + # if vector.requires_grad: + # vector.register_hook(lambda grad: print("latent grad", grad.min(), grad.max())) + return self.params_proj( + self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options + ) + + +class Transmitter(nn.Module): + def __init__(self, encoder: Encoder, renderer: Renderer): + super().__init__() + self.encoder = encoder + self.renderer = renderer + + def forward(self, batch: AttrDict, options: Optional[AttrDict] = None) -> AttrDict: + """ + Transmit the batch through the encoder and then the renderer. + """ + params = self.encoder(batch, options=options) + return self.renderer(batch, params=params, options=options) + + +class VectorDecoder(nn.Module): + def __init__( + self, + *, + device: torch.device, + param_shapes: Dict[str, Tuple[int]], + params_proj: Dict[str, Any], + d_latent: int, + latent_warp: Optional[Dict[str, Any]] = None, + renderer: Renderer, + ): + super().__init__() + self.device = device + self.param_shapes = param_shapes + + if latent_warp is None: + latent_warp = dict(name="identity") + self.d_latent = d_latent + self.params_proj = params_proj_from_config( + params_proj, device=device, param_shapes=param_shapes, d_latent=d_latent + ) + self.latent_warp = latent_warp_from_config(latent_warp, device=device) + self.renderer = renderer + + def bottleneck_to_params( + self, vector: torch.Tensor, options: Optional[AttrDict] = None + ) -> AttrDict: + _ = options + return self.params_proj(self.latent_warp.unwarp(vector, options=options), options=options) + + +class ChannelsDecoder(VectorDecoder): + def __init__( + self, + *, + latent_ctx: int, + **kwargs, + ): + super().__init__(**kwargs) + self.latent_ctx = latent_ctx + + def bottleneck_to_channels( + self, vector: torch.Tensor, options: Optional[AttrDict] = None + ) -> torch.Tensor: + _ = options + return vector.view(vector.shape[0], self.latent_ctx, -1) + + def bottleneck_to_params( + self, vector: torch.Tensor, options: Optional[AttrDict] = None + ) -> AttrDict: + _ = options + return self.params_proj( + self.bottleneck_to_channels(self.latent_warp.unwarp(vector)), options=options + ) diff --git a/shap_e/models/transmitter/bottleneck.py b/shap_e/models/transmitter/bottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..9d44e63c7a9cc5254b4cf103a642a25fea625aee --- /dev/null +++ b/shap_e/models/transmitter/bottleneck.py @@ -0,0 +1,127 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +import numpy as np +import torch.nn as nn +from torch import torch + +from shap_e.diffusion.gaussian_diffusion import diffusion_from_config +from shap_e.util.collections import AttrDict + + +class LatentBottleneck(nn.Module, ABC): + def __init__(self, *, device: torch.device, d_latent: int): + super().__init__() + self.device = device + self.d_latent = d_latent + + @abstractmethod + def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + pass + + +class LatentWarp(nn.Module, ABC): + def __init__(self, *, device: torch.device): + super().__init__() + self.device = device + + @abstractmethod + def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + pass + + @abstractmethod + def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + pass + + +class IdentityLatentWarp(LatentWarp): + def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + _ = options + return x + + def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + _ = options + return x + + +class Tan2LatentWarp(LatentWarp): + def __init__(self, *, coeff1: float = 1.0, device: torch.device): + super().__init__(device=device) + self.coeff1 = coeff1 + self.scale = np.tan(np.tan(1.0) * coeff1) + + def warp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + _ = options + return ((x.float().tan() * self.coeff1).tan() / self.scale).to(x.dtype) + + def unwarp(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + _ = options + return ((x.float() * self.scale).arctan() / self.coeff1).arctan().to(x.dtype) + + +class IdentityLatentBottleneck(LatentBottleneck): + def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + _ = options + return x + + +class ClampNoiseBottleneck(LatentBottleneck): + def __init__(self, *, device: torch.device, d_latent: int, noise_scale: float): + super().__init__(device=device, d_latent=d_latent) + self.noise_scale = noise_scale + + def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + _ = options + x = x.tanh() + if not self.training: + return x + return x + torch.randn_like(x) * self.noise_scale + + +class ClampDiffusionNoiseBottleneck(LatentBottleneck): + def __init__( + self, + *, + device: torch.device, + d_latent: int, + diffusion: Dict[str, Any], + diffusion_prob: float = 1.0, + ): + super().__init__(device=device, d_latent=d_latent) + self.diffusion = diffusion_from_config(diffusion) + self.diffusion_prob = diffusion_prob + + def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + _ = options + x = x.tanh() + if not self.training: + return x + t = torch.randint(low=0, high=self.diffusion.num_timesteps, size=(len(x),), device=x.device) + t = torch.where( + torch.rand(len(x), device=x.device) < self.diffusion_prob, t, torch.zeros_like(t) + ) + return self.diffusion.q_sample(x, t) + + +def latent_bottleneck_from_config(config: Dict[str, Any], device: torch.device, d_latent: int): + name = config.pop("name") + if name == "clamp_noise": + return ClampNoiseBottleneck(**config, device=device, d_latent=d_latent) + elif name == "identity": + return IdentityLatentBottleneck(**config, device=device, d_latent=d_latent) + elif name == "clamp_diffusion_noise": + return ClampDiffusionNoiseBottleneck(**config, device=device, d_latent=d_latent) + else: + raise ValueError(f"unknown latent bottleneck: {name}") + + +def latent_warp_from_config(config: Dict[str, Any], device: torch.device): + name = config.pop("name") + if name == "identity": + print("indentity warp") + return IdentityLatentWarp(**config, device=device) + elif name == "tan2": + print("tan2 warp") + return Tan2LatentWarp(**config, device=device) + else: + raise ValueError(f"unknown latent warping function: {name}") diff --git a/shap_e/models/transmitter/channels_encoder.py b/shap_e/models/transmitter/channels_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c39cd3980100354b973cbaa7e3b5e26e5729b288 --- /dev/null +++ b/shap_e/models/transmitter/channels_encoder.py @@ -0,0 +1,959 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torch import torch + +from shap_e.models.generation.perceiver import SimplePerceiver +from shap_e.models.generation.transformer import Transformer +from shap_e.models.nn.camera import DifferentiableProjectiveCamera +from shap_e.models.nn.encoding import ( + MultiviewPointCloudEmbedding, + MultiviewPoseEmbedding, + PosEmbLinear, +) +from shap_e.models.nn.ops import PointSetEmbedding +from shap_e.rendering.point_cloud import PointCloud +from shap_e.rendering.view_data import ProjectiveCamera +from shap_e.util.collections import AttrDict + +from .base import ChannelsEncoder + + +class TransformerChannelsEncoder(ChannelsEncoder, ABC): + """ + Encode point clouds using a transformer model with an extra output + token used to extract a latent vector. + """ + + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + param_shapes: Dict[str, Tuple[int]], + params_proj: Dict[str, Any], + d_latent: int = 512, + latent_bottleneck: Optional[Dict[str, Any]] = None, + latent_warp: Optional[Dict[str, Any]] = None, + n_ctx: int = 1024, + width: int = 512, + layers: int = 12, + heads: int = 8, + init_scale: float = 0.25, + latent_scale: float = 1.0, + ): + super().__init__( + device=device, + param_shapes=param_shapes, + params_proj=params_proj, + d_latent=d_latent, + latent_bottleneck=latent_bottleneck, + latent_warp=latent_warp, + ) + self.width = width + self.device = device + self.dtype = dtype + + self.n_ctx = n_ctx + + self.backbone = Transformer( + device=device, + dtype=dtype, + n_ctx=n_ctx + self.latent_ctx, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + ) + self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.register_parameter( + "output_tokens", + nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)), + ) + self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype) + self.latent_scale = latent_scale + + @abstractmethod + def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: + pass + + def encode_to_channels( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> torch.Tensor: + h = self.encode_input(batch, options=options) + h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1) + h = self.ln_pre(h) + h = self.backbone(h) + h = h[:, -self.latent_ctx :] + h = self.ln_post(h) + h = self.output_proj(h) + return h + + +class PerceiverChannelsEncoder(ChannelsEncoder, ABC): + """ + Encode point clouds using a perceiver model with an extra output + token used to extract a latent vector. + """ + + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + param_shapes: Dict[str, Tuple[int]], + params_proj: Dict[str, Any], + min_unrolls: int, + max_unrolls: int, + d_latent: int = 512, + latent_bottleneck: Optional[Dict[str, Any]] = None, + latent_warp: Optional[Dict[str, Any]] = None, + width: int = 512, + layers: int = 12, + xattn_layers: int = 1, + heads: int = 8, + init_scale: float = 0.25, + # Training hparams + inner_batch_size: Union[int, List[int]] = 1, + data_ctx: int = 1, + ): + super().__init__( + device=device, + param_shapes=param_shapes, + params_proj=params_proj, + d_latent=d_latent, + latent_bottleneck=latent_bottleneck, + latent_warp=latent_warp, + ) + self.width = width + self.device = device + self.dtype = dtype + + if isinstance(inner_batch_size, int): + inner_batch_size = [inner_batch_size] + self.inner_batch_size = inner_batch_size + self.data_ctx = data_ctx + self.min_unrolls = min_unrolls + self.max_unrolls = max_unrolls + + encoder_fn = lambda inner_batch_size: SimplePerceiver( + device=device, + dtype=dtype, + n_ctx=self.data_ctx + self.latent_ctx, + n_data=inner_batch_size, + width=width, + layers=xattn_layers, + heads=heads, + init_scale=init_scale, + ) + self.encoder = ( + encoder_fn(self.inner_batch_size[0]) + if len(self.inner_batch_size) == 1 + else nn.ModuleList([encoder_fn(inner_bsz) for inner_bsz in self.inner_batch_size]) + ) + self.processor = Transformer( + device=device, + dtype=dtype, + n_ctx=self.data_ctx + self.latent_ctx, + layers=layers - xattn_layers, + width=width, + heads=heads, + init_scale=init_scale, + ) + self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.register_parameter( + "output_tokens", + nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)), + ) + self.output_proj = nn.Linear(width, d_latent, device=device, dtype=dtype) + + @abstractmethod + def get_h_and_iterator( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> Tuple[torch.Tensor, Iterable[Union[torch.Tensor, Tuple]]]: + """ + :return: a tuple of ( + the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], + an iterator over the given data + ) + """ + + def encode_to_channels( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> torch.Tensor: + h, it = self.get_h_and_iterator(batch, options=options) + n_unrolls = self.get_n_unrolls() + + for _ in range(n_unrolls): + data = next(it) + if isinstance(data, tuple): + for data_i, encoder_i in zip(data, self.encoder): + h = encoder_i(h, data_i) + else: + h = self.encoder(h, data) + h = self.processor(h) + + h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :])) + return h + + def get_n_unrolls(self): + if self.training: + n_unrolls = torch.randint( + self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device + ) + dist.broadcast(n_unrolls, 0) + n_unrolls = n_unrolls.item() + else: + n_unrolls = self.max_unrolls + return n_unrolls + + +@dataclass +class DatasetIterator: + + embs: torch.Tensor # [batch_size, dataset_size, *shape] + batch_size: int + + def __iter__(self): + self._reset() + return self + + def __next__(self): + _outer_batch_size, dataset_size, *_shape = self.embs.shape + + while True: + start = self.idx + self.idx += self.batch_size + end = self.idx + if end <= dataset_size: + break + self._reset() + + return self.embs[:, start:end] + + def _reset(self): + self._shuffle() + self.idx = 0 # pylint: disable=attribute-defined-outside-init + + def _shuffle(self): + outer_batch_size, dataset_size, *shape = self.embs.shape + idx = torch.stack( + [ + torch.randperm(dataset_size, device=self.embs.device) + for _ in range(outer_batch_size) + ], + dim=0, + ) + idx = idx.view(outer_batch_size, dataset_size, *([1] * len(shape))) + idx = torch.broadcast_to(idx, self.embs.shape) + self.embs = torch.gather(self.embs, 1, idx) + + +class PointCloudTransformerChannelsEncoder(TransformerChannelsEncoder): + """ + Encode point clouds using a transformer model with an extra output + token used to extract a latent vector. + """ + + def __init__( + self, + *, + input_channels: int = 6, + **kwargs, + ): + super().__init__(**kwargs) + self.input_channels = input_channels + self.input_proj = nn.Linear( + input_channels, self.width, device=self.device, dtype=self.dtype + ) + + def encode_input(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: + _ = options + points = batch.points + h = self.input_proj(points.permute(0, 2, 1)) # NCL -> NLC + return h + + +class PointCloudPerceiverChannelsEncoder(PerceiverChannelsEncoder): + """ + Encode point clouds using a transformer model with an extra output + token used to extract a latent vector. + """ + + def __init__( + self, + *, + cross_attention_dataset: str = "pcl", + fps_method: str = "fps", + # point cloud hyperparameters + input_channels: int = 6, + pos_emb: Optional[str] = None, + # multiview hyperparameters + image_size: int = 256, + patch_size: int = 32, + pose_dropout: float = 0.0, + use_depth: bool = False, + max_depth: float = 5.0, + # point conv hyperparameters + pointconv_radius: float = 0.5, + pointconv_samples: int = 32, + pointconv_hidden: Optional[List[int]] = None, + pointconv_patch_size: int = 1, + pointconv_stride: int = 1, + pointconv_padding_mode: str = "zeros", + use_pointconv: bool = False, + # other hyperparameters + **kwargs, + ): + super().__init__(**kwargs) + assert cross_attention_dataset in ( + "pcl", + "multiview", + "dense_pose_multiview", + "multiview_pcl", + "pcl_and_multiview_pcl", + "incorrect_multiview_pcl", + "pcl_and_incorrect_multiview_pcl", + ) + assert fps_method in ("fps", "first") + self.cross_attention_dataset = cross_attention_dataset + self.fps_method = fps_method + self.input_channels = input_channels + self.input_proj = PosEmbLinear( + pos_emb, + input_channels, + self.width, + device=self.device, + dtype=self.dtype, + ) + self.use_pointconv = use_pointconv + if use_pointconv: + if pointconv_hidden is None: + pointconv_hidden = [self.width] + self.point_conv = PointSetEmbedding( + n_point=self.data_ctx, + radius=pointconv_radius, + n_sample=pointconv_samples, + d_input=self.input_proj.weight.shape[0], + d_hidden=pointconv_hidden, + patch_size=pointconv_patch_size, + stride=pointconv_stride, + padding_mode=pointconv_padding_mode, + fps_method=fps_method, + device=self.device, + dtype=self.dtype, + ) + if self.cross_attention_dataset == "multiview": + self.image_size = image_size + self.patch_size = patch_size + self.pose_dropout = pose_dropout + self.use_depth = use_depth + self.max_depth = max_depth + pos_ctx = (image_size // patch_size) ** 2 + self.register_parameter( + "pos_emb", + nn.Parameter( + torch.randn( + pos_ctx * self.inner_batch_size, + self.width, + device=self.device, + dtype=self.dtype, + ) + ), + ) + self.patch_emb = nn.Conv2d( + in_channels=3 if not use_depth else 4, + out_channels=self.width, + kernel_size=patch_size, + stride=patch_size, + device=self.device, + dtype=self.dtype, + ) + self.camera_emb = nn.Sequential( + nn.Linear( + 3 * 4 + 1, self.width, device=self.device, dtype=self.dtype + ), # input size is for origin+x+y+z+fov + nn.GELU(), + nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype), + ) + elif self.cross_attention_dataset == "dense_pose_multiview": + # The number of output features is halved, because a patch_size of + # 32 ends up with a large patch_emb weight. + self.view_pose_width = self.width // 2 + self.image_size = image_size + self.patch_size = patch_size + self.use_depth = use_depth + self.max_depth = max_depth + self.mv_pose_embed = MultiviewPoseEmbedding( + posemb_version="nerf", + n_channels=4 if self.use_depth else 3, + out_features=self.view_pose_width, + device=self.device, + dtype=self.dtype, + ) + pos_ctx = (image_size // patch_size) ** 2 + # Positional embedding is unnecessary because pose information is baked into each pixel + self.patch_emb = nn.Conv2d( + in_channels=self.view_pose_width, + out_channels=self.width, + kernel_size=patch_size, + stride=patch_size, + device=self.device, + dtype=self.dtype, + ) + + elif ( + self.cross_attention_dataset == "multiview_pcl" + or self.cross_attention_dataset == "incorrect_multiview_pcl" + ): + self.view_pose_width = self.width // 2 + self.image_size = image_size + self.patch_size = patch_size + self.max_depth = max_depth + assert use_depth + self.mv_pcl_embed = MultiviewPointCloudEmbedding( + posemb_version="nerf", + n_channels=3, + out_features=self.view_pose_width, + device=self.device, + dtype=self.dtype, + ) + self.patch_emb = nn.Conv2d( + in_channels=self.view_pose_width, + out_channels=self.width, + kernel_size=patch_size, + stride=patch_size, + device=self.device, + dtype=self.dtype, + ) + + elif ( + self.cross_attention_dataset == "pcl_and_multiview_pcl" + or self.cross_attention_dataset == "pcl_and_incorrect_multiview_pcl" + ): + self.view_pose_width = self.width // 2 + self.image_size = image_size + self.patch_size = patch_size + self.max_depth = max_depth + assert use_depth + self.mv_pcl_embed = MultiviewPointCloudEmbedding( + posemb_version="nerf", + n_channels=3, + out_features=self.view_pose_width, + device=self.device, + dtype=self.dtype, + ) + self.patch_emb = nn.Conv2d( + in_channels=self.view_pose_width, + out_channels=self.width, + kernel_size=patch_size, + stride=patch_size, + device=self.device, + dtype=self.dtype, + ) + + def get_h_and_iterator( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> Tuple[torch.Tensor, Iterable]: + """ + :return: a tuple of ( + the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], + an iterator over the given data + ) + """ + options = AttrDict() if options is None else options + + # Build the initial query embeddings + points = batch.points.permute(0, 2, 1) # NCL -> NLC + if self.use_pointconv: + points = self.input_proj(points).permute(0, 2, 1) # NLC -> NCL + xyz = batch.points[:, :3] + data_tokens = self.point_conv(xyz, points).permute(0, 2, 1) # NCL -> NLC + else: + fps_samples = self.sample_pcl_fps(points) + data_tokens = self.input_proj(fps_samples) + batch_size = points.shape[0] + latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1) + h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1)) + assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width) + + # Build the dataset embedding iterator + dataset_fn = { + "pcl": self.get_pcl_dataset, + "multiview": self.get_multiview_dataset, + "dense_pose_multiview": self.get_dense_pose_multiview_dataset, + "pcl_and_multiview_pcl": self.get_pcl_and_multiview_pcl_dataset, + "multiview_pcl": self.get_multiview_pcl_dataset, + }[self.cross_attention_dataset] + it = dataset_fn(batch, options=options) + + return h, it + + def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor: + return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method) + + def get_pcl_dataset( + self, + batch: AttrDict, + options: Optional[AttrDict[str, Any]] = None, + inner_batch_size: Optional[int] = None, + ) -> Iterable: + _ = options + if inner_batch_size is None: + inner_batch_size = self.inner_batch_size[0] + points = batch.points.permute(0, 2, 1) # NCL -> NLC + dataset_emb = self.input_proj(points) + assert dataset_emb.shape[1] >= inner_batch_size + return iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) + + def get_multiview_dataset( + self, + batch: AttrDict, + options: Optional[AttrDict] = None, + inner_batch_size: Optional[int] = None, + ) -> Iterable: + _ = options + + if inner_batch_size is None: + inner_batch_size = self.inner_batch_size[0] + + dataset_emb = self.encode_views(batch) + batch_size, num_views, n_patches, width = dataset_emb.shape + + assert num_views >= inner_batch_size + + it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) + + def gen(): + while True: + examples = next(it) + assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width) + views = examples.reshape(batch_size, -1, width) + self.pos_emb + yield views + + return gen() + + def get_dense_pose_multiview_dataset( + self, + batch: AttrDict, + options: Optional[AttrDict] = None, + inner_batch_size: Optional[int] = None, + ) -> Iterable: + _ = options + + if inner_batch_size is None: + inner_batch_size = self.inner_batch_size[0] + + dataset_emb = self.encode_dense_pose_views(batch) + batch_size, num_views, n_patches, width = dataset_emb.shape + + assert num_views >= inner_batch_size + + it = iter(DatasetIterator(dataset_emb, batch_size=inner_batch_size)) + + def gen(): + while True: + examples = next(it) + assert examples.shape == (batch_size, inner_batch_size, n_patches, self.width) + views = examples.reshape(batch_size, -1, width) + yield views + + return gen() + + def get_pcl_and_multiview_pcl_dataset( + self, + batch: AttrDict, + options: Optional[AttrDict] = None, + use_distance: bool = True, + ) -> Iterable: + _ = options + + pcl_it = self.get_pcl_dataset( + batch, options=options, inner_batch_size=self.inner_batch_size[0] + ) + multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance) + batch_size, num_views, n_patches, width = multiview_pcl_emb.shape + + assert num_views >= self.inner_batch_size[1] + + multiview_pcl_it = iter( + DatasetIterator(multiview_pcl_emb, batch_size=self.inner_batch_size[1]) + ) + + def gen(): + while True: + pcl = next(pcl_it) + multiview_pcl = next(multiview_pcl_it) + assert multiview_pcl.shape == ( + batch_size, + self.inner_batch_size[1], + n_patches, + self.width, + ) + yield pcl, multiview_pcl.reshape(batch_size, -1, width) + + return gen() + + def get_multiview_pcl_dataset( + self, + batch: AttrDict, + options: Optional[AttrDict] = None, + inner_batch_size: Optional[int] = None, + use_distance: bool = True, + ) -> Iterable: + _ = options + + if inner_batch_size is None: + inner_batch_size = self.inner_batch_size[0] + + multiview_pcl_emb = self.encode_multiview_pcl(batch, use_distance=use_distance) + batch_size, num_views, n_patches, width = multiview_pcl_emb.shape + + assert num_views >= inner_batch_size + + multiview_pcl_it = iter(DatasetIterator(multiview_pcl_emb, batch_size=inner_batch_size)) + + def gen(): + while True: + multiview_pcl = next(multiview_pcl_it) + assert multiview_pcl.shape == ( + batch_size, + inner_batch_size, + n_patches, + self.width, + ) + yield multiview_pcl.reshape(batch_size, -1, width) + + return gen() + + def encode_views(self, batch: AttrDict) -> torch.Tensor: + """ + :return: [batch_size, num_views, n_patches, width] + """ + all_views = self.views_to_tensor(batch.views).to(self.device) + if self.use_depth: + all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2) + all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device) + + batch_size, num_views, _, _, _ = all_views.shape + + views_proj = self.patch_emb( + all_views.reshape([batch_size * num_views, *all_views.shape[2:]]) + ) + views_proj = ( + views_proj.reshape([batch_size, num_views, self.width, -1]) + .permute(0, 1, 3, 2) + .contiguous() + ) # [batch_size x num_views x n_patches x width] + + # [batch_size, num_views, 1, 2 * width] + camera_proj = self.camera_emb(all_cameras).reshape( + [batch_size, num_views, 1, self.width * 2] + ) + pose_dropout = self.pose_dropout if self.training else 0.0 + mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout + camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj)) + scale, shift = camera_proj.chunk(2, dim=3) + views_proj = views_proj * (scale + 1.0) + shift + return views_proj + + def encode_dense_pose_views(self, batch: AttrDict) -> torch.Tensor: + """ + :return: [batch_size, num_views, n_patches, width] + """ + all_views = self.views_to_tensor(batch.views).to(self.device) + if self.use_depth: + depths = self.depths_to_tensor(batch.depths) + all_views = torch.cat([all_views, depths], dim=2) + + dense_poses, _ = self.dense_pose_cameras_to_tensor(batch.cameras) + dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3) + position, direction = dense_poses[:, :, 0], dense_poses[:, :, 1] + all_view_poses = self.mv_pose_embed(all_views, position, direction) + + batch_size, num_views, _, _, _ = all_view_poses.shape + + views_proj = self.patch_emb( + all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]]) + ) + views_proj = ( + views_proj.reshape([batch_size, num_views, self.width, -1]) + .permute(0, 1, 3, 2) + .contiguous() + ) # [batch_size x num_views x n_patches x width] + + return views_proj + + def encode_multiview_pcl(self, batch: AttrDict, use_distance: bool = True) -> torch.Tensor: + """ + :return: [batch_size, num_views, n_patches, width] + """ + all_views = self.views_to_tensor(batch.views).to(self.device) + depths = self.raw_depths_to_tensor(batch.depths) + all_view_alphas = self.view_alphas_to_tensor(batch.view_alphas).to(self.device) + mask = all_view_alphas >= 0.999 + + dense_poses, camera_z = self.dense_pose_cameras_to_tensor(batch.cameras) + dense_poses = dense_poses.permute(0, 1, 4, 5, 2, 3) + + origin, direction = dense_poses[:, :, 0], dense_poses[:, :, 1] + if use_distance: + ray_depth_factor = torch.sum(direction * camera_z[..., None, None], dim=2, keepdim=True) + depths = depths / ray_depth_factor + position = origin + depths * direction + all_view_poses = self.mv_pcl_embed(all_views, origin, position, mask) + + batch_size, num_views, _, _, _ = all_view_poses.shape + + views_proj = self.patch_emb( + all_view_poses.reshape([batch_size * num_views, *all_view_poses.shape[2:]]) + ) + views_proj = ( + views_proj.reshape([batch_size, num_views, self.width, -1]) + .permute(0, 1, 3, 2) + .contiguous() + ) # [batch_size x num_views x n_patches x width] + + return views_proj + + def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor: + """ + Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1]. + """ + if isinstance(views, torch.Tensor): + return views + + tensor_batch = [] + num_views = len(views[0]) + for inner_list in views: + assert len(inner_list) == num_views + inner_batch = [] + for img in inner_list: + img = img.resize((self.image_size,) * 2).convert("RGB") + inner_batch.append( + torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) + / 127.5 + - 1 + ) + tensor_batch.append(torch.stack(inner_batch, dim=0)) + return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3) + + def depths_to_tensor( + self, depths: Union[torch.Tensor, List[List[Image.Image]]] + ) -> torch.Tensor: + """ + Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1]. + """ + if isinstance(depths, torch.Tensor): + return depths + + tensor_batch = [] + num_views = len(depths[0]) + for inner_list in depths: + assert len(inner_list) == num_views + inner_batch = [] + for arr in inner_list: + tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth + tensor = tensor * 2 - 1 + tensor = F.interpolate( + tensor[None, None], + (self.image_size,) * 2, + mode="nearest", + ) + inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) + tensor_batch.append(torch.cat(inner_batch, dim=0)) + return torch.stack(tensor_batch, dim=0) + + def view_alphas_to_tensor( + self, view_alphas: Union[torch.Tensor, List[List[Image.Image]]] + ) -> torch.Tensor: + """ + Returns a [batch x num_views x 1 x size x size] tensor in the range [0, 1]. + """ + if isinstance(view_alphas, torch.Tensor): + return view_alphas + + tensor_batch = [] + num_views = len(view_alphas[0]) + for inner_list in view_alphas: + assert len(inner_list) == num_views + inner_batch = [] + for img in inner_list: + tensor = ( + torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) + / 255.0 + ) + tensor = F.interpolate( + tensor[None, None], + (self.image_size,) * 2, + mode="nearest", + ) + inner_batch.append(tensor) + tensor_batch.append(torch.cat(inner_batch, dim=0)) + return torch.stack(tensor_batch, dim=0) + + def raw_depths_to_tensor( + self, depths: Union[torch.Tensor, List[List[Image.Image]]] + ) -> torch.Tensor: + """ + Returns a [batch x num_views x 1 x size x size] tensor + """ + if isinstance(depths, torch.Tensor): + return depths + + tensor_batch = [] + num_views = len(depths[0]) + for inner_list in depths: + assert len(inner_list) == num_views + inner_batch = [] + for arr in inner_list: + tensor = torch.from_numpy(arr).clamp(max=self.max_depth) + tensor = F.interpolate( + tensor[None, None], + (self.image_size,) * 2, + mode="nearest", + ) + inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) + tensor_batch.append(torch.cat(inner_batch, dim=0)) + return torch.stack(tensor_batch, dim=0) + + def cameras_to_tensor( + self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] + ) -> torch.Tensor: + """ + Returns a [batch x num_views x 3*4+1] tensor of camera information. + """ + if isinstance(cameras, torch.Tensor): + return cameras + outer_batch = [] + for inner_list in cameras: + inner_batch = [] + for camera in inner_list: + inner_batch.append( + np.array( + [ + *camera.x, + *camera.y, + *camera.z, + *camera.origin, + camera.x_fov, + ] + ) + ) + outer_batch.append(np.stack(inner_batch, axis=0)) + return torch.from_numpy(np.stack(outer_batch, axis=0)).float() + + def dense_pose_cameras_to_tensor( + self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns a tuple of (rays, z_directions) where + - rays: [batch, num_views, height, width, 2, 3] tensor of camera information. + - z_directions: [batch, num_views, 3] tensor of camera z directions. + """ + if isinstance(cameras, torch.Tensor): + raise NotImplementedError + + for inner_list in cameras: + assert len(inner_list) == len(cameras[0]) + + camera = cameras[0][0] + flat_camera = DifferentiableProjectiveCamera( + origin=torch.from_numpy( + np.stack( + [cam.origin for inner_list in cameras for cam in inner_list], + axis=0, + ) + ).to(self.device), + x=torch.from_numpy( + np.stack( + [cam.x for inner_list in cameras for cam in inner_list], + axis=0, + ) + ).to(self.device), + y=torch.from_numpy( + np.stack( + [cam.y for inner_list in cameras for cam in inner_list], + axis=0, + ) + ).to(self.device), + z=torch.from_numpy( + np.stack( + [cam.z for inner_list in cameras for cam in inner_list], + axis=0, + ) + ).to(self.device), + width=camera.width, + height=camera.height, + x_fov=camera.x_fov, + y_fov=camera.y_fov, + ) + batch_size = len(cameras) * len(cameras[0]) + coords = ( + flat_camera.image_coords() + .to(flat_camera.origin.device) + .unsqueeze(0) + .repeat(batch_size, 1, 1) + ) + rays = flat_camera.camera_rays(coords) + return ( + rays.view(len(cameras), len(cameras[0]), camera.height, camera.width, 2, 3).to( + self.device + ), + flat_camera.z.view(len(cameras), len(cameras[0]), 3).to(self.device), + ) + + +def sample_pcl_fps(points: torch.Tensor, data_ctx: int, method: str = "fps") -> torch.Tensor: + """ + Run farthest-point sampling on a batch of point clouds. + + :param points: batch of shape [N x num_points]. + :param data_ctx: subsample count. + :param method: either 'fps' or 'first'. Using 'first' assumes that the + points are already sorted according to FPS sampling. + :return: batch of shape [N x min(num_points, data_ctx)]. + """ + n_points = points.shape[1] + if n_points == data_ctx: + return points + if method == "first": + return points[:, :data_ctx] + elif method == "fps": + batch = points.cpu().split(1, dim=0) + fps = [sample_fps(x, n_samples=data_ctx) for x in batch] + return torch.cat(fps, dim=0).to(points.device) + else: + raise ValueError(f"unsupported farthest-point sampling method: {method}") + + +def sample_fps(example: torch.Tensor, n_samples: int) -> torch.Tensor: + """ + :param example: [1, n_points, 3 + n_channels] + :return: [1, n_samples, 3 + n_channels] + """ + points = example.cpu().squeeze(0).numpy() + coords, raw_channels = points[:, :3], points[:, 3:] + n_points, n_channels = raw_channels.shape + assert n_samples <= n_points + channels = {str(idx): raw_channels[:, idx] for idx in range(n_channels)} + max_points = min(32768, n_points) + fps_pcl = ( + PointCloud(coords=coords, channels=channels) + .random_sample(max_points) + .farthest_point_sample(n_samples) + ) + fps_channels = np.stack([fps_pcl.channels[str(idx)] for idx in range(n_channels)], axis=1) + fps = np.concatenate([fps_pcl.coords, fps_channels], axis=1) + fps = torch.from_numpy(fps).unsqueeze(0) + assert fps.shape == (1, n_samples, 3 + n_channels) + return fps diff --git a/shap_e/models/transmitter/multiview_encoder.py b/shap_e/models/transmitter/multiview_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..904e36eb2c6bd2db4e920e631137c2ee4bde85b7 --- /dev/null +++ b/shap_e/models/transmitter/multiview_encoder.py @@ -0,0 +1,201 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image + +from shap_e.models.generation.transformer import Transformer +from shap_e.rendering.view_data import ProjectiveCamera +from shap_e.util.collections import AttrDict + +from .base import VectorEncoder + + +class MultiviewTransformerEncoder(VectorEncoder): + """ + Encode cameras and views using a transformer model with extra output + token(s) used to extract a latent vector. + """ + + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + param_shapes: Dict[str, Tuple[int]], + params_proj: Dict[str, Any], + latent_bottleneck: Optional[Dict[str, Any]] = None, + d_latent: int = 512, + latent_ctx: int = 1, + num_views: int = 20, + image_size: int = 256, + patch_size: int = 32, + use_depth: bool = False, + max_depth: float = 5.0, + width: int = 512, + layers: int = 12, + heads: int = 8, + init_scale: float = 0.25, + pos_emb_init_scale: float = 1.0, + ): + super().__init__( + device=device, + param_shapes=param_shapes, + params_proj=params_proj, + latent_bottleneck=latent_bottleneck, + d_latent=d_latent, + ) + self.num_views = num_views + self.image_size = image_size + self.patch_size = patch_size + self.use_depth = use_depth + self.max_depth = max_depth + self.n_ctx = num_views * (1 + (image_size // patch_size) ** 2) + self.latent_ctx = latent_ctx + self.width = width + + assert d_latent % latent_ctx == 0 + + self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) + self.backbone = Transformer( + device=device, + dtype=dtype, + n_ctx=self.n_ctx + latent_ctx, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + ) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.register_parameter( + "output_tokens", + nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)), + ) + self.register_parameter( + "pos_emb", + nn.Parameter( + pos_emb_init_scale * torch.randn(self.n_ctx, width, device=device, dtype=dtype) + ), + ) + self.patch_emb = nn.Conv2d( + in_channels=3 if not use_depth else 4, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + device=device, + dtype=dtype, + ) + self.camera_emb = nn.Sequential( + nn.Linear( + 3 * 4 + 1, width, device=device, dtype=dtype + ), # input size is for origin+x+y+z+fov + nn.GELU(), + nn.Linear(width, width, device=device, dtype=dtype), + ) + self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype) + + def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: + _ = options + + all_views = self.views_to_tensor(batch.views).to(self.device) + if self.use_depth: + all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2) + all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device) + + batch_size, num_views, _, _, _ = all_views.shape + + views_proj = self.patch_emb( + all_views.reshape([batch_size * num_views, *all_views.shape[2:]]) + ) + views_proj = ( + views_proj.reshape([batch_size, num_views, self.width, -1]) + .permute(0, 1, 3, 2) + .contiguous() + ) # [batch_size x num_views x n_patches x width] + + cameras_proj = self.camera_emb(all_cameras).reshape([batch_size, num_views, 1, self.width]) + + h = torch.cat([views_proj, cameras_proj], dim=2).reshape([batch_size, -1, self.width]) + h = h + self.pos_emb + h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1) + h = self.ln_pre(h) + h = self.backbone(h) + h = self.ln_post(h) + h = h[:, self.n_ctx :] + h = self.output_proj(h).flatten(1) + + return h + + def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor: + """ + Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1]. + """ + if isinstance(views, torch.Tensor): + return views + + tensor_batch = [] + for inner_list in views: + assert len(inner_list) == self.num_views + inner_batch = [] + for img in inner_list: + img = img.resize((self.image_size,) * 2).convert("RGB") + inner_batch.append( + torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) + / 127.5 + - 1 + ) + tensor_batch.append(torch.stack(inner_batch, dim=0)) + return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3) + + def depths_to_tensor( + self, depths: Union[torch.Tensor, List[List[Image.Image]]] + ) -> torch.Tensor: + """ + Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1]. + """ + if isinstance(depths, torch.Tensor): + return depths + + tensor_batch = [] + for inner_list in depths: + assert len(inner_list) == self.num_views + inner_batch = [] + for arr in inner_list: + tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth + tensor = tensor * 2 - 1 + tensor = F.interpolate( + tensor[None, None], + (self.image_size,) * 2, + mode="nearest", + ) + inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) + tensor_batch.append(torch.cat(inner_batch, dim=0)) + return torch.stack(tensor_batch, dim=0) + + def cameras_to_tensor( + self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] + ) -> torch.Tensor: + """ + Returns a [batch x num_views x 3*4+1] tensor of camera information. + """ + if isinstance(cameras, torch.Tensor): + return cameras + outer_batch = [] + for inner_list in cameras: + inner_batch = [] + for camera in inner_list: + inner_batch.append( + np.array( + [ + *camera.x, + *camera.y, + *camera.z, + *camera.origin, + camera.x_fov, + ] + ) + ) + outer_batch.append(np.stack(inner_batch, axis=0)) + return torch.from_numpy(np.stack(outer_batch, axis=0)).float() diff --git a/shap_e/models/transmitter/params_proj.py b/shap_e/models/transmitter/params_proj.py new file mode 100644 index 0000000000000000000000000000000000000000..5097e58888465b3eb30c7842f023530e98bba9ac --- /dev/null +++ b/shap_e/models/transmitter/params_proj.py @@ -0,0 +1,199 @@ +import math +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any, Dict, Optional, Tuple + +import numpy as np +import torch.nn as nn +from torch import torch + +from shap_e.util.collections import AttrDict + + +def flatten_param_shapes(param_shapes: Dict[str, Tuple[int]]): + flat_shapes = OrderedDict( + (name, (int(np.prod(shape)) // shape[-1], shape[-1])) + for name, shape in param_shapes.items() + ) + return flat_shapes + + +class ParamsProj(nn.Module, ABC): + def __init__(self, *, device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int): + super().__init__() + self.device = device + self.param_shapes = param_shapes + self.d_latent = d_latent + + @abstractmethod + def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + pass + + +class LinearParamsProj(ParamsProj): + def __init__( + self, + *, + device: torch.device, + param_shapes: Dict[str, Tuple[int]], + d_latent: int, + init_scale: Optional[float] = None, + ): + super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) + self.param_shapes = param_shapes + self.projections = nn.ModuleDict({}) + for k, v in param_shapes.items(): + self.projections[_sanitize_name(k)] = nn.Linear( + d_latent, int(np.prod(v)), device=device + ) + if init_scale is not None: + scale = init_scale / math.sqrt(d_latent) + mod = self.projections[_sanitize_name(k)] + nn.init.normal_(mod.weight, std=scale) + nn.init.zeros_(mod.bias) + + def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + out = AttrDict() + for k in self.param_shapes.keys(): + proj = self.projections[_sanitize_name(k)] + out[k] = proj(x).reshape([len(x), *self.param_shapes[k]]) + return out + + +class MLPParamsProj(ParamsProj): + def __init__( + self, + *, + device: torch.device, + param_shapes: Dict[str, Tuple[int]], + d_latent: int, + hidden_size: Optional[int] = None, + ): + super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) + if hidden_size is None: + hidden_size = d_latent + self.param_shapes = param_shapes + self.projections = nn.ModuleDict({}) + for k, v in param_shapes.items(): + self.projections[_sanitize_name(k)] = nn.Sequential( + nn.Linear(d_latent, hidden_size, device=device), + nn.GELU(), + nn.Linear(hidden_size, int(np.prod(v)), device=device), + ) + + def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + out = AttrDict() + for k in self.param_shapes.keys(): + proj = self.projections[_sanitize_name(k)] + out[k] = proj(x).reshape([len(x), *self.param_shapes[k]]) + return out + + +class ChannelsProj(nn.Module): + def __init__( + self, + *, + device: torch.device, + vectors: int, + channels: int, + d_latent: int, + init_scale: float = 1.0, + learned_scale: Optional[float] = None, + use_ln: bool = False, + ): + super().__init__() + self.proj = nn.Linear(d_latent, vectors * channels, device=device) + self.use_ln = use_ln + self.learned_scale = learned_scale + if use_ln: + self.norm = nn.LayerNorm(normalized_shape=(channels,), device=device) + if learned_scale is not None: + self.norm.weight.data.fill_(learned_scale) + scale = init_scale / math.sqrt(d_latent) + elif learned_scale is not None: + gain = torch.ones((channels,), device=device) * learned_scale + self.register_parameter("gain", nn.Parameter(gain)) + scale = init_scale / math.sqrt(d_latent) + else: + scale = init_scale / math.sqrt(d_latent * channels) + nn.init.normal_(self.proj.weight, std=scale) + nn.init.zeros_(self.proj.bias) + self.d_latent = d_latent + self.vectors = vectors + self.channels = channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_bvd = x + w_vcd = self.proj.weight.view(self.vectors, self.channels, self.d_latent) + b_vc = self.proj.bias.view(1, self.vectors, self.channels) + h = torch.einsum("bvd,vcd->bvc", x_bvd, w_vcd) + if self.use_ln: + h = self.norm(h) + elif self.learned_scale is not None: + h = h * self.gain.view(1, 1, -1) + h = h + b_vc + return h + + +class ChannelsParamsProj(ParamsProj): + def __init__( + self, + *, + device: torch.device, + param_shapes: Dict[str, Tuple[int]], + d_latent: int, + init_scale: float = 1.0, + learned_scale: Optional[float] = None, + use_ln: bool = False, + ): + super().__init__(device=device, param_shapes=param_shapes, d_latent=d_latent) + self.param_shapes = param_shapes + self.projections = nn.ModuleDict({}) + self.flat_shapes = flatten_param_shapes(param_shapes) + self.learned_scale = learned_scale + self.use_ln = use_ln + for k, (vectors, channels) in self.flat_shapes.items(): + self.projections[_sanitize_name(k)] = ChannelsProj( + device=device, + vectors=vectors, + channels=channels, + d_latent=d_latent, + init_scale=init_scale, + learned_scale=learned_scale, + use_ln=use_ln, + ) + + def forward(self, x: torch.Tensor, options: Optional[AttrDict] = None) -> AttrDict: + out = AttrDict() + start = 0 + for k, shape in self.param_shapes.items(): + vectors, _ = self.flat_shapes[k] + end = start + vectors + x_bvd = x[:, start:end] + # print("x.shape", x.shape) + # print("x_bvd.shape", x_bvd.shape) + out[k] = self.projections[_sanitize_name(k)](x_bvd).reshape(len(x), *shape) + start = end + return out + + +def params_proj_from_config( + config: Dict[str, Any], device: torch.device, param_shapes: Dict[str, Tuple[int]], d_latent: int +): + name = config.pop("name") + if name == "linear": + return LinearParamsProj( + **config, device=device, param_shapes=param_shapes, d_latent=d_latent + ) + elif name == "mlp": + return MLPParamsProj(**config, device=device, param_shapes=param_shapes, d_latent=d_latent) + elif name == "channels": + return ChannelsParamsProj( + **config, device=device, param_shapes=param_shapes, d_latent=d_latent + ) + else: + raise ValueError(f"unknown params proj: {name}") + + +def _sanitize_name(x: str) -> str: + return x.replace(".", "__") diff --git a/shap_e/models/transmitter/pc_encoder.py b/shap_e/models/transmitter/pc_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7f62b7f091068b772887c6ce1b986cf45491ebe2 --- /dev/null +++ b/shap_e/models/transmitter/pc_encoder.py @@ -0,0 +1,426 @@ +from abc import abstractmethod +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torch import torch + +from shap_e.models.generation.perceiver import SimplePerceiver +from shap_e.models.generation.transformer import Transformer +from shap_e.models.nn.encoding import PosEmbLinear +from shap_e.rendering.view_data import ProjectiveCamera +from shap_e.util.collections import AttrDict + +from .base import VectorEncoder +from .channels_encoder import DatasetIterator, sample_pcl_fps + + +class PointCloudTransformerEncoder(VectorEncoder): + """ + Encode point clouds using a transformer model with an extra output + token used to extract a latent vector. + """ + + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + param_shapes: Dict[str, Tuple[int]], + params_proj: Dict[str, Any], + latent_bottleneck: Optional[Dict[str, Any]] = None, + d_latent: int = 512, + latent_ctx: int = 1, + input_channels: int = 6, + n_ctx: int = 1024, + width: int = 512, + layers: int = 12, + heads: int = 8, + init_scale: float = 0.25, + pos_emb: Optional[str] = None, + ): + super().__init__( + device=device, + param_shapes=param_shapes, + params_proj=params_proj, + latent_bottleneck=latent_bottleneck, + d_latent=d_latent, + ) + self.input_channels = input_channels + self.n_ctx = n_ctx + self.latent_ctx = latent_ctx + + assert d_latent % latent_ctx == 0 + + self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) + self.backbone = Transformer( + device=device, + dtype=dtype, + n_ctx=n_ctx + latent_ctx, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + ) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.register_parameter( + "output_tokens", + nn.Parameter(torch.randn(latent_ctx, width, device=device, dtype=dtype)), + ) + + self.input_proj = PosEmbLinear(pos_emb, input_channels, width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, d_latent // latent_ctx, device=device, dtype=dtype) + + def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: + _ = options + points = batch.points.permute(0, 2, 1) # NCL -> NLC + h = self.input_proj(points) + h = torch.cat([h, self.output_tokens[None].repeat(len(h), 1, 1)], dim=1) + h = self.ln_pre(h) + h = self.backbone(h) + h = self.ln_post(h) + h = h[:, self.n_ctx :] + h = self.output_proj(h).flatten(1) + return h + + +class PerceiverEncoder(VectorEncoder): + """ + Encode point clouds using a perceiver model with an extra output + token used to extract a latent vector. + """ + + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + param_shapes: Dict[str, Tuple[int]], + params_proj: Dict[str, Any], + latent_bottleneck: Optional[Dict[str, Any]] = None, + d_latent: int = 512, + latent_ctx: int = 1, + width: int = 512, + layers: int = 12, + xattn_layers: int = 1, + heads: int = 8, + init_scale: float = 0.25, + # Training hparams + inner_batch_size: int = 1, + data_ctx: int = 1, + min_unrolls: int, + max_unrolls: int, + ): + super().__init__( + device=device, + param_shapes=param_shapes, + params_proj=params_proj, + latent_bottleneck=latent_bottleneck, + d_latent=d_latent, + ) + self.width = width + self.device = device + self.dtype = dtype + self.latent_ctx = latent_ctx + + self.inner_batch_size = inner_batch_size + self.data_ctx = data_ctx + self.min_unrolls = min_unrolls + self.max_unrolls = max_unrolls + + self.encoder = SimplePerceiver( + device=device, + dtype=dtype, + n_ctx=self.data_ctx + self.latent_ctx, + n_data=self.inner_batch_size, + width=width, + layers=xattn_layers, + heads=heads, + init_scale=init_scale, + ) + self.processor = Transformer( + device=device, + dtype=dtype, + n_ctx=self.data_ctx + self.latent_ctx, + layers=layers - xattn_layers, + width=width, + heads=heads, + init_scale=init_scale, + ) + self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.register_parameter( + "output_tokens", + nn.Parameter(torch.randn(self.latent_ctx, width, device=device, dtype=dtype)), + ) + self.output_proj = nn.Linear(width, d_latent // self.latent_ctx, device=device, dtype=dtype) + + @abstractmethod + def get_h_and_iterator( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> Tuple[torch.Tensor, Iterable]: + """ + :return: a tuple of ( + the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], + an iterator over the given data + ) + """ + + def encode_to_vector(self, batch: AttrDict, options: Optional[AttrDict] = None) -> torch.Tensor: + h, it = self.get_h_and_iterator(batch, options=options) + n_unrolls = self.get_n_unrolls() + + for _ in range(n_unrolls): + data = next(it) + h = self.encoder(h, data) + h = self.processor(h) + + h = self.output_proj(self.ln_post(h[:, -self.latent_ctx :])) + return h.flatten(1) + + def get_n_unrolls(self): + if self.training: + n_unrolls = torch.randint( + self.min_unrolls, self.max_unrolls + 1, size=(), device=self.device + ) + dist.broadcast(n_unrolls, 0) + n_unrolls = n_unrolls.item() + else: + n_unrolls = self.max_unrolls + return n_unrolls + + +class PointCloudPerceiverEncoder(PerceiverEncoder): + """ + Encode point clouds using a transformer model with an extra output + token used to extract a latent vector. + """ + + def __init__( + self, + *, + cross_attention_dataset: str = "pcl", + fps_method: str = "fps", + # point cloud hyperparameters + input_channels: int = 6, + pos_emb: Optional[str] = None, + # multiview hyperparameters + image_size: int = 256, + patch_size: int = 32, + pose_dropout: float = 0.0, + use_depth: bool = False, + max_depth: float = 5.0, + # other hyperparameters + **kwargs, + ): + super().__init__(**kwargs) + assert cross_attention_dataset in ("pcl", "multiview") + assert fps_method in ("fps", "first") + self.cross_attention_dataset = cross_attention_dataset + self.fps_method = fps_method + self.input_channels = input_channels + self.input_proj = PosEmbLinear( + pos_emb, input_channels, self.width, device=self.device, dtype=self.dtype + ) + if self.cross_attention_dataset == "multiview": + self.image_size = image_size + self.patch_size = patch_size + self.pose_dropout = pose_dropout + self.use_depth = use_depth + self.max_depth = max_depth + pos_ctx = (image_size // patch_size) ** 2 + self.register_parameter( + "pos_emb", + nn.Parameter( + torch.randn( + pos_ctx * self.inner_batch_size, + self.width, + device=self.device, + dtype=self.dtype, + ) + ), + ) + self.patch_emb = nn.Conv2d( + in_channels=3 if not use_depth else 4, + out_channels=self.width, + kernel_size=patch_size, + stride=patch_size, + device=self.device, + dtype=self.dtype, + ) + self.camera_emb = nn.Sequential( + nn.Linear( + 3 * 4 + 1, self.width, device=self.device, dtype=self.dtype + ), # input size is for origin+x+y+z+fov + nn.GELU(), + nn.Linear(self.width, 2 * self.width, device=self.device, dtype=self.dtype), + ) + + def get_h_and_iterator( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> Tuple[torch.Tensor, Iterable]: + """ + :return: a tuple of ( + the initial output tokens of size [batch_size, data_ctx + latent_ctx, width], + an iterator over the given data + ) + """ + options = AttrDict() if options is None else options + + # Build the initial query embeddings + points = batch.points.permute(0, 2, 1) # NCL -> NLC + fps_samples = self.sample_pcl_fps(points) + batch_size = points.shape[0] + data_tokens = self.input_proj(fps_samples) + latent_tokens = self.output_tokens.unsqueeze(0).repeat(batch_size, 1, 1) + h = self.ln_pre(torch.cat([data_tokens, latent_tokens], dim=1)) + assert h.shape == (batch_size, self.data_ctx + self.latent_ctx, self.width) + + # Build the dataset embedding iterator + dataset_fn = { + "pcl": self.get_pcl_dataset, + "multiview": self.get_multiview_dataset, + }[self.cross_attention_dataset] + it = dataset_fn(batch, options=options) + + return h, it + + def sample_pcl_fps(self, points: torch.Tensor) -> torch.Tensor: + return sample_pcl_fps(points, data_ctx=self.data_ctx, method=self.fps_method) + + def get_pcl_dataset( + self, batch: AttrDict, options: Optional[AttrDict[str, Any]] = None + ) -> Iterable: + _ = options + dataset_emb = self.input_proj(batch.points.permute(0, 2, 1)) # NCL -> NLC + assert dataset_emb.shape[1] >= self.inner_batch_size + return iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size)) + + def get_multiview_dataset( + self, batch: AttrDict, options: Optional[AttrDict] = None + ) -> Iterable: + _ = options + + dataset_emb = self.encode_views(batch) + batch_size, num_views, n_patches, width = dataset_emb.shape + + assert num_views >= self.inner_batch_size + + it = iter(DatasetIterator(dataset_emb, batch_size=self.inner_batch_size)) + + def gen(): + while True: + examples = next(it) + assert examples.shape == (batch_size, self.inner_batch_size, n_patches, self.width) + views = examples.reshape(batch_size, -1, width) + self.pos_emb + yield views + + return gen() + + def encode_views(self, batch: AttrDict) -> torch.Tensor: + """ + :return: [batch_size, num_views, n_patches, width] + """ + all_views = self.views_to_tensor(batch.views).to(self.device) + if self.use_depth: + all_views = torch.cat([all_views, self.depths_to_tensor(batch.depths)], dim=2) + all_cameras = self.cameras_to_tensor(batch.cameras).to(self.device) + + batch_size, num_views, _, _, _ = all_views.shape + + views_proj = self.patch_emb( + all_views.reshape([batch_size * num_views, *all_views.shape[2:]]) + ) + views_proj = ( + views_proj.reshape([batch_size, num_views, self.width, -1]) + .permute(0, 1, 3, 2) + .contiguous() + ) # [batch_size x num_views x n_patches x width] + + # [batch_size, num_views, 1, 2 * width] + camera_proj = self.camera_emb(all_cameras).reshape( + [batch_size, num_views, 1, self.width * 2] + ) + pose_dropout = self.pose_dropout if self.training else 0.0 + mask = torch.rand(batch_size, 1, 1, 1, device=views_proj.device) >= pose_dropout + camera_proj = torch.where(mask, camera_proj, torch.zeros_like(camera_proj)) + scale, shift = camera_proj.chunk(2, dim=3) + views_proj = views_proj * (scale + 1.0) + shift + return views_proj + + def views_to_tensor(self, views: Union[torch.Tensor, List[List[Image.Image]]]) -> torch.Tensor: + """ + Returns a [batch x num_views x 3 x size x size] tensor in the range [-1, 1]. + """ + if isinstance(views, torch.Tensor): + return views + + tensor_batch = [] + num_views = len(views[0]) + for inner_list in views: + assert len(inner_list) == num_views + inner_batch = [] + for img in inner_list: + img = img.resize((self.image_size,) * 2).convert("RGB") + inner_batch.append( + torch.from_numpy(np.array(img)).to(device=self.device, dtype=torch.float32) + / 127.5 + - 1 + ) + tensor_batch.append(torch.stack(inner_batch, dim=0)) + return torch.stack(tensor_batch, dim=0).permute(0, 1, 4, 2, 3) + + def depths_to_tensor( + self, depths: Union[torch.Tensor, List[List[Image.Image]]] + ) -> torch.Tensor: + """ + Returns a [batch x num_views x 1 x size x size] tensor in the range [-1, 1]. + """ + if isinstance(depths, torch.Tensor): + return depths + + tensor_batch = [] + num_views = len(depths[0]) + for inner_list in depths: + assert len(inner_list) == num_views + inner_batch = [] + for arr in inner_list: + tensor = torch.from_numpy(arr).clamp(max=self.max_depth) / self.max_depth + tensor = tensor * 2 - 1 + tensor = F.interpolate( + tensor[None, None], + (self.image_size,) * 2, + mode="nearest", + ) + inner_batch.append(tensor.to(device=self.device, dtype=torch.float32)) + tensor_batch.append(torch.cat(inner_batch, dim=0)) + return torch.stack(tensor_batch, dim=0) + + def cameras_to_tensor( + self, cameras: Union[torch.Tensor, List[List[ProjectiveCamera]]] + ) -> torch.Tensor: + """ + Returns a [batch x num_views x 3*4+1] tensor of camera information. + """ + if isinstance(cameras, torch.Tensor): + return cameras + outer_batch = [] + for inner_list in cameras: + inner_batch = [] + for camera in inner_list: + inner_batch.append( + np.array( + [ + *camera.x, + *camera.y, + *camera.z, + *camera.origin, + camera.x_fov, + ] + ) + ) + outer_batch.append(np.stack(inner_batch, axis=0)) + return torch.from_numpy(np.stack(outer_batch, axis=0)).float() diff --git a/shap_e/models/volume.py b/shap_e/models/volume.py new file mode 100644 index 0000000000000000000000000000000000000000..d78e132194c75baa72dc80cd0f098e158b01f515 --- /dev/null +++ b/shap_e/models/volume.py @@ -0,0 +1,255 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch + +from shap_e.models.nn.meta import MetaModule +from shap_e.models.nn.utils import ArrayType, safe_divide, to_torch + + +@dataclass +class VolumeRange: + t0: torch.Tensor + t1: torch.Tensor + intersected: torch.Tensor + + def __post_init__(self): + assert self.t0.shape == self.t1.shape == self.intersected.shape + + def next_t0(self): + """ + Given convex volume1 and volume2, where volume1 is contained in + volume2, this function returns the t0 at which rays leave volume1 and + intersect with volume2 \\ volume1. + """ + return self.t1 * self.intersected.float() + + def extend(self, another: "VolumeRange") -> "VolumeRange": + """ + The ranges at which rays intersect with either one, or both, or none of + the self and another are merged together. + """ + return VolumeRange( + t0=torch.where(self.intersected, self.t0, another.t0), + t1=torch.where(another.intersected, another.t1, self.t1), + intersected=torch.logical_or(self.intersected, another.intersected), + ) + + def partition(self, ts) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Partitions t0 and t1 into n_samples intervals. + + :param ts: [batch_size, *shape, n_samples, 1] + :return: a tuple of ( + lower: [batch_size, *shape, n_samples, 1] + upper: [batch_size, *shape, n_samples, 1] + delta: [batch_size, *shape, n_samples, 1] + ) where + + ts \\in [lower, upper] + deltas = upper - lower + """ + mids = (ts[..., 1:, :] + ts[..., :-1, :]) * 0.5 + lower = torch.cat([self.t0[..., None, :], mids], dim=-2) + upper = torch.cat([mids, self.t1[..., None, :]], dim=-2) + delta = upper - lower + assert lower.shape == upper.shape == delta.shape == ts.shape + return lower, upper, delta + + +class Volume(ABC): + """ + An abstraction of rendering volume. + """ + + @abstractmethod + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: Optional[torch.Tensor] = None, + params: Optional[Dict] = None, + epsilon: float = 1e-6, + ) -> VolumeRange: + """ + :param origin: [batch_size, *shape, 3] + :param direction: [batch_size, *shape, 3] + :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. + :param params: Optional meta parameters in case Volume is parametric + :param epsilon: to stabilize calculations + + :return: A tuple of (t0, t1, intersected) where each has a shape + [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is + in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed + to be on the boundary of the volume. + """ + + +class BoundingBoxVolume(MetaModule, Volume): + """ + Axis-aligned bounding box defined by the two opposite corners. + """ + + def __init__( + self, + *, + bbox_min: ArrayType, + bbox_max: ArrayType, + min_dist: float = 0.0, + min_t_range: float = 1e-3, + device: torch.device = torch.device("cuda"), + ): + """ + :param bbox_min: the left/bottommost corner of the bounding box + :param bbox_max: the other corner of the bounding box + :param min_dist: all rays should start at least this distance away from the origin. + """ + super().__init__() + + self.bbox_min = to_torch(bbox_min).to(device) + self.bbox_max = to_torch(bbox_max).to(device) + self.min_dist = min_dist + self.min_t_range = min_t_range + self.bbox = torch.stack([self.bbox_min, self.bbox_max]) + assert self.bbox.shape == (2, 3) + assert self.min_dist >= 0.0 + assert self.min_t_range > 0.0 + self.device = device + + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: Optional[torch.Tensor] = None, + params: Optional[Dict] = None, + epsilon=1e-6, + ) -> VolumeRange: + """ + :param origin: [batch_size, *shape, 3] + :param direction: [batch_size, *shape, 3] + :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. + :param params: Optional meta parameters in case Volume is parametric + :param epsilon: to stabilize calculations + + :return: A tuple of (t0, t1, intersected) where each has a shape + [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is + in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed + to be on the boundary of the volume. + """ + + batch_size, *shape, _ = origin.shape + ones = [1] * len(shape) + bbox = self.bbox.view(1, *ones, 2, 3) + # import pdb; pdb.set_trace() + ts = safe_divide(bbox - origin[..., None, :], direction[..., None, :], epsilon=epsilon) + + # Cases to think about: + # + # 1. t1 <= t0: the ray does not pass through the AABB. + # 2. t0 < t1 <= 0: the ray intersects but the BB is behind the origin. + # 3. t0 <= 0 <= t1: the ray starts from inside the BB + # 4. 0 <= t0 < t1: the ray is not inside and intersects with the BB twice. + # + # 1 and 4 are clearly handled from t0 < t1 below. + # Making t0 at least min_dist (>= 0) takes care of 2 and 3. + t0 = ts.min(dim=-2).values.max(dim=-1, keepdim=True).values.clamp(self.min_dist) + t1 = ts.max(dim=-2).values.min(dim=-1, keepdim=True).values + assert t0.shape == t1.shape == (batch_size, *shape, 1) + if t0_lower is not None: + assert t0.shape == t0_lower.shape + t0 = torch.maximum(t0, t0_lower) + + intersected = t0 + self.min_t_range < t1 + t0 = torch.where(intersected, t0, torch.zeros_like(t0)) + t1 = torch.where(intersected, t1, torch.ones_like(t1)) + + return VolumeRange(t0=t0, t1=t1, intersected=intersected) + + +class UnboundedVolume(MetaModule, Volume): + """ + Originally used in NeRF. Unbounded volume but with a limited visibility + when rendering (e.g. objects that are farther away than the max_dist from + the ray origin are not considered) + """ + + def __init__( + self, + *, + max_dist: float, + min_dist: float = 0.0, + min_t_range: float = 1e-3, + device: torch.device = torch.device("cuda"), + ): + super().__init__() + self.max_dist = max_dist + self.min_dist = min_dist + self.min_t_range = min_t_range + assert self.min_dist >= 0.0 + assert self.min_t_range > 0.0 + self.device = device + + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: Optional[torch.Tensor] = None, + params: Optional[Dict] = None, + ) -> VolumeRange: + """ + :param origin: [batch_size, *shape, 3] + :param direction: [batch_size, *shape, 3] + :param t0_lower: Optional [batch_size, *shape, 1] lower bound of t0 when intersecting this volume. + :param params: Optional meta parameters in case Volume is parametric + :param epsilon: to stabilize calculations + + :return: A tuple of (t0, t1, intersected) where each has a shape + [batch_size, *shape, 1]. If a ray intersects with the volume, `o + td` is + in the volume for all t in [t0, t1]. If the volume is bounded, t1 is guaranteed + to be on the boundary of the volume. + """ + + batch_size, *shape, _ = origin.shape + t0 = torch.zeros(batch_size, *shape, 1, dtype=origin.dtype, device=origin.device) + if t0_lower is not None: + t0 = torch.maximum(t0, t0_lower) + t1 = t0 + self.max_dist + t0 = t0.clamp(self.min_dist) + return VolumeRange(t0=t0, t1=t1, intersected=t0 + self.min_t_range < t1) + + +class SphericalVolume(MetaModule, Volume): + """ + Used in NeRF++ but will not be used probably unless we want to reproduce + their results. + """ + + def __init__( + self, + *, + radius: float, + center: ArrayType = (0.0, 0.0, 0.0), + min_dist: float = 0.0, + min_t_range: float = 1e-3, + device: torch.device = torch.device("cuda"), + ): + super().__init__() + + self.radius = radius + self.center = to_torch(center).to(device) + self.min_dist = min_dist + self.min_t_range = min_t_range + assert self.min_dist >= 0.0 + assert self.min_t_range > 0.0 + self.device = device + + def intersect( + self, + origin: torch.Tensor, + direction: torch.Tensor, + t0_lower: Optional[torch.Tensor] = None, + params: Optional[Dict] = None, + epsilon=1e-6, + ) -> VolumeRange: + raise NotImplementedError diff --git a/shap_e/rendering/__init__.py b/shap_e/rendering/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/rendering/__pycache__/__init__.cpython-39.pyc b/shap_e/rendering/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..478c50af71f9f2096f61e760c5c43bc524fec3a9 Binary files /dev/null and b/shap_e/rendering/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/rendering/__pycache__/_mc_table.cpython-39.pyc b/shap_e/rendering/__pycache__/_mc_table.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5031860a9c9d08d9fe9ed2d3f82c0577220fe83 Binary files /dev/null and b/shap_e/rendering/__pycache__/_mc_table.cpython-39.pyc differ diff --git a/shap_e/rendering/__pycache__/mc.cpython-39.pyc b/shap_e/rendering/__pycache__/mc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b4c0c83b9bfcd4782c9e3cbd7d7870ce84dba3a Binary files /dev/null and b/shap_e/rendering/__pycache__/mc.cpython-39.pyc differ diff --git a/shap_e/rendering/__pycache__/mesh.cpython-39.pyc b/shap_e/rendering/__pycache__/mesh.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0599eb1d06af102d0b45ff61f04f663e45fdf0be Binary files /dev/null and b/shap_e/rendering/__pycache__/mesh.cpython-39.pyc differ diff --git a/shap_e/rendering/__pycache__/ply_util.cpython-39.pyc b/shap_e/rendering/__pycache__/ply_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a1ea8345e9daea00ce00f4c969b7952d2e03d9 Binary files /dev/null and b/shap_e/rendering/__pycache__/ply_util.cpython-39.pyc differ diff --git a/shap_e/rendering/__pycache__/point_cloud.cpython-39.pyc b/shap_e/rendering/__pycache__/point_cloud.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91fae1f5bd3ec88da9903352a72c294328810c87 Binary files /dev/null and b/shap_e/rendering/__pycache__/point_cloud.cpython-39.pyc differ diff --git a/shap_e/rendering/__pycache__/torch_mesh.cpython-39.pyc b/shap_e/rendering/__pycache__/torch_mesh.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeca01a4c01c1bd6e74efb8b2d0293ec5ff136d8 Binary files /dev/null and b/shap_e/rendering/__pycache__/torch_mesh.cpython-39.pyc differ diff --git a/shap_e/rendering/__pycache__/view_data.cpython-39.pyc b/shap_e/rendering/__pycache__/view_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec4652cbd7704c25eace152744f7249f327c9a60 Binary files /dev/null and b/shap_e/rendering/__pycache__/view_data.cpython-39.pyc differ diff --git a/shap_e/rendering/_mc_table.py b/shap_e/rendering/_mc_table.py new file mode 100644 index 0000000000000000000000000000000000000000..c3f6ab6df6256b2207fd902c91ba9c2f5ca7529a --- /dev/null +++ b/shap_e/rendering/_mc_table.py @@ -0,0 +1,482 @@ +# Treat a cube as a bitmap, and create the index into this array in order of +# ZYX (note Z is the most significant digit). +# The resulting object is an array of triangles, where each triangle is 6 +# indices. Each consecutive pair of indices within this triangle represents an +# edge spanning two corners (identified by the indices). +# +# The corners of a cube are indexed as follows +# +# (0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0), +# (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1) +# +# Here is a visualization of the cube indices: +# +# 6 + -----------------------+ 7 +# /| /| +# / | / | +# / | / | +# 4 +------------------------+ 5 | +# | | | | +# | | | | +# | | | | +# | | 2 | | 3 +# | +--------------------|---+ +# | / | / +# | / | / +# |/ |/ +# +------------------------+ +# 0 1 +# +# Derived using model3d, in particular this function: +# https://github.com/unixpickle/model3d/blob/7a3adb982c154c80c1a22032b5a0695160a7f96d/model3d/mc.go#L434 +# +MC_TABLE = [ + [], + [[0, 1, 0, 2, 0, 4]], + [[1, 0, 1, 5, 1, 3]], + [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2]], + [[2, 0, 2, 3, 2, 6]], + [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4]], + [[1, 0, 1, 5, 1, 3], [2, 6, 0, 2, 3, 2]], + [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4]], + [[3, 1, 3, 7, 3, 2]], + [[0, 2, 0, 4, 0, 1], [3, 7, 2, 3, 1, 3]], + [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0]], + [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5]], + [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6]], + [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6]], + [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7]], + [[0, 4, 1, 5, 3, 7], [0, 4, 3, 7, 2, 6]], + [[4, 0, 4, 6, 4, 5]], + [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1]], + [[1, 5, 1, 3, 1, 0], [4, 6, 5, 4, 0, 4]], + [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2]], + [[2, 0, 2, 3, 2, 6], [4, 5, 0, 4, 6, 4]], + [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1]], + [[2, 6, 2, 0, 3, 2], [1, 0, 1, 5, 3, 1], [6, 4, 5, 4, 0, 4]], + [[1, 3, 5, 4, 1, 5], [1, 3, 4, 6, 5, 4], [1, 3, 3, 2, 4, 6], [3, 2, 2, 6, 4, 6]], + [[3, 1, 3, 7, 3, 2], [6, 4, 5, 4, 0, 4]], + [[4, 5, 0, 1, 4, 6], [0, 1, 0, 2, 4, 6], [7, 3, 2, 3, 1, 3]], + [[3, 2, 1, 0, 3, 7], [1, 0, 1, 5, 3, 7], [6, 4, 5, 4, 0, 4]], + [[3, 7, 3, 2, 1, 5], [3, 2, 6, 4, 1, 5], [1, 5, 6, 4, 5, 4], [3, 2, 2, 0, 6, 4]], + [[3, 7, 2, 6, 3, 1], [2, 6, 2, 0, 3, 1], [5, 4, 0, 4, 6, 4]], + [[1, 0, 1, 3, 5, 4], [1, 3, 2, 6, 5, 4], [1, 3, 3, 7, 2, 6], [5, 4, 2, 6, 4, 6]], + [[0, 1, 1, 5, 0, 2], [0, 2, 1, 5, 2, 6], [2, 6, 1, 5, 3, 7], [4, 5, 0, 4, 4, 6]], + [[6, 2, 4, 6, 4, 5], [4, 5, 5, 1, 6, 2], [6, 2, 5, 1, 7, 3]], + [[5, 1, 5, 4, 5, 7]], + [[0, 1, 0, 2, 0, 4], [5, 7, 1, 5, 4, 5]], + [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3]], + [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3]], + [[2, 0, 2, 3, 2, 6], [7, 5, 1, 5, 4, 5]], + [[2, 6, 0, 4, 2, 3], [0, 4, 0, 1, 2, 3], [7, 5, 1, 5, 4, 5]], + [[5, 7, 1, 3, 5, 4], [1, 3, 1, 0, 5, 4], [6, 2, 0, 2, 3, 2]], + [[3, 1, 3, 2, 7, 5], [3, 2, 0, 4, 7, 5], [3, 2, 2, 6, 0, 4], [7, 5, 0, 4, 5, 4]], + [[3, 7, 3, 2, 3, 1], [5, 4, 7, 5, 1, 5]], + [[0, 4, 0, 1, 2, 0], [3, 1, 3, 7, 2, 3], [4, 5, 7, 5, 1, 5]], + [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0]], + [[0, 4, 2, 3, 0, 2], [0, 4, 3, 7, 2, 3], [0, 4, 4, 5, 3, 7], [4, 5, 5, 7, 3, 7]], + [[2, 0, 3, 1, 2, 6], [3, 1, 3, 7, 2, 6], [4, 5, 7, 5, 1, 5]], + [[1, 3, 3, 7, 1, 0], [1, 0, 3, 7, 0, 4], [0, 4, 3, 7, 2, 6], [5, 7, 1, 5, 5, 4]], + [[2, 6, 2, 0, 3, 7], [2, 0, 4, 5, 3, 7], [3, 7, 4, 5, 7, 5], [2, 0, 0, 1, 4, 5]], + [[4, 0, 5, 4, 5, 7], [5, 7, 7, 3, 4, 0], [4, 0, 7, 3, 6, 2]], + [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0]], + [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6]], + [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7]], + [[0, 2, 4, 6, 5, 7], [0, 2, 5, 7, 1, 3]], + [[5, 1, 4, 0, 5, 7], [4, 0, 4, 6, 5, 7], [3, 2, 6, 2, 0, 2]], + [[2, 3, 2, 6, 0, 1], [2, 6, 7, 5, 0, 1], [0, 1, 7, 5, 1, 5], [2, 6, 6, 4, 7, 5]], + [[0, 4, 4, 6, 0, 1], [0, 1, 4, 6, 1, 3], [1, 3, 4, 6, 5, 7], [2, 6, 0, 2, 2, 3]], + [[3, 1, 2, 3, 2, 6], [2, 6, 6, 4, 3, 1], [3, 1, 6, 4, 7, 5]], + [[4, 6, 5, 7, 4, 0], [5, 7, 5, 1, 4, 0], [2, 3, 1, 3, 7, 3]], + [[1, 0, 0, 2, 1, 5], [1, 5, 0, 2, 5, 7], [5, 7, 0, 2, 4, 6], [3, 2, 1, 3, 3, 7]], + [[0, 1, 0, 4, 2, 3], [0, 4, 5, 7, 2, 3], [0, 4, 4, 6, 5, 7], [2, 3, 5, 7, 3, 7]], + [[7, 5, 3, 7, 3, 2], [3, 2, 2, 0, 7, 5], [7, 5, 2, 0, 6, 4]], + [[0, 4, 4, 6, 5, 7], [0, 4, 5, 7, 1, 5], [0, 2, 1, 3, 3, 7], [3, 7, 2, 6, 0, 2]], + [ + [3, 1, 7, 3, 6, 2], + [6, 2, 0, 1, 3, 1], + [6, 4, 0, 1, 6, 2], + [6, 4, 5, 1, 0, 1], + [6, 4, 7, 5, 5, 1], + ], + [ + [4, 0, 6, 4, 7, 5], + [7, 5, 1, 0, 4, 0], + [7, 3, 1, 0, 7, 5], + [7, 3, 2, 0, 1, 0], + [7, 3, 6, 2, 2, 0], + ], + [[7, 3, 6, 2, 6, 4], [7, 5, 7, 3, 6, 4]], + [[6, 2, 6, 7, 6, 4]], + [[0, 4, 0, 1, 0, 2], [6, 7, 4, 6, 2, 6]], + [[1, 0, 1, 5, 1, 3], [7, 6, 4, 6, 2, 6]], + [[1, 3, 0, 2, 1, 5], [0, 2, 0, 4, 1, 5], [7, 6, 4, 6, 2, 6]], + [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0]], + [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3]], + [[6, 4, 2, 0, 6, 7], [2, 0, 2, 3, 6, 7], [5, 1, 3, 1, 0, 1]], + [[1, 5, 1, 3, 0, 4], [1, 3, 7, 6, 0, 4], [0, 4, 7, 6, 4, 6], [1, 3, 3, 2, 7, 6]], + [[3, 2, 3, 1, 3, 7], [6, 4, 2, 6, 7, 6]], + [[3, 7, 3, 2, 1, 3], [0, 2, 0, 4, 1, 0], [7, 6, 4, 6, 2, 6]], + [[1, 5, 3, 7, 1, 0], [3, 7, 3, 2, 1, 0], [4, 6, 2, 6, 7, 6]], + [[2, 0, 0, 4, 2, 3], [2, 3, 0, 4, 3, 7], [3, 7, 0, 4, 1, 5], [6, 4, 2, 6, 6, 7]], + [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0]], + [[0, 1, 4, 6, 0, 4], [0, 1, 6, 7, 4, 6], [0, 1, 1, 3, 6, 7], [1, 3, 3, 7, 6, 7]], + [[0, 2, 0, 1, 4, 6], [0, 1, 3, 7, 4, 6], [0, 1, 1, 5, 3, 7], [4, 6, 3, 7, 6, 7]], + [[7, 3, 6, 7, 6, 4], [6, 4, 4, 0, 7, 3], [7, 3, 4, 0, 5, 1]], + [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5]], + [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5]], + [[6, 7, 4, 5, 6, 2], [4, 5, 4, 0, 6, 2], [3, 1, 0, 1, 5, 1]], + [[2, 0, 2, 6, 3, 1], [2, 6, 4, 5, 3, 1], [2, 6, 6, 7, 4, 5], [3, 1, 4, 5, 1, 5]], + [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7]], + [[0, 1, 2, 3, 6, 7], [0, 1, 6, 7, 4, 5]], + [[0, 2, 2, 3, 0, 4], [0, 4, 2, 3, 4, 5], [4, 5, 2, 3, 6, 7], [1, 3, 0, 1, 1, 5]], + [[5, 4, 1, 5, 1, 3], [1, 3, 3, 2, 5, 4], [5, 4, 3, 2, 7, 6]], + [[4, 0, 6, 2, 4, 5], [6, 2, 6, 7, 4, 5], [1, 3, 7, 3, 2, 3]], + [[2, 6, 6, 7, 2, 0], [2, 0, 6, 7, 0, 1], [0, 1, 6, 7, 4, 5], [3, 7, 2, 3, 3, 1]], + [[0, 1, 1, 5, 3, 7], [0, 1, 3, 7, 2, 3], [0, 4, 2, 6, 6, 7], [6, 7, 4, 5, 0, 4]], + [ + [6, 2, 7, 6, 5, 4], + [5, 4, 0, 2, 6, 2], + [5, 1, 0, 2, 5, 4], + [5, 1, 3, 2, 0, 2], + [5, 1, 7, 3, 3, 2], + ], + [[3, 1, 3, 7, 2, 0], [3, 7, 5, 4, 2, 0], [2, 0, 5, 4, 0, 4], [3, 7, 7, 6, 5, 4]], + [[1, 0, 3, 1, 3, 7], [3, 7, 7, 6, 1, 0], [1, 0, 7, 6, 5, 4]], + [ + [1, 0, 5, 1, 7, 3], + [7, 3, 2, 0, 1, 0], + [7, 6, 2, 0, 7, 3], + [7, 6, 4, 0, 2, 0], + [7, 6, 5, 4, 4, 0], + ], + [[7, 6, 5, 4, 5, 1], [7, 3, 7, 6, 5, 1]], + [[5, 7, 5, 1, 5, 4], [6, 2, 7, 6, 4, 6]], + [[0, 2, 0, 4, 1, 0], [5, 4, 5, 7, 1, 5], [2, 6, 7, 6, 4, 6]], + [[1, 0, 5, 4, 1, 3], [5, 4, 5, 7, 1, 3], [2, 6, 7, 6, 4, 6]], + [[4, 5, 5, 7, 4, 0], [4, 0, 5, 7, 0, 2], [0, 2, 5, 7, 1, 3], [6, 7, 4, 6, 6, 2]], + [[2, 3, 6, 7, 2, 0], [6, 7, 6, 4, 2, 0], [1, 5, 4, 5, 7, 5]], + [[4, 0, 0, 1, 4, 6], [4, 6, 0, 1, 6, 7], [6, 7, 0, 1, 2, 3], [5, 1, 4, 5, 5, 7]], + [[0, 2, 2, 3, 6, 7], [0, 2, 6, 7, 4, 6], [0, 1, 4, 5, 5, 7], [5, 7, 1, 3, 0, 1]], + [ + [5, 4, 7, 5, 3, 1], + [3, 1, 0, 4, 5, 4], + [3, 2, 0, 4, 3, 1], + [3, 2, 6, 4, 0, 4], + [3, 2, 7, 6, 6, 4], + ], + [[5, 4, 5, 7, 1, 5], [3, 7, 3, 2, 1, 3], [4, 6, 2, 6, 7, 6]], + [[1, 0, 0, 2, 0, 4], [1, 5, 5, 4, 5, 7], [3, 2, 1, 3, 3, 7], [2, 6, 7, 6, 4, 6]], + [[7, 3, 3, 2, 7, 5], [7, 5, 3, 2, 5, 4], [5, 4, 3, 2, 1, 0], [6, 2, 7, 6, 6, 4]], + [ + [0, 4, 2, 3, 0, 2], + [0, 4, 3, 7, 2, 3], + [0, 4, 4, 5, 3, 7], + [4, 5, 5, 7, 3, 7], + [6, 7, 4, 6, 2, 6], + ], + [[7, 6, 6, 4, 7, 3], [7, 3, 6, 4, 3, 1], [3, 1, 6, 4, 2, 0], [5, 4, 7, 5, 5, 1]], + [ + [0, 1, 4, 6, 0, 4], + [0, 1, 6, 7, 4, 6], + [0, 1, 1, 3, 6, 7], + [1, 3, 3, 7, 6, 7], + [5, 7, 1, 5, 4, 5], + ], + [ + [6, 7, 4, 6, 0, 2], + [0, 2, 3, 7, 6, 7], + [0, 1, 3, 7, 0, 2], + [0, 1, 5, 7, 3, 7], + [0, 1, 4, 5, 5, 7], + ], + [[4, 0, 6, 7, 4, 6], [4, 0, 7, 3, 6, 7], [4, 0, 5, 7, 7, 3], [4, 5, 5, 7, 4, 0]], + [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0]], + [[0, 2, 1, 5, 0, 1], [0, 2, 5, 7, 1, 5], [0, 2, 2, 6, 5, 7], [2, 6, 6, 7, 5, 7]], + [[1, 3, 1, 0, 5, 7], [1, 0, 2, 6, 5, 7], [5, 7, 2, 6, 7, 6], [1, 0, 0, 4, 2, 6]], + [[2, 0, 6, 2, 6, 7], [6, 7, 7, 5, 2, 0], [2, 0, 7, 5, 3, 1]], + [[0, 4, 0, 2, 1, 5], [0, 2, 6, 7, 1, 5], [0, 2, 2, 3, 6, 7], [1, 5, 6, 7, 5, 7]], + [[7, 6, 5, 7, 5, 1], [5, 1, 1, 0, 7, 6], [7, 6, 1, 0, 3, 2]], + [ + [2, 0, 3, 2, 7, 6], + [7, 6, 4, 0, 2, 0], + [7, 5, 4, 0, 7, 6], + [7, 5, 1, 0, 4, 0], + [7, 5, 3, 1, 1, 0], + ], + [[7, 5, 3, 1, 3, 2], [7, 6, 7, 5, 3, 2]], + [[7, 5, 5, 1, 7, 6], [7, 6, 5, 1, 6, 2], [6, 2, 5, 1, 4, 0], [3, 1, 7, 3, 3, 2]], + [ + [0, 2, 1, 5, 0, 1], + [0, 2, 5, 7, 1, 5], + [0, 2, 2, 6, 5, 7], + [2, 6, 6, 7, 5, 7], + [3, 7, 2, 3, 1, 3], + ], + [ + [3, 7, 2, 3, 0, 1], + [0, 1, 5, 7, 3, 7], + [0, 4, 5, 7, 0, 1], + [0, 4, 6, 7, 5, 7], + [0, 4, 2, 6, 6, 7], + ], + [[2, 0, 3, 7, 2, 3], [2, 0, 7, 5, 3, 7], [2, 0, 6, 7, 7, 5], [2, 6, 6, 7, 2, 0]], + [ + [5, 7, 1, 5, 0, 4], + [0, 4, 6, 7, 5, 7], + [0, 2, 6, 7, 0, 4], + [0, 2, 3, 7, 6, 7], + [0, 2, 1, 3, 3, 7], + ], + [[1, 0, 5, 7, 1, 5], [1, 0, 7, 6, 5, 7], [1, 0, 3, 7, 7, 6], [1, 3, 3, 7, 1, 0]], + [[0, 2, 0, 1, 0, 4], [3, 7, 6, 7, 5, 7]], + [[7, 5, 7, 3, 7, 6]], + [[7, 3, 7, 5, 7, 6]], + [[0, 1, 0, 2, 0, 4], [6, 7, 3, 7, 5, 7]], + [[1, 3, 1, 0, 1, 5], [7, 6, 3, 7, 5, 7]], + [[0, 4, 1, 5, 0, 2], [1, 5, 1, 3, 0, 2], [6, 7, 3, 7, 5, 7]], + [[2, 6, 2, 0, 2, 3], [7, 5, 6, 7, 3, 7]], + [[0, 1, 2, 3, 0, 4], [2, 3, 2, 6, 0, 4], [5, 7, 6, 7, 3, 7]], + [[1, 5, 1, 3, 0, 1], [2, 3, 2, 6, 0, 2], [5, 7, 6, 7, 3, 7]], + [[3, 2, 2, 6, 3, 1], [3, 1, 2, 6, 1, 5], [1, 5, 2, 6, 0, 4], [7, 6, 3, 7, 7, 5]], + [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2]], + [[7, 6, 3, 2, 7, 5], [3, 2, 3, 1, 7, 5], [4, 0, 1, 0, 2, 0]], + [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2]], + [[2, 3, 2, 0, 6, 7], [2, 0, 1, 5, 6, 7], [2, 0, 0, 4, 1, 5], [6, 7, 1, 5, 7, 5]], + [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1]], + [[0, 4, 0, 1, 2, 6], [0, 1, 5, 7, 2, 6], [2, 6, 5, 7, 6, 7], [0, 1, 1, 3, 5, 7]], + [[1, 5, 0, 2, 1, 0], [1, 5, 2, 6, 0, 2], [1, 5, 5, 7, 2, 6], [5, 7, 7, 6, 2, 6]], + [[5, 1, 7, 5, 7, 6], [7, 6, 6, 2, 5, 1], [5, 1, 6, 2, 4, 0]], + [[4, 5, 4, 0, 4, 6], [7, 3, 5, 7, 6, 7]], + [[0, 2, 4, 6, 0, 1], [4, 6, 4, 5, 0, 1], [3, 7, 5, 7, 6, 7]], + [[4, 6, 4, 5, 0, 4], [1, 5, 1, 3, 0, 1], [6, 7, 3, 7, 5, 7]], + [[5, 1, 1, 3, 5, 4], [5, 4, 1, 3, 4, 6], [4, 6, 1, 3, 0, 2], [7, 3, 5, 7, 7, 6]], + [[2, 3, 2, 6, 0, 2], [4, 6, 4, 5, 0, 4], [3, 7, 5, 7, 6, 7]], + [[6, 4, 4, 5, 6, 2], [6, 2, 4, 5, 2, 3], [2, 3, 4, 5, 0, 1], [7, 5, 6, 7, 7, 3]], + [[0, 1, 1, 5, 1, 3], [0, 2, 2, 3, 2, 6], [4, 5, 0, 4, 4, 6], [5, 7, 6, 7, 3, 7]], + [ + [1, 3, 5, 4, 1, 5], + [1, 3, 4, 6, 5, 4], + [1, 3, 3, 2, 4, 6], + [3, 2, 2, 6, 4, 6], + [7, 6, 3, 7, 5, 7], + ], + [[3, 1, 7, 5, 3, 2], [7, 5, 7, 6, 3, 2], [0, 4, 6, 4, 5, 4]], + [[1, 0, 0, 2, 4, 6], [1, 0, 4, 6, 5, 4], [1, 3, 5, 7, 7, 6], [7, 6, 3, 2, 1, 3]], + [[5, 7, 7, 6, 5, 1], [5, 1, 7, 6, 1, 0], [1, 0, 7, 6, 3, 2], [4, 6, 5, 4, 4, 0]], + [ + [7, 5, 6, 7, 2, 3], + [2, 3, 1, 5, 7, 5], + [2, 0, 1, 5, 2, 3], + [2, 0, 4, 5, 1, 5], + [2, 0, 6, 4, 4, 5], + ], + [[6, 2, 2, 0, 6, 7], [6, 7, 2, 0, 7, 5], [7, 5, 2, 0, 3, 1], [4, 0, 6, 4, 4, 5]], + [ + [4, 6, 5, 4, 1, 0], + [1, 0, 2, 6, 4, 6], + [1, 3, 2, 6, 1, 0], + [1, 3, 7, 6, 2, 6], + [1, 3, 5, 7, 7, 6], + ], + [ + [1, 5, 0, 2, 1, 0], + [1, 5, 2, 6, 0, 2], + [1, 5, 5, 7, 2, 6], + [5, 7, 7, 6, 2, 6], + [4, 6, 5, 4, 0, 4], + ], + [[5, 1, 4, 6, 5, 4], [5, 1, 6, 2, 4, 6], [5, 1, 7, 6, 6, 2], [5, 7, 7, 6, 5, 1]], + [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1]], + [[7, 3, 5, 1, 7, 6], [5, 1, 5, 4, 7, 6], [2, 0, 4, 0, 1, 0]], + [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4]], + [[0, 2, 0, 4, 1, 3], [0, 4, 6, 7, 1, 3], [1, 3, 6, 7, 3, 7], [0, 4, 4, 5, 6, 7]], + [[5, 4, 7, 6, 5, 1], [7, 6, 7, 3, 5, 1], [0, 2, 3, 2, 6, 2]], + [[1, 5, 5, 4, 7, 6], [1, 5, 7, 6, 3, 7], [1, 0, 3, 2, 2, 6], [2, 6, 0, 4, 1, 0]], + [[3, 1, 1, 0, 3, 7], [3, 7, 1, 0, 7, 6], [7, 6, 1, 0, 5, 4], [2, 0, 3, 2, 2, 6]], + [ + [2, 3, 6, 2, 4, 0], + [4, 0, 1, 3, 2, 3], + [4, 5, 1, 3, 4, 0], + [4, 5, 7, 3, 1, 3], + [4, 5, 6, 7, 7, 3], + ], + [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6]], + [[1, 5, 5, 4, 1, 3], [1, 3, 5, 4, 3, 2], [3, 2, 5, 4, 7, 6], [0, 4, 1, 0, 0, 2]], + [[1, 0, 5, 4, 7, 6], [1, 0, 7, 6, 3, 2]], + [[2, 3, 0, 2, 0, 4], [0, 4, 4, 5, 2, 3], [2, 3, 4, 5, 6, 7]], + [[1, 3, 1, 5, 0, 2], [1, 5, 7, 6, 0, 2], [1, 5, 5, 4, 7, 6], [0, 2, 7, 6, 2, 6]], + [ + [5, 1, 4, 5, 6, 7], + [6, 7, 3, 1, 5, 1], + [6, 2, 3, 1, 6, 7], + [6, 2, 0, 1, 3, 1], + [6, 2, 4, 0, 0, 1], + ], + [[6, 7, 2, 6, 2, 0], [2, 0, 0, 1, 6, 7], [6, 7, 0, 1, 4, 5]], + [[6, 2, 4, 0, 4, 5], [6, 7, 6, 2, 4, 5]], + [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1]], + [[1, 5, 1, 0, 3, 7], [1, 0, 4, 6, 3, 7], [1, 0, 0, 2, 4, 6], [3, 7, 4, 6, 7, 6]], + [[1, 0, 3, 7, 1, 3], [1, 0, 7, 6, 3, 7], [1, 0, 0, 4, 7, 6], [0, 4, 4, 6, 7, 6]], + [[6, 4, 7, 6, 7, 3], [7, 3, 3, 1, 6, 4], [6, 4, 3, 1, 2, 0]], + [[6, 7, 7, 3, 6, 4], [6, 4, 7, 3, 4, 0], [4, 0, 7, 3, 5, 1], [2, 3, 6, 2, 2, 0]], + [ + [7, 6, 3, 7, 1, 5], + [1, 5, 4, 6, 7, 6], + [1, 0, 4, 6, 1, 5], + [1, 0, 2, 6, 4, 6], + [1, 0, 3, 2, 2, 6], + ], + [ + [1, 0, 3, 7, 1, 3], + [1, 0, 7, 6, 3, 7], + [1, 0, 0, 4, 7, 6], + [0, 4, 4, 6, 7, 6], + [2, 6, 0, 2, 3, 2], + ], + [[3, 1, 7, 6, 3, 7], [3, 1, 6, 4, 7, 6], [3, 1, 2, 6, 6, 4], [3, 2, 2, 6, 3, 1]], + [[3, 2, 3, 1, 7, 6], [3, 1, 0, 4, 7, 6], [7, 6, 0, 4, 6, 4], [3, 1, 1, 5, 0, 4]], + [ + [0, 1, 2, 0, 6, 4], + [6, 4, 5, 1, 0, 1], + [6, 7, 5, 1, 6, 4], + [6, 7, 3, 1, 5, 1], + [6, 7, 2, 3, 3, 1], + ], + [[0, 1, 4, 0, 4, 6], [4, 6, 6, 7, 0, 1], [0, 1, 6, 7, 2, 3]], + [[6, 7, 2, 3, 2, 0], [6, 4, 6, 7, 2, 0]], + [ + [2, 6, 0, 2, 1, 3], + [1, 3, 7, 6, 2, 6], + [1, 5, 7, 6, 1, 3], + [1, 5, 4, 6, 7, 6], + [1, 5, 0, 4, 4, 6], + ], + [[1, 5, 1, 0, 1, 3], [4, 6, 7, 6, 2, 6]], + [[0, 1, 2, 6, 0, 2], [0, 1, 6, 7, 2, 6], [0, 1, 4, 6, 6, 7], [0, 4, 4, 6, 0, 1]], + [[6, 7, 6, 2, 6, 4]], + [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4]], + [[7, 5, 6, 4, 7, 3], [6, 4, 6, 2, 7, 3], [1, 0, 2, 0, 4, 0]], + [[6, 2, 7, 3, 6, 4], [7, 3, 7, 5, 6, 4], [0, 1, 5, 1, 3, 1]], + [[2, 0, 0, 4, 1, 5], [2, 0, 1, 5, 3, 1], [2, 6, 3, 7, 7, 5], [7, 5, 6, 4, 2, 6]], + [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4]], + [[3, 2, 3, 7, 1, 0], [3, 7, 6, 4, 1, 0], [3, 7, 7, 5, 6, 4], [1, 0, 6, 4, 0, 4]], + [[3, 7, 7, 5, 3, 2], [3, 2, 7, 5, 2, 0], [2, 0, 7, 5, 6, 4], [1, 5, 3, 1, 1, 0]], + [ + [7, 3, 5, 7, 4, 6], + [4, 6, 2, 3, 7, 3], + [4, 0, 2, 3, 4, 6], + [4, 0, 1, 3, 2, 3], + [4, 0, 5, 1, 1, 3], + ], + [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5]], + [[2, 3, 3, 1, 2, 6], [2, 6, 3, 1, 6, 4], [6, 4, 3, 1, 7, 5], [0, 1, 2, 0, 0, 4]], + [[1, 0, 1, 5, 3, 2], [1, 5, 4, 6, 3, 2], [3, 2, 4, 6, 2, 6], [1, 5, 5, 7, 4, 6]], + [ + [0, 2, 4, 0, 5, 1], + [5, 1, 3, 2, 0, 2], + [5, 7, 3, 2, 5, 1], + [5, 7, 6, 2, 3, 2], + [5, 7, 4, 6, 6, 2], + ], + [[2, 0, 3, 1, 7, 5], [2, 0, 7, 5, 6, 4]], + [[4, 6, 0, 4, 0, 1], [0, 1, 1, 3, 4, 6], [4, 6, 1, 3, 5, 7]], + [[0, 2, 1, 0, 1, 5], [1, 5, 5, 7, 0, 2], [0, 2, 5, 7, 4, 6]], + [[5, 7, 4, 6, 4, 0], [5, 1, 5, 7, 4, 0]], + [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2]], + [[0, 1, 0, 2, 4, 5], [0, 2, 3, 7, 4, 5], [4, 5, 3, 7, 5, 7], [0, 2, 2, 6, 3, 7]], + [[5, 4, 4, 0, 5, 7], [5, 7, 4, 0, 7, 3], [7, 3, 4, 0, 6, 2], [1, 0, 5, 1, 1, 3]], + [ + [1, 5, 3, 1, 2, 0], + [2, 0, 4, 5, 1, 5], + [2, 6, 4, 5, 2, 0], + [2, 6, 7, 5, 4, 5], + [2, 6, 3, 7, 7, 5], + ], + [[2, 3, 0, 4, 2, 0], [2, 3, 4, 5, 0, 4], [2, 3, 3, 7, 4, 5], [3, 7, 7, 5, 4, 5]], + [[3, 2, 7, 3, 7, 5], [7, 5, 5, 4, 3, 2], [3, 2, 5, 4, 1, 0]], + [ + [2, 3, 0, 4, 2, 0], + [2, 3, 4, 5, 0, 4], + [2, 3, 3, 7, 4, 5], + [3, 7, 7, 5, 4, 5], + [1, 5, 3, 1, 0, 1], + ], + [[3, 2, 1, 5, 3, 1], [3, 2, 5, 4, 1, 5], [3, 2, 7, 5, 5, 4], [3, 7, 7, 5, 3, 2]], + [[2, 6, 2, 3, 0, 4], [2, 3, 7, 5, 0, 4], [2, 3, 3, 1, 7, 5], [0, 4, 7, 5, 4, 5]], + [ + [3, 2, 1, 3, 5, 7], + [5, 7, 6, 2, 3, 2], + [5, 4, 6, 2, 5, 7], + [5, 4, 0, 2, 6, 2], + [5, 4, 1, 0, 0, 2], + ], + [ + [4, 5, 0, 4, 2, 6], + [2, 6, 7, 5, 4, 5], + [2, 3, 7, 5, 2, 6], + [2, 3, 1, 5, 7, 5], + [2, 3, 0, 1, 1, 5], + ], + [[2, 3, 2, 0, 2, 6], [1, 5, 7, 5, 4, 5]], + [[5, 7, 4, 5, 4, 0], [4, 0, 0, 2, 5, 7], [5, 7, 0, 2, 1, 3]], + [[5, 4, 1, 0, 1, 3], [5, 7, 5, 4, 1, 3]], + [[0, 2, 4, 5, 0, 4], [0, 2, 5, 7, 4, 5], [0, 2, 1, 5, 5, 7], [0, 1, 1, 5, 0, 2]], + [[5, 4, 5, 1, 5, 7]], + [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3]], + [[4, 6, 6, 2, 4, 5], [4, 5, 6, 2, 5, 1], [5, 1, 6, 2, 7, 3], [0, 2, 4, 0, 0, 1]], + [[3, 7, 3, 1, 2, 6], [3, 1, 5, 4, 2, 6], [3, 1, 1, 0, 5, 4], [2, 6, 5, 4, 6, 4]], + [ + [6, 4, 2, 6, 3, 7], + [3, 7, 5, 4, 6, 4], + [3, 1, 5, 4, 3, 7], + [3, 1, 0, 4, 5, 4], + [3, 1, 2, 0, 0, 4], + ], + [[2, 0, 2, 3, 6, 4], [2, 3, 1, 5, 6, 4], [6, 4, 1, 5, 4, 5], [2, 3, 3, 7, 1, 5]], + [ + [0, 4, 1, 0, 3, 2], + [3, 2, 6, 4, 0, 4], + [3, 7, 6, 4, 3, 2], + [3, 7, 5, 4, 6, 4], + [3, 7, 1, 5, 5, 4], + ], + [ + [1, 3, 0, 1, 4, 5], + [4, 5, 7, 3, 1, 3], + [4, 6, 7, 3, 4, 5], + [4, 6, 2, 3, 7, 3], + [4, 6, 0, 2, 2, 3], + ], + [[3, 7, 3, 1, 3, 2], [5, 4, 6, 4, 0, 4]], + [[3, 1, 2, 6, 3, 2], [3, 1, 6, 4, 2, 6], [3, 1, 1, 5, 6, 4], [1, 5, 5, 4, 6, 4]], + [ + [3, 1, 2, 6, 3, 2], + [3, 1, 6, 4, 2, 6], + [3, 1, 1, 5, 6, 4], + [1, 5, 5, 4, 6, 4], + [0, 4, 1, 0, 2, 0], + ], + [[4, 5, 6, 4, 6, 2], [6, 2, 2, 3, 4, 5], [4, 5, 2, 3, 0, 1]], + [[2, 3, 6, 4, 2, 6], [2, 3, 4, 5, 6, 4], [2, 3, 0, 4, 4, 5], [2, 0, 0, 4, 2, 3]], + [[1, 3, 5, 1, 5, 4], [5, 4, 4, 6, 1, 3], [1, 3, 4, 6, 0, 2]], + [[1, 3, 0, 4, 1, 0], [1, 3, 4, 6, 0, 4], [1, 3, 5, 4, 4, 6], [1, 5, 5, 4, 1, 3]], + [[4, 6, 0, 2, 0, 1], [4, 5, 4, 6, 0, 1]], + [[4, 6, 4, 0, 4, 5]], + [[4, 0, 6, 2, 7, 3], [4, 0, 7, 3, 5, 1]], + [[1, 5, 0, 1, 0, 2], [0, 2, 2, 6, 1, 5], [1, 5, 2, 6, 3, 7]], + [[3, 7, 1, 3, 1, 0], [1, 0, 0, 4, 3, 7], [3, 7, 0, 4, 2, 6]], + [[3, 1, 2, 0, 2, 6], [3, 7, 3, 1, 2, 6]], + [[0, 4, 2, 0, 2, 3], [2, 3, 3, 7, 0, 4], [0, 4, 3, 7, 1, 5]], + [[3, 7, 1, 5, 1, 0], [3, 2, 3, 7, 1, 0]], + [[0, 4, 1, 3, 0, 1], [0, 4, 3, 7, 1, 3], [0, 4, 2, 3, 3, 7], [0, 2, 2, 3, 0, 4]], + [[3, 7, 3, 1, 3, 2]], + [[2, 6, 3, 2, 3, 1], [3, 1, 1, 5, 2, 6], [2, 6, 1, 5, 0, 4]], + [[1, 5, 3, 2, 1, 3], [1, 5, 2, 6, 3, 2], [1, 5, 0, 2, 2, 6], [1, 0, 0, 2, 1, 5]], + [[2, 3, 0, 1, 0, 4], [2, 6, 2, 3, 0, 4]], + [[2, 3, 2, 0, 2, 6]], + [[1, 5, 0, 4, 0, 2], [1, 3, 1, 5, 0, 2]], + [[1, 5, 1, 0, 1, 3]], + [[0, 2, 0, 1, 0, 4]], + [], +] diff --git a/shap_e/rendering/blender/__init__.py b/shap_e/rendering/blender/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..043ec3e58e587fe8b2fd96e4c60c50bc4bd937cd --- /dev/null +++ b/shap_e/rendering/blender/__init__.py @@ -0,0 +1,4 @@ +from .render import render_mesh, render_model +from .view_data import BlenderViewData + +__all__ = ["BlenderViewData", "render_model"] diff --git a/shap_e/rendering/blender/__pycache__/__init__.cpython-39.pyc b/shap_e/rendering/blender/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e320fd68140fb89dbdd7fbc30817c0d8c60ea03 Binary files /dev/null and b/shap_e/rendering/blender/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/rendering/blender/__pycache__/constants.cpython-39.pyc b/shap_e/rendering/blender/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8113c4ba37afff1a2824924ebbaa22146438e050 Binary files /dev/null and b/shap_e/rendering/blender/__pycache__/constants.cpython-39.pyc differ diff --git a/shap_e/rendering/blender/__pycache__/render.cpython-39.pyc b/shap_e/rendering/blender/__pycache__/render.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e384af05c7a0734c4a4675294214095df55ab1dd Binary files /dev/null and b/shap_e/rendering/blender/__pycache__/render.cpython-39.pyc differ diff --git a/shap_e/rendering/blender/__pycache__/view_data.cpython-39.pyc b/shap_e/rendering/blender/__pycache__/view_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d85c8281d8de032db00ad97391a5f222dcc837f1 Binary files /dev/null and b/shap_e/rendering/blender/__pycache__/view_data.cpython-39.pyc differ diff --git a/shap_e/rendering/blender/blender_script.py b/shap_e/rendering/blender/blender_script.py new file mode 100644 index 0000000000000000000000000000000000000000..924a768157f2bd45b391933705a8f73b1e32cfb3 --- /dev/null +++ b/shap_e/rendering/blender/blender_script.py @@ -0,0 +1,676 @@ +""" +Script to run within blender. + +Provide arguments after `--`. +For example: `blender -b -P blender_script.py -- --help` +""" + +import argparse +import json +import math +import os +import random +import sys + +import bpy +from mathutils import Vector +from mathutils.noise import random_unit_vector + +MAX_DEPTH = 5.0 +FORMAT_VERSION = 6 + +# Set by main(), these constants are passed to the script to avoid +# duplicating them across multiple files. +UNIFORM_LIGHT_DIRECTION = None +BASIC_AMBIENT_COLOR = None +BASIC_DIFFUSE_COLOR = None + + +def clear_scene(): + bpy.ops.object.select_all(action="SELECT") + bpy.ops.object.delete() + + +def clear_lights(): + bpy.ops.object.select_all(action="DESELECT") + for obj in bpy.context.scene.objects.values(): + if isinstance(obj.data, bpy.types.Light): + obj.select_set(True) + bpy.ops.object.delete() + + +def import_model(path): + clear_scene() + _, ext = os.path.splitext(path) + ext = ext.lower() + if ext == ".obj": + bpy.ops.import_scene.obj(filepath=path) + elif ext in [".glb", ".gltf"]: + bpy.ops.import_scene.gltf(filepath=path) + elif ext == ".stl": + bpy.ops.import_mesh.stl(filepath=path) + elif ext == ".fbx": + bpy.ops.import_scene.fbx(filepath=path) + elif ext == ".dae": + bpy.ops.wm.collada_import(filepath=path) + elif ext == ".ply": + bpy.ops.import_mesh.ply(filepath=path) + else: + raise RuntimeError(f"unexpected extension: {ext}") + + +def scene_root_objects(): + for obj in bpy.context.scene.objects.values(): + if not obj.parent: + yield obj + + +def scene_bbox(single_obj=None, ignore_matrix=False): + bbox_min = (math.inf,) * 3 + bbox_max = (-math.inf,) * 3 + found = False + for obj in scene_meshes() if single_obj is None else [single_obj]: + found = True + for coord in obj.bound_box: + coord = Vector(coord) + if not ignore_matrix: + coord = obj.matrix_world @ coord + bbox_min = tuple(min(x, y) for x, y in zip(bbox_min, coord)) + bbox_max = tuple(max(x, y) for x, y in zip(bbox_max, coord)) + if not found: + raise RuntimeError("no objects in scene to compute bounding box for") + return Vector(bbox_min), Vector(bbox_max) + + +def scene_meshes(): + for obj in bpy.context.scene.objects.values(): + if isinstance(obj.data, (bpy.types.Mesh)): + yield obj + + +def normalize_scene(): + bbox_min, bbox_max = scene_bbox() + scale = 1 / max(bbox_max - bbox_min) + + for obj in scene_root_objects(): + obj.scale = obj.scale * scale + + # Apply scale to matrix_world. + bpy.context.view_layer.update() + + bbox_min, bbox_max = scene_bbox() + offset = -(bbox_min + bbox_max) / 2 + for obj in scene_root_objects(): + obj.matrix_world.translation += offset + + bpy.ops.object.select_all(action="DESELECT") + + +def create_camera(): + # https://b3d.interplanety.org/en/how-to-create-camera-through-the-blender-python-api/ + camera_data = bpy.data.cameras.new(name="Camera") + camera_object = bpy.data.objects.new("Camera", camera_data) + bpy.context.scene.collection.objects.link(camera_object) + bpy.context.scene.camera = camera_object + + +def set_camera(direction, camera_dist=2.0): + camera_pos = -camera_dist * direction + bpy.context.scene.camera.location = camera_pos + + # https://blender.stackexchange.com/questions/5210/pointing-the-camera-in-a-particular-direction-programmatically + rot_quat = direction.to_track_quat("-Z", "Y") + bpy.context.scene.camera.rotation_euler = rot_quat.to_euler() + + bpy.context.view_layer.update() + + +def randomize_camera(camera_dist=2.0): + direction = random_unit_vector() + set_camera(direction, camera_dist=camera_dist) + + +def pan_camera(time, axis="Z", camera_dist=2.0, elevation=0.1): + angle = time * math.pi * 2 + direction = [-math.cos(angle), -math.sin(angle), elevation] + assert axis in ["X", "Y", "Z"] + if axis == "X": + direction = [direction[2], *direction[:2]] + elif axis == "Y": + direction = [direction[0], elevation, direction[1]] + direction = Vector(direction).normalized() + set_camera(direction, camera_dist=camera_dist) + + +def place_camera(time, camera_pose_mode="random", camera_dist_min=2.0, camera_dist_max=2.0): + camera_dist = random.uniform(camera_dist_min, camera_dist_max) + if camera_pose_mode == "random": + randomize_camera(camera_dist=camera_dist) + elif camera_pose_mode == "z-circular": + pan_camera(time, axis="Z", camera_dist=camera_dist) + elif camera_pose_mode == "z-circular-elevated": + pan_camera(time, axis="Z", camera_dist=camera_dist, elevation=-0.2617993878) + else: + raise ValueError(f"Unknown camera pose mode: {camera_pose_mode}") + + +def create_light(location, energy=1.0, angle=0.5 * math.pi / 180): + # https://blender.stackexchange.com/questions/215624/how-to-create-a-light-with-the-python-api-in-blender-2-92 + light_data = bpy.data.lights.new(name="Light", type="SUN") + light_data.energy = energy + light_data.angle = angle + light_object = bpy.data.objects.new(name="Light", object_data=light_data) + + direction = -location + rot_quat = direction.to_track_quat("-Z", "Y") + light_object.rotation_euler = rot_quat.to_euler() + bpy.context.view_layer.update() + + bpy.context.collection.objects.link(light_object) + light_object.location = location + + +def create_random_lights(count=4, distance=2.0, energy=1.5): + clear_lights() + for _ in range(count): + create_light(random_unit_vector() * distance, energy=energy) + + +def create_camera_light(): + clear_lights() + create_light(bpy.context.scene.camera.location, energy=5.0) + + +def create_uniform_light(backend): + clear_lights() + # Random direction to decorrelate axis-aligned sides. + pos = Vector(UNIFORM_LIGHT_DIRECTION) + angle = 0.0092 if backend == "CYCLES" else math.pi + create_light(pos, energy=5.0, angle=angle) + create_light(-pos, energy=5.0, angle=angle) + + +def create_vertex_color_shaders(): + # By default, Blender will ignore vertex colors in both the + # Eevee and Cycles backends, since these colors aren't + # associated with a material. + # + # What we do here is create a simple material shader and link + # the vertex color to the material color. + for obj in bpy.context.scene.objects.values(): + if not isinstance(obj.data, (bpy.types.Mesh)): + continue + + if len(obj.data.materials): + # We don't want to override any existing materials. + continue + + color_keys = (obj.data.vertex_colors or {}).keys() + if not len(color_keys): + # Many objects will have no materials *or* vertex colors. + continue + + mat = bpy.data.materials.new(name="VertexColored") + mat.use_nodes = True + + # There should be a Principled BSDF by default. + bsdf_node = None + for node in mat.node_tree.nodes: + if node.type == "BSDF_PRINCIPLED": + bsdf_node = node + assert bsdf_node is not None, "material has no Principled BSDF node to modify" + + socket_map = {} + for input in bsdf_node.inputs: + socket_map[input.name] = input + + # Make sure nothing lights the object except for the diffuse color. + socket_map["Specular"].default_value = 0.0 + socket_map["Roughness"].default_value = 1.0 + + v_color = mat.node_tree.nodes.new("ShaderNodeVertexColor") + v_color.layer_name = color_keys[0] + + mat.node_tree.links.new(v_color.outputs[0], socket_map["Base Color"]) + + obj.data.materials.append(mat) + + +def create_default_materials(): + for obj in bpy.context.scene.objects.values(): + if isinstance(obj.data, (bpy.types.Mesh)): + if not len(obj.data.materials): + mat = bpy.data.materials.new(name="DefaultMaterial") + mat.use_nodes = True + obj.data.materials.append(mat) + + +def find_materials(): + all_materials = set() + for obj in bpy.context.scene.objects.values(): + if not isinstance(obj.data, bpy.types.Mesh): + continue + for mat in obj.data.materials: + all_materials.add(mat) + return all_materials + + +def delete_all_materials(): + for obj in bpy.context.scene.objects.values(): + if isinstance(obj.data, bpy.types.Mesh): + # https://blender.stackexchange.com/questions/146714/removing-all-material-slots-in-one-go + obj.data.materials.clear() + + +def setup_material_extraction_shaders(capturing_material_alpha: bool): + """ + Change every material to emit texture colors (or alpha) rather than having + an actual reflective color. Returns a function to undo the changes to the + materials. + """ + # Objects can share materials, so we first find all of the + # materials in the project, and then modify them each once. + undo_fns = [] + for mat in find_materials(): + undo_fn = setup_material_extraction_shader_for_material(mat, capturing_material_alpha) + if undo_fn is not None: + undo_fns.append(undo_fn) + return lambda: [undo_fn() for undo_fn in undo_fns] + + +def setup_material_extraction_shader_for_material(mat, capturing_material_alpha: bool): + mat.use_nodes = True + + # By default, most imported models should use the regular + # "Principled BSDF" material, so we should always find this. + # If not, this shader manipulation logic won't work. + bsdf_node = None + for node in mat.node_tree.nodes: + if node.type == "BSDF_PRINCIPLED": + bsdf_node = node + assert bsdf_node is not None, "material has no Principled BSDF node to modify" + + socket_map = {} + for input in bsdf_node.inputs: + socket_map[input.name] = input + for name in ["Base Color", "Emission", "Emission Strength", "Alpha", "Specular"]: + assert name in socket_map.keys(), f"{name} not in {list(socket_map.keys())}" + + old_base_color = get_socket_value(mat.node_tree, socket_map["Base Color"]) + old_alpha = get_socket_value(mat.node_tree, socket_map["Alpha"]) + old_emission = get_socket_value(mat.node_tree, socket_map["Emission"]) + old_emission_strength = get_socket_value(mat.node_tree, socket_map["Emission Strength"]) + old_specular = get_socket_value(mat.node_tree, socket_map["Specular"]) + + # Make sure the base color of all objects is black and the opacity + # is 1, so that we are effectively just telling the shader what color + # to make the pixels. + clear_socket_input(mat.node_tree, socket_map["Base Color"]) + socket_map["Base Color"].default_value = [0, 0, 0, 1] + clear_socket_input(mat.node_tree, socket_map["Alpha"]) + socket_map["Alpha"].default_value = 1 + clear_socket_input(mat.node_tree, socket_map["Specular"]) + socket_map["Specular"].default_value = 0.0 + + old_blend_method = mat.blend_method + mat.blend_method = "OPAQUE" + + if capturing_material_alpha: + set_socket_value(mat.node_tree, socket_map["Emission"], old_alpha) + else: + set_socket_value(mat.node_tree, socket_map["Emission"], old_base_color) + clear_socket_input(mat.node_tree, socket_map["Emission Strength"]) + socket_map["Emission Strength"].default_value = 1.0 + + def undo_fn(): + mat.blend_method = old_blend_method + set_socket_value(mat.node_tree, socket_map["Base Color"], old_base_color) + set_socket_value(mat.node_tree, socket_map["Alpha"], old_alpha) + set_socket_value(mat.node_tree, socket_map["Emission"], old_emission) + set_socket_value(mat.node_tree, socket_map["Emission Strength"], old_emission_strength) + set_socket_value(mat.node_tree, socket_map["Specular"], old_specular) + + return undo_fn + + +def get_socket_value(tree, socket): + default = socket.default_value + if not isinstance(default, float): + default = list(default) + for link in tree.links: + if link.to_socket == socket: + return (link.from_socket, default) + return (None, default) + + +def clear_socket_input(tree, socket): + for link in list(tree.links): + if link.to_socket == socket: + tree.links.remove(link) + + +def set_socket_value(tree, socket, socket_and_default): + clear_socket_input(tree, socket) + old_source_socket, default = socket_and_default + if isinstance(default, float) and not isinstance(socket.default_value, float): + # Codepath for setting Emission to a previous alpha value. + socket.default_value = [default] * 3 + [1.0] + else: + socket.default_value = default + if old_source_socket is not None: + tree.links.new(old_source_socket, socket) + + +def setup_nodes(output_path, capturing_material_alpha: bool = False, basic_lighting: bool = False): + tree = bpy.context.scene.node_tree + links = tree.links + + for node in tree.nodes: + tree.nodes.remove(node) + + # Helpers to perform math on links and constants. + def node_op(op: str, *args, clamp=False): + node = tree.nodes.new(type="CompositorNodeMath") + node.operation = op + if clamp: + node.use_clamp = True + for i, arg in enumerate(args): + if isinstance(arg, (int, float)): + node.inputs[i].default_value = arg + else: + links.new(arg, node.inputs[i]) + return node.outputs[0] + + def node_clamp(x, maximum=1.0): + return node_op("MINIMUM", x, maximum) + + def node_mul(x, y, **kwargs): + return node_op("MULTIPLY", x, y, **kwargs) + + def node_add(x, y, **kwargs): + return node_op("ADD", x, y, **kwargs) + + def node_abs(x, **kwargs): + return node_op("ABSOLUTE", x, **kwargs) + + input_node = tree.nodes.new(type="CompositorNodeRLayers") + input_node.scene = bpy.context.scene + + input_sockets = {} + for output in input_node.outputs: + input_sockets[output.name] = output + + if capturing_material_alpha: + color_socket = input_sockets["Image"] + else: + raw_color_socket = input_sockets["Image"] + if basic_lighting: + # Compute diffuse lighting + normal_xyz = tree.nodes.new(type="CompositorNodeSeparateXYZ") + tree.links.new(input_sockets["Normal"], normal_xyz.inputs[0]) + normal_x, normal_y, normal_z = [normal_xyz.outputs[i] for i in range(3)] + dot = node_add( + node_mul(UNIFORM_LIGHT_DIRECTION[0], normal_x), + node_add( + node_mul(UNIFORM_LIGHT_DIRECTION[1], normal_y), + node_mul(UNIFORM_LIGHT_DIRECTION[2], normal_z), + ), + ) + diffuse = node_abs(dot) + # Compute ambient + diffuse lighting + brightness = node_add(BASIC_AMBIENT_COLOR, node_mul(BASIC_DIFFUSE_COLOR, diffuse)) + # Modulate the RGB channels using the total brightness. + rgba_node = tree.nodes.new(type="CompositorNodeSepRGBA") + tree.links.new(raw_color_socket, rgba_node.inputs[0]) + combine_node = tree.nodes.new(type="CompositorNodeCombRGBA") + for i in range(3): + tree.links.new(node_mul(rgba_node.outputs[i], brightness), combine_node.inputs[i]) + tree.links.new(rgba_node.outputs[3], combine_node.inputs[3]) + raw_color_socket = combine_node.outputs[0] + + # We apply sRGB here so that our fixed-point depth map and material + # alpha values are not sRGB, and so that we perform ambient+diffuse + # lighting in linear RGB space. + color_node = tree.nodes.new(type="CompositorNodeConvertColorSpace") + color_node.from_color_space = "Linear" + color_node.to_color_space = "sRGB" + tree.links.new(raw_color_socket, color_node.inputs[0]) + color_socket = color_node.outputs[0] + split_node = tree.nodes.new(type="CompositorNodeSepRGBA") + tree.links.new(color_socket, split_node.inputs[0]) + # Create separate file output nodes for every channel we care about. + # The process calling this script must decide how to recombine these + # channels, possibly into a single image. + for i, channel in enumerate("rgba") if not capturing_material_alpha else [(0, "MatAlpha")]: + output_node = tree.nodes.new(type="CompositorNodeOutputFile") + output_node.base_path = f"{output_path}_{channel}" + links.new(split_node.outputs[i], output_node.inputs[0]) + + if capturing_material_alpha: + # No need to re-write depth here. + return + + depth_out = node_clamp(node_mul(input_sockets["Depth"], 1 / MAX_DEPTH)) + output_node = tree.nodes.new(type="CompositorNodeOutputFile") + output_node.base_path = f"{output_path}_depth" + links.new(depth_out, output_node.inputs[0]) + + +def render_scene(output_path, fast_mode: bool, extract_material: bool, basic_lighting: bool): + use_workbench = bpy.context.scene.render.engine == "BLENDER_WORKBENCH" + if use_workbench: + # We must use a different engine to compute depth maps. + print("Switching to EEVEE for depth map computation.") + bpy.context.scene.render.engine = "BLENDER_EEVEE" + bpy.context.scene.eevee.taa_render_samples = 1 # faster, since we discard image. + if fast_mode: + if bpy.context.scene.render.engine == "BLENDER_EEVEE": + bpy.context.scene.eevee.taa_render_samples = 1 + elif bpy.context.scene.render.engine == "CYCLES": + bpy.context.scene.cycles.samples = 256 + else: + if bpy.context.scene.render.engine == "CYCLES": + # We should still impose a per-frame time limit + # so that we don't timeout completely. + bpy.context.scene.cycles.time_limit = 40 + bpy.context.view_layer.update() + bpy.context.scene.use_nodes = True + bpy.context.scene.view_layers["ViewLayer"].use_pass_z = True + if basic_lighting: + bpy.context.scene.view_layers["ViewLayer"].use_pass_normal = True + bpy.context.scene.view_settings.view_transform = "Raw" # sRGB done in graph nodes + bpy.context.scene.render.film_transparent = True + bpy.context.scene.render.resolution_x = 512 + bpy.context.scene.render.resolution_y = 512 + bpy.context.scene.render.image_settings.file_format = "PNG" + bpy.context.scene.render.image_settings.color_mode = "BW" + bpy.context.scene.render.image_settings.color_depth = "16" + bpy.context.scene.render.filepath = output_path + if extract_material: + for do_alpha in [False, True]: + undo_fn = setup_material_extraction_shaders(capturing_material_alpha=do_alpha) + setup_nodes(output_path, capturing_material_alpha=do_alpha) + bpy.ops.render.render(write_still=True) + undo_fn() + else: + setup_nodes(output_path, basic_lighting=basic_lighting) + bpy.ops.render.render(write_still=True) + + # The output images must be moved from their own sub-directories, or + # discarded if we are using workbench for the color. + for channel_name in ["r", "g", "b", "a", "depth", *(["MatAlpha"] if extract_material else [])]: + sub_dir = f"{output_path}_{channel_name}" + image_path = os.path.join(sub_dir, os.listdir(sub_dir)[0]) + name, ext = os.path.splitext(output_path) + if channel_name == "depth" or not use_workbench: + os.rename(image_path, f"{name}_{channel_name}{ext}") + else: + os.remove(image_path) + os.removedirs(sub_dir) + + if use_workbench: + # Re-render RGBA using workbench with texture mode, since this seems + # to show the most reasonable colors when lighting is broken. + bpy.context.scene.use_nodes = False + bpy.context.scene.render.engine = "BLENDER_WORKBENCH" + bpy.context.scene.render.image_settings.color_mode = "RGBA" + bpy.context.scene.render.image_settings.color_depth = "8" + bpy.context.scene.display.shading.color_type = "TEXTURE" + bpy.context.scene.display.shading.light = "FLAT" + if fast_mode: + # Single pass anti-aliasing. + bpy.context.scene.display.render_aa = "FXAA" + os.remove(output_path) + bpy.ops.render.render(write_still=True) + bpy.context.scene.render.image_settings.color_mode = "BW" + bpy.context.scene.render.image_settings.color_depth = "16" + + +def scene_fov(): + x_fov = bpy.context.scene.camera.data.angle_x + y_fov = bpy.context.scene.camera.data.angle_y + width = bpy.context.scene.render.resolution_x + height = bpy.context.scene.render.resolution_y + if bpy.context.scene.camera.data.angle == x_fov: + y_fov = 2 * math.atan(math.tan(x_fov / 2) * height / width) + else: + x_fov = 2 * math.atan(math.tan(y_fov / 2) * width / height) + return x_fov, y_fov + + +def write_camera_metadata(path): + x_fov, y_fov = scene_fov() + bbox_min, bbox_max = scene_bbox() + matrix = bpy.context.scene.camera.matrix_world + with open(path, "w") as f: + json.dump( + dict( + format_version=FORMAT_VERSION, + max_depth=MAX_DEPTH, + bbox=[list(bbox_min), list(bbox_max)], + origin=list(matrix.col[3])[:3], + x_fov=x_fov, + y_fov=y_fov, + x=list(matrix.col[0])[:3], + y=list(-matrix.col[1])[:3], + z=list(-matrix.col[2])[:3], + ), + f, + ) + + +def save_rendering_dataset( + input_path: str, + output_path: str, + num_images: int, + backend: str, + light_mode: str, + camera_pose: str, + camera_dist_min: float, + camera_dist_max: float, + fast_mode: bool, + extract_material: bool, + delete_material: bool, +): + assert light_mode in ["random", "uniform", "camera", "basic"] + assert camera_pose in ["random", "z-circular", "z-circular-elevated"] + + basic_lighting = light_mode == "basic" + assert not (basic_lighting and extract_material), "cannot extract material with basic lighting" + assert not (delete_material and extract_material), "cannot extract material and delete it" + + import_model(input_path) + bpy.context.scene.render.engine = backend + normalize_scene() + if light_mode == "random": + create_random_lights() + elif light_mode == "uniform": + create_uniform_light(backend) + create_camera() + create_vertex_color_shaders() + if delete_material: + delete_all_materials() + if extract_material or basic_lighting: + create_default_materials() + if basic_lighting: + # Make sure materials are uniformly lit, so that we can light + # them in the output shader. + setup_material_extraction_shaders(capturing_material_alpha=False) + for i in range(num_images): + t = i / max(num_images - 1, 1) # same as np.linspace(0, 1, num_images) + place_camera( + t, + camera_pose_mode=camera_pose, + camera_dist_min=camera_dist_min, + camera_dist_max=camera_dist_max, + ) + if light_mode == "camera": + create_camera_light() + render_scene( + os.path.join(output_path, f"{i:05}.png"), + fast_mode=fast_mode, + extract_material=extract_material, + basic_lighting=basic_lighting, + ) + write_camera_metadata(os.path.join(output_path, f"{i:05}.json")) + with open(os.path.join(output_path, "info.json"), "w") as f: + info = dict( + backend=backend, + light_mode=light_mode, + fast_mode=fast_mode, + extract_material=extract_material, + format_version=FORMAT_VERSION, + channels=["R", "G", "B", "A", "D", *(["MatAlpha"] if extract_material else [])], + scale=0.5, # The scene is bounded by [-scale, scale]. + ) + json.dump(info, f) + + +def main(): + global UNIFORM_LIGHT_DIRECTION, BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR + + try: + dash_index = sys.argv.index("--") + except ValueError as exc: + raise ValueError("arguments must be preceded by '--'") from exc + + raw_args = sys.argv[dash_index + 1 :] + parser = argparse.ArgumentParser() + parser.add_argument("--input_path", required=True, type=str) + parser.add_argument("--output_path", required=True, type=str) + parser.add_argument("--num_images", required=True, type=int) + parser.add_argument("--backend", type=str, default="BLENDER_EEVEE") + parser.add_argument("--light_mode", type=str, default="random") + parser.add_argument("--camera_pose", type=str, default="random") + parser.add_argument("--camera_dist_min", type=float, default=2.0) + parser.add_argument("--camera_dist_max", type=float, default=2.0) + parser.add_argument("--fast_mode", action="store_true") + parser.add_argument("--extract_material", action="store_true") + parser.add_argument("--delete_material", action="store_true") + + # Prevent constants from being repeated. + parser.add_argument("--uniform_light_direction", required=True, type=float, nargs="+") + parser.add_argument("--basic_ambient", required=True, type=float) + parser.add_argument("--basic_diffuse", required=True, type=float) + args = parser.parse_args(raw_args) + + UNIFORM_LIGHT_DIRECTION = args.uniform_light_direction + BASIC_AMBIENT_COLOR = args.basic_ambient + BASIC_DIFFUSE_COLOR = args.basic_diffuse + + save_rendering_dataset( + input_path=args.input_path, + output_path=args.output_path, + num_images=args.num_images, + backend=args.backend, + light_mode=args.light_mode, + camera_pose=args.camera_pose, + camera_dist_min=args.camera_dist_min, + camera_dist_max=args.camera_dist_max, + fast_mode=args.fast_mode, + extract_material=args.extract_material, + delete_material=args.delete_material, + ) + + +main() diff --git a/shap_e/rendering/blender/constants.py b/shap_e/rendering/blender/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..14aaf07a562873ab24fbb7b89cf600f176fb5c4d --- /dev/null +++ b/shap_e/rendering/blender/constants.py @@ -0,0 +1,3 @@ +UNIFORM_LIGHT_DIRECTION = [0.09387503, -0.63953443, -0.7630093] +BASIC_AMBIENT_COLOR = 0.3 +BASIC_DIFFUSE_COLOR = 0.7 diff --git a/shap_e/rendering/blender/render.py b/shap_e/rendering/blender/render.py new file mode 100644 index 0000000000000000000000000000000000000000..d54efb7cfa3e57e80f87ebe12ca40aed0eb95c99 --- /dev/null +++ b/shap_e/rendering/blender/render.py @@ -0,0 +1,147 @@ +import os +import platform +import subprocess +import tempfile +import zipfile + +import blobfile as bf +import numpy as np +from PIL import Image + +from shap_e.rendering.mesh import TriMesh + +from .constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION + +SCRIPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "blender_script.py") + + +def render_model( + model_path: str, + output_path: str, + num_images: int, + backend: str = "BLENDER_EEVEE", + light_mode: str = "random", + camera_pose: str = "random", + camera_dist_min: float = 2.0, + camera_dist_max: float = 2.0, + fast_mode: bool = False, + extract_material: bool = False, + delete_material: bool = False, + verbose: bool = False, + timeout: float = 15 * 60, +): + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_in = model_path + tmp_out = os.path.join(tmp_dir, "out") + zip_out = tmp_out + ".zip" + os.mkdir(tmp_out) + args = [] + if platform.system() == "Linux": + # Needed to enable Eevee backend on headless linux. + print("Needed to enable Eevee backend on headless linux.") + args = ["xvfb-run", "-a"] + args.extend( + [ + _blender_binary_path(), + "-b", + "-P", + SCRIPT_PATH, + "--", + "--input_path", + tmp_in, + "--output_path", + tmp_out, + "--num_images", + str(num_images), + "--backend", + backend, + "--light_mode", + light_mode, + "--camera_pose", + camera_pose, + "--camera_dist_min", + str(camera_dist_min), + "--camera_dist_max", + str(camera_dist_max), + "--uniform_light_direction", + *[str(x) for x in UNIFORM_LIGHT_DIRECTION], + "--basic_ambient", + str(BASIC_AMBIENT_COLOR), + "--basic_diffuse", + str(BASIC_DIFFUSE_COLOR), + ] + ) + if fast_mode: + args.append("--fast_mode") + if extract_material: + args.append("--extract_material") + if delete_material: + args.append("--delete_material") + if verbose: + print("args:", args) + subprocess.check_call(args) + else: + try: + output = subprocess.check_output(args, stderr=subprocess.STDOUT, timeout=timeout) + except subprocess.CalledProcessError as exc: + raise RuntimeError(f"{exc}: {exc.output}") from exc + if not os.path.exists(os.path.join(tmp_out, "info.json")): + if verbose: + # There is no output available, since it was + # logged directly to stdout/stderr. + raise RuntimeError(f"render failed: output file missing") + else: + raise RuntimeError(f"render failed: output file missing. Output: {output}") + _combine_rgba(tmp_out) + with zipfile.ZipFile(zip_out, mode="w") as zf: + for name in os.listdir(tmp_out): + zf.write(os.path.join(tmp_out, name), name) + bf.copy(zip_out, output_path, overwrite=True) + + +def render_mesh( + mesh: TriMesh, + output_path: str, + num_images: int, + backend: str = "BLENDER_EEVEE", + **kwargs, +): + if mesh.has_vertex_colors() and backend not in ["BLENDER_EEVEE", "CYCLES"]: + raise ValueError(f"backend does not support vertex colors: {backend}") + + with tempfile.TemporaryDirectory() as tmp_dir: + ply_path = os.path.join(tmp_dir, "out.ply") + with open(ply_path, "wb") as f: + mesh.write_ply(f) + render_model( + ply_path, output_path=output_path, num_images=num_images, backend=backend, **kwargs + ) + + +def _combine_rgba(out_dir: str): + i = 0 + while True: + paths = [os.path.join(out_dir, f"{i:05}_{ch}.png") for ch in "rgba"] + if not os.path.exists(paths[0]): + break + joined = np.stack( + [(np.array(Image.open(path)) >> 8).astype(np.uint8) for path in paths], axis=-1 + ) + Image.fromarray(joined).save(os.path.join(out_dir, f"{i:05}.png")) + for path in paths: + os.remove(path) + i += 1 + + +def _blender_binary_path() -> str: + path = os.getenv("BLENDER_PATH", None) + if path is not None: + return path + + if os.path.exists("/Applications/Blender.app/Contents/MacOS/Blender"): + return "/Applications/Blender.app/Contents/MacOS/Blender" + + raise EnvironmentError( + "To render 3D models, install Blender version 3.3.1 or higher and " + "set the environment variable `BLENDER_PATH` to the path of the Blender executable." + ) diff --git a/shap_e/rendering/blender/view_data.py b/shap_e/rendering/blender/view_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2f726ec40f8b5fbf713af145282c9d7475dfb2ea --- /dev/null +++ b/shap_e/rendering/blender/view_data.py @@ -0,0 +1,84 @@ +import itertools +import json +import zipfile +from typing import BinaryIO, List, Tuple + +import numpy as np +from PIL import Image + +from shap_e.rendering.view_data import Camera, ProjectiveCamera, ViewData + + +class BlenderViewData(ViewData): + """ + Interact with a dataset zipfile exported by view_data.py. + """ + + def __init__(self, f_obj: BinaryIO): + self.zipfile = zipfile.ZipFile(f_obj, mode="r") + self.infos = [] + with self.zipfile.open("info.json", "r") as f: + self.info = json.load(f) + self.channels = list(self.info.get("channels", "RGBAD")) + assert set("RGBA").issubset( + set(self.channels) + ), "The blender output should at least have RGBA images." + names = set(x.filename for x in self.zipfile.infolist()) + for i in itertools.count(): + name = f"{i:05}.json" + if name not in names: + break + with self.zipfile.open(name, "r") as f: + self.infos.append(json.load(f)) + + @property + def num_views(self) -> int: + return len(self.infos) + + @property + def channel_names(self) -> List[str]: + return list(self.channels) + + def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]: + for ch in channels: + if ch not in self.channel_names: + raise ValueError(f"unsupported channel: {ch}") + + # Gather (a superset of) the requested channels. + channel_map = {} + if any(x in channels for x in "RGBA"): + with self.zipfile.open(f"{index:05}.png", "r") as f: + rgba = np.array(Image.open(f)).astype(np.float32) / 255.0 + channel_map.update(zip("RGBA", rgba.transpose([2, 0, 1]))) + if "D" in channels: + with self.zipfile.open(f"{index:05}_depth.png", "r") as f: + # Decode a 16-bit fixed-point number. + fp = np.array(Image.open(f)) + inf_dist = fp == 0xFFFF + channel_map["D"] = np.where( + inf_dist, + np.inf, + self.infos[index]["max_depth"] * (fp.astype(np.float32) / 65536), + ) + if "MatAlpha" in channels: + with self.zipfile.open(f"{index:05}_MatAlpha.png", "r") as f: + channel_map["MatAlpha"] = np.array(Image.open(f)).astype(np.float32) / 65536 + + # The order of channels is user-specified. + combined = np.stack([channel_map[k] for k in channels], axis=-1) + + h, w, _ = combined.shape + return self.camera(index, w, h), combined + + def camera(self, index: int, width: int, height: int) -> ProjectiveCamera: + info = self.infos[index] + return ProjectiveCamera( + origin=np.array(info["origin"], dtype=np.float32), + x=np.array(info["x"], dtype=np.float32), + y=np.array(info["y"], dtype=np.float32), + z=np.array(info["z"], dtype=np.float32), + width=width, + height=height, + x_fov=info["x_fov"], + y_fov=info["y_fov"], + ) diff --git a/shap_e/rendering/mc.py b/shap_e/rendering/mc.py new file mode 100644 index 0000000000000000000000000000000000000000..128070755e0af76d657bbc7e137557fdddec45e1 --- /dev/null +++ b/shap_e/rendering/mc.py @@ -0,0 +1,253 @@ +from dataclasses import dataclass +from functools import lru_cache +from typing import Tuple + +import torch + +from ._mc_table import MC_TABLE +from .torch_mesh import TorchMesh + + +def marching_cubes( + field: torch.Tensor, + min_point: torch.Tensor, + size: torch.Tensor, +) -> TorchMesh: + """ + For a signed distance field, produce a mesh using marching cubes. + + :param field: a 3D tensor of field values, where negative values correspond + to the outside of the shape. The dimensions correspond to the + x, y, and z directions, respectively. + :param min_point: a tensor of shape [3] containing the point corresponding + to (0, 0, 0) in the field. + :param size: a tensor of shape [3] containing the per-axis distance from the + (0, 0, 0) field corner and the (-1, -1, -1) field corner. + """ + assert len(field.shape) == 3, "input must be a 3D scalar field" + dev = field.device + + grid_size = field.shape + grid_size_tensor = torch.tensor(grid_size).to(size) + lut = _lookup_table(dev) + + # Create bitmasks between 0 and 255 (inclusive) indicating the state + # of the eight corners of each cube. + bitmasks = (field > 0).to(torch.uint8) + bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) + bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) + bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) + + # Compute corner coordinates across the entire grid. + corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) + corner_coords[range(grid_size[0]), :, :, 0] = torch.arange( + grid_size[0], device=dev, dtype=field.dtype + )[:, None, None] + corner_coords[:, range(grid_size[1]), :, 1] = torch.arange( + grid_size[1], device=dev, dtype=field.dtype + )[:, None] + corner_coords[:, :, range(grid_size[2]), 2] = torch.arange( + grid_size[2], device=dev, dtype=field.dtype + ) + + # Compute all vertices across all edges in the grid, even though we will + # throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. + # These are all midpoints, and don't account for interpolation (which is + # done later based on the used edge midpoints). + edge_midpoints = torch.cat( + [ + ((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), + ((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), + ], + dim=0, + ) + + # Create a flat array of [X, Y, Z] indices for each cube. + cube_indices = torch.zeros( + grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long + ) + cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[ + :, None, None + ] + cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[ + :, None + ] + cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) + flat_cube_indices = cube_indices.reshape(-1, 3) + + # Create a flat array mapping each cube to 12 global edge indices. + edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) + + # Apply the LUT to figure out the triangles. + flat_bitmasks = bitmasks.reshape( + -1 + ).long() # must cast to long for indexing to believe this not a mask + local_tris = lut.cases[flat_bitmasks] + local_masks = lut.masks[flat_bitmasks] + # Compute the global edge indices for the triangles. + global_tris = torch.gather( + edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1) + ).reshape(local_tris.shape) + # Select the used triangles for each cube. + selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] + + # Now we have a bunch of indices into the full list of possible vertices, + # but we want to reduce this list to only the used vertices. + used_vertex_indices = torch.unique(selected_tris.view(-1)) + used_edge_midpoints = edge_midpoints[used_vertex_indices] + old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) + old_index_to_new_index[used_vertex_indices] = torch.arange( + len(used_vertex_indices), device=dev, dtype=torch.long + ) + + # Rewrite the triangles to use the new indices + selected_tris = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape( + selected_tris.shape + ) + + # Compute the actual interpolated coordinates corresponding to edge midpoints. + v1 = torch.floor(used_edge_midpoints).to(torch.long) + v2 = torch.ceil(used_edge_midpoints).to(torch.long) + s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] + s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] + p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point + p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point + # The signs of s1 and s2 should be different. We want to find + # t such that t*s2 + (1-t)*s1 = 0. + t = (s1 / (s1 - s2))[:, None] + verts = t * p2 + (1 - t) * p1 + + return TorchMesh(verts=verts, faces=selected_tris) + + +def _create_flat_edge_indices( + flat_cube_indices: torch.Tensor, grid_size: Tuple[int, int, int] +) -> torch.Tensor: + num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] + y_offset = num_xs + num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] + z_offset = num_xs + num_ys + return torch.stack( + [ + # Edges spanning x-axis. + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2], + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] + + (flat_cube_indices[:, 1] + 1) * grid_size[2] + + flat_cube_indices[:, 2] + + 1, + # Edges spanning y-axis. + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + ), + ( + y_offset + + flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + ( + y_offset + + (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] + + flat_cube_indices[:, 1] * grid_size[2] + + flat_cube_indices[:, 2] + + 1 + ), + # Edges spanning z-axis. + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + flat_cube_indices[:, 1] * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ( + z_offset + + (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) + + (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) + + flat_cube_indices[:, 2] + ), + ], + dim=-1, + ) + + +@dataclass +class McLookupTable: + # Coordinates in triangles are represented as edge indices from 0-12 + # Here is an MC cell with both corner and edge indices marked. + # 6 + ---------- 3 ----------+ 7 + # /| /| + # 6 | 7 | + # / | / | + # 4 +--------- 2 ------------+ 5 | + # | 10 | | + # | | | 11 + # | | | | + # 8 | 2 9 | 3 + # | +--------- 1 --------|---+ + # | / | / + # | 4 | 5 + # |/ |/ + # +---------- 0 -----------+ + # 0 1 + cases: torch.Tensor # [256 x 5 x 3] long tensor + masks: torch.Tensor # [256 x 5] bool tensor + + +@lru_cache(maxsize=9) # if there's more than 8 GPUs and a CPU, don't bother caching +def _lookup_table(device: torch.device) -> McLookupTable: + cases = torch.zeros(256, 5, 3, device=device, dtype=torch.long) + masks = torch.zeros(256, 5, device=device, dtype=torch.bool) + + edge_to_index = { + (0, 1): 0, + (2, 3): 1, + (4, 5): 2, + (6, 7): 3, + (0, 2): 4, + (1, 3): 5, + (4, 6): 6, + (5, 7): 7, + (0, 4): 8, + (1, 5): 9, + (2, 6): 10, + (3, 7): 11, + } + + for i, case in enumerate(MC_TABLE): + for j, tri in enumerate(case): + for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])): + cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)] + masks[i, j] = True + return McLookupTable(cases=cases, masks=masks) diff --git a/shap_e/rendering/mesh.py b/shap_e/rendering/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..f120ddf26123f5e814403708221640f233609a9d --- /dev/null +++ b/shap_e/rendering/mesh.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass, field +from typing import BinaryIO, Dict, Optional, Union + +import blobfile as bf +import numpy as np + +from .ply_util import write_ply + + +@dataclass +class TriMesh: + """ + A 3D triangle mesh with optional data at the vertices and faces. + """ + + # [N x 3] array of vertex coordinates. + verts: np.ndarray + + # [M x 3] array of triangles, pointing to indices in verts. + faces: np.ndarray + + # [P x 3] array of normal vectors per face. + normals: Optional[np.ndarray] = None + + # Extra data per vertex and face. + vertex_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict) + face_channels: Optional[Dict[str, np.ndarray]] = field(default_factory=dict) + + @classmethod + def load(cls, f: Union[str, BinaryIO]) -> "TriMesh": + """ + Load the mesh from a .npz file. + """ + if isinstance(f, str): + with bf.BlobFile(f, "rb") as reader: + return cls.load(reader) + else: + obj = np.load(f) + keys = list(obj.keys()) + verts = obj["verts"] + faces = obj["faces"] + normals = obj["normals"] if "normals" in keys else None + vertex_channels = {} + face_channels = {} + for key in keys: + if key.startswith("v_"): + vertex_channels[key[2:]] = obj[key] + elif key.startswith("f_"): + face_channels[key[2:]] = obj[key] + return cls( + verts=verts, + faces=faces, + normals=normals, + vertex_channels=vertex_channels, + face_channels=face_channels, + ) + + def save(self, f: Union[str, BinaryIO]): + """ + Save the mesh to a .npz file. + """ + if isinstance(f, str): + with bf.BlobFile(f, "wb") as writer: + self.save(writer) + else: + obj_dict = dict(verts=self.verts, faces=self.faces) + if self.normals is not None: + obj_dict["normals"] = self.normals + for k, v in self.vertex_channels.items(): + obj_dict[f"v_{k}"] = v + for k, v in self.face_channels.items(): + obj_dict[f"f_{k}"] = v + np.savez(f, **obj_dict) + + def has_vertex_colors(self) -> bool: + return self.vertex_channels is not None and all(x in self.vertex_channels for x in "RGB") + + def write_ply(self, raw_f: BinaryIO): + write_ply( + raw_f, + coords=self.verts, + rgb=( + np.stack([self.vertex_channels[x] for x in "RGB"], axis=1) + if self.has_vertex_colors() + else None + ), + faces=self.faces, + ) + + def write_obj(self, raw_f: BinaryIO): + if self.has_vertex_colors(): + vertex_colors = np.stack([self.vertex_channels[x] for x in "RGB"], axis=1) + vertices = [ + "{} {} {} {} {} {}".format(*coord, *color) + for coord, color in zip(self.verts.tolist(), vertex_colors.tolist()) + ] + else: + vertices = ["{} {} {}".format(*coord) for coord in self.verts.tolist()] + + faces = [ + "f {} {} {}".format(str(tri[0] + 1), str(tri[1] + 1), str(tri[2] + 1)) + for tri in self.faces.tolist() + ] + + combined_data = ["v " + vertex for vertex in vertices] + faces + + raw_f.writelines("\n".join(combined_data)) diff --git a/shap_e/rendering/ply_util.py b/shap_e/rendering/ply_util.py new file mode 100644 index 0000000000000000000000000000000000000000..0500b64e783b77d71134e8cd419d5905c0019d54 --- /dev/null +++ b/shap_e/rendering/ply_util.py @@ -0,0 +1,58 @@ +import struct +from typing import BinaryIO, Optional + +import numpy as np + +from shap_e.util.io import buffered_writer + + +def write_ply( + raw_f: BinaryIO, + coords: np.ndarray, + rgb: Optional[np.ndarray] = None, + faces: Optional[np.ndarray] = None, +): + """ + Write a PLY file for a mesh or a point cloud. + + :param coords: an [N x 3] array of floating point coordinates. + :param rgb: an [N x 3] array of vertex colors, in the range [0.0, 1.0]. + :param faces: an [N x 3] array of triangles encoded as integer indices. + """ + with buffered_writer(raw_f) as f: + f.write(b"ply\n") + f.write(b"format binary_little_endian 1.0\n") + f.write(bytes(f"element vertex {len(coords)}\n", "ascii")) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") + if rgb is not None: + f.write(b"property uchar red\n") + f.write(b"property uchar green\n") + f.write(b"property uchar blue\n") + if faces is not None: + f.write(bytes(f"element face {len(faces)}\n", "ascii")) + f.write(b"property list uchar int vertex_index\n") + f.write(b"end_header\n") + + if rgb is not None: + rgb = (rgb * 255.499).round().astype(int) + vertices = [ + (*coord, *rgb) + for coord, rgb in zip( + coords.tolist(), + rgb.tolist(), + ) + ] + format = struct.Struct("<3f3B") + for item in vertices: + f.write(format.pack(*item)) + else: + format = struct.Struct("<3f") + for vertex in coords.tolist(): + f.write(format.pack(*vertex)) + + if faces is not None: + format = struct.Struct(" "PointCloud": + """ + Construct a point cloud from the given view data. + + The data must have a depth channel. All other channels will be stored + in the `channels` attribute of the result. + + Pixels in the rendered views are not converted into points in the cloud + if they have infinite depth or less than 1.0 alpha. + """ + channel_names = vd.channel_names + if "D" not in channel_names: + raise ValueError(f"view data must have depth channel") + depth_index = channel_names.index("D") + + all_coords = [] + all_channels = defaultdict(list) + + if num_views is None: + num_views = vd.num_views + for i in range(num_views): + camera, channel_values = vd.load_view(i, channel_names) + flat_values = channel_values.reshape([-1, len(channel_names)]) + + # Create an array of integer (x, y) image coordinates for Camera methods. + image_coords = camera.image_coords() + + # Select subset of pixels that have meaningful depth/color. + image_mask = np.isfinite(flat_values[:, depth_index]) + if "A" in channel_names: + image_mask = image_mask & (flat_values[:, channel_names.index("A")] >= 1 - 1e-5) + image_coords = image_coords[image_mask] + flat_values = flat_values[image_mask] + + # Use the depth and camera information to compute the coordinates + # corresponding to every visible pixel. + camera_rays = camera.camera_rays(image_coords) + camera_origins = camera_rays[:, 0] + camera_directions = camera_rays[:, 1] + depth_dirs = camera.depth_directions(image_coords) + ray_scales = flat_values[:, depth_index] / np.sum( + camera_directions * depth_dirs, axis=-1 + ) + coords = camera_origins + camera_directions * ray_scales[:, None] + + all_coords.append(coords) + for j, name in enumerate(channel_names): + if name != "D": + all_channels[name].append(flat_values[:, j]) + + if len(all_coords) == 0: + return cls(coords=np.zeros([0, 3], dtype=np.float32), channels={}) + + return cls( + coords=np.concatenate(all_coords, axis=0), + channels={k: np.concatenate(v, axis=0) for k, v in all_channels.items()}, + ) + + @classmethod + def load(cls, f: Union[str, BinaryIO]) -> "PointCloud": + """ + Load the point cloud from a .npz file. + """ + if isinstance(f, str): + with bf.BlobFile(f, "rb") as reader: + return cls.load(reader) + else: + obj = np.load(f) + keys = list(obj.keys()) + return PointCloud( + coords=obj["coords"], + channels={k: obj[k] for k in keys if k != "coords"}, + ) + + def save(self, f: Union[str, BinaryIO]): + """ + Save the point cloud to a .npz file. + """ + if isinstance(f, str): + with bf.BlobFile(f, "wb") as writer: + self.save(writer) + else: + np.savez(f, coords=self.coords, **self.channels) + + def write_ply(self, raw_f: BinaryIO): + write_ply( + raw_f, + coords=self.coords, + rgb=( + np.stack([self.channels[x] for x in "RGB"], axis=1) + if all(x in self.channels for x in "RGB") + else None + ), + ) + + def random_sample(self, num_points: int, **subsample_kwargs) -> "PointCloud": + """ + Sample a random subset of this PointCloud. + + :param num_points: maximum number of points to sample. + :param subsample_kwargs: arguments to self.subsample(). + :return: a reduced PointCloud, or self if num_points is not less than + the current number of points. + """ + if len(self.coords) <= num_points: + return self + indices = np.random.choice(len(self.coords), size=(num_points,), replace=False) + return self.subsample(indices, **subsample_kwargs) + + def farthest_point_sample( + self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs + ) -> "PointCloud": + """ + Sample a subset of the point cloud that is evenly distributed in space. + + First, a random point is selected. Then each successive point is chosen + such that it is furthest from the currently selected points. + + The time complexity of this operation is O(NM), where N is the original + number of points and M is the reduced number. Therefore, performance + can be improved by randomly subsampling points with random_sample() + before running farthest_point_sample(). + + :param num_points: maximum number of points to sample. + :param init_idx: if specified, the first point to sample. + :param subsample_kwargs: arguments to self.subsample(). + :return: a reduced PointCloud, or self if num_points is not less than + the current number of points. + """ + if len(self.coords) <= num_points: + return self + init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx + indices = np.zeros([num_points], dtype=np.int64) + indices[0] = init_idx + sq_norms = np.sum(self.coords**2, axis=-1) + + def compute_dists(idx: int): + # Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B). + return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx]) + + cur_dists = compute_dists(init_idx) + for i in range(1, num_points): + idx = np.argmax(cur_dists) + indices[i] = idx + + # Without this line, we may duplicate an index more than once if + # there are duplicate points, due to rounding errors. + cur_dists[idx] = -1 + + cur_dists = np.minimum(cur_dists, compute_dists(idx)) + + return self.subsample(indices, **subsample_kwargs) + + def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> "PointCloud": + if not average_neighbors: + return PointCloud( + coords=self.coords[indices], + channels={k: v[indices] for k, v in self.channels.items()}, + ) + + new_coords = self.coords[indices] + neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords) + + # Make sure every point points to itself, which might not + # be the case if points are duplicated or there is rounding + # error. + neighbor_indices[indices] = np.arange(len(indices)) + + new_channels = {} + for k, v in self.channels.items(): + v_sum = np.zeros_like(v[: len(indices)]) + v_count = np.zeros_like(v[: len(indices)]) + np.add.at(v_sum, neighbor_indices, v) + np.add.at(v_count, neighbor_indices, 1) + new_channels[k] = v_sum / v_count + return PointCloud(coords=new_coords, channels=new_channels) + + def select_channels(self, channel_names: List[str]) -> np.ndarray: + data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1) + return data + + def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray: + """ + For each point in another set of points, compute the point in this + pointcloud which is closest. + + :param points: an [N x 3] array of points. + :param batch_size: the number of neighbor distances to compute at once. + Smaller values save memory, while larger values may + make the computation faster. + :return: an [N] array of indices into self.coords. + """ + norms = np.sum(self.coords**2, axis=-1) + all_indices = [] + for i in range(0, len(points), batch_size): + batch = points[i : i + batch_size] + dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T) + all_indices.append(np.argmin(dists, axis=-1)) + return np.concatenate(all_indices, axis=0) + + def combine(self, other: "PointCloud") -> "PointCloud": + assert self.channels.keys() == other.channels.keys() + return PointCloud( + coords=np.concatenate([self.coords, other.coords], axis=0), + channels={ + k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items() + }, + ) diff --git a/shap_e/rendering/pytorch3d_util.py b/shap_e/rendering/pytorch3d_util.py new file mode 100644 index 0000000000000000000000000000000000000000..772036a5bddbe4f6a7557846109aeb46b83942dd --- /dev/null +++ b/shap_e/rendering/pytorch3d_util.py @@ -0,0 +1,248 @@ +import copy +import inspect +from typing import Any, Callable, List, Sequence, Tuple, Union + +import numpy as np +import torch +from pytorch3d.renderer import ( + BlendParams, + DirectionalLights, + FoVPerspectiveCameras, + MeshRasterizer, + MeshRenderer, + RasterizationSettings, + SoftPhongShader, + TexturesVertex, +) +from pytorch3d.renderer.utils import TensorProperties +from pytorch3d.structures import Meshes + +from shap_e.models.nn.checkpoint import checkpoint + +from .blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION +from .torch_mesh import TorchMesh +from .view_data import ProjectiveCamera + +# Using a lower value like 1e-4 seems to result in weird issues +# for our high-poly meshes. +DEFAULT_RENDER_SIGMA = 1e-5 + +DEFAULT_RENDER_GAMMA = 1e-4 + + +def render_images( + image_size: int, + meshes: Meshes, + cameras: Any, + lights: Any, + sigma: float = DEFAULT_RENDER_SIGMA, + gamma: float = DEFAULT_RENDER_GAMMA, + max_faces_per_bin=100000, + faces_per_pixel=50, + bin_size=None, + use_checkpoint: bool = False, +) -> torch.Tensor: + if use_checkpoint: + # Decompose all of our arguments into a bunch of tensor lists + # so that autograd can keep track of what the op depends on. + verts_list = meshes.verts_list() + faces_list = meshes.faces_list() + assert isinstance(meshes.textures, TexturesVertex) + assert isinstance(lights, BidirectionalLights) + textures = meshes.textures.verts_features_padded() + light_vecs, light_fn = _deconstruct_tensor_props(lights) + camera_vecs, camera_fn = _deconstruct_tensor_props(cameras) + + def ckpt_fn( + *args: torch.Tensor, + num_verts=len(verts_list), + num_light_vecs=len(light_vecs), + num_camera_vecs=len(camera_vecs), + light_fn=light_fn, + camera_fn=camera_fn, + faces_list=faces_list + ): + args = list(args) + verts_list = args[:num_verts] + del args[:num_verts] + light_vecs = args[:num_light_vecs] + del args[:num_light_vecs] + camera_vecs = args[:num_camera_vecs] + del args[:num_camera_vecs] + textures = args.pop(0) + + meshes = Meshes(verts=verts_list, faces=faces_list, textures=TexturesVertex(textures)) + lights = light_fn(light_vecs) + cameras = camera_fn(camera_vecs) + return render_images( + image_size=image_size, + meshes=meshes, + cameras=cameras, + lights=lights, + sigma=sigma, + gamma=gamma, + max_faces_per_bin=max_faces_per_bin, + faces_per_pixel=faces_per_pixel, + bin_size=bin_size, + use_checkpoint=False, + ) + + result = checkpoint(ckpt_fn, (*verts_list, *light_vecs, *camera_vecs, textures), (), True) + else: + raster_settings_soft = RasterizationSettings( + image_size=image_size, + blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, + faces_per_pixel=faces_per_pixel, + max_faces_per_bin=max_faces_per_bin, + bin_size=bin_size, + perspective_correct=False, + ) + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings_soft), + shader=SoftPhongShader( + device=meshes.device, + cameras=cameras, + lights=lights, + blend_params=BlendParams(sigma=sigma, gamma=gamma, background_color=(0, 0, 0)), + ), + ) + result = renderer(meshes) + + return result + + +def _deconstruct_tensor_props( + props: TensorProperties, +) -> Tuple[List[torch.Tensor], Callable[[List[torch.Tensor]], TensorProperties]]: + vecs = [] + names = [] + other_props = {} + for k in dir(props): + if k.startswith("__"): + continue + v = getattr(props, k) + if inspect.ismethod(v): + continue + if torch.is_tensor(v): + vecs.append(v) + names.append(k) + else: + other_props[k] = v + + def recreate_fn(vecs_arg): + other = type(props)(device=props.device) + for k, v in other_props.items(): + setattr(other, k, copy.deepcopy(v)) + for name, vec in zip(names, vecs_arg): + setattr(other, name, vec) + return other + + return vecs, recreate_fn + + + +def convert_meshes(raw_meshes: Sequence[TorchMesh], default_brightness=0.8) -> Meshes: + meshes = Meshes( + verts=[mesh.verts for mesh in raw_meshes], faces=[mesh.faces for mesh in raw_meshes] + ) + rgbs = [] + for mesh in raw_meshes: + if mesh.vertex_channels and all(k in mesh.vertex_channels for k in "RGB"): + rgbs.append(torch.stack([mesh.vertex_channels[k] for k in "RGB"], axis=-1)) + else: + rgbs.append( + torch.ones( + len(mesh.verts) * default_brightness, + 3, + device=mesh.verts.device, + dtype=mesh.verts.dtype, + ) + ) + meshes.textures = TexturesVertex(verts_features=rgbs) + return meshes + + +def convert_cameras( + cameras: Sequence[ProjectiveCamera], device: torch.device +) -> FoVPerspectiveCameras: + Rs = [] + Ts = [] + for camera in cameras: + assert ( + camera.width == camera.height and camera.x_fov == camera.y_fov + ), "viewports must be square" + assert camera.x_fov == cameras[0].x_fov, "all cameras must have same field-of-view" + R = np.stack([-camera.x, -camera.y, camera.z], axis=0).T + T = -R.T @ camera.origin + Rs.append(R) + Ts.append(T) + return FoVPerspectiveCameras( + R=np.stack(Rs, axis=0), + T=np.stack(Ts, axis=0), + fov=cameras[0].x_fov, + degrees=False, + device=device, + ) + + +def convert_cameras_torch( + origins: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor, zs: torch.Tensor, fov: float +) -> FoVPerspectiveCameras: + Rs = [] + Ts = [] + for origin, x, y, z in zip(origins, xs, ys, zs): + R = torch.stack([-x, -y, z], axis=0).T + T = -R.T @ origin + Rs.append(R) + Ts.append(T) + return FoVPerspectiveCameras( + R=torch.stack(Rs, dim=0), + T=torch.stack(Ts, dim=0), + fov=fov, + degrees=False, + device=origins.device, + ) + + +def blender_uniform_lights( + batch_size: int, + device: torch.device, + ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, + diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, + specular_color: Union[float, Tuple[float]] = 0.0, +) -> "BidirectionalLights": + """ + Create a light that attempts to match the light used by the Blender + renderer when run with `--light_mode basic`. + """ + if isinstance(ambient_color, float): + ambient_color = (ambient_color,) * 3 + if isinstance(diffuse_color, float): + diffuse_color = (diffuse_color,) * 3 + if isinstance(specular_color, float): + specular_color = (specular_color,) * 3 + return BidirectionalLights( + ambient_color=(ambient_color,) * batch_size, + diffuse_color=(diffuse_color,) * batch_size, + specular_color=(specular_color,) * batch_size, + direction=(UNIFORM_LIGHT_DIRECTION,) * batch_size, + device=device, + ) + + +class BidirectionalLights(DirectionalLights): + """ + Adapted from here, but effectively shines the light in both positive and negative directions: + https://github.com/facebookresearch/pytorch3d/blob/efea540bbcab56fccde6f4bc729d640a403dac56/pytorch3d/renderer/lighting.py#L159 + """ + + def diffuse(self, normals, points=None) -> torch.Tensor: + return torch.maximum( + super().diffuse(normals, points=points), super().diffuse(-normals, points=points) + ) + + def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: + return torch.maximum( + super().specular(normals, points, camera_position, shininess), + super().specular(-normals, points, camera_position, shininess), + ) diff --git a/shap_e/rendering/raycast/__init__.py b/shap_e/rendering/raycast/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/rendering/raycast/_utils.py b/shap_e/rendering/raycast/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..61661861fe756a8435c44b23a65df15c1c6e3018 --- /dev/null +++ b/shap_e/rendering/raycast/_utils.py @@ -0,0 +1,16 @@ +import torch + + +def normalize(v: torch.Tensor) -> torch.Tensor: + return v / torch.linalg.norm(v, dim=-1, keepdim=True) + + +def cross_product(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor: + return torch.stack( + [ + v1[..., 1] * v2[..., 2] - v2[..., 1] * v1[..., 2], + -(v1[..., 0] * v2[..., 2] - v2[..., 0] * v1[..., 2]), + v1[..., 0] * v2[..., 1] - v2[..., 0] * v1[..., 1], + ], + dim=-1, + ) diff --git a/shap_e/rendering/raycast/cast.py b/shap_e/rendering/raycast/cast.py new file mode 100644 index 0000000000000000000000000000000000000000..7ef357128e3930e1f5c8d5b0f91d011faf9b2d1f --- /dev/null +++ b/shap_e/rendering/raycast/cast.py @@ -0,0 +1,132 @@ +from typing import Iterator, Optional, Tuple + +import numpy as np +import torch + +from shap_e.rendering.view_data import ProjectiveCamera + +from ._utils import cross_product +from .types import RayCollisions, Rays, TriMesh + + +def cast_camera( + camera: ProjectiveCamera, + mesh: TriMesh, + ray_batch_size: Optional[int] = None, + checkpoint: Optional[bool] = None, +) -> Iterator[RayCollisions]: + pixel_indices = np.arange(camera.width * camera.height) + image_coords = np.stack([pixel_indices % camera.width, pixel_indices // camera.width], axis=1) + rays = camera.camera_rays(image_coords) + batch_size = ray_batch_size or len(rays) + checkpoint = checkpoint if checkpoint is not None else batch_size < len(rays) + for i in range(0, len(rays), batch_size): + sub_rays = rays[i : i + batch_size] + origins = torch.from_numpy(sub_rays[:, 0]).to(mesh.vertices) + directions = torch.from_numpy(sub_rays[:, 1]).to(mesh.vertices) + yield cast_rays(Rays(origins=origins, directions=directions), mesh, checkpoint=checkpoint) + + +def cast_rays(rays: Rays, mesh: TriMesh, checkpoint: bool = False) -> RayCollisions: + """ + Cast a batch of rays onto a mesh. + """ + if checkpoint: + collides, ray_dists, tri_indices, barycentric, normals = RayCollisionFunction.apply( + rays.origins, rays.directions, mesh.faces, mesh.vertices + ) + return RayCollisions( + collides=collides, + ray_dists=ray_dists, + tri_indices=tri_indices, + barycentric=barycentric, + normals=normals, + ) + + # https://github.com/unixpickle/vae-textures/blob/2968549ddd4a3487f9437d4db00793324453cd59/vae_textures/render.py#L98 + normals = mesh.normals() # [N x 3] + directions = rays.directions # [M x 3] + collides = (directions @ normals.T).abs() > 1e-8 # [N x M] + + tris = mesh.vertices[mesh.faces] # [N x 3 x 3] + v1 = tris[:, 1] - tris[:, 0] + v2 = tris[:, 2] - tris[:, 0] + + cross1 = cross_product(directions[:, None], v2[None]) # [N x M x 3] + det = torch.sum(cross1 * v1[None], dim=-1) # [N x M] + collides = torch.logical_and(collides, det.abs() > 1e-8) + + invDet = 1 / det # [N x M] + o = rays.origins[:, None] - tris[None, :, 0] # [N x M x 3] + bary1 = invDet * torch.sum(o * cross1, dim=-1) # [N x M] + collides = torch.logical_and(collides, torch.logical_and(bary1 >= 0, bary1 <= 1)) + + cross2 = cross_product(o, v1[None]) # [N x M x 3] + bary2 = invDet * torch.sum(directions[:, None] * cross2, dim=-1) # [N x M] + collides = torch.logical_and(collides, torch.logical_and(bary2 >= 0, bary2 <= 1)) + + bary0 = 1 - (bary1 + bary2) + + # Make sure this is in the positive part of the ray. + scale = invDet * torch.sum(v2 * cross2, dim=-1) + collides = torch.logical_and(collides, scale > 0) + + # Select the nearest collision + ray_dists, tri_indices = torch.min( + torch.where(collides, scale, torch.tensor(torch.inf).to(scale)), dim=-1 + ) # [N] + nearest_bary = torch.stack( + [ + bary0[range(len(tri_indices)), tri_indices], + bary1[range(len(tri_indices)), tri_indices], + bary2[range(len(tri_indices)), tri_indices], + ], + dim=-1, + ) + + return RayCollisions( + collides=torch.any(collides, dim=-1), + ray_dists=ray_dists, + tri_indices=tri_indices, + barycentric=nearest_bary, + normals=normals[tri_indices], + ) + + +class RayCollisionFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, origins, directions, faces, vertices + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ctx.save_for_backward(origins, directions, faces, vertices) + with torch.no_grad(): + res = cast_rays( + Rays(origins=origins, directions=directions), + TriMesh(faces=faces, vertices=vertices), + checkpoint=False, + ) + return (res.collides, res.ray_dists, res.tri_indices, res.barycentric, res.normals) + + @staticmethod + def backward( + ctx, _collides_grad, ray_dists_grad, _tri_indices_grad, barycentric_grad, normals_grad + ): + origins, directions, faces, vertices = ctx.input_tensors + + origins = origins.detach().requires_grad_(True) + directions = directions.detach().requires_grad_(True) + vertices = vertices.detach().requires_grad_(True) + + with torch.enable_grad(): + outputs = cast_rays( + Rays(origins=origins, directions=directions), + TriMesh(faces=faces, vertices=vertices), + checkpoint=False, + ) + + origins_grad, directions_grad, vertices_grad = torch.autograd.grad( + (outputs.ray_dists, outputs.barycentric, outputs.normals), + (origins, directions, vertices), + (ray_dists_grad, barycentric_grad, normals_grad), + ) + return (origins_grad, directions_grad, None, vertices_grad) diff --git a/shap_e/rendering/raycast/render.py b/shap_e/rendering/raycast/render.py new file mode 100644 index 0000000000000000000000000000000000000000..d99461c6b5f92de706b4797e139a9cc3dc7df6db --- /dev/null +++ b/shap_e/rendering/raycast/render.py @@ -0,0 +1,57 @@ +from typing import Optional, Sequence + +import torch + +from shap_e.rendering.blender.constants import ( + BASIC_AMBIENT_COLOR, + BASIC_DIFFUSE_COLOR, + UNIFORM_LIGHT_DIRECTION, +) +from shap_e.rendering.view_data import ProjectiveCamera + +from .cast import cast_camera +from .types import RayCollisions, TriMesh + + +def render_diffuse_mesh( + camera: ProjectiveCamera, + mesh: TriMesh, + light_direction: Sequence[float] = tuple(UNIFORM_LIGHT_DIRECTION), + diffuse: float = BASIC_DIFFUSE_COLOR, + ambient: float = BASIC_AMBIENT_COLOR, + ray_batch_size: Optional[int] = None, + checkpoint: Optional[bool] = None, +) -> torch.Tensor: + """ + Return an [H x W x 4] RGBA tensor of the rendered image. + The pixels are floating points, with alpha in the range [0, 1] and the + other colors matching the scale used by the mesh's vertex colors. + """ + light_direction = torch.tensor( + light_direction, device=mesh.vertices.device, dtype=mesh.vertices.dtype + ) + + all_collisions = RayCollisions.collect( + cast_camera( + camera=camera, + mesh=mesh, + ray_batch_size=ray_batch_size, + checkpoint=checkpoint, + ) + ) + num_rays = len(all_collisions.normals) + if mesh.vertex_colors is None: + vertex_colors = torch.tensor([[0.8, 0.8, 0.8]]).to(mesh.vertices).repeat(num_rays, 1) + else: + vertex_colors = mesh.vertex_colors + + light_coeffs = ambient + ( + diffuse * torch.sum(all_collisions.normals * light_direction, dim=-1).abs() + ) + vertex_colors = mesh.vertex_colors[mesh.faces[all_collisions.tri_indices]] + bary_products = torch.sum(vertex_colors * all_collisions.barycentric[..., None], axis=-2) + out_colors = bary_products * light_coeffs[..., None] + res = torch.where(all_collisions.collides[:, None], out_colors, torch.zeros_like(out_colors)) + return torch.cat([res, all_collisions.collides[:, None].float()], dim=-1).view( + camera.height, camera.width, 4 + ) diff --git a/shap_e/rendering/raycast/types.py b/shap_e/rendering/raycast/types.py new file mode 100644 index 0000000000000000000000000000000000000000..a2cba7a7acc3f7aeff637ae415ea2b3207d11a91 --- /dev/null +++ b/shap_e/rendering/raycast/types.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass +from typing import Iterable, Optional + +import numpy as np +import torch + +import shap_e.rendering.mesh + +from ._utils import cross_product, normalize + + +@dataclass +class Rays: + """ + A ray in ray casting. + """ + + origins: torch.Tensor # [N x 3] float tensor + directions: torch.Tensor # [N x 3] float tensor + + def normalized_directions(self) -> torch.Tensor: + return normalize(self.directions) + + +@dataclass +class RayCollisions: + """ + The result of casting N rays onto a mesh. + """ + + collides: torch.Tensor # [N] boolean tensor + ray_dists: torch.Tensor # [N] float tensor + tri_indices: torch.Tensor # [N] long tensor + barycentric: torch.Tensor # [N x 3] float tensor + normals: torch.Tensor # [N x 3] float tensor + + @classmethod + def collect(cls, it: Iterable["RayCollisions"]) -> "RayCollisions": + res = None + for x in it: + if res is None: + res = x + else: + res = cls( + collides=torch.cat([res.collides, x.collides]), + ray_dists=torch.cat([res.ray_dists, x.ray_dists]), + tri_indices=torch.cat([res.tri_indices, x.tri_indices]), + barycentric=torch.cat([res.barycentric, x.barycentric]), + normals=torch.cat([res.normals, x.normals]), + ) + if res is None: + raise ValueError("cannot collect an empty iterable of RayCollisions") + return res + + +@dataclass +class TriMesh: + faces: torch.Tensor # [N x 3] long tensor + vertices: torch.Tensor # [N x 3] float tensor + + vertex_colors: Optional[torch.Tensor] = None + + def normals(self) -> torch.Tensor: + """ + Returns an [N x 3] batch of normal vectors per triangle assuming the + right-hand rule. + """ + tris = self.vertices[self.faces] + v1 = tris[:, 1] - tris[:, 0] + v2 = tris[:, 2] - tris[:, 0] + return normalize(cross_product(v1, v2)) + + @classmethod + def from_numpy(cls, x: shap_e.rendering.mesh.TriMesh) -> "TriMesh": + vertex_colors = None + if all(ch in x.vertex_channels for ch in "RGB"): + vertex_colors = torch.from_numpy( + np.stack([x.vertex_channels[ch] for ch in "RGB"], axis=-1) + ) + return cls( + faces=torch.from_numpy(x.faces), + vertices=torch.from_numpy(x.verts), + vertex_colors=vertex_colors, + ) + + def to(self, *args, **kwargs) -> "TriMesh": + return TriMesh( + faces=self.faces.to(*args, **kwargs), + vertices=self.vertices.to(*args, **kwargs), + vertex_colors=None + if self.vertex_colors is None + else self.vertex_colors.to(*args, **kwargs), + ) diff --git a/shap_e/rendering/torch_mesh.py b/shap_e/rendering/torch_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..49c6894c9046ac0e0884ceba450b65b2bb847534 --- /dev/null +++ b/shap_e/rendering/torch_mesh.py @@ -0,0 +1,42 @@ +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch + +from .mesh import TriMesh + + +@dataclass +class TorchMesh: + """ + A 3D triangle mesh with optional data at the vertices and faces. + """ + + # [N x 3] array of vertex coordinates. + verts: torch.Tensor + + # [M x 3] array of triangles, pointing to indices in verts. + faces: torch.Tensor + + # Extra data per vertex and face. + vertex_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict) + face_channels: Optional[Dict[str, torch.Tensor]] = field(default_factory=dict) + + def tri_mesh(self) -> TriMesh: + """ + Create a CPU version of the mesh. + """ + return TriMesh( + verts=self.verts.detach().cpu().numpy(), + faces=self.faces.cpu().numpy(), + vertex_channels=( + {k: v.detach().cpu().numpy() for k, v in self.vertex_channels.items()} + if self.vertex_channels is not None + else None + ), + face_channels=( + {k: v.detach().cpu().numpy() for k, v in self.face_channels.items()} + if self.face_channels is not None + else None + ), + ) diff --git a/shap_e/rendering/view_data.py b/shap_e/rendering/view_data.py new file mode 100644 index 0000000000000000000000000000000000000000..fded23d8410b3c2add60e084c59eda102f8beaf7 --- /dev/null +++ b/shap_e/rendering/view_data.py @@ -0,0 +1,206 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np + + +@dataclass +class Camera(ABC): + """ + An object describing how a camera corresponds to pixels in an image. + """ + + @abstractmethod + def image_coords(self) -> np.ndarray: + """ + :return: ([self.height, self.width, 2]).reshape(self.height * self.width, 2) image coordinates + """ + + @abstractmethod + def camera_rays(self, coords: np.ndarray) -> np.ndarray: + """ + For every (x, y) coordinate in a rendered image, compute the ray of the + corresponding pixel. + + :param coords: an [N x 2] integer array of 2D image coordinates. + :return: an [N x 2 x 3] array of [2 x 3] (origin, direction) tuples. + The direction should always be unit length. + """ + + def depth_directions(self, coords: np.ndarray) -> np.ndarray: + """ + For every (x, y) coordinate in a rendered image, get the direction that + corresponds to "depth" in an RGBD rendering. + + This may raise an exception if there is no "D" channel in the + corresponding ViewData. + + :param coords: an [N x 2] integer array of 2D image coordinates. + :return: an [N x 3] array of normalized depth directions. + """ + _ = coords + raise NotImplementedError + + @abstractmethod + def center_crop(self) -> "Camera": + """ + Creates a new camera with the same intrinsics and direction as this one, + but with a center crop to a square of the smaller dimension. + """ + + @abstractmethod + def resize_image(self, width: int, height: int) -> "Camera": + """ + Creates a new camera with the same intrinsics and direction as this one, + but with resized image dimensions. + """ + + @abstractmethod + def scale_scene(self, factor: float) -> "Camera": + """ + Creates a new camera with the same intrinsics and direction as this one, + but with the scene rescaled by the given factor. + """ + + +@dataclass +class ProjectiveCamera(Camera): + """ + A Camera implementation for a standard pinhole camera. + + The camera rays shoot away from the origin in the z direction, with the x + and y directions corresponding to the positive horizontal and vertical axes + in image space. + """ + + origin: np.ndarray + x: np.ndarray + y: np.ndarray + z: np.ndarray + width: int + height: int + x_fov: float + y_fov: float + + def image_coords(self) -> np.ndarray: + ind = np.arange(self.width * self.height) + coords = np.stack([ind % self.width, ind // self.width], axis=1).astype(np.float32) + return coords + + def camera_rays(self, coords: np.ndarray) -> np.ndarray: + fracs = (coords / (np.array([self.width, self.height], dtype=np.float32) - 1)) * 2 - 1 + fracs = fracs * np.tan(np.array([self.x_fov, self.y_fov]) / 2) + directions = self.z + self.x * fracs[:, :1] + self.y * fracs[:, 1:] + directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True) + return np.stack([np.broadcast_to(self.origin, directions.shape), directions], axis=1) + + def depth_directions(self, coords: np.ndarray) -> np.ndarray: + return np.tile((self.z / np.linalg.norm(self.z))[None], [len(coords), 1]) + + def resize_image(self, width: int, height: int) -> "ProjectiveCamera": + """ + Creates a new camera for the resized view assuming the aspect ratio does not change. + """ + assert width * self.height == height * self.width, "The aspect ratio should not change." + return ProjectiveCamera( + origin=self.origin, + x=self.x, + y=self.y, + z=self.z, + width=width, + height=height, + x_fov=self.x_fov, + y_fov=self.y_fov, + ) + + def center_crop(self) -> "ProjectiveCamera": + """ + Creates a new camera for the center-cropped view + """ + size = min(self.width, self.height) + fov = min(self.x_fov, self.y_fov) + return ProjectiveCamera( + origin=self.origin, + x=self.x, + y=self.y, + z=self.z, + width=size, + height=size, + x_fov=fov, + y_fov=fov, + ) + + def scale_scene(self, factor: float) -> "ProjectiveCamera": + """ + Creates a new camera with the same intrinsics and direction as this one, + but with the camera frame rescaled by the given factor. + """ + return ProjectiveCamera( + origin=self.origin * factor, + x=self.x, + y=self.y, + z=self.z, + width=self.width, + height=self.height, + x_fov=self.x_fov, + y_fov=self.y_fov, + ) + + +class ViewData(ABC): + """ + A collection of rendered camera views of a scene or object. + + This is a generalization of a NeRF dataset, since NeRF datasets only encode + RGB or RGBA data, whereas this dataset supports arbitrary channels. + """ + + @property + @abstractmethod + def num_views(self) -> int: + """ + The number of rendered views. + """ + + @property + @abstractmethod + def channel_names(self) -> List[str]: + """ + Get all of the supported channels available for the views. + + This can be arbitrary, but there are some standard names: + "R", "G", "B", "A" (alpha), and "D" (depth). + """ + + @abstractmethod + def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]: + """ + Load the given channels from the view at the given index. + + :return: a tuple (camera_view, data), where data is a float array of + shape [height x width x num_channels]. + """ + + +class MemoryViewData(ViewData): + """ + A ViewData that is implemented in memory. + """ + + def __init__(self, channels: Dict[str, np.ndarray], cameras: List[Camera]): + assert all(v.shape[0] == len(cameras) for v in channels.values()) + self.channels = channels + self.cameras = cameras + + @property + def num_views(self) -> int: + return len(self.cameras) + + @property + def channel_names(self) -> List[str]: + return list(self.channels.keys()) + + def load_view(self, index: int, channels: List[str]) -> Tuple[Camera, np.ndarray]: + outputs = [self.channels[channel][index] for channel in channels] + return self.cameras[index], np.stack(outputs, axis=-1) diff --git a/shap_e/util/__init__.py b/shap_e/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/shap_e/util/__pycache__/__init__.cpython-39.pyc b/shap_e/util/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb6e27a9ec86c857fa748738a223f430eed0c744 Binary files /dev/null and b/shap_e/util/__pycache__/__init__.cpython-39.pyc differ diff --git a/shap_e/util/__pycache__/collections.cpython-39.pyc b/shap_e/util/__pycache__/collections.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7861920c3f8a467a6155abdaa07dfc004797a0eb Binary files /dev/null and b/shap_e/util/__pycache__/collections.cpython-39.pyc differ diff --git a/shap_e/util/__pycache__/io.cpython-39.pyc b/shap_e/util/__pycache__/io.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18a2bb5ee1fc931f74cff860688adcd647228bff Binary files /dev/null and b/shap_e/util/__pycache__/io.cpython-39.pyc differ diff --git a/shap_e/util/__pycache__/notebooks.cpython-39.pyc b/shap_e/util/__pycache__/notebooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba1b98f160531e8b0f5b71cbe25e055ec1799449 Binary files /dev/null and b/shap_e/util/__pycache__/notebooks.cpython-39.pyc differ diff --git a/shap_e/util/collections.py b/shap_e/util/collections.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d823a2ea7582c09ebcc916267e36f8c90174d5 --- /dev/null +++ b/shap_e/util/collections.py @@ -0,0 +1,136 @@ +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Optional + + +class AttrDict(OrderedDict): + """ + An attribute dictionary that automatically handles nested keys joined by "/". + + Originally copied from: https://stackoverflow.com/questions/3031219/recursively-access-dict-via-attributes-as-well-as-index-access + """ + + MARKER = object() + + # pylint: disable=super-init-not-called + def __init__(self, *args, **kwargs): + if len(args) == 0: + for key, value in kwargs.items(): + self.__setitem__(key, value) + else: + assert len(args) == 1 + assert isinstance(args[0], (dict, AttrDict)) + for key, value in args[0].items(): + self.__setitem__(key, value) + + def __contains__(self, key): + if "/" in key: + keys = key.split("/") + key, next_key = keys[0], "/".join(keys[1:]) + return key in self and next_key in self[key] + return super(AttrDict, self).__contains__(key) + + def __setitem__(self, key, value): + if "/" in key: + keys = key.split("/") + key, next_key = keys[0], "/".join(keys[1:]) + if key not in self: + self[key] = AttrDict() + self[key].__setitem__(next_key, value) + return + + if isinstance(value, dict) and not isinstance(value, AttrDict): + value = AttrDict(**value) + if isinstance(value, list): + value = [AttrDict(val) if isinstance(val, dict) else val for val in value] + super(AttrDict, self).__setitem__(key, value) + + def __getitem__(self, key): + if "/" in key: + keys = key.split("/") + key, next_key = keys[0], "/".join(keys[1:]) + val = self[key] + if not isinstance(val, AttrDict): + raise ValueError + return val.__getitem__(next_key) + + return self.get(key, None) + + def all_keys( + self, + leaves_only: bool = False, + parent: Optional[str] = None, + ) -> List[str]: + keys = [] + for key in self.keys(): + cur = key if parent is None else f"{parent}/{key}" + if not leaves_only or not isinstance(self[key], dict): + keys.append(cur) + if isinstance(self[key], dict): + keys.extend(self[key].all_keys(leaves_only=leaves_only, parent=cur)) + return keys + + def dumpable(self, strip=True): + """ + Casts into OrderedDict and removes internal attributes + """ + + def _dump(val): + if isinstance(val, AttrDict): + return val.dumpable() + elif isinstance(val, list): + return [_dump(v) for v in val] + return val + + if strip: + return {k: _dump(v) for k, v in self.items() if not k.startswith("_")} + return {k: _dump(v if not k.startswith("_") else repr(v)) for k, v in self.items()} + + def map( + self, + map_fn: Callable[[Any, Any], Any], + should_map: Optional[Callable[[Any, Any], bool]] = None, + ) -> "AttrDict": + """ + Creates a copy of self where some or all values are transformed by + map_fn. + + :param should_map: If provided, only those values that evaluate to true + are converted; otherwise, all values are mapped. + """ + + def _apply(key, val): + if isinstance(val, AttrDict): + return val.map(map_fn, should_map) + elif should_map is None or should_map(key, val): + return map_fn(key, val) + return val + + return AttrDict({k: _apply(k, v) for k, v in self.items()}) + + def __eq__(self, other): + return self.keys() == other.keys() and all(self[k] == other[k] for k in self.keys()) + + def combine( + self, + other: Dict[str, Any], + combine_fn: Callable[[Optional[Any], Optional[Any]], Any], + ) -> "AttrDict": + """ + Some values may be missing, but the dictionary structures must be the + same. + + :param combine_fn: a (possibly non-commutative) function to combine the + values + """ + + def _apply(val, other_val): + if val is not None and isinstance(val, AttrDict): + assert isinstance(other_val, AttrDict) + return val.combine(other_val, combine_fn) + return combine_fn(val, other_val) + + # TODO nit: this changes the ordering.. + keys = self.keys() | other.keys() + return AttrDict({k: _apply(self[k], other[k]) for k in keys}) + + __setattr__, __getattr__ = __setitem__, __getitem__ diff --git a/shap_e/util/data_util.py b/shap_e/util/data_util.py new file mode 100644 index 0000000000000000000000000000000000000000..93f0dc833b83f9c5ef7ccf2bca393af7f312f1bc --- /dev/null +++ b/shap_e/util/data_util.py @@ -0,0 +1,256 @@ +import tempfile +from contextlib import contextmanager +from typing import Iterator, Optional, Union + +import blobfile as bf +import numpy as np +import torch +from PIL import Image + +from shap_e.rendering.blender.render import render_mesh, render_model +from shap_e.rendering.blender.view_data import BlenderViewData +from shap_e.rendering.mesh import TriMesh +from shap_e.rendering.point_cloud import PointCloud +from shap_e.rendering.view_data import ViewData +from shap_e.util.collections import AttrDict +from shap_e.util.image_util import center_crop, get_alpha, remove_alpha, resize + + +def load_or_create_multimodal_batch( + device: torch.device, + *, + mesh_path: Optional[str] = None, + model_path: Optional[str] = None, + cache_dir: Optional[str] = None, + point_count: int = 2**14, + random_sample_count: int = 2**19, + pc_num_views: int = 40, + mv_light_mode: Optional[str] = None, + mv_num_views: int = 20, + mv_image_size: int = 512, + mv_alpha_removal: str = "black", + verbose: bool = False, +) -> AttrDict: + if verbose: + print("creating point cloud...") + pc = load_or_create_pc( + mesh_path=mesh_path, + model_path=model_path, + cache_dir=cache_dir, + random_sample_count=random_sample_count, + point_count=point_count, + num_views=pc_num_views, + verbose=verbose, + ) + raw_pc = np.concatenate([pc.coords, pc.select_channels(["R", "G", "B"])], axis=-1) + encode_me = torch.from_numpy(raw_pc).float().to(device) + batch = AttrDict(points=encode_me.t()[None]) + if mv_light_mode: + if verbose: + print("creating multiview...") + with load_or_create_multiview( + mesh_path=mesh_path, + model_path=model_path, + cache_dir=cache_dir, + num_views=mv_num_views, + extract_material=False, + light_mode=mv_light_mode, + verbose=verbose, + ) as mv: + cameras, views, view_alphas, depths = [], [], [], [] + for view_idx in range(mv.num_views): + camera, view = mv.load_view( + view_idx, + ["R", "G", "B", "A"] if "A" in mv.channel_names else ["R", "G", "B"], + ) + depth = None + if "D" in mv.channel_names: + _, depth = mv.load_view(view_idx, ["D"]) + depth = process_depth(depth, mv_image_size) + view, alpha = process_image( + np.round(view * 255.0).astype(np.uint8), mv_alpha_removal, mv_image_size + ) + camera = camera.center_crop().resize_image(mv_image_size, mv_image_size) + cameras.append(camera) + views.append(view) + view_alphas.append(alpha) + depths.append(depth) + batch.depths = [depths] + batch.views = [views] + batch.view_alphas = [view_alphas] + batch.cameras = [cameras] + return normalize_input_batch(batch, pc_scale=2.0, color_scale=1.0 / 255.0) + + +def load_or_create_pc( + *, + mesh_path: Optional[str], + model_path: Optional[str], + cache_dir: Optional[str], + random_sample_count: int, + point_count: int, + num_views: int, + verbose: bool = False, +) -> PointCloud: + + assert (model_path is not None) ^ ( + mesh_path is not None + ), "must specify exactly one of model_path or mesh_path" + path = model_path if model_path is not None else mesh_path + + if cache_dir is not None: + cache_path = bf.join( + cache_dir, + f"pc_{bf.basename(path)}_mat_{num_views}_{random_sample_count}_{point_count}.npz", + ) + if bf.exists(cache_path): + return PointCloud.load(cache_path) + else: + cache_path = None + + with load_or_create_multiview( + mesh_path=mesh_path, + model_path=model_path, + cache_dir=cache_dir, + num_views=num_views, + verbose=verbose, + ) as mv: + if verbose: + print("extracting point cloud from multiview...") + pc = mv_to_pc( + multiview=mv, random_sample_count=random_sample_count, point_count=point_count + ) + if cache_path is not None: + pc.save(cache_path) + return pc + + +@contextmanager +def load_or_create_multiview( + *, + mesh_path: Optional[str], + model_path: Optional[str], + cache_dir: Optional[str], + num_views: int = 20, + extract_material: bool = True, + light_mode: Optional[str] = None, + verbose: bool = False, +) -> Iterator[BlenderViewData]: + + assert (model_path is not None) ^ ( + mesh_path is not None + ), "must specify exactly one of model_path or mesh_path" + path = model_path if model_path is not None else mesh_path + + if extract_material: + assert light_mode is None, "light_mode is ignored when extract_material=True" + else: + assert light_mode is not None, "must specify light_mode when extract_material=False" + + if cache_dir is not None: + if extract_material: + cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_mat_{num_views}.zip") + else: + cache_path = bf.join(cache_dir, f"mv_{bf.basename(path)}_{light_mode}_{num_views}.zip") + if bf.exists(cache_path): + with bf.BlobFile(cache_path, "rb") as f: + yield BlenderViewData(f) + return + else: + cache_path = None + + common_kwargs = dict( + fast_mode=True, + extract_material=extract_material, + camera_pose="random", + light_mode=light_mode or "uniform", + verbose=verbose, + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = bf.join(tmp_dir, "out.zip") + if mesh_path is not None: + mesh = TriMesh.load(mesh_path) + render_mesh( + mesh=mesh, + output_path=tmp_path, + num_images=num_views, + backend="CYCLES", + **common_kwargs, + ) + elif model_path is not None: + render_model( + model_path, + output_path=tmp_path, + num_images=num_views, + backend="CYCLES", + **common_kwargs, + ) + if cache_path is not None: + bf.copy(tmp_path, cache_path) + with bf.BlobFile(tmp_path, "rb") as f: + yield BlenderViewData(f) + + +def mv_to_pc(multiview: ViewData, random_sample_count: int, point_count: int) -> PointCloud: + pc = PointCloud.from_rgbd(multiview) + + # Handle empty samples. + if len(pc.coords) == 0: + pc = PointCloud( + coords=np.zeros([1, 3]), + channels=dict(zip("RGB", np.zeros([3, 1]))), + ) + while len(pc.coords) < point_count: + pc = pc.combine(pc) + # Prevent duplicate points; some models may not like it. + pc.coords += np.random.normal(size=pc.coords.shape) * 1e-4 + + pc = pc.random_sample(random_sample_count) + pc = pc.farthest_point_sample(point_count, average_neighbors=True) + + return pc + + +def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict: + res = batch.copy() + scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device) + res.points = res.points * scale_vec[:, None] + + if "cameras" in res: + res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras] + + if "depths" in res: + res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths] + + return res + + +def process_depth(depth_img: np.ndarray, image_size: int) -> np.ndarray: + depth_img = center_crop(depth_img) + depth_img = resize(depth_img, width=image_size, height=image_size) + return np.squeeze(depth_img) + + +def process_image( + img_or_img_arr: Union[Image.Image, np.ndarray], alpha_removal: str, image_size: int +): + if isinstance(img_or_img_arr, np.ndarray): + img = Image.fromarray(img_or_img_arr) + img_arr = img_or_img_arr + else: + img = img_or_img_arr + img_arr = np.array(img) + if len(img_arr.shape) == 2: + # Grayscale + rgb = Image.new("RGB", img.size) + rgb.paste(img) + img = rgb + img_arr = np.array(img) + + img = center_crop(img) + alpha = get_alpha(img) + img = remove_alpha(img, mode=alpha_removal) + alpha = alpha.resize((image_size,) * 2, resample=Image.BILINEAR) + img = img.resize((image_size,) * 2, resample=Image.BILINEAR) + return img, alpha diff --git a/shap_e/util/image_util.py b/shap_e/util/image_util.py new file mode 100644 index 0000000000000000000000000000000000000000..fb30518211d4d44c88c90bfe247d68cdfbb2d119 --- /dev/null +++ b/shap_e/util/image_util.py @@ -0,0 +1,170 @@ +import random +from typing import Any, List, Optional, Union + +import blobfile as bf +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + + +def center_crop( + img: Union[Image.Image, torch.Tensor, np.ndarray] +) -> Union[Image.Image, torch.Tensor, np.ndarray]: + """ + Center crops an image. + """ + if isinstance(img, (np.ndarray, torch.Tensor)): + height, width = img.shape[:2] + else: + width, height = img.size + size = min(width, height) + left, top = (width - size) // 2, (height - size) // 2 + right, bottom = left + size, top + size + if isinstance(img, (np.ndarray, torch.Tensor)): + img = img[top:bottom, left:right] + else: + img = img.crop((left, top, right, bottom)) + return img + + +def resize( + img: Union[Image.Image, torch.Tensor, np.ndarray], + *, + height: int, + width: int, + min_value: Optional[Any] = None, + max_value: Optional[Any] = None, +) -> Union[Image.Image, torch.Tensor, np.ndarray]: + """ + :param: img: image in HWC order + :return: currently written for downsampling + """ + + orig, cls = img, type(img) + if isinstance(img, Image.Image): + img = np.array(img) + dtype = img.dtype + if isinstance(img, np.ndarray): + img = torch.from_numpy(img) + ndim = img.ndim + if img.ndim == 2: + img = img.unsqueeze(-1) + + if min_value is None and max_value is None: + # .clamp throws an error when both are None + min_value = -np.inf + + img = img.permute(2, 0, 1) + size = (height, width) + img = ( + F.interpolate(img[None].float(), size=size, mode="area")[0] + .clamp(min_value, max_value) + .to(img.dtype) + .permute(1, 2, 0) + ) + + if ndim < img.ndim: + img = img.squeeze(-1) + if not isinstance(orig, torch.Tensor): + img = img.numpy() + img = img.astype(dtype) + if isinstance(orig, Image.Image): + img = Image.fromarray(img) + + return img + + +def get_alpha(img: Image.Image) -> Image.Image: + """ + :return: the alpha channel separated out as a grayscale image + """ + img_arr = np.asarray(img) + if img_arr.shape[2] == 4: + alpha = img_arr[:, :, 3] + else: + alpha = np.full(img_arr.shape[:2], 255, dtype=np.uint8) + alpha = Image.fromarray(alpha) + return alpha + + +def remove_alpha(img: Image.Image, mode: str = "random") -> Image.Image: + """ + No op if the image doesn't have an alpha channel. + + :param: mode: Defaults to "random" but has an option to use a "black" or + "white" background + + :return: image with alpha removed + """ + img_arr = np.asarray(img) + if img_arr.shape[2] == 4: + # Add bg to get rid of alpha channel + if mode == "random": + height, width = img_arr.shape[:2] + bg = Image.fromarray( + random.choice([_black_bg, _gray_bg, _checker_bg, _noise_bg])(height, width) + ) + bg.paste(img, mask=img) + img = bg + elif mode == "black" or mode == "white": + img_arr = img_arr.astype(float) + rgb, alpha = img_arr[:, :, :3], img_arr[:, :, -1:] / 255 + background = np.zeros((1, 1, 3)) if mode == "black" else np.full((1, 1, 3), 255) + rgb = rgb * alpha + background * (1 - alpha) + img = Image.fromarray(np.round(rgb).astype(np.uint8)) + return img + + +def _black_bg(h: int, w: int) -> np.ndarray: + return np.zeros([h, w, 3], dtype=np.uint8) + + +def _gray_bg(h: int, w: int) -> np.ndarray: + return (np.zeros([h, w, 3]) + np.random.randint(low=0, high=256)).astype(np.uint8) + + +def _checker_bg(h: int, w: int) -> np.ndarray: + checker_size = np.ceil(np.exp(np.random.uniform() * np.log(min(h, w)))) + c1 = np.random.randint(low=0, high=256) + c2 = np.random.randint(low=0, high=256) + + xs = np.arange(w)[None, :, None] + np.random.randint(low=0, high=checker_size + 1) + ys = np.arange(h)[:, None, None] + np.random.randint(low=0, high=checker_size + 1) + + fields = np.logical_xor((xs // checker_size) % 2 == 0, (ys // checker_size) % 2 == 0) + return np.where(fields, np.array([c1] * 3), np.array([c2] * 3)).astype(np.uint8) + + +def _noise_bg(h: int, w: int) -> np.ndarray: + return np.random.randint(low=0, high=256, size=[h, w, 3]).astype(np.uint8) + + +def load_image(image_path: str) -> Image.Image: + with bf.BlobFile(image_path, "rb") as thefile: + img = Image.open(thefile) + img.load() + return img + + +def make_tile(images: List[Union[np.ndarray, Image.Image]], columns=8) -> Image.Image: + """ + to test, run + >>> display(make_tile([(np.zeros((128, 128, 3)) + c).astype(np.uint8) for c in np.linspace(0, 255, 15)])) + """ + images = list(map(np.array, images)) + size = images[0].shape[0] + n = round_up(len(images), columns) + n_blanks = n - len(images) + images.extend([np.zeros((size, size, 3), dtype=np.uint8)] * n_blanks) + images = ( + np.array(images) + .reshape(n // columns, columns, size, size, 3) + .transpose([0, 2, 1, 3, 4]) + .reshape(n // columns * size, columns * size, 3) + ) + return Image.fromarray(images) + + +def round_up(n: int, b: int) -> int: + return (n + b - 1) // b * b diff --git a/shap_e/util/io.py b/shap_e/util/io.py new file mode 100644 index 0000000000000000000000000000000000000000..aead1186b37b43974cbb3d7fb50d27925aea9a01 --- /dev/null +++ b/shap_e/util/io.py @@ -0,0 +1,34 @@ +import io +from contextlib import contextmanager +from typing import Any, BinaryIO, Iterator, Union + +import blobfile as bf +import yaml + +from shap_e.util.collections import AttrDict + + +def read_config(path_or_file: Union[str, io.IOBase]) -> Any: + if isinstance(path_or_file, io.IOBase): + obj = yaml.load(path_or_file, Loader=yaml.SafeLoader) + else: + with bf.BlobFile(path_or_file, "rb") as f: + try: + obj = yaml.load(f, Loader=yaml.SafeLoader) + except Exception as exc: + with bf.BlobFile(path_or_file, "rb") as f: + print(f.read()) + raise exc + if isinstance(obj, dict): + return AttrDict(obj) + return obj + + +@contextmanager +def buffered_writer(raw_f: BinaryIO) -> Iterator[io.BufferedIOBase]: + if isinstance(raw_f, io.BufferedIOBase): + yield raw_f + else: + f = io.BufferedWriter(raw_f) + yield f + f.flush() diff --git a/shap_e/util/notebooks.py b/shap_e/util/notebooks.py new file mode 100644 index 0000000000000000000000000000000000000000..5b0db4127c4e14f839fa010b6ce0b5ec00ad984b --- /dev/null +++ b/shap_e/util/notebooks.py @@ -0,0 +1,79 @@ +import base64 +import io +from typing import Union, Optional + +import numpy as np +import torch +from PIL import Image + +from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera +from shap_e.models.transmitter.base import Transmitter, VectorDecoder +from shap_e.rendering.torch_mesh import TorchMesh +from shap_e.util.collections import AttrDict + + +def create_pan_cameras(size: int, device: torch.device, batch_size: Optional[int] = 1, dist: int = 4) -> DifferentiableCameraBatch: + origins = [] + xs = [] + ys = [] + zs = [] + for theta in np.linspace(0, 2 * np.pi, num=20): + z = np.array([np.sin(theta), np.cos(theta), -0.5]) + z /= np.sqrt(np.sum(z**2)) + origin = -z * dist + x = np.array([np.cos(theta), -np.sin(theta), 0.0]) + y = np.cross(z, x) + origins.append(origin) + xs.append(x) + ys.append(y) + zs.append(z) + return DifferentiableCameraBatch( + shape=(batch_size, len(xs)), + flat_camera=DifferentiableProjectiveCamera( + origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device).repeat(batch_size, 1), + x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device).repeat(batch_size, 1), + y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device).repeat(batch_size, 1), + z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device).repeat(batch_size, 1), + width=size, + height=size, + x_fov=0.7, + y_fov=0.7, + ), + ) + +@torch.no_grad() +def decode_latent_images( + xm: Union[Transmitter, VectorDecoder], + latent: torch.Tensor, + cameras: DifferentiableCameraBatch, + rendering_mode: str = "stf", +): + # import pdb; pdb.set_trace() + + decoded = xm.renderer.render_views( + AttrDict(cameras=cameras), + params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( + latent[None] + ), + options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False), + ) + import pdb; pdb.set_trace() + arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy() + return [Image.fromarray(x) for x in arr] + + +@torch.no_grad() +def decode_latent_mesh( + xm: Union[Transmitter, VectorDecoder], + latent: torch.Tensor, +) -> TorchMesh: + decoded = xm.renderer.render_views( + AttrDict(cameras=create_pan_cameras(2, latent.device)), # lowest resolution possible + params=(xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( + latent[None] + ), + options=AttrDict(rendering_mode="stf", render_with_direction=False), + ) + return decoded.raw_meshes[0] + +