silentchen commited on
Commit
19c4ddf
1 Parent(s): f633cbe

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +375 -0
  2. requirements.txt +17 -0
  3. shap_e/.DS_Store +0 -0
  4. shap_e/__init__.py +0 -0
  5. shap_e/__pycache__/__init__.cpython-39.pyc +0 -0
  6. shap_e/diffusion/__init__.py +0 -0
  7. shap_e/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  8. shap_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc +0 -0
  9. shap_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc +0 -0
  10. shap_e/diffusion/__pycache__/sample.cpython-39.pyc +0 -0
  11. shap_e/diffusion/gaussian_diffusion.py +1143 -0
  12. shap_e/diffusion/k_diffusion.py +426 -0
  13. shap_e/diffusion/sample.py +160 -0
  14. shap_e/examples/encode_model.ipynb +93 -0
  15. shap_e/examples/sample_image_to_3d.ipynb +125 -0
  16. shap_e/examples/sample_text_to_3d.ipynb +124 -0
  17. shap_e/models/__init__.py +0 -0
  18. shap_e/models/__pycache__/__init__.cpython-39.pyc +0 -0
  19. shap_e/models/__pycache__/configs.cpython-39.pyc +0 -0
  20. shap_e/models/__pycache__/download.cpython-39.pyc +0 -0
  21. shap_e/models/__pycache__/query.cpython-39.pyc +0 -0
  22. shap_e/models/__pycache__/renderer.cpython-39.pyc +0 -0
  23. shap_e/models/__pycache__/volume.cpython-39.pyc +0 -0
  24. shap_e/models/configs.py +166 -0
  25. shap_e/models/download.py +152 -0
  26. shap_e/models/generation/__init__.py +0 -0
  27. shap_e/models/generation/__pycache__/__init__.cpython-39.pyc +0 -0
  28. shap_e/models/generation/__pycache__/latent_diffusion.cpython-39.pyc +0 -0
  29. shap_e/models/generation/__pycache__/perceiver.cpython-39.pyc +0 -0
  30. shap_e/models/generation/__pycache__/pooled_mlp.cpython-39.pyc +0 -0
  31. shap_e/models/generation/__pycache__/pretrained_clip.cpython-39.pyc +0 -0
  32. shap_e/models/generation/__pycache__/transformer.cpython-39.pyc +0 -0
  33. shap_e/models/generation/__pycache__/util.cpython-39.pyc +0 -0
  34. shap_e/models/generation/latent_diffusion.py +32 -0
  35. shap_e/models/generation/perceiver.py +244 -0
  36. shap_e/models/generation/pooled_mlp.py +74 -0
  37. shap_e/models/generation/pretrained_clip.py +270 -0
  38. shap_e/models/generation/transformer.py +494 -0
  39. shap_e/models/generation/util.py +23 -0
  40. shap_e/models/nerf/__init__.py +0 -0
  41. shap_e/models/nerf/__pycache__/__init__.cpython-39.pyc +0 -0
  42. shap_e/models/nerf/__pycache__/model.cpython-39.pyc +0 -0
  43. shap_e/models/nerf/__pycache__/ray.cpython-39.pyc +0 -0
  44. shap_e/models/nerf/__pycache__/renderer.cpython-39.pyc +0 -0
  45. shap_e/models/nerf/model.py +255 -0
  46. shap_e/models/nerf/ray.py +512 -0
  47. shap_e/models/nerf/renderer.py +301 -0
  48. shap_e/models/nerstf/__pycache__/mlp.cpython-39.pyc +0 -0
  49. shap_e/models/nerstf/__pycache__/renderer.cpython-39.pyc +0 -0
  50. shap_e/models/nerstf/mlp.py +174 -0
app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from functools import partial
5
+ from typing import Optional
6
+ from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
7
+ from shap_e.diffusion.sample import sample_latents
8
+ from shap_e.models.download import load_model, load_config
9
+ from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
10
+ import trimesh
11
+ import torch.nn as nn
12
+ import os
13
+ import random
14
+ import warnings
15
+ from huggingface_hub import hf_hub_download
16
+ import hashlib
17
+
18
+ import sys
19
+
20
+ sys.tracebacklimit = 0
21
+ def set_seed(seed=1024):
22
+ random.seed(seed)
23
+ np.random.seed(seed)
24
+ torch.manual_seed(seed)
25
+ if torch.cuda.is_available():
26
+ torch.cuda.manual_seed(seed)
27
+ torch.cuda.manual_seed_all(seed)
28
+ torch.backends.cudnn.deterministic = True
29
+
30
+ def freeze_params(params):
31
+ for param in params:
32
+ param.requires_grad = False
33
+
34
+ class Blocks(gr.Blocks):
35
+
36
+ def __init__(
37
+ self,
38
+ theme: str = "default",
39
+ analytics_enabled: Optional[bool] = None,
40
+ mode: str = "blocks",
41
+ title: str = "Gradio",
42
+ css: Optional[str] = None,
43
+ **kwargs,
44
+ ):
45
+ self.extra_configs = {
46
+ 'thumbnail': kwargs.pop('thumbnail', ''),
47
+ 'url': kwargs.pop('url', 'https://gradio.app/'),
48
+ 'creator': kwargs.pop('creator', '@teamGradio'),
49
+ }
50
+
51
+ super(Blocks, self).__init__(theme, analytics_enabled, mode, title, css, **kwargs)
52
+ warnings.filterwarnings("ignore")
53
+
54
+ def get_config_file(self):
55
+ config = super(Blocks, self).get_config_file()
56
+
57
+ for k, v in self.extra_configs.items():
58
+ config[k] = v
59
+
60
+ return config
61
+ def optimize_all(xm, models, initial_noise, noise_start_t, diffusion, latent_model, device, prompt, instruction, rand_seed):
62
+ state = {}
63
+ out_gen_1, out_gen_2, out_gen_3, out_gen_4, state = generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state)
64
+ edited_1, edited_2, edited_3, edited_4, state = _3d_editing(xm, models, diffusion, initial_noise, noise_start_t, device, instruction, rand_seed, state)
65
+ print(state)
66
+ return out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4
67
+ def generate_3d_with_shap_e(xm, diffusion, latent_model, device, prompt, rand_seed, state):
68
+ set_seed(rand_seed)
69
+ batch_size = 4
70
+ guidance_scale = 15.0
71
+ xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device)
72
+ xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device)
73
+ xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
74
+
75
+ print("prompt: ", prompt, "rand_seed: ", rand_seed, "state:", state)
76
+ latents = sample_latents(
77
+ batch_size=batch_size,
78
+ model=latent_model,
79
+ diffusion=diffusion,
80
+ guidance_scale=guidance_scale,
81
+ model_kwargs=dict(texts=[prompt] * batch_size),
82
+ progress=True,
83
+ clip_denoised=True,
84
+ use_fp16=True,
85
+ use_karras=True,
86
+ karras_steps=64,
87
+ sigma_min=1e-3,
88
+ sigma_max=160,
89
+ s_churn=0,
90
+ )
91
+ prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed)).encode('utf-8')).hexdigest())
92
+ mesh_path = []
93
+ output_path = './logs'
94
+ os.makedirs(os.path.join(output_path, 'source'), exist_ok=True)
95
+ state['latent'] = []
96
+ state['prompt'] = prompt
97
+ state['rand_seed_1'] = rand_seed
98
+ for i, latent in enumerate(latents):
99
+
100
+ output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
101
+ t_obj = decode_latent_mesh(xm, latent).tri_mesh()
102
+ with open(output_path_tmp, 'w') as f:
103
+ t_obj.write_obj(f)
104
+
105
+ mesh = trimesh.load_mesh(output_path_tmp)
106
+ angle = np.radians(180)
107
+ axis = [0, 1, 0]
108
+ rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
109
+ mesh.apply_transform(rotation_matrix)
110
+ angle = np.radians(90)
111
+ axis = [1, 0, 0]
112
+ rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
113
+ mesh.apply_transform(rotation_matrix)
114
+ output_path_tmp = os.path.join(output_path, 'source', '{}_{}.obj'.format(prompt_hash, i))
115
+ mesh.export(output_path_tmp)
116
+ state['latent'].append(latent.clone().detach())
117
+ mesh_path.append(output_path_tmp)
118
+
119
+ return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
120
+
121
+ def _3d_editing(xm, models, diffusion, initial_noise, start_t, device, instruction, rand_seed, state):
122
+ set_seed(rand_seed)
123
+ mesh_path = []
124
+ prompt = state['prompt']
125
+ rand_seed_1 = state['rand_seed_1']
126
+ print("prompt: ", prompt, "rand_seed: ", rand_seed, "instruction:", instruction, "state:", state)
127
+ prompt_hash = str(hashlib.sha256((prompt + '_' + str(rand_seed_1) + '_' + instruction + '_' + str(rand_seed)).encode('utf-8')).hexdigest())
128
+ if 'santa' in instruction:
129
+ e_type = 'santa_hat'
130
+ elif 'rainbow' in instruction:
131
+ e_type = 'rainbow'
132
+ elif 'gold' in instruction:
133
+ e_type = 'golden'
134
+ elif 'lego' in instruction:
135
+ e_type = 'lego'
136
+ elif 'wooden' in instruction:
137
+ e_type = 'wooden'
138
+ elif 'cyber' in instruction:
139
+ e_type = 'cyber'
140
+
141
+ # import pdb; pdb.set_trace()
142
+ model = models[e_type].to(device)
143
+ noise_initial = initial_noise[e_type].to(device)
144
+ noise_start_t = start_t[e_type]
145
+ general_save_path = './logs/edited'
146
+ os.makedirs(general_save_path, exist_ok=True)
147
+ for i, latent in enumerate(state['latent']):
148
+ latent = latent.to(device)
149
+ text_embeddings_clip = model.cached_model_kwargs(1, dict(texts=[instruction]))
150
+ print("shape of latent: ", latent.clone().unsqueeze(0).shape, "instruction: ", instruction)
151
+ ref_latent = latent.clone().unsqueeze(0)
152
+ t_1 = torch.randint(noise_start_t, noise_start_t + 1, (1,), device=device).long()
153
+
154
+ noise_input = diffusion.q_sample(ref_latent, t_1, noise=noise_initial)
155
+ out_1 = diffusion.p_mean_variance(model, noise_input, t_1, clip_denoised=True,
156
+ model_kwargs=text_embeddings_clip,
157
+ condition_latents=ref_latent)
158
+
159
+ updated_latents = out_1['pred_xstart']
160
+
161
+ if 'santa' in instruction:
162
+ xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.25]).to(device)
163
+ xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1]).to(device)
164
+ xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
165
+
166
+ else:
167
+ xm.renderer.volume.bbox_max = torch.tensor([1.0, 1.0, 1.0]).to(device)
168
+ xm.renderer.volume.bbox_min = torch.tensor([-1.0, -1.0, -1.0]).to(device)
169
+ xm.renderer.volume.bbox = torch.stack([xm.renderer.volume.bbox_min, xm.renderer.volume.bbox_max])
170
+
171
+ for latent_idx, updated_latent in enumerate(updated_latents):
172
+ output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i))
173
+
174
+ t = decode_latent_mesh(xm, updated_latent).tri_mesh()
175
+ with open(output_path, 'w') as f:
176
+ t.write_obj(f)
177
+ mesh = trimesh.load_mesh(output_path)
178
+
179
+ angle = np.radians(180)
180
+ axis = [0, 1, 0]
181
+
182
+ rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
183
+ mesh.apply_transform(rotation_matrix)
184
+ angle = np.radians(90)
185
+ axis = [1, 0, 0]
186
+
187
+ rotation_matrix = trimesh.transformations.rotation_matrix(angle, axis)
188
+ mesh.apply_transform(rotation_matrix)
189
+
190
+ output_path = os.path.join(general_save_path, '{}_{}.obj'.format(prompt_hash, i))
191
+ mesh.export(output_path)
192
+ mesh_path.append(output_path)
193
+ return mesh_path[0], mesh_path[1], mesh_path[2], mesh_path[3], state
194
+ def main():
195
+
196
+ css = """
197
+ #img2img_image, #img2img_image > .fixed-height, #img2img_image > .fixed-height > div, #img2img_image > .fixed-height > div > img
198
+ {
199
+ height: var(--height) !important;
200
+ max-height: var(--height) !important;
201
+ min-height: var(--height) !important;
202
+ }
203
+ #paper-info a {
204
+ color:#008AD7;
205
+ text-decoration: none;
206
+ }
207
+ #paper-info a:hover {
208
+ cursor: pointer;
209
+ text-decoration: none;
210
+ }
211
+
212
+ .tooltip {
213
+ color: #555;
214
+ position: relative;
215
+ display: inline-block;
216
+ cursor: pointer;
217
+ }
218
+
219
+ .tooltip .tooltiptext {
220
+ visibility: hidden;
221
+ width: 400px;
222
+ background-color: #555;
223
+ color: #fff;
224
+ text-align: center;
225
+ padding: 5px;
226
+ border-radius: 5px;
227
+ position: absolute;
228
+ z-index: 1; /* Set z-index to 1 */
229
+ left: 10px;
230
+ top: 100%;
231
+ opacity: 0;
232
+ transition: opacity 0.3s;
233
+ }
234
+
235
+ .tooltip:hover .tooltiptext {
236
+ visibility: visible;
237
+ opacity: 1;
238
+ z-index: 9999; /* Set a high z-index value when hovering */
239
+ }
240
+
241
+
242
+ """
243
+
244
+ rescale_js = """
245
+ function(x) {
246
+ const root = document.querySelector('gradio-app').shadowRoot || document.querySelector('gradio-app');
247
+ let image_scale = parseFloat(root.querySelector('#image_scale input').value) || 1.0;
248
+ const image_width = root.querySelector('#img2img_image').clientWidth;
249
+ const target_height = parseInt(image_width * image_scale);
250
+ document.body.style.setProperty('--height', `${target_height}px`);
251
+ root.querySelectorAll('button.justify-center.rounded')[0].style.display='none';
252
+ root.querySelectorAll('button.justify-center.rounded')[1].style.display='none';
253
+ return x;
254
+ }
255
+ """
256
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
257
+ latent_model = load_model('text300M', device=device)
258
+ xm = load_model('transmitter', device=device)
259
+ diffusion = diffusion_from_config(load_config('diffusion'))
260
+ freeze_params(xm.parameters())
261
+ models = dict()
262
+ initial_noise = dict()
263
+ noise_start_t = dict()
264
+ editing_types = ['rainbow', 'santa_hat', 'lego', 'golden', 'wooden', 'cyber']
265
+
266
+ for editing_type in editing_types:
267
+ tmp_model = load_model('text300M', device=device)
268
+ with torch.no_grad():
269
+ new_proj = nn.Linear(1024 * 2, 1024, device=device, dtype=tmp_model.wrapped.input_proj.weight.dtype)
270
+ new_proj.weight = nn.Parameter(torch.zeros_like(new_proj.weight))
271
+ new_proj.weight[:, :1024].copy_(tmp_model.wrapped.input_proj.weight) #
272
+ new_proj.bias = nn.Parameter(torch.zeros_like(new_proj.bias))
273
+ new_proj.bias[:1024].copy_(tmp_model.wrapped.input_proj.bias)
274
+ tmp_model.wrapped.input_proj = new_proj
275
+
276
+ ckp = torch.load(hf_hub_download(repo_id='silentchen/Shap_Editor', subfolder='single', filename='{}.pt'.format(editing_type)), map_location='cpu')
277
+ tmp_model.load_state_dict(ckp['model'])
278
+ noise_initial = ckp['initial_noise']['noise'].to(device)
279
+ initial_noise[editing_type] = noise_initial
280
+ noise_start_t[editing_type] = ckp['t_start']
281
+ models[editing_type] = tmp_model
282
+
283
+ with Blocks(
284
+ css=css,
285
+ analytics_enabled=False,
286
+ title="SHAPE-EDITOR demo",
287
+ ) as demo:
288
+ description = """<p style="text-align: center; font-weight: bold;">
289
+ <span style="font-size: 28px"> <span style="font-size: 140%">S</span>HAP-<span style="font-size: 140%">E</span>DITOR: Instruction-guided <br> Latent 3D Editing in Seconds</span>
290
+ <br>
291
+ <span style="font-size: 18px" id="paper-info">
292
+ [<a href=" " target="_blank">Project Page</a>]
293
+ [<a href=" " target="_blank">Paper</a>]
294
+ [<a href=" " target="_blank">GitHub</a>]
295
+ </span>
296
+ </p>
297
+ """
298
+ state = gr.State({})
299
+ gr.HTML(description)
300
+ with gr.Column():
301
+ with gr.Column():
302
+ gr.HTML('<span style="font-size: 20px; font-weight: bold">Step 1: generate original 3D object using Shap-E.</span>')
303
+ prompt = gr.Textbox(
304
+ label="Text prompt for initial 3D generation", lines=1
305
+ )
306
+ gen_btn = gr.Button(value='Generate', scale=1)
307
+
308
+
309
+ with gr.Column():
310
+ gr.HTML('<span style="font-size: 20px; font-weight: bold">Generated 3D objects</span>')
311
+ with gr.Row():
312
+ out_gen_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 1 (step 1)")
313
+ out_gen_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 2 (step 1)")
314
+ out_gen_3 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 3 (step 1)")
315
+ out_gen_4 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 4 (step 1)")
316
+
317
+ with gr.Column(scale=1):
318
+ gr.HTML('<span style="font-size: 20px; font-weight: bold">Step 2: apply 3D editing with S</span>HAP-<span style="font-size: 140%">E</span>DITOR.</span>')
319
+
320
+ editing_choice = gr.Dropdown(
321
+ ["Add a santa hat to it", "Make it look like made of gold", "Make the color of it look like rainbow", "Make it in cyberpunk style", "Make it wooden", "Make it look like make of lego"], value='Add a santa hat to it', multiselect=False, label="Editing effects", info="Select specific editing you want to apply!"
322
+ ),
323
+ apply_btn = gr.Button(value='Editing', scale=1)
324
+
325
+ with gr.Column(scale=3):
326
+ gr.HTML('<span style="font-size: 20px; font-weight: bold">Edited 3D objects</span>')
327
+ with gr.Row():
328
+ edited_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 1 (step 2)")
329
+ edited_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 2 (step 2)")
330
+ edited_3 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 3 (step 2)")
331
+ edited_4 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], visible=True, label="3D Model 4 (step 2)")
332
+
333
+
334
+ with gr.Accordion("Advanced Options", open=False):
335
+ rand_seed = gr.Slider(minimum=0, maximum=1000, step=1, value=445, label="Random seed")
336
+
337
+ gen_btn.click(
338
+ fn=partial(generate_3d_with_shap_e, xm, diffusion, latent_model, device),
339
+ inputs=[prompt, rand_seed, state],
340
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, state],
341
+ queue=False)
342
+
343
+ apply_btn.click(
344
+ fn=partial(_3d_editing, xm, models, diffusion, initial_noise, noise_start_t, device),
345
+ inputs=[
346
+ editing_choice[0], rand_seed, state
347
+ ],
348
+ outputs=[edited_1, edited_2, edited_3, edited_4, state],
349
+ queue=True
350
+ )
351
+ print("Generate examples...")
352
+ with gr.Column():
353
+ gr.Examples(
354
+ examples=[
355
+ [ "a corgi",
356
+ "Make the color of it look like rainbow",
357
+ 456,
358
+ ],
359
+ ["a penguin",
360
+ "Make it look like make of lego",
361
+ 214,
362
+ ],
363
+ ],
364
+ inputs=[prompt, editing_choice[0], rand_seed],
365
+ outputs=[out_gen_1, out_gen_2, out_gen_3, out_gen_4, edited_1, edited_2, edited_3, edited_4],
366
+ fn=partial(optimize_all, xm, models, initial_noise, noise_start_t, diffusion, latent_model, device),
367
+ cache_examples=True,
368
+ )
369
+
370
+
371
+ demo.queue(max_size=10, api_open=False)
372
+ demo.launch(share=True, show_api=False, show_error=True)
373
+
374
+ if __name__ == '__main__':
375
+ main()
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ filelock
2
+ pillow
3
+ torch
4
+ fire
5
+ humanize
6
+ requests
7
+ tqdm
8
+ matplot
9
+ scikit-image
10
+ scipy
11
+ numpy
12
+ blobfile
13
+ clip @ git+https://github.com/openai/CLIP.git
14
+ trimesh
15
+
16
+ # gradio demo
17
+ gradio
shap_e/.DS_Store ADDED
Binary file (6.15 kB). View file
 
shap_e/__init__.py ADDED
File without changes
shap_e/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (156 Bytes). View file
 
shap_e/diffusion/__init__.py ADDED
File without changes
shap_e/diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (166 Bytes). View file
 
shap_e/diffusion/__pycache__/gaussian_diffusion.cpython-39.pyc ADDED
Binary file (33.9 kB). View file
 
shap_e/diffusion/__pycache__/k_diffusion.cpython-39.pyc ADDED
Binary file (12.4 kB). View file
 
shap_e/diffusion/__pycache__/sample.cpython-39.pyc ADDED
Binary file (3.71 kB). View file
 
shap_e/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ """
4
+
5
+ import math
6
+ from typing import Any, Dict, Iterable, Optional, Sequence, Union
7
+
8
+ import blobfile as bf
9
+ import numpy as np
10
+ import torch as th
11
+ import yaml
12
+
13
+
14
+ def diffusion_from_config(config: Union[str, Dict[str, Any]]) -> "GaussianDiffusion":
15
+ if isinstance(config, str):
16
+ with bf.BlobFile(config, "rb") as f:
17
+ obj = yaml.load(f, Loader=yaml.SafeLoader)
18
+ return diffusion_from_config(obj)
19
+
20
+ schedule = config["schedule"]
21
+ steps = config["timesteps"]
22
+ respace = config.get("respacing", None)
23
+ mean_type = config.get("mean_type", "epsilon")
24
+ betas = get_named_beta_schedule(schedule, steps, **config.get("schedule_args", {}))
25
+ channel_scales = config.get("channel_scales", None)
26
+ channel_biases = config.get("channel_biases", None)
27
+ if channel_scales is not None:
28
+ channel_scales = np.array(channel_scales)
29
+ if channel_biases is not None:
30
+ channel_biases = np.array(channel_biases)
31
+ kwargs = dict(
32
+ betas=betas,
33
+ model_mean_type=mean_type,
34
+ model_var_type="learned_range",
35
+ loss_type="mse",
36
+ channel_scales=channel_scales,
37
+ channel_biases=channel_biases,
38
+ )
39
+ if respace is None:
40
+ return GaussianDiffusion(**kwargs)
41
+ else:
42
+ return SpacedDiffusion(use_timesteps=space_timesteps(steps, respace), **kwargs)
43
+
44
+
45
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
46
+ """
47
+ This is the deprecated API for creating beta schedules.
48
+
49
+ See get_named_beta_schedule() for the new library of schedules.
50
+ """
51
+ if beta_schedule == "linear":
52
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
53
+ else:
54
+ raise NotImplementedError(beta_schedule)
55
+ assert betas.shape == (num_diffusion_timesteps,)
56
+ return betas
57
+
58
+
59
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, **extra_args: float):
60
+ """
61
+ Get a pre-defined beta schedule for the given name.
62
+
63
+ The beta schedule library consists of beta schedules which remain similar
64
+ in the limit of num_diffusion_timesteps.
65
+ Beta schedules may be added, but should not be removed or changed once
66
+ they are committed to maintain backwards compatibility.
67
+ """
68
+ if schedule_name == "linear":
69
+ # Linear schedule from Ho et al, extended to work for any number of
70
+ # diffusion steps.
71
+ scale = 1000 / num_diffusion_timesteps
72
+ return get_beta_schedule(
73
+ "linear",
74
+ beta_start=scale * 0.0001,
75
+ beta_end=scale * 0.02,
76
+ num_diffusion_timesteps=num_diffusion_timesteps,
77
+ )
78
+ elif schedule_name == "cosine":
79
+ return betas_for_alpha_bar(
80
+ num_diffusion_timesteps,
81
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
82
+ )
83
+ elif schedule_name == "inv_parabola":
84
+ exponent = extra_args.get("power", 2.0)
85
+ return betas_for_alpha_bar(
86
+ num_diffusion_timesteps,
87
+ lambda t: 1 - t**exponent,
88
+ )
89
+ elif schedule_name == "translated_parabola":
90
+ exponent = extra_args.get("power", 2.0)
91
+ return betas_for_alpha_bar(
92
+ num_diffusion_timesteps,
93
+ lambda t: (1 - t) ** exponent,
94
+ )
95
+ elif schedule_name == "exp":
96
+ coefficient = extra_args.get("coefficient", -12.0)
97
+ return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.exp(t * coefficient))
98
+ else:
99
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
100
+
101
+
102
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
103
+ """
104
+ Create a beta schedule that discretizes the given alpha_t_bar function,
105
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
106
+
107
+ :param num_diffusion_timesteps: the number of betas to produce.
108
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
109
+ produces the cumulative product of (1-beta) up to that
110
+ part of the diffusion process.
111
+ :param max_beta: the maximum beta to use; use values lower than 1 to
112
+ prevent singularities.
113
+ """
114
+ betas = []
115
+ for i in range(num_diffusion_timesteps):
116
+ t1 = i / num_diffusion_timesteps
117
+ t2 = (i + 1) / num_diffusion_timesteps
118
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
119
+ return np.array(betas)
120
+
121
+
122
+ def space_timesteps(num_timesteps, section_counts):
123
+ """
124
+ Create a list of timesteps to use from an original diffusion process,
125
+ given the number of timesteps we want to take from equally-sized portions
126
+ of the original process.
127
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
128
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
129
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
130
+ :param num_timesteps: the number of diffusion steps in the original
131
+ process to divide up.
132
+ :param section_counts: either a list of numbers, or a string containing
133
+ comma-separated numbers, indicating the step count
134
+ per section. As a special case, use "ddimN" where N
135
+ is a number of steps to use the striding from the
136
+ DDIM paper.
137
+ :return: a set of diffusion steps from the original process to use.
138
+ """
139
+ if isinstance(section_counts, str):
140
+ if section_counts.startswith("ddim"):
141
+ desired_count = int(section_counts[len("ddim") :])
142
+ for i in range(1, num_timesteps):
143
+ if len(range(0, num_timesteps, i)) == desired_count:
144
+ return set(range(0, num_timesteps, i))
145
+ raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
146
+ elif section_counts.startswith("exact"):
147
+ res = set(int(x) for x in section_counts[len("exact") :].split(","))
148
+ for x in res:
149
+ if x < 0 or x >= num_timesteps:
150
+ raise ValueError(f"timestep out of bounds: {x}")
151
+ return res
152
+ section_counts = [int(x) for x in section_counts.split(",")]
153
+ size_per = num_timesteps // len(section_counts)
154
+ extra = num_timesteps % len(section_counts)
155
+ start_idx = 0
156
+ all_steps = []
157
+ for i, section_count in enumerate(section_counts):
158
+ size = size_per + (1 if i < extra else 0)
159
+ if size < section_count:
160
+ raise ValueError(f"cannot divide section of {size} steps into {section_count}")
161
+ if section_count <= 1:
162
+ frac_stride = 1
163
+ else:
164
+ frac_stride = (size - 1) / (section_count - 1)
165
+ cur_idx = 0.0
166
+ taken_steps = []
167
+ for _ in range(section_count):
168
+ taken_steps.append(start_idx + round(cur_idx))
169
+ cur_idx += frac_stride
170
+ all_steps += taken_steps
171
+ start_idx += size
172
+ return set(all_steps)
173
+
174
+
175
+ class GaussianDiffusion:
176
+ """
177
+ Utilities for training and sampling diffusion models.
178
+
179
+ Ported directly from here:
180
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
181
+
182
+ :param betas: a 1-D array of betas for each diffusion timestep from T to 1.
183
+ :param model_mean_type: a string determining what the model outputs.
184
+ :param model_var_type: a string determining how variance is output.
185
+ :param loss_type: a string determining the loss function to use.
186
+ :param discretized_t0: if True, use discrete gaussian loss for t=0. Only
187
+ makes sense for images.
188
+ :param channel_scales: a multiplier to apply to x_start in training_losses
189
+ and sampling functions.
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ *,
195
+ betas: Sequence[float],
196
+ model_mean_type: str,
197
+ model_var_type: str,
198
+ loss_type: str,
199
+ discretized_t0: bool = False,
200
+ channel_scales: Optional[np.ndarray] = None,
201
+ channel_biases: Optional[np.ndarray] = None,
202
+ ):
203
+ self.model_mean_type = model_mean_type
204
+ self.model_var_type = model_var_type
205
+ self.loss_type = loss_type
206
+ self.discretized_t0 = discretized_t0
207
+ self.channel_scales = channel_scales
208
+ self.channel_biases = channel_biases
209
+
210
+ # Use float64 for accuracy.
211
+ betas = np.array(betas, dtype=np.float64)
212
+ self.betas = betas
213
+ assert len(betas.shape) == 1, "betas must be 1-D"
214
+ assert (betas > 0).all() and (betas <= 1).all()
215
+
216
+ self.num_timesteps = int(betas.shape[0])
217
+
218
+ alphas = 1.0 - betas
219
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
220
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
221
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
222
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
223
+
224
+ # calculations for diffusion q(x_t | x_{t-1}) and others
225
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
226
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
227
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
228
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
229
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
230
+
231
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
232
+ self.posterior_variance = (
233
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
234
+ )
235
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
236
+ self.posterior_log_variance_clipped = np.log(
237
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
238
+ )
239
+ self.posterior_mean_coef1 = (
240
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
241
+ )
242
+ self.posterior_mean_coef2 = (
243
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
244
+ )
245
+
246
+ def get_sigmas(self, t):
247
+ return _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, t.shape)
248
+
249
+ def q_mean_variance(self, x_start, t):
250
+ """
251
+ Get the distribution q(x_t | x_0).
252
+
253
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
254
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
255
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
256
+ """
257
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
258
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
259
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
260
+ return mean, variance, log_variance
261
+
262
+ def q_sample(self, x_start, t, noise=None):
263
+ """
264
+ Diffuse the data for a given number of diffusion steps.
265
+
266
+ In other words, sample from q(x_t | x_0).
267
+
268
+ :param x_start: the initial data batch.
269
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
270
+ :param noise: if specified, the split-out normal noise.
271
+ :return: A noisy version of x_start.
272
+ """
273
+ if noise is None:
274
+ noise = th.randn_like(x_start)
275
+ assert noise.shape == x_start.shape
276
+ return (
277
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
278
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
279
+ )
280
+
281
+ def q_posterior_mean_variance(self, x_start, x_t, t):
282
+ """
283
+ Compute the mean and variance of the diffusion posterior:
284
+
285
+ q(x_{t-1} | x_t, x_0)
286
+
287
+ """
288
+ assert x_start.shape == x_t.shape
289
+ posterior_mean = (
290
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
291
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
292
+ )
293
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
294
+ posterior_log_variance_clipped = _extract_into_tensor(
295
+ self.posterior_log_variance_clipped, t, x_t.shape
296
+ )
297
+ assert (
298
+ posterior_mean.shape[0]
299
+ == posterior_variance.shape[0]
300
+ == posterior_log_variance_clipped.shape[0]
301
+ == x_start.shape[0]
302
+ )
303
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
304
+
305
+ def p_mean_variance(
306
+ self, model, x, t, clip_denoised=False, denoised_fn=None, model_kwargs=None, condition_latents=None
307
+ ):
308
+ """
309
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
310
+ the initial x, x_0.
311
+
312
+ :param model: the model, which takes a signal and a batch of timesteps
313
+ as input.
314
+ :param x: the [N x C x ...] tensor at time t.
315
+ :param t: a 1-D Tensor of timesteps.
316
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
317
+ :param denoised_fn: if not None, a function which applies to the
318
+ x_start prediction before it is used to sample. Applies before
319
+ clip_denoised.
320
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
321
+ pass to the model. This can be used for conditioning.
322
+ :return: a dict with the following keys:
323
+ - 'mean': the model mean output.
324
+ - 'variance': the model variance output.
325
+ - 'log_variance': the log of 'variance'.
326
+ - 'pred_xstart': the prediction for x_0.
327
+ """
328
+ if model_kwargs is None:
329
+ model_kwargs = {}
330
+ B, C = x.shape[:2]
331
+ assert t.shape == (B,)
332
+ model_output = model(x, t, **model_kwargs) if condition_latents is None else model(x, t, condition_latents, **model_kwargs)
333
+ if isinstance(model_output, tuple):
334
+ model_output, extra = model_output
335
+ else:
336
+ extra = None
337
+
338
+ if self.model_var_type in ["learned", "learned_range"]:
339
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
340
+ model_output, model_var_values = th.split(model_output, C, dim=1)
341
+ if self.model_var_type == "learned":
342
+ model_log_variance = model_var_values
343
+ model_variance = th.exp(model_log_variance)
344
+ else:
345
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
346
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
347
+ # The model_var_values is [-1, 1] for [min_var, max_var].
348
+ frac = (model_var_values + 1) / 2
349
+ model_log_variance = frac * max_log + (1 - frac) * min_log
350
+ model_variance = th.exp(model_log_variance)
351
+ else:
352
+ model_variance, model_log_variance = {
353
+ # for fixedlarge, we set the initial (log-)variance like so
354
+ # to get a better decoder log likelihood.
355
+ "fixed_large": (
356
+ np.append(self.posterior_variance[1], self.betas[1:]),
357
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
358
+ ),
359
+ "fixed_small": (
360
+ self.posterior_variance,
361
+ self.posterior_log_variance_clipped,
362
+ ),
363
+ }[self.model_var_type]
364
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
365
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
366
+
367
+ def process_xstart(x):
368
+ if denoised_fn is not None:
369
+ x = denoised_fn(x)
370
+ if clip_denoised:
371
+ return x.clamp(-1, 1)
372
+ return x
373
+
374
+ if self.model_mean_type == "x_prev":
375
+ pred_xstart = process_xstart(
376
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
377
+ )
378
+ model_mean = model_output
379
+ elif self.model_mean_type in ["x_start", "epsilon"]:
380
+ if self.model_mean_type == "x_start":
381
+ pred_xstart = process_xstart(model_output)
382
+ else:
383
+ pred_xstart = process_xstart(
384
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
385
+ )
386
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
387
+ else:
388
+ raise NotImplementedError(self.model_mean_type)
389
+
390
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
391
+ return {
392
+ "mean": model_mean,
393
+ "variance": model_variance,
394
+ "log_variance": model_log_variance,
395
+ "pred_xstart": pred_xstart,
396
+ "extra": extra,
397
+ }
398
+
399
+ def _predict_xstart_from_eps(self, x_t, t, eps):
400
+ assert x_t.shape == eps.shape
401
+ return (
402
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
403
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
404
+ )
405
+
406
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
407
+ assert x_t.shape == xprev.shape
408
+ return ( # (xprev - coef2*x_t) / coef1
409
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
410
+ - _extract_into_tensor(
411
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
412
+ )
413
+ * x_t
414
+ )
415
+
416
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
417
+ return (
418
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
419
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
420
+
421
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
422
+ """
423
+ Compute the mean for the previous step, given a function cond_fn that
424
+ computes the gradient of a conditional log probability with respect to
425
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
426
+ condition on y.
427
+
428
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
429
+ """
430
+ gradient = cond_fn(x, t, **(model_kwargs or {}))
431
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
432
+ return new_mean
433
+
434
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
435
+ """
436
+ Compute what the p_mean_variance output would have been, should the
437
+ model's score function be conditioned by cond_fn.
438
+
439
+ See condition_mean() for details on cond_fn.
440
+
441
+ Unlike condition_mean(), this instead uses the conditioning strategy
442
+ from Song et al (2020).
443
+ """
444
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
445
+
446
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
447
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **(model_kwargs or {}))
448
+
449
+ out = p_mean_var.copy()
450
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
451
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
452
+ return out
453
+
454
+ def p_sample(
455
+ self,
456
+ model,
457
+ x,
458
+ t,
459
+ clip_denoised=False,
460
+ denoised_fn=None,
461
+ cond_fn=None,
462
+ model_kwargs=None,
463
+ ):
464
+ """
465
+ Sample x_{t-1} from the model at the given timestep.
466
+
467
+ :param model: the model to sample from.
468
+ :param x: the current tensor at x_{t-1}.
469
+ :param t: the value of t, starting at 0 for the first diffusion step.
470
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
471
+ :param denoised_fn: if not None, a function which applies to the
472
+ x_start prediction before it is used to sample.
473
+ :param cond_fn: if not None, this is a gradient function that acts
474
+ similarly to the model.
475
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
476
+ pass to the model. This can be used for conditioning.
477
+ :return: a dict containing the following keys:
478
+ - 'sample': a random sample from the model.
479
+ - 'pred_xstart': a prediction of x_0.
480
+ """
481
+ out = self.p_mean_variance(
482
+ model,
483
+ x,
484
+ t,
485
+ clip_denoised=clip_denoised,
486
+ denoised_fn=denoised_fn,
487
+ model_kwargs=model_kwargs,
488
+ )
489
+ noise = th.randn_like(x)
490
+ nonzero_mask = (
491
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
492
+ ) # no noise when t == 0
493
+ if cond_fn is not None:
494
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
495
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
496
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
497
+
498
+ def p_sample_loop(
499
+ self,
500
+ model,
501
+ shape,
502
+ noise=None,
503
+ clip_denoised=False,
504
+ denoised_fn=None,
505
+ cond_fn=None,
506
+ model_kwargs=None,
507
+ device=None,
508
+ progress=False,
509
+ temp=1.0,
510
+ ):
511
+ """
512
+ Generate samples from the model.
513
+
514
+ :param model: the model module.
515
+ :param shape: the shape of the samples, (N, C, H, W).
516
+ :param noise: if specified, the noise from the encoder to sample.
517
+ Should be of the same shape as `shape`.
518
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
519
+ :param denoised_fn: if not None, a function which applies to the
520
+ x_start prediction before it is used to sample.
521
+ :param cond_fn: if not None, this is a gradient function that acts
522
+ similarly to the model.
523
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
524
+ pass to the model. This can be used for conditioning.
525
+ :param device: if specified, the device to create the samples on.
526
+ If not specified, use a model parameter's device.
527
+ :param progress: if True, show a tqdm progress bar.
528
+ :return: a non-differentiable batch of samples.
529
+ """
530
+ final = None
531
+ for sample in self.p_sample_loop_progressive(
532
+ model,
533
+ shape,
534
+ noise=noise,
535
+ clip_denoised=clip_denoised,
536
+ denoised_fn=denoised_fn,
537
+ cond_fn=cond_fn,
538
+ model_kwargs=model_kwargs,
539
+ device=device,
540
+ progress=progress,
541
+ temp=temp,
542
+ ):
543
+ final = sample
544
+ return final["sample"]
545
+
546
+ def p_sample_loop_progressive(
547
+ self,
548
+ model,
549
+ shape,
550
+ noise=None,
551
+ clip_denoised=False,
552
+ denoised_fn=None,
553
+ cond_fn=None,
554
+ model_kwargs=None,
555
+ device=None,
556
+ progress=False,
557
+ temp=1.0,
558
+ ):
559
+ """
560
+ Generate samples from the model and yield intermediate samples from
561
+ each timestep of diffusion.
562
+
563
+ Arguments are the same as p_sample_loop().
564
+ Returns a generator over dicts, where each dict is the return value of
565
+ p_sample().
566
+ """
567
+
568
+ if device is None:
569
+ device = next(model.parameters()).device
570
+ assert isinstance(shape, (tuple, list))
571
+ if noise is not None:
572
+ img = noise
573
+ else:
574
+ img = th.randn(*shape, device=device) * temp
575
+ indices = list(range(self.num_timesteps))[::-1]
576
+
577
+ if progress:
578
+ # Lazy import so that we don't depend on tqdm.
579
+ from tqdm.auto import tqdm
580
+
581
+ indices = tqdm(indices)
582
+
583
+ for i in indices:
584
+ t = th.tensor([i] * shape[0], device=device)
585
+ with th.no_grad():
586
+ out = self.p_sample(
587
+ model,
588
+ img,
589
+ t,
590
+ clip_denoised=clip_denoised,
591
+ denoised_fn=denoised_fn,
592
+ cond_fn=cond_fn,
593
+ model_kwargs=model_kwargs,
594
+ )
595
+ yield self.unscale_out_dict(out)
596
+ img = out["sample"]
597
+
598
+ def ddim_sample(
599
+ self,
600
+ model,
601
+ x,
602
+ t,
603
+ clip_denoised=False,
604
+ denoised_fn=None,
605
+ cond_fn=None,
606
+ model_kwargs=None,
607
+ eta=0.0,
608
+ ):
609
+ """
610
+ Sample x_{t-1} from the model using DDIM.
611
+
612
+ Same usage as p_sample().
613
+ """
614
+ out = self.p_mean_variance(
615
+ model,
616
+ x,
617
+ t,
618
+ clip_denoised=clip_denoised,
619
+ denoised_fn=denoised_fn,
620
+ model_kwargs=model_kwargs,
621
+ )
622
+ if cond_fn is not None:
623
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
624
+
625
+ # Usually our model outputs epsilon, but we re-derive it
626
+ # in case we used x_start or x_prev prediction.
627
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
628
+
629
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
630
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
631
+ sigma = (
632
+ eta
633
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
634
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
635
+ )
636
+ # Equation 12.
637
+ noise = th.randn_like(x)
638
+ mean_pred = (
639
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
640
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
641
+ )
642
+ nonzero_mask = (
643
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
644
+ ) # no noise when t == 0
645
+ sample = mean_pred + nonzero_mask * sigma * noise
646
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
647
+
648
+ def ddim_reverse_sample(
649
+ self,
650
+ model,
651
+ x,
652
+ t,
653
+ clip_denoised=False,
654
+ denoised_fn=None,
655
+ cond_fn=None,
656
+ model_kwargs=None,
657
+ eta=0.0,
658
+ ):
659
+ """
660
+ Sample x_{t+1} from the model using DDIM reverse ODE.
661
+ """
662
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
663
+ out = self.p_mean_variance(
664
+ model,
665
+ x,
666
+ t,
667
+ clip_denoised=clip_denoised,
668
+ denoised_fn=denoised_fn,
669
+ model_kwargs=model_kwargs,
670
+ )
671
+ if cond_fn is not None:
672
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
673
+ # Usually our model outputs epsilon, but we re-derive it
674
+ # in case we used x_start or x_prev prediction.
675
+ eps = (
676
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
677
+ - out["pred_xstart"]
678
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
679
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
680
+
681
+ # Equation 12. reversed
682
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
683
+
684
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
685
+
686
+ def ddim_sample_loop(
687
+ self,
688
+ model,
689
+ shape,
690
+ noise=None,
691
+ clip_denoised=False,
692
+ denoised_fn=None,
693
+ cond_fn=None,
694
+ model_kwargs=None,
695
+ device=None,
696
+ progress=False,
697
+ eta=0.0,
698
+ temp=1.0,
699
+ ):
700
+ """
701
+ Generate samples from the model using DDIM.
702
+
703
+ Same usage as p_sample_loop().
704
+ """
705
+ final = None
706
+ for sample in self.ddim_sample_loop_progressive(
707
+ model,
708
+ shape,
709
+ noise=noise,
710
+ clip_denoised=clip_denoised,
711
+ denoised_fn=denoised_fn,
712
+ cond_fn=cond_fn,
713
+ model_kwargs=model_kwargs,
714
+ device=device,
715
+ progress=progress,
716
+ eta=eta,
717
+ temp=temp,
718
+ ):
719
+ final = sample
720
+ return final["sample"]
721
+
722
+ def ddim_sample_loop_progressive(
723
+ self,
724
+ model,
725
+ shape,
726
+ noise=None,
727
+ clip_denoised=False,
728
+ denoised_fn=None,
729
+ cond_fn=None,
730
+ model_kwargs=None,
731
+ device=None,
732
+ progress=False,
733
+ eta=0.0,
734
+ temp=1.0,
735
+ ):
736
+ """
737
+ Use DDIM to sample from the model and yield intermediate samples from
738
+ each timestep of DDIM.
739
+
740
+ Same usage as p_sample_loop_progressive().
741
+ """
742
+ if device is None:
743
+ device = next(model.parameters()).device
744
+ assert isinstance(shape, (tuple, list))
745
+ if noise is not None:
746
+ img = noise
747
+ else:
748
+ img = th.randn(*shape, device=device) * temp
749
+ indices = list(range(self.num_timesteps))[::-1]
750
+
751
+ if progress:
752
+ # Lazy import so that we don't depend on tqdm.
753
+ from tqdm.auto import tqdm
754
+
755
+ indices = tqdm(indices)
756
+
757
+ for i in indices:
758
+ t = th.tensor([i] * shape[0], device=device)
759
+ with th.no_grad():
760
+ out = self.ddim_sample(
761
+ model,
762
+ img,
763
+ t,
764
+ clip_denoised=clip_denoised,
765
+ denoised_fn=denoised_fn,
766
+ cond_fn=cond_fn,
767
+ model_kwargs=model_kwargs,
768
+ eta=eta,
769
+ )
770
+ yield self.unscale_out_dict(out)
771
+ img = out["sample"]
772
+
773
+ def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=False, model_kwargs=None):
774
+ """
775
+ Get a term for the variational lower-bound.
776
+
777
+ The resulting units are bits (rather than nats, as one might expect).
778
+ This allows for comparison to other papers.
779
+
780
+ :return: a dict with the following keys:
781
+ - 'output': a shape [N] tensor of NLLs or KLs.
782
+ - 'pred_xstart': the x_0 predictions.
783
+ """
784
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
785
+ x_start=x_start, x_t=x_t, t=t
786
+ )
787
+ out = self.p_mean_variance(
788
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
789
+ )
790
+ kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
791
+ kl = mean_flat(kl) / np.log(2.0)
792
+
793
+ decoder_nll = -discretized_gaussian_log_likelihood(
794
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
795
+ )
796
+ if not self.discretized_t0:
797
+ decoder_nll = th.zeros_like(decoder_nll)
798
+ assert decoder_nll.shape == x_start.shape
799
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
800
+
801
+ # At the first timestep return the decoder NLL,
802
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
803
+ output = th.where((t == 0), decoder_nll, kl)
804
+ return {
805
+ "output": output,
806
+ "pred_xstart": out["pred_xstart"],
807
+ "extra": out["extra"],
808
+ }
809
+
810
+ def training_losses(
811
+ self, model, x_start, t, model_kwargs=None, noise=None
812
+ ) -> Dict[str, th.Tensor]:
813
+ """
814
+ Compute training losses for a single timestep.
815
+
816
+ :param model: the model to evaluate loss on.
817
+ :param x_start: the [N x C x ...] tensor of inputs.
818
+ :param t: a batch of timestep indices.
819
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
820
+ pass to the model. This can be used for conditioning.
821
+ :param noise: if specified, the specific Gaussian noise to try to remove.
822
+ :return: a dict with the key "loss" containing a tensor of shape [N].
823
+ Some mean or variance settings may also have other keys.
824
+ """
825
+ x_start = self.scale_channels(x_start)
826
+ if model_kwargs is None:
827
+ model_kwargs = {}
828
+ if noise is None:
829
+ noise = th.randn_like(x_start)
830
+ x_t = self.q_sample(x_start, t, noise=noise)
831
+
832
+ terms = {}
833
+
834
+ if self.loss_type == "kl" or self.loss_type == "rescaled_kl":
835
+ vb_terms = self._vb_terms_bpd(
836
+ model=model,
837
+ x_start=x_start,
838
+ x_t=x_t,
839
+ t=t,
840
+ clip_denoised=False,
841
+ model_kwargs=model_kwargs,
842
+ )
843
+ terms["loss"] = vb_terms["output"]
844
+ if self.loss_type == "rescaled_kl":
845
+ terms["loss"] *= self.num_timesteps
846
+ extra = vb_terms["extra"]
847
+ elif self.loss_type == "mse" or self.loss_type == "rescaled_mse":
848
+ model_output = model(x_t, t, **model_kwargs)
849
+ if isinstance(model_output, tuple):
850
+ model_output, extra = model_output
851
+ else:
852
+ extra = {}
853
+
854
+ if self.model_var_type in [
855
+ "learned",
856
+ "learned_range",
857
+ ]:
858
+ B, C = x_t.shape[:2]
859
+ assert model_output.shape == (
860
+ B,
861
+ C * 2,
862
+ *x_t.shape[2:],
863
+ ), f"{model_output.shape} != {(B, C * 2, *x_t.shape[2:])}"
864
+ model_output, model_var_values = th.split(model_output, C, dim=1)
865
+ # Learn the variance using the variational bound, but don't let
866
+ # it affect our mean prediction.
867
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
868
+ terms["vb"] = self._vb_terms_bpd(
869
+ model=lambda *args, r=frozen_out: r,
870
+ x_start=x_start,
871
+ x_t=x_t,
872
+ t=t,
873
+ clip_denoised=False,
874
+ )["output"]
875
+ if self.loss_type == "rescaled_mse":
876
+ # Divide by 1000 for equivalence with initial implementation.
877
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
878
+ terms["vb"] *= self.num_timesteps / 1000.0
879
+
880
+ target = {
881
+ "x_prev": self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
882
+ "x_start": x_start,
883
+ "epsilon": noise,
884
+ }[self.model_mean_type]
885
+ assert model_output.shape == target.shape == x_start.shape
886
+ terms["mse"] = mean_flat((target - model_output) ** 2)
887
+ if "vb" in terms:
888
+ terms["loss"] = terms["mse"] + terms["vb"]
889
+ else:
890
+ terms["loss"] = terms["mse"]
891
+ else:
892
+ raise NotImplementedError(self.loss_type)
893
+
894
+ if "losses" in extra:
895
+ terms.update({k: loss for k, (loss, _scale) in extra["losses"].items()})
896
+ for loss, scale in extra["losses"].values():
897
+ terms["loss"] = terms["loss"] + loss * scale
898
+
899
+ return terms
900
+
901
+ def _prior_bpd(self, x_start):
902
+ """
903
+ Get the prior KL term for the variational lower-bound, measured in
904
+ bits-per-dim.
905
+
906
+ This term can't be optimized, as it only depends on the encoder.
907
+
908
+ :param x_start: the [N x C x ...] tensor of inputs.
909
+ :return: a batch of [N] KL values (in bits), one per batch element.
910
+ """
911
+ batch_size = x_start.shape[0]
912
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
913
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
914
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
915
+ return mean_flat(kl_prior) / np.log(2.0)
916
+
917
+ def calc_bpd_loop(self, model, x_start, clip_denoised=False, model_kwargs=None):
918
+ """
919
+ Compute the entire variational lower-bound, measured in bits-per-dim,
920
+ as well as other related quantities.
921
+
922
+ :param model: the model to evaluate loss on.
923
+ :param x_start: the [N x C x ...] tensor of inputs.
924
+ :param clip_denoised: if True, clip denoised samples.
925
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
926
+ pass to the model. This can be used for conditioning.
927
+
928
+ :return: a dict containing the following keys:
929
+ - total_bpd: the total variational lower-bound, per batch element.
930
+ - prior_bpd: the prior term in the lower-bound.
931
+ - vb: an [N x T] tensor of terms in the lower-bound.
932
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
933
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
934
+ """
935
+ device = x_start.device
936
+ batch_size = x_start.shape[0]
937
+
938
+ vb = []
939
+ xstart_mse = []
940
+ mse = []
941
+ for t in list(range(self.num_timesteps))[::-1]:
942
+ t_batch = th.tensor([t] * batch_size, device=device)
943
+ noise = th.randn_like(x_start)
944
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
945
+ # Calculate VLB term at the current timestep
946
+ with th.no_grad():
947
+ out = self._vb_terms_bpd(
948
+ model,
949
+ x_start=x_start,
950
+ x_t=x_t,
951
+ t=t_batch,
952
+ clip_denoised=clip_denoised,
953
+ model_kwargs=model_kwargs,
954
+ )
955
+ vb.append(out["output"])
956
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
957
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
958
+ mse.append(mean_flat((eps - noise) ** 2))
959
+
960
+ vb = th.stack(vb, dim=1)
961
+ xstart_mse = th.stack(xstart_mse, dim=1)
962
+ mse = th.stack(mse, dim=1)
963
+
964
+ prior_bpd = self._prior_bpd(x_start)
965
+ total_bpd = vb.sum(dim=1) + prior_bpd
966
+ return {
967
+ "total_bpd": total_bpd,
968
+ "prior_bpd": prior_bpd,
969
+ "vb": vb,
970
+ "xstart_mse": xstart_mse,
971
+ "mse": mse,
972
+ }
973
+
974
+ def scale_channels(self, x: th.Tensor) -> th.Tensor:
975
+ if self.channel_scales is not None:
976
+ x = x * th.from_numpy(self.channel_scales).to(x).reshape(
977
+ [1, -1, *([1] * (len(x.shape) - 2))]
978
+ )
979
+ if self.channel_biases is not None:
980
+ x = x + th.from_numpy(self.channel_biases).to(x).reshape(
981
+ [1, -1, *([1] * (len(x.shape) - 2))]
982
+ )
983
+ return x
984
+
985
+ def unscale_channels(self, x: th.Tensor) -> th.Tensor:
986
+ if self.channel_biases is not None:
987
+ x = x - th.from_numpy(self.channel_biases).to(x).reshape(
988
+ [1, -1, *([1] * (len(x.shape) - 2))]
989
+ )
990
+ if self.channel_scales is not None:
991
+ x = x / th.from_numpy(self.channel_scales).to(x).reshape(
992
+ [1, -1, *([1] * (len(x.shape) - 2))]
993
+ )
994
+ return x
995
+
996
+ def unscale_out_dict(
997
+ self, out: Dict[str, Union[th.Tensor, Any]]
998
+ ) -> Dict[str, Union[th.Tensor, Any]]:
999
+ return {
1000
+ k: (self.unscale_channels(v) if isinstance(v, th.Tensor) else v) for k, v in out.items()
1001
+ }
1002
+
1003
+
1004
+ class SpacedDiffusion(GaussianDiffusion):
1005
+ """
1006
+ A diffusion process which can skip steps in a base diffusion process.
1007
+ :param use_timesteps: (unordered) timesteps from the original diffusion
1008
+ process to retain.
1009
+ :param kwargs: the kwargs to create the base diffusion process.
1010
+ """
1011
+
1012
+ def __init__(self, use_timesteps: Iterable[int], **kwargs):
1013
+ self.use_timesteps = set(use_timesteps)
1014
+ self.timestep_map = []
1015
+ self.original_num_steps = len(kwargs["betas"])
1016
+
1017
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
1018
+ last_alpha_cumprod = 1.0
1019
+ new_betas = []
1020
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
1021
+ if i in self.use_timesteps:
1022
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
1023
+ last_alpha_cumprod = alpha_cumprod
1024
+ self.timestep_map.append(i)
1025
+ kwargs["betas"] = np.array(new_betas)
1026
+ super().__init__(**kwargs)
1027
+
1028
+ def p_mean_variance(self, model, *args, **kwargs):
1029
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
1030
+
1031
+ def training_losses(self, model, *args, **kwargs):
1032
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
1033
+
1034
+ def condition_mean(self, cond_fn, *args, **kwargs):
1035
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
1036
+
1037
+ def condition_score(self, cond_fn, *args, **kwargs):
1038
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
1039
+
1040
+ def _wrap_model(self, model):
1041
+ if isinstance(model, _WrappedModel):
1042
+ return model
1043
+ return _WrappedModel(model, self.timestep_map, self.original_num_steps)
1044
+
1045
+
1046
+ class _WrappedModel:
1047
+ def __init__(self, model, timestep_map, original_num_steps):
1048
+ self.model = model
1049
+ self.timestep_map = timestep_map
1050
+ self.original_num_steps = original_num_steps
1051
+
1052
+ def __call__(self, x, ts, **kwargs):
1053
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
1054
+ new_ts = map_tensor[ts]
1055
+ return self.model(x, new_ts, **kwargs)
1056
+
1057
+
1058
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1059
+ """
1060
+ Extract values from a 1-D numpy array for a batch of indices.
1061
+
1062
+ :param arr: the 1-D numpy array.
1063
+ :param timesteps: a tensor of indices into the array to extract.
1064
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1065
+ dimension equal to the length of timesteps.
1066
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1067
+ """
1068
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1069
+ while len(res.shape) < len(broadcast_shape):
1070
+ res = res[..., None]
1071
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
1072
+
1073
+
1074
+ def normal_kl(mean1, logvar1, mean2, logvar2):
1075
+ """
1076
+ Compute the KL divergence between two gaussians.
1077
+ Shapes are automatically broadcasted, so batches can be compared to
1078
+ scalars, among other use cases.
1079
+ """
1080
+ tensor = None
1081
+ for obj in (mean1, logvar1, mean2, logvar2):
1082
+ if isinstance(obj, th.Tensor):
1083
+ tensor = obj
1084
+ break
1085
+ assert tensor is not None, "at least one argument must be a Tensor"
1086
+
1087
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
1088
+ # Tensors, but it does not work for th.exp().
1089
+ logvar1, logvar2 = [
1090
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)
1091
+ ]
1092
+
1093
+ return 0.5 * (
1094
+ -1.0
1095
+ + logvar2
1096
+ - logvar1
1097
+ + th.exp(logvar1 - logvar2)
1098
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
1099
+ )
1100
+
1101
+
1102
+ def approx_standard_normal_cdf(x):
1103
+ """
1104
+ A fast approximation of the cumulative distribution function of the
1105
+ standard normal.
1106
+ """
1107
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
1108
+
1109
+
1110
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
1111
+ """
1112
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
1113
+ given image.
1114
+ :param x: the target images. It is assumed that this was uint8 values,
1115
+ rescaled to the range [-1, 1].
1116
+ :param means: the Gaussian mean Tensor.
1117
+ :param log_scales: the Gaussian log stddev Tensor.
1118
+ :return: a tensor like x of log probabilities (in nats).
1119
+ """
1120
+ assert x.shape == means.shape == log_scales.shape
1121
+ centered_x = x - means
1122
+ inv_stdv = th.exp(-log_scales)
1123
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
1124
+ cdf_plus = approx_standard_normal_cdf(plus_in)
1125
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
1126
+ cdf_min = approx_standard_normal_cdf(min_in)
1127
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
1128
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
1129
+ cdf_delta = cdf_plus - cdf_min
1130
+ log_probs = th.where(
1131
+ x < -0.999,
1132
+ log_cdf_plus,
1133
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
1134
+ )
1135
+ assert log_probs.shape == x.shape
1136
+ return log_probs
1137
+
1138
+
1139
+ def mean_flat(tensor):
1140
+ """
1141
+ Take the mean over all non-batch dimensions.
1142
+ """
1143
+ return tensor.flatten(1).mean(1)
shap_e/diffusion/k_diffusion.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on: https://github.com/crowsonkb/k-diffusion
3
+
4
+ Copyright (c) 2022 Katherine Crowson
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in
14
+ all copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22
+ THE SOFTWARE.
23
+ """
24
+
25
+ import numpy as np
26
+ import torch as th
27
+
28
+ from .gaussian_diffusion import GaussianDiffusion, mean_flat
29
+
30
+
31
+ class KarrasDenoiser:
32
+ def __init__(self, sigma_data: float = 0.5):
33
+ self.sigma_data = sigma_data
34
+
35
+ def get_snr(self, sigmas):
36
+ return sigmas**-2
37
+
38
+ def get_sigmas(self, sigmas):
39
+ return sigmas
40
+
41
+ def get_scalings(self, sigma):
42
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
43
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
44
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
45
+ return c_skip, c_out, c_in
46
+
47
+ def training_losses(self, model, x_start, sigmas, model_kwargs=None, noise=None):
48
+ if model_kwargs is None:
49
+ model_kwargs = {}
50
+ if noise is None:
51
+ noise = th.randn_like(x_start)
52
+
53
+ terms = {}
54
+
55
+ dims = x_start.ndim
56
+ x_t = x_start + noise * append_dims(sigmas, dims)
57
+ c_skip, c_out, _ = [append_dims(x, dims) for x in self.get_scalings(sigmas)]
58
+ model_output, denoised = self.denoise(model, x_t, sigmas, **model_kwargs)
59
+ target = (x_start - c_skip * x_t) / c_out
60
+
61
+ terms["mse"] = mean_flat((model_output - target) ** 2)
62
+ terms["xs_mse"] = mean_flat((denoised - x_start) ** 2)
63
+
64
+ if "vb" in terms:
65
+ terms["loss"] = terms["mse"] + terms["vb"]
66
+ else:
67
+ terms["loss"] = terms["mse"]
68
+
69
+ return terms
70
+
71
+ def denoise(self, model, x_t, sigmas, **model_kwargs):
72
+ c_skip, c_out, c_in = [append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas)]
73
+ rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44)
74
+ model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
75
+ denoised = c_out * model_output + c_skip * x_t
76
+ return model_output, denoised
77
+
78
+
79
+ class GaussianToKarrasDenoiser:
80
+ def __init__(self, model, diffusion):
81
+ from scipy import interpolate
82
+
83
+ self.model = model
84
+ self.diffusion = diffusion
85
+ self.alpha_cumprod_to_t = interpolate.interp1d(
86
+ diffusion.alphas_cumprod, np.arange(0, diffusion.num_timesteps)
87
+ )
88
+
89
+ def sigma_to_t(self, sigma):
90
+ alpha_cumprod = 1.0 / (sigma**2 + 1)
91
+ if alpha_cumprod > self.diffusion.alphas_cumprod[0]:
92
+ return 0
93
+ elif alpha_cumprod <= self.diffusion.alphas_cumprod[-1]:
94
+ return self.diffusion.num_timesteps - 1
95
+ else:
96
+ return float(self.alpha_cumprod_to_t(alpha_cumprod))
97
+
98
+ def denoise(self, x_t, sigmas, clip_denoised=True, model_kwargs=None, condition_latents=None):
99
+ t = th.tensor(
100
+ [self.sigma_to_t(sigma) for sigma in sigmas.cpu().numpy()],
101
+ dtype=th.long,
102
+ device=sigmas.device,
103
+ )
104
+ c_in = append_dims(1.0 / (sigmas**2 + 1) ** 0.5, x_t.ndim)
105
+ out = self.diffusion.p_mean_variance(
106
+ self.model, x_t * c_in, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latents
107
+ )
108
+ return None, out["pred_xstart"]
109
+
110
+
111
+ def karras_sample(*args, **kwargs):
112
+ last = None
113
+ x_sequence = []
114
+ # print("kraras_sample_model_kwargs", kwargs["model_kwargs"]['embeddings'].shape)
115
+ for x in karras_sample_progressive(*args, **kwargs):
116
+ last = x["x"]
117
+ x_sequence.append(last)
118
+ return last, x_sequence
119
+
120
+
121
+
122
+ def karras_sample_progressive(
123
+ diffusion,
124
+ model,
125
+ shape,
126
+ steps,
127
+ clip_denoised=True,
128
+ progress=False,
129
+ model_kwargs=None,
130
+ device=None,
131
+ sigma_min=0.002,
132
+ sigma_max=80, # higher for highres?
133
+ rho=7.0,
134
+ sampler="heun",
135
+ s_churn=0.0,
136
+ s_tmin=0.0,
137
+ s_tmax=float("inf"),
138
+ s_noise=1.0,
139
+ guidance_scale=0.0,
140
+ condition_latent=None,
141
+ initial_noise=None,
142
+ ):
143
+ sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
144
+ # print("sigmas", sigmas.shape, sigmas)
145
+ if initial_noise is None:
146
+ x_T = th.randn(*shape, device=device) * sigma_max
147
+ else:
148
+ x_T = initial_noise.clone() * sigma_max
149
+ sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[
150
+ sampler
151
+ ]
152
+ if sampler != "ancestral":
153
+ sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
154
+ else:
155
+ sampler_args = {}
156
+
157
+ if isinstance(diffusion, KarrasDenoiser):
158
+ def denoiser(x_t, sigma):
159
+ _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
160
+ if clip_denoised:
161
+ denoised = denoised.clamp(-1, 1)
162
+ return denoised
163
+
164
+ elif isinstance(diffusion, GaussianDiffusion):
165
+ model = GaussianToKarrasDenoiser(model, diffusion)
166
+
167
+ def denoiser(x_t, sigma):
168
+ _, denoised = model.denoise(
169
+ x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latent
170
+ )
171
+ return denoised
172
+
173
+ else:
174
+ raise NotImplementedError
175
+
176
+ if guidance_scale != 0 and guidance_scale != 1:
177
+
178
+ def guided_denoiser(x_t, sigma):
179
+ x_t = th.cat([x_t, x_t], dim=0)
180
+ sigma = th.cat([sigma, sigma], dim=0)
181
+ x_0 = denoiser(x_t, sigma)
182
+ cond_x_0, uncond_x_0 = th.split(x_0, len(x_0) // 2, dim=0)
183
+ x_0 = uncond_x_0 + guidance_scale * (cond_x_0 - uncond_x_0)
184
+ return x_0
185
+
186
+ else:
187
+ guided_denoiser = denoiser
188
+
189
+ for obj in sample_fn(
190
+ guided_denoiser,
191
+ x_T,
192
+ sigmas,
193
+ progress=progress,
194
+ condition_latent=condition_latent,
195
+ **sampler_args,
196
+ ):
197
+ if isinstance(diffusion, GaussianDiffusion):
198
+ # print("is gaussian diffusion", obj)
199
+ yield diffusion.unscale_out_dict(obj)
200
+ else:
201
+ yield obj
202
+
203
+
204
+ def karras_sample_progressive_condition(
205
+ diffusion,
206
+ model,
207
+ shape,
208
+ steps,
209
+ clip_denoised=True,
210
+ progress=False,
211
+ model_kwargs=None,
212
+ device=None,
213
+ sigma_min=0.002,
214
+ sigma_max=80, # higher for highres?
215
+ rho=7.0,
216
+ sampler="heun",
217
+ s_churn=0.0,
218
+ s_tmin=0.0,
219
+ s_tmax=float("inf"),
220
+ s_noise=1.0,
221
+ text_guidance_scale=0.0,
222
+ image_guidance_scale=0.0,
223
+ condition_latent=None,
224
+ ):
225
+ sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device)
226
+ x_T = th.randn(*shape, device=device) * sigma_max
227
+ sample_fn = {"heun": sample_heun, "dpm": sample_dpm, "ancestral": sample_euler_ancestral}[
228
+ sampler
229
+ ]
230
+ if sampler != "ancestral":
231
+ sampler_args = dict(s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise)
232
+ else:
233
+ sampler_args = {}
234
+
235
+ if isinstance(diffusion, KarrasDenoiser):
236
+ def denoiser(x_t, sigma):
237
+ _, denoised = diffusion.denoise(model, x_t, sigma, **model_kwargs)
238
+ if clip_denoised:
239
+ denoised = denoised.clamp(-1, 1)
240
+ return denoised
241
+
242
+ elif isinstance(diffusion, GaussianDiffusion):
243
+ model = GaussianToKarrasDenoiser(model, diffusion)
244
+
245
+ def denoiser(x_t, sigma):
246
+ _, denoised = model.denoise(
247
+ x_t, sigma, clip_denoised=clip_denoised, model_kwargs=model_kwargs, condition_latents=condition_latent
248
+ )
249
+ return denoised
250
+
251
+ else:
252
+ raise NotImplementedError
253
+
254
+ if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0):
255
+ def guided_denoiser(x_t, sigma):
256
+ x_t = th.cat([x_t, x_t, x_t], dim=0)
257
+ sigma = th.cat([sigma, sigma, sigma], dim=0)
258
+ x_0 = denoiser(x_t, sigma)
259
+ # import pdb; pdb.set_trace()
260
+ cond_x_0_text, cond_x_0_image, uncond_x_0 = th.chunk(x_0, 3, dim=0)
261
+ x_0 = uncond_x_0 + text_guidance_scale * (cond_x_0_text - cond_x_0_image) + image_guidance_scale * (cond_x_0_image - uncond_x_0)
262
+ return x_0
263
+
264
+ else:
265
+ guided_denoiser = denoiser
266
+
267
+ for obj in sample_fn(
268
+ guided_denoiser,
269
+ x_T,
270
+ sigmas,
271
+ progress=progress,
272
+ condition_latent=condition_latent,
273
+ **sampler_args,
274
+ ):
275
+ if isinstance(diffusion, GaussianDiffusion):
276
+ yield diffusion.unscale_out_dict(obj)
277
+ else:
278
+ yield obj
279
+ def karras_sample_addition_condition(*args, **kwargs):
280
+ last = None
281
+ x_sequence = []
282
+ for x in karras_sample_progressive_condition(*args, **kwargs):
283
+ last = x["x"]
284
+ x_sequence.append(x["pred_xstart"])
285
+ return last, x_sequence
286
+
287
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"):
288
+ """Constructs the noise schedule of Karras et al. (2022)."""
289
+ ramp = th.linspace(0, 1, n)
290
+ min_inv_rho = sigma_min ** (1 / rho)
291
+ max_inv_rho = sigma_max ** (1 / rho)
292
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
293
+ return append_zero(sigmas).to(device)
294
+
295
+
296
+ def to_d(x, sigma, denoised):
297
+ """Converts a denoiser output to a Karras ODE derivative."""
298
+ return (x - denoised) / append_dims(sigma, x.ndim)
299
+
300
+
301
+ def get_ancestral_step(sigma_from, sigma_to):
302
+ """Calculates the noise level (sigma_down) to step down to and the amount
303
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
304
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
305
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
306
+ return sigma_down, sigma_up
307
+
308
+
309
+ @th.no_grad()
310
+ def sample_euler_ancestral(model, x, sigmas, progress=False):
311
+ """Ancestral sampling with Euler method steps."""
312
+ s_in = x.new_ones([x.shape[0]])
313
+ indices = range(len(sigmas) - 1)
314
+ if progress:
315
+ from tqdm.auto import tqdm
316
+
317
+ indices = tqdm(indices)
318
+
319
+ for i in indices:
320
+ denoised = model(x, sigmas[i] * s_in)
321
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
322
+ yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "pred_xstart": denoised}
323
+ d = to_d(x, sigmas[i], denoised)
324
+ # Euler method
325
+ dt = sigma_down - sigmas[i]
326
+ x = x + d * dt
327
+ x = x + th.randn_like(x) * sigma_up
328
+ yield {"x": x, "pred_xstart": x}
329
+
330
+
331
+ @th.no_grad()
332
+ def sample_heun(
333
+ denoiser,
334
+ x,
335
+ sigmas,
336
+ progress=False,
337
+ s_churn=0.0,
338
+ s_tmin=0.0,
339
+ s_tmax=float("inf"),
340
+ s_noise=1.0,
341
+ condition_latent=None,
342
+ ):
343
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
344
+ s_in = x.new_ones([x.shape[0]])
345
+ indices = range(len(sigmas) - 1)
346
+ if progress:
347
+ from tqdm.auto import tqdm
348
+
349
+ indices = tqdm(indices)
350
+
351
+ for i in indices:
352
+ gamma = (
353
+ min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
354
+ )
355
+ eps = th.randn_like(x) * s_noise
356
+ sigma_hat = sigmas[i] * (gamma + 1)
357
+ if gamma > 0:
358
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
359
+ denoised = denoiser(x, sigma_hat * s_in)
360
+ d = to_d(x, sigma_hat, denoised)
361
+ yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "pred_xstart": denoised}
362
+ dt = sigmas[i + 1] - sigma_hat
363
+ if sigmas[i + 1] == 0:
364
+ # Euler method
365
+ x = x + d * dt
366
+ else:
367
+ # Heun's method
368
+ x_2 = x + d * dt
369
+ denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in)
370
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
371
+ d_prime = (d + d_2) / 2
372
+ x = x + d_prime * dt
373
+ yield {"x": x, "pred_xstart": denoised}
374
+
375
+
376
+ @th.no_grad()
377
+ def sample_dpm(
378
+ denoiser,
379
+ x,
380
+ sigmas,
381
+ progress=False,
382
+ s_churn=0.0,
383
+ s_tmin=0.0,
384
+ s_tmax=float("inf"),
385
+ s_noise=1.0,
386
+ ):
387
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
388
+ s_in = x.new_ones([x.shape[0]])
389
+ indices = range(len(sigmas) - 1)
390
+ if progress:
391
+ from tqdm.auto import tqdm
392
+
393
+ indices = tqdm(indices)
394
+
395
+ for i in indices:
396
+ gamma = (
397
+ min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0
398
+ )
399
+ eps = th.randn_like(x) * s_noise
400
+ sigma_hat = sigmas[i] * (gamma + 1)
401
+ if gamma > 0:
402
+ x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
403
+ denoised = denoiser(x, sigma_hat * s_in)
404
+ d = to_d(x, sigma_hat, denoised)
405
+ yield {"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised}
406
+ # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
407
+ sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
408
+ dt_1 = sigma_mid - sigma_hat
409
+ dt_2 = sigmas[i + 1] - sigma_hat
410
+ x_2 = x + d * dt_1
411
+ denoised_2 = denoiser(x_2, sigma_mid * s_in)
412
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
413
+ x = x + d_2 * dt_2
414
+ yield {"x": x, "pred_xstart": denoised}
415
+
416
+
417
+ def append_dims(x, target_dims):
418
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
419
+ dims_to_append = target_dims - x.ndim
420
+ if dims_to_append < 0:
421
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
422
+ return x[(...,) + (None,) * dims_to_append]
423
+
424
+
425
+ def append_zero(x):
426
+ return th.cat([x, x.new_zeros([1])])
shap_e/diffusion/sample.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, Optional, List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .gaussian_diffusion import GaussianDiffusion
7
+ from .k_diffusion import karras_sample, karras_sample_addition_condition
8
+
9
+ DEFAULT_KARRAS_STEPS = 64
10
+ DEFAULT_KARRAS_SIGMA_MIN = 1e-3
11
+ DEFAULT_KARRAS_SIGMA_MAX = 160
12
+ DEFAULT_KARRAS_S_CHURN = 0.0
13
+
14
+
15
+ def uncond_guide_model(
16
+ model: Callable[..., torch.Tensor], scale: float
17
+ ) -> Callable[..., torch.Tensor]:
18
+
19
+ def model_fn(x_t, ts, **kwargs):
20
+ half = x_t[: len(x_t) // 2]
21
+ combined = torch.cat([half, half], dim=0)
22
+ model_out = model(combined, ts, **kwargs)
23
+ cond_out, uncond_out = torch.chunk(model_out, 2, dim=0)
24
+ cond_out = uncond_out + scale * (cond_out - uncond_out)
25
+ return torch.cat([cond_out, cond_out], dim=0)
26
+
27
+ return model_fn
28
+
29
+
30
+ def sample_latents(
31
+ *,
32
+ batch_size: int,
33
+ model: nn.Module,
34
+ diffusion: GaussianDiffusion,
35
+ model_kwargs: Dict[str, Any],
36
+ guidance_scale: float,
37
+ clip_denoised: bool,
38
+ use_fp16: bool,
39
+ use_karras: bool,
40
+ karras_steps: int,
41
+ sigma_min: float,
42
+ sigma_max: float,
43
+ s_churn: float,
44
+ device: Optional[torch.device] = None,
45
+ progress: bool = False,
46
+ initial_noise: Optional[torch.Tensor] = None,
47
+ ) -> (torch.Tensor, List[torch.Tensor]):
48
+ sample_shape = (batch_size, model.d_latent)
49
+
50
+ if device is None:
51
+ device = next(model.parameters()).device
52
+
53
+ if hasattr(model, "cached_model_kwargs"):
54
+ model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
55
+ if guidance_scale != 1.0 and guidance_scale != 0.0:
56
+ for k, v in model_kwargs.copy().items():
57
+ # print(k, v.shape)
58
+ model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
59
+
60
+ sample_shape = (batch_size, model.d_latent)
61
+ with torch.autocast(device_type=device.type, enabled=use_fp16):
62
+ if use_karras:
63
+ samples, sample_sequence = karras_sample(
64
+ diffusion=diffusion,
65
+ model=model,
66
+ shape=sample_shape,
67
+ steps=karras_steps,
68
+ clip_denoised=clip_denoised,
69
+ model_kwargs=model_kwargs,
70
+ device=device,
71
+ sigma_min=sigma_min,
72
+ sigma_max=sigma_max,
73
+ s_churn=s_churn,
74
+ guidance_scale=guidance_scale,
75
+ progress=progress,
76
+ initial_noise=initial_noise,
77
+ )
78
+ else:
79
+ internal_batch_size = batch_size
80
+ if guidance_scale != 1.0:
81
+ model = uncond_guide_model(model, guidance_scale)
82
+ internal_batch_size *= 2
83
+ samples = diffusion.p_sample_loop(
84
+ model,
85
+ shape=(internal_batch_size, *sample_shape[1:]),
86
+ model_kwargs=model_kwargs,
87
+ device=device,
88
+ clip_denoised=clip_denoised,
89
+ progress=progress,
90
+ )
91
+
92
+ return samples
93
+
94
+
95
+ def sample_latents_with_additional_latent(
96
+ *,
97
+ batch_size: int,
98
+ model: nn.Module,
99
+ diffusion: GaussianDiffusion,
100
+ model_kwargs: Dict[str, Any],
101
+ text_guidance_scale: float,
102
+ image_guidance_scale: float,
103
+ clip_denoised: bool,
104
+ use_fp16: bool,
105
+ use_karras: bool,
106
+ karras_steps: int,
107
+ sigma_min: float,
108
+ sigma_max: float,
109
+ s_churn: float,
110
+ device: Optional[torch.device] = None,
111
+ progress: bool = False,
112
+ condition_latent: Optional[torch.Tensor] = None,
113
+ ) -> (torch.Tensor, List[torch.Tensor]):
114
+
115
+ if device is None:
116
+ device = next(model.parameters()).device
117
+
118
+ if hasattr(model, "cached_model_kwargs"):
119
+ model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs)
120
+ if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0):
121
+ for k, v in model_kwargs.copy().items():
122
+ # print(k, v.shape)
123
+ model_kwargs[k] = torch.cat([v, torch.zeros_like(v), torch.zeros_like(v)], dim=0)
124
+ condition_latent = torch.cat([condition_latent, condition_latent, torch.zeros_like(condition_latent)], dim=0)
125
+
126
+ sample_shape = (batch_size, model.d_latent)
127
+ # print("sample_shape", sample_shape)
128
+ with torch.autocast(device_type=device.type, enabled=use_fp16):
129
+ if use_karras:
130
+ samples, samples_squence = karras_sample_addition_condition(
131
+ diffusion=diffusion,
132
+ model=model,
133
+ shape=sample_shape,
134
+ steps=karras_steps,
135
+ clip_denoised=clip_denoised,
136
+ model_kwargs=model_kwargs,
137
+ device=device,
138
+ sigma_min=sigma_min,
139
+ sigma_max=sigma_max,
140
+ s_churn=s_churn,
141
+ text_guidance_scale=text_guidance_scale,
142
+ image_guidance_scale=image_guidance_scale,
143
+ progress=progress,
144
+ condition_latent=condition_latent,
145
+ )
146
+ else:
147
+ internal_batch_size = batch_size
148
+ if text_guidance_scale != 1.0:
149
+ model = uncond_guide_model(model, text_guidance_scale)
150
+ internal_batch_size *= 2
151
+ samples = diffusion.p_sample_loop(
152
+ model,
153
+ shape=(internal_batch_size, *sample_shape[1:]),
154
+ model_kwargs=model_kwargs,
155
+ device=device,
156
+ clip_denoised=clip_denoised,
157
+ progress=progress,
158
+ )
159
+
160
+ return samples
shap_e/examples/encode_model.ipynb ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "\n",
11
+ "from shap_e.models.download import load_model\n",
12
+ "from shap_e.util.data_util import load_or_create_multimodal_batch\n",
13
+ "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 2,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "xm = load_model('transmitter', device=device)"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 3,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "model_path = \"example_data/cactus/object.obj\"\n",
41
+ "\n",
42
+ "# This may take a few minutes, since it requires rendering the model twice\n",
43
+ "# in two different modes.\n",
44
+ "batch = load_or_create_multimodal_batch(\n",
45
+ " device,\n",
46
+ " model_path=model_path,\n",
47
+ " mv_light_mode=\"basic\",\n",
48
+ " mv_image_size=256,\n",
49
+ " cache_dir=\"example_data/cactus/cached\",\n",
50
+ " verbose=True, # this will show Blender output during renders\n",
51
+ ")"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": null,
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "with torch.no_grad():\n",
61
+ " latent = xm.encoder.encode_to_bottleneck(batch)\n",
62
+ "\n",
63
+ " render_mode = 'stf' # you can change this to 'nerf'\n",
64
+ " size = 128 # recommended that you lower resolution when using nerf\n",
65
+ "\n",
66
+ " cameras = create_pan_cameras(size, device)\n",
67
+ " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
68
+ " display(gif_widget(images))"
69
+ ]
70
+ }
71
+ ],
72
+ "metadata": {
73
+ "kernelspec": {
74
+ "display_name": "Python 3 (ipykernel)",
75
+ "language": "python",
76
+ "name": "python3"
77
+ },
78
+ "language_info": {
79
+ "codemirror_mode": {
80
+ "name": "ipython",
81
+ "version": 3
82
+ },
83
+ "file_extension": ".py",
84
+ "mimetype": "text/x-python",
85
+ "name": "python",
86
+ "nbconvert_exporter": "python",
87
+ "pygments_lexer": "ipython3",
88
+ "version": "3.9.9"
89
+ }
90
+ },
91
+ "nbformat": 4,
92
+ "nbformat_minor": 5
93
+ }
shap_e/examples/sample_image_to_3d.ipynb ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "964ccced",
7
+ "metadata": {
8
+ "pycharm": {
9
+ "is_executing": true
10
+ }
11
+ },
12
+ "outputs": [],
13
+ "source": [
14
+ "import torch\n",
15
+ "\n",
16
+ "from shap_e.diffusion.sample import sample_latents\n",
17
+ "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
18
+ "from shap_e.models.download import load_model, load_config\n",
19
+ "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget\n",
20
+ "from shap_e.util.image_util import load_image"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "8eed3a76",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "2d922637",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "xm = load_model('transmitter', device=device)\n",
41
+ "model = load_model('image300M', device=device)\n",
42
+ "diffusion = diffusion_from_config(load_config('diffusion'))"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "53d329d0",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "batch_size = 4\n",
53
+ "guidance_scale = 3.0\n",
54
+ "\n",
55
+ "image = load_image(\"example_data/corgi.png\")\n",
56
+ "\n",
57
+ "latents = sample_latents(\n",
58
+ " batch_size=batch_size,\n",
59
+ " model=model,\n",
60
+ " diffusion=diffusion,\n",
61
+ " guidance_scale=guidance_scale,\n",
62
+ " model_kwargs=dict(images=[image] * batch_size),\n",
63
+ " progress=True,\n",
64
+ " clip_denoised=True,\n",
65
+ " use_fp16=True,\n",
66
+ " use_karras=True,\n",
67
+ " karras_steps=64,\n",
68
+ " sigma_min=1e-3,\n",
69
+ " sigma_max=160,\n",
70
+ " s_churn=0,\n",
71
+ ")"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "id": "633da2ec",
78
+ "metadata": {
79
+ "pycharm": {
80
+ "is_executing": true
81
+ }
82
+ },
83
+ "outputs": [],
84
+ "source": [
85
+ "render_mode = 'nerf' # you can change this to 'stf' for mesh rendering\n",
86
+ "size = 64 # this is the size of the renders; higher values take longer to render.\n",
87
+ "\n",
88
+ "cameras = create_pan_cameras(size, device)\n",
89
+ "for i, latent in enumerate(latents):\n",
90
+ " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
91
+ " display(gif_widget(images))"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "outputs": [],
98
+ "source": [],
99
+ "metadata": {
100
+ "collapsed": false
101
+ }
102
+ }
103
+ ],
104
+ "metadata": {
105
+ "kernelspec": {
106
+ "display_name": "Python 3 (ipykernel)",
107
+ "language": "python",
108
+ "name": "python3"
109
+ },
110
+ "language_info": {
111
+ "codemirror_mode": {
112
+ "name": "ipython",
113
+ "version": 3
114
+ },
115
+ "file_extension": ".py",
116
+ "mimetype": "text/x-python",
117
+ "name": "python",
118
+ "nbconvert_exporter": "python",
119
+ "pygments_lexer": "ipython3",
120
+ "version": "3.9.9"
121
+ }
122
+ },
123
+ "nbformat": 4,
124
+ "nbformat_minor": 5
125
+ }
shap_e/examples/sample_text_to_3d.ipynb ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "964ccced",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "\n",
12
+ "from shap_e.diffusion.sample import sample_latents\n",
13
+ "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n",
14
+ "from shap_e.models.download import load_model, load_config\n",
15
+ "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "id": "8eed3a76",
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "id": "2d922637",
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "xm = load_model('transmitter', device=device)\n",
36
+ "model = load_model('text300M', device=device)\n",
37
+ "diffusion = diffusion_from_config(load_config('diffusion'))"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "53d329d0",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "batch_size = 4\n",
48
+ "guidance_scale = 15.0\n",
49
+ "prompt = \"a shark\"\n",
50
+ "\n",
51
+ "latents = sample_latents(\n",
52
+ " batch_size=batch_size,\n",
53
+ " model=model,\n",
54
+ " diffusion=diffusion,\n",
55
+ " guidance_scale=guidance_scale,\n",
56
+ " model_kwargs=dict(texts=[prompt] * batch_size),\n",
57
+ " progress=True,\n",
58
+ " clip_denoised=True,\n",
59
+ " use_fp16=True,\n",
60
+ " use_karras=True,\n",
61
+ " karras_steps=64,\n",
62
+ " sigma_min=1e-3,\n",
63
+ " sigma_max=160,\n",
64
+ " s_churn=0,\n",
65
+ ")"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "id": "633da2ec",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "render_mode = 'nerf' # you can change this to 'stf'\n",
76
+ "size = 64 # this is the size of the renders; higher values take longer to render.\n",
77
+ "\n",
78
+ "cameras = create_pan_cameras(size, device)\n",
79
+ "for i, latent in enumerate(latents):\n",
80
+ " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n",
81
+ " display(gif_widget(images))"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "id": "85a4dce4",
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "# Example of saving the latents as meshes.\n",
92
+ "from shap_e.util.notebooks import decode_latent_mesh\n",
93
+ "\n",
94
+ "for i, latent in enumerate(latents):\n",
95
+ " t = decode_latent_mesh(xm, latent).tri_mesh()\n",
96
+ " with open(f'example_mesh_{i}.ply', 'wb') as f:\n",
97
+ " t.write_ply(f)\n",
98
+ " with open(f'example_mesh_{i}.obj', 'w') as f:\n",
99
+ " t.write_obj(f)"
100
+ ]
101
+ }
102
+ ],
103
+ "metadata": {
104
+ "kernelspec": {
105
+ "display_name": "Python 3 (ipykernel)",
106
+ "language": "python",
107
+ "name": "python3"
108
+ },
109
+ "language_info": {
110
+ "codemirror_mode": {
111
+ "name": "ipython",
112
+ "version": 3
113
+ },
114
+ "file_extension": ".py",
115
+ "mimetype": "text/x-python",
116
+ "name": "python",
117
+ "nbconvert_exporter": "python",
118
+ "pygments_lexer": "ipython3",
119
+ "version": "3.11.3"
120
+ }
121
+ },
122
+ "nbformat": 4,
123
+ "nbformat_minor": 5
124
+ }
shap_e/models/__init__.py ADDED
File without changes
shap_e/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (163 Bytes). View file
 
shap_e/models/__pycache__/configs.cpython-39.pyc ADDED
Binary file (5 kB). View file
 
shap_e/models/__pycache__/download.cpython-39.pyc ADDED
Binary file (5.17 kB). View file
 
shap_e/models/__pycache__/query.cpython-39.pyc ADDED
Binary file (1.05 kB). View file
 
shap_e/models/__pycache__/renderer.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
shap_e/models/__pycache__/volume.cpython-39.pyc ADDED
Binary file (7.64 kB). View file
 
shap_e/models/configs.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Union
2
+
3
+ import blobfile as bf
4
+ import torch
5
+ import torch.nn as nn
6
+ import yaml
7
+
8
+ from shap_e.models.generation.latent_diffusion import SplitVectorDiffusion
9
+ from shap_e.models.generation.perceiver import PointDiffusionPerceiver
10
+ from shap_e.models.generation.pooled_mlp import PooledMLP
11
+ from shap_e.models.generation.transformer import (
12
+ CLIPImageGridPointDiffusionTransformer,
13
+ CLIPImageGridUpsamplePointDiffusionTransformer,
14
+ CLIPImagePointDiffusionTransformer,
15
+ PointDiffusionTransformer,
16
+ UpsamplePointDiffusionTransformer,
17
+ )
18
+ from shap_e.models.nerf.model import MLPNeRFModel, VoidNeRFModel
19
+ from shap_e.models.nerf.renderer import OneStepNeRFRenderer, TwoStepNeRFRenderer
20
+ from shap_e.models.nerstf.mlp import MLPDensitySDFModel, MLPNeRSTFModel
21
+ from shap_e.models.nerstf.renderer import NeRSTFRenderer
22
+ from shap_e.models.nn.meta import batch_meta_state_dict
23
+ from shap_e.models.stf.mlp import MLPSDFModel, MLPTextureFieldModel
24
+ from shap_e.models.stf.renderer import STFRenderer
25
+ from shap_e.models.transmitter.base import ChannelsDecoder, Transmitter, VectorDecoder
26
+ from shap_e.models.transmitter.channels_encoder import (
27
+ PointCloudPerceiverChannelsEncoder,
28
+ PointCloudTransformerChannelsEncoder,
29
+ )
30
+ from shap_e.models.transmitter.multiview_encoder import MultiviewTransformerEncoder
31
+ from shap_e.models.transmitter.pc_encoder import (
32
+ PointCloudPerceiverEncoder,
33
+ PointCloudTransformerEncoder,
34
+ )
35
+ from shap_e.models.volume import BoundingBoxVolume, SphericalVolume, UnboundedVolume
36
+
37
+
38
+ def model_from_config(config: Union[str, Dict[str, Any]], device: torch.device) -> nn.Module:
39
+ print(config)
40
+ if isinstance(config, str):
41
+ print("config", config)
42
+ with bf.BlobFile(config, "rb") as f:
43
+ obj = yaml.load(f, Loader=yaml.SafeLoader)
44
+ return model_from_config(obj, device=device)
45
+
46
+ config = config.copy()
47
+ name = config.pop("name")
48
+
49
+ if name == "PointCloudTransformerEncoder":
50
+ return PointCloudTransformerEncoder(device=device, dtype=torch.float32, **config)
51
+ elif name == "PointCloudPerceiverEncoder":
52
+ return PointCloudPerceiverEncoder(device=device, dtype=torch.float32, **config)
53
+ elif name == "PointCloudTransformerChannelsEncoder":
54
+ return PointCloudTransformerChannelsEncoder(device=device, dtype=torch.float32, **config)
55
+ elif name == "PointCloudPerceiverChannelsEncoder":
56
+ return PointCloudPerceiverChannelsEncoder(device=device, dtype=torch.float32, **config)
57
+ elif name == "MultiviewTransformerEncoder":
58
+ return MultiviewTransformerEncoder(device=device, dtype=torch.float32, **config)
59
+ elif name == "Transmitter":
60
+ renderer = model_from_config(config.pop("renderer"), device=device)
61
+ param_shapes = {
62
+ k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
63
+ }
64
+ encoder_config = config.pop("encoder").copy()
65
+ encoder_config["param_shapes"] = param_shapes
66
+ encoder = model_from_config(encoder_config, device=device)
67
+ return Transmitter(encoder=encoder, renderer=renderer, **config)
68
+ elif name == "VectorDecoder":
69
+ renderer = model_from_config(config.pop("renderer"), device=device)
70
+ param_shapes = {
71
+ k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
72
+ }
73
+ return VectorDecoder(param_shapes=param_shapes, renderer=renderer, device=device, **config)
74
+ elif name == "ChannelsDecoder":
75
+ renderer = model_from_config(config.pop("renderer"), device=device)
76
+ param_shapes = {
77
+ k: v.shape[1:] for k, v in batch_meta_state_dict(renderer, batch_size=1).items()
78
+ }
79
+ return ChannelsDecoder(
80
+ param_shapes=param_shapes, renderer=renderer, device=device, **config
81
+ )
82
+ elif name == "OneStepNeRFRenderer":
83
+ config = config.copy()
84
+ for field in [
85
+ # Required
86
+ "void_model",
87
+ "foreground_model",
88
+ "volume",
89
+ # Optional to use NeRF++
90
+ "background_model",
91
+ "outer_volume",
92
+ ]:
93
+ if field in config:
94
+ config[field] = model_from_config(config.pop(field).copy(), device)
95
+ return OneStepNeRFRenderer(device=device, **config)
96
+ elif name == "TwoStepNeRFRenderer":
97
+ config = config.copy()
98
+ for field in [
99
+ # Required
100
+ "void_model",
101
+ "coarse_model",
102
+ "fine_model",
103
+ "volume",
104
+ # Optional to use NeRF++
105
+ "coarse_background_model",
106
+ "fine_background_model",
107
+ "outer_volume",
108
+ ]:
109
+ if field in config:
110
+ config[field] = model_from_config(config.pop(field).copy(), device)
111
+ return TwoStepNeRFRenderer(device=device, **config)
112
+ elif name == "PooledMLP":
113
+ return PooledMLP(device, **config)
114
+ elif name == "PointDiffusionTransformer":
115
+ return PointDiffusionTransformer(device=device, dtype=torch.float32, **config)
116
+ elif name == "PointDiffusionPerceiver":
117
+ return PointDiffusionPerceiver(device=device, dtype=torch.float32, **config)
118
+ elif name == "CLIPImagePointDiffusionTransformer":
119
+ return CLIPImagePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
120
+ elif name == "CLIPImageGridPointDiffusionTransformer":
121
+ return CLIPImageGridPointDiffusionTransformer(device=device, dtype=torch.float32, **config)
122
+ elif name == "UpsamplePointDiffusionTransformer":
123
+ return UpsamplePointDiffusionTransformer(device=device, dtype=torch.float32, **config)
124
+ elif name == "CLIPImageGridUpsamplePointDiffusionTransformer":
125
+ return CLIPImageGridUpsamplePointDiffusionTransformer(
126
+ device=device, dtype=torch.float32, **config
127
+ )
128
+ elif name == "SplitVectorDiffusion":
129
+ inner_config = config.pop("inner")
130
+ d_latent = config.pop("d_latent")
131
+ latent_ctx = config.pop("latent_ctx", 1)
132
+ inner_config["input_channels"] = d_latent // latent_ctx
133
+ inner_config["n_ctx"] = latent_ctx
134
+ inner_config["output_channels"] = d_latent // latent_ctx * 2
135
+ inner_model = model_from_config(inner_config, device)
136
+ return SplitVectorDiffusion(
137
+ device=device, wrapped=inner_model, n_ctx=latent_ctx, d_latent=d_latent
138
+ )
139
+ elif name == "STFRenderer":
140
+ config = config.copy()
141
+ for field in ["sdf", "tf", "volume"]:
142
+ config[field] = model_from_config(config.pop(field), device)
143
+ return STFRenderer(device=device, **config)
144
+ elif name == "NeRSTFRenderer":
145
+ config = config.copy()
146
+ for field in ["sdf", "tf", "nerstf", "void", "volume"]:
147
+ if field not in config:
148
+ continue
149
+ config[field] = model_from_config(config.pop(field), device)
150
+ config.setdefault("sdf", None)
151
+ config.setdefault("tf", None)
152
+ config.setdefault("nerstf", None)
153
+ return NeRSTFRenderer(device=device, **config)
154
+
155
+ model_cls = {
156
+ "MLPSDFModel": MLPSDFModel,
157
+ "MLPTextureFieldModel": MLPTextureFieldModel,
158
+ "MLPNeRFModel": MLPNeRFModel,
159
+ "MLPDensitySDFModel": MLPDensitySDFModel,
160
+ "MLPNeRSTFModel": MLPNeRSTFModel,
161
+ "VoidNeRFModel": VoidNeRFModel,
162
+ "BoundingBoxVolume": BoundingBoxVolume,
163
+ "SphericalVolume": SphericalVolume,
164
+ "UnboundedVolume": UnboundedVolume,
165
+ }[name]
166
+ return model_cls(device=device, **config)
shap_e/models/download.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from: https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/download.py
3
+ """
4
+
5
+ import hashlib
6
+ import os
7
+ from functools import lru_cache
8
+ from typing import Dict, Optional
9
+
10
+ import requests
11
+ import torch
12
+ import yaml
13
+ from filelock import FileLock
14
+ from tqdm.auto import tqdm
15
+
16
+ MODEL_PATHS = {
17
+ "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt",
18
+ "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt",
19
+ "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt",
20
+ "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt",
21
+ }
22
+
23
+ CONFIG_PATHS = {
24
+ "transmitter": "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml",
25
+ "decoder": "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml",
26
+ "text300M": "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml",
27
+ "image300M": "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml",
28
+ "diffusion": "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml",
29
+ }
30
+
31
+ URL_HASHES = {
32
+ "https://openaipublic.azureedge.net/main/shap-e/transmitter.pt": "af02a0b85a8abdfb3919584b63c540ba175f6ad4790f574a7fef4617e5acdc3b",
33
+ "https://openaipublic.azureedge.net/main/shap-e/vector_decoder.pt": "d7e7ebbfe3780499ae89b2da5e7c1354012dba5a6abfe295bed42f25c3be1b98",
34
+ "https://openaipublic.azureedge.net/main/shap-e/text_cond.pt": "e6b4fa599a7b3c3b16c222d5f5fe56f9db9289ff0b6575fbe5c11bc97106aad4",
35
+ "https://openaipublic.azureedge.net/main/shap-e/image_cond.pt": "cb8072c64bbbcf6910488814d212227de5db291780d4ea99c6152f9346cf12aa",
36
+ "https://openaipublic.azureedge.net/main/shap-e/transmitter_config.yaml": "ffe1bcb405104a37d9408391182ab118a4ef313c391e07689684f1f62071605e",
37
+ "https://openaipublic.azureedge.net/main/shap-e/vector_decoder_config.yaml": "e6d373649f8e24d85925f4674b9ac41c57aba5f60e42cde6d10f87381326365c",
38
+ "https://openaipublic.azureedge.net/main/shap-e/text_cond_config.yaml": "f290beeea3d3e9ff15db01bde5382b6e549e463060c0744f89c049505be246c1",
39
+ "https://openaipublic.azureedge.net/main/shap-e/image_cond_config.yaml": "4e0745605a533c543c72add803a78d233e2a6401e0abfa0cad58afb4d74ad0b0",
40
+ "https://openaipublic.azureedge.net/main/shap-e/diffusion_config.yaml": "efcb2cd7ee545b2d27223979d41857802448143990572a42645cd09c2942ed57",
41
+ }
42
+
43
+
44
+ @lru_cache()
45
+ def default_cache_dir() -> str:
46
+ return os.path.join(os.path.abspath(os.getcwd()), "shap_e_model_cache")
47
+
48
+
49
+ def fetch_file_cached(
50
+ url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
51
+ ) -> str:
52
+ """
53
+ Download the file at the given URL into a local file and return the path.
54
+ If cache_dir is specified, it will be used to download the files.
55
+ Otherwise, default_cache_dir() is used.
56
+ """
57
+ expected_hash = URL_HASHES[url]
58
+
59
+ if cache_dir is None:
60
+ cache_dir = default_cache_dir()
61
+ os.makedirs(cache_dir, exist_ok=True)
62
+ local_path = os.path.join(cache_dir, url.split("/")[-1])
63
+ if os.path.exists(local_path):
64
+ check_hash(local_path, expected_hash)
65
+ return local_path
66
+
67
+ response = requests.get(url, stream=True)
68
+ size = int(response.headers.get("content-length", "0"))
69
+ with FileLock(local_path + ".lock"):
70
+ if progress:
71
+ pbar = tqdm(total=size, unit="iB", unit_scale=True)
72
+ tmp_path = local_path + ".tmp"
73
+ with open(tmp_path, "wb") as f:
74
+ for chunk in response.iter_content(chunk_size):
75
+ if progress:
76
+ pbar.update(len(chunk))
77
+ f.write(chunk)
78
+ os.rename(tmp_path, local_path)
79
+ if progress:
80
+ pbar.close()
81
+ check_hash(local_path, expected_hash)
82
+ return local_path
83
+
84
+
85
+ def check_hash(path: str, expected_hash: str):
86
+ actual_hash = hash_file(path)
87
+ if actual_hash != expected_hash:
88
+ raise RuntimeError(
89
+ f"The file {path} should have hash {expected_hash} but has {actual_hash}. "
90
+ "Try deleting it and running this call again."
91
+ )
92
+
93
+
94
+ def hash_file(path: str) -> str:
95
+ sha256_hash = hashlib.sha256()
96
+ with open(path, "rb") as file:
97
+ while True:
98
+ data = file.read(4096)
99
+ if not len(data):
100
+ break
101
+ sha256_hash.update(data)
102
+ return sha256_hash.hexdigest()
103
+
104
+
105
+ def load_config(
106
+ config_name: str,
107
+ progress: bool = False,
108
+ cache_dir: Optional[str] = None,
109
+ chunk_size: int = 4096,
110
+ ):
111
+ if config_name not in CONFIG_PATHS:
112
+ raise ValueError(
113
+ f"Unknown config name {config_name}. Known names are: {CONFIG_PATHS.keys()}."
114
+ )
115
+ path = fetch_file_cached(
116
+ CONFIG_PATHS[config_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
117
+ )
118
+ with open(path, "r") as f:
119
+ return yaml.safe_load(f)
120
+
121
+
122
+ def load_checkpoint(
123
+ checkpoint_name: str,
124
+ device: torch.device,
125
+ progress: bool = True,
126
+ cache_dir: Optional[str] = None,
127
+ chunk_size: int = 4096,
128
+ ) -> Dict[str, torch.Tensor]:
129
+ if checkpoint_name not in MODEL_PATHS:
130
+ raise ValueError(
131
+ f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
132
+ )
133
+ print(checkpoint_name)
134
+ path = fetch_file_cached(
135
+ MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
136
+ )
137
+ return torch.load(path, map_location=device)
138
+
139
+
140
+ def load_model(
141
+ model_name: str,
142
+ device: torch.device,
143
+ **kwargs,
144
+ ) -> Dict[str, torch.Tensor]:
145
+ from .configs import model_from_config
146
+
147
+ model = model_from_config(load_config(model_name, **kwargs), device=device)
148
+ # print(model_name, kwargs)
149
+ # print(model)
150
+ model.load_state_dict(load_checkpoint(model_name, device=device, **kwargs))
151
+ model.eval()
152
+ return model
shap_e/models/generation/__init__.py ADDED
File without changes
shap_e/models/generation/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (174 Bytes). View file
 
shap_e/models/generation/__pycache__/latent_diffusion.cpython-39.pyc ADDED
Binary file (1.44 kB). View file
 
shap_e/models/generation/__pycache__/perceiver.cpython-39.pyc ADDED
Binary file (6.73 kB). View file
 
shap_e/models/generation/__pycache__/pooled_mlp.cpython-39.pyc ADDED
Binary file (2.72 kB). View file
 
shap_e/models/generation/__pycache__/pretrained_clip.cpython-39.pyc ADDED
Binary file (9.69 kB). View file
 
shap_e/models/generation/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (19.3 kB). View file
 
shap_e/models/generation/__pycache__/util.cpython-39.pyc ADDED
Binary file (1.06 kB). View file
 
shap_e/models/generation/latent_diffusion.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Any, Callable, Dict, Optional
6
+
7
+
8
+ class SplitVectorDiffusion(nn.Module):
9
+ def __init__(self, *, device: torch.device, wrapped: nn.Module, n_ctx: int, d_latent: int):
10
+ super().__init__()
11
+ self.device = device
12
+ self.n_ctx = n_ctx
13
+ self.d_latent = d_latent
14
+ self.wrapped = wrapped
15
+
16
+ if hasattr(self.wrapped, "cached_model_kwargs"):
17
+ self.cached_model_kwargs = self.wrapped.cached_model_kwargs
18
+
19
+ def forward(self, x: torch.Tensor, t: torch.Tensor, conditional_latent: Optional[torch.Tensor] = None, **kwargs):
20
+ h = x.reshape(x.shape[0], self.n_ctx, -1).permute(0, 2, 1)
21
+ if conditional_latent is not None:
22
+ conditional_latent = conditional_latent.reshape(conditional_latent.shape[0], self.n_ctx, -1)
23
+ h = torch.cat([h.permute(0, 2, 1) , conditional_latent], dim=-1).permute(0, 2, 1) # (batch_size, n_ctx, channel) -> (batch_size, d_latent, n_ctx)
24
+ h = self.wrapped(h, t, **kwargs)
25
+ eps, var = torch.chunk(h, 2, dim=1)
26
+ return torch.cat(
27
+ [
28
+ eps.permute(0, 2, 1).flatten(1),
29
+ var.permute(0, 2, 1).flatten(1),
30
+ ],
31
+ dim=1,
32
+ )
shap_e/models/generation/perceiver.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from shap_e.models.nn.checkpoint import checkpoint
8
+
9
+ from .transformer import MLP, Transformer, init_linear
10
+ from .util import timestep_embedding
11
+
12
+
13
+ class MultiheadCrossAttention(nn.Module):
14
+ def __init__(
15
+ self,
16
+ *,
17
+ device: torch.device,
18
+ dtype: torch.dtype,
19
+ n_ctx: int,
20
+ n_data: int,
21
+ width: int,
22
+ heads: int,
23
+ init_scale: float,
24
+ data_width: Optional[int] = None,
25
+ ):
26
+ super().__init__()
27
+ self.n_ctx = n_ctx
28
+ self.n_data = n_data
29
+ self.width = width
30
+ self.heads = heads
31
+ self.data_width = width if data_width is None else data_width
32
+ self.c_q = nn.Linear(width, width, device=device, dtype=dtype)
33
+ self.c_kv = nn.Linear(self.data_width, width * 2, device=device, dtype=dtype)
34
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
35
+ self.attention = QKVMultiheadCrossAttention(
36
+ device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, n_data=n_data
37
+ )
38
+ init_linear(self.c_q, init_scale)
39
+ init_linear(self.c_kv, init_scale)
40
+ init_linear(self.c_proj, init_scale)
41
+
42
+ def forward(self, x, data):
43
+ x = self.c_q(x)
44
+ data = self.c_kv(data)
45
+ x = checkpoint(self.attention, (x, data), (), True)
46
+ x = self.c_proj(x)
47
+ return x
48
+
49
+
50
+ class QKVMultiheadCrossAttention(nn.Module):
51
+ def __init__(
52
+ self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, n_data: int
53
+ ):
54
+ super().__init__()
55
+ self.device = device
56
+ self.dtype = dtype
57
+ self.heads = heads
58
+ self.n_ctx = n_ctx
59
+ self.n_data = n_data
60
+
61
+ def forward(self, q, kv):
62
+ _, n_ctx, _ = q.shape
63
+ bs, n_data, width = kv.shape
64
+ attn_ch = width // self.heads // 2
65
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
66
+ q = q.view(bs, n_ctx, self.heads, -1)
67
+ kv = kv.view(bs, n_data, self.heads, -1)
68
+ k, v = torch.split(kv, attn_ch, dim=-1)
69
+ weight = torch.einsum(
70
+ "bthc,bshc->bhts", q * scale, k * scale
71
+ ) # More stable with f16 than dividing afterwards
72
+ wdtype = weight.dtype
73
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
74
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
75
+
76
+
77
+ class ResidualCrossAttentionBlock(nn.Module):
78
+ def __init__(
79
+ self,
80
+ *,
81
+ device: torch.device,
82
+ dtype: torch.dtype,
83
+ n_ctx: int,
84
+ n_data: int,
85
+ width: int,
86
+ heads: int,
87
+ data_width: Optional[int] = None,
88
+ init_scale: float = 1.0,
89
+ ):
90
+ super().__init__()
91
+
92
+ if data_width is None:
93
+ data_width = width
94
+
95
+ self.attn = MultiheadCrossAttention(
96
+ device=device,
97
+ dtype=dtype,
98
+ n_ctx=n_ctx,
99
+ n_data=n_data,
100
+ width=width,
101
+ heads=heads,
102
+ data_width=data_width,
103
+ init_scale=init_scale,
104
+ )
105
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
106
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
107
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
108
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
109
+
110
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
111
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
112
+ x = x + self.mlp(self.ln_3(x))
113
+ return x
114
+
115
+
116
+ class SimplePerceiver(nn.Module):
117
+ """
118
+ Only does cross attention
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ *,
124
+ device: torch.device,
125
+ dtype: torch.dtype,
126
+ n_ctx: int,
127
+ n_data: int,
128
+ width: int,
129
+ layers: int,
130
+ heads: int,
131
+ init_scale: float = 0.25,
132
+ data_width: Optional[int] = None,
133
+ ):
134
+ super().__init__()
135
+ self.n_ctx = n_ctx
136
+ self.width = width
137
+ self.layers = layers
138
+ init_scale = init_scale * math.sqrt(1.0 / width)
139
+ self.resblocks = nn.ModuleList(
140
+ [
141
+ ResidualCrossAttentionBlock(
142
+ device=device,
143
+ dtype=dtype,
144
+ n_ctx=n_ctx,
145
+ n_data=n_data,
146
+ width=width,
147
+ heads=heads,
148
+ init_scale=init_scale,
149
+ data_width=data_width,
150
+ )
151
+ for _ in range(layers)
152
+ ]
153
+ )
154
+
155
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
156
+ for block in self.resblocks:
157
+ x = block(x, data)
158
+ return x
159
+
160
+
161
+ class PointDiffusionPerceiver(nn.Module):
162
+ def __init__(
163
+ self,
164
+ *,
165
+ device: torch.device,
166
+ dtype: torch.dtype,
167
+ input_channels: int = 3,
168
+ output_channels: int = 3,
169
+ n_ctx: int = 1024,
170
+ n_latent: int = 128,
171
+ width: int = 512,
172
+ encoder_layers: int = 12,
173
+ latent_layers: int = 12,
174
+ decoder_layers: int = 12,
175
+ heads: int = 8,
176
+ init_scale: float = 0.25,
177
+ ):
178
+ super().__init__()
179
+ self.time_embed = MLP(
180
+ device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
181
+ )
182
+ self.latent_embed = MLP(
183
+ device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
184
+ )
185
+ self.n_latent = n_latent
186
+
187
+ self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
188
+ self.encoder = SimplePerceiver(
189
+ device=device,
190
+ dtype=dtype,
191
+ n_ctx=n_latent,
192
+ n_data=n_ctx,
193
+ width=width,
194
+ layers=encoder_layers,
195
+ heads=heads,
196
+ init_scale=init_scale,
197
+ )
198
+ self.processor = Transformer(
199
+ device=device,
200
+ dtype=dtype,
201
+ n_ctx=n_latent,
202
+ width=width,
203
+ layers=latent_layers,
204
+ heads=heads,
205
+ init_scale=init_scale,
206
+ )
207
+ self.decoder = SimplePerceiver(
208
+ device=device,
209
+ dtype=dtype,
210
+ n_ctx=n_ctx,
211
+ n_data=n_latent,
212
+ width=width,
213
+ layers=decoder_layers,
214
+ heads=heads,
215
+ init_scale=init_scale,
216
+ )
217
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
218
+ self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
219
+ self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
220
+ with torch.no_grad():
221
+ self.output_proj.weight.zero_()
222
+ self.output_proj.bias.zero_()
223
+
224
+ def forward(self, x: torch.Tensor, t: torch.Tensor):
225
+ """
226
+ :param x: an [N x C x T] tensor.
227
+ :param t: an [N] tensor.
228
+ :return: an [N x C' x T] tensor.
229
+ """
230
+ assert x.shape[-1] == self.decoder.n_ctx
231
+ t_embed = self.time_embed(timestep_embedding(t, self.encoder.width))
232
+ data = self.input_proj(x.permute(0, 2, 1)) + t_embed[:, None]
233
+ data = self.ln_pre(data)
234
+
235
+ l = torch.arange(self.n_latent).to(x.device)
236
+ h = self.latent_embed(timestep_embedding(l, self.decoder.width))
237
+ h = h.unsqueeze(0).repeat(x.shape[0], 1, 1)
238
+
239
+ h = self.encoder(h, data)
240
+ h = self.processor(h)
241
+ h = self.decoder(data, h)
242
+ h = self.ln_post(h)
243
+ h = self.output_proj(h)
244
+ return h.permute(0, 2, 1)
shap_e/models/generation/pooled_mlp.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .util import timestep_embedding
5
+
6
+
7
+ class PooledMLP(nn.Module):
8
+ def __init__(
9
+ self,
10
+ device: torch.device,
11
+ *,
12
+ input_channels: int = 3,
13
+ output_channels: int = 6,
14
+ hidden_size: int = 256,
15
+ resblocks: int = 4,
16
+ pool_op: str = "max",
17
+ ):
18
+ super().__init__()
19
+ self.input_embed = nn.Conv1d(input_channels, hidden_size, kernel_size=1, device=device)
20
+ self.time_embed = nn.Linear(hidden_size, hidden_size, device=device)
21
+
22
+ blocks = []
23
+ for _ in range(resblocks):
24
+ blocks.append(ResBlock(hidden_size, pool_op, device=device))
25
+ self.sequence = nn.Sequential(*blocks)
26
+
27
+ self.out = nn.Conv1d(hidden_size, output_channels, kernel_size=1, device=device)
28
+ with torch.no_grad():
29
+ self.out.bias.zero_()
30
+ self.out.weight.zero_()
31
+
32
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
33
+ in_embed = self.input_embed(x)
34
+ t_embed = self.time_embed(timestep_embedding(t, in_embed.shape[1]))
35
+ h = in_embed + t_embed[..., None]
36
+ h = self.sequence(h)
37
+ h = self.out(h)
38
+ return h
39
+
40
+
41
+ class ResBlock(nn.Module):
42
+ def __init__(self, hidden_size: int, pool_op: str, device: torch.device):
43
+ super().__init__()
44
+ assert pool_op in ["mean", "max"]
45
+ self.pool_op = pool_op
46
+ self.body = nn.Sequential(
47
+ nn.SiLU(),
48
+ nn.LayerNorm((hidden_size,), device=device),
49
+ nn.Linear(hidden_size, hidden_size, device=device),
50
+ nn.SiLU(),
51
+ nn.LayerNorm((hidden_size,), device=device),
52
+ nn.Linear(hidden_size, hidden_size, device=device),
53
+ )
54
+ self.gate = nn.Sequential(
55
+ nn.Linear(hidden_size, hidden_size, device=device),
56
+ nn.Tanh(),
57
+ )
58
+
59
+ def forward(self, x: torch.Tensor):
60
+ N, C, T = x.shape
61
+ out = self.body(x.permute(0, 2, 1).reshape(N * T, C)).reshape([N, T, C]).permute(0, 2, 1)
62
+ pooled = pool(self.pool_op, x)
63
+ gate = self.gate(pooled)
64
+ return x + out * gate[..., None]
65
+
66
+
67
+ def pool(op_name: str, x: torch.Tensor) -> torch.Tensor:
68
+ if op_name == "max":
69
+ pooled, _ = torch.max(x, dim=-1)
70
+ elif op_name == "mean":
71
+ pooled, _ = torch.mean(x, dim=-1)
72
+ else:
73
+ raise ValueError(f"unknown pool op: {op_name}")
74
+ return pooled
shap_e/models/generation/pretrained_clip.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+
8
+ from shap_e.models.download import default_cache_dir
9
+
10
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
11
+
12
+
13
+ class ImageCLIP(nn.Module):
14
+ """
15
+ A wrapper around a pre-trained CLIP model that automatically handles
16
+ batches of texts, images, and embeddings.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ device: torch.device,
22
+ dtype: Optional[torch.dtype] = torch.float32,
23
+ ensure_used_params: bool = True,
24
+ clip_name: str = "ViT-L/14",
25
+ cache_dir: Optional[str] = None,
26
+ ):
27
+ super().__init__()
28
+
29
+ assert clip_name in ["ViT-L/14", "ViT-B/32"]
30
+
31
+ self.device = device
32
+ self.ensure_used_params = ensure_used_params
33
+
34
+ # Lazy import because of torchvision.
35
+ import clip
36
+
37
+ self.clip_model, self.preprocess = clip.load(
38
+ clip_name, device=device, download_root=cache_dir or default_cache_dir()
39
+ )
40
+ self.clip_name = clip_name
41
+
42
+ if dtype is not None:
43
+ self.clip_model.to(dtype)
44
+ self._tokenize = clip.tokenize
45
+
46
+ @property
47
+ def feature_dim(self) -> int:
48
+ if self.clip_name == "ViT-L/14":
49
+ return 768
50
+ else:
51
+ return 512
52
+
53
+ @property
54
+ def grid_size(self) -> int:
55
+ if self.clip_name == "ViT-L/14":
56
+ return 16
57
+ else:
58
+ return 7
59
+
60
+ @property
61
+ def grid_feature_dim(self) -> int:
62
+ if self.clip_name == "ViT-L/14":
63
+ return 1024
64
+ else:
65
+ return 768
66
+
67
+ def forward(
68
+ self,
69
+ batch_size: int,
70
+ images: Optional[Iterable[Optional[ImageType]]] = None,
71
+ texts: Optional[Iterable[Optional[str]]] = None,
72
+ embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
73
+ ) -> torch.Tensor:
74
+ """
75
+ Generate a batch of embeddings from a mixture of images, texts,
76
+ precomputed embeddings, and possibly empty values.
77
+
78
+ For each batch element, at most one of images, texts, and embeddings
79
+ should have a non-None value. Embeddings from multiple modalities
80
+ cannot be mixed for a single batch element. If no modality is provided,
81
+ a zero embedding will be used for the batch element.
82
+ """
83
+ image_seq = [None] * batch_size if images is None else list(images)
84
+ text_seq = [None] * batch_size if texts is None else list(texts)
85
+ embedding_seq = [None] * batch_size if embeddings is None else list(embeddings)
86
+ assert len(image_seq) == batch_size, "number of images should match batch size"
87
+ assert len(text_seq) == batch_size, "number of texts should match batch size"
88
+ assert len(embedding_seq) == batch_size, "number of embeddings should match batch size"
89
+
90
+ if self.ensure_used_params:
91
+ return self._static_multimodal_embed(
92
+ images=image_seq, texts=text_seq, embeddings=embedding_seq
93
+ )
94
+
95
+ result = torch.zeros((batch_size, self.feature_dim), device=self.device)
96
+ index_images = []
97
+ index_texts = []
98
+ for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)):
99
+ assert (
100
+ sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2
101
+ ), "only one modality may be non-None per batch element"
102
+ if image is not None:
103
+ index_images.append((i, image))
104
+ elif text is not None:
105
+ index_texts.append((i, text))
106
+ elif emb is not None:
107
+ result[i] = emb.to(result)
108
+
109
+ if len(index_images):
110
+ embs = self.embed_images((img for _, img in index_images))
111
+ for (i, _), emb in zip(index_images, embs):
112
+ result[i] = emb.to(result)
113
+ if len(index_texts):
114
+ embs = self.embed_text((text for _, text in index_texts))
115
+ for (i, _), emb in zip(index_texts, embs):
116
+ result[i] = emb.to(result)
117
+
118
+ return result
119
+
120
+ def _static_multimodal_embed(
121
+ self,
122
+ images: List[Optional[ImageType]] = None,
123
+ texts: List[Optional[str]] = None,
124
+ embeddings: List[Optional[torch.Tensor]] = None,
125
+ ) -> torch.Tensor:
126
+ """
127
+ Like forward(), but always runs all encoders to ensure that
128
+ the forward graph looks the same on every rank.
129
+ """
130
+ image_emb = self.embed_images(images)
131
+ text_emb = self.embed_text(t if t else "" for t in texts)
132
+ joined_embs = torch.stack(
133
+ [
134
+ emb.to(device=self.device, dtype=torch.float32)
135
+ if emb is not None
136
+ else torch.zeros(self.feature_dim, device=self.device)
137
+ for emb in embeddings
138
+ ],
139
+ dim=0,
140
+ )
141
+
142
+ image_flag = torch.tensor([x is not None for x in images], device=self.device)[
143
+ :, None
144
+ ].expand_as(image_emb)
145
+ text_flag = torch.tensor([x is not None for x in texts], device=self.device)[
146
+ :, None
147
+ ].expand_as(image_emb)
148
+ emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[
149
+ :, None
150
+ ].expand_as(image_emb)
151
+
152
+ return (
153
+ image_flag.float() * image_emb
154
+ + text_flag.float() * text_emb
155
+ + emb_flag.float() * joined_embs
156
+ + self.clip_model.logit_scale * 0 # avoid unused parameters
157
+ )
158
+
159
+ def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
160
+ """
161
+ :param xs: N images, stored as numpy arrays, tensors, or PIL images.
162
+ :return: an [N x D] tensor of features.
163
+ """
164
+ clip_inputs = self.images_to_tensor(xs)
165
+ results = self.clip_model.encode_image(clip_inputs).float()
166
+ return results / torch.linalg.norm(results, dim=-1, keepdim=True)
167
+
168
+ def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
169
+ """
170
+ Embed text prompts as an [N x D] tensor.
171
+ """
172
+ enc = self.clip_model.encode_text(
173
+ self._tokenize(list(prompts), truncate=True).to(self.device)
174
+ ).float()
175
+ return enc / torch.linalg.norm(enc, dim=-1, keepdim=True)
176
+
177
+ def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
178
+ """
179
+ Embed images into latent grids.
180
+
181
+ :param xs: an iterable of images to embed.
182
+ :return: a tensor of shape [N x C x L], where L = self.grid_size**2.
183
+ """
184
+ if self.ensure_used_params:
185
+ extra_value = 0.0
186
+ for p in self.parameters():
187
+ extra_value = extra_value + p.mean() * 0.0
188
+ else:
189
+ extra_value = 0.0
190
+
191
+ x = self.images_to_tensor(xs).to(self.clip_model.dtype)
192
+
193
+ # https://github.com/openai/CLIP/blob/4d120f3ec35b30bd0f992f5d8af2d793aad98d2a/clip/model.py#L225
194
+ vt = self.clip_model.visual
195
+ x = vt.conv1(x) # shape = [*, width, grid, grid]
196
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
197
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
198
+ x = torch.cat(
199
+ [
200
+ vt.class_embedding.to(x.dtype)
201
+ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
202
+ x,
203
+ ],
204
+ dim=1,
205
+ ) # shape = [*, grid ** 2 + 1, width]
206
+ x = x + vt.positional_embedding.to(x.dtype)
207
+ x = vt.ln_pre(x)
208
+
209
+ x = x.permute(1, 0, 2) # NLD -> LND
210
+ x = vt.transformer(x)
211
+ x = x.permute(1, 2, 0) # LND -> NDL
212
+
213
+ return x[..., 1:].contiguous().float() + extra_value
214
+
215
+ def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
216
+ return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device)
217
+
218
+
219
+ class FrozenImageCLIP:
220
+ def __init__(self, device: torch.device, **kwargs):
221
+ self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs)
222
+ for parameter in self.model.parameters():
223
+ parameter.requires_grad_(False)
224
+
225
+ @property
226
+ def feature_dim(self) -> int:
227
+ return self.model.feature_dim
228
+
229
+ @property
230
+ def grid_size(self) -> int:
231
+ return self.model.grid_size
232
+
233
+ @property
234
+ def grid_feature_dim(self) -> int:
235
+ return self.model.grid_feature_dim
236
+
237
+ def __call__(
238
+ self,
239
+ batch_size: int,
240
+ images: Optional[Iterable[Optional[ImageType]]] = None,
241
+ texts: Optional[Iterable[Optional[str]]] = None,
242
+ embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
243
+ ) -> torch.Tensor:
244
+ # We don't do a no_grad() here so that gradients could still
245
+ # flow to the input embeddings argument.
246
+ # This behavior is currently not used, but it could be.
247
+ return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings)
248
+
249
+ def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
250
+ with torch.no_grad():
251
+ return self.model.embed_images(xs)
252
+
253
+ def embed_text(self, prompts: Iterable[str]) -> torch.Tensor:
254
+ with torch.no_grad():
255
+ return self.model.embed_text(prompts)
256
+
257
+ def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor:
258
+ with torch.no_grad():
259
+ return self.model.embed_images_grid(xs)
260
+
261
+
262
+ def _image_to_pil(obj: Optional[ImageType]) -> Image.Image:
263
+ if obj is None:
264
+ return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8))
265
+ if isinstance(obj, np.ndarray):
266
+ return Image.fromarray(obj.astype(np.uint8))
267
+ elif isinstance(obj, torch.Tensor):
268
+ return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8))
269
+ else:
270
+ return obj
shap_e/models/generation/transformer.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from shap_e.models.nn.checkpoint import checkpoint
8
+
9
+ from .pretrained_clip import FrozenImageCLIP, ImageCLIP, ImageType
10
+ from .util import timestep_embedding
11
+
12
+ def init_linear(l, stddev):
13
+ nn.init.normal_(l.weight, std=stddev)
14
+ if l.bias is not None:
15
+ nn.init.constant_(l.bias, 0.0)
16
+
17
+
18
+ class MultiheadAttention(nn.Module):
19
+ def __init__(
20
+ self,
21
+ *,
22
+ device: torch.device,
23
+ dtype: torch.dtype,
24
+ n_ctx: int,
25
+ width: int,
26
+ heads: int,
27
+ init_scale: float,
28
+ ):
29
+ super().__init__()
30
+ self.n_ctx = n_ctx
31
+ self.width = width
32
+ self.heads = heads
33
+ self.c_qkv = nn.Linear(width, width * 3, device=device, dtype=dtype)
34
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
35
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
36
+ init_linear(self.c_qkv, init_scale)
37
+ init_linear(self.c_proj, init_scale)
38
+
39
+ def forward(self, x):
40
+ x = self.c_qkv(x)
41
+ x = checkpoint(self.attention, (x,), (), True)
42
+ x = self.c_proj(x)
43
+ return x
44
+
45
+
46
+ class MLP(nn.Module):
47
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, width: int, init_scale: float):
48
+ super().__init__()
49
+ self.width = width
50
+ self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
51
+ self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
52
+ self.gelu = nn.GELU()
53
+ init_linear(self.c_fc, init_scale)
54
+ init_linear(self.c_proj, init_scale)
55
+
56
+ def forward(self, x):
57
+ return self.c_proj(self.gelu(self.c_fc(x)))
58
+
59
+
60
+ class QKVMultiheadAttention(nn.Module):
61
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
62
+ super().__init__()
63
+ self.device = device
64
+ self.dtype = dtype
65
+ self.heads = heads
66
+ self.n_ctx = n_ctx
67
+
68
+ def forward(self, qkv):
69
+ bs, n_ctx, width = qkv.shape
70
+ attn_ch = width // self.heads // 3
71
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
72
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
73
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
74
+ weight = torch.einsum(
75
+ "bthc,bshc->bhts", q * scale, k * scale
76
+ ) # More stable with f16 than dividing afterwards
77
+ wdtype = weight.dtype
78
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
79
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
80
+
81
+
82
+ class ResidualAttentionBlock(nn.Module):
83
+ def __init__(
84
+ self,
85
+ *,
86
+ device: torch.device,
87
+ dtype: torch.dtype,
88
+ n_ctx: int,
89
+ width: int,
90
+ heads: int,
91
+ init_scale: float = 1.0,
92
+ ):
93
+ super().__init__()
94
+
95
+ self.attn = MultiheadAttention(
96
+ device=device,
97
+ dtype=dtype,
98
+ n_ctx=n_ctx,
99
+ width=width,
100
+ heads=heads,
101
+ init_scale=init_scale,
102
+ )
103
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
104
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
105
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
106
+
107
+ def forward(self, x: torch.Tensor):
108
+ x = x + self.attn(self.ln_1(x))
109
+ x = x + self.mlp(self.ln_2(x))
110
+ return x
111
+
112
+
113
+ class Transformer(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ device: torch.device,
118
+ dtype: torch.dtype,
119
+ n_ctx: int,
120
+ width: int,
121
+ layers: int,
122
+ heads: int,
123
+ init_scale: float = 0.25,
124
+ ):
125
+ super().__init__()
126
+ self.n_ctx = n_ctx
127
+ self.width = width
128
+ self.layers = layers
129
+ init_scale = init_scale * math.sqrt(1.0 / width)
130
+ self.resblocks = nn.ModuleList(
131
+ [
132
+ ResidualAttentionBlock(
133
+ device=device,
134
+ dtype=dtype,
135
+ n_ctx=n_ctx,
136
+ width=width,
137
+ heads=heads,
138
+ init_scale=init_scale,
139
+ )
140
+ for _ in range(layers)
141
+ ]
142
+ )
143
+
144
+ def forward(self, x: torch.Tensor):
145
+ for block in self.resblocks:
146
+ x = block(x)
147
+ return x
148
+
149
+
150
+ class PointDiffusionTransformer(nn.Module):
151
+ def __init__(
152
+ self,
153
+ *,
154
+ device: torch.device,
155
+ dtype: torch.dtype,
156
+ input_channels: int = 3,
157
+ output_channels: int = 3,
158
+ n_ctx: int = 1024,
159
+ width: int = 512,
160
+ layers: int = 12,
161
+ heads: int = 8,
162
+ init_scale: float = 0.25,
163
+ time_token_cond: bool = False,
164
+ use_pos_emb: bool = False,
165
+ pos_emb_init_scale: float = 1.0,
166
+ pos_emb_n_ctx: Optional[int] = None,
167
+ ):
168
+ super().__init__()
169
+ self.input_channels = input_channels
170
+ self.output_channels = output_channels
171
+ self.n_ctx = n_ctx
172
+ self.time_token_cond = time_token_cond
173
+ self.use_pos_emb = use_pos_emb
174
+ self.time_embed = MLP(
175
+ device=device, dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
176
+ )
177
+ self.ln_pre = nn.LayerNorm(width, device=device, dtype=dtype)
178
+ self.backbone = Transformer(
179
+ device=device,
180
+ dtype=dtype,
181
+ n_ctx=n_ctx + int(time_token_cond),
182
+ width=width,
183
+ layers=layers,
184
+ heads=heads,
185
+ init_scale=init_scale,
186
+ )
187
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
188
+ self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
189
+ self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
190
+ # with torch.no_grad():
191
+ # self.output_proj.weight.zero_()
192
+ # self.output_proj.bias.zero_()
193
+ if self.use_pos_emb:
194
+ self.register_parameter(
195
+ "pos_emb",
196
+ nn.Parameter(
197
+ pos_emb_init_scale
198
+ * torch.randn(pos_emb_n_ctx or self.n_ctx, width, device=device, dtype=dtype)
199
+ ),
200
+ )
201
+
202
+ def forward(self, x: torch.Tensor, t: torch.Tensor):
203
+ """
204
+ :param x: an [N x C x T] tensor.
205
+ :param t: an [N] tensor.
206
+ :return: an [N x C' x T] tensor.
207
+ """
208
+ assert x.shape[-1] == self.n_ctx
209
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
210
+ return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
211
+
212
+ def _forward_with_cond(
213
+ self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
214
+ ) -> torch.Tensor:
215
+ h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC
216
+ for emb, as_token in cond_as_token:
217
+ if not as_token:
218
+ h = h + emb[:, None]
219
+ if self.use_pos_emb:
220
+ h = h + self.pos_emb
221
+ extra_tokens = [
222
+ (emb[:, None] if len(emb.shape) == 2 else emb)
223
+ for emb, as_token in cond_as_token
224
+ if as_token
225
+ ]
226
+ if len(extra_tokens):
227
+ h = torch.cat(extra_tokens + [h], dim=1)
228
+ h = self.ln_pre(h)
229
+ h = self.backbone(h)
230
+ h = self.ln_post(h)
231
+ if len(extra_tokens):
232
+ h = h[:, sum(h.shape[1] for h in extra_tokens):]
233
+ h = self.output_proj(h)
234
+ return h.permute(0, 2, 1) # NCL -> NLC
235
+
236
+
237
+
238
+
239
+ class CLIPImagePointDiffusionTransformer(PointDiffusionTransformer):
240
+ def __init__(
241
+ self,
242
+ *,
243
+ device: torch.device,
244
+ dtype: torch.dtype,
245
+ n_ctx: int = 1024,
246
+ token_cond: bool = False,
247
+ cond_drop_prob: float = 0.0,
248
+ frozen_clip: bool = True,
249
+ **kwargs,
250
+ ):
251
+ super().__init__(
252
+ device=device, dtype=dtype, n_ctx=n_ctx + int(token_cond), pos_emb_n_ctx=n_ctx, **kwargs
253
+ )
254
+ # print("!!!!!", "deivce:", device, "dtype:", dtype, "n_ctx:", n_ctx, "token_cond:", token_cond, "cond_drop_prob:", cond_drop_prob, "frozen_clip:", frozen_clip, "kwargs:", kwargs)
255
+ self.n_ctx = n_ctx
256
+ self.token_cond = token_cond
257
+ self.clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
258
+ self.clip_embed = nn.Linear(
259
+ self.clip.feature_dim, self.backbone.width, device=device, dtype=dtype
260
+ )
261
+ self.cond_drop_prob = cond_drop_prob
262
+
263
+ def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
264
+ with torch.no_grad():
265
+ return dict(embeddings=self.clip(batch_size, **model_kwargs))
266
+
267
+ def forward(
268
+ self,
269
+ x: torch.Tensor,
270
+ t: torch.Tensor,
271
+ images: Optional[Iterable[Optional[ImageType]]] = None,
272
+ texts: Optional[Iterable[Optional[str]]] = None,
273
+ embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None,
274
+ ):
275
+ """
276
+ :param x: an [N x C x T] tensor.
277
+ :param t: an [N] tensor.
278
+ :param images: a batch of images to condition on.
279
+ :param texts: a batch of texts to condition on.
280
+ :param embeddings: a batch of CLIP embeddings to condition on.
281
+ :return: an [N x C' x T] tensor.
282
+ """
283
+ # print("x.shape", x.shape, "t.shape", t.shape, "images", images, "texts", texts, "embeddings", embeddings)
284
+ assert x.shape[-1] == self.n_ctx # self.n_ctx = 1024
285
+
286
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
287
+ clip_out = self.clip(batch_size=len(x), images=images, texts=texts, embeddings=embeddings)
288
+ assert len(clip_out.shape) == 2 and clip_out.shape[0] == x.shape[0]
289
+
290
+ if self.training:
291
+ mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
292
+ clip_out = clip_out * mask[:, None].to(clip_out)
293
+
294
+ # Rescale the features to have unit variance
295
+ clip_out = math.sqrt(clip_out.shape[1]) * clip_out
296
+
297
+ clip_embed = self.clip_embed(clip_out)
298
+
299
+ cond = [(clip_embed, self.token_cond), (t_embed, self.time_token_cond)]
300
+ return self._forward_with_cond(x, cond)
301
+
302
+
303
+ class CLIPImageGridPointDiffusionTransformer(PointDiffusionTransformer):
304
+ def __init__(
305
+ self,
306
+ *,
307
+ device: torch.device,
308
+ dtype: torch.dtype,
309
+ n_ctx: int = 1024,
310
+ cond_drop_prob: float = 0.0,
311
+ frozen_clip: bool = True,
312
+ **kwargs,
313
+ ):
314
+ clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
315
+ super().__init__(
316
+ device=device,
317
+ dtype=dtype,
318
+ n_ctx=n_ctx + clip.grid_size**2,
319
+ pos_emb_n_ctx=n_ctx,
320
+ **kwargs,
321
+ )
322
+ self.n_ctx = n_ctx
323
+ self.clip = clip
324
+ self.clip_embed = nn.Sequential(
325
+ nn.LayerNorm(
326
+ normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
327
+ ),
328
+ nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
329
+ )
330
+ self.cond_drop_prob = cond_drop_prob
331
+
332
+ def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
333
+ _ = batch_size
334
+ with torch.no_grad():
335
+ return dict(embeddings=self.clip.embed_images_grid(model_kwargs["images"]))
336
+
337
+ def forward(
338
+ self,
339
+ x: torch.Tensor,
340
+ t: torch.Tensor,
341
+ images: Optional[Iterable[ImageType]] = None,
342
+ embeddings: Optional[Iterable[torch.Tensor]] = None,
343
+ ):
344
+ """
345
+ :param x: an [N x C x T] tensor.
346
+ :param t: an [N] tensor.
347
+ :param images: a batch of images to condition on.
348
+ :param embeddings: a batch of CLIP latent grids to condition on.
349
+ :return: an [N x C' x T] tensor.
350
+ """
351
+ assert images is not None or embeddings is not None, "must specify images or embeddings"
352
+ assert images is None or embeddings is None, "cannot specify both images and embeddings"
353
+ assert x.shape[-1] == self.n_ctx
354
+
355
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
356
+
357
+ if images is not None:
358
+ clip_out = self.clip.embed_images_grid(images)
359
+ else:
360
+ clip_out = embeddings
361
+
362
+ if self.training:
363
+ mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
364
+ clip_out = clip_out * mask[:, None, None].to(clip_out)
365
+
366
+ clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
367
+ clip_embed = self.clip_embed(clip_out)
368
+
369
+ cond = [(t_embed, self.time_token_cond), (clip_embed, True)]
370
+ return self._forward_with_cond(x, cond)
371
+
372
+
373
+ class UpsamplePointDiffusionTransformer(PointDiffusionTransformer):
374
+ def __init__(
375
+ self,
376
+ *,
377
+ device: torch.device,
378
+ dtype: torch.dtype,
379
+ cond_input_channels: Optional[int] = None,
380
+ cond_ctx: int = 1024,
381
+ n_ctx: int = 4096 - 1024,
382
+ channel_scales: Optional[Sequence[float]] = None,
383
+ channel_biases: Optional[Sequence[float]] = None,
384
+ **kwargs,
385
+ ):
386
+ super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + cond_ctx, **kwargs)
387
+ self.n_ctx = n_ctx
388
+ self.cond_input_channels = cond_input_channels or self.input_channels
389
+ self.cond_point_proj = nn.Linear(
390
+ self.cond_input_channels, self.backbone.width, device=device, dtype=dtype
391
+ )
392
+
393
+ self.register_buffer(
394
+ "channel_scales",
395
+ torch.tensor(channel_scales, dtype=dtype, device=device)
396
+ if channel_scales is not None
397
+ else None,
398
+ )
399
+ self.register_buffer(
400
+ "channel_biases",
401
+ torch.tensor(channel_biases, dtype=dtype, device=device)
402
+ if channel_biases is not None
403
+ else None,
404
+ )
405
+
406
+ def forward(self, x: torch.Tensor, t: torch.Tensor, *, low_res: torch.Tensor):
407
+ """
408
+ :param x: an [N x C1 x T] tensor.
409
+ :param t: an [N] tensor.
410
+ :param low_res: an [N x C2 x T'] tensor of conditioning points.
411
+ :return: an [N x C3 x T] tensor.
412
+ """
413
+ assert x.shape[-1] == self.n_ctx
414
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
415
+ low_res_embed = self._embed_low_res(low_res)
416
+ cond = [(t_embed, self.time_token_cond), (low_res_embed, True)]
417
+ return self._forward_with_cond(x, cond)
418
+
419
+ def _embed_low_res(self, x: torch.Tensor) -> torch.Tensor:
420
+ if self.channel_scales is not None:
421
+ x = x * self.channel_scales[None, :, None]
422
+ if self.channel_biases is not None:
423
+ x = x + self.channel_biases[None, :, None]
424
+ return self.cond_point_proj(x.permute(0, 2, 1))
425
+
426
+
427
+ class CLIPImageGridUpsamplePointDiffusionTransformer(UpsamplePointDiffusionTransformer):
428
+ def __init__(
429
+ self,
430
+ *,
431
+ device: torch.device,
432
+ dtype: torch.dtype,
433
+ n_ctx: int = 4096 - 1024,
434
+ cond_drop_prob: float = 0.0,
435
+ frozen_clip: bool = True,
436
+ **kwargs,
437
+ ):
438
+ clip = (FrozenImageCLIP if frozen_clip else ImageCLIP)(device)
439
+ super().__init__(device=device, dtype=dtype, n_ctx=n_ctx + clip.grid_size**2, **kwargs)
440
+ self.n_ctx = n_ctx
441
+
442
+ self.clip = clip
443
+ self.clip_embed = nn.Sequential(
444
+ nn.LayerNorm(
445
+ normalized_shape=(self.clip.grid_feature_dim,), device=device, dtype=dtype
446
+ ),
447
+ nn.Linear(self.clip.grid_feature_dim, self.backbone.width, device=device, dtype=dtype),
448
+ )
449
+ self.cond_drop_prob = cond_drop_prob
450
+
451
+ def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
452
+ _ = batch_size
453
+ with torch.no_grad():
454
+ return dict(
455
+ embeddings=self.clip.embed_images_grid(model_kwargs["images"]),
456
+ low_res=model_kwargs["low_res"],
457
+ )
458
+
459
+ def forward(
460
+ self,
461
+ x: torch.Tensor,
462
+ t: torch.Tensor,
463
+ *,
464
+ low_res: torch.Tensor,
465
+ images: Optional[Iterable[ImageType]] = None,
466
+ embeddings: Optional[Iterable[torch.Tensor]] = None,
467
+ ):
468
+ """
469
+ :param x: an [N x C1 x T] tensor.
470
+ :param t: an [N] tensor.
471
+ :param low_res: an [N x C2 x T'] tensor of conditioning points.
472
+ :param images: a batch of images to condition on.
473
+ :param embeddings: a batch of CLIP latent grids to condition on.
474
+ :return: an [N x C3 x T] tensor.
475
+ """
476
+ assert x.shape[-1] == self.n_ctx
477
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
478
+ low_res_embed = self._embed_low_res(low_res)
479
+
480
+ if images is not None:
481
+ clip_out = self.clip.embed_images_grid(images)
482
+ else:
483
+ clip_out = embeddings
484
+
485
+ if self.training:
486
+ mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
487
+ clip_out = clip_out * mask[:, None, None].to(clip_out)
488
+
489
+ clip_out = clip_out.permute(0, 2, 1) # NCL -> NLC
490
+ clip_embed = self.clip_embed(clip_out)
491
+
492
+ cond = [(t_embed, self.time_token_cond), (clip_embed, True), (low_res_embed, True)]
493
+ return self._forward_with_cond(x, cond)
494
+
shap_e/models/generation/util.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+
5
+
6
+ def timestep_embedding(timesteps, dim, max_period=10000):
7
+ """
8
+ Create sinusoidal timestep embeddings.
9
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
10
+ These may be fractional.
11
+ :param dim: the dimension of the output.
12
+ :param max_period: controls the minimum frequency of the embeddings.
13
+ :return: an [N x dim] Tensor of positional embeddings.
14
+ """
15
+ half = dim // 2
16
+ freqs = torch.exp(
17
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
18
+ ).to(device=timesteps.device)
19
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
20
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
21
+ if dim % 2:
22
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
23
+ return embedding
shap_e/models/nerf/__init__.py ADDED
File without changes
shap_e/models/nerf/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (168 Bytes). View file
 
shap_e/models/nerf/__pycache__/model.cpython-39.pyc ADDED
Binary file (6.51 kB). View file
 
shap_e/models/nerf/__pycache__/ray.cpython-39.pyc ADDED
Binary file (15.3 kB). View file
 
shap_e/models/nerf/__pycache__/renderer.cpython-39.pyc ADDED
Binary file (5.62 kB). View file
 
shap_e/models/nerf/model.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from functools import partial
3
+ from typing import Any, Dict, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from shap_e.models.nn.checkpoint import checkpoint
10
+ from shap_e.models.nn.encoding import encode_position, spherical_harmonics_basis
11
+ from shap_e.models.nn.meta import MetaModule, subdict
12
+ from shap_e.models.nn.ops import MLP, MetaMLP, get_act, mlp_init, zero_init
13
+ from shap_e.models.nn.utils import ArrayType
14
+ from shap_e.models.query import Query
15
+ from shap_e.util.collections import AttrDict
16
+
17
+
18
+ class NeRFModel(ABC):
19
+ """
20
+ Parametric scene representation whose outputs are integrated by NeRFRenderer
21
+ """
22
+
23
+ @abstractmethod
24
+ def forward(
25
+ self,
26
+ query: Query,
27
+ params: Optional[Dict[str, torch.Tensor]] = None,
28
+ options: Optional[Dict[str, Any]] = None,
29
+ ) -> AttrDict:
30
+ """
31
+ :param query: the points in the field to query.
32
+ :param params: Meta parameters
33
+ :param options: Optional hyperparameters
34
+ :return: An AttrDict containing at least
35
+ - density: [batch_size x ... x 1]
36
+ - channels: [batch_size x ... x n_channels]
37
+ - aux_losses: [batch_size x ... x 1]
38
+ """
39
+
40
+
41
+ class VoidNeRFModel(MetaModule, NeRFModel):
42
+ """
43
+ Implements the default empty space model where all queries are rendered as
44
+ background.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ background: ArrayType,
50
+ trainable: bool = False,
51
+ channel_scale: float = 255.0,
52
+ device: torch.device = torch.device("cuda"),
53
+ ):
54
+ super().__init__()
55
+ background = nn.Parameter(
56
+ torch.from_numpy(np.array(background)).to(dtype=torch.float32, device=device)
57
+ / channel_scale
58
+ )
59
+ if trainable:
60
+ self.register_parameter("background", background)
61
+ else:
62
+ self.register_buffer("background", background)
63
+
64
+ def forward(
65
+ self,
66
+ query: Query,
67
+ params: Optional[Dict[str, torch.Tensor]] = None,
68
+ options: Optional[Dict[str, Any]] = None,
69
+ ) -> AttrDict:
70
+ _ = params
71
+ default_bg = self.background[None]
72
+ background = options.get("background", default_bg) if options is not None else default_bg
73
+
74
+ shape = query.position.shape[:-1]
75
+ ones = [1] * (len(shape) - 1)
76
+ n_channels = background.shape[-1]
77
+ background = torch.broadcast_to(
78
+ background.view(background.shape[0], *ones, n_channels), [*shape, n_channels]
79
+ )
80
+ return background
81
+
82
+
83
+ class MLPNeRFModel(MetaModule, NeRFModel):
84
+ def __init__(
85
+ self,
86
+ # Positional encoding parameters
87
+ n_levels: int = 10,
88
+ # MLP parameters
89
+ d_hidden: int = 256,
90
+ n_density_layers: int = 4,
91
+ n_channel_layers: int = 1,
92
+ n_channels: int = 3,
93
+ sh_degree: int = 4,
94
+ activation: str = "relu",
95
+ density_activation: str = "exp",
96
+ init: Optional[str] = None,
97
+ init_scale: float = 1.0,
98
+ output_activation: str = "sigmoid",
99
+ meta_parameters: bool = False,
100
+ trainable_meta: bool = False,
101
+ zero_out: bool = True,
102
+ register_freqs: bool = True,
103
+ posenc_version: str = "v1",
104
+ device: torch.device = torch.device("cuda"),
105
+ ):
106
+ super().__init__()
107
+
108
+ # Positional encoding
109
+ if register_freqs:
110
+ # not used anymore
111
+ self.register_buffer(
112
+ "freqs",
113
+ 2.0 ** torch.arange(n_levels, device=device, dtype=torch.float).view(1, n_levels),
114
+ )
115
+
116
+ self.posenc_version = posenc_version
117
+ dummy = torch.eye(1, 3)
118
+ d_input = encode_position(posenc_version, position=dummy).shape[-1]
119
+
120
+ self.n_levels = n_levels
121
+
122
+ self.sh_degree = sh_degree
123
+ d_sh_coeffs = sh_degree**2
124
+
125
+ self.meta_parameters = meta_parameters
126
+
127
+ mlp_cls = (
128
+ partial(
129
+ MetaMLP,
130
+ meta_scale=False,
131
+ meta_shift=False,
132
+ meta_proj=True,
133
+ meta_bias=True,
134
+ trainable_meta=trainable_meta,
135
+ )
136
+ if meta_parameters
137
+ else MLP
138
+ )
139
+
140
+ self.density_mlp = mlp_cls(
141
+ d_input=d_input,
142
+ d_hidden=[d_hidden] * (n_density_layers - 1),
143
+ d_output=d_hidden,
144
+ act_name=activation,
145
+ init_scale=init_scale,
146
+ )
147
+
148
+ self.channel_mlp = mlp_cls(
149
+ d_input=d_hidden + d_sh_coeffs,
150
+ d_hidden=[d_hidden] * n_channel_layers,
151
+ d_output=n_channels,
152
+ act_name=activation,
153
+ init_scale=init_scale,
154
+ )
155
+
156
+ self.act = get_act(output_activation)
157
+ self.density_act = get_act(density_activation)
158
+
159
+ mlp_init(
160
+ list(self.density_mlp.affines) + list(self.channel_mlp.affines),
161
+ init=init,
162
+ init_scale=init_scale,
163
+ )
164
+
165
+ if zero_out:
166
+ zero_init(self.channel_mlp.affines[-1])
167
+
168
+ self.to(device)
169
+
170
+ def encode_position(self, query: Query):
171
+ h = encode_position(self.posenc_version, position=query.position)
172
+ return h
173
+
174
+ def forward(
175
+ self,
176
+ query: Query,
177
+ params: Optional[Dict[str, torch.Tensor]] = None,
178
+ options: Optional[Dict[str, Any]] = None,
179
+ ) -> AttrDict:
180
+ params = self.update(params)
181
+
182
+ options = AttrDict() if options is None else AttrDict(options)
183
+
184
+ query = query.copy()
185
+
186
+ h_position = self.encode_position(query)
187
+
188
+ if self.meta_parameters:
189
+ density_params = subdict(params, "density_mlp")
190
+ density_mlp = partial(
191
+ self.density_mlp, params=density_params, options=options, log_prefix="density_"
192
+ )
193
+ density_mlp_parameters = list(density_params.values())
194
+ else:
195
+ density_mlp = partial(self.density_mlp, options=options, log_prefix="density_")
196
+ density_mlp_parameters = self.density_mlp.parameters()
197
+ h_density = checkpoint(
198
+ density_mlp,
199
+ (h_position,),
200
+ density_mlp_parameters,
201
+ options.checkpoint_nerf_mlp,
202
+ )
203
+ h_direction = maybe_get_spherical_harmonics_basis(
204
+ sh_degree=self.sh_degree,
205
+ coords_shape=query.position.shape,
206
+ coords=query.direction,
207
+ device=query.position.device,
208
+ )
209
+
210
+ if self.meta_parameters:
211
+ channel_params = subdict(params, "channel_mlp")
212
+ channel_mlp = partial(
213
+ self.channel_mlp, params=channel_params, options=options, log_prefix="channel_"
214
+ )
215
+ channel_mlp_parameters = list(channel_params.values())
216
+ else:
217
+ channel_mlp = partial(self.channel_mlp, options=options, log_prefix="channel_")
218
+ channel_mlp_parameters = self.channel_mlp.parameters()
219
+ h_channel = checkpoint(
220
+ channel_mlp,
221
+ (torch.cat([h_density, h_direction], dim=-1),),
222
+ channel_mlp_parameters,
223
+ options.checkpoint_nerf_mlp,
224
+ )
225
+
226
+ density_logit = h_density[..., :1]
227
+
228
+ res = AttrDict(
229
+ density_logit=density_logit,
230
+ density=self.density_act(density_logit),
231
+ channels=self.act(h_channel),
232
+ aux_losses=AttrDict(),
233
+ no_weight_grad_aux_losses=AttrDict(),
234
+ )
235
+ if options.return_h_density:
236
+ res.h_density = h_density
237
+
238
+ return res
239
+
240
+
241
+ def maybe_get_spherical_harmonics_basis(
242
+ sh_degree: int,
243
+ coords_shape: Tuple[int],
244
+ coords: Optional[torch.Tensor] = None,
245
+ device: torch.device = torch.device("cuda"),
246
+ ) -> torch.Tensor:
247
+ """
248
+ :param sh_degree: Spherical harmonics degree
249
+ :param coords_shape: [*shape, 3]
250
+ :param coords: optional coordinate tensor of coords_shape
251
+ """
252
+ if coords is None:
253
+ return torch.zeros(*coords_shape[:-1], sh_degree**2).to(device)
254
+
255
+ return spherical_harmonics_basis(coords, sh_degree)
shap_e/models/nerf/ray.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from shap_e.models.nn.utils import sample_pmf
9
+ from shap_e.models.volume import Volume, VolumeRange
10
+ from shap_e.util.collections import AttrDict
11
+
12
+ from .model import NeRFModel, Query
13
+
14
+
15
+ def render_rays(
16
+ rays: torch.Tensor,
17
+ parts: List["RayVolumeIntegral"],
18
+ void_model: NeRFModel,
19
+ shared: bool = False,
20
+ prev_raw_outputs: Optional[List[AttrDict]] = None,
21
+ render_with_direction: bool = True,
22
+ importance_sampling_options: Optional[Dict[str, Any]] = None,
23
+ ) -> Tuple["RayVolumeIntegralResults", List["RaySampler"], List[AttrDict]]:
24
+ """
25
+ Perform volumetric rendering over a partition of possible t's in the union
26
+ of rendering volumes (written below with some abuse of notations)
27
+
28
+ C(r) := sum(
29
+ transmittance(t[i]) *
30
+ integrate(
31
+ lambda t: density(t) * channels(t) * transmittance(t),
32
+ [t[i], t[i + 1]],
33
+ )
34
+ for i in range(len(parts))
35
+ ) + transmittance(t[-1]) * void_model(t[-1]).channels
36
+
37
+ where
38
+
39
+ 1) transmittance(s) := exp(-integrate(density, [t[0], s])) calculates the
40
+ probability of light passing through the volume specified by [t[0], s].
41
+ (transmittance of 1 means light can pass freely)
42
+ 2) density and channels are obtained by evaluating the appropriate
43
+ part.model at time t.
44
+ 3) [t[i], t[i + 1]] is defined as the range of t where the ray intersects
45
+ (parts[i].volume \\ union(part.volume for part in parts[:i])) at the surface
46
+ of the shell (if bounded). If the ray does not intersect, the integral over
47
+ this segment is evaluated as 0 and transmittance(t[i + 1]) :=
48
+ transmittance(t[i]).
49
+ 4) The last term is integration to infinity (e.g. [t[-1], math.inf]) that
50
+ is evaluated by the void_model (i.e. we consider this space to be empty).
51
+
52
+ :param rays: [batch_size x ... x 2 x 3] origin and direction.
53
+ :param parts: disjoint volume integrals.
54
+ :param void_model: use this model to integrate over the empty space
55
+ :param shared: All RayVolumeIntegrals are calculated with the same model.
56
+ :param prev_raw_outputs: Raw outputs from the previous rendering step
57
+
58
+ :return: A tuple of
59
+ - AttrDict containing the rendered `channels`, `distances`, and the `aux_losses`
60
+ - A list of importance samplers for additional fine-grained rendering
61
+ - A list of raw output for each interval
62
+ """
63
+ if importance_sampling_options is None:
64
+ importance_sampling_options = {}
65
+
66
+ origin, direc = rays[..., 0, :], rays[..., 1, :]
67
+
68
+ if prev_raw_outputs is None:
69
+ prev_raw_outputs = [None] * len(parts)
70
+
71
+ samplers = []
72
+ raw_outputs = []
73
+ t0 = None
74
+ results = None
75
+ # import pdb; pdb.set_trace()
76
+ for part_i, prev_raw_i in zip(parts, prev_raw_outputs):
77
+
78
+ # Integrate over [t[i], t[i + 1]]
79
+ results_i = part_i.render_rays(
80
+ origin,
81
+ direc,
82
+ t0=t0,
83
+ prev_raw=prev_raw_i,
84
+ shared=shared,
85
+ render_with_direction=render_with_direction,
86
+ )
87
+
88
+ # Create an importance sampler for (optional) fine rendering
89
+ samplers.append(
90
+ ImportanceRaySampler(
91
+ results_i.volume_range, results_i.raw, **importance_sampling_options
92
+ )
93
+ )
94
+ raw_outputs.append(results_i.raw)
95
+
96
+ # Pass t[i + 1] as the start of integration for the next interval.
97
+ t0 = results_i.volume_range.next_t0()
98
+
99
+ # Combine the results from [t[0], t[i]] and [t[i], t[i+1]]
100
+ results = results_i if results is None else results.combine(results_i)
101
+
102
+ # While integrating out [t[-1], math.inf] is the correct thing to do, this
103
+ # erases a lot of useful information. Also, void_model is meant to predict
104
+ # the channels at t=math.inf.
105
+
106
+ # # Add the void background over [t[-1], math.inf] to complete integration.
107
+ # results = results.combine(
108
+ # RayVolumeIntegralResults(
109
+ # output=AttrDict(
110
+ # channels=void_model(origin, direc),
111
+ # distances=torch.zeros_like(t0),
112
+ # aux_losses=AttrDict(),
113
+ # ),
114
+ # volume_range=VolumeRange(
115
+ # t0=t0,
116
+ # t1=torch.full_like(t0, math.inf),
117
+ # intersected=torch.full_like(results.volume_range.intersected, True),
118
+ # ),
119
+ # # Void space extends to infinity. It is assumed that no light
120
+ # # passes beyond the void.
121
+ # transmittance=torch.zeros_like(results_i.transmittance),
122
+ # )
123
+ # )
124
+ results.output.channels = results.output.channels + results.transmittance * void_model(
125
+ Query(origin, direc)
126
+ )
127
+
128
+ return results, samplers, raw_outputs
129
+
130
+
131
+ @dataclass
132
+ class RayVolumeIntegralResults:
133
+ """
134
+ Stores the relevant state and results of
135
+
136
+ integrate(
137
+ lambda t: density(t) * channels(t) * transmittance(t),
138
+ [t0, t1],
139
+ )
140
+ """
141
+
142
+ # Rendered output and auxiliary losses
143
+ # output.channels has shape [batch_size, *inner_shape, n_channels]
144
+ output: AttrDict
145
+
146
+ """
147
+ Optional values
148
+ """
149
+
150
+ # Raw values contain the sampled `ts`, `density`, `channels`, etc.
151
+ raw: Optional[AttrDict] = None
152
+
153
+ # Integration
154
+ volume_range: Optional[VolumeRange] = None
155
+
156
+ # If a ray intersects, the transmittance from t0 to t1 (e.g. the
157
+ # probability that the ray passes through this volume).
158
+ # has shape [batch_size, *inner_shape, 1]
159
+ transmittance: Optional[torch.Tensor] = None
160
+
161
+ def combine(self, cur: "RayVolumeIntegralResults") -> "RayVolumeIntegralResults":
162
+ """
163
+ Combines the integration results of `self` over [t0, t1] and
164
+ `cur` over [t1, t2] to produce a new set of results over [t0, t2] by
165
+ using a similar equation to (4) in NeRF++:
166
+
167
+ integrate(
168
+ lambda t: density(t) * channels(t) * transmittance(t),
169
+ [t0, t2]
170
+ )
171
+
172
+ = integrate(
173
+ lambda t: density(t) * channels(t) * transmittance(t),
174
+ [t0, t1]
175
+ ) + transmittance(t1) * integrate(
176
+ lambda t: density(t) * channels(t) * transmittance(t),
177
+ [t1, t2]
178
+ )
179
+ """
180
+ assert torch.allclose(self.volume_range.next_t0(), cur.volume_range.t0)
181
+
182
+ def _combine_fn(
183
+ prev_val: Optional[torch.Tensor],
184
+ cur_val: Optional[torch.Tensor],
185
+ *,
186
+ prev_transmittance: torch.Tensor,
187
+ ):
188
+ assert prev_val is not None
189
+ if cur_val is None:
190
+ # cur_output.aux_losses are empty for the void_model.
191
+ return prev_val
192
+ return prev_val + prev_transmittance * cur_val
193
+
194
+ output = self.output.combine(
195
+ cur.output, combine_fn=partial(_combine_fn, prev_transmittance=self.transmittance)
196
+ )
197
+
198
+ combined = RayVolumeIntegralResults(
199
+ output=output,
200
+ volume_range=self.volume_range.extend(cur.volume_range),
201
+ transmittance=self.transmittance * cur.transmittance,
202
+ )
203
+ return combined
204
+
205
+
206
+ @dataclass
207
+ class RayVolumeIntegral:
208
+ model: NeRFModel
209
+ volume: Volume
210
+ sampler: "RaySampler"
211
+ n_samples: int
212
+
213
+ def render_rays(
214
+ self,
215
+ origin: torch.Tensor,
216
+ direction: torch.Tensor,
217
+ t0: Optional[torch.Tensor] = None,
218
+ prev_raw: Optional[AttrDict] = None,
219
+ shared: bool = False,
220
+ render_with_direction: bool = True,
221
+ ) -> "RayVolumeIntegralResults":
222
+ """
223
+ Perform volumetric rendering over the given volume.
224
+
225
+ :param position: [batch_size, *shape, 3]
226
+ :param direction: [batch_size, *shape, 3]
227
+ :param t0: Optional [batch_size, *shape, 1]
228
+ :param prev_raw: the raw outputs when using multiple levels with this model.
229
+ :param shared: means the same model is used for all RayVolumeIntegral's
230
+ :param render_with_direction: use the incoming ray direction when querying the model.
231
+
232
+ :return: RayVolumeIntegralResults
233
+ """
234
+ # 1. Intersect the rays with the current volume and sample ts to
235
+ # integrate along.
236
+ vrange = self.volume.intersect(origin, direction, t0_lower=t0)
237
+ ts = self.sampler.sample(vrange.t0, vrange.t1, self.n_samples)
238
+
239
+ if prev_raw is not None and not shared:
240
+ # Append the previous ts now before fprop because previous
241
+ # rendering used a different model and we can't reuse the output.
242
+ ts = torch.sort(torch.cat([ts, prev_raw.ts], dim=-2), dim=-2).values
243
+
244
+ # Shape sanity checks
245
+ batch_size, *_shape, _t0_dim = vrange.t0.shape
246
+ _, *ts_shape, _ts_dim = ts.shape
247
+
248
+ # 2. Get the points along the ray and query the model
249
+ directions = torch.broadcast_to(direction.unsqueeze(-2), [batch_size, *ts_shape, 3])
250
+ positions = origin.unsqueeze(-2) + ts * directions
251
+
252
+ optional_directions = directions if render_with_direction else None
253
+ mids = (ts[..., 1:, :] + ts[..., :-1, :]) / 2
254
+ raw = self.model(
255
+ Query(
256
+ position=positions,
257
+ direction=optional_directions,
258
+ t_min=torch.cat([vrange.t0[..., None, :], mids], dim=-2),
259
+ t_max=torch.cat([mids, vrange.t1[..., None, :]], dim=-2),
260
+ )
261
+ )
262
+ raw.ts = ts
263
+
264
+ if prev_raw is not None and shared:
265
+ # We can append the additional queries to previous raw outputs
266
+ # before integration
267
+ copy = prev_raw.copy()
268
+ result = torch.sort(torch.cat([raw.pop("ts"), copy.pop("ts")], dim=-2), dim=-2)
269
+ merge_results = partial(self._merge_results, dim=-2, indices=result.indices)
270
+ raw = raw.combine(copy, merge_results)
271
+ raw.ts = result.values
272
+
273
+ # 3. Integrate the raw results
274
+ output, transmittance = self.integrate_samples(vrange, raw)
275
+
276
+ # 4. Clean up results that do not intersect with the volume.
277
+ transmittance = torch.where(
278
+ vrange.intersected, transmittance, torch.ones_like(transmittance)
279
+ )
280
+
281
+ def _mask_fn(_key: str, tensor: torch.Tensor):
282
+ return torch.where(vrange.intersected, tensor, torch.zeros_like(tensor))
283
+
284
+ def _is_tensor(_key: str, value: Any):
285
+ return isinstance(value, torch.Tensor)
286
+
287
+ output = output.map(map_fn=_mask_fn, should_map=_is_tensor)
288
+
289
+ return RayVolumeIntegralResults(
290
+ output=output,
291
+ raw=raw,
292
+ volume_range=vrange,
293
+ transmittance=transmittance,
294
+ )
295
+
296
+ def integrate_samples(
297
+ self,
298
+ volume_range: VolumeRange,
299
+ raw: AttrDict,
300
+ ) -> Tuple[AttrDict, torch.Tensor]:
301
+ """
302
+ Integrate the raw.channels along with other aux_losses and values to
303
+ produce the final output dictionary containing rendered `channels`,
304
+ estimated `distances` and `aux_losses`.
305
+
306
+ :param volume_range: Specifies the integral range [t0, t1]
307
+ :param raw: Contains a dict of function evaluations at ts. Should have
308
+
309
+ density: torch.Tensor [batch_size, *shape, n_samples, 1]
310
+ channels: torch.Tensor [batch_size, *shape, n_samples, n_channels]
311
+ aux_losses: {key: torch.Tensor [batch_size, *shape, n_samples, 1] for each key}
312
+ no_weight_grad_aux_losses: an optional set of losses for which the weights
313
+ should be detached before integration.
314
+
315
+ after the call, integrate_samples populates some intermediate calculations
316
+ for later use like
317
+
318
+ weights: torch.Tensor [batch_size, *shape, n_samples, 1] (density *
319
+ transmittance)[i] weight for each rgb output at [..., i, :].
320
+ :returns: a tuple of (
321
+ a dictionary of rendered outputs and aux_losses,
322
+ transmittance of this volume,
323
+ )
324
+ """
325
+
326
+ # 1. Calculate the weights
327
+ _, _, dt = volume_range.partition(raw.ts)
328
+ ddensity = raw.density * dt
329
+
330
+ mass = torch.cumsum(ddensity, dim=-2)
331
+ transmittance = torch.exp(-mass[..., -1, :])
332
+
333
+ alphas = 1.0 - torch.exp(-ddensity)
334
+ Ts = torch.exp(torch.cat([torch.zeros_like(mass[..., :1, :]), -mass[..., :-1, :]], dim=-2))
335
+ # This is the probability of light hitting and reflecting off of
336
+ # something at depth [..., i, :].
337
+ weights = alphas * Ts
338
+
339
+ # 2. Integrate all results
340
+ def _integrate(key: str, samples: torch.Tensor, weights: torch.Tensor):
341
+ if key == "density":
342
+ # Omit integrating the density, because we don't need it
343
+ return None
344
+ return torch.sum(samples * weights, dim=-2)
345
+
346
+ def _is_tensor(_key: str, value: Any):
347
+ return isinstance(value, torch.Tensor)
348
+
349
+ if raw.no_weight_grad_aux_losses:
350
+ extra_aux_losses = raw.no_weight_grad_aux_losses.map(
351
+ partial(_integrate, weights=weights.detach()), should_map=_is_tensor
352
+ )
353
+ else:
354
+ extra_aux_losses = {}
355
+ output = raw.map(partial(_integrate, weights=weights), should_map=_is_tensor)
356
+ if "no_weight_grad_aux_losses" in output:
357
+ del output["no_weight_grad_aux_losses"]
358
+ output.aux_losses.update(extra_aux_losses)
359
+
360
+ # Integrating the ts yields the distance away from the origin; rename the variable.
361
+ output.distances = output.ts
362
+ del output["ts"]
363
+ del output["density"]
364
+
365
+ assert output.distances.shape == (*output.channels.shape[:-1], 1)
366
+ assert output.channels.shape[:-1] == raw.channels.shape[:-2]
367
+ assert output.channels.shape[-1] == raw.channels.shape[-1]
368
+
369
+ # 3. Reduce loss
370
+ def _reduce_loss(_key: str, loss: torch.Tensor):
371
+ return loss.view(loss.shape[0], -1).sum(dim=-1)
372
+
373
+ # 4. Store other useful calculations
374
+ raw.weights = weights
375
+
376
+ output.aux_losses = output.aux_losses.map(_reduce_loss)
377
+
378
+ return output, transmittance
379
+
380
+ def _merge_results(
381
+ self, a: Optional[torch.Tensor], b: torch.Tensor, dim: int, indices: torch.Tensor
382
+ ):
383
+ """
384
+ :param a: [..., n_a, ...]. The other dictionary containing the b's may
385
+ contain extra tensors from earlier calculations, so a can be None.
386
+ :param b: [..., n_b, ...]
387
+ :param dim: dimension to merge
388
+ :param indices: how the merged results should be sorted at the end
389
+ :return: a concatted and sorted tensor of size [..., n_a + n_b, ...]
390
+ """
391
+ if a is None:
392
+ return None
393
+
394
+ merged = torch.cat([a, b], dim=dim)
395
+ return torch.gather(merged, dim=dim, index=torch.broadcast_to(indices, merged.shape))
396
+
397
+
398
+ class RaySampler(ABC):
399
+ @abstractmethod
400
+ def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
401
+ """
402
+ :param t0: start time has shape [batch_size, *shape, 1]
403
+ :param t1: finish time has shape [batch_size, *shape, 1]
404
+ :param n_samples: number of ts to sample
405
+ :return: sampled ts of shape [batch_size, *shape, n_samples, 1]
406
+ """
407
+
408
+
409
+ class StratifiedRaySampler(RaySampler):
410
+ """
411
+ Instead of fixed intervals, a sample is drawn uniformly at random from each
412
+ interval.
413
+ """
414
+
415
+ def __init__(self, depth_mode: str = "linear"):
416
+ """
417
+ :param depth_mode: linear samples ts linearly in depth. harmonic ensures
418
+ closer points are sampled more densely.
419
+ """
420
+ self.depth_mode = depth_mode
421
+ assert self.depth_mode in ("linear", "geometric", "harmonic")
422
+
423
+ def sample(
424
+ self,
425
+ t0: torch.Tensor,
426
+ t1: torch.Tensor,
427
+ n_samples: int,
428
+ epsilon: float = 1e-3,
429
+ ) -> torch.Tensor:
430
+ """
431
+ :param t0: start time has shape [batch_size, *shape, 1]
432
+ :param t1: finish time has shape [batch_size, *shape, 1]
433
+ :param n_samples: number of ts to sample
434
+ :return: sampled ts of shape [batch_size, *shape, n_samples, 1]
435
+ """
436
+ ones = [1] * (len(t0.shape) - 1)
437
+ ts = torch.linspace(0, 1, n_samples).view(*ones, n_samples).to(t0.dtype).to(t0.device)
438
+
439
+ if self.depth_mode == "linear":
440
+ ts = t0 * (1.0 - ts) + t1 * ts
441
+ elif self.depth_mode == "geometric":
442
+ ts = (t0.clamp(epsilon).log() * (1.0 - ts) + t1.clamp(epsilon).log() * ts).exp()
443
+ elif self.depth_mode == "harmonic":
444
+ # The original NeRF recommends this interpolation scheme for
445
+ # spherical scenes, but there could be some weird edge cases when
446
+ # the observer crosses from the inner to outer volume.
447
+ ts = 1.0 / (1.0 / t0.clamp(epsilon) * (1.0 - ts) + 1.0 / t1.clamp(epsilon) * ts)
448
+
449
+ mids = 0.5 * (ts[..., 1:] + ts[..., :-1])
450
+ upper = torch.cat([mids, t1], dim=-1)
451
+ lower = torch.cat([t0, mids], dim=-1)
452
+ t_rand = torch.rand_like(ts)
453
+
454
+ ts = lower + (upper - lower) * t_rand
455
+ return ts.unsqueeze(-1)
456
+
457
+
458
+ class ImportanceRaySampler(RaySampler):
459
+ """
460
+ Given the initial estimate of densities, this samples more from
461
+ regions/bins expected to have objects.
462
+ """
463
+
464
+ def __init__(
465
+ self, volume_range: VolumeRange, raw: AttrDict, blur_pool: bool = False, alpha: float = 1e-5
466
+ ):
467
+ """
468
+ :param volume_range: the range in which a ray intersects the given volume.
469
+ :param raw: dictionary of raw outputs from the NeRF models of shape
470
+ [batch_size, *shape, n_coarse_samples, 1]. Should at least contain
471
+
472
+ :param ts: earlier samples from the coarse rendering step
473
+ :param weights: discretized version of density * transmittance
474
+ :param blur_pool: if true, use 2-tap max + 2-tap blur filter from mip-NeRF.
475
+ :param alpha: small value to add to weights.
476
+ """
477
+ self.volume_range = volume_range
478
+ self.ts = raw.ts.clone().detach()
479
+ self.weights = raw.weights.clone().detach()
480
+ self.blur_pool = blur_pool
481
+ self.alpha = alpha
482
+
483
+ @torch.no_grad()
484
+ def sample(self, t0: torch.Tensor, t1: torch.Tensor, n_samples: int) -> torch.Tensor:
485
+ """
486
+ :param t0: start time has shape [batch_size, *shape, 1]
487
+ :param t1: finish time has shape [batch_size, *shape, 1]
488
+ :param n_samples: number of ts to sample
489
+ :return: sampled ts of shape [batch_size, *shape, n_samples, 1]
490
+ """
491
+ lower, upper, _ = self.volume_range.partition(self.ts)
492
+
493
+ batch_size, *shape, n_coarse_samples, _ = self.ts.shape
494
+
495
+ weights = self.weights
496
+ if self.blur_pool:
497
+ padded = torch.cat([weights[..., :1, :], weights, weights[..., -1:, :]], dim=-2)
498
+ maxes = torch.maximum(padded[..., :-1, :], padded[..., 1:, :])
499
+ weights = 0.5 * (maxes[..., :-1, :] + maxes[..., 1:, :])
500
+ weights = weights + self.alpha
501
+ pmf = weights / weights.sum(dim=-2, keepdim=True)
502
+ inds = sample_pmf(pmf, n_samples)
503
+ assert inds.shape == (batch_size, *shape, n_samples, 1)
504
+ assert (inds >= 0).all() and (inds < n_coarse_samples).all()
505
+
506
+ t_rand = torch.rand(inds.shape, device=inds.device)
507
+ lower_ = torch.gather(lower, -2, inds)
508
+ upper_ = torch.gather(upper, -2, inds)
509
+
510
+ ts = lower_ + (upper_ - lower_) * t_rand
511
+ ts = torch.sort(ts, dim=-2).values
512
+ return ts
shap_e/models/nerf/renderer.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+
6
+ from shap_e.models.nn.meta import subdict
7
+ from shap_e.models.renderer import RayRenderer
8
+ from shap_e.models.volume import Volume
9
+ from shap_e.util.collections import AttrDict
10
+
11
+ from .model import NeRFModel
12
+ from .ray import RayVolumeIntegral, StratifiedRaySampler, render_rays
13
+
14
+
15
+ class TwoStepNeRFRenderer(RayRenderer):
16
+ """
17
+ Coarse and fine-grained rendering as proposed by NeRF. This class
18
+ additionally supports background rendering like NeRF++.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ n_coarse_samples: int,
24
+ n_fine_samples: int,
25
+ void_model: NeRFModel,
26
+ fine_model: NeRFModel,
27
+ volume: Volume,
28
+ coarse_model: Optional[NeRFModel] = None,
29
+ coarse_background_model: Optional[NeRFModel] = None,
30
+ fine_background_model: Optional[NeRFModel] = None,
31
+ outer_volume: Optional[Volume] = None,
32
+ foreground_stratified_depth_sampling_mode: str = "linear",
33
+ background_stratified_depth_sampling_mode: str = "linear",
34
+ importance_sampling_options: Optional[Dict[str, Any]] = None,
35
+ channel_scale: float = 255,
36
+ device: torch.device = torch.device("cuda"),
37
+ **kwargs,
38
+ ):
39
+ """
40
+ :param outer_volume: is where distant objects are encoded.
41
+ """
42
+ super().__init__(**kwargs)
43
+
44
+ if coarse_model is None:
45
+ assert (
46
+ fine_background_model is None or coarse_background_model is None
47
+ ), "models should be shared for both fg and bg"
48
+
49
+ self.n_coarse_samples = n_coarse_samples
50
+ self.n_fine_samples = n_fine_samples
51
+ self.void_model = void_model
52
+ self.coarse_model = coarse_model
53
+ self.fine_model = fine_model
54
+ self.volume = volume
55
+ self.coarse_background_model = coarse_background_model
56
+ self.fine_background_model = fine_background_model
57
+ self.outer_volume = outer_volume
58
+ self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode
59
+ self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode
60
+ self.importance_sampling_options = AttrDict(importance_sampling_options or {})
61
+ self.channel_scale = channel_scale
62
+ self.device = device
63
+ self.to(device)
64
+
65
+ if self.coarse_background_model is not None:
66
+ assert self.fine_background_model is not None
67
+ assert self.outer_volume is not None
68
+
69
+ def render_rays(
70
+ self,
71
+ batch: Dict,
72
+ params: Optional[Dict] = None,
73
+ options: Optional[Dict] = None,
74
+ ) -> AttrDict:
75
+ params = self.update(params)
76
+
77
+ batch = AttrDict(batch)
78
+ if options is None:
79
+ options = AttrDict()
80
+ options.setdefault("render_background", True)
81
+ options.setdefault("render_with_direction", True)
82
+ options.setdefault("n_coarse_samples", self.n_coarse_samples)
83
+ options.setdefault("n_fine_samples", self.n_fine_samples)
84
+ options.setdefault(
85
+ "foreground_stratified_depth_sampling_mode",
86
+ self.foreground_stratified_depth_sampling_mode,
87
+ )
88
+ options.setdefault(
89
+ "background_stratified_depth_sampling_mode",
90
+ self.background_stratified_depth_sampling_mode,
91
+ )
92
+
93
+ shared = self.coarse_model is None
94
+
95
+ # First, render rays using the coarse models with stratified ray samples.
96
+ coarse_model, coarse_key = (
97
+ (self.fine_model, "fine_model") if shared else (self.coarse_model, "coarse_model")
98
+ )
99
+ coarse_model = partial(
100
+ coarse_model,
101
+ params=subdict(params, coarse_key),
102
+ options=options,
103
+ )
104
+ parts = [
105
+ RayVolumeIntegral(
106
+ model=coarse_model,
107
+ volume=self.volume,
108
+ sampler=StratifiedRaySampler(
109
+ depth_mode=options.foreground_stratified_depth_sampling_mode,
110
+ ),
111
+ n_samples=options.n_coarse_samples,
112
+ ),
113
+ ]
114
+ if options.render_background and self.outer_volume is not None:
115
+ coarse_background_model, coarse_background_key = (
116
+ (self.fine_background_model, "fine_background_model")
117
+ if shared
118
+ else (self.coarse_background_model, "coarse_background_model")
119
+ )
120
+ coarse_background_model = partial(
121
+ coarse_background_model,
122
+ params=subdict(params, coarse_background_key),
123
+ options=options,
124
+ )
125
+ parts.append(
126
+ RayVolumeIntegral(
127
+ model=coarse_background_model,
128
+ volume=self.outer_volume,
129
+ sampler=StratifiedRaySampler(
130
+ depth_mode=options.background_stratified_depth_sampling_mode,
131
+ ),
132
+ n_samples=options.n_coarse_samples,
133
+ )
134
+ )
135
+ coarse_results, samplers, coarse_raw_outputs = render_rays(
136
+ batch.rays,
137
+ parts,
138
+ partial(self.void_model, options=options),
139
+ shared=shared,
140
+ render_with_direction=options.render_with_direction,
141
+ importance_sampling_options=AttrDict(self.importance_sampling_options),
142
+ )
143
+
144
+ # Then, render rays using the fine models with importance-weighted ray samples.
145
+ fine_model = partial(
146
+ self.fine_model,
147
+ params=subdict(params, "fine_model"),
148
+ options=options,
149
+ )
150
+ parts = [
151
+ RayVolumeIntegral(
152
+ model=fine_model,
153
+ volume=self.volume,
154
+ sampler=samplers[0],
155
+ n_samples=options.n_fine_samples,
156
+ ),
157
+ ]
158
+ if options.render_background and self.outer_volume is not None:
159
+ fine_background_model = partial(
160
+ self.fine_background_model,
161
+ params=subdict(params, "fine_background_model"),
162
+ options=options,
163
+ )
164
+ parts.append(
165
+ RayVolumeIntegral(
166
+ model=fine_background_model,
167
+ volume=self.outer_volume,
168
+ sampler=samplers[1],
169
+ n_samples=options.n_fine_samples,
170
+ )
171
+ )
172
+ fine_results, *_ = render_rays(
173
+ batch.rays,
174
+ parts,
175
+ partial(self.void_model, options=options),
176
+ shared=shared,
177
+ prev_raw_outputs=coarse_raw_outputs,
178
+ render_with_direction=options.render_with_direction,
179
+ )
180
+
181
+ # Combine results
182
+ aux_losses = fine_results.output.aux_losses.copy()
183
+ for key, val in coarse_results.output.aux_losses.items():
184
+ aux_losses[key + "_coarse"] = val
185
+
186
+ return AttrDict(
187
+ channels=fine_results.output.channels * self.channel_scale,
188
+ channels_coarse=coarse_results.output.channels * self.channel_scale,
189
+ distances=fine_results.output.distances,
190
+ transmittance=fine_results.transmittance,
191
+ transmittance_coarse=coarse_results.transmittance,
192
+ t0=fine_results.volume_range.t0,
193
+ t1=fine_results.volume_range.t1,
194
+ intersected=fine_results.volume_range.intersected,
195
+ aux_losses=aux_losses,
196
+ )
197
+
198
+
199
+ class OneStepNeRFRenderer(RayRenderer):
200
+ """
201
+ Renders rays using stratified sampling only unlike vanilla NeRF.
202
+ The same setup as NeRF++.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ n_samples: int,
208
+ void_model: NeRFModel,
209
+ foreground_model: NeRFModel,
210
+ volume: Volume,
211
+ background_model: Optional[NeRFModel] = None,
212
+ outer_volume: Optional[Volume] = None,
213
+ foreground_stratified_depth_sampling_mode: str = "linear",
214
+ background_stratified_depth_sampling_mode: str = "linear",
215
+ channel_scale: float = 255,
216
+ device: torch.device = torch.device("cuda"),
217
+ **kwargs,
218
+ ):
219
+ super().__init__(**kwargs)
220
+ self.n_samples = n_samples
221
+ self.void_model = void_model
222
+ self.foreground_model = foreground_model
223
+ self.volume = volume
224
+ self.background_model = background_model
225
+ self.outer_volume = outer_volume
226
+ self.foreground_stratified_depth_sampling_mode = foreground_stratified_depth_sampling_mode
227
+ self.background_stratified_depth_sampling_mode = background_stratified_depth_sampling_mode
228
+ self.channel_scale = channel_scale
229
+ self.device = device
230
+ self.to(device)
231
+
232
+ def render_rays(
233
+ self,
234
+ batch: Dict,
235
+ params: Optional[Dict] = None,
236
+ options: Optional[Dict] = None,
237
+ ) -> AttrDict:
238
+ params = self.update(params)
239
+
240
+ batch = AttrDict(batch)
241
+ if options is None:
242
+ options = AttrDict()
243
+ options.setdefault("render_background", True)
244
+ options.setdefault("render_with_direction", True)
245
+ options.setdefault("n_samples", self.n_samples)
246
+ options.setdefault(
247
+ "foreground_stratified_depth_sampling_mode",
248
+ self.foreground_stratified_depth_sampling_mode,
249
+ )
250
+ options.setdefault(
251
+ "background_stratified_depth_sampling_mode",
252
+ self.background_stratified_depth_sampling_mode,
253
+ )
254
+
255
+ foreground_model = partial(
256
+ self.foreground_model,
257
+ params=subdict(params, "foreground_model"),
258
+ options=options,
259
+ )
260
+ parts = [
261
+ RayVolumeIntegral(
262
+ model=foreground_model,
263
+ volume=self.volume,
264
+ sampler=StratifiedRaySampler(
265
+ depth_mode=options.foreground_stratified_depth_sampling_mode
266
+ ),
267
+ n_samples=options.n_samples,
268
+ ),
269
+ ]
270
+ if options.render_background and self.outer_volume is not None:
271
+ background_model = partial(
272
+ self.background_model,
273
+ params=subdict(params, "background_model"),
274
+ options=options,
275
+ )
276
+ parts.append(
277
+ RayVolumeIntegral(
278
+ model=background_model,
279
+ volume=self.outer_volume,
280
+ sampler=StratifiedRaySampler(
281
+ depth_mode=options.background_stratified_depth_sampling_mode
282
+ ),
283
+ n_samples=options.n_samples,
284
+ )
285
+ )
286
+ results, *_ = render_rays(
287
+ batch.rays,
288
+ parts,
289
+ self.void_model,
290
+ render_with_direction=options.render_with_direction,
291
+ )
292
+
293
+ return AttrDict(
294
+ channels=results.output.channels * self.channel_scale,
295
+ distances=results.output.distances,
296
+ transmittance=results.transmittance,
297
+ t0=results.volume_range.t0,
298
+ t1=results.volume_range.t1,
299
+ intersected=results.volume_range.intersected,
300
+ aux_losses=results.output.aux_losses,
301
+ )
shap_e/models/nerstf/__pycache__/mlp.cpython-39.pyc ADDED
Binary file (4.74 kB). View file
 
shap_e/models/nerstf/__pycache__/renderer.cpython-39.pyc ADDED
Binary file (6.65 kB). View file
 
shap_e/models/nerstf/mlp.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple
2
+
3
+ import torch
4
+
5
+ from shap_e.models.nn.ops import get_act
6
+ from shap_e.models.query import Query
7
+ from shap_e.models.stf.mlp import MLPModel
8
+ from shap_e.util.collections import AttrDict
9
+
10
+
11
+ class MLPDensitySDFModel(MLPModel):
12
+ def __init__(
13
+ self,
14
+ initial_bias: float = -0.1,
15
+ sdf_activation="tanh",
16
+ density_activation="exp",
17
+ **kwargs,
18
+ ):
19
+ super().__init__(
20
+ n_output=2,
21
+ output_activation="identity",
22
+ **kwargs,
23
+ )
24
+ self.mlp[-1].bias[0].data.fill_(initial_bias)
25
+ self.sdf_activation = get_act(sdf_activation)
26
+ self.density_activation = get_act(density_activation)
27
+
28
+ def forward(
29
+ self,
30
+ query: Query,
31
+ params: Optional[Dict[str, torch.Tensor]] = None,
32
+ options: Optional[Dict[str, Any]] = None,
33
+ ) -> AttrDict[str, Any]:
34
+ # query.direction is None typically for SDF models and training
35
+ h, _h_directionless = self._mlp(
36
+ query.position, query.direction, params=params, options=options
37
+ )
38
+ h_sdf, h_density = h.split(1, dim=-1)
39
+ return AttrDict(
40
+ density=self.density_activation(h_density),
41
+ signed_distance=self.sdf_activation(h_sdf),
42
+ )
43
+
44
+
45
+ class MLPNeRSTFModel(MLPModel):
46
+ def __init__(
47
+ self,
48
+ sdf_activation="tanh",
49
+ density_activation="exp",
50
+ channel_activation="sigmoid",
51
+ direction_dependent_shape: bool = True, # To be able to load old models. Set this to be False in future models.
52
+ separate_nerf_channels: bool = False,
53
+ separate_coarse_channels: bool = False,
54
+ initial_density_bias: float = 0.0,
55
+ initial_sdf_bias: float = -0.1,
56
+ **kwargs,
57
+ ):
58
+ h_map, h_directionless_map = indices_for_output_mode(
59
+ direction_dependent_shape=direction_dependent_shape,
60
+ separate_nerf_channels=separate_nerf_channels,
61
+ separate_coarse_channels=separate_coarse_channels,
62
+ )
63
+ n_output = index_mapping_max(h_map)
64
+ super().__init__(
65
+ n_output=n_output,
66
+ output_activation="identity",
67
+ **kwargs,
68
+ )
69
+ self.direction_dependent_shape = direction_dependent_shape
70
+ self.separate_nerf_channels = separate_nerf_channels
71
+ self.separate_coarse_channels = separate_coarse_channels
72
+ self.sdf_activation = get_act(sdf_activation)
73
+ self.density_activation = get_act(density_activation)
74
+ self.channel_activation = get_act(channel_activation)
75
+ self.h_map = h_map
76
+ self.h_directionless_map = h_directionless_map
77
+ self.mlp[-1].bias.data.zero_()
78
+ layer = -1 if self.direction_dependent_shape else self.insert_direction_at
79
+ self.mlp[layer].bias[0].data.fill_(initial_sdf_bias)
80
+ self.mlp[layer].bias[1].data.fill_(initial_density_bias)
81
+
82
+ def forward(
83
+ self,
84
+ query: Query,
85
+ params: Optional[Dict[str, torch.Tensor]] = None,
86
+ options: Optional[Dict[str, Any]] = None,
87
+ ) -> AttrDict[str, Any]:
88
+
89
+ options = AttrDict() if options is None else AttrDict(options)
90
+ h, h_directionless = self._mlp(
91
+ query.position, query.direction, params=params, options=options
92
+ )
93
+ activations = map_indices_to_keys(self.h_map, h)
94
+ activations.update(map_indices_to_keys(self.h_directionless_map, h_directionless))
95
+
96
+ if options.nerf_level == "coarse":
97
+ h_density = activations.density_coarse
98
+ else:
99
+ h_density = activations.density_fine
100
+
101
+ if options.get("rendering_mode", "stf") == "nerf":
102
+ if options.nerf_level == "coarse":
103
+ h_channels = activations.nerf_coarse
104
+ else:
105
+ h_channels = activations.nerf_fine
106
+ else:
107
+ h_channels = activations.stf
108
+ return AttrDict(
109
+ density=self.density_activation(h_density),
110
+ signed_distance=self.sdf_activation(activations.sdf),
111
+ channels=self.channel_activation(h_channels),
112
+ )
113
+
114
+
115
+ IndexMapping = AttrDict[str, Tuple[int, int]]
116
+
117
+
118
+ def indices_for_output_mode(
119
+ direction_dependent_shape: bool,
120
+ separate_nerf_channels: bool,
121
+ separate_coarse_channels: bool,
122
+ ) -> Tuple[IndexMapping, IndexMapping]:
123
+ """
124
+ Get output mappings for (h, h_directionless).
125
+ """
126
+ h_map = AttrDict()
127
+ h_directionless_map = AttrDict()
128
+ if direction_dependent_shape:
129
+ h_map.sdf = (0, 1)
130
+ if separate_coarse_channels:
131
+ assert separate_nerf_channels
132
+ h_map.density_coarse = (1, 2)
133
+ h_map.density_fine = (2, 3)
134
+ h_map.stf = (3, 6)
135
+ h_map.nerf_coarse = (6, 9)
136
+ h_map.nerf_fine = (9, 12)
137
+ else:
138
+ h_map.density_coarse = (1, 2)
139
+ h_map.density_fine = (1, 2)
140
+ if separate_nerf_channels:
141
+ h_map.stf = (2, 5)
142
+ h_map.nerf_coarse = (5, 8)
143
+ h_map.nerf_fine = (5, 8)
144
+ else:
145
+ h_map.stf = (2, 5)
146
+ h_map.nerf_coarse = (2, 5)
147
+ h_map.nerf_fine = (2, 5)
148
+ else:
149
+ h_directionless_map.sdf = (0, 1)
150
+ h_directionless_map.density_coarse = (1, 2)
151
+ if separate_coarse_channels:
152
+ h_directionless_map.density_fine = (2, 3)
153
+ else:
154
+ h_directionless_map.density_fine = h_directionless_map.density_coarse
155
+ h_map.stf = (0, 3)
156
+ if separate_coarse_channels:
157
+ assert separate_nerf_channels
158
+ h_map.nerf_coarse = (3, 6)
159
+ h_map.nerf_fine = (6, 9)
160
+ else:
161
+ if separate_nerf_channels:
162
+ h_map.nerf_coarse = (3, 6)
163
+ else:
164
+ h_map.nerf_coarse = (0, 3)
165
+ h_map.nerf_fine = h_map.nerf_coarse
166
+ return h_map, h_directionless_map
167
+
168
+
169
+ def map_indices_to_keys(mapping: IndexMapping, data: torch.Tensor) -> AttrDict[str, torch.Tensor]:
170
+ return AttrDict({k: data[..., start:end] for k, (start, end) in mapping.items()})
171
+
172
+
173
+ def index_mapping_max(mapping: IndexMapping) -> int:
174
+ return max(end for _, (_, end) in mapping.items())