Spaces:
Runtime error
Runtime error
silentchen
commited on
Commit
•
19c4ddf
1
Parent(s):
f633cbe
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +375 -0
- requirements.txt +17 -0
- shap_e/.DS_Store +0 -0
- shap_e/__init__.py +0 -0
- shap_e/__pycache__/__init__.cpython-39.pyc +0 -0
- shap_e/diffusion/__init__.py +0 -0
- shap_e/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
- shap_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc +0 -0
- shap_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc +0 -0
- shap_e/diffusion/__pycache__/sample.cpython-39.pyc +0 -0
- shap_e/diffusion/gaussian_diffusion.py +1143 -0
- shap_e/diffusion/k_diffusion.py +426 -0
- shap_e/diffusion/sample.py +160 -0
- shap_e/examples/encode_model.ipynb +93 -0
- shap_e/examples/sample_image_to_3d.ipynb +125 -0
- shap_e/examples/sample_text_to_3d.ipynb +124 -0
- shap_e/models/__init__.py +0 -0
- shap_e/models/__pycache__/__init__.cpython-39.pyc +0 -0
- shap_e/models/__pycache__/configs.cpython-39.pyc +0 -0
- shap_e/models/__pycache__/download.cpython-39.pyc +0 -0
- shap_e/models/__pycache__/query.cpython-39.pyc +0 -0
- shap_e/models/__pycache__/renderer.cpython-39.pyc +0 -0
- shap_e/models/__pycache__/volume.cpython-39.pyc +0 -0
- shap_e/models/configs.py +166 -0
- shap_e/models/download.py +152 -0
- shap_e/models/generation/__init__.py +0 -0
- shap_e/models/generation/__pycache__/__init__.cpython-39.pyc +0 -0
- shap_e/models/generation/__pycache__/latent_diffusion.cpython-39.pyc +0 -0
- shap_e/models/generation/__pycache__/perceiver.cpython-39.pyc +0 -0
- shap_e/models/generation/__pycache__/pooled_mlp.cpython-39.pyc +0 -0
- shap_e/models/generation/__pycache__/pretrained_clip.cpython-39.pyc +0 -0
- shap_e/models/generation/__pycache__/transformer.cpython-39.pyc +0 -0
- shap_e/models/generation/__pycache__/util.cpython-39.pyc +0 -0
- shap_e/models/generation/latent_diffusion.py +32 -0
- shap_e/models/generation/perceiver.py +244 -0
- shap_e/models/generation/pooled_mlp.py +74 -0
- shap_e/models/generation/pretrained_clip.py +270 -0
- shap_e/models/generation/transformer.py +494 -0
- shap_e/models/generation/util.py +23 -0
- shap_e/models/nerf/__init__.py +0 -0
- shap_e/models/nerf/__pycache__/__init__.cpython-39.pyc +0 -0
- shap_e/models/nerf/__pycache__/model.cpython-39.pyc +0 -0
- shap_e/models/nerf/__pycache__/ray.cpython-39.pyc +0 -0
- shap_e/models/nerf/__pycache__/renderer.cpython-39.pyc +0 -0
- shap_e/models/nerf/model.py +255 -0
- shap_e/models/nerf/ray.py +512 -0
- shap_e/models/nerf/renderer.py +301 -0
- shap_e/models/nerstf/__pycache__/mlp.cpython-39.pyc +0 -0
- shap_e/models/nerstf/__pycache__/renderer.cpython-39.pyc +0 -0
- shap_e/models/nerstf/mlp.py +174 -0
app.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from functools import partial
|
5 |
+
from typing import Optional
|
6 |
+
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
7 |
+
from shap_e.diffusion.sample import sample_latents
|
8 |
+
from shap_e.models.download import load_model, load_config
|
9 |
+
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
|
10 |
+
import trimesh
|
11 |
+
import torch.nn as nn
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import warnings
|
15 |
+
from huggingface_hub import hf_hub_download
|
16 |
+
import hashlib
|
17 |
+
|
18 |
+
import sys
|
19 |
+
|
20 |
+
sys.tracebacklimit = 0
|
21 |
+
def set_seed(seed=1024):
|
22 |
+
random.seed(seed)
|
23 |
+
np.random.seed(seed)
|
24 |
+
torch.manual_seed(seed)
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
torch.cuda.manual_seed(seed)
|
27 |
+
torch.cuda.manual_seed_all(seed)
|
28 |
+
torch.backends.cudnn.deterministic = True
|
29 |
+
|
30 |
+
def freeze_params(params):
|
31 |
+
for param in params:
|
32 |
+
param.requires_grad = False
|
33 |
+
|
34 |
+
class Blocks(gr.Blocks):
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
theme: str = "default",
|
39 |
+
analytics_enabled: Optional[bool] = None,
|
40 |
+
mode: str = "blocks",
|
41 |
+
title: str = "Gradio",
|
42 |
+
css: Optional[str] = None,
|
43 |
+
**kwargs,
|
44 |
+
):
|
45 |
+
self.extra_configs = {
|
46 |
+
'thumbnail': kwargs.pop('thumbnail', ''),
|
47 |
+
'url': kwargs.pop('url', 'https://gradio.app/'),
|
48 |
+
'creator': kwargs.pop('creator', '@teamGradio'),
|
49 |
+
}
|
50 |
+
|
51 |
+
super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
|
52 |
+
warnings.filterwarnings("ignore")
|
53 |
+
|
54 |
+
def get_config_file(self):
|
55 |
+
config = super(Blocks, self).get_config_file()
|
56 |
+
|
57 |
+
for k, v in self.extra_configs.items():
|
58 |
+
config[k] = v
|
59 |
+
|
60 |
+
return config
|
61 |
+
def optimize_all(xm, models, initial_noise, noise_start_t, diffusion, latent_model, device, prompt, instruction, rand_seed):
|
62 |
+
state = {}
|
63 |
+
out_gen_1, out_gen_2, out_gen_3, out_gen_4, state = generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state)
|
64 |
+
edited_1, edited_2, edited_3, edited_4, state = _3d_editing(xm, models, diffusion, initial_noise, noise_start_t, device, instruction, rand_seed, state)
|
65 |
+
print(state)
|
66 |
+
return out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4
|
67 |
+
def generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state):
|
68 |
+
set_seed(rand_seed)
|
69 |
+
batch_size = 4
|
70 |
+
guidance_scale = 15.0
|
71 |
+
xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device)
|
72 |
+
xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device)
|
73 |
+
xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
|
74 |
+
|
75 |
+
print("prompt: ", prompt, "rand_seed: ", rand_seed, "state:", state)
|
76 |
+
latents = sample_latents(
|
77 |
+
batch_size=batch_size,
|
78 |
+
model=latent_model,
|
79 |
+
diffusion=diffusion,
|
80 |
+
guidance_scale=guidance_scale,
|
81 |
+
model_kwargs=dict(texts=[prompt] * batch_size),
|
82 |
+
progress=True,
|
83 |
+
clip_denoised=True,
|
84 |
+
use_fp16=True,
|
85 |
+
use_karras=True,
|
86 |
+
karras_steps=64,
|
87 |
+
sigma_min=1e-3,
|
88 |
+
sigma_max=160,
|
89 |
+
s_churn=0,
|
90 |
+
)
|
91 |
+
prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed)).encode('utf-8')).hexdigest())
|
92 |
+
mesh_path = []
|
93 |
+
output_path = './logs'
|
94 |
+
os.makedirs(os.path.join(output_path, 'source'), exist_ok=True)
|
95 |
+
state['latent'] = []
|
96 |
+
state['prompt'] = prompt
|
97 |
+
state['rand_seed_1'] = rand_seed
|
98 |
+
for i, latent in enumerate(latents):
|
99 |
+
|
100 |
+
output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
|
101 |
+
t_obj = decode_latent_mesh(xm, latent).tri_mesh()
|
102 |
+
with open(output_path_tmp, 'w') as f:
|
103 |
+
t_obj.write_obj(f)
|
104 |
+
|
105 |
+
mesh = trimesh.load_mesh(output_path_tmp)
|
106 |
+
angle = np.radians(180)
|
107 |
+
axis = [0, 1, 0]
|
108 |
+
rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
|
109 |
+
mesh.apply_transform(rotation_matrix)
|
110 |
+
angle = np.radians(90)
|
111 |
+
axis = [1, 0, 0]
|
112 |
+
rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
|
113 |
+
mesh.apply_transform(rotation_matrix)
|
114 |
+
output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
|
115 |
+
mesh.export(output_path_tmp)
|
116 |
+
state['latent'].append(latent.clone().detach())
|
117 |
+
mesh_path.append(output_path_tmp)
|
118 |
+
|
119 |
+
return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
|
120 |
+
|
121 |
+
def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instruction, rand_seed, state):
|
122 |
+
set_seed(rand_seed)
|
123 |
+
mesh_path = []
|
124 |
+
prompt = state['prompt']
|
125 |
+
rand_seed_1 = state['rand_seed_1']
|
126 |
+
print("prompt: ", prompt, "rand_seed: ", rand_seed, "instruction:", instruction, "state:", state)
|
127 |
+
prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed_1) + '_' + instruction + '_' + str(rand_seed)).encode('utf-8')).hexdigest())
|
128 |
+
if 'santa' in instruction:
|
129 |
+
e_type = 'santa_hat'
|
130 |
+
elif 'rainbow' in instruction:
|
131 |
+
e_type = 'rainbow'
|
132 |
+
elif 'gold' in instruction:
|
133 |
+
e_type = 'golden'
|
134 |
+
elif 'lego' in instruction:
|
135 |
+
e_type = 'lego'
|
136 |
+
elif 'wooden' in instruction:
|
137 |
+
e_type = 'wooden'
|
138 |
+
elif 'cyber' in instruction:
|
139 |
+
e_type = 'cyber'
|
140 |
+
|
141 |
+
# import pdb; pdb.set_trace()
|
142 |
+
model = models[e_type].to(device)
|
143 |
+
noise_initial = initial_noise[e_type].to(device)
|
144 |
+
noise_start_t = start_t[e_type]
|
145 |
+
general_save_path = './logs/edited'
|
146 |
+
os.makedirs(general_save_path, exist_ok=True)
|
147 |
+
for i, latent in enumerate(state['latent']):
|
148 |
+
latent = latent.to(device)
|
149 |
+
text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
|
150 |
+
print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
|
151 |
+
ref_latent = latent.clone().unsqueeze(0)
|
152 |
+
t_1 = torch.randint(noise_start_t, noise_start_t + 1, (1,), device=device).long()
|
153 |
+
|
154 |
+
noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
|
155 |
+
out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
|
156 |
+
model_kwargs=text_embeddings_clip,
|
157 |
+
condition_latents=ref_latent)
|
158 |
+
|
159 |
+
updated_latents = out_1['pred_xstart']
|
160 |
+
|
161 |
+
if 'santa' in instruction:
|
162 |
+
xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.25]).to(device)
|
163 |
+
xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1]).to(device)
|
164 |
+
xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
|
165 |
+
|
166 |
+
else:
|
167 |
+
xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device)
|
168 |
+
xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device)
|
169 |
+
xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
|
170 |
+
|
171 |
+
for latent_idx, updated_latent in enumerate(updated_latents):
|
172 |
+
output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i))
|
173 |
+
|
174 |
+
t = decode_latent_mesh(xm, updated_latent).tri_mesh()
|
175 |
+
with open(output_path, 'w') as f:
|
176 |
+
t.write_obj(f)
|
177 |
+
mesh = trimesh.load_mesh(output_path)
|
178 |
+
|
179 |
+
angle = np.radians(180)
|
180 |
+
axis = [0, 1, 0]
|
181 |
+
|
182 |
+
rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
|
183 |
+
mesh.apply_transform(rotation_matrix)
|
184 |
+
angle = np.radians(90)
|
185 |
+
axis = [1, 0, 0]
|
186 |
+
|
187 |
+
rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
|
188 |
+
mesh.apply_transform(rotation_matrix)
|
189 |
+
|
190 |
+
output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i))
|
191 |
+
mesh.export(output_path)
|
192 |
+
mesh_path.append(output_path)
|
193 |
+
return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
|
194 |
+
def main():
|
195 |
+
|
196 |
+
css = """
|
197 |
+
#img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
|
198 |
+
{
|
199 |
+
height: var(--height) !important;
|
200 |
+
max-height: var(--height) !important;
|
201 |
+
min-height: var(--height) !important;
|
202 |
+
}
|
203 |
+
#paper-info a {
|
204 |
+
color:#008AD7;
|
205 |
+
text-decoration: none;
|
206 |
+
}
|
207 |
+
#paper-info a:hover {
|
208 |
+
cursor: pointer;
|
209 |
+
text-decoration: none;
|
210 |
+
}
|
211 |
+
|
212 |
+
.tooltip {
|
213 |
+
color: #555;
|
214 |
+
position: relative;
|
215 |
+
display: inline-block;
|
216 |
+
cursor: pointer;
|
217 |
+
}
|
218 |
+
|
219 |
+
.tooltip .tooltiptext {
|
220 |
+
visibility: hidden;
|
221 |
+
width: 400px;
|
222 |
+
background-color: #555;
|
223 |
+
color: #fff;
|
224 |
+
text-align: center;
|
225 |
+
padding: 5px;
|
226 |
+
border-radius: 5px;
|
227 |
+
position: absolute;
|
228 |
+
z-index: 1; /* Set z-index to 1 */
|
229 |
+
left: 10px;
|
230 |
+
top: 100%;
|
231 |
+
opacity: 0;
|
232 |
+
transition: opacity 0.3s;
|
233 |
+
}
|
234 |
+
|
235 |
+
.tooltip:hover .tooltiptext {
|
236 |
+
visibility: visible;
|
237 |
+
opacity: 1;
|
238 |
+
z-index: 9999; /* Set a high z-index value when hovering */
|
239 |
+
}
|
240 |
+
|
241 |
+
|
242 |
+
"""
|
243 |
+
|
244 |
+
rescale_js = """
|
245 |
+
function(x) {
|
246 |
+
const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
|
247 |
+
let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
|
248 |
+
const image_width = root.querySelector('#img2img_image').clientWidth;
|
249 |
+
const target_height = parseInt(image_width * image_scale);
|
250 |
+
document.body.style.setProperty('--height', `${target_height}px`);
|
251 |
+
root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
|
252 |
+
root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
|
253 |
+
return x;
|
254 |
+
}
|
255 |
+
"""
|
256 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
257 |
+
latent_model = load_model('text300M', device=device)
|
258 |
+
xm = load_model('transmitter', device=device)
|
259 |
+
diffusion = diffusion_from_config(load_config('diffusion'))
|
260 |
+
freeze_params(xm.parameters())
|
261 |
+
models = dict()
|
262 |
+
initial_noise = dict()
|
263 |
+
noise_start_t = dict()
|
264 |
+
editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
|
265 |
+
|
266 |
+
for editing_type in editing_types:
|
267 |
+
tmp_model = load_model('text300M', device=device)
|
268 |
+
with torch.no_grad():
|
269 |
+
new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=tmp_model.wrapped.input_proj.weight.dtype)
|
270 |
+
new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
|
271 |
+
new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
|
272 |
+
new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
|
273 |
+
new_proj.bias[:1024].copy_(tmp_model.wrapped.input_proj.bias)
|
274 |
+
tmp_model.wrapped.input_proj = new_proj
|
275 |
+
|
276 |
+
ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
|
277 |
+
tmp_model.load_state_dict(ckp['model'])
|
278 |
+
noise_initial = ckp['initial_noise']['noise'].to(device)
|
279 |
+
initial_noise[editing_type] = noise_initial
|
280 |
+
noise_start_t[editing_type] = ckp['t_start']
|
281 |
+
models[editing_type] = tmp_model
|
282 |
+
|
283 |
+
with Blocks(
|
284 |
+
css=css,
|
285 |
+
analytics_enabled=False,
|
286 |
+
title="SHAPE-EDITOR demo",
|
287 |
+
) as demo:
|
288 |
+
description = """<p style="text-align: center; font-weight: bold;">
|
289 |
+
<span style="font-size: 28px"> <span style="font-size: 140%">S</span>HAP-<span style="font-size: 140%">E</span>DITOR: Instruction-guided <br> Latent 3D Editing in Seconds</span>
|
290 |
+
<br>
|
291 |
+
<span style="font-size: 18px" id="paper-info">
|
292 |
+
[<a href=" " target="_blank">Project Page</a>]
|
293 |
+
[<a href=" " target="_blank">Paper</a>]
|
294 |
+
[<a href=" " target="_blank">GitHub</a>]
|
295 |
+
</span>
|
296 |
+
</p>
|
297 |
+
"""
|
298 |
+
state = gr.State({})
|
299 |
+
gr.HTML(description)
|
300 |
+
with gr.Column():
|
301 |
+
with gr.Column():
|
302 |
+
gr.HTML('<span style="font-size: 20px; font-weight: bold">Step 1: generate original 3D object using Shap-E.</span>')
|
303 |
+
prompt = gr.Textbox(
|
304 |
+
label="Text prompt for initial 3D generation", lines=1
|
305 |
+
)
|
306 |
+
gen_btn = gr.Button(value='Generate', scale=1)
|
307 |
+
|
308 |
+
|
309 |
+
with gr.Column():
|
310 |
+
gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated 3D objects</span>')
|
311 |
+
with gr.Row():
|
312 |
+
out_gen_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 1 (step 1)")
|
313 |
+
out_gen_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 2 (step 1)")
|
314 |
+
out_gen_3 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 3 (step 1)")
|
315 |
+
out_gen_4 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 4 (step 1)")
|
316 |
+
|
317 |
+
with gr.Column(scale=1):
|
318 |
+
gr.HTML('<span style="font-size: 20px; font-weight: bold">Step 2: apply 3D editing with S</span>HAP-<span style="font-size: 140%">E</span>DITOR.</span>')
|
319 |
+
|
320 |
+
editing_choice = gr.Dropdown(
|
321 |
+
["Add a santa hat to it", "Make it look like made of gold", "Make the color of it look like rainbow", "Make it in cyberpunk style", "Make it wooden", "Make it look like make of lego"], value='Add a santa hat to it', multiselect=False, label="Editing effects", info="Select specific editing you want to apply!"
|
322 |
+
),
|
323 |
+
apply_btn = gr.Button(value='Editing', scale=1)
|
324 |
+
|
325 |
+
with gr.Column(scale=3):
|
326 |
+
gr.HTML('<span style="font-size: 20px; font-weight: bold">Edited 3D objects</span>')
|
327 |
+
with gr.Row():
|
328 |
+
edited_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 1 (step 2)")
|
329 |
+
edited_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 2 (step 2)")
|
330 |
+
edited_3 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 3 (step 2)")
|
331 |
+
edited_4 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 4 (step 2)")
|
332 |
+
|
333 |
+
|
334 |
+
with gr.Accordion("Advanced Options", open=False):
|
335 |
+
rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random seed")
|
336 |
+
|
337 |
+
gen_btn.click(
|
338 |
+
fn=partial(generate_3d_with_shap_e, xm, diffusion, latent_model, device),
|
339 |
+
inputs=[prompt, rand_seed, state],
|
340 |
+
outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
|
341 |
+
queue=False)
|
342 |
+
|
343 |
+
apply_btn.click(
|
344 |
+
fn=partial(_3d_editing, xm, models, diffusion, initial_noise, noise_start_t, device),
|
345 |
+
inputs=[
|
346 |
+
editing_choice[0], rand_seed, state
|
347 |
+
],
|
348 |
+
outputs=[edited_1, edited_2, edited_3, edited_4, state],
|
349 |
+
queue=True
|
350 |
+
)
|
351 |
+
print("Generate examples...")
|
352 |
+
with gr.Column():
|
353 |
+
gr.Examples(
|
354 |
+
examples=[
|
355 |
+
[ "a corgi",
|
356 |
+
"Make the color of it look like rainbow",
|
357 |
+
456,
|
358 |
+
],
|
359 |
+
["a penguin",
|
360 |
+
"Make it look like make of lego",
|
361 |
+
214,
|
362 |
+
],
|
363 |
+
],
|
364 |
+
inputs=[prompt, editing_choice[0], rand_seed],
|
365 |
+
outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4],
|
366 |
+
fn=partial(optimize_all, xm, models, initial_noise, noise_start_t, diffusion, latent_model, device),
|
367 |
+
cache_examples=True,
|
368 |
+
)
|
369 |
+
|
370 |
+
|
371 |
+
demo.queue(max_size=10, api_open=False)
|
372 |
+
demo.launch(share=True, show_api=False, show_error=True)
|
373 |
+
|
374 |
+
if __name__ == '__main__':
|
375 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
filelock
|
2 |
+
pillow
|
3 |
+
torch
|
4 |
+
fire
|
5 |
+
humanize
|
6 |
+
requests
|
7 |
+
tqdm
|
8 |
+
matplot
|
9 |
+
scikit-image
|
10 |
+
scipy
|
11 |
+
numpy
|
12 |
+
blobfile
|
13 |
+
clip @ git+https://github.com/openai/CLIP.git
|
14 |
+
trimesh
|
15 |
+
|
16 |
+
# gradio demo
|
17 |
+
gradio
|
shap_e/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
shap_e/__init__.py
ADDED
File without changes
|
shap_e/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (156 Bytes). View file
|
|
shap_e/diffusion/__init__.py
ADDED
File without changes
|
shap_e/diffusion/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (166 Bytes). View file
|
|
shap_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc
ADDED
Binary file (33.9 kB). View file
|
|
shap_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc
ADDED
Binary file (12.4 kB). View file
|
|
shap_e/diffusion/__pycache__/sample.cpython-39.pyc
ADDED
Binary file (3.71 kB). View file
|
|
shap_e/diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,1143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Based on https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from typing import Any, Dict, Iterable, Optional, Sequence, Union
|
7 |
+
|
8 |
+
import blobfile as bf
|
9 |
+
import numpy as np
|
10 |
+
import torch as th
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
|
14 |
+
def diffusion_from_config(config: Union[str, Dict[str, Any]]) -> "GaussianDiffusion":
|
15 |
+
if isinstance(config, str):
|
16 |
+
with bf.BlobFile(config, "rb") as f:
|
17 |
+
obj = yaml.load(f, Loader=yaml.SafeLoader)
|
18 |
+
return diffusion_from_config(obj)
|
19 |
+
|
20 |
+
schedule = config["schedule"]
|
21 |
+
steps = config["timesteps"]
|
22 |
+
respace = config.get("respacing", None)
|
23 |
+
mean_type = config.get("mean_type", "epsilon")
|
24 |
+
betas = get_named_beta_schedule(schedule, steps, **config.get("schedule_args", {}))
|
25 |
+
channel_scales = config.get("channel_scales", None)
|
26 |
+
channel_biases = config.get("channel_biases", None)
|
27 |
+
if channel_scales is not None:
|
28 |
+
channel_scales = np.array(channel_scales)
|
29 |
+
if channel_biases is not None:
|
30 |
+
channel_biases = np.array(channel_biases)
|
31 |
+
kwargs = dict(
|
32 |
+
betas=betas,
|
33 |
+
model_mean_type=mean_type,
|
34 |
+
model_var_type="learned_range",
|
35 |
+
loss_type="mse",
|
36 |
+
channel_scales=channel_scales,
|
37 |
+
channel_biases=channel_biases,
|
38 |
+
)
|
39 |
+
if respace is None:
|
40 |
+
return GaussianDiffusion(**kwargs)
|
41 |
+
else:
|
42 |
+
return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs)
|
43 |
+
|
44 |
+
|
45 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
46 |
+
"""
|
47 |
+
This is the deprecated API for creating beta schedules.
|
48 |
+
|
49 |
+
See get_named_beta_schedule() for the new library of schedules.
|
50 |
+
"""
|
51 |
+
if beta_schedule == "linear":
|
52 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
53 |
+
else:
|
54 |
+
raise NotImplementedError(beta_schedule)
|
55 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
56 |
+
return betas
|
57 |
+
|
58 |
+
|
59 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **extra_args: float):
|
60 |
+
"""
|
61 |
+
Get a pre-defined beta schedule for the given name.
|
62 |
+
|
63 |
+
The beta schedule library consists of beta schedules which remain similar
|
64 |
+
in the limit of num_diffusion_timesteps.
|
65 |
+
Beta schedules may be added, but should not be removed or changed once
|
66 |
+
they are committed to maintain backwards compatibility.
|
67 |
+
"""
|
68 |
+
if schedule_name == "linear":
|
69 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
70 |
+
# diffusion steps.
|
71 |
+
scale = 1000 / num_diffusion_timesteps
|
72 |
+
return get_beta_schedule(
|
73 |
+
"linear",
|
74 |
+
beta_start=scale * 0.0001,
|
75 |
+
beta_end=scale * 0.02,
|
76 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
77 |
+
)
|
78 |
+
elif schedule_name == "cosine":
|
79 |
+
return betas_for_alpha_bar(
|
80 |
+
num_diffusion_timesteps,
|
81 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
82 |
+
)
|
83 |
+
elif schedule_name == "inv_parabola":
|
84 |
+
exponent = extra_args.get("power", 2.0)
|
85 |
+
return betas_for_alpha_bar(
|
86 |
+
num_diffusion_timesteps,
|
87 |
+
lambda t: 1 - t**exponent,
|
88 |
+
)
|
89 |
+
elif schedule_name == "translated_parabola":
|
90 |
+
exponent = extra_args.get("power", 2.0)
|
91 |
+
return betas_for_alpha_bar(
|
92 |
+
num_diffusion_timesteps,
|
93 |
+
lambda t: (1 - t) ** exponent,
|
94 |
+
)
|
95 |
+
elif schedule_name == "exp":
|
96 |
+
coefficient = extra_args.get("coefficient", -12.0)
|
97 |
+
return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.exp(t * coefficient))
|
98 |
+
else:
|
99 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
100 |
+
|
101 |
+
|
102 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
103 |
+
"""
|
104 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
105 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
106 |
+
|
107 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
108 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
109 |
+
produces the cumulative product of (1-beta) up to that
|
110 |
+
part of the diffusion process.
|
111 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
112 |
+
prevent singularities.
|
113 |
+
"""
|
114 |
+
betas = []
|
115 |
+
for i in range(num_diffusion_timesteps):
|
116 |
+
t1 = i / num_diffusion_timesteps
|
117 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
118 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
119 |
+
return np.array(betas)
|
120 |
+
|
121 |
+
|
122 |
+
def space_timesteps(num_timesteps, section_counts):
|
123 |
+
"""
|
124 |
+
Create a list of timesteps to use from an original diffusion process,
|
125 |
+
given the number of timesteps we want to take from equally-sized portions
|
126 |
+
of the original process.
|
127 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
128 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
129 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
130 |
+
:param num_timesteps: the number of diffusion steps in the original
|
131 |
+
process to divide up.
|
132 |
+
:param section_counts: either a list of numbers, or a string containing
|
133 |
+
comma-separated numbers, indicating the step count
|
134 |
+
per section. As a special case, use "ddimN" where N
|
135 |
+
is a number of steps to use the striding from the
|
136 |
+
DDIM paper.
|
137 |
+
:return: a set of diffusion steps from the original process to use.
|
138 |
+
"""
|
139 |
+
if isinstance(section_counts, str):
|
140 |
+
if section_counts.startswith("ddim"):
|
141 |
+
desired_count = int(section_counts[len("ddim") :])
|
142 |
+
for i in range(1, num_timesteps):
|
143 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
144 |
+
return set(range(0, num_timesteps, i))
|
145 |
+
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
|
146 |
+
elif section_counts.startswith("exact"):
|
147 |
+
res = set(int(x) for x in section_counts[len("exact") :].split(","))
|
148 |
+
for x in res:
|
149 |
+
if x < 0 or x >= num_timesteps:
|
150 |
+
raise ValueError(f"timestep out of bounds: {x}")
|
151 |
+
return res
|
152 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
153 |
+
size_per = num_timesteps // len(section_counts)
|
154 |
+
extra = num_timesteps % len(section_counts)
|
155 |
+
start_idx = 0
|
156 |
+
all_steps = []
|
157 |
+
for i, section_count in enumerate(section_counts):
|
158 |
+
size = size_per + (1 if i < extra else 0)
|
159 |
+
if size < section_count:
|
160 |
+
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
|
161 |
+
if section_count <= 1:
|
162 |
+
frac_stride = 1
|
163 |
+
else:
|
164 |
+
frac_stride = (size - 1) / (section_count - 1)
|
165 |
+
cur_idx = 0.0
|
166 |
+
taken_steps = []
|
167 |
+
for _ in range(section_count):
|
168 |
+
taken_steps.append(start_idx + round(cur_idx))
|
169 |
+
cur_idx += frac_stride
|
170 |
+
all_steps += taken_steps
|
171 |
+
start_idx += size
|
172 |
+
return set(all_steps)
|
173 |
+
|
174 |
+
|
175 |
+
class GaussianDiffusion:
|
176 |
+
"""
|
177 |
+
Utilities for training and sampling diffusion models.
|
178 |
+
|
179 |
+
Ported directly from here:
|
180 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
181 |
+
|
182 |
+
:param betas: a 1-D array of betas for each diffusion timestep from T to 1.
|
183 |
+
:param model_mean_type: a string determining what the model outputs.
|
184 |
+
:param model_var_type: a string determining how variance is output.
|
185 |
+
:param loss_type: a string determining the loss function to use.
|
186 |
+
:param discretized_t0: if True, use discrete gaussian loss for t=0. Only
|
187 |
+
makes sense for images.
|
188 |
+
:param channel_scales: a multiplier to apply to x_start in training_losses
|
189 |
+
and sampling functions.
|
190 |
+
"""
|
191 |
+
|
192 |
+
def __init__(
|
193 |
+
self,
|
194 |
+
*,
|
195 |
+
betas: Sequence[float],
|
196 |
+
model_mean_type: str,
|
197 |
+
model_var_type: str,
|
198 |
+
loss_type: str,
|
199 |
+
discretized_t0: bool = False,
|
200 |
+
channel_scales: Optional[np.ndarray] = None,
|
201 |
+
channel_biases: Optional[np.ndarray] = None,
|
202 |
+
):
|
203 |
+
self.model_mean_type = model_mean_type
|
204 |
+
self.model_var_type = model_var_type
|
205 |
+
self.loss_type = loss_type
|
206 |
+
self.discretized_t0 = discretized_t0
|
207 |
+
self.channel_scales = channel_scales
|
208 |
+
self.channel_biases = channel_biases
|
209 |
+
|
210 |
+
# Use float64 for accuracy.
|
211 |
+
betas = np.array(betas, dtype=np.float64)
|
212 |
+
self.betas = betas
|
213 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
214 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
215 |
+
|
216 |
+
self.num_timesteps = int(betas.shape[0])
|
217 |
+
|
218 |
+
alphas = 1.0 - betas
|
219 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
220 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
221 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
222 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
223 |
+
|
224 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
225 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
226 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
227 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
228 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
229 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
230 |
+
|
231 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
232 |
+
self.posterior_variance = (
|
233 |
+
betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
234 |
+
)
|
235 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
236 |
+
self.posterior_log_variance_clipped = np.log(
|
237 |
+
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
238 |
+
)
|
239 |
+
self.posterior_mean_coef1 = (
|
240 |
+
betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
241 |
+
)
|
242 |
+
self.posterior_mean_coef2 = (
|
243 |
+
(1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
244 |
+
)
|
245 |
+
|
246 |
+
def get_sigmas(self, t):
|
247 |
+
return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape)
|
248 |
+
|
249 |
+
def q_mean_variance(self, x_start, t):
|
250 |
+
"""
|
251 |
+
Get the distribution q(x_t | x_0).
|
252 |
+
|
253 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
254 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
255 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
256 |
+
"""
|
257 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
258 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
259 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
260 |
+
return mean, variance, log_variance
|
261 |
+
|
262 |
+
def q_sample(self, x_start, t, noise=None):
|
263 |
+
"""
|
264 |
+
Diffuse the data for a given number of diffusion steps.
|
265 |
+
|
266 |
+
In other words, sample from q(x_t | x_0).
|
267 |
+
|
268 |
+
:param x_start: the initial data batch.
|
269 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
270 |
+
:param noise: if specified, the split-out normal noise.
|
271 |
+
:return: A noisy version of x_start.
|
272 |
+
"""
|
273 |
+
if noise is None:
|
274 |
+
noise = th.randn_like(x_start)
|
275 |
+
assert noise.shape == x_start.shape
|
276 |
+
return (
|
277 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
278 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
279 |
+
)
|
280 |
+
|
281 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
282 |
+
"""
|
283 |
+
Compute the mean and variance of the diffusion posterior:
|
284 |
+
|
285 |
+
q(x_{t-1} | x_t, x_0)
|
286 |
+
|
287 |
+
"""
|
288 |
+
assert x_start.shape == x_t.shape
|
289 |
+
posterior_mean = (
|
290 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
291 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
292 |
+
)
|
293 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
294 |
+
posterior_log_variance_clipped = _extract_into_tensor(
|
295 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
296 |
+
)
|
297 |
+
assert (
|
298 |
+
posterior_mean.shape[0]
|
299 |
+
== posterior_variance.shape[0]
|
300 |
+
== posterior_log_variance_clipped.shape[0]
|
301 |
+
== x_start.shape[0]
|
302 |
+
)
|
303 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
304 |
+
|
305 |
+
def p_mean_variance(
|
306 |
+
self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None, condition_latents=None
|
307 |
+
):
|
308 |
+
"""
|
309 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
310 |
+
the initial x, x_0.
|
311 |
+
|
312 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
313 |
+
as input.
|
314 |
+
:param x: the [N x C x ...] tensor at time t.
|
315 |
+
:param t: a 1-D Tensor of timesteps.
|
316 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
317 |
+
:param denoised_fn: if not None, a function which applies to the
|
318 |
+
x_start prediction before it is used to sample. Applies before
|
319 |
+
clip_denoised.
|
320 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
321 |
+
pass to the model. This can be used for conditioning.
|
322 |
+
:return: a dict with the following keys:
|
323 |
+
- 'mean': the model mean output.
|
324 |
+
- 'variance': the model variance output.
|
325 |
+
- 'log_variance': the log of 'variance'.
|
326 |
+
- 'pred_xstart': the prediction for x_0.
|
327 |
+
"""
|
328 |
+
if model_kwargs is None:
|
329 |
+
model_kwargs = {}
|
330 |
+
B, C = x.shape[:2]
|
331 |
+
assert t.shape == (B,)
|
332 |
+
model_output = model(x, t, **model_kwargs) if condition_latents is None else model(x, t, condition_latents, **model_kwargs)
|
333 |
+
if isinstance(model_output, tuple):
|
334 |
+
model_output, extra = model_output
|
335 |
+
else:
|
336 |
+
extra = None
|
337 |
+
|
338 |
+
if self.model_var_type in ["learned", "learned_range"]:
|
339 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
340 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
341 |
+
if self.model_var_type == "learned":
|
342 |
+
model_log_variance = model_var_values
|
343 |
+
model_variance = th.exp(model_log_variance)
|
344 |
+
else:
|
345 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
346 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
347 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
348 |
+
frac = (model_var_values + 1) / 2
|
349 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
350 |
+
model_variance = th.exp(model_log_variance)
|
351 |
+
else:
|
352 |
+
model_variance, model_log_variance = {
|
353 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
354 |
+
# to get a better decoder log likelihood.
|
355 |
+
"fixed_large": (
|
356 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
357 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
358 |
+
),
|
359 |
+
"fixed_small": (
|
360 |
+
self.posterior_variance,
|
361 |
+
self.posterior_log_variance_clipped,
|
362 |
+
),
|
363 |
+
}[self.model_var_type]
|
364 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
365 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
366 |
+
|
367 |
+
def process_xstart(x):
|
368 |
+
if denoised_fn is not None:
|
369 |
+
x = denoised_fn(x)
|
370 |
+
if clip_denoised:
|
371 |
+
return x.clamp(-1, 1)
|
372 |
+
return x
|
373 |
+
|
374 |
+
if self.model_mean_type == "x_prev":
|
375 |
+
pred_xstart = process_xstart(
|
376 |
+
self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
|
377 |
+
)
|
378 |
+
model_mean = model_output
|
379 |
+
elif self.model_mean_type in ["x_start", "epsilon"]:
|
380 |
+
if self.model_mean_type == "x_start":
|
381 |
+
pred_xstart = process_xstart(model_output)
|
382 |
+
else:
|
383 |
+
pred_xstart = process_xstart(
|
384 |
+
self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
|
385 |
+
)
|
386 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
387 |
+
else:
|
388 |
+
raise NotImplementedError(self.model_mean_type)
|
389 |
+
|
390 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
391 |
+
return {
|
392 |
+
"mean": model_mean,
|
393 |
+
"variance": model_variance,
|
394 |
+
"log_variance": model_log_variance,
|
395 |
+
"pred_xstart": pred_xstart,
|
396 |
+
"extra": extra,
|
397 |
+
}
|
398 |
+
|
399 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
400 |
+
assert x_t.shape == eps.shape
|
401 |
+
return (
|
402 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
403 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
404 |
+
)
|
405 |
+
|
406 |
+
def _predict_xstart_from_xprev(self, x_t, t, xprev):
|
407 |
+
assert x_t.shape == xprev.shape
|
408 |
+
return ( # (xprev - coef2*x_t) / coef1
|
409 |
+
_extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
|
410 |
+
- _extract_into_tensor(
|
411 |
+
self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
|
412 |
+
)
|
413 |
+
* x_t
|
414 |
+
)
|
415 |
+
|
416 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
417 |
+
return (
|
418 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
419 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
420 |
+
|
421 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
422 |
+
"""
|
423 |
+
Compute the mean for the previous step, given a function cond_fn that
|
424 |
+
computes the gradient of a conditional log probability with respect to
|
425 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
426 |
+
condition on y.
|
427 |
+
|
428 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
429 |
+
"""
|
430 |
+
gradient = cond_fn(x, t, **(model_kwargs or {}))
|
431 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
432 |
+
return new_mean
|
433 |
+
|
434 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
435 |
+
"""
|
436 |
+
Compute what the p_mean_variance output would have been, should the
|
437 |
+
model's score function be conditioned by cond_fn.
|
438 |
+
|
439 |
+
See condition_mean() for details on cond_fn.
|
440 |
+
|
441 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
442 |
+
from Song et al (2020).
|
443 |
+
"""
|
444 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
445 |
+
|
446 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
447 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **(model_kwargs or {}))
|
448 |
+
|
449 |
+
out = p_mean_var.copy()
|
450 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
451 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
452 |
+
return out
|
453 |
+
|
454 |
+
def p_sample(
|
455 |
+
self,
|
456 |
+
model,
|
457 |
+
x,
|
458 |
+
t,
|
459 |
+
clip_denoised=False,
|
460 |
+
denoised_fn=None,
|
461 |
+
cond_fn=None,
|
462 |
+
model_kwargs=None,
|
463 |
+
):
|
464 |
+
"""
|
465 |
+
Sample x_{t-1} from the model at the given timestep.
|
466 |
+
|
467 |
+
:param model: the model to sample from.
|
468 |
+
:param x: the current tensor at x_{t-1}.
|
469 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
470 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
471 |
+
:param denoised_fn: if not None, a function which applies to the
|
472 |
+
x_start prediction before it is used to sample.
|
473 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
474 |
+
similarly to the model.
|
475 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
476 |
+
pass to the model. This can be used for conditioning.
|
477 |
+
:return: a dict containing the following keys:
|
478 |
+
- 'sample': a random sample from the model.
|
479 |
+
- 'pred_xstart': a prediction of x_0.
|
480 |
+
"""
|
481 |
+
out = self.p_mean_variance(
|
482 |
+
model,
|
483 |
+
x,
|
484 |
+
t,
|
485 |
+
clip_denoised=clip_denoised,
|
486 |
+
denoised_fn=denoised_fn,
|
487 |
+
model_kwargs=model_kwargs,
|
488 |
+
)
|
489 |
+
noise = th.randn_like(x)
|
490 |
+
nonzero_mask = (
|
491 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
492 |
+
) # no noise when t == 0
|
493 |
+
if cond_fn is not None:
|
494 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
495 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
496 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
497 |
+
|
498 |
+
def p_sample_loop(
|
499 |
+
self,
|
500 |
+
model,
|
501 |
+
shape,
|
502 |
+
noise=None,
|
503 |
+
clip_denoised=False,
|
504 |
+
denoised_fn=None,
|
505 |
+
cond_fn=None,
|
506 |
+
model_kwargs=None,
|
507 |
+
device=None,
|
508 |
+
progress=False,
|
509 |
+
temp=1.0,
|
510 |
+
):
|
511 |
+
"""
|
512 |
+
Generate samples from the model.
|
513 |
+
|
514 |
+
:param model: the model module.
|
515 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
516 |
+
:param noise: if specified, the noise from the encoder to sample.
|
517 |
+
Should be of the same shape as `shape`.
|
518 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
519 |
+
:param denoised_fn: if not None, a function which applies to the
|
520 |
+
x_start prediction before it is used to sample.
|
521 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
522 |
+
similarly to the model.
|
523 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
524 |
+
pass to the model. This can be used for conditioning.
|
525 |
+
:param device: if specified, the device to create the samples on.
|
526 |
+
If not specified, use a model parameter's device.
|
527 |
+
:param progress: if True, show a tqdm progress bar.
|
528 |
+
:return: a non-differentiable batch of samples.
|
529 |
+
"""
|
530 |
+
final = None
|
531 |
+
for sample in self.p_sample_loop_progressive(
|
532 |
+
model,
|
533 |
+
shape,
|
534 |
+
noise=noise,
|
535 |
+
clip_denoised=clip_denoised,
|
536 |
+
denoised_fn=denoised_fn,
|
537 |
+
cond_fn=cond_fn,
|
538 |
+
model_kwargs=model_kwargs,
|
539 |
+
device=device,
|
540 |
+
progress=progress,
|
541 |
+
temp=temp,
|
542 |
+
):
|
543 |
+
final = sample
|
544 |
+
return final["sample"]
|
545 |
+
|
546 |
+
def p_sample_loop_progressive(
|
547 |
+
self,
|
548 |
+
model,
|
549 |
+
shape,
|
550 |
+
noise=None,
|
551 |
+
clip_denoised=False,
|
552 |
+
denoised_fn=None,
|
553 |
+
cond_fn=None,
|
554 |
+
model_kwargs=None,
|
555 |
+
device=None,
|
556 |
+
progress=False,
|
557 |
+
temp=1.0,
|
558 |
+
):
|
559 |
+
"""
|
560 |
+
Generate samples from the model and yield intermediate samples from
|
561 |
+
each timestep of diffusion.
|
562 |
+
|
563 |
+
Arguments are the same as p_sample_loop().
|
564 |
+
Returns a generator over dicts, where each dict is the return value of
|
565 |
+
p_sample().
|
566 |
+
"""
|
567 |
+
|
568 |
+
if device is None:
|
569 |
+
device = next(model.parameters()).device
|
570 |
+
assert isinstance(shape, (tuple, list))
|
571 |
+
if noise is not None:
|
572 |
+
img = noise
|
573 |
+
else:
|
574 |
+
img = th.randn(*shape, device=device) * temp
|
575 |
+
indices = list(range(self.num_timesteps))[::-1]
|
576 |
+
|
577 |
+
if progress:
|
578 |
+
# Lazy import so that we don't depend on tqdm.
|
579 |
+
from tqdm.auto import tqdm
|
580 |
+
|
581 |
+
indices = tqdm(indices)
|
582 |
+
|
583 |
+
for i in indices:
|
584 |
+
t = th.tensor([i] * shape[0], device=device)
|
585 |
+
with th.no_grad():
|
586 |
+
out = self.p_sample(
|
587 |
+
model,
|
588 |
+
img,
|
589 |
+
t,
|
590 |
+
clip_denoised=clip_denoised,
|
591 |
+
denoised_fn=denoised_fn,
|
592 |
+
cond_fn=cond_fn,
|
593 |
+
model_kwargs=model_kwargs,
|
594 |
+
)
|
595 |
+
yield self.unscale_out_dict(out)
|
596 |
+
img = out["sample"]
|
597 |
+
|
598 |
+
def ddim_sample(
|
599 |
+
self,
|
600 |
+
model,
|
601 |
+
x,
|
602 |
+
t,
|
603 |
+
clip_denoised=False,
|
604 |
+
denoised_fn=None,
|
605 |
+
cond_fn=None,
|
606 |
+
model_kwargs=None,
|
607 |
+
eta=0.0,
|
608 |
+
):
|
609 |
+
"""
|
610 |
+
Sample x_{t-1} from the model using DDIM.
|
611 |
+
|
612 |
+
Same usage as p_sample().
|
613 |
+
"""
|
614 |
+
out = self.p_mean_variance(
|
615 |
+
model,
|
616 |
+
x,
|
617 |
+
t,
|
618 |
+
clip_denoised=clip_denoised,
|
619 |
+
denoised_fn=denoised_fn,
|
620 |
+
model_kwargs=model_kwargs,
|
621 |
+
)
|
622 |
+
if cond_fn is not None:
|
623 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
624 |
+
|
625 |
+
# Usually our model outputs epsilon, but we re-derive it
|
626 |
+
# in case we used x_start or x_prev prediction.
|
627 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
628 |
+
|
629 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
630 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
631 |
+
sigma = (
|
632 |
+
eta
|
633 |
+
* th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
634 |
+
* th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
635 |
+
)
|
636 |
+
# Equation 12.
|
637 |
+
noise = th.randn_like(x)
|
638 |
+
mean_pred = (
|
639 |
+
out["pred_xstart"] * th.sqrt(alpha_bar_prev)
|
640 |
+
+ th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
641 |
+
)
|
642 |
+
nonzero_mask = (
|
643 |
+
(t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
644 |
+
) # no noise when t == 0
|
645 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
646 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
647 |
+
|
648 |
+
def ddim_reverse_sample(
|
649 |
+
self,
|
650 |
+
model,
|
651 |
+
x,
|
652 |
+
t,
|
653 |
+
clip_denoised=False,
|
654 |
+
denoised_fn=None,
|
655 |
+
cond_fn=None,
|
656 |
+
model_kwargs=None,
|
657 |
+
eta=0.0,
|
658 |
+
):
|
659 |
+
"""
|
660 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
661 |
+
"""
|
662 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
663 |
+
out = self.p_mean_variance(
|
664 |
+
model,
|
665 |
+
x,
|
666 |
+
t,
|
667 |
+
clip_denoised=clip_denoised,
|
668 |
+
denoised_fn=denoised_fn,
|
669 |
+
model_kwargs=model_kwargs,
|
670 |
+
)
|
671 |
+
if cond_fn is not None:
|
672 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
673 |
+
# Usually our model outputs epsilon, but we re-derive it
|
674 |
+
# in case we used x_start or x_prev prediction.
|
675 |
+
eps = (
|
676 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
|
677 |
+
- out["pred_xstart"]
|
678 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
679 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
680 |
+
|
681 |
+
# Equation 12. reversed
|
682 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
683 |
+
|
684 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
685 |
+
|
686 |
+
def ddim_sample_loop(
|
687 |
+
self,
|
688 |
+
model,
|
689 |
+
shape,
|
690 |
+
noise=None,
|
691 |
+
clip_denoised=False,
|
692 |
+
denoised_fn=None,
|
693 |
+
cond_fn=None,
|
694 |
+
model_kwargs=None,
|
695 |
+
device=None,
|
696 |
+
progress=False,
|
697 |
+
eta=0.0,
|
698 |
+
temp=1.0,
|
699 |
+
):
|
700 |
+
"""
|
701 |
+
Generate samples from the model using DDIM.
|
702 |
+
|
703 |
+
Same usage as p_sample_loop().
|
704 |
+
"""
|
705 |
+
final = None
|
706 |
+
for sample in self.ddim_sample_loop_progressive(
|
707 |
+
model,
|
708 |
+
shape,
|
709 |
+
noise=noise,
|
710 |
+
clip_denoised=clip_denoised,
|
711 |
+
denoised_fn=denoised_fn,
|
712 |
+
cond_fn=cond_fn,
|
713 |
+
model_kwargs=model_kwargs,
|
714 |
+
device=device,
|
715 |
+
progress=progress,
|
716 |
+
eta=eta,
|
717 |
+
temp=temp,
|
718 |
+
):
|
719 |
+
final = sample
|
720 |
+
return final["sample"]
|
721 |
+
|
722 |
+
def ddim_sample_loop_progressive(
|
723 |
+
self,
|
724 |
+
model,
|
725 |
+
shape,
|
726 |
+
noise=None,
|
727 |
+
clip_denoised=False,
|
728 |
+
denoised_fn=None,
|
729 |
+
cond_fn=None,
|
730 |
+
model_kwargs=None,
|
731 |
+
device=None,
|
732 |
+
progress=False,
|
733 |
+
eta=0.0,
|
734 |
+
temp=1.0,
|
735 |
+
):
|
736 |
+
"""
|
737 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
738 |
+
each timestep of DDIM.
|
739 |
+
|
740 |
+
Same usage as p_sample_loop_progressive().
|
741 |
+
"""
|
742 |
+
if device is None:
|
743 |
+
device = next(model.parameters()).device
|
744 |
+
assert isinstance(shape, (tuple, list))
|
745 |
+
if noise is not None:
|
746 |
+
img = noise
|
747 |
+
else:
|
748 |
+
img = th.randn(*shape, device=device) * temp
|
749 |
+
indices = list(range(self.num_timesteps))[::-1]
|
750 |
+
|
751 |
+
if progress:
|
752 |
+
# Lazy import so that we don't depend on tqdm.
|
753 |
+
from tqdm.auto import tqdm
|
754 |
+
|
755 |
+
indices = tqdm(indices)
|
756 |
+
|
757 |
+
for i in indices:
|
758 |
+
t = th.tensor([i] * shape[0], device=device)
|
759 |
+
with th.no_grad():
|
760 |
+
out = self.ddim_sample(
|
761 |
+
model,
|
762 |
+
img,
|
763 |
+
t,
|
764 |
+
clip_denoised=clip_denoised,
|
765 |
+
denoised_fn=denoised_fn,
|
766 |
+
cond_fn=cond_fn,
|
767 |
+
model_kwargs=model_kwargs,
|
768 |
+
eta=eta,
|
769 |
+
)
|
770 |
+
yield self.unscale_out_dict(out)
|
771 |
+
img = out["sample"]
|
772 |
+
|
773 |
+
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None):
|
774 |
+
"""
|
775 |
+
Get a term for the variational lower-bound.
|
776 |
+
|
777 |
+
The resulting units are bits (rather than nats, as one might expect).
|
778 |
+
This allows for comparison to other papers.
|
779 |
+
|
780 |
+
:return: a dict with the following keys:
|
781 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
782 |
+
- 'pred_xstart': the x_0 predictions.
|
783 |
+
"""
|
784 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
|
785 |
+
x_start=x_start, x_t=x_t, t=t
|
786 |
+
)
|
787 |
+
out = self.p_mean_variance(
|
788 |
+
model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
|
789 |
+
)
|
790 |
+
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
|
791 |
+
kl = mean_flat(kl) / np.log(2.0)
|
792 |
+
|
793 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
794 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
795 |
+
)
|
796 |
+
if not self.discretized_t0:
|
797 |
+
decoder_nll = th.zeros_like(decoder_nll)
|
798 |
+
assert decoder_nll.shape == x_start.shape
|
799 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
800 |
+
|
801 |
+
# At the first timestep return the decoder NLL,
|
802 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
803 |
+
output = th.where((t == 0), decoder_nll, kl)
|
804 |
+
return {
|
805 |
+
"output": output,
|
806 |
+
"pred_xstart": out["pred_xstart"],
|
807 |
+
"extra": out["extra"],
|
808 |
+
}
|
809 |
+
|
810 |
+
def training_losses(
|
811 |
+
self, model, x_start, t, model_kwargs=None, noise=None
|
812 |
+
) -> Dict[str, th.Tensor]:
|
813 |
+
"""
|
814 |
+
Compute training losses for a single timestep.
|
815 |
+
|
816 |
+
:param model: the model to evaluate loss on.
|
817 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
818 |
+
:param t: a batch of timestep indices.
|
819 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
820 |
+
pass to the model. This can be used for conditioning.
|
821 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
822 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
823 |
+
Some mean or variance settings may also have other keys.
|
824 |
+
"""
|
825 |
+
x_start = self.scale_channels(x_start)
|
826 |
+
if model_kwargs is None:
|
827 |
+
model_kwargs = {}
|
828 |
+
if noise is None:
|
829 |
+
noise = th.randn_like(x_start)
|
830 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
831 |
+
|
832 |
+
terms = {}
|
833 |
+
|
834 |
+
if self.loss_type == "kl" or self.loss_type == "rescaled_kl":
|
835 |
+
vb_terms = self._vb_terms_bpd(
|
836 |
+
model=model,
|
837 |
+
x_start=x_start,
|
838 |
+
x_t=x_t,
|
839 |
+
t=t,
|
840 |
+
clip_denoised=False,
|
841 |
+
model_kwargs=model_kwargs,
|
842 |
+
)
|
843 |
+
terms["loss"] = vb_terms["output"]
|
844 |
+
if self.loss_type == "rescaled_kl":
|
845 |
+
terms["loss"] *= self.num_timesteps
|
846 |
+
extra = vb_terms["extra"]
|
847 |
+
elif self.loss_type == "mse" or self.loss_type == "rescaled_mse":
|
848 |
+
model_output = model(x_t, t, **model_kwargs)
|
849 |
+
if isinstance(model_output, tuple):
|
850 |
+
model_output, extra = model_output
|
851 |
+
else:
|
852 |
+
extra = {}
|
853 |
+
|
854 |
+
if self.model_var_type in [
|
855 |
+
"learned",
|
856 |
+
"learned_range",
|
857 |
+
]:
|
858 |
+
B, C = x_t.shape[:2]
|
859 |
+
assert model_output.shape == (
|
860 |
+
B,
|
861 |
+
C * 2,
|
862 |
+
*x_t.shape[2:],
|
863 |
+
), f"{model_output.shape} != {(B, C * 2, *x_t.shape[2:])}"
|
864 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
865 |
+
# Learn the variance using the variational bound, but don't let
|
866 |
+
# it affect our mean prediction.
|
867 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
868 |
+
terms["vb"] = self._vb_terms_bpd(
|
869 |
+
model=lambda *args, r=frozen_out: r,
|
870 |
+
x_start=x_start,
|
871 |
+
x_t=x_t,
|
872 |
+
t=t,
|
873 |
+
clip_denoised=False,
|
874 |
+
)["output"]
|
875 |
+
if self.loss_type == "rescaled_mse":
|
876 |
+
# Divide by 1000 for equivalence with initial implementation.
|
877 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
878 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
879 |
+
|
880 |
+
target = {
|
881 |
+
"x_prev": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
|
882 |
+
"x_start": x_start,
|
883 |
+
"epsilon": noise,
|
884 |
+
}[self.model_mean_type]
|
885 |
+
assert model_output.shape == target.shape == x_start.shape
|
886 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
887 |
+
if "vb" in terms:
|
888 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
889 |
+
else:
|
890 |
+
terms["loss"] = terms["mse"]
|
891 |
+
else:
|
892 |
+
raise NotImplementedError(self.loss_type)
|
893 |
+
|
894 |
+
if "losses" in extra:
|
895 |
+
terms.update({k: loss for k, (loss, _scale) in extra["losses"].items()})
|
896 |
+
for loss, scale in extra["losses"].values():
|
897 |
+
terms["loss"] = terms["loss"] + loss * scale
|
898 |
+
|
899 |
+
return terms
|
900 |
+
|
901 |
+
def _prior_bpd(self, x_start):
|
902 |
+
"""
|
903 |
+
Get the prior KL term for the variational lower-bound, measured in
|
904 |
+
bits-per-dim.
|
905 |
+
|
906 |
+
This term can't be optimized, as it only depends on the encoder.
|
907 |
+
|
908 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
909 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
910 |
+
"""
|
911 |
+
batch_size = x_start.shape[0]
|
912 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
913 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
914 |
+
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
915 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
916 |
+
|
917 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None):
|
918 |
+
"""
|
919 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
920 |
+
as well as other related quantities.
|
921 |
+
|
922 |
+
:param model: the model to evaluate loss on.
|
923 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
924 |
+
:param clip_denoised: if True, clip denoised samples.
|
925 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
926 |
+
pass to the model. This can be used for conditioning.
|
927 |
+
|
928 |
+
:return: a dict containing the following keys:
|
929 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
930 |
+
- prior_bpd: the prior term in the lower-bound.
|
931 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
932 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
933 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
934 |
+
"""
|
935 |
+
device = x_start.device
|
936 |
+
batch_size = x_start.shape[0]
|
937 |
+
|
938 |
+
vb = []
|
939 |
+
xstart_mse = []
|
940 |
+
mse = []
|
941 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
942 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
943 |
+
noise = th.randn_like(x_start)
|
944 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
945 |
+
# Calculate VLB term at the current timestep
|
946 |
+
with th.no_grad():
|
947 |
+
out = self._vb_terms_bpd(
|
948 |
+
model,
|
949 |
+
x_start=x_start,
|
950 |
+
x_t=x_t,
|
951 |
+
t=t_batch,
|
952 |
+
clip_denoised=clip_denoised,
|
953 |
+
model_kwargs=model_kwargs,
|
954 |
+
)
|
955 |
+
vb.append(out["output"])
|
956 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
957 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
958 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
959 |
+
|
960 |
+
vb = th.stack(vb, dim=1)
|
961 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
962 |
+
mse = th.stack(mse, dim=1)
|
963 |
+
|
964 |
+
prior_bpd = self._prior_bpd(x_start)
|
965 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
966 |
+
return {
|
967 |
+
"total_bpd": total_bpd,
|
968 |
+
"prior_bpd": prior_bpd,
|
969 |
+
"vb": vb,
|
970 |
+
"xstart_mse": xstart_mse,
|
971 |
+
"mse": mse,
|
972 |
+
}
|
973 |
+
|
974 |
+
def scale_channels(self, x: th.Tensor) -> th.Tensor:
|
975 |
+
if self.channel_scales is not None:
|
976 |
+
x = x * th.from_numpy(self.channel_scales).to(x).reshape(
|
977 |
+
[1, -1, *([1] * (len(x.shape) - 2))]
|
978 |
+
)
|
979 |
+
if self.channel_biases is not None:
|
980 |
+
x = x + th.from_numpy(self.channel_biases).to(x).reshape(
|
981 |
+
[1, -1, *([1] * (len(x.shape) - 2))]
|
982 |
+
)
|
983 |
+
return x
|
984 |
+
|
985 |
+
def unscale_channels(self, x: th.Tensor) -> th.Tensor:
|
986 |
+
if self.channel_biases is not None:
|
987 |
+
x = x - th.from_numpy(self.channel_biases).to(x).reshape(
|
988 |
+
[1, -1, *([1] * (len(x.shape) - 2))]
|
989 |
+
)
|
990 |
+
if self.channel_scales is not None:
|
991 |
+
x = x / th.from_numpy(self.channel_scales).to(x).reshape(
|
992 |
+
[1, -1, *([1] * (len(x.shape) - 2))]
|
993 |
+
)
|
994 |
+
return x
|
995 |
+
|
996 |
+
def unscale_out_dict(
|
997 |
+
self, out: Dict[str, Union[th.Tensor, Any]]
|
998 |
+
) -> Dict[str, Union[th.Tensor, Any]]:
|
999 |
+
return {
|
1000 |
+
k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items()
|
1001 |
+
}
|
1002 |
+
|
1003 |
+
|
1004 |
+
class SpacedDiffusion(GaussianDiffusion):
|
1005 |
+
"""
|
1006 |
+
A diffusion process which can skip steps in a base diffusion process.
|
1007 |
+
:param use_timesteps: (unordered) timesteps from the original diffusion
|
1008 |
+
process to retain.
|
1009 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
1010 |
+
"""
|
1011 |
+
|
1012 |
+
def __init__(self, use_timesteps: Iterable[int], **kwargs):
|
1013 |
+
self.use_timesteps = set(use_timesteps)
|
1014 |
+
self.timestep_map = []
|
1015 |
+
self.original_num_steps = len(kwargs["betas"])
|
1016 |
+
|
1017 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
1018 |
+
last_alpha_cumprod = 1.0
|
1019 |
+
new_betas = []
|
1020 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
1021 |
+
if i in self.use_timesteps:
|
1022 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
1023 |
+
last_alpha_cumprod = alpha_cumprod
|
1024 |
+
self.timestep_map.append(i)
|
1025 |
+
kwargs["betas"] = np.array(new_betas)
|
1026 |
+
super().__init__(**kwargs)
|
1027 |
+
|
1028 |
+
def p_mean_variance(self, model, *args, **kwargs):
|
1029 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
1030 |
+
|
1031 |
+
def training_losses(self, model, *args, **kwargs):
|
1032 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
1033 |
+
|
1034 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
1035 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
1036 |
+
|
1037 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
1038 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
1039 |
+
|
1040 |
+
def _wrap_model(self, model):
|
1041 |
+
if isinstance(model, _WrappedModel):
|
1042 |
+
return model
|
1043 |
+
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
|
1044 |
+
|
1045 |
+
|
1046 |
+
class _WrappedModel:
|
1047 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
1048 |
+
self.model = model
|
1049 |
+
self.timestep_map = timestep_map
|
1050 |
+
self.original_num_steps = original_num_steps
|
1051 |
+
|
1052 |
+
def __call__(self, x, ts, **kwargs):
|
1053 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
1054 |
+
new_ts = map_tensor[ts]
|
1055 |
+
return self.model(x, new_ts, **kwargs)
|
1056 |
+
|
1057 |
+
|
1058 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
1059 |
+
"""
|
1060 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
1061 |
+
|
1062 |
+
:param arr: the 1-D numpy array.
|
1063 |
+
:param timesteps: a tensor of indices into the array to extract.
|
1064 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
1065 |
+
dimension equal to the length of timesteps.
|
1066 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
1067 |
+
"""
|
1068 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
1069 |
+
while len(res.shape) < len(broadcast_shape):
|
1070 |
+
res = res[..., None]
|
1071 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
1072 |
+
|
1073 |
+
|
1074 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
1075 |
+
"""
|
1076 |
+
Compute the KL divergence between two gaussians.
|
1077 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
1078 |
+
scalars, among other use cases.
|
1079 |
+
"""
|
1080 |
+
tensor = None
|
1081 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
1082 |
+
if isinstance(obj, th.Tensor):
|
1083 |
+
tensor = obj
|
1084 |
+
break
|
1085 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
1086 |
+
|
1087 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
1088 |
+
# Tensors, but it does not work for th.exp().
|
1089 |
+
logvar1, logvar2 = [
|
1090 |
+
x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)
|
1091 |
+
]
|
1092 |
+
|
1093 |
+
return 0.5 * (
|
1094 |
+
-1.0
|
1095 |
+
+ logvar2
|
1096 |
+
- logvar1
|
1097 |
+
+ th.exp(logvar1 - logvar2)
|
1098 |
+
+ ((mean1 - mean2) ** 2) * th.exp(-logvar2)
|
1099 |
+
)
|
1100 |
+
|
1101 |
+
|
1102 |
+
def approx_standard_normal_cdf(x):
|
1103 |
+
"""
|
1104 |
+
A fast approximation of the cumulative distribution function of the
|
1105 |
+
standard normal.
|
1106 |
+
"""
|
1107 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
1108 |
+
|
1109 |
+
|
1110 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
1111 |
+
"""
|
1112 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
1113 |
+
given image.
|
1114 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
1115 |
+
rescaled to the range [-1, 1].
|
1116 |
+
:param means: the Gaussian mean Tensor.
|
1117 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
1118 |
+
:return: a tensor like x of log probabilities (in nats).
|
1119 |
+
"""
|
1120 |
+
assert x.shape == means.shape == log_scales.shape
|
1121 |
+
centered_x = x - means
|
1122 |
+
inv_stdv = th.exp(-log_scales)
|
1123 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
1124 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
1125 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
1126 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
1127 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
1128 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
1129 |
+
cdf_delta = cdf_plus - cdf_min
|
1130 |
+
log_probs = th.where(
|
1131 |
+
x < -0.999,
|
1132 |
+
log_cdf_plus,
|
1133 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
1134 |
+
)
|
1135 |
+
assert log_probs.shape == x.shape
|
1136 |
+
return log_probs
|
1137 |
+
|
1138 |
+
|
1139 |
+
def mean_flat(tensor):
|
1140 |
+
"""
|
1141 |
+
Take the mean over all non-batch dimensions.
|
1142 |
+
"""
|
1143 |
+
return tensor.flatten(1).mean(1)
|
shap_e/diffusion/k_diffusion.py
ADDED
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Based on: https://github.com/crowsonkb/k-diffusion
|
3 |
+
|
4 |
+
Copyright (c) 2022 Katherine Crowson
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in
|
14 |
+
all copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
22 |
+
THE SOFTWARE.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import torch as th
|
27 |
+
|
28 |
+
from .gaussian_diffusion import GaussianDiffusion, mean_flat
|
29 |
+
|
30 |
+
|
31 |
+
class KarrasDenoiser:
|
32 |
+
def __init__(self, sigma_data: float = 0.5):
|
33 |
+
self.sigma_data = sigma_data
|
34 |
+
|
35 |
+
def get_snr(self, sigmas):
|
36 |
+
return sigmas**-2
|
37 |
+
|
38 |
+
def get_sigmas(self, sigmas):
|
39 |
+
return sigmas
|
40 |
+
|
41 |
+
def get_scalings(self, sigma):
|
42 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
43 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
|
44 |
+
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
|
45 |
+
return c_skip, c_out, c_in
|
46 |
+
|
47 |
+
def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None):
|
48 |
+
if model_kwargs is None:
|
49 |
+
model_kwargs = {}
|
50 |
+
if noise is None:
|
51 |
+
noise = th.randn_like(x_start)
|
52 |
+
|
53 |
+
terms = {}
|
54 |
+
|
55 |
+
dims = x_start.ndim
|
56 |
+
x_t = x_start + noise * append_dims(sigmas, dims)
|
57 |
+
c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)]
|
58 |
+
model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs)
|
59 |
+
target = (x_start - c_skip * x_t) / c_out
|
60 |
+
|
61 |
+
terms["mse"] = mean_flat((model_output - target) ** 2)
|
62 |
+
terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
|
63 |
+
|
64 |
+
if "vb" in terms:
|
65 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
66 |
+
else:
|
67 |
+
terms["loss"] = terms["mse"]
|
68 |
+
|
69 |
+
return terms
|
70 |
+
|
71 |
+
def denoise(self, model, x_t, sigmas, **model_kwargs):
|
72 |
+
c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)]
|
73 |
+
rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
|
74 |
+
model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
|
75 |
+
denoised = c_out * model_output + c_skip * x_t
|
76 |
+
return model_output, denoised
|
77 |
+
|
78 |
+
|
79 |
+
class GaussianToKarrasDenoiser:
|
80 |
+
def __init__(self, model, diffusion):
|
81 |
+
from scipy import interpolate
|
82 |
+
|
83 |
+
self.model = model
|
84 |
+
self.diffusion = diffusion
|
85 |
+
self.alpha_cumprod_to_t = interpolate.interp1d(
|
86 |
+
diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps)
|
87 |
+
)
|
88 |
+
|
89 |
+
def sigma_to_t(self, sigma):
|
90 |
+
alpha_cumprod = 1.0 / (sigma**2 + 1)
|
91 |
+
if alpha_cumprod > self.diffusion.alphas_cumprod[0]:
|
92 |
+
return 0
|
93 |
+
elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]:
|
94 |
+
return self.diffusion.num_timesteps - 1
|
95 |
+
else:
|
96 |
+
return float(self.alpha_cumprod_to_t(alpha_cumprod))
|
97 |
+
|
98 |
+
def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None, condition_latents=None):
|
99 |
+
t = th.tensor(
|
100 |
+
[self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()],
|
101 |
+
dtype=th.long,
|
102 |
+
device=sigmas.device,
|
103 |
+
)
|
104 |
+
c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim)
|
105 |
+
out = self.diffusion.p_mean_variance(
|
106 |
+
self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latents
|
107 |
+
)
|
108 |
+
return None, out["pred_xstart"]
|
109 |
+
|
110 |
+
|
111 |
+
def karras_sample(*args, **kwargs):
|
112 |
+
last = None
|
113 |
+
x_sequence = []
|
114 |
+
# print("kraras_sample_model_kwargs", kwargs["model_kwargs"]['embeddings'].shape)
|
115 |
+
for x in karras_sample_progressive(*args, **kwargs):
|
116 |
+
last = x["x"]
|
117 |
+
x_sequence.append(last)
|
118 |
+
return last, x_sequence
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
def karras_sample_progressive(
|
123 |
+
diffusion,
|
124 |
+
model,
|
125 |
+
shape,
|
126 |
+
steps,
|
127 |
+
clip_denoised=True,
|
128 |
+
progress=False,
|
129 |
+
model_kwargs=None,
|
130 |
+
device=None,
|
131 |
+
sigma_min=0.002,
|
132 |
+
sigma_max=80, # higher for highres?
|
133 |
+
rho=7.0,
|
134 |
+
sampler="heun",
|
135 |
+
s_churn=0.0,
|
136 |
+
s_tmin=0.0,
|
137 |
+
s_tmax=float("inf"),
|
138 |
+
s_noise=1.0,
|
139 |
+
guidance_scale=0.0,
|
140 |
+
condition_latent=None,
|
141 |
+
initial_noise=None,
|
142 |
+
):
|
143 |
+
sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
|
144 |
+
# print("sigmas", sigmas.shape, sigmas)
|
145 |
+
if initial_noise is None:
|
146 |
+
x_T = th.randn(*shape, device=device) * sigma_max
|
147 |
+
else:
|
148 |
+
x_T = initial_noise.clone() * sigma_max
|
149 |
+
sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[
|
150 |
+
sampler
|
151 |
+
]
|
152 |
+
if sampler != "ancestral":
|
153 |
+
sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
|
154 |
+
else:
|
155 |
+
sampler_args = {}
|
156 |
+
|
157 |
+
if isinstance(diffusion, KarrasDenoiser):
|
158 |
+
def denoiser(x_t, sigma):
|
159 |
+
_, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
|
160 |
+
if clip_denoised:
|
161 |
+
denoised = denoised.clamp(-1, 1)
|
162 |
+
return denoised
|
163 |
+
|
164 |
+
elif isinstance(diffusion, GaussianDiffusion):
|
165 |
+
model = GaussianToKarrasDenoiser(model, diffusion)
|
166 |
+
|
167 |
+
def denoiser(x_t, sigma):
|
168 |
+
_, denoised = model.denoise(
|
169 |
+
x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latent
|
170 |
+
)
|
171 |
+
return denoised
|
172 |
+
|
173 |
+
else:
|
174 |
+
raise NotImplementedError
|
175 |
+
|
176 |
+
if guidance_scale != 0 and guidance_scale != 1:
|
177 |
+
|
178 |
+
def guided_denoiser(x_t, sigma):
|
179 |
+
x_t = th.cat([x_t, x_t], dim=0)
|
180 |
+
sigma = th.cat([sigma, sigma], dim=0)
|
181 |
+
x_0 = denoiser(x_t, sigma)
|
182 |
+
cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0)
|
183 |
+
x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0)
|
184 |
+
return x_0
|
185 |
+
|
186 |
+
else:
|
187 |
+
guided_denoiser = denoiser
|
188 |
+
|
189 |
+
for obj in sample_fn(
|
190 |
+
guided_denoiser,
|
191 |
+
x_T,
|
192 |
+
sigmas,
|
193 |
+
progress=progress,
|
194 |
+
condition_latent=condition_latent,
|
195 |
+
**sampler_args,
|
196 |
+
):
|
197 |
+
if isinstance(diffusion, GaussianDiffusion):
|
198 |
+
# print("is gaussian diffusion", obj)
|
199 |
+
yield diffusion.unscale_out_dict(obj)
|
200 |
+
else:
|
201 |
+
yield obj
|
202 |
+
|
203 |
+
|
204 |
+
def karras_sample_progressive_condition(
|
205 |
+
diffusion,
|
206 |
+
model,
|
207 |
+
shape,
|
208 |
+
steps,
|
209 |
+
clip_denoised=True,
|
210 |
+
progress=False,
|
211 |
+
model_kwargs=None,
|
212 |
+
device=None,
|
213 |
+
sigma_min=0.002,
|
214 |
+
sigma_max=80, # higher for highres?
|
215 |
+
rho=7.0,
|
216 |
+
sampler="heun",
|
217 |
+
s_churn=0.0,
|
218 |
+
s_tmin=0.0,
|
219 |
+
s_tmax=float("inf"),
|
220 |
+
s_noise=1.0,
|
221 |
+
text_guidance_scale=0.0,
|
222 |
+
image_guidance_scale=0.0,
|
223 |
+
condition_latent=None,
|
224 |
+
):
|
225 |
+
sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
|
226 |
+
x_T = th.randn(*shape, device=device) * sigma_max
|
227 |
+
sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[
|
228 |
+
sampler
|
229 |
+
]
|
230 |
+
if sampler != "ancestral":
|
231 |
+
sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
|
232 |
+
else:
|
233 |
+
sampler_args = {}
|
234 |
+
|
235 |
+
if isinstance(diffusion, KarrasDenoiser):
|
236 |
+
def denoiser(x_t, sigma):
|
237 |
+
_, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
|
238 |
+
if clip_denoised:
|
239 |
+
denoised = denoised.clamp(-1, 1)
|
240 |
+
return denoised
|
241 |
+
|
242 |
+
elif isinstance(diffusion, GaussianDiffusion):
|
243 |
+
model = GaussianToKarrasDenoiser(model, diffusion)
|
244 |
+
|
245 |
+
def denoiser(x_t, sigma):
|
246 |
+
_, denoised = model.denoise(
|
247 |
+
x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latent
|
248 |
+
)
|
249 |
+
return denoised
|
250 |
+
|
251 |
+
else:
|
252 |
+
raise NotImplementedError
|
253 |
+
|
254 |
+
if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0):
|
255 |
+
def guided_denoiser(x_t, sigma):
|
256 |
+
x_t = th.cat([x_t, x_t, x_t], dim=0)
|
257 |
+
sigma = th.cat([sigma, sigma, sigma], dim=0)
|
258 |
+
x_0 = denoiser(x_t, sigma)
|
259 |
+
# import pdb; pdb.set_trace()
|
260 |
+
cond_x_0_text, cond_x_0_image, uncond_x_0 = th.chunk(x_0, 3, dim=0)
|
261 |
+
x_0 = uncond_x_0 + text_guidance_scale * (cond_x_0_text - cond_x_0_image) + image_guidance_scale * (cond_x_0_image - uncond_x_0)
|
262 |
+
return x_0
|
263 |
+
|
264 |
+
else:
|
265 |
+
guided_denoiser = denoiser
|
266 |
+
|
267 |
+
for obj in sample_fn(
|
268 |
+
guided_denoiser,
|
269 |
+
x_T,
|
270 |
+
sigmas,
|
271 |
+
progress=progress,
|
272 |
+
condition_latent=condition_latent,
|
273 |
+
**sampler_args,
|
274 |
+
):
|
275 |
+
if isinstance(diffusion, GaussianDiffusion):
|
276 |
+
yield diffusion.unscale_out_dict(obj)
|
277 |
+
else:
|
278 |
+
yield obj
|
279 |
+
def karras_sample_addition_condition(*args, **kwargs):
|
280 |
+
last = None
|
281 |
+
x_sequence = []
|
282 |
+
for x in karras_sample_progressive_condition(*args, **kwargs):
|
283 |
+
last = x["x"]
|
284 |
+
x_sequence.append(x["pred_xstart"])
|
285 |
+
return last, x_sequence
|
286 |
+
|
287 |
+
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
|
288 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
289 |
+
ramp = th.linspace(0, 1, n)
|
290 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
291 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
292 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
293 |
+
return append_zero(sigmas).to(device)
|
294 |
+
|
295 |
+
|
296 |
+
def to_d(x, sigma, denoised):
|
297 |
+
"""Converts a denoiser output to a Karras ODE derivative."""
|
298 |
+
return (x - denoised) / append_dims(sigma, x.ndim)
|
299 |
+
|
300 |
+
|
301 |
+
def get_ancestral_step(sigma_from, sigma_to):
|
302 |
+
"""Calculates the noise level (sigma_down) to step down to and the amount
|
303 |
+
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
304 |
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
305 |
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
306 |
+
return sigma_down, sigma_up
|
307 |
+
|
308 |
+
|
309 |
+
@th.no_grad()
|
310 |
+
def sample_euler_ancestral(model, x, sigmas, progress=False):
|
311 |
+
"""Ancestral sampling with Euler method steps."""
|
312 |
+
s_in = x.new_ones([x.shape[0]])
|
313 |
+
indices = range(len(sigmas) - 1)
|
314 |
+
if progress:
|
315 |
+
from tqdm.auto import tqdm
|
316 |
+
|
317 |
+
indices = tqdm(indices)
|
318 |
+
|
319 |
+
for i in indices:
|
320 |
+
denoised = model(x, sigmas[i] * s_in)
|
321 |
+
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
322 |
+
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised}
|
323 |
+
d = to_d(x, sigmas[i], denoised)
|
324 |
+
# Euler method
|
325 |
+
dt = sigma_down - sigmas[i]
|
326 |
+
x = x + d * dt
|
327 |
+
x = x + th.randn_like(x) * sigma_up
|
328 |
+
yield {"x": x, "pred_xstart": x}
|
329 |
+
|
330 |
+
|
331 |
+
@th.no_grad()
|
332 |
+
def sample_heun(
|
333 |
+
denoiser,
|
334 |
+
x,
|
335 |
+
sigmas,
|
336 |
+
progress=False,
|
337 |
+
s_churn=0.0,
|
338 |
+
s_tmin=0.0,
|
339 |
+
s_tmax=float("inf"),
|
340 |
+
s_noise=1.0,
|
341 |
+
condition_latent=None,
|
342 |
+
):
|
343 |
+
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
344 |
+
s_in = x.new_ones([x.shape[0]])
|
345 |
+
indices = range(len(sigmas) - 1)
|
346 |
+
if progress:
|
347 |
+
from tqdm.auto import tqdm
|
348 |
+
|
349 |
+
indices = tqdm(indices)
|
350 |
+
|
351 |
+
for i in indices:
|
352 |
+
gamma = (
|
353 |
+
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
|
354 |
+
)
|
355 |
+
eps = th.randn_like(x) * s_noise
|
356 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
357 |
+
if gamma > 0:
|
358 |
+
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
359 |
+
denoised = denoiser(x, sigma_hat * s_in)
|
360 |
+
d = to_d(x, sigma_hat, denoised)
|
361 |
+
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised}
|
362 |
+
dt = sigmas[i + 1] - sigma_hat
|
363 |
+
if sigmas[i + 1] == 0:
|
364 |
+
# Euler method
|
365 |
+
x = x + d * dt
|
366 |
+
else:
|
367 |
+
# Heun's method
|
368 |
+
x_2 = x + d * dt
|
369 |
+
denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
|
370 |
+
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
371 |
+
d_prime = (d + d_2) / 2
|
372 |
+
x = x + d_prime * dt
|
373 |
+
yield {"x": x, "pred_xstart": denoised}
|
374 |
+
|
375 |
+
|
376 |
+
@th.no_grad()
|
377 |
+
def sample_dpm(
|
378 |
+
denoiser,
|
379 |
+
x,
|
380 |
+
sigmas,
|
381 |
+
progress=False,
|
382 |
+
s_churn=0.0,
|
383 |
+
s_tmin=0.0,
|
384 |
+
s_tmax=float("inf"),
|
385 |
+
s_noise=1.0,
|
386 |
+
):
|
387 |
+
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
388 |
+
s_in = x.new_ones([x.shape[0]])
|
389 |
+
indices = range(len(sigmas) - 1)
|
390 |
+
if progress:
|
391 |
+
from tqdm.auto import tqdm
|
392 |
+
|
393 |
+
indices = tqdm(indices)
|
394 |
+
|
395 |
+
for i in indices:
|
396 |
+
gamma = (
|
397 |
+
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
|
398 |
+
)
|
399 |
+
eps = th.randn_like(x) * s_noise
|
400 |
+
sigma_hat = sigmas[i] * (gamma + 1)
|
401 |
+
if gamma > 0:
|
402 |
+
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
|
403 |
+
denoised = denoiser(x, sigma_hat * s_in)
|
404 |
+
d = to_d(x, sigma_hat, denoised)
|
405 |
+
yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}
|
406 |
+
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
407 |
+
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
|
408 |
+
dt_1 = sigma_mid - sigma_hat
|
409 |
+
dt_2 = sigmas[i + 1] - sigma_hat
|
410 |
+
x_2 = x + d * dt_1
|
411 |
+
denoised_2 = denoiser(x_2, sigma_mid * s_in)
|
412 |
+
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
413 |
+
x = x + d_2 * dt_2
|
414 |
+
yield {"x": x, "pred_xstart": denoised}
|
415 |
+
|
416 |
+
|
417 |
+
def append_dims(x, target_dims):
|
418 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
419 |
+
dims_to_append = target_dims - x.ndim
|
420 |
+
if dims_to_append < 0:
|
421 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
422 |
+
return x[(...,) + (None,) * dims_to_append]
|
423 |
+
|
424 |
+
|
425 |
+
def append_zero(x):
|
426 |
+
return th.cat([x, x.new_zeros([1])])
|
shap_e/diffusion/sample.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, Optional, List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .gaussian_diffusion import GaussianDiffusion
|
7 |
+
from .k_diffusion import karras_sample, karras_sample_addition_condition
|
8 |
+
|
9 |
+
DEFAULT_KARRAS_STEPS = 64
|
10 |
+
DEFAULT_KARRAS_SIGMA_MIN = 1e-3
|
11 |
+
DEFAULT_KARRAS_SIGMA_MAX = 160
|
12 |
+
DEFAULT_KARRAS_S_CHURN = 0.0
|
13 |
+
|
14 |
+
|
15 |
+
def uncond_guide_model(
|
16 |
+
model: Callable[..., torch.Tensor], scale: float
|
17 |
+
) -> Callable[..., torch.Tensor]:
|
18 |
+
|
19 |
+
def model_fn(x_t, ts, **kwargs):
|
20 |
+
half = x_t[: len(x_t) // 2]
|
21 |
+
combined = torch.cat([half, half], dim=0)
|
22 |
+
model_out = model(combined, ts, **kwargs)
|
23 |
+
cond_out, uncond_out = torch.chunk(model_out, 2, dim=0)
|
24 |
+
cond_out = uncond_out + scale * (cond_out - uncond_out)
|
25 |
+
return torch.cat([cond_out, cond_out], dim=0)
|
26 |
+
|
27 |
+
return model_fn
|
28 |
+
|
29 |
+
|
30 |
+
def sample_latents(
|
31 |
+
*,
|
32 |
+
batch_size: int,
|
33 |
+
model: nn.Module,
|
34 |
+
diffusion: GaussianDiffusion,
|
35 |
+
model_kwargs: Dict[str, Any],
|
36 |
+
guidance_scale: float,
|
37 |
+
clip_denoised: bool,
|
38 |
+
use_fp16: bool,
|
39 |
+
use_karras: bool,
|
40 |
+
karras_steps: int,
|
41 |
+
sigma_min: float,
|
42 |
+
sigma_max: float,
|
43 |
+
s_churn: float,
|
44 |
+
device: Optional[torch.device] = None,
|
45 |
+
progress: bool = False,
|
46 |
+
initial_noise: Optional[torch.Tensor] = None,
|
47 |
+
) -> (torch.Tensor, List[torch.Tensor]):
|
48 |
+
sample_shape = (batch_size, model.d_latent)
|
49 |
+
|
50 |
+
if device is None:
|
51 |
+
device = next(model.parameters()).device
|
52 |
+
|
53 |
+
if hasattr(model, "cached_model_kwargs"):
|
54 |
+
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
|
55 |
+
if guidance_scale != 1.0 and guidance_scale != 0.0:
|
56 |
+
for k, v in model_kwargs.copy().items():
|
57 |
+
# print(k, v.shape)
|
58 |
+
model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
|
59 |
+
|
60 |
+
sample_shape = (batch_size, model.d_latent)
|
61 |
+
with torch.autocast(device_type=device.type, enabled=use_fp16):
|
62 |
+
if use_karras:
|
63 |
+
samples, sample_sequence = karras_sample(
|
64 |
+
diffusion=diffusion,
|
65 |
+
model=model,
|
66 |
+
shape=sample_shape,
|
67 |
+
steps=karras_steps,
|
68 |
+
clip_denoised=clip_denoised,
|
69 |
+
model_kwargs=model_kwargs,
|
70 |
+
device=device,
|
71 |
+
sigma_min=sigma_min,
|
72 |
+
sigma_max=sigma_max,
|
73 |
+
s_churn=s_churn,
|
74 |
+
guidance_scale=guidance_scale,
|
75 |
+
progress=progress,
|
76 |
+
initial_noise=initial_noise,
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
internal_batch_size = batch_size
|
80 |
+
if guidance_scale != 1.0:
|
81 |
+
model = uncond_guide_model(model, guidance_scale)
|
82 |
+
internal_batch_size *= 2
|
83 |
+
samples = diffusion.p_sample_loop(
|
84 |
+
model,
|
85 |
+
shape=(internal_batch_size, *sample_shape[1:]),
|
86 |
+
model_kwargs=model_kwargs,
|
87 |
+
device=device,
|
88 |
+
clip_denoised=clip_denoised,
|
89 |
+
progress=progress,
|
90 |
+
)
|
91 |
+
|
92 |
+
return samples
|
93 |
+
|
94 |
+
|
95 |
+
def sample_latents_with_additional_latent(
|
96 |
+
*,
|
97 |
+
batch_size: int,
|
98 |
+
model: nn.Module,
|
99 |
+
diffusion: GaussianDiffusion,
|
100 |
+
model_kwargs: Dict[str, Any],
|
101 |
+
text_guidance_scale: float,
|
102 |
+
image_guidance_scale: float,
|
103 |
+
clip_denoised: bool,
|
104 |
+
use_fp16: bool,
|
105 |
+
use_karras: bool,
|
106 |
+
karras_steps: int,
|
107 |
+
sigma_min: float,
|
108 |
+
sigma_max: float,
|
109 |
+
s_churn: float,
|
110 |
+
device: Optional[torch.device] = None,
|
111 |
+
progress: bool = False,
|
112 |
+
condition_latent: Optional[torch.Tensor] = None,
|
113 |
+
) -> (torch.Tensor, List[torch.Tensor]):
|
114 |
+
|
115 |
+
if device is None:
|
116 |
+
device = next(model.parameters()).device
|
117 |
+
|
118 |
+
if hasattr(model, "cached_model_kwargs"):
|
119 |
+
model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
|
120 |
+
if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0):
|
121 |
+
for k, v in model_kwargs.copy().items():
|
122 |
+
# print(k, v.shape)
|
123 |
+
model_kwargs[k] = torch.cat([v, torch.zeros_like(v), torch.zeros_like(v)], dim=0)
|
124 |
+
condition_latent = torch.cat([condition_latent, condition_latent, torch.zeros_like(condition_latent)], dim=0)
|
125 |
+
|
126 |
+
sample_shape = (batch_size, model.d_latent)
|
127 |
+
# print("sample_shape", sample_shape)
|
128 |
+
with torch.autocast(device_type=device.type, enabled=use_fp16):
|
129 |
+
if use_karras:
|
130 |
+
samples, samples_squence = karras_sample_addition_condition(
|
131 |
+
diffusion=diffusion,
|
132 |
+
model=model,
|
133 |
+
shape=sample_shape,
|
134 |
+
steps=karras_steps,
|
135 |
+
clip_denoised=clip_denoised,
|
136 |
+
model_kwargs=model_kwargs,
|
137 |
+
device=device,
|
138 |
+
sigma_min=sigma_min,
|
139 |
+
sigma_max=sigma_max,
|
140 |
+
s_churn=s_churn,
|
141 |
+
text_guidance_scale=text_guidance_scale,
|
142 |
+
image_guidance_scale=image_guidance_scale,
|
143 |
+
progress=progress,
|
144 |
+
condition_latent=condition_latent,
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
internal_batch_size = batch_size
|
148 |
+
if text_guidance_scale != 1.0:
|
149 |
+
model = uncond_guide_model(model, text_guidance_scale)
|
150 |
+
internal_batch_size *= 2
|
151 |
+
samples = diffusion.p_sample_loop(
|
152 |
+
model,
|
153 |
+
shape=(internal_batch_size, *sample_shape[1:]),
|
154 |
+
model_kwargs=model_kwargs,
|
155 |
+
device=device,
|
156 |
+
clip_denoised=clip_denoised,
|
157 |
+
progress=progress,
|
158 |
+
)
|
159 |
+
|
160 |
+
return samples
|
shap_e/examples/encode_model.ipynb
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import torch\n",
|
10 |
+
"\n",
|
11 |
+
"from shap_e.models.download import load_model\n",
|
12 |
+
"from shap_e.util.data_util import load_or_create_multimodal_batch\n",
|
13 |
+
"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
|
14 |
+
]
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"cell_type": "code",
|
18 |
+
"execution_count": 2,
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": null,
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"xm = load_model('transmitter', device=device)"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 3,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"model_path = \"example_data/cactus/object.obj\"\n",
|
41 |
+
"\n",
|
42 |
+
"# This may take a few minutes, since it requires rendering the model twice\n",
|
43 |
+
"# in two different modes.\n",
|
44 |
+
"batch = load_or_create_multimodal_batch(\n",
|
45 |
+
" device,\n",
|
46 |
+
" model_path=model_path,\n",
|
47 |
+
" mv_light_mode=\"basic\",\n",
|
48 |
+
" mv_image_size=256,\n",
|
49 |
+
" cache_dir=\"example_data/cactus/cached\",\n",
|
50 |
+
" verbose=True, # this will show Blender output during renders\n",
|
51 |
+
")"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": null,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"with torch.no_grad():\n",
|
61 |
+
" latent = xm.encoder.encode_to_bottleneck(batch)\n",
|
62 |
+
"\n",
|
63 |
+
" render_mode = 'stf' # you can change this to 'nerf'\n",
|
64 |
+
" size = 128 # recommended that you lower resolution when using nerf\n",
|
65 |
+
"\n",
|
66 |
+
" cameras = create_pan_cameras(size, device)\n",
|
67 |
+
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
|
68 |
+
" display(gif_widget(images))"
|
69 |
+
]
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"metadata": {
|
73 |
+
"kernelspec": {
|
74 |
+
"display_name": "Python 3 (ipykernel)",
|
75 |
+
"language": "python",
|
76 |
+
"name": "python3"
|
77 |
+
},
|
78 |
+
"language_info": {
|
79 |
+
"codemirror_mode": {
|
80 |
+
"name": "ipython",
|
81 |
+
"version": 3
|
82 |
+
},
|
83 |
+
"file_extension": ".py",
|
84 |
+
"mimetype": "text/x-python",
|
85 |
+
"name": "python",
|
86 |
+
"nbconvert_exporter": "python",
|
87 |
+
"pygments_lexer": "ipython3",
|
88 |
+
"version": "3.9.9"
|
89 |
+
}
|
90 |
+
},
|
91 |
+
"nbformat": 4,
|
92 |
+
"nbformat_minor": 5
|
93 |
+
}
|
shap_e/examples/sample_image_to_3d.ipynb
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "964ccced",
|
7 |
+
"metadata": {
|
8 |
+
"pycharm": {
|
9 |
+
"is_executing": true
|
10 |
+
}
|
11 |
+
},
|
12 |
+
"outputs": [],
|
13 |
+
"source": [
|
14 |
+
"import torch\n",
|
15 |
+
"\n",
|
16 |
+
"from shap_e.diffusion.sample import sample_latents\n",
|
17 |
+
"from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
|
18 |
+
"from shap_e.models.download import load_model, load_config\n",
|
19 |
+
"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\n",
|
20 |
+
"from shap_e.util.image_util import load_image"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"id": "8eed3a76",
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": null,
|
36 |
+
"id": "2d922637",
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"xm = load_model('transmitter', device=device)\n",
|
41 |
+
"model = load_model('image300M', device=device)\n",
|
42 |
+
"diffusion = diffusion_from_config(load_config('diffusion'))"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": null,
|
48 |
+
"id": "53d329d0",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"batch_size = 4\n",
|
53 |
+
"guidance_scale = 3.0\n",
|
54 |
+
"\n",
|
55 |
+
"image = load_image(\"example_data/corgi.png\")\n",
|
56 |
+
"\n",
|
57 |
+
"latents = sample_latents(\n",
|
58 |
+
" batch_size=batch_size,\n",
|
59 |
+
" model=model,\n",
|
60 |
+
" diffusion=diffusion,\n",
|
61 |
+
" guidance_scale=guidance_scale,\n",
|
62 |
+
" model_kwargs=dict(images=[image] * batch_size),\n",
|
63 |
+
" progress=True,\n",
|
64 |
+
" clip_denoised=True,\n",
|
65 |
+
" use_fp16=True,\n",
|
66 |
+
" use_karras=True,\n",
|
67 |
+
" karras_steps=64,\n",
|
68 |
+
" sigma_min=1e-3,\n",
|
69 |
+
" sigma_max=160,\n",
|
70 |
+
" s_churn=0,\n",
|
71 |
+
")"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"execution_count": null,
|
77 |
+
"id": "633da2ec",
|
78 |
+
"metadata": {
|
79 |
+
"pycharm": {
|
80 |
+
"is_executing": true
|
81 |
+
}
|
82 |
+
},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"render_mode = 'nerf' # you can change this to 'stf' for mesh rendering\n",
|
86 |
+
"size = 64 # this is the size of the renders; higher values take longer to render.\n",
|
87 |
+
"\n",
|
88 |
+
"cameras = create_pan_cameras(size, device)\n",
|
89 |
+
"for i, latent in enumerate(latents):\n",
|
90 |
+
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
|
91 |
+
" display(gif_widget(images))"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": null,
|
97 |
+
"outputs": [],
|
98 |
+
"source": [],
|
99 |
+
"metadata": {
|
100 |
+
"collapsed": false
|
101 |
+
}
|
102 |
+
}
|
103 |
+
],
|
104 |
+
"metadata": {
|
105 |
+
"kernelspec": {
|
106 |
+
"display_name": "Python 3 (ipykernel)",
|
107 |
+
"language": "python",
|
108 |
+
"name": "python3"
|
109 |
+
},
|
110 |
+
"language_info": {
|
111 |
+
"codemirror_mode": {
|
112 |
+
"name": "ipython",
|
113 |
+
"version": 3
|
114 |
+
},
|
115 |
+
"file_extension": ".py",
|
116 |
+
"mimetype": "text/x-python",
|
117 |
+
"name": "python",
|
118 |
+
"nbconvert_exporter": "python",
|
119 |
+
"pygments_lexer": "ipython3",
|
120 |
+
"version": "3.9.9"
|
121 |
+
}
|
122 |
+
},
|
123 |
+
"nbformat": 4,
|
124 |
+
"nbformat_minor": 5
|
125 |
+
}
|
shap_e/examples/sample_text_to_3d.ipynb
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "964ccced",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import torch\n",
|
11 |
+
"\n",
|
12 |
+
"from shap_e.diffusion.sample import sample_latents\n",
|
13 |
+
"from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
|
14 |
+
"from shap_e.models.download import load_model, load_config\n",
|
15 |
+
"from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "code",
|
20 |
+
"execution_count": null,
|
21 |
+
"id": "8eed3a76",
|
22 |
+
"metadata": {},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"id": "2d922637",
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"xm = load_model('transmitter', device=device)\n",
|
36 |
+
"model = load_model('text300M', device=device)\n",
|
37 |
+
"diffusion = diffusion_from_config(load_config('diffusion'))"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"id": "53d329d0",
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"batch_size = 4\n",
|
48 |
+
"guidance_scale = 15.0\n",
|
49 |
+
"prompt = \"a shark\"\n",
|
50 |
+
"\n",
|
51 |
+
"latents = sample_latents(\n",
|
52 |
+
" batch_size=batch_size,\n",
|
53 |
+
" model=model,\n",
|
54 |
+
" diffusion=diffusion,\n",
|
55 |
+
" guidance_scale=guidance_scale,\n",
|
56 |
+
" model_kwargs=dict(texts=[prompt] * batch_size),\n",
|
57 |
+
" progress=True,\n",
|
58 |
+
" clip_denoised=True,\n",
|
59 |
+
" use_fp16=True,\n",
|
60 |
+
" use_karras=True,\n",
|
61 |
+
" karras_steps=64,\n",
|
62 |
+
" sigma_min=1e-3,\n",
|
63 |
+
" sigma_max=160,\n",
|
64 |
+
" s_churn=0,\n",
|
65 |
+
")"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": null,
|
71 |
+
"id": "633da2ec",
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"render_mode = 'nerf' # you can change this to 'stf'\n",
|
76 |
+
"size = 64 # this is the size of the renders; higher values take longer to render.\n",
|
77 |
+
"\n",
|
78 |
+
"cameras = create_pan_cameras(size, device)\n",
|
79 |
+
"for i, latent in enumerate(latents):\n",
|
80 |
+
" images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
|
81 |
+
" display(gif_widget(images))"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": null,
|
87 |
+
"id": "85a4dce4",
|
88 |
+
"metadata": {},
|
89 |
+
"outputs": [],
|
90 |
+
"source": [
|
91 |
+
"# Example of saving the latents as meshes.\n",
|
92 |
+
"from shap_e.util.notebooks import decode_latent_mesh\n",
|
93 |
+
"\n",
|
94 |
+
"for i, latent in enumerate(latents):\n",
|
95 |
+
" t = decode_latent_mesh(xm, latent).tri_mesh()\n",
|
96 |
+
" with open(f'example_mesh_{i}.ply', 'wb') as f:\n",
|
97 |
+
" t.write_ply(f)\n",
|
98 |
+
" with open(f'example_mesh_{i}.obj', 'w') as f:\n",
|
99 |
+
" t.write_obj(f)"
|
100 |
+
]
|
101 |
+
}
|
102 |
+
],
|
103 |
+
"metadata": {
|
104 |
+
"kernelspec": {
|
105 |
+
"display_name": "Python 3 (ipykernel)",
|
106 |
+
"language": "python",
|
107 |
+
"name": "python3"
|
108 |
+
},
|
109 |
+
"language_info": {
|
110 |
+
"codemirror_mode": {
|
111 |
+
"name": "ipython",
|
112 |
+
"version": 3
|
113 |
+
},
|
114 |
+
"file_extension": ".py",
|
115 |
+
"mimetype": "text/x-python",
|
116 |
+
"name": "python",
|
117 |
+
"nbconvert_exporter": "python",
|
118 |
+
"pygments_lexer": "ipython3",
|
119 |
+
"version": "3.11.3"
|
120 |
+
}
|
121 |
+
},
|
122 |
+
"nbformat": 4,
|
123 |
+
"nbformat_minor": 5
|
124 |
+
}
|
shap_e/models/__init__.py
ADDED
File without changes
|
shap_e/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (163 Bytes). View file
|
|
shap_e/models/__pycache__/configs.cpython-39.pyc
ADDED
Binary file (5 kB). View file
|
|
shap_e/models/__pycache__/download.cpython-39.pyc
ADDED
Binary file (5.17 kB). View file
|
|
shap_e/models/__pycache__/query.cpython-39.pyc
ADDED
Binary file (1.05 kB). View file
|
|
shap_e/models/__pycache__/renderer.cpython-39.pyc
ADDED
Binary file (10.8 kB). View file
|
|
shap_e/models/__pycache__/volume.cpython-39.pyc
ADDED
Binary file (7.64 kB). View file
|
|
shap_e/models/configs.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Union
|
2 |
+
|
3 |
+
import blobfile as bf
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import yaml
|
7 |
+
|
8 |
+
from shap_e.models.generation.latent_diffusion import SplitVectorDiffusion
|
9 |
+
from shap_e.models.generation.perceiver import PointDiffusionPerceiver
|
10 |
+
from shap_e.models.generation.pooled_mlp import PooledMLP
|
11 |
+
from shap_e.models.generation.transformer import (
|
12 |
+
CLIPImageGridPointDiffusionTransformer,
|
13 |
+
CLIPImageGridUpsamplePointDiffusionTransformer,
|
14 |
+
CLIPImagePointDiffusionTransformer,
|
15 |
+
PointDiffusionTransformer,
|
16 |
+
UpsamplePointDiffusionTransformer,
|
17 |
+
)
|
18 |
+
from shap_e.models.nerf.model import MLPNeRFModel, VoidNeRFModel
|
19 |
+
from shap_e.models.nerf.renderer import OneStepNeRFRenderer, TwoStepNeRFRenderer
|
20 |
+
from shap_e.models.nerstf.mlp import MLPDensitySDFModel, MLPNeRSTFModel
|
21 |
+
from shap_e.models.nerstf.renderer import NeRSTFRenderer
|
22 |
+
from shap_e.models.nn.meta import batch_meta_state_dict
|
23 |
+
from shap_e.models.stf.mlp import MLPSDFModel, MLPTextureFieldModel
|
24 |
+
from shap_e.models.stf.renderer import STFRenderer
|
25 |
+
from shap_e.models.transmitter.base import ChannelsDecoder, Transmitter, VectorDecoder
|
26 |
+
from shap_e.models.transmitter.channels_encoder import (
|
27 |
+
PointCloudPerceiverChannelsEncoder,
|
28 |
+
PointCloudTransformerChannelsEncoder,
|
29 |
+
)
|
30 |
+
from shap_e.models.transmitter.multiview_encoder import MultiviewTransformerEncoder
|
31 |
+
from shap_e.models.transmitter.pc_encoder import (
|
32 |
+
PointCloudPerceiverEncoder,
|
33 |
+
PointCloudTransformerEncoder,
|
34 |
+
)
|
35 |
+
from shap_e.models.volume import BoundingBoxVolume, SphericalVolume, UnboundedVolume
|
36 |
+
|
37 |
+
|
38 |
+
def model_from_config(config: Union[str, Dict[str, Any]], device: torch.device) -> nn.Module:
|
39 |
+
print(config)
|
40 |
+
if isinstance(config, str):
|
41 |
+
print("config", config)
|
42 |
+
with bf.BlobFile(config, "rb") as f:
|
43 |
+
obj = yaml.load(f, Loader=yaml.SafeLoader)
|
44 |
+
return model_from_config(obj, device=device)
|
45 |
+
|
46 |
+
config = config.copy()
|
47 |
+
name = config.pop("name")
|
48 |
+
|
49 |
+
if name == "PointCloudTransformerEncoder":
|
50 |
+
return PointCloudTransformerEncoder(device=device, dtype=torch.float32, **config)
|
51 |
+
elif name == "PointCloudPerceiverEncoder":
|
52 |
+
return PointCloudPerceiverEncoder(device=device, dtype=torch.float32, **config)
|
53 |
+
elif name == "PointCloudTransformerChannelsEncoder":
|
54 |
+
return PointCloudTransformerChannelsEncoder(device=device, dtype=torch.float32, **config)
|
55 |
+
elif name == "PointCloudPerceiverChannelsEncoder":
|
56 |
+
return PointCloudPerceiverChannelsEncoder(device=device, dtype=torch.float32, **config)
|
57 |
+
elif name == "MultiviewTransformerEncoder":
|
58 |
+
return MultiviewTransformerEncoder(device=device, dtype=torch.float32, **config)
|
59 |
+
elif name == "Transmitter":
|
60 |
+
renderer = model_from_config(config.pop("renderer"), device=device)
|
61 |
+
param_shapes = {
|
62 |
+
k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
|
63 |
+
}
|
64 |
+
encoder_config = config.pop("encoder").copy()
|
65 |
+
encoder_config["param_shapes"] = param_shapes
|
66 |
+
encoder = model_from_config(encoder_config, device=device)
|
67 |
+
return Transmitter(encoder=encoder, renderer=renderer, **config)
|
68 |
+
elif name == "VectorDecoder":
|
69 |
+
renderer = model_from_config(config.pop("renderer"), device=device)
|
70 |
+
param_shapes = {
|
71 |
+
k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
|
72 |
+
}
|
73 |
+
return VectorDecoder(param_shapes=param_shapes, renderer=renderer, device=device, **config)
|
74 |
+
elif name == "ChannelsDecoder":
|
75 |
+
renderer = model_from_config(config.pop("renderer"), device=device)
|
76 |
+
param_shapes = {
|
77 |
+
k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
|
78 |
+
}
|
79 |
+
return ChannelsDecoder(
|
80 |
+
param_shapes=param_shapes, renderer=renderer, device=device, **config
|
81 |
+
)
|
82 |
+
elif name == "OneStepNeRFRenderer":
|
83 |
+
config = config.copy()
|
84 |
+
for field in [
|
85 |
+
# Required
|
86 |
+
"void_model",
|
87 |
+
"foreground_model",
|
88 |
+
"volume",
|
89 |
+
# Optional to use NeRF++
|
90 |
+
"background_model",
|
91 |
+
"outer_volume",
|
92 |
+
]:
|
93 |
+
if field in config:
|
94 |
+
config[field] = model_from_config(config.pop(field).copy(), device)
|
95 |
+
return OneStepNeRFRenderer(device=device, **config)
|
96 |
+
elif name == "TwoStepNeRFRenderer":
|
97 |
+
config = config.copy()
|
98 |
+
for field in [
|
99 |
+
# Required
|
100 |
+
"void_model",
|
101 |
+
"coarse_model",
|
102 |
+
"fine_model",
|
103 |
+
"volume",
|
104 |
+
# Optional to use NeRF++
|
105 |
+
"coarse_background_model",
|
106 |
+
"fine_background_model",
|
107 |
+
"outer_volume",
|
108 |
+
]:
|
109 |
+
if field in config:
|
110 |
+
config[field] = model_from_config(config.pop(field).copy(), device)
|
111 |
+
return TwoStepNeRFRenderer(device=device, **config)
|
112 |
+
elif name == "PooledMLP":
|
113 |
+
return PooledMLP(device, **config)
|
114 |
+
elif name == "PointDiffusionTransformer":
|
115 |
+
return PointDiffusionTransformer(device=device, dtype=torch.float32, **config)
|
116 |
+
elif name == "PointDiffusionPerceiver":
|
117 |
+
return PointDiffusionPerceiver(device=device, dtype=torch.float32, **config)
|
118 |
+
elif name == "CLIPImagePointDiffusionTransformer":
|
119 |
+
return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
|
120 |
+
elif name == "CLIPImageGridPointDiffusionTransformer":
|
121 |
+
return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config)
|
122 |
+
elif name == "UpsamplePointDiffusionTransformer":
|
123 |
+
return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
|
124 |
+
elif name == "CLIPImageGridUpsamplePointDiffusionTransformer":
|
125 |
+
return CLIPImageGridUpsamplePointDiffusionTransformer(
|
126 |
+
device=device, dtype=torch.float32, **config
|
127 |
+
)
|
128 |
+
elif name == "SplitVectorDiffusion":
|
129 |
+
inner_config = config.pop("inner")
|
130 |
+
d_latent = config.pop("d_latent")
|
131 |
+
latent_ctx = config.pop("latent_ctx", 1)
|
132 |
+
inner_config["input_channels"] = d_latent // latent_ctx
|
133 |
+
inner_config["n_ctx"] = latent_ctx
|
134 |
+
inner_config["output_channels"] = d_latent // latent_ctx * 2
|
135 |
+
inner_model = model_from_config(inner_config, device)
|
136 |
+
return SplitVectorDiffusion(
|
137 |
+
device=device, wrapped=inner_model, n_ctx=latent_ctx, d_latent=d_latent
|
138 |
+
)
|
139 |
+
elif name == "STFRenderer":
|
140 |
+
config = config.copy()
|
141 |
+
for field in ["sdf", "tf", "volume"]:
|
142 |
+
config[field] = model_from_config(config.pop(field), device)
|
143 |
+
return STFRenderer(device=device, **config)
|
144 |
+
elif name == "NeRSTFRenderer":
|
145 |
+
config = config.copy()
|
146 |
+
for field in ["sdf", "tf", "nerstf", "void", "volume"]:
|
147 |
+
if field not in config:
|
148 |
+
continue
|
149 |
+
config[field] = model_from_config(config.pop(field), device)
|
150 |
+
config.setdefault("sdf", None)
|
151 |
+
config.setdefault("tf", None)
|
152 |
+
config.setdefault("nerstf", None)
|
153 |
+
return NeRSTFRenderer(device=device, **config)
|
154 |
+
|
155 |
+
model_cls = {
|
156 |
+
"MLPSDFModel": MLPSDFModel,
|
157 |
+
"MLPTextureFieldModel": MLPTextureFieldModel,
|
158 |
+
"MLPNeRFModel": MLPNeRFModel,
|
159 |
+
"MLPDensitySDFModel": MLPDensitySDFModel,
|
160 |
+
"MLPNeRSTFModel": MLPNeRSTFModel,
|
161 |
+
"VoidNeRFModel": VoidNeRFModel,
|
162 |
+
"BoundingBoxVolume": BoundingBoxVolume,
|
163 |
+
"SphericalVolume": SphericalVolume,
|
164 |
+
"UnboundedVolume": UnboundedVolume,
|
165 |
+
}[name]
|
166 |
+
return model_cls(device=device, **config)
|
shap_e/models/download.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import hashlib
|
6 |
+
import os
|
7 |
+
from functools import lru_cache
|
8 |
+
from typing import Dict, Optional
|
9 |
+
|
10 |
+
import requests
|
11 |
+
import torch
|
12 |
+
import yaml
|
13 |
+
from filelock import FileLock
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
|
16 |
+
MODEL_PATHS = {
|
17 |
+
"transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt",
|
18 |
+
"decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt",
|
19 |
+
"text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt",
|
20 |
+
"image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt",
|
21 |
+
}
|
22 |
+
|
23 |
+
CONFIG_PATHS = {
|
24 |
+
"transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml",
|
25 |
+
"decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml",
|
26 |
+
"text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml",
|
27 |
+
"image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml",
|
28 |
+
"diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml",
|
29 |
+
}
|
30 |
+
|
31 |
+
URL_HASHES = {
|
32 |
+
"https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b",
|
33 |
+
"https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98",
|
34 |
+
"https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4",
|
35 |
+
"https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa",
|
36 |
+
"https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e",
|
37 |
+
"https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c",
|
38 |
+
"https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1",
|
39 |
+
"https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0",
|
40 |
+
"https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57",
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
@lru_cache()
|
45 |
+
def default_cache_dir() -> str:
|
46 |
+
return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache")
|
47 |
+
|
48 |
+
|
49 |
+
def fetch_file_cached(
|
50 |
+
url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
|
51 |
+
) -> str:
|
52 |
+
"""
|
53 |
+
Download the file at the given URL into a local file and return the path.
|
54 |
+
If cache_dir is specified, it will be used to download the files.
|
55 |
+
Otherwise, default_cache_dir() is used.
|
56 |
+
"""
|
57 |
+
expected_hash = URL_HASHES[url]
|
58 |
+
|
59 |
+
if cache_dir is None:
|
60 |
+
cache_dir = default_cache_dir()
|
61 |
+
os.makedirs(cache_dir, exist_ok=True)
|
62 |
+
local_path = os.path.join(cache_dir, url.split("/")[-1])
|
63 |
+
if os.path.exists(local_path):
|
64 |
+
check_hash(local_path, expected_hash)
|
65 |
+
return local_path
|
66 |
+
|
67 |
+
response = requests.get(url, stream=True)
|
68 |
+
size = int(response.headers.get("content-length", "0"))
|
69 |
+
with FileLock(local_path + ".lock"):
|
70 |
+
if progress:
|
71 |
+
pbar = tqdm(total=size, unit="iB", unit_scale=True)
|
72 |
+
tmp_path = local_path + ".tmp"
|
73 |
+
with open(tmp_path, "wb") as f:
|
74 |
+
for chunk in response.iter_content(chunk_size):
|
75 |
+
if progress:
|
76 |
+
pbar.update(len(chunk))
|
77 |
+
f.write(chunk)
|
78 |
+
os.rename(tmp_path, local_path)
|
79 |
+
if progress:
|
80 |
+
pbar.close()
|
81 |
+
check_hash(local_path, expected_hash)
|
82 |
+
return local_path
|
83 |
+
|
84 |
+
|
85 |
+
def check_hash(path: str, expected_hash: str):
|
86 |
+
actual_hash = hash_file(path)
|
87 |
+
if actual_hash != expected_hash:
|
88 |
+
raise RuntimeError(
|
89 |
+
f"The file {path} should have hash {expected_hash} but has {actual_hash}. "
|
90 |
+
"Try deleting it and running this call again."
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def hash_file(path: str) -> str:
|
95 |
+
sha256_hash = hashlib.sha256()
|
96 |
+
with open(path, "rb") as file:
|
97 |
+
while True:
|
98 |
+
data = file.read(4096)
|
99 |
+
if not len(data):
|
100 |
+
break
|
101 |
+
sha256_hash.update(data)
|
102 |
+
return sha256_hash.hexdigest()
|
103 |
+
|
104 |
+
|
105 |
+
def load_config(
|
106 |
+
config_name: str,
|
107 |
+
progress: bool = False,
|
108 |
+
cache_dir: Optional[str] = None,
|
109 |
+
chunk_size: int = 4096,
|
110 |
+
):
|
111 |
+
if config_name not in CONFIG_PATHS:
|
112 |
+
raise ValueError(
|
113 |
+
f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}."
|
114 |
+
)
|
115 |
+
path = fetch_file_cached(
|
116 |
+
CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
|
117 |
+
)
|
118 |
+
with open(path, "r") as f:
|
119 |
+
return yaml.safe_load(f)
|
120 |
+
|
121 |
+
|
122 |
+
def load_checkpoint(
|
123 |
+
checkpoint_name: str,
|
124 |
+
device: torch.device,
|
125 |
+
progress: bool = True,
|
126 |
+
cache_dir: Optional[str] = None,
|
127 |
+
chunk_size: int = 4096,
|
128 |
+
) -> Dict[str, torch.Tensor]:
|
129 |
+
if checkpoint_name not in MODEL_PATHS:
|
130 |
+
raise ValueError(
|
131 |
+
f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
|
132 |
+
)
|
133 |
+
print(checkpoint_name)
|
134 |
+
path = fetch_file_cached(
|
135 |
+
MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
|
136 |
+
)
|
137 |
+
return torch.load(path, map_location=device)
|
138 |
+
|
139 |
+
|
140 |
+
def load_model(
|
141 |
+
model_name: str,
|
142 |
+
device: torch.device,
|
143 |
+
**kwargs,
|
144 |
+
) -> Dict[str, torch.Tensor]:
|
145 |
+
from .configs import model_from_config
|
146 |
+
|
147 |
+
model = model_from_config(load_config(model_name, **kwargs), device=device)
|
148 |
+
# print(model_name, kwargs)
|
149 |
+
# print(model)
|
150 |
+
model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs))
|
151 |
+
model.eval()
|
152 |
+
return model
|
shap_e/models/generation/__init__.py
ADDED
File without changes
|
shap_e/models/generation/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (174 Bytes). View file
|
|
shap_e/models/generation/__pycache__/latent_diffusion.cpython-39.pyc
ADDED
Binary file (1.44 kB). View file
|
|
shap_e/models/generation/__pycache__/perceiver.cpython-39.pyc
ADDED
Binary file (6.73 kB). View file
|
|
shap_e/models/generation/__pycache__/pooled_mlp.cpython-39.pyc
ADDED
Binary file (2.72 kB). View file
|
|
shap_e/models/generation/__pycache__/pretrained_clip.cpython-39.pyc
ADDED
Binary file (9.69 kB). View file
|
|
shap_e/models/generation/__pycache__/transformer.cpython-39.pyc
ADDED
Binary file (19.3 kB). View file
|
|
shap_e/models/generation/__pycache__/util.cpython-39.pyc
ADDED
Binary file (1.06 kB). View file
|
|
shap_e/models/generation/latent_diffusion.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from typing import Any, Callable, Dict, Optional
|
6 |
+
|
7 |
+
|
8 |
+
class SplitVectorDiffusion(nn.Module):
|
9 |
+
def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int):
|
10 |
+
super().__init__()
|
11 |
+
self.device = device
|
12 |
+
self.n_ctx = n_ctx
|
13 |
+
self.d_latent = d_latent
|
14 |
+
self.wrapped = wrapped
|
15 |
+
|
16 |
+
if hasattr(self.wrapped, "cached_model_kwargs"):
|
17 |
+
self.cached_model_kwargs = self.wrapped.cached_model_kwargs
|
18 |
+
|
19 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor, conditional_latent: Optional[torch.Tensor] = None, **kwargs):
|
20 |
+
h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1)
|
21 |
+
if conditional_latent is not None:
|
22 |
+
conditional_latent = conditional_latent.reshape(conditional_latent.shape[0], self.n_ctx, -1)
|
23 |
+
h = torch.cat([h.permute(0, 2, 1) , conditional_latent], dim=-1).permute(0, 2, 1) # (batch_size, n_ctx, channel) -> (batch_size, d_latent, n_ctx)
|
24 |
+
h = self.wrapped(h, t, **kwargs)
|
25 |
+
eps, var = torch.chunk(h, 2, dim=1)
|
26 |
+
return torch.cat(
|
27 |
+
[
|
28 |
+
eps.permute(0, 2, 1).flatten(1),
|
29 |
+
var.permute(0, 2, 1).flatten(1),
|
30 |
+
],
|
31 |
+
dim=1,
|
32 |
+
)
|
shap_e/models/generation/perceiver.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from shap_e.models.nn.checkpoint import checkpoint
|
8 |
+
|
9 |
+
from .transformer import MLP, Transformer, init_linear
|
10 |
+
from .util import timestep_embedding
|
11 |
+
|
12 |
+
|
13 |
+
class MultiheadCrossAttention(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
*,
|
17 |
+
device: torch.device,
|
18 |
+
dtype: torch.dtype,
|
19 |
+
n_ctx: int,
|
20 |
+
n_data: int,
|
21 |
+
width: int,
|
22 |
+
heads: int,
|
23 |
+
init_scale: float,
|
24 |
+
data_width: Optional[int] = None,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.n_ctx = n_ctx
|
28 |
+
self.n_data = n_data
|
29 |
+
self.width = width
|
30 |
+
self.heads = heads
|
31 |
+
self.data_width = width if data_width is None else data_width
|
32 |
+
self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
|
33 |
+
self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype)
|
34 |
+
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
35 |
+
self.attention = QKVMultiheadCrossAttention(
|
36 |
+
device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, n_data=n_data
|
37 |
+
)
|
38 |
+
init_linear(self.c_q, init_scale)
|
39 |
+
init_linear(self.c_kv, init_scale)
|
40 |
+
init_linear(self.c_proj, init_scale)
|
41 |
+
|
42 |
+
def forward(self, x, data):
|
43 |
+
x = self.c_q(x)
|
44 |
+
data = self.c_kv(data)
|
45 |
+
x = checkpoint(self.attention, (x, data), (), True)
|
46 |
+
x = self.c_proj(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
class QKVMultiheadCrossAttention(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, n_data: int
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
self.device = device
|
56 |
+
self.dtype = dtype
|
57 |
+
self.heads = heads
|
58 |
+
self.n_ctx = n_ctx
|
59 |
+
self.n_data = n_data
|
60 |
+
|
61 |
+
def forward(self, q, kv):
|
62 |
+
_, n_ctx, _ = q.shape
|
63 |
+
bs, n_data, width = kv.shape
|
64 |
+
attn_ch = width // self.heads // 2
|
65 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
66 |
+
q = q.view(bs, n_ctx, self.heads, -1)
|
67 |
+
kv = kv.view(bs, n_data, self.heads, -1)
|
68 |
+
k, v = torch.split(kv, attn_ch, dim=-1)
|
69 |
+
weight = torch.einsum(
|
70 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
71 |
+
) # More stable with f16 than dividing afterwards
|
72 |
+
wdtype = weight.dtype
|
73 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
74 |
+
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
75 |
+
|
76 |
+
|
77 |
+
class ResidualCrossAttentionBlock(nn.Module):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
*,
|
81 |
+
device: torch.device,
|
82 |
+
dtype: torch.dtype,
|
83 |
+
n_ctx: int,
|
84 |
+
n_data: int,
|
85 |
+
width: int,
|
86 |
+
heads: int,
|
87 |
+
data_width: Optional[int] = None,
|
88 |
+
init_scale: float = 1.0,
|
89 |
+
):
|
90 |
+
super().__init__()
|
91 |
+
|
92 |
+
if data_width is None:
|
93 |
+
data_width = width
|
94 |
+
|
95 |
+
self.attn = MultiheadCrossAttention(
|
96 |
+
device=device,
|
97 |
+
dtype=dtype,
|
98 |
+
n_ctx=n_ctx,
|
99 |
+
n_data=n_data,
|
100 |
+
width=width,
|
101 |
+
heads=heads,
|
102 |
+
data_width=data_width,
|
103 |
+
init_scale=init_scale,
|
104 |
+
)
|
105 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
106 |
+
self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
|
107 |
+
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
|
108 |
+
self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
|
109 |
+
|
110 |
+
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
111 |
+
x = x + self.attn(self.ln_1(x), self.ln_2(data))
|
112 |
+
x = x + self.mlp(self.ln_3(x))
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
class SimplePerceiver(nn.Module):
|
117 |
+
"""
|
118 |
+
Only does cross attention
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
*,
|
124 |
+
device: torch.device,
|
125 |
+
dtype: torch.dtype,
|
126 |
+
n_ctx: int,
|
127 |
+
n_data: int,
|
128 |
+
width: int,
|
129 |
+
layers: int,
|
130 |
+
heads: int,
|
131 |
+
init_scale: float = 0.25,
|
132 |
+
data_width: Optional[int] = None,
|
133 |
+
):
|
134 |
+
super().__init__()
|
135 |
+
self.n_ctx = n_ctx
|
136 |
+
self.width = width
|
137 |
+
self.layers = layers
|
138 |
+
init_scale = init_scale * math.sqrt(1.0 / width)
|
139 |
+
self.resblocks = nn.ModuleList(
|
140 |
+
[
|
141 |
+
ResidualCrossAttentionBlock(
|
142 |
+
device=device,
|
143 |
+
dtype=dtype,
|
144 |
+
n_ctx=n_ctx,
|
145 |
+
n_data=n_data,
|
146 |
+
width=width,
|
147 |
+
heads=heads,
|
148 |
+
init_scale=init_scale,
|
149 |
+
data_width=data_width,
|
150 |
+
)
|
151 |
+
for _ in range(layers)
|
152 |
+
]
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x: torch.Tensor, data: torch.Tensor):
|
156 |
+
for block in self.resblocks:
|
157 |
+
x = block(x, data)
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
class PointDiffusionPerceiver(nn.Module):
|
162 |
+
def __init__(
|
163 |
+
self,
|
164 |
+
*,
|
165 |
+
device: torch.device,
|
166 |
+
dtype: torch.dtype,
|
167 |
+
input_channels: int = 3,
|
168 |
+
output_channels: int = 3,
|
169 |
+
n_ctx: int = 1024,
|
170 |
+
n_latent: int = 128,
|
171 |
+
width: int = 512,
|
172 |
+
encoder_layers: int = 12,
|
173 |
+
latent_layers: int = 12,
|
174 |
+
decoder_layers: int = 12,
|
175 |
+
heads: int = 8,
|
176 |
+
init_scale: float = 0.25,
|
177 |
+
):
|
178 |
+
super().__init__()
|
179 |
+
self.time_embed = MLP(
|
180 |
+
device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
|
181 |
+
)
|
182 |
+
self.latent_embed = MLP(
|
183 |
+
device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
|
184 |
+
)
|
185 |
+
self.n_latent = n_latent
|
186 |
+
|
187 |
+
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
|
188 |
+
self.encoder = SimplePerceiver(
|
189 |
+
device=device,
|
190 |
+
dtype=dtype,
|
191 |
+
n_ctx=n_latent,
|
192 |
+
n_data=n_ctx,
|
193 |
+
width=width,
|
194 |
+
layers=encoder_layers,
|
195 |
+
heads=heads,
|
196 |
+
init_scale=init_scale,
|
197 |
+
)
|
198 |
+
self.processor = Transformer(
|
199 |
+
device=device,
|
200 |
+
dtype=dtype,
|
201 |
+
n_ctx=n_latent,
|
202 |
+
width=width,
|
203 |
+
layers=latent_layers,
|
204 |
+
heads=heads,
|
205 |
+
init_scale=init_scale,
|
206 |
+
)
|
207 |
+
self.decoder = SimplePerceiver(
|
208 |
+
device=device,
|
209 |
+
dtype=dtype,
|
210 |
+
n_ctx=n_ctx,
|
211 |
+
n_data=n_latent,
|
212 |
+
width=width,
|
213 |
+
layers=decoder_layers,
|
214 |
+
heads=heads,
|
215 |
+
init_scale=init_scale,
|
216 |
+
)
|
217 |
+
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
218 |
+
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
|
219 |
+
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
|
220 |
+
with torch.no_grad():
|
221 |
+
self.output_proj.weight.zero_()
|
222 |
+
self.output_proj.bias.zero_()
|
223 |
+
|
224 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
225 |
+
"""
|
226 |
+
:param x: an [N x C x T] tensor.
|
227 |
+
:param t: an [N] tensor.
|
228 |
+
:return: an [N x C' x T] tensor.
|
229 |
+
"""
|
230 |
+
assert x.shape[-1] == self.decoder.n_ctx
|
231 |
+
t_embed = self.time_embed(timestep_embedding(t, self.encoder.width))
|
232 |
+
data = self.input_proj(x.permute(0, 2, 1)) + t_embed[:, None]
|
233 |
+
data = self.ln_pre(data)
|
234 |
+
|
235 |
+
l = torch.arange(self.n_latent).to(x.device)
|
236 |
+
h = self.latent_embed(timestep_embedding(l, self.decoder.width))
|
237 |
+
h = h.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
238 |
+
|
239 |
+
h = self.encoder(h, data)
|
240 |
+
h = self.processor(h)
|
241 |
+
h = self.decoder(data, h)
|
242 |
+
h = self.ln_post(h)
|
243 |
+
h = self.output_proj(h)
|
244 |
+
return h.permute(0, 2, 1)
|
shap_e/models/generation/pooled_mlp.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .util import timestep_embedding
|
5 |
+
|
6 |
+
|
7 |
+
class PooledMLP(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
device: torch.device,
|
11 |
+
*,
|
12 |
+
input_channels: int = 3,
|
13 |
+
output_channels: int = 6,
|
14 |
+
hidden_size: int = 256,
|
15 |
+
resblocks: int = 4,
|
16 |
+
pool_op: str = "max",
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device)
|
20 |
+
self.time_embed = nn.Linear(hidden_size, hidden_size, device=device)
|
21 |
+
|
22 |
+
blocks = []
|
23 |
+
for _ in range(resblocks):
|
24 |
+
blocks.append(ResBlock(hidden_size, pool_op, device=device))
|
25 |
+
self.sequence = nn.Sequential(*blocks)
|
26 |
+
|
27 |
+
self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device)
|
28 |
+
with torch.no_grad():
|
29 |
+
self.out.bias.zero_()
|
30 |
+
self.out.weight.zero_()
|
31 |
+
|
32 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
33 |
+
in_embed = self.input_embed(x)
|
34 |
+
t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1]))
|
35 |
+
h = in_embed + t_embed[..., None]
|
36 |
+
h = self.sequence(h)
|
37 |
+
h = self.out(h)
|
38 |
+
return h
|
39 |
+
|
40 |
+
|
41 |
+
class ResBlock(nn.Module):
|
42 |
+
def __init__(self, hidden_size: int, pool_op: str, device: torch.device):
|
43 |
+
super().__init__()
|
44 |
+
assert pool_op in ["mean", "max"]
|
45 |
+
self.pool_op = pool_op
|
46 |
+
self.body = nn.Sequential(
|
47 |
+
nn.SiLU(),
|
48 |
+
nn.LayerNorm((hidden_size,), device=device),
|
49 |
+
nn.Linear(hidden_size, hidden_size, device=device),
|
50 |
+
nn.SiLU(),
|
51 |
+
nn.LayerNorm((hidden_size,), device=device),
|
52 |
+
nn.Linear(hidden_size, hidden_size, device=device),
|
53 |
+
)
|
54 |
+
self.gate = nn.Sequential(
|
55 |
+
nn.Linear(hidden_size, hidden_size, device=device),
|
56 |
+
nn.Tanh(),
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor):
|
60 |
+
N, C, T = x.shape
|
61 |
+
out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1)
|
62 |
+
pooled = pool(self.pool_op, x)
|
63 |
+
gate = self.gate(pooled)
|
64 |
+
return x + out * gate[..., None]
|
65 |
+
|
66 |
+
|
67 |
+
def pool(op_name: str, x: torch.Tensor) -> torch.Tensor:
|
68 |
+
if op_name == "max":
|
69 |
+
pooled, _ = torch.max(x, dim=-1)
|
70 |
+
elif op_name == "mean":
|
71 |
+
pooled, _ = torch.mean(x, dim=-1)
|
72 |
+
else:
|
73 |
+
raise ValueError(f"unknown pool op: {op_name}")
|
74 |
+
return pooled
|
shap_e/models/generation/pretrained_clip.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterable, List, Optional, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from shap_e.models.download import default_cache_dir
|
9 |
+
|
10 |
+
ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
|
11 |
+
|
12 |
+
|
13 |
+
class ImageCLIP(nn.Module):
|
14 |
+
"""
|
15 |
+
A wrapper around a pre-trained CLIP model that automatically handles
|
16 |
+
batches of texts, images, and embeddings.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
device: torch.device,
|
22 |
+
dtype: Optional[torch.dtype] = torch.float32,
|
23 |
+
ensure_used_params: bool = True,
|
24 |
+
clip_name: str = "ViT-L/14",
|
25 |
+
cache_dir: Optional[str] = None,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
assert clip_name in ["ViT-L/14", "ViT-B/32"]
|
30 |
+
|
31 |
+
self.device = device
|
32 |
+
self.ensure_used_params = ensure_used_params
|
33 |
+
|
34 |
+
# Lazy import because of torchvision.
|
35 |
+
import clip
|
36 |
+
|
37 |
+
self.clip_model, self.preprocess = clip.load(
|
38 |
+
clip_name, device=device, download_root=cache_dir or default_cache_dir()
|
39 |
+
)
|
40 |
+
self.clip_name = clip_name
|
41 |
+
|
42 |
+
if dtype is not None:
|
43 |
+
self.clip_model.to(dtype)
|
44 |
+
self._tokenize = clip.tokenize
|
45 |
+
|
46 |
+
@property
|
47 |
+
def feature_dim(self) -> int:
|
48 |
+
if self.clip_name == "ViT-L/14":
|
49 |
+
return 768
|
50 |
+
else:
|
51 |
+
return 512
|
52 |
+
|
53 |
+
@property
|
54 |
+
def grid_size(self) -> int:
|
55 |
+
if self.clip_name == "ViT-L/14":
|
56 |
+
return 16
|
57 |
+
else:
|
58 |
+
return 7
|
59 |
+
|
60 |
+
@property
|
61 |
+
def grid_feature_dim(self) -> int:
|
62 |
+
if self.clip_name == "ViT-L/14":
|
63 |
+
return 1024
|
64 |
+
else:
|
65 |
+
return 768
|
66 |
+
|
67 |
+
def forward(
|
68 |
+
self,
|
69 |
+
batch_size: int,
|
70 |
+
images: Optional[Iterable[Optional[ImageType]]] = None,
|
71 |
+
texts: Optional[Iterable[Optional[str]]] = None,
|
72 |
+
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
|
73 |
+
) -> torch.Tensor:
|
74 |
+
"""
|
75 |
+
Generate a batch of embeddings from a mixture of images, texts,
|
76 |
+
precomputed embeddings, and possibly empty values.
|
77 |
+
|
78 |
+
For each batch element, at most one of images, texts, and embeddings
|
79 |
+
should have a non-None value. Embeddings from multiple modalities
|
80 |
+
cannot be mixed for a single batch element. If no modality is provided,
|
81 |
+
a zero embedding will be used for the batch element.
|
82 |
+
"""
|
83 |
+
image_seq = [None] * batch_size if images is None else list(images)
|
84 |
+
text_seq = [None] * batch_size if texts is None else list(texts)
|
85 |
+
embedding_seq = [None] * batch_size if embeddings is None else list(embeddings)
|
86 |
+
assert len(image_seq) == batch_size, "number of images should match batch size"
|
87 |
+
assert len(text_seq) == batch_size, "number of texts should match batch size"
|
88 |
+
assert len(embedding_seq) == batch_size, "number of embeddings should match batch size"
|
89 |
+
|
90 |
+
if self.ensure_used_params:
|
91 |
+
return self._static_multimodal_embed(
|
92 |
+
images=image_seq, texts=text_seq, embeddings=embedding_seq
|
93 |
+
)
|
94 |
+
|
95 |
+
result = torch.zeros((batch_size, self.feature_dim), device=self.device)
|
96 |
+
index_images = []
|
97 |
+
index_texts = []
|
98 |
+
for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)):
|
99 |
+
assert (
|
100 |
+
sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2
|
101 |
+
), "only one modality may be non-None per batch element"
|
102 |
+
if image is not None:
|
103 |
+
index_images.append((i, image))
|
104 |
+
elif text is not None:
|
105 |
+
index_texts.append((i, text))
|
106 |
+
elif emb is not None:
|
107 |
+
result[i] = emb.to(result)
|
108 |
+
|
109 |
+
if len(index_images):
|
110 |
+
embs = self.embed_images((img for _, img in index_images))
|
111 |
+
for (i, _), emb in zip(index_images, embs):
|
112 |
+
result[i] = emb.to(result)
|
113 |
+
if len(index_texts):
|
114 |
+
embs = self.embed_text((text for _, text in index_texts))
|
115 |
+
for (i, _), emb in zip(index_texts, embs):
|
116 |
+
result[i] = emb.to(result)
|
117 |
+
|
118 |
+
return result
|
119 |
+
|
120 |
+
def _static_multimodal_embed(
|
121 |
+
self,
|
122 |
+
images: List[Optional[ImageType]] = None,
|
123 |
+
texts: List[Optional[str]] = None,
|
124 |
+
embeddings: List[Optional[torch.Tensor]] = None,
|
125 |
+
) -> torch.Tensor:
|
126 |
+
"""
|
127 |
+
Like forward(), but always runs all encoders to ensure that
|
128 |
+
the forward graph looks the same on every rank.
|
129 |
+
"""
|
130 |
+
image_emb = self.embed_images(images)
|
131 |
+
text_emb = self.embed_text(t if t else "" for t in texts)
|
132 |
+
joined_embs = torch.stack(
|
133 |
+
[
|
134 |
+
emb.to(device=self.device, dtype=torch.float32)
|
135 |
+
if emb is not None
|
136 |
+
else torch.zeros(self.feature_dim, device=self.device)
|
137 |
+
for emb in embeddings
|
138 |
+
],
|
139 |
+
dim=0,
|
140 |
+
)
|
141 |
+
|
142 |
+
image_flag = torch.tensor([x is not None for x in images], device=self.device)[
|
143 |
+
:, None
|
144 |
+
].expand_as(image_emb)
|
145 |
+
text_flag = torch.tensor([x is not None for x in texts], device=self.device)[
|
146 |
+
:, None
|
147 |
+
].expand_as(image_emb)
|
148 |
+
emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[
|
149 |
+
:, None
|
150 |
+
].expand_as(image_emb)
|
151 |
+
|
152 |
+
return (
|
153 |
+
image_flag.float() * image_emb
|
154 |
+
+ text_flag.float() * text_emb
|
155 |
+
+ emb_flag.float() * joined_embs
|
156 |
+
+ self.clip_model.logit_scale * 0 # avoid unused parameters
|
157 |
+
)
|
158 |
+
|
159 |
+
def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
160 |
+
"""
|
161 |
+
:param xs: N images, stored as numpy arrays, tensors, or PIL images.
|
162 |
+
:return: an [N x D] tensor of features.
|
163 |
+
"""
|
164 |
+
clip_inputs = self.images_to_tensor(xs)
|
165 |
+
results = self.clip_model.encode_image(clip_inputs).float()
|
166 |
+
return results / torch.linalg.norm(results, dim=-1, keepdim=True)
|
167 |
+
|
168 |
+
def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
|
169 |
+
"""
|
170 |
+
Embed text prompts as an [N x D] tensor.
|
171 |
+
"""
|
172 |
+
enc = self.clip_model.encode_text(
|
173 |
+
self._tokenize(list(prompts), truncate=True).to(self.device)
|
174 |
+
).float()
|
175 |
+
return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)
|
176 |
+
|
177 |
+
def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
178 |
+
"""
|
179 |
+
Embed images into latent grids.
|
180 |
+
|
181 |
+
:param xs: an iterable of images to embed.
|
182 |
+
:return: a tensor of shape [N x C x L], where L = self.grid_size**2.
|
183 |
+
"""
|
184 |
+
if self.ensure_used_params:
|
185 |
+
extra_value = 0.0
|
186 |
+
for p in self.parameters():
|
187 |
+
extra_value = extra_value + p.mean() * 0.0
|
188 |
+
else:
|
189 |
+
extra_value = 0.0
|
190 |
+
|
191 |
+
x = self.images_to_tensor(xs).to(self.clip_model.dtype)
|
192 |
+
|
193 |
+
# https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225
|
194 |
+
vt = self.clip_model.visual
|
195 |
+
x = vt.conv1(x) # shape = [*, width, grid, grid]
|
196 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
197 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
198 |
+
x = torch.cat(
|
199 |
+
[
|
200 |
+
vt.class_embedding.to(x.dtype)
|
201 |
+
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
202 |
+
x,
|
203 |
+
],
|
204 |
+
dim=1,
|
205 |
+
) # shape = [*, grid ** 2 + 1, width]
|
206 |
+
x = x + vt.positional_embedding.to(x.dtype)
|
207 |
+
x = vt.ln_pre(x)
|
208 |
+
|
209 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
210 |
+
x = vt.transformer(x)
|
211 |
+
x = x.permute(1, 2, 0) # LND -> NDL
|
212 |
+
|
213 |
+
return x[..., 1:].contiguous().float() + extra_value
|
214 |
+
|
215 |
+
def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
216 |
+
return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device)
|
217 |
+
|
218 |
+
|
219 |
+
class FrozenImageCLIP:
|
220 |
+
def __init__(self, device: torch.device, **kwargs):
|
221 |
+
self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs)
|
222 |
+
for parameter in self.model.parameters():
|
223 |
+
parameter.requires_grad_(False)
|
224 |
+
|
225 |
+
@property
|
226 |
+
def feature_dim(self) -> int:
|
227 |
+
return self.model.feature_dim
|
228 |
+
|
229 |
+
@property
|
230 |
+
def grid_size(self) -> int:
|
231 |
+
return self.model.grid_size
|
232 |
+
|
233 |
+
@property
|
234 |
+
def grid_feature_dim(self) -> int:
|
235 |
+
return self.model.grid_feature_dim
|
236 |
+
|
237 |
+
def __call__(
|
238 |
+
self,
|
239 |
+
batch_size: int,
|
240 |
+
images: Optional[Iterable[Optional[ImageType]]] = None,
|
241 |
+
texts: Optional[Iterable[Optional[str]]] = None,
|
242 |
+
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
|
243 |
+
) -> torch.Tensor:
|
244 |
+
# We don't do a no_grad() here so that gradients could still
|
245 |
+
# flow to the input embeddings argument.
|
246 |
+
# This behavior is currently not used, but it could be.
|
247 |
+
return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings)
|
248 |
+
|
249 |
+
def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
250 |
+
with torch.no_grad():
|
251 |
+
return self.model.embed_images(xs)
|
252 |
+
|
253 |
+
def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
|
254 |
+
with torch.no_grad():
|
255 |
+
return self.model.embed_text(prompts)
|
256 |
+
|
257 |
+
def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
|
258 |
+
with torch.no_grad():
|
259 |
+
return self.model.embed_images_grid(xs)
|
260 |
+
|
261 |
+
|
262 |
+
def _image_to_pil(obj: Optional[ImageType]) -> Image.Image:
|
263 |
+
if obj is None:
|
264 |
+
return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))
|
265 |
+
if isinstance(obj, np.ndarray):
|
266 |
+
return Image.fromarray(obj.astype(np.uint8))
|
267 |
+
elif isinstance(obj, torch.Tensor):
|
268 |
+
return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))
|
269 |
+
else:
|
270 |
+
return obj
|
shap_e/models/generation/transformer.py
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from shap_e.models.nn.checkpoint import checkpoint
|
8 |
+
|
9 |
+
from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType
|
10 |
+
from .util import timestep_embedding
|
11 |
+
|
12 |
+
def init_linear(l, stddev):
|
13 |
+
nn.init.normal_(l.weight, std=stddev)
|
14 |
+
if l.bias is not None:
|
15 |
+
nn.init.constant_(l.bias, 0.0)
|
16 |
+
|
17 |
+
|
18 |
+
class MultiheadAttention(nn.Module):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
*,
|
22 |
+
device: torch.device,
|
23 |
+
dtype: torch.dtype,
|
24 |
+
n_ctx: int,
|
25 |
+
width: int,
|
26 |
+
heads: int,
|
27 |
+
init_scale: float,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.n_ctx = n_ctx
|
31 |
+
self.width = width
|
32 |
+
self.heads = heads
|
33 |
+
self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
|
34 |
+
self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
35 |
+
self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
|
36 |
+
init_linear(self.c_qkv, init_scale)
|
37 |
+
init_linear(self.c_proj, init_scale)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = self.c_qkv(x)
|
41 |
+
x = checkpoint(self.attention, (x,), (), True)
|
42 |
+
x = self.c_proj(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
class MLP(nn.Module):
|
47 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
|
48 |
+
super().__init__()
|
49 |
+
self.width = width
|
50 |
+
self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
|
51 |
+
self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
|
52 |
+
self.gelu = nn.GELU()
|
53 |
+
init_linear(self.c_fc, init_scale)
|
54 |
+
init_linear(self.c_proj, init_scale)
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
return self.c_proj(self.gelu(self.c_fc(x)))
|
58 |
+
|
59 |
+
|
60 |
+
class QKVMultiheadAttention(nn.Module):
|
61 |
+
def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
|
62 |
+
super().__init__()
|
63 |
+
self.device = device
|
64 |
+
self.dtype = dtype
|
65 |
+
self.heads = heads
|
66 |
+
self.n_ctx = n_ctx
|
67 |
+
|
68 |
+
def forward(self, qkv):
|
69 |
+
bs, n_ctx, width = qkv.shape
|
70 |
+
attn_ch = width // self.heads // 3
|
71 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
72 |
+
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
73 |
+
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
74 |
+
weight = torch.einsum(
|
75 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
76 |
+
) # More stable with f16 than dividing afterwards
|
77 |
+
wdtype = weight.dtype
|
78 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
79 |
+
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
80 |
+
|
81 |
+
|
82 |
+
class ResidualAttentionBlock(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
*,
|
86 |
+
device: torch.device,
|
87 |
+
dtype: torch.dtype,
|
88 |
+
n_ctx: int,
|
89 |
+
width: int,
|
90 |
+
heads: int,
|
91 |
+
init_scale: float = 1.0,
|
92 |
+
):
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
self.attn = MultiheadAttention(
|
96 |
+
device=device,
|
97 |
+
dtype=dtype,
|
98 |
+
n_ctx=n_ctx,
|
99 |
+
width=width,
|
100 |
+
heads=heads,
|
101 |
+
init_scale=init_scale,
|
102 |
+
)
|
103 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
104 |
+
self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
|
105 |
+
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
|
106 |
+
|
107 |
+
def forward(self, x: torch.Tensor):
|
108 |
+
x = x + self.attn(self.ln_1(x))
|
109 |
+
x = x + self.mlp(self.ln_2(x))
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class Transformer(nn.Module):
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
*,
|
117 |
+
device: torch.device,
|
118 |
+
dtype: torch.dtype,
|
119 |
+
n_ctx: int,
|
120 |
+
width: int,
|
121 |
+
layers: int,
|
122 |
+
heads: int,
|
123 |
+
init_scale: float = 0.25,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
self.n_ctx = n_ctx
|
127 |
+
self.width = width
|
128 |
+
self.layers = layers
|
129 |
+
init_scale = init_scale * math.sqrt(1.0 / width)
|
130 |
+
self.resblocks = nn.ModuleList(
|
131 |
+
[
|
132 |
+
ResidualAttentionBlock(
|
133 |
+
device=device,
|
134 |
+
dtype=dtype,
|
135 |
+
n_ctx=n_ctx,
|
136 |
+
width=width,
|
137 |
+
heads=heads,
|
138 |
+
init_scale=init_scale,
|
139 |
+
)
|
140 |
+
for _ in range(layers)
|
141 |
+
]
|
142 |
+
)
|
143 |
+
|
144 |
+
def forward(self, x: torch.Tensor):
|
145 |
+
for block in self.resblocks:
|
146 |
+
x = block(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
|
150 |
+
class PointDiffusionTransformer(nn.Module):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
*,
|
154 |
+
device: torch.device,
|
155 |
+
dtype: torch.dtype,
|
156 |
+
input_channels: int = 3,
|
157 |
+
output_channels: int = 3,
|
158 |
+
n_ctx: int = 1024,
|
159 |
+
width: int = 512,
|
160 |
+
layers: int = 12,
|
161 |
+
heads: int = 8,
|
162 |
+
init_scale: float = 0.25,
|
163 |
+
time_token_cond: bool = False,
|
164 |
+
use_pos_emb: bool = False,
|
165 |
+
pos_emb_init_scale: float = 1.0,
|
166 |
+
pos_emb_n_ctx: Optional[int] = None,
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
self.input_channels = input_channels
|
170 |
+
self.output_channels = output_channels
|
171 |
+
self.n_ctx = n_ctx
|
172 |
+
self.time_token_cond = time_token_cond
|
173 |
+
self.use_pos_emb = use_pos_emb
|
174 |
+
self.time_embed = MLP(
|
175 |
+
device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
|
176 |
+
)
|
177 |
+
self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
|
178 |
+
self.backbone = Transformer(
|
179 |
+
device=device,
|
180 |
+
dtype=dtype,
|
181 |
+
n_ctx=n_ctx + int(time_token_cond),
|
182 |
+
width=width,
|
183 |
+
layers=layers,
|
184 |
+
heads=heads,
|
185 |
+
init_scale=init_scale,
|
186 |
+
)
|
187 |
+
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
188 |
+
self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
|
189 |
+
self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
|
190 |
+
# with torch.no_grad():
|
191 |
+
# self.output_proj.weight.zero_()
|
192 |
+
# self.output_proj.bias.zero_()
|
193 |
+
if self.use_pos_emb:
|
194 |
+
self.register_parameter(
|
195 |
+
"pos_emb",
|
196 |
+
nn.Parameter(
|
197 |
+
pos_emb_init_scale
|
198 |
+
* torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype)
|
199 |
+
),
|
200 |
+
)
|
201 |
+
|
202 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
203 |
+
"""
|
204 |
+
:param x: an [N x C x T] tensor.
|
205 |
+
:param t: an [N] tensor.
|
206 |
+
:return: an [N x C' x T] tensor.
|
207 |
+
"""
|
208 |
+
assert x.shape[-1] == self.n_ctx
|
209 |
+
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
210 |
+
return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
|
211 |
+
|
212 |
+
def _forward_with_cond(
|
213 |
+
self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
|
214 |
+
) -> torch.Tensor:
|
215 |
+
h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC
|
216 |
+
for emb, as_token in cond_as_token:
|
217 |
+
if not as_token:
|
218 |
+
h = h + emb[:, None]
|
219 |
+
if self.use_pos_emb:
|
220 |
+
h = h + self.pos_emb
|
221 |
+
extra_tokens = [
|
222 |
+
(emb[:, None] if len(emb.shape) == 2 else emb)
|
223 |
+
for emb, as_token in cond_as_token
|
224 |
+
if as_token
|
225 |
+
]
|
226 |
+
if len(extra_tokens):
|
227 |
+
h = torch.cat(extra_tokens + [h], dim=1)
|
228 |
+
h = self.ln_pre(h)
|
229 |
+
h = self.backbone(h)
|
230 |
+
h = self.ln_post(h)
|
231 |
+
if len(extra_tokens):
|
232 |
+
h = h[:, sum(h.shape[1] for h in extra_tokens):]
|
233 |
+
h = self.output_proj(h)
|
234 |
+
return h.permute(0, 2, 1) # NCL -> NLC
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
*,
|
243 |
+
device: torch.device,
|
244 |
+
dtype: torch.dtype,
|
245 |
+
n_ctx: int = 1024,
|
246 |
+
token_cond: bool = False,
|
247 |
+
cond_drop_prob: float = 0.0,
|
248 |
+
frozen_clip: bool = True,
|
249 |
+
**kwargs,
|
250 |
+
):
|
251 |
+
super().__init__(
|
252 |
+
device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs
|
253 |
+
)
|
254 |
+
# print("!!!!!", "deivce:", device, "dtype:", dtype, "n_ctx:", n_ctx, "token_cond:", token_cond, "cond_drop_prob:", cond_drop_prob, "frozen_clip:", frozen_clip, "kwargs:", kwargs)
|
255 |
+
self.n_ctx = n_ctx
|
256 |
+
self.token_cond = token_cond
|
257 |
+
self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
|
258 |
+
self.clip_embed = nn.Linear(
|
259 |
+
self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype
|
260 |
+
)
|
261 |
+
self.cond_drop_prob = cond_drop_prob
|
262 |
+
|
263 |
+
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
264 |
+
with torch.no_grad():
|
265 |
+
return dict(embeddings=self.clip(batch_size, **model_kwargs))
|
266 |
+
|
267 |
+
def forward(
|
268 |
+
self,
|
269 |
+
x: torch.Tensor,
|
270 |
+
t: torch.Tensor,
|
271 |
+
images: Optional[Iterable[Optional[ImageType]]] = None,
|
272 |
+
texts: Optional[Iterable[Optional[str]]] = None,
|
273 |
+
embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
|
274 |
+
):
|
275 |
+
"""
|
276 |
+
:param x: an [N x C x T] tensor.
|
277 |
+
:param t: an [N] tensor.
|
278 |
+
:param images: a batch of images to condition on.
|
279 |
+
:param texts: a batch of texts to condition on.
|
280 |
+
:param embeddings: a batch of CLIP embeddings to condition on.
|
281 |
+
:return: an [N x C' x T] tensor.
|
282 |
+
"""
|
283 |
+
# print("x.shape", x.shape, "t.shape", t.shape, "images", images, "texts", texts, "embeddings", embeddings)
|
284 |
+
assert x.shape[-1] == self.n_ctx # self.n_ctx = 1024
|
285 |
+
|
286 |
+
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
287 |
+
clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings)
|
288 |
+
assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0]
|
289 |
+
|
290 |
+
if self.training:
|
291 |
+
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
|
292 |
+
clip_out = clip_out * mask[:, None].to(clip_out)
|
293 |
+
|
294 |
+
# Rescale the features to have unit variance
|
295 |
+
clip_out = math.sqrt(clip_out.shape[1]) * clip_out
|
296 |
+
|
297 |
+
clip_embed = self.clip_embed(clip_out)
|
298 |
+
|
299 |
+
cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)]
|
300 |
+
return self._forward_with_cond(x, cond)
|
301 |
+
|
302 |
+
|
303 |
+
class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer):
|
304 |
+
def __init__(
|
305 |
+
self,
|
306 |
+
*,
|
307 |
+
device: torch.device,
|
308 |
+
dtype: torch.dtype,
|
309 |
+
n_ctx: int = 1024,
|
310 |
+
cond_drop_prob: float = 0.0,
|
311 |
+
frozen_clip: bool = True,
|
312 |
+
**kwargs,
|
313 |
+
):
|
314 |
+
clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
|
315 |
+
super().__init__(
|
316 |
+
device=device,
|
317 |
+
dtype=dtype,
|
318 |
+
n_ctx=n_ctx + clip.grid_size**2,
|
319 |
+
pos_emb_n_ctx=n_ctx,
|
320 |
+
**kwargs,
|
321 |
+
)
|
322 |
+
self.n_ctx = n_ctx
|
323 |
+
self.clip = clip
|
324 |
+
self.clip_embed = nn.Sequential(
|
325 |
+
nn.LayerNorm(
|
326 |
+
normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
|
327 |
+
),
|
328 |
+
nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
|
329 |
+
)
|
330 |
+
self.cond_drop_prob = cond_drop_prob
|
331 |
+
|
332 |
+
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
333 |
+
_ = batch_size
|
334 |
+
with torch.no_grad():
|
335 |
+
return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"]))
|
336 |
+
|
337 |
+
def forward(
|
338 |
+
self,
|
339 |
+
x: torch.Tensor,
|
340 |
+
t: torch.Tensor,
|
341 |
+
images: Optional[Iterable[ImageType]] = None,
|
342 |
+
embeddings: Optional[Iterable[torch.Tensor]] = None,
|
343 |
+
):
|
344 |
+
"""
|
345 |
+
:param x: an [N x C x T] tensor.
|
346 |
+
:param t: an [N] tensor.
|
347 |
+
:param images: a batch of images to condition on.
|
348 |
+
:param embeddings: a batch of CLIP latent grids to condition on.
|
349 |
+
:return: an [N x C' x T] tensor.
|
350 |
+
"""
|
351 |
+
assert images is not None or embeddings is not None, "must specify images or embeddings"
|
352 |
+
assert images is None or embeddings is None, "cannot specify both images and embeddings"
|
353 |
+
assert x.shape[-1] == self.n_ctx
|
354 |
+
|
355 |
+
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
356 |
+
|
357 |
+
if images is not None:
|
358 |
+
clip_out = self.clip.embed_images_grid(images)
|
359 |
+
else:
|
360 |
+
clip_out = embeddings
|
361 |
+
|
362 |
+
if self.training:
|
363 |
+
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
|
364 |
+
clip_out = clip_out * mask[:, None, None].to(clip_out)
|
365 |
+
|
366 |
+
clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
|
367 |
+
clip_embed = self.clip_embed(clip_out)
|
368 |
+
|
369 |
+
cond = [(t_embed, self.time_token_cond), (clip_embed, True)]
|
370 |
+
return self._forward_with_cond(x, cond)
|
371 |
+
|
372 |
+
|
373 |
+
class UpsamplePointDiffusionTransformer(PointDiffusionTransformer):
|
374 |
+
def __init__(
|
375 |
+
self,
|
376 |
+
*,
|
377 |
+
device: torch.device,
|
378 |
+
dtype: torch.dtype,
|
379 |
+
cond_input_channels: Optional[int] = None,
|
380 |
+
cond_ctx: int = 1024,
|
381 |
+
n_ctx: int = 4096 - 1024,
|
382 |
+
channel_scales: Optional[Sequence[float]] = None,
|
383 |
+
channel_biases: Optional[Sequence[float]] = None,
|
384 |
+
**kwargs,
|
385 |
+
):
|
386 |
+
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs)
|
387 |
+
self.n_ctx = n_ctx
|
388 |
+
self.cond_input_channels = cond_input_channels or self.input_channels
|
389 |
+
self.cond_point_proj = nn.Linear(
|
390 |
+
self.cond_input_channels, self.backbone.width, device=device, dtype=dtype
|
391 |
+
)
|
392 |
+
|
393 |
+
self.register_buffer(
|
394 |
+
"channel_scales",
|
395 |
+
torch.tensor(channel_scales, dtype=dtype, device=device)
|
396 |
+
if channel_scales is not None
|
397 |
+
else None,
|
398 |
+
)
|
399 |
+
self.register_buffer(
|
400 |
+
"channel_biases",
|
401 |
+
torch.tensor(channel_biases, dtype=dtype, device=device)
|
402 |
+
if channel_biases is not None
|
403 |
+
else None,
|
404 |
+
)
|
405 |
+
|
406 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor):
|
407 |
+
"""
|
408 |
+
:param x: an [N x C1 x T] tensor.
|
409 |
+
:param t: an [N] tensor.
|
410 |
+
:param low_res: an [N x C2 x T'] tensor of conditioning points.
|
411 |
+
:return: an [N x C3 x T] tensor.
|
412 |
+
"""
|
413 |
+
assert x.shape[-1] == self.n_ctx
|
414 |
+
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
415 |
+
low_res_embed = self._embed_low_res(low_res)
|
416 |
+
cond = [(t_embed, self.time_token_cond), (low_res_embed, True)]
|
417 |
+
return self._forward_with_cond(x, cond)
|
418 |
+
|
419 |
+
def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor:
|
420 |
+
if self.channel_scales is not None:
|
421 |
+
x = x * self.channel_scales[None, :, None]
|
422 |
+
if self.channel_biases is not None:
|
423 |
+
x = x + self.channel_biases[None, :, None]
|
424 |
+
return self.cond_point_proj(x.permute(0, 2, 1))
|
425 |
+
|
426 |
+
|
427 |
+
class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer):
|
428 |
+
def __init__(
|
429 |
+
self,
|
430 |
+
*,
|
431 |
+
device: torch.device,
|
432 |
+
dtype: torch.dtype,
|
433 |
+
n_ctx: int = 4096 - 1024,
|
434 |
+
cond_drop_prob: float = 0.0,
|
435 |
+
frozen_clip: bool = True,
|
436 |
+
**kwargs,
|
437 |
+
):
|
438 |
+
clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
|
439 |
+
super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs)
|
440 |
+
self.n_ctx = n_ctx
|
441 |
+
|
442 |
+
self.clip = clip
|
443 |
+
self.clip_embed = nn.Sequential(
|
444 |
+
nn.LayerNorm(
|
445 |
+
normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
|
446 |
+
),
|
447 |
+
nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
|
448 |
+
)
|
449 |
+
self.cond_drop_prob = cond_drop_prob
|
450 |
+
|
451 |
+
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
452 |
+
_ = batch_size
|
453 |
+
with torch.no_grad():
|
454 |
+
return dict(
|
455 |
+
embeddings=self.clip.embed_images_grid(model_kwargs["images"]),
|
456 |
+
low_res=model_kwargs["low_res"],
|
457 |
+
)
|
458 |
+
|
459 |
+
def forward(
|
460 |
+
self,
|
461 |
+
x: torch.Tensor,
|
462 |
+
t: torch.Tensor,
|
463 |
+
*,
|
464 |
+
low_res: torch.Tensor,
|
465 |
+
images: Optional[Iterable[ImageType]] = None,
|
466 |
+
embeddings: Optional[Iterable[torch.Tensor]] = None,
|
467 |
+
):
|
468 |
+
"""
|
469 |
+
:param x: an [N x C1 x T] tensor.
|
470 |
+
:param t: an [N] tensor.
|
471 |
+
:param low_res: an [N x C2 x T'] tensor of conditioning points.
|
472 |
+
:param images: a batch of images to condition on.
|
473 |
+
:param embeddings: a batch of CLIP latent grids to condition on.
|
474 |
+
:return: an [N x C3 x T] tensor.
|
475 |
+
"""
|
476 |
+
assert x.shape[-1] == self.n_ctx
|
477 |
+
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
478 |
+
low_res_embed = self._embed_low_res(low_res)
|
479 |
+
|
480 |
+
if images is not None:
|
481 |
+
clip_out = self.clip.embed_images_grid(images)
|
482 |
+
else:
|
483 |
+
clip_out = embeddings
|
484 |
+
|
485 |
+
if self.training:
|
486 |
+
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
|
487 |
+
clip_out = clip_out * mask[:, None, None].to(clip_out)
|
488 |
+
|
489 |
+
clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
|
490 |
+
clip_embed = self.clip_embed(clip_out)
|
491 |
+
|
492 |
+
cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)]
|
493 |
+
return self._forward_with_cond(x, cond)
|
494 |
+
|
shap_e/models/generation/util.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
7 |
+
"""
|
8 |
+
Create sinusoidal timestep embeddings.
|
9 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
10 |
+
These may be fractional.
|
11 |
+
:param dim: the dimension of the output.
|
12 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
13 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
14 |
+
"""
|
15 |
+
half = dim // 2
|
16 |
+
freqs = torch.exp(
|
17 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
18 |
+
).to(device=timesteps.device)
|
19 |
+
args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
|
20 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
21 |
+
if dim % 2:
|
22 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
23 |
+
return embedding
|
shap_e/models/nerf/__init__.py
ADDED
File without changes
|
shap_e/models/nerf/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (168 Bytes). View file
|
|
shap_e/models/nerf/__pycache__/model.cpython-39.pyc
ADDED
Binary file (6.51 kB). View file
|
|
shap_e/models/nerf/__pycache__/ray.cpython-39.pyc
ADDED
Binary file (15.3 kB). View file
|
|
shap_e/models/nerf/__pycache__/renderer.cpython-39.pyc
ADDED
Binary file (5.62 kB). View file
|
|
shap_e/models/nerf/model.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from functools import partial
|
3 |
+
from typing import Any, Dict, Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from shap_e.models.nn.checkpoint import checkpoint
|
10 |
+
from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis
|
11 |
+
from shap_e.models.nn.meta import MetaModule, subdict
|
12 |
+
from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init
|
13 |
+
from shap_e.models.nn.utils import ArrayType
|
14 |
+
from shap_e.models.query import Query
|
15 |
+
from shap_e.util.collections import AttrDict
|
16 |
+
|
17 |
+
|
18 |
+
class NeRFModel(ABC):
|
19 |
+
"""
|
20 |
+
Parametric scene representation whose outputs are integrated by NeRFRenderer
|
21 |
+
"""
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def forward(
|
25 |
+
self,
|
26 |
+
query: Query,
|
27 |
+
params: Optional[Dict[str, torch.Tensor]] = None,
|
28 |
+
options: Optional[Dict[str, Any]] = None,
|
29 |
+
) -> AttrDict:
|
30 |
+
"""
|
31 |
+
:param query: the points in the field to query.
|
32 |
+
:param params: Meta parameters
|
33 |
+
:param options: Optional hyperparameters
|
34 |
+
:return: An AttrDict containing at least
|
35 |
+
- density: [batch_size x ... x 1]
|
36 |
+
- channels: [batch_size x ... x n_channels]
|
37 |
+
- aux_losses: [batch_size x ... x 1]
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
class VoidNeRFModel(MetaModule, NeRFModel):
|
42 |
+
"""
|
43 |
+
Implements the default empty space model where all queries are rendered as
|
44 |
+
background.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
background: ArrayType,
|
50 |
+
trainable: bool = False,
|
51 |
+
channel_scale: float = 255.0,
|
52 |
+
device: torch.device = torch.device("cuda"),
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
background = nn.Parameter(
|
56 |
+
torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device)
|
57 |
+
/ channel_scale
|
58 |
+
)
|
59 |
+
if trainable:
|
60 |
+
self.register_parameter("background", background)
|
61 |
+
else:
|
62 |
+
self.register_buffer("background", background)
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
query: Query,
|
67 |
+
params: Optional[Dict[str, torch.Tensor]] = None,
|
68 |
+
options: Optional[Dict[str, Any]] = None,
|
69 |
+
) -> AttrDict:
|
70 |
+
_ = params
|
71 |
+
default_bg = self.background[None]
|
72 |
+
background = options.get("background", default_bg) if options is not None else default_bg
|
73 |
+
|
74 |
+
shape = query.position.shape[:-1]
|
75 |
+
ones = [1] * (len(shape) - 1)
|
76 |
+
n_channels = background.shape[-1]
|
77 |
+
background = torch.broadcast_to(
|
78 |
+
background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]
|
79 |
+
)
|
80 |
+
return background
|
81 |
+
|
82 |
+
|
83 |
+
class MLPNeRFModel(MetaModule, NeRFModel):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
# Positional encoding parameters
|
87 |
+
n_levels: int = 10,
|
88 |
+
# MLP parameters
|
89 |
+
d_hidden: int = 256,
|
90 |
+
n_density_layers: int = 4,
|
91 |
+
n_channel_layers: int = 1,
|
92 |
+
n_channels: int = 3,
|
93 |
+
sh_degree: int = 4,
|
94 |
+
activation: str = "relu",
|
95 |
+
density_activation: str = "exp",
|
96 |
+
init: Optional[str] = None,
|
97 |
+
init_scale: float = 1.0,
|
98 |
+
output_activation: str = "sigmoid",
|
99 |
+
meta_parameters: bool = False,
|
100 |
+
trainable_meta: bool = False,
|
101 |
+
zero_out: bool = True,
|
102 |
+
register_freqs: bool = True,
|
103 |
+
posenc_version: str = "v1",
|
104 |
+
device: torch.device = torch.device("cuda"),
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
# Positional encoding
|
109 |
+
if register_freqs:
|
110 |
+
# not used anymore
|
111 |
+
self.register_buffer(
|
112 |
+
"freqs",
|
113 |
+
2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels),
|
114 |
+
)
|
115 |
+
|
116 |
+
self.posenc_version = posenc_version
|
117 |
+
dummy = torch.eye(1, 3)
|
118 |
+
d_input = encode_position(posenc_version, position=dummy).shape[-1]
|
119 |
+
|
120 |
+
self.n_levels = n_levels
|
121 |
+
|
122 |
+
self.sh_degree = sh_degree
|
123 |
+
d_sh_coeffs = sh_degree**2
|
124 |
+
|
125 |
+
self.meta_parameters = meta_parameters
|
126 |
+
|
127 |
+
mlp_cls = (
|
128 |
+
partial(
|
129 |
+
MetaMLP,
|
130 |
+
meta_scale=False,
|
131 |
+
meta_shift=False,
|
132 |
+
meta_proj=True,
|
133 |
+
meta_bias=True,
|
134 |
+
trainable_meta=trainable_meta,
|
135 |
+
)
|
136 |
+
if meta_parameters
|
137 |
+
else MLP
|
138 |
+
)
|
139 |
+
|
140 |
+
self.density_mlp = mlp_cls(
|
141 |
+
d_input=d_input,
|
142 |
+
d_hidden=[d_hidden] * (n_density_layers - 1),
|
143 |
+
d_output=d_hidden,
|
144 |
+
act_name=activation,
|
145 |
+
init_scale=init_scale,
|
146 |
+
)
|
147 |
+
|
148 |
+
self.channel_mlp = mlp_cls(
|
149 |
+
d_input=d_hidden + d_sh_coeffs,
|
150 |
+
d_hidden=[d_hidden] * n_channel_layers,
|
151 |
+
d_output=n_channels,
|
152 |
+
act_name=activation,
|
153 |
+
init_scale=init_scale,
|
154 |
+
)
|
155 |
+
|
156 |
+
self.act = get_act(output_activation)
|
157 |
+
self.density_act = get_act(density_activation)
|
158 |
+
|
159 |
+
mlp_init(
|
160 |
+
list(self.density_mlp.affines) + list(self.channel_mlp.affines),
|
161 |
+
init=init,
|
162 |
+
init_scale=init_scale,
|
163 |
+
)
|
164 |
+
|
165 |
+
if zero_out:
|
166 |
+
zero_init(self.channel_mlp.affines[-1])
|
167 |
+
|
168 |
+
self.to(device)
|
169 |
+
|
170 |
+
def encode_position(self, query: Query):
|
171 |
+
h = encode_position(self.posenc_version, position=query.position)
|
172 |
+
return h
|
173 |
+
|
174 |
+
def forward(
|
175 |
+
self,
|
176 |
+
query: Query,
|
177 |
+
params: Optional[Dict[str, torch.Tensor]] = None,
|
178 |
+
options: Optional[Dict[str, Any]] = None,
|
179 |
+
) -> AttrDict:
|
180 |
+
params = self.update(params)
|
181 |
+
|
182 |
+
options = AttrDict() if options is None else AttrDict(options)
|
183 |
+
|
184 |
+
query = query.copy()
|
185 |
+
|
186 |
+
h_position = self.encode_position(query)
|
187 |
+
|
188 |
+
if self.meta_parameters:
|
189 |
+
density_params = subdict(params, "density_mlp")
|
190 |
+
density_mlp = partial(
|
191 |
+
self.density_mlp, params=density_params, options=options, log_prefix="density_"
|
192 |
+
)
|
193 |
+
density_mlp_parameters = list(density_params.values())
|
194 |
+
else:
|
195 |
+
density_mlp = partial(self.density_mlp, options=options, log_prefix="density_")
|
196 |
+
density_mlp_parameters = self.density_mlp.parameters()
|
197 |
+
h_density = checkpoint(
|
198 |
+
density_mlp,
|
199 |
+
(h_position,),
|
200 |
+
density_mlp_parameters,
|
201 |
+
options.checkpoint_nerf_mlp,
|
202 |
+
)
|
203 |
+
h_direction = maybe_get_spherical_harmonics_basis(
|
204 |
+
sh_degree=self.sh_degree,
|
205 |
+
coords_shape=query.position.shape,
|
206 |
+
coords=query.direction,
|
207 |
+
device=query.position.device,
|
208 |
+
)
|
209 |
+
|
210 |
+
if self.meta_parameters:
|
211 |
+
channel_params = subdict(params, "channel_mlp")
|
212 |
+
channel_mlp = partial(
|
213 |
+
self.channel_mlp, params=channel_params, options=options, log_prefix="channel_"
|
214 |
+
)
|
215 |
+
channel_mlp_parameters = list(channel_params.values())
|
216 |
+
else:
|
217 |
+
channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_")
|
218 |
+
channel_mlp_parameters = self.channel_mlp.parameters()
|
219 |
+
h_channel = checkpoint(
|
220 |
+
channel_mlp,
|
221 |
+
(torch.cat([h_density, h_direction], dim=-1),),
|
222 |
+
channel_mlp_parameters,
|
223 |
+
options.checkpoint_nerf_mlp,
|
224 |
+
)
|
225 |
+
|
226 |
+
density_logit = h_density[..., :1]
|
227 |
+
|
228 |
+
res = AttrDict(
|
229 |
+
density_logit=density_logit,
|
230 |
+
density=self.density_act(density_logit),
|
231 |
+
channels=self.act(h_channel),
|
232 |
+
aux_losses=AttrDict(),
|
233 |
+
no_weight_grad_aux_losses=AttrDict(),
|
234 |
+
)
|
235 |
+
if options.return_h_density:
|
236 |
+
res.h_density = h_density
|
237 |
+
|
238 |
+
return res
|
239 |
+
|
240 |
+
|
241 |
+
def maybe_get_spherical_harmonics_basis(
|
242 |
+
sh_degree: int,
|
243 |
+
coords_shape: Tuple[int],
|
244 |
+
coords: Optional[torch.Tensor] = None,
|
245 |
+
device: torch.device = torch.device("cuda"),
|
246 |
+
) -> torch.Tensor:
|
247 |
+
"""
|
248 |
+
:param sh_degree: Spherical harmonics degree
|
249 |
+
:param coords_shape: [*shape, 3]
|
250 |
+
:param coords: optional coordinate tensor of coords_shape
|
251 |
+
"""
|
252 |
+
if coords is None:
|
253 |
+
return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device)
|
254 |
+
|
255 |
+
return spherical_harmonics_basis(coords, sh_degree)
|
shap_e/models/nerf/ray.py
ADDED
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from functools import partial
|
4 |
+
from typing import Any, Dict, List, Optional, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from shap_e.models.nn.utils import sample_pmf
|
9 |
+
from shap_e.models.volume import Volume, VolumeRange
|
10 |
+
from shap_e.util.collections import AttrDict
|
11 |
+
|
12 |
+
from .model import NeRFModel, Query
|
13 |
+
|
14 |
+
|
15 |
+
def render_rays(
|
16 |
+
rays: torch.Tensor,
|
17 |
+
parts: List["RayVolumeIntegral"],
|
18 |
+
void_model: NeRFModel,
|
19 |
+
shared: bool = False,
|
20 |
+
prev_raw_outputs: Optional[List[AttrDict]] = None,
|
21 |
+
render_with_direction: bool = True,
|
22 |
+
importance_sampling_options: Optional[Dict[str, Any]] = None,
|
23 |
+
) -> Tuple["RayVolumeIntegralResults", List["RaySampler"], List[AttrDict]]:
|
24 |
+
"""
|
25 |
+
Perform volumetric rendering over a partition of possible t's in the union
|
26 |
+
of rendering volumes (written below with some abuse of notations)
|
27 |
+
|
28 |
+
C(r) := sum(
|
29 |
+
transmittance(t[i]) *
|
30 |
+
integrate(
|
31 |
+
lambda t: density(t) * channels(t) * transmittance(t),
|
32 |
+
[t[i], t[i + 1]],
|
33 |
+
)
|
34 |
+
for i in range(len(parts))
|
35 |
+
) + transmittance(t[-1]) * void_model(t[-1]).channels
|
36 |
+
|
37 |
+
where
|
38 |
+
|
39 |
+
1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the
|
40 |
+
probability of light passing through the volume specified by [t[0], s].
|
41 |
+
(transmittance of 1 means light can pass freely)
|
42 |
+
2) density and channels are obtained by evaluating the appropriate
|
43 |
+
part.model at time t.
|
44 |
+
3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects
|
45 |
+
(parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface
|
46 |
+
of the shell (if bounded). If the ray does not intersect, the integral over
|
47 |
+
this segment is evaluated as 0 and transmittance(t[i + 1]) :=
|
48 |
+
transmittance(t[i]).
|
49 |
+
4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that
|
50 |
+
is evaluated by the void_model (i.e. we consider this space to be empty).
|
51 |
+
|
52 |
+
:param rays: [batch_size x ... x 2 x 3] origin and direction.
|
53 |
+
:param parts: disjoint volume integrals.
|
54 |
+
:param void_model: use this model to integrate over the empty space
|
55 |
+
:param shared: All RayVolumeIntegrals are calculated with the same model.
|
56 |
+
:param prev_raw_outputs: Raw outputs from the previous rendering step
|
57 |
+
|
58 |
+
:return: A tuple of
|
59 |
+
- AttrDict containing the rendered `channels`, `distances`, and the `aux_losses`
|
60 |
+
- A list of importance samplers for additional fine-grained rendering
|
61 |
+
- A list of raw output for each interval
|
62 |
+
"""
|
63 |
+
if importance_sampling_options is None:
|
64 |
+
importance_sampling_options = {}
|
65 |
+
|
66 |
+
origin, direc = rays[..., 0, :], rays[..., 1, :]
|
67 |
+
|
68 |
+
if prev_raw_outputs is None:
|
69 |
+
prev_raw_outputs = [None] * len(parts)
|
70 |
+
|
71 |
+
samplers = []
|
72 |
+
raw_outputs = []
|
73 |
+
t0 = None
|
74 |
+
results = None
|
75 |
+
# import pdb; pdb.set_trace()
|
76 |
+
for part_i, prev_raw_i in zip(parts, prev_raw_outputs):
|
77 |
+
|
78 |
+
# Integrate over [t[i], t[i + 1]]
|
79 |
+
results_i = part_i.render_rays(
|
80 |
+
origin,
|
81 |
+
direc,
|
82 |
+
t0=t0,
|
83 |
+
prev_raw=prev_raw_i,
|
84 |
+
shared=shared,
|
85 |
+
render_with_direction=render_with_direction,
|
86 |
+
)
|
87 |
+
|
88 |
+
# Create an importance sampler for (optional) fine rendering
|
89 |
+
samplers.append(
|
90 |
+
ImportanceRaySampler(
|
91 |
+
results_i.volume_range, results_i.raw, **importance_sampling_options
|
92 |
+
)
|
93 |
+
)
|
94 |
+
raw_outputs.append(results_i.raw)
|
95 |
+
|
96 |
+
# Pass t[i + 1] as the start of integration for the next interval.
|
97 |
+
t0 = results_i.volume_range.next_t0()
|
98 |
+
|
99 |
+
# Combine the results from [t[0], t[i]] and [t[i], t[i+1]]
|
100 |
+
results = results_i if results is None else results.combine(results_i)
|
101 |
+
|
102 |
+
# While integrating out [t[-1], math.inf] is the correct thing to do, this
|
103 |
+
# erases a lot of useful information. Also, void_model is meant to predict
|
104 |
+
# the channels at t=math.inf.
|
105 |
+
|
106 |
+
# # Add the void background over [t[-1], math.inf] to complete integration.
|
107 |
+
# results = results.combine(
|
108 |
+
# RayVolumeIntegralResults(
|
109 |
+
# output=AttrDict(
|
110 |
+
# channels=void_model(origin, direc),
|
111 |
+
# distances=torch.zeros_like(t0),
|
112 |
+
# aux_losses=AttrDict(),
|
113 |
+
# ),
|
114 |
+
# volume_range=VolumeRange(
|
115 |
+
# t0=t0,
|
116 |
+
# t1=torch.full_like(t0, math.inf),
|
117 |
+
# intersected=torch.full_like(results.volume_range.intersected, True),
|
118 |
+
# ),
|
119 |
+
# # Void space extends to infinity. It is assumed that no light
|
120 |
+
# # passes beyond the void.
|
121 |
+
# transmittance=torch.zeros_like(results_i.transmittance),
|
122 |
+
# )
|
123 |
+
# )
|
124 |
+
results.output.channels = results.output.channels + results.transmittance * void_model(
|
125 |
+
Query(origin, direc)
|
126 |
+
)
|
127 |
+
|
128 |
+
return results, samplers, raw_outputs
|
129 |
+
|
130 |
+
|
131 |
+
@dataclass
|
132 |
+
class RayVolumeIntegralResults:
|
133 |
+
"""
|
134 |
+
Stores the relevant state and results of
|
135 |
+
|
136 |
+
integrate(
|
137 |
+
lambda t: density(t) * channels(t) * transmittance(t),
|
138 |
+
[t0, t1],
|
139 |
+
)
|
140 |
+
"""
|
141 |
+
|
142 |
+
# Rendered output and auxiliary losses
|
143 |
+
# output.channels has shape [batch_size, *inner_shape, n_channels]
|
144 |
+
output: AttrDict
|
145 |
+
|
146 |
+
"""
|
147 |
+
Optional values
|
148 |
+
"""
|
149 |
+
|
150 |
+
# Raw values contain the sampled `ts`, `density`, `channels`, etc.
|
151 |
+
raw: Optional[AttrDict] = None
|
152 |
+
|
153 |
+
# Integration
|
154 |
+
volume_range: Optional[VolumeRange] = None
|
155 |
+
|
156 |
+
# If a ray intersects, the transmittance from t0 to t1 (e.g. the
|
157 |
+
# probability that the ray passes through this volume).
|
158 |
+
# has shape [batch_size, *inner_shape, 1]
|
159 |
+
transmittance: Optional[torch.Tensor] = None
|
160 |
+
|
161 |
+
def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegralResults":
|
162 |
+
"""
|
163 |
+
Combines the integration results of `self` over [t0, t1] and
|
164 |
+
`cur` over [t1, t2] to produce a new set of results over [t0, t2] by
|
165 |
+
using a similar equation to (4) in NeRF++:
|
166 |
+
|
167 |
+
integrate(
|
168 |
+
lambda t: density(t) * channels(t) * transmittance(t),
|
169 |
+
[t0, t2]
|
170 |
+
)
|
171 |
+
|
172 |
+
= integrate(
|
173 |
+
lambda t: density(t) * channels(t) * transmittance(t),
|
174 |
+
[t0, t1]
|
175 |
+
) + transmittance(t1) * integrate(
|
176 |
+
lambda t: density(t) * channels(t) * transmittance(t),
|
177 |
+
[t1, t2]
|
178 |
+
)
|
179 |
+
"""
|
180 |
+
assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0)
|
181 |
+
|
182 |
+
def _combine_fn(
|
183 |
+
prev_val: Optional[torch.Tensor],
|
184 |
+
cur_val: Optional[torch.Tensor],
|
185 |
+
*,
|
186 |
+
prev_transmittance: torch.Tensor,
|
187 |
+
):
|
188 |
+
assert prev_val is not None
|
189 |
+
if cur_val is None:
|
190 |
+
# cur_output.aux_losses are empty for the void_model.
|
191 |
+
return prev_val
|
192 |
+
return prev_val + prev_transmittance * cur_val
|
193 |
+
|
194 |
+
output = self.output.combine(
|
195 |
+
cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance)
|
196 |
+
)
|
197 |
+
|
198 |
+
combined = RayVolumeIntegralResults(
|
199 |
+
output=output,
|
200 |
+
volume_range=self.volume_range.extend(cur.volume_range),
|
201 |
+
transmittance=self.transmittance * cur.transmittance,
|
202 |
+
)
|
203 |
+
return combined
|
204 |
+
|
205 |
+
|
206 |
+
@dataclass
|
207 |
+
class RayVolumeIntegral:
|
208 |
+
model: NeRFModel
|
209 |
+
volume: Volume
|
210 |
+
sampler: "RaySampler"
|
211 |
+
n_samples: int
|
212 |
+
|
213 |
+
def render_rays(
|
214 |
+
self,
|
215 |
+
origin: torch.Tensor,
|
216 |
+
direction: torch.Tensor,
|
217 |
+
t0: Optional[torch.Tensor] = None,
|
218 |
+
prev_raw: Optional[AttrDict] = None,
|
219 |
+
shared: bool = False,
|
220 |
+
render_with_direction: bool = True,
|
221 |
+
) -> "RayVolumeIntegralResults":
|
222 |
+
"""
|
223 |
+
Perform volumetric rendering over the given volume.
|
224 |
+
|
225 |
+
:param position: [batch_size, *shape, 3]
|
226 |
+
:param direction: [batch_size, *shape, 3]
|
227 |
+
:param t0: Optional [batch_size, *shape, 1]
|
228 |
+
:param prev_raw: the raw outputs when using multiple levels with this model.
|
229 |
+
:param shared: means the same model is used for all RayVolumeIntegral's
|
230 |
+
:param render_with_direction: use the incoming ray direction when querying the model.
|
231 |
+
|
232 |
+
:return: RayVolumeIntegralResults
|
233 |
+
"""
|
234 |
+
# 1. Intersect the rays with the current volume and sample ts to
|
235 |
+
# integrate along.
|
236 |
+
vrange = self.volume.intersect(origin, direction, t0_lower=t0)
|
237 |
+
ts = self.sampler.sample(vrange.t0, vrange.t1, self.n_samples)
|
238 |
+
|
239 |
+
if prev_raw is not None and not shared:
|
240 |
+
# Append the previous ts now before fprop because previous
|
241 |
+
# rendering used a different model and we can't reuse the output.
|
242 |
+
ts = torch.sort(torch.cat([ts, prev_raw.ts], dim=-2), dim=-2).values
|
243 |
+
|
244 |
+
# Shape sanity checks
|
245 |
+
batch_size, *_shape, _t0_dim = vrange.t0.shape
|
246 |
+
_, *ts_shape, _ts_dim = ts.shape
|
247 |
+
|
248 |
+
# 2. Get the points along the ray and query the model
|
249 |
+
directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
|
250 |
+
positions = origin.unsqueeze(-2) + ts * directions
|
251 |
+
|
252 |
+
optional_directions = directions if render_with_direction else None
|
253 |
+
mids = (ts[..., 1:, :] + ts[..., :-1, :]) / 2
|
254 |
+
raw = self.model(
|
255 |
+
Query(
|
256 |
+
position=positions,
|
257 |
+
direction=optional_directions,
|
258 |
+
t_min=torch.cat([vrange.t0[..., None, :], mids], dim=-2),
|
259 |
+
t_max=torch.cat([mids, vrange.t1[..., None, :]], dim=-2),
|
260 |
+
)
|
261 |
+
)
|
262 |
+
raw.ts = ts
|
263 |
+
|
264 |
+
if prev_raw is not None and shared:
|
265 |
+
# We can append the additional queries to previous raw outputs
|
266 |
+
# before integration
|
267 |
+
copy = prev_raw.copy()
|
268 |
+
result = torch.sort(torch.cat([raw.pop("ts"), copy.pop("ts")], dim=-2), dim=-2)
|
269 |
+
merge_results = partial(self._merge_results, dim=-2, indices=result.indices)
|
270 |
+
raw = raw.combine(copy, merge_results)
|
271 |
+
raw.ts = result.values
|
272 |
+
|
273 |
+
# 3. Integrate the raw results
|
274 |
+
output, transmittance = self.integrate_samples(vrange, raw)
|
275 |
+
|
276 |
+
# 4. Clean up results that do not intersect with the volume.
|
277 |
+
transmittance = torch.where(
|
278 |
+
vrange.intersected, transmittance, torch.ones_like(transmittance)
|
279 |
+
)
|
280 |
+
|
281 |
+
def _mask_fn(_key: str, tensor: torch.Tensor):
|
282 |
+
return torch.where(vrange.intersected, tensor, torch.zeros_like(tensor))
|
283 |
+
|
284 |
+
def _is_tensor(_key: str, value: Any):
|
285 |
+
return isinstance(value, torch.Tensor)
|
286 |
+
|
287 |
+
output = output.map(map_fn=_mask_fn, should_map=_is_tensor)
|
288 |
+
|
289 |
+
return RayVolumeIntegralResults(
|
290 |
+
output=output,
|
291 |
+
raw=raw,
|
292 |
+
volume_range=vrange,
|
293 |
+
transmittance=transmittance,
|
294 |
+
)
|
295 |
+
|
296 |
+
def integrate_samples(
|
297 |
+
self,
|
298 |
+
volume_range: VolumeRange,
|
299 |
+
raw: AttrDict,
|
300 |
+
) -> Tuple[AttrDict, torch.Tensor]:
|
301 |
+
"""
|
302 |
+
Integrate the raw.channels along with other aux_losses and values to
|
303 |
+
produce the final output dictionary containing rendered `channels`,
|
304 |
+
estimated `distances` and `aux_losses`.
|
305 |
+
|
306 |
+
:param volume_range: Specifies the integral range [t0, t1]
|
307 |
+
:param raw: Contains a dict of function evaluations at ts. Should have
|
308 |
+
|
309 |
+
density: torch.Tensor [batch_size, *shape, n_samples, 1]
|
310 |
+
channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
|
311 |
+
aux_losses: {key: torch.Tensor [batch_size, *shape, n_samples, 1] for each key}
|
312 |
+
no_weight_grad_aux_losses: an optional set of losses for which the weights
|
313 |
+
should be detached before integration.
|
314 |
+
|
315 |
+
after the call, integrate_samples populates some intermediate calculations
|
316 |
+
for later use like
|
317 |
+
|
318 |
+
weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density *
|
319 |
+
transmittance)[i] weight for each rgb output at [..., i, :].
|
320 |
+
:returns: a tuple of (
|
321 |
+
a dictionary of rendered outputs and aux_losses,
|
322 |
+
transmittance of this volume,
|
323 |
+
)
|
324 |
+
"""
|
325 |
+
|
326 |
+
# 1. Calculate the weights
|
327 |
+
_, _, dt = volume_range.partition(raw.ts)
|
328 |
+
ddensity = raw.density * dt
|
329 |
+
|
330 |
+
mass = torch.cumsum(ddensity, dim=-2)
|
331 |
+
transmittance = torch.exp(-mass[..., -1, :])
|
332 |
+
|
333 |
+
alphas = 1.0 - torch.exp(-ddensity)
|
334 |
+
Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
|
335 |
+
# This is the probability of light hitting and reflecting off of
|
336 |
+
# something at depth [..., i, :].
|
337 |
+
weights = alphas * Ts
|
338 |
+
|
339 |
+
# 2. Integrate all results
|
340 |
+
def _integrate(key: str, samples: torch.Tensor, weights: torch.Tensor):
|
341 |
+
if key == "density":
|
342 |
+
# Omit integrating the density, because we don't need it
|
343 |
+
return None
|
344 |
+
return torch.sum(samples * weights, dim=-2)
|
345 |
+
|
346 |
+
def _is_tensor(_key: str, value: Any):
|
347 |
+
return isinstance(value, torch.Tensor)
|
348 |
+
|
349 |
+
if raw.no_weight_grad_aux_losses:
|
350 |
+
extra_aux_losses = raw.no_weight_grad_aux_losses.map(
|
351 |
+
partial(_integrate, weights=weights.detach()), should_map=_is_tensor
|
352 |
+
)
|
353 |
+
else:
|
354 |
+
extra_aux_losses = {}
|
355 |
+
output = raw.map(partial(_integrate, weights=weights), should_map=_is_tensor)
|
356 |
+
if "no_weight_grad_aux_losses" in output:
|
357 |
+
del output["no_weight_grad_aux_losses"]
|
358 |
+
output.aux_losses.update(extra_aux_losses)
|
359 |
+
|
360 |
+
# Integrating the ts yields the distance away from the origin; rename the variable.
|
361 |
+
output.distances = output.ts
|
362 |
+
del output["ts"]
|
363 |
+
del output["density"]
|
364 |
+
|
365 |
+
assert output.distances.shape == (*output.channels.shape[:-1], 1)
|
366 |
+
assert output.channels.shape[:-1] == raw.channels.shape[:-2]
|
367 |
+
assert output.channels.shape[-1] == raw.channels.shape[-1]
|
368 |
+
|
369 |
+
# 3. Reduce loss
|
370 |
+
def _reduce_loss(_key: str, loss: torch.Tensor):
|
371 |
+
return loss.view(loss.shape[0], -1).sum(dim=-1)
|
372 |
+
|
373 |
+
# 4. Store other useful calculations
|
374 |
+
raw.weights = weights
|
375 |
+
|
376 |
+
output.aux_losses = output.aux_losses.map(_reduce_loss)
|
377 |
+
|
378 |
+
return output, transmittance
|
379 |
+
|
380 |
+
def _merge_results(
|
381 |
+
self, a: Optional[torch.Tensor], b: torch.Tensor, dim: int, indices: torch.Tensor
|
382 |
+
):
|
383 |
+
"""
|
384 |
+
:param a: [..., n_a, ...]. The other dictionary containing the b's may
|
385 |
+
contain extra tensors from earlier calculations, so a can be None.
|
386 |
+
:param b: [..., n_b, ...]
|
387 |
+
:param dim: dimension to merge
|
388 |
+
:param indices: how the merged results should be sorted at the end
|
389 |
+
:return: a concatted and sorted tensor of size [..., n_a + n_b, ...]
|
390 |
+
"""
|
391 |
+
if a is None:
|
392 |
+
return None
|
393 |
+
|
394 |
+
merged = torch.cat([a, b], dim=dim)
|
395 |
+
return torch.gather(merged, dim=dim, index=torch.broadcast_to(indices, merged.shape))
|
396 |
+
|
397 |
+
|
398 |
+
class RaySampler(ABC):
|
399 |
+
@abstractmethod
|
400 |
+
def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
|
401 |
+
"""
|
402 |
+
:param t0: start time has shape [batch_size, *shape, 1]
|
403 |
+
:param t1: finish time has shape [batch_size, *shape, 1]
|
404 |
+
:param n_samples: number of ts to sample
|
405 |
+
:return: sampled ts of shape [batch_size, *shape, n_samples, 1]
|
406 |
+
"""
|
407 |
+
|
408 |
+
|
409 |
+
class StratifiedRaySampler(RaySampler):
|
410 |
+
"""
|
411 |
+
Instead of fixed intervals, a sample is drawn uniformly at random from each
|
412 |
+
interval.
|
413 |
+
"""
|
414 |
+
|
415 |
+
def __init__(self, depth_mode: str = "linear"):
|
416 |
+
"""
|
417 |
+
:param depth_mode: linear samples ts linearly in depth. harmonic ensures
|
418 |
+
closer points are sampled more densely.
|
419 |
+
"""
|
420 |
+
self.depth_mode = depth_mode
|
421 |
+
assert self.depth_mode in ("linear", "geometric", "harmonic")
|
422 |
+
|
423 |
+
def sample(
|
424 |
+
self,
|
425 |
+
t0: torch.Tensor,
|
426 |
+
t1: torch.Tensor,
|
427 |
+
n_samples: int,
|
428 |
+
epsilon: float = 1e-3,
|
429 |
+
) -> torch.Tensor:
|
430 |
+
"""
|
431 |
+
:param t0: start time has shape [batch_size, *shape, 1]
|
432 |
+
:param t1: finish time has shape [batch_size, *shape, 1]
|
433 |
+
:param n_samples: number of ts to sample
|
434 |
+
:return: sampled ts of shape [batch_size, *shape, n_samples, 1]
|
435 |
+
"""
|
436 |
+
ones = [1] * (len(t0.shape) - 1)
|
437 |
+
ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)
|
438 |
+
|
439 |
+
if self.depth_mode == "linear":
|
440 |
+
ts = t0 * (1.0 - ts) + t1 * ts
|
441 |
+
elif self.depth_mode == "geometric":
|
442 |
+
ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
|
443 |
+
elif self.depth_mode == "harmonic":
|
444 |
+
# The original NeRF recommends this interpolation scheme for
|
445 |
+
# spherical scenes, but there could be some weird edge cases when
|
446 |
+
# the observer crosses from the inner to outer volume.
|
447 |
+
ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)
|
448 |
+
|
449 |
+
mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
|
450 |
+
upper = torch.cat([mids, t1], dim=-1)
|
451 |
+
lower = torch.cat([t0, mids], dim=-1)
|
452 |
+
t_rand = torch.rand_like(ts)
|
453 |
+
|
454 |
+
ts = lower + (upper - lower) * t_rand
|
455 |
+
return ts.unsqueeze(-1)
|
456 |
+
|
457 |
+
|
458 |
+
class ImportanceRaySampler(RaySampler):
|
459 |
+
"""
|
460 |
+
Given the initial estimate of densities, this samples more from
|
461 |
+
regions/bins expected to have objects.
|
462 |
+
"""
|
463 |
+
|
464 |
+
def __init__(
|
465 |
+
self, volume_range: VolumeRange, raw: AttrDict, blur_pool: bool = False, alpha: float = 1e-5
|
466 |
+
):
|
467 |
+
"""
|
468 |
+
:param volume_range: the range in which a ray intersects the given volume.
|
469 |
+
:param raw: dictionary of raw outputs from the NeRF models of shape
|
470 |
+
[batch_size, *shape, n_coarse_samples, 1]. Should at least contain
|
471 |
+
|
472 |
+
:param ts: earlier samples from the coarse rendering step
|
473 |
+
:param weights: discretized version of density * transmittance
|
474 |
+
:param blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
|
475 |
+
:param alpha: small value to add to weights.
|
476 |
+
"""
|
477 |
+
self.volume_range = volume_range
|
478 |
+
self.ts = raw.ts.clone().detach()
|
479 |
+
self.weights = raw.weights.clone().detach()
|
480 |
+
self.blur_pool = blur_pool
|
481 |
+
self.alpha = alpha
|
482 |
+
|
483 |
+
@torch.no_grad()
|
484 |
+
def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
|
485 |
+
"""
|
486 |
+
:param t0: start time has shape [batch_size, *shape, 1]
|
487 |
+
:param t1: finish time has shape [batch_size, *shape, 1]
|
488 |
+
:param n_samples: number of ts to sample
|
489 |
+
:return: sampled ts of shape [batch_size, *shape, n_samples, 1]
|
490 |
+
"""
|
491 |
+
lower, upper, _ = self.volume_range.partition(self.ts)
|
492 |
+
|
493 |
+
batch_size, *shape, n_coarse_samples, _ = self.ts.shape
|
494 |
+
|
495 |
+
weights = self.weights
|
496 |
+
if self.blur_pool:
|
497 |
+
padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
|
498 |
+
maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
|
499 |
+
weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
|
500 |
+
weights = weights + self.alpha
|
501 |
+
pmf = weights / weights.sum(dim=-2, keepdim=True)
|
502 |
+
inds = sample_pmf(pmf, n_samples)
|
503 |
+
assert inds.shape == (batch_size, *shape, n_samples, 1)
|
504 |
+
assert (inds >= 0).all() and (inds < n_coarse_samples).all()
|
505 |
+
|
506 |
+
t_rand = torch.rand(inds.shape, device=inds.device)
|
507 |
+
lower_ = torch.gather(lower, -2, inds)
|
508 |
+
upper_ = torch.gather(upper, -2, inds)
|
509 |
+
|
510 |
+
ts = lower_ + (upper_ - lower_) * t_rand
|
511 |
+
ts = torch.sort(ts, dim=-2).values
|
512 |
+
return ts
|
shap_e/models/nerf/renderer.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from shap_e.models.nn.meta import subdict
|
7 |
+
from shap_e.models.renderer import RayRenderer
|
8 |
+
from shap_e.models.volume import Volume
|
9 |
+
from shap_e.util.collections import AttrDict
|
10 |
+
|
11 |
+
from .model import NeRFModel
|
12 |
+
from .ray import RayVolumeIntegral, StratifiedRaySampler, render_rays
|
13 |
+
|
14 |
+
|
15 |
+
class TwoStepNeRFRenderer(RayRenderer):
|
16 |
+
"""
|
17 |
+
Coarse and fine-grained rendering as proposed by NeRF. This class
|
18 |
+
additionally supports background rendering like NeRF++.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
n_coarse_samples: int,
|
24 |
+
n_fine_samples: int,
|
25 |
+
void_model: NeRFModel,
|
26 |
+
fine_model: NeRFModel,
|
27 |
+
volume: Volume,
|
28 |
+
coarse_model: Optional[NeRFModel] = None,
|
29 |
+
coarse_background_model: Optional[NeRFModel] = None,
|
30 |
+
fine_background_model: Optional[NeRFModel] = None,
|
31 |
+
outer_volume: Optional[Volume] = None,
|
32 |
+
foreground_stratified_depth_sampling_mode: str = "linear",
|
33 |
+
background_stratified_depth_sampling_mode: str = "linear",
|
34 |
+
importance_sampling_options: Optional[Dict[str, Any]] = None,
|
35 |
+
channel_scale: float = 255,
|
36 |
+
device: torch.device = torch.device("cuda"),
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
:param outer_volume: is where distant objects are encoded.
|
41 |
+
"""
|
42 |
+
super().__init__(**kwargs)
|
43 |
+
|
44 |
+
if coarse_model is None:
|
45 |
+
assert (
|
46 |
+
fine_background_model is None or coarse_background_model is None
|
47 |
+
), "models should be shared for both fg and bg"
|
48 |
+
|
49 |
+
self.n_coarse_samples = n_coarse_samples
|
50 |
+
self.n_fine_samples = n_fine_samples
|
51 |
+
self.void_model = void_model
|
52 |
+
self.coarse_model = coarse_model
|
53 |
+
self.fine_model = fine_model
|
54 |
+
self.volume = volume
|
55 |
+
self.coarse_background_model = coarse_background_model
|
56 |
+
self.fine_background_model = fine_background_model
|
57 |
+
self.outer_volume = outer_volume
|
58 |
+
self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode
|
59 |
+
self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode
|
60 |
+
self.importance_sampling_options = AttrDict(importance_sampling_options or {})
|
61 |
+
self.channel_scale = channel_scale
|
62 |
+
self.device = device
|
63 |
+
self.to(device)
|
64 |
+
|
65 |
+
if self.coarse_background_model is not None:
|
66 |
+
assert self.fine_background_model is not None
|
67 |
+
assert self.outer_volume is not None
|
68 |
+
|
69 |
+
def render_rays(
|
70 |
+
self,
|
71 |
+
batch: Dict,
|
72 |
+
params: Optional[Dict] = None,
|
73 |
+
options: Optional[Dict] = None,
|
74 |
+
) -> AttrDict:
|
75 |
+
params = self.update(params)
|
76 |
+
|
77 |
+
batch = AttrDict(batch)
|
78 |
+
if options is None:
|
79 |
+
options = AttrDict()
|
80 |
+
options.setdefault("render_background", True)
|
81 |
+
options.setdefault("render_with_direction", True)
|
82 |
+
options.setdefault("n_coarse_samples", self.n_coarse_samples)
|
83 |
+
options.setdefault("n_fine_samples", self.n_fine_samples)
|
84 |
+
options.setdefault(
|
85 |
+
"foreground_stratified_depth_sampling_mode",
|
86 |
+
self.foreground_stratified_depth_sampling_mode,
|
87 |
+
)
|
88 |
+
options.setdefault(
|
89 |
+
"background_stratified_depth_sampling_mode",
|
90 |
+
self.background_stratified_depth_sampling_mode,
|
91 |
+
)
|
92 |
+
|
93 |
+
shared = self.coarse_model is None
|
94 |
+
|
95 |
+
# First, render rays using the coarse models with stratified ray samples.
|
96 |
+
coarse_model, coarse_key = (
|
97 |
+
(self.fine_model, "fine_model") if shared else (self.coarse_model, "coarse_model")
|
98 |
+
)
|
99 |
+
coarse_model = partial(
|
100 |
+
coarse_model,
|
101 |
+
params=subdict(params, coarse_key),
|
102 |
+
options=options,
|
103 |
+
)
|
104 |
+
parts = [
|
105 |
+
RayVolumeIntegral(
|
106 |
+
model=coarse_model,
|
107 |
+
volume=self.volume,
|
108 |
+
sampler=StratifiedRaySampler(
|
109 |
+
depth_mode=options.foreground_stratified_depth_sampling_mode,
|
110 |
+
),
|
111 |
+
n_samples=options.n_coarse_samples,
|
112 |
+
),
|
113 |
+
]
|
114 |
+
if options.render_background and self.outer_volume is not None:
|
115 |
+
coarse_background_model, coarse_background_key = (
|
116 |
+
(self.fine_background_model, "fine_background_model")
|
117 |
+
if shared
|
118 |
+
else (self.coarse_background_model, "coarse_background_model")
|
119 |
+
)
|
120 |
+
coarse_background_model = partial(
|
121 |
+
coarse_background_model,
|
122 |
+
params=subdict(params, coarse_background_key),
|
123 |
+
options=options,
|
124 |
+
)
|
125 |
+
parts.append(
|
126 |
+
RayVolumeIntegral(
|
127 |
+
model=coarse_background_model,
|
128 |
+
volume=self.outer_volume,
|
129 |
+
sampler=StratifiedRaySampler(
|
130 |
+
depth_mode=options.background_stratified_depth_sampling_mode,
|
131 |
+
),
|
132 |
+
n_samples=options.n_coarse_samples,
|
133 |
+
)
|
134 |
+
)
|
135 |
+
coarse_results, samplers, coarse_raw_outputs = render_rays(
|
136 |
+
batch.rays,
|
137 |
+
parts,
|
138 |
+
partial(self.void_model, options=options),
|
139 |
+
shared=shared,
|
140 |
+
render_with_direction=options.render_with_direction,
|
141 |
+
importance_sampling_options=AttrDict(self.importance_sampling_options),
|
142 |
+
)
|
143 |
+
|
144 |
+
# Then, render rays using the fine models with importance-weighted ray samples.
|
145 |
+
fine_model = partial(
|
146 |
+
self.fine_model,
|
147 |
+
params=subdict(params, "fine_model"),
|
148 |
+
options=options,
|
149 |
+
)
|
150 |
+
parts = [
|
151 |
+
RayVolumeIntegral(
|
152 |
+
model=fine_model,
|
153 |
+
volume=self.volume,
|
154 |
+
sampler=samplers[0],
|
155 |
+
n_samples=options.n_fine_samples,
|
156 |
+
),
|
157 |
+
]
|
158 |
+
if options.render_background and self.outer_volume is not None:
|
159 |
+
fine_background_model = partial(
|
160 |
+
self.fine_background_model,
|
161 |
+
params=subdict(params, "fine_background_model"),
|
162 |
+
options=options,
|
163 |
+
)
|
164 |
+
parts.append(
|
165 |
+
RayVolumeIntegral(
|
166 |
+
model=fine_background_model,
|
167 |
+
volume=self.outer_volume,
|
168 |
+
sampler=samplers[1],
|
169 |
+
n_samples=options.n_fine_samples,
|
170 |
+
)
|
171 |
+
)
|
172 |
+
fine_results, *_ = render_rays(
|
173 |
+
batch.rays,
|
174 |
+
parts,
|
175 |
+
partial(self.void_model, options=options),
|
176 |
+
shared=shared,
|
177 |
+
prev_raw_outputs=coarse_raw_outputs,
|
178 |
+
render_with_direction=options.render_with_direction,
|
179 |
+
)
|
180 |
+
|
181 |
+
# Combine results
|
182 |
+
aux_losses = fine_results.output.aux_losses.copy()
|
183 |
+
for key, val in coarse_results.output.aux_losses.items():
|
184 |
+
aux_losses[key + "_coarse"] = val
|
185 |
+
|
186 |
+
return AttrDict(
|
187 |
+
channels=fine_results.output.channels * self.channel_scale,
|
188 |
+
channels_coarse=coarse_results.output.channels * self.channel_scale,
|
189 |
+
distances=fine_results.output.distances,
|
190 |
+
transmittance=fine_results.transmittance,
|
191 |
+
transmittance_coarse=coarse_results.transmittance,
|
192 |
+
t0=fine_results.volume_range.t0,
|
193 |
+
t1=fine_results.volume_range.t1,
|
194 |
+
intersected=fine_results.volume_range.intersected,
|
195 |
+
aux_losses=aux_losses,
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
class OneStepNeRFRenderer(RayRenderer):
|
200 |
+
"""
|
201 |
+
Renders rays using stratified sampling only unlike vanilla NeRF.
|
202 |
+
The same setup as NeRF++.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
n_samples: int,
|
208 |
+
void_model: NeRFModel,
|
209 |
+
foreground_model: NeRFModel,
|
210 |
+
volume: Volume,
|
211 |
+
background_model: Optional[NeRFModel] = None,
|
212 |
+
outer_volume: Optional[Volume] = None,
|
213 |
+
foreground_stratified_depth_sampling_mode: str = "linear",
|
214 |
+
background_stratified_depth_sampling_mode: str = "linear",
|
215 |
+
channel_scale: float = 255,
|
216 |
+
device: torch.device = torch.device("cuda"),
|
217 |
+
**kwargs,
|
218 |
+
):
|
219 |
+
super().__init__(**kwargs)
|
220 |
+
self.n_samples = n_samples
|
221 |
+
self.void_model = void_model
|
222 |
+
self.foreground_model = foreground_model
|
223 |
+
self.volume = volume
|
224 |
+
self.background_model = background_model
|
225 |
+
self.outer_volume = outer_volume
|
226 |
+
self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode
|
227 |
+
self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode
|
228 |
+
self.channel_scale = channel_scale
|
229 |
+
self.device = device
|
230 |
+
self.to(device)
|
231 |
+
|
232 |
+
def render_rays(
|
233 |
+
self,
|
234 |
+
batch: Dict,
|
235 |
+
params: Optional[Dict] = None,
|
236 |
+
options: Optional[Dict] = None,
|
237 |
+
) -> AttrDict:
|
238 |
+
params = self.update(params)
|
239 |
+
|
240 |
+
batch = AttrDict(batch)
|
241 |
+
if options is None:
|
242 |
+
options = AttrDict()
|
243 |
+
options.setdefault("render_background", True)
|
244 |
+
options.setdefault("render_with_direction", True)
|
245 |
+
options.setdefault("n_samples", self.n_samples)
|
246 |
+
options.setdefault(
|
247 |
+
"foreground_stratified_depth_sampling_mode",
|
248 |
+
self.foreground_stratified_depth_sampling_mode,
|
249 |
+
)
|
250 |
+
options.setdefault(
|
251 |
+
"background_stratified_depth_sampling_mode",
|
252 |
+
self.background_stratified_depth_sampling_mode,
|
253 |
+
)
|
254 |
+
|
255 |
+
foreground_model = partial(
|
256 |
+
self.foreground_model,
|
257 |
+
params=subdict(params, "foreground_model"),
|
258 |
+
options=options,
|
259 |
+
)
|
260 |
+
parts = [
|
261 |
+
RayVolumeIntegral(
|
262 |
+
model=foreground_model,
|
263 |
+
volume=self.volume,
|
264 |
+
sampler=StratifiedRaySampler(
|
265 |
+
depth_mode=options.foreground_stratified_depth_sampling_mode
|
266 |
+
),
|
267 |
+
n_samples=options.n_samples,
|
268 |
+
),
|
269 |
+
]
|
270 |
+
if options.render_background and self.outer_volume is not None:
|
271 |
+
background_model = partial(
|
272 |
+
self.background_model,
|
273 |
+
params=subdict(params, "background_model"),
|
274 |
+
options=options,
|
275 |
+
)
|
276 |
+
parts.append(
|
277 |
+
RayVolumeIntegral(
|
278 |
+
model=background_model,
|
279 |
+
volume=self.outer_volume,
|
280 |
+
sampler=StratifiedRaySampler(
|
281 |
+
depth_mode=options.background_stratified_depth_sampling_mode
|
282 |
+
),
|
283 |
+
n_samples=options.n_samples,
|
284 |
+
)
|
285 |
+
)
|
286 |
+
results, *_ = render_rays(
|
287 |
+
batch.rays,
|
288 |
+
parts,
|
289 |
+
self.void_model,
|
290 |
+
render_with_direction=options.render_with_direction,
|
291 |
+
)
|
292 |
+
|
293 |
+
return AttrDict(
|
294 |
+
channels=results.output.channels * self.channel_scale,
|
295 |
+
distances=results.output.distances,
|
296 |
+
transmittance=results.transmittance,
|
297 |
+
t0=results.volume_range.t0,
|
298 |
+
t1=results.volume_range.t1,
|
299 |
+
intersected=results.volume_range.intersected,
|
300 |
+
aux_losses=results.output.aux_losses,
|
301 |
+
)
|
shap_e/models/nerstf/__pycache__/mlp.cpython-39.pyc
ADDED
Binary file (4.74 kB). View file
|
|
shap_e/models/nerstf/__pycache__/renderer.cpython-39.pyc
ADDED
Binary file (6.65 kB). View file
|
|
shap_e/models/nerstf/mlp.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from shap_e.models.nn.ops import get_act
|
6 |
+
from shap_e.models.query import Query
|
7 |
+
from shap_e.models.stf.mlp import MLPModel
|
8 |
+
from shap_e.util.collections import AttrDict
|
9 |
+
|
10 |
+
|
11 |
+
class MLPDensitySDFModel(MLPModel):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
initial_bias: float = -0.1,
|
15 |
+
sdf_activation="tanh",
|
16 |
+
density_activation="exp",
|
17 |
+
**kwargs,
|
18 |
+
):
|
19 |
+
super().__init__(
|
20 |
+
n_output=2,
|
21 |
+
output_activation="identity",
|
22 |
+
**kwargs,
|
23 |
+
)
|
24 |
+
self.mlp[-1].bias[0].data.fill_(initial_bias)
|
25 |
+
self.sdf_activation = get_act(sdf_activation)
|
26 |
+
self.density_activation = get_act(density_activation)
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self,
|
30 |
+
query: Query,
|
31 |
+
params: Optional[Dict[str, torch.Tensor]] = None,
|
32 |
+
options: Optional[Dict[str, Any]] = None,
|
33 |
+
) -> AttrDict[str, Any]:
|
34 |
+
# query.direction is None typically for SDF models and training
|
35 |
+
h, _h_directionless = self._mlp(
|
36 |
+
query.position, query.direction, params=params, options=options
|
37 |
+
)
|
38 |
+
h_sdf, h_density = h.split(1, dim=-1)
|
39 |
+
return AttrDict(
|
40 |
+
density=self.density_activation(h_density),
|
41 |
+
signed_distance=self.sdf_activation(h_sdf),
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
class MLPNeRSTFModel(MLPModel):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
sdf_activation="tanh",
|
49 |
+
density_activation="exp",
|
50 |
+
channel_activation="sigmoid",
|
51 |
+
direction_dependent_shape: bool = True, # To be able to load old models. Set this to be False in future models.
|
52 |
+
separate_nerf_channels: bool = False,
|
53 |
+
separate_coarse_channels: bool = False,
|
54 |
+
initial_density_bias: float = 0.0,
|
55 |
+
initial_sdf_bias: float = -0.1,
|
56 |
+
**kwargs,
|
57 |
+
):
|
58 |
+
h_map, h_directionless_map = indices_for_output_mode(
|
59 |
+
direction_dependent_shape=direction_dependent_shape,
|
60 |
+
separate_nerf_channels=separate_nerf_channels,
|
61 |
+
separate_coarse_channels=separate_coarse_channels,
|
62 |
+
)
|
63 |
+
n_output = index_mapping_max(h_map)
|
64 |
+
super().__init__(
|
65 |
+
n_output=n_output,
|
66 |
+
output_activation="identity",
|
67 |
+
**kwargs,
|
68 |
+
)
|
69 |
+
self.direction_dependent_shape = direction_dependent_shape
|
70 |
+
self.separate_nerf_channels = separate_nerf_channels
|
71 |
+
self.separate_coarse_channels = separate_coarse_channels
|
72 |
+
self.sdf_activation = get_act(sdf_activation)
|
73 |
+
self.density_activation = get_act(density_activation)
|
74 |
+
self.channel_activation = get_act(channel_activation)
|
75 |
+
self.h_map = h_map
|
76 |
+
self.h_directionless_map = h_directionless_map
|
77 |
+
self.mlp[-1].bias.data.zero_()
|
78 |
+
layer = -1 if self.direction_dependent_shape else self.insert_direction_at
|
79 |
+
self.mlp[layer].bias[0].data.fill_(initial_sdf_bias)
|
80 |
+
self.mlp[layer].bias[1].data.fill_(initial_density_bias)
|
81 |
+
|
82 |
+
def forward(
|
83 |
+
self,
|
84 |
+
query: Query,
|
85 |
+
params: Optional[Dict[str, torch.Tensor]] = None,
|
86 |
+
options: Optional[Dict[str, Any]] = None,
|
87 |
+
) -> AttrDict[str, Any]:
|
88 |
+
|
89 |
+
options = AttrDict() if options is None else AttrDict(options)
|
90 |
+
h, h_directionless = self._mlp(
|
91 |
+
query.position, query.direction, params=params, options=options
|
92 |
+
)
|
93 |
+
activations = map_indices_to_keys(self.h_map, h)
|
94 |
+
activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless))
|
95 |
+
|
96 |
+
if options.nerf_level == "coarse":
|
97 |
+
h_density = activations.density_coarse
|
98 |
+
else:
|
99 |
+
h_density = activations.density_fine
|
100 |
+
|
101 |
+
if options.get("rendering_mode", "stf") == "nerf":
|
102 |
+
if options.nerf_level == "coarse":
|
103 |
+
h_channels = activations.nerf_coarse
|
104 |
+
else:
|
105 |
+
h_channels = activations.nerf_fine
|
106 |
+
else:
|
107 |
+
h_channels = activations.stf
|
108 |
+
return AttrDict(
|
109 |
+
density=self.density_activation(h_density),
|
110 |
+
signed_distance=self.sdf_activation(activations.sdf),
|
111 |
+
channels=self.channel_activation(h_channels),
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
IndexMapping = AttrDict[str, Tuple[int, int]]
|
116 |
+
|
117 |
+
|
118 |
+
def indices_for_output_mode(
|
119 |
+
direction_dependent_shape: bool,
|
120 |
+
separate_nerf_channels: bool,
|
121 |
+
separate_coarse_channels: bool,
|
122 |
+
) -> Tuple[IndexMapping, IndexMapping]:
|
123 |
+
"""
|
124 |
+
Get output mappings for (h, h_directionless).
|
125 |
+
"""
|
126 |
+
h_map = AttrDict()
|
127 |
+
h_directionless_map = AttrDict()
|
128 |
+
if direction_dependent_shape:
|
129 |
+
h_map.sdf = (0, 1)
|
130 |
+
if separate_coarse_channels:
|
131 |
+
assert separate_nerf_channels
|
132 |
+
h_map.density_coarse = (1, 2)
|
133 |
+
h_map.density_fine = (2, 3)
|
134 |
+
h_map.stf = (3, 6)
|
135 |
+
h_map.nerf_coarse = (6, 9)
|
136 |
+
h_map.nerf_fine = (9, 12)
|
137 |
+
else:
|
138 |
+
h_map.density_coarse = (1, 2)
|
139 |
+
h_map.density_fine = (1, 2)
|
140 |
+
if separate_nerf_channels:
|
141 |
+
h_map.stf = (2, 5)
|
142 |
+
h_map.nerf_coarse = (5, 8)
|
143 |
+
h_map.nerf_fine = (5, 8)
|
144 |
+
else:
|
145 |
+
h_map.stf = (2, 5)
|
146 |
+
h_map.nerf_coarse = (2, 5)
|
147 |
+
h_map.nerf_fine = (2, 5)
|
148 |
+
else:
|
149 |
+
h_directionless_map.sdf = (0, 1)
|
150 |
+
h_directionless_map.density_coarse = (1, 2)
|
151 |
+
if separate_coarse_channels:
|
152 |
+
h_directionless_map.density_fine = (2, 3)
|
153 |
+
else:
|
154 |
+
h_directionless_map.density_fine = h_directionless_map.density_coarse
|
155 |
+
h_map.stf = (0, 3)
|
156 |
+
if separate_coarse_channels:
|
157 |
+
assert separate_nerf_channels
|
158 |
+
h_map.nerf_coarse = (3, 6)
|
159 |
+
h_map.nerf_fine = (6, 9)
|
160 |
+
else:
|
161 |
+
if separate_nerf_channels:
|
162 |
+
h_map.nerf_coarse = (3, 6)
|
163 |
+
else:
|
164 |
+
h_map.nerf_coarse = (0, 3)
|
165 |
+
h_map.nerf_fine = h_map.nerf_coarse
|
166 |
+
return h_map, h_directionless_map
|
167 |
+
|
168 |
+
|
169 |
+
def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]:
|
170 |
+
return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()})
|
171 |
+
|
172 |
+
|
173 |
+
def index_mapping_max(mapping: IndexMapping) -> int:
|
174 |
+
return max(end for _, (_, end) in mapping.items())
|