customdiffusion360 commited on
Commit
ad7bc89
1 Parent(s): 2954c0e

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +64 -0
  2. README.md +4 -4
  3. app.py +395 -0
  4. assets/car0_mesh_centered_flipped.obj +0 -0
  5. assets/chair191_mesh_centered_flipped.obj +0 -0
  6. assets/motorcycle12_mesh_centered_flipped.obj +0 -0
  7. assets/motorcycle29_mesh_centered_flipped.obj +0 -0
  8. assets/plane.obj +14 -0
  9. assets/teddybear0_mesh_centered_flipped.obj +0 -0
  10. assets/teddybear31_mesh_centered_flipped.obj +0 -0
  11. configs/train_co3d_concept.yaml +198 -0
  12. pretrained_models/car0/checkpoints/step=000001600.ckpt +3 -0
  13. pretrained_models/car0/configs/2024-04-12T21-30-20-lightning.yaml +31 -0
  14. pretrained_models/car0/configs/2024-04-12T21-30-20-project.yaml +168 -0
  15. pretrained_models/car0/configs/2024-04-13T11-31-55-lightning.yaml +31 -0
  16. pretrained_models/car0/configs/2024-04-13T11-31-55-project.yaml +170 -0
  17. pretrained_models/car0/configs/2024-04-13T11-42-30-lightning.yaml +31 -0
  18. pretrained_models/car0/configs/2024-04-13T11-42-30-project.yaml +170 -0
  19. pretrained_models/chair191/checkpoints/step=000001600.ckpt +3 -0
  20. pretrained_models/chair191/configs/2024-04-12T22-10-18-lightning.yaml +31 -0
  21. pretrained_models/chair191/configs/2024-04-12T22-10-18-project.yaml +168 -0
  22. pretrained_models/motorcycle12/checkpoints/step=000001600.ckpt +3 -0
  23. pretrained_models/motorcycle12/configs/2024-04-12T23-30-18-project.yaml +168 -0
  24. pretrained_models/teddybear31/checkpoints/step=000001600.ckpt +3 -0
  25. pretrained_models/teddybear31/configs/2024-04-12T22-50-24-lightning.yaml +31 -0
  26. pretrained_models/teddybear31/configs/2024-04-12T22-50-24-project.yaml +168 -0
  27. requirements.txt +37 -0
  28. sampling_for_demo.py +487 -0
  29. scripts.js +147 -0
  30. sgm/__init__.py +4 -0
  31. sgm/data/__init__.py +1 -0
  32. sgm/data/data_co3d.py +762 -0
  33. sgm/lr_scheduler.py +135 -0
  34. sgm/models/__init__.py +2 -0
  35. sgm/models/autoencoder.py +335 -0
  36. sgm/models/diffusion.py +556 -0
  37. sgm/modules/__init__.py +6 -0
  38. sgm/modules/attention.py +1202 -0
  39. sgm/modules/autoencoding/__init__.py +0 -0
  40. sgm/modules/autoencoding/lpips/__init__.py +0 -0
  41. sgm/modules/autoencoding/lpips/loss.py +0 -0
  42. sgm/modules/autoencoding/lpips/loss/LICENSE +23 -0
  43. sgm/modules/autoencoding/lpips/loss/__init__.py +0 -0
  44. sgm/modules/autoencoding/lpips/loss/lpips.py +147 -0
  45. sgm/modules/autoencoding/lpips/model/LICENSE +58 -0
  46. sgm/modules/autoencoding/lpips/model/__init__.py +0 -0
  47. sgm/modules/autoencoding/lpips/model/model.py +88 -0
  48. sgm/modules/autoencoding/lpips/util.py +128 -0
  49. sgm/modules/autoencoding/lpips/vqperceptual.py +17 -0
  50. sgm/modules/autoencoding/regularizers/__init__.py +31 -0
Dockerfile ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+
5
+ ENV PYTHONUNBUFFERED=1
6
+
7
+ RUN apt-get update && apt-get install --no-install-recommends -y \
8
+ build-essential \
9
+ wget \
10
+ git \
11
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
12
+
13
+ WORKDIR /code
14
+
15
+ COPY ./requirements.txt /code/requirements.txt
16
+
17
+ # Set up a new user named "user" with user ID 1000
18
+ RUN useradd -m -u 1000 user
19
+ # Switch to the "user" user
20
+ USER user
21
+ # Set home to the user's home directory
22
+ ENV HOME=/home/user \
23
+ PATH=/home/user/.local/bin:$PATH \
24
+ PYTHONPATH=$HOME/app \
25
+ PYTHONUNBUFFERED=1 \
26
+ SYSTEM=spaces
27
+
28
+ # Install miniconda
29
+ RUN mkdir -p /home/user/conda
30
+ ENV CONDA_DIR /home/user/conda
31
+ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
32
+ /bin/bash ~/miniconda.sh -b -p /home/user/conda
33
+
34
+ # Put conda in path so we can use conda activate
35
+ ENV PATH=$CONDA_DIR/bin:$PATH
36
+
37
+ # Activate
38
+ RUN conda init bash
39
+
40
+ RUN . /home/user/conda/bin/activate
41
+
42
+ # Install dependencies
43
+ RUN conda create -n pose python=3.8
44
+ RUN conda activate pose
45
+
46
+ RUN pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118
47
+
48
+ RUN pip install -r /code/requirements.txt
49
+
50
+ RUN conda install -c conda-forge cudatoolkit-dev -y
51
+ ENV CUDA_HOME=$CONDA_PREFIX/pkgs/cuda-toolkit/
52
+ RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
53
+
54
+
55
+ RUN wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P /code/pretrained_models
56
+ RUN wget https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors -P /code/pretrained_models
57
+
58
+ # Set the working directory to the user's home directory
59
+ WORKDIR $HOME/app
60
+
61
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
62
+ COPY --chown=user . $HOME/app
63
+
64
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Customdiffusion360
3
- emoji: 🚀
4
  colorFrom: gray
5
  colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: CustomDiffusion360
3
+ emoji: 📷
4
  colorFrom: gray
5
  colorTo: yellow
6
+ sdk: docker
7
+ app_port: 7860
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import plotly.graph_objects as go
4
+ import torch
5
+ import json
6
+ import glob
7
+ import numpy as np
8
+ from PIL import Image
9
+ import time
10
+ import tqdm
11
+ import copy
12
+
13
+ # Mesh imports
14
+ from pytorch3d.io import load_objs_as_meshes
15
+ from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene
16
+ from pytorch3d.transforms import Transform3d, RotateAxisAngle, Translate, Rotate
17
+
18
+ from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model
19
+
20
+
21
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
22
+
23
+
24
+ def transform_mesh(mesh, transform, scale=1.0):
25
+ mesh = mesh.clone()
26
+ verts = mesh.verts_packed() * scale
27
+ verts = transform.transform_points(verts)
28
+ mesh.offset_verts_(verts - mesh.verts_packed())
29
+ return mesh
30
+
31
+
32
+ def get_input_pose_fig():
33
+ global curr_camera_dict
34
+ global obj_filename
35
+ global plane_trans
36
+
37
+ plane_filename = 'assets/plane.obj'
38
+
39
+ mesh_scale = 0.75
40
+ mesh = load_objs_as_meshes([obj_filename], device=device)
41
+ mesh.scale_verts_(mesh_scale)
42
+
43
+ plane = load_objs_as_meshes([plane_filename], device=device)
44
+
45
+ ### plane
46
+ rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device)
47
+ plane = transform_mesh(plane, rotate_x)
48
+ translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device)
49
+ plane = transform_mesh(plane, translate_y)
50
+
51
+ fig = plot_scene({
52
+ "plot": {
53
+ "object": mesh,
54
+ },
55
+ },
56
+ axis_args=AxisArgs(showgrid=True, backgroundcolor='#cccde0'),
57
+ xaxis=dict(range=[-1, 1]),
58
+ yaxis=dict(range=[-1, 1]),
59
+ zaxis=dict(range=[-1, 1])
60
+ )
61
+
62
+ plane = plane.detach().cpu()
63
+ verts = plane.verts_packed()
64
+ faces = plane.faces_packed()
65
+
66
+ fig.add_trace(
67
+ go.Mesh3d(
68
+ x=verts[:, 0],
69
+ y=verts[:, 1],
70
+ z=verts[:, 2],
71
+ i=faces[:, 0],
72
+ j=faces[:, 1],
73
+ k=faces[:, 2],
74
+ opacity=0.7,
75
+ color='gray',
76
+ hoverinfo='skip',
77
+ ),
78
+ )
79
+
80
+
81
+ print("fig: curr camera dict")
82
+ print(curr_camera_dict)
83
+ camera_dict = curr_camera_dict
84
+
85
+ fig.update_layout(scene=dict(
86
+ xaxis=dict(showticklabels=True, visible=True),
87
+ yaxis=dict(showticklabels=True, visible=True),
88
+ zaxis=dict(showticklabels=True, visible=True),
89
+ ))
90
+ # show grid
91
+ fig.update_layout(scene=dict(
92
+ xaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
93
+ yaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
94
+ zaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'),
95
+ bgcolor='#dedede',
96
+ ))
97
+
98
+ fig.update_layout(
99
+ camera_dict,
100
+ width=512, height=512,
101
+ )
102
+
103
+ return fig
104
+
105
+
106
+ def run_inference(cam_pose_json, prompt, scale_im, scale, steps, seed):
107
+ print("prompt is ", prompt)
108
+ global current_data, current_model
109
+
110
+ # run model
111
+ images = sample(
112
+ current_model, current_data,
113
+ num_images=1,
114
+ prompt=prompt,
115
+ appendpath="",
116
+ camera_json=cam_pose_json,
117
+ train=False,
118
+ scale=scale,
119
+ scale_im=scale_im,
120
+ beta=1.0,
121
+ num_ref=8,
122
+ skipreflater=False,
123
+ num_steps=steps,
124
+ valid=False,
125
+ max_images=20,
126
+ seed=seed
127
+ )
128
+
129
+ result = images[0]
130
+ print(result.shape)
131
+ result = Image.fromarray((np.clip(((result+1.0)/2.0).permute(1, 2, 0).cpu().numpy(), 0., 1.)*255).astype(np.uint8))
132
+ print('result obtained')
133
+ return result
134
+
135
+
136
+
137
+ def update_curr_camera_dict(camera_json):
138
+ # TODO: this does not always update the figure, also there's always flashes
139
+ global curr_camera_dict
140
+ global prev_camera_dict
141
+ if camera_json is None:
142
+ camera_json = json.dumps(prev_camera_dict)
143
+ camera_json = camera_json.replace("'", "\"")
144
+ curr_camera_dict = json.loads(camera_json) # ["scene.camera"]
145
+ print("update curr camera dict")
146
+ print(curr_camera_dict)
147
+ return camera_json
148
+
149
+
150
+ MODELS_DIR = "pretrained-models/"
151
+
152
+ def select_and_load_model(category, category_single_id):
153
+ global current_data, current_model, base_model
154
+ current_model = None
155
+ current_model = copy.deepcopy(base_model)
156
+
157
+ ### choose model checkpoint and config
158
+ delta_ckpt = glob.glob(f"{MODELS_DIR}/*{category}{category_single_id}*/checkpoints/step=*.ckpt")[0]
159
+ print(f"Loading model from {delta_ckpt}")
160
+
161
+ logdir = delta_ckpt.split('/checkpoints')[0]
162
+ config = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))[-1]
163
+
164
+ start_time = time.time()
165
+ current_model, current_data = load_and_return_model_and_data(config, current_model,
166
+ delta_ckpt=delta_ckpt
167
+ )
168
+
169
+ print(f"Time taken to load delta model: {time.time() - start_time:.2f}s")
170
+
171
+ print("!!! model loaded")
172
+
173
+ input_prompt = f"photo of a <new1> {category}"
174
+ return "### Model loaded!", input_prompt
175
+
176
+
177
+ global current_data
178
+ global current_model
179
+ current_data = None
180
+ current_model = None
181
+
182
+ global base_model
183
+ BASE_CONFIG = "configs/train_co3d_concept.yaml"
184
+ BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
185
+
186
+ start_time = time.time()
187
+ base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False)
188
+ print(f"Time taken to load base model: {time.time() - start_time:.2f}s")
189
+
190
+ global curr_camera_dict
191
+ curr_camera_dict = {
192
+ "scene.camera": {
193
+ "up": {"x": -0.13227683305740356,
194
+ "y": -0.9911391735076904,
195
+ "z": -0.013464212417602539},
196
+ "center": {"x": -0.005292057991027832,
197
+ "y": 0.020704858005046844,
198
+ "z": 0.0873757004737854},
199
+ "eye": {"x": 0.8585731983184814,
200
+ "y": -0.08790968358516693,
201
+ "z": -0.40458938479423523},
202
+ },
203
+ "scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974},
204
+ "scene.aspectmode": "manual"
205
+ }
206
+
207
+ global prev_camera_dict
208
+ prev_camera_dict = copy.deepcopy(curr_camera_dict)
209
+
210
+ global obj_filename
211
+ obj_filename = "assets/car0_mesh_centered_flipped.obj"
212
+ global plane_trans
213
+ plane_trans = 0.16
214
+
215
+ my_fig = get_input_pose_fig()
216
+
217
+ scripts = open("scripts.js", "r").read()
218
+
219
+
220
+ def update_category_single_id(category):
221
+ global curr_camera_dict
222
+ global prev_camera_dict
223
+ global obj_filename
224
+ global plane_trans
225
+ choices = None
226
+
227
+ if category == "car":
228
+ choices = ["0"]
229
+ curr_camera_dict = {
230
+ "scene.camera": {
231
+ "up": {"x": -0.13227683305740356,
232
+ "y": -0.9911391735076904,
233
+ "z": -0.013464212417602539},
234
+ "center": {"x": -0.005292057991027832,
235
+ "y": 0.020704858005046844,
236
+ "z": 0.0873757004737854},
237
+ "eye": {"x": 0.8585731983184814,
238
+ "y": -0.08790968358516693,
239
+ "z": -0.40458938479423523},
240
+ },
241
+ "scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974},
242
+ "scene.aspectmode": "manual"
243
+ }
244
+ plane_trans = 0.16
245
+
246
+ elif category == "chair":
247
+ choices = ["191"]
248
+ curr_camera_dict = {
249
+ "scene.camera": {
250
+ "up": {"x": 1.0477e-04,
251
+ "y": -9.9995e-01,
252
+ "z": 1.0288e-02},
253
+ "center": {"x": 0.0539,
254
+ "y": 0.0015,
255
+ "z": 0.0007},
256
+ "eye": {"x": 0.0410,
257
+ "y": -0.0091,
258
+ "z": -0.9991},
259
+ },
260
+ "scene.aspectratio": {"x": 0.9084, "y": 0.9084, "z": 0.9084},
261
+ "scene.aspectmode": "manual"
262
+ }
263
+ plane_trans = 0.38
264
+
265
+ elif category == "motorcycle":
266
+ choices = ["12"]
267
+ curr_camera_dict = {
268
+ "scene.camera": {
269
+ "up": {"x": 0.0308,
270
+ "y": -0.9994,
271
+ "z": -0.0147},
272
+ "center": {"x": 0.0240,
273
+ "y": -0.0310,
274
+ "z": -0.0016},
275
+ "eye": {"x": -0.0580,
276
+ "y": -0.0188,
277
+ "z": -0.9981},
278
+ },
279
+ "scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786},
280
+ "scene.aspectmode": "manual"
281
+ }
282
+ plane_trans = 0.16
283
+
284
+ elif category == "teddybear":
285
+ choices = ["31"]
286
+ curr_camera_dict = {
287
+ "scene.camera": {
288
+ "up": {"x": 0.4304,
289
+ "y": -0.9023,
290
+ "z": -0.0221},
291
+ "center": {"x": -0.0658,
292
+ "y": 0.2081,
293
+ "z": 0.0175},
294
+ "eye": {"x": -0.4456,
295
+ "y": 0.0493,
296
+ "z": -0.8939},
297
+ },
298
+ "scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052},
299
+ "scene.aspectmode": "manual",
300
+ }
301
+ plane_trans = 0.23
302
+
303
+ obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj"
304
+ prev_camera_dict = copy.deepcopy(curr_camera_dict)
305
+ return gr.Dropdown(choices=choices, label="Object ID", value=choices[0])
306
+
307
+
308
+ head = """
309
+ <script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script>
310
+ """
311
+
312
+ ORIGINAL_SPACE_ID = 'customdiffusion360'
313
+ SPACE_ID = os.getenv('SPACE_ID')
314
+
315
+ SHARED_UI_WARNING = f'''## Attention - the demo requires at least 40GB VRAM for inference. Please clone this repository to run on your own machine.
316
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
317
+ '''
318
+
319
+ with gr.Blocks(head=head,
320
+ css="style.css",
321
+ js=scripts,
322
+ title="Customizing Text-to-Image Diffusion with Camera Viewpoint Control") as demo:
323
+
324
+ gr.HTML("""
325
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
326
+ <div>
327
+ <h1>Customizing Text-to-Image Diffusion with Camera Viewpoint Control</h1>
328
+ </div>
329
+ </div>
330
+ <div>
331
+ </br>
332
+ </div>
333
+ <hr></hr>
334
+ """,
335
+ visible=True
336
+ )
337
+
338
+
339
+ if SPACE_ID == ORIGINAL_SPACE_ID:
340
+ gr.Markdown(SHARED_UI_WARNING)
341
+
342
+ with gr.Row():
343
+ with gr.Column(min_width=150):
344
+ gr.Markdown("## 1. SELECT CUSTOMIZED MODEL")
345
+
346
+ category = gr.Dropdown(choices=["car", "chair", "motorcycle", "teddybear"], label="Category", value="car")
347
+
348
+ category_single_id = gr.Dropdown(label="Object ID", choices=["0"], type="value", value="0", visible=False)
349
+
350
+ category.change(update_category_single_id, [category], [category_single_id])
351
+
352
+ load_model_btn = gr.Button(value="Load Model", elem_id="load_model_button")
353
+
354
+ load_model_status = gr.Markdown(elem_id="load_model_status", value="### Please select and load a model.")
355
+
356
+ with gr.Column(min_width=512):
357
+ gr.Markdown("## 2. CAMERA POSE VISUALIZATION")
358
+
359
+ # TODO ? don't use gradio plotly element so we can remove menu buttons
360
+ map = gr.Plot(value=my_fig, min_width=512, elem_id="map")
361
+
362
+ ### hidden elements
363
+ update_pose_btn = gr.Button(value="Update Camera Pose", visible=False, elem_id="update_pose_button")
364
+ input_pose = gr.TextArea(value=curr_camera_dict, label="Input Camera Pose", visible=False, elem_id="input_pose", interactive=False)
365
+ check_pose_btn = gr.Button(value="Check Camera Pose", visible=False, elem_id="check_pose_button")
366
+
367
+ ## TODO: track init_camera_dict and with js?
368
+
369
+ ### visible elements
370
+ input_prompt = gr.Textbox(value="photo of a <new1> car", label="Prompt", interactive=True)
371
+ scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1)
372
+ scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1)
373
+ steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1)
374
+ seed = gr.Textbox(value=42, label="Seed")
375
+
376
+ with gr.Column(min_width=50, elem_id="column_process", scale=0.3):
377
+ run_btn = gr.Button(value="Run", elem_id="run_button", min_width=50)
378
+
379
+
380
+ with gr.Column(min_width=512):
381
+ gr.Markdown("## 3. OUR OUTPUT")
382
+ result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result")
383
+
384
+ load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt])
385
+ load_model_btn.click(get_input_pose_fig, [], [map])
386
+
387
+ update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio)
388
+ # check_pose_btn.click(check_curr_camera_dict, [], [input_pose])
389
+ run_btn.click(run_inference, [input_pose, input_prompt, scale_im, scale, steps, seed], result)
390
+
391
+ demo.load(js=scripts)
392
+
393
+
394
+ if __name__ == "__main__":
395
+ demo.queue().launch(debug=True)
assets/car0_mesh_centered_flipped.obj ADDED
The diff for this file is too large to render. See raw diff
 
assets/chair191_mesh_centered_flipped.obj ADDED
The diff for this file is too large to render. See raw diff
 
assets/motorcycle12_mesh_centered_flipped.obj ADDED
The diff for this file is too large to render. See raw diff
 
assets/motorcycle29_mesh_centered_flipped.obj ADDED
The diff for this file is too large to render. See raw diff
 
assets/plane.obj ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ o plane
2
+ v 3.000000 -3.000000 0.000000
3
+ v 3.000000 3.000000 0.000000
4
+ v -3.000000 3.000000 0.000000
5
+ v -3.000000 -3.000000 0.000000
6
+
7
+ vt 3.000000 0.000000
8
+ vt 3.000000 3.000000
9
+ vt 0.000000 3.000000
10
+ vt 0.000000 0.000000
11
+
12
+ s off
13
+ f 1/1 2/2 3/3
14
+ f 1/1 3/3 4/4
assets/teddybear0_mesh_centered_flipped.obj ADDED
The diff for this file is too large to render. See raw diff
 
assets/teddybear31_mesh_centered_flipped.obj ADDED
The diff for this file is too large to render. See raw diff
 
configs/train_co3d_concept.yaml ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-4
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: True
7
+ trainkeys: pose
8
+ multiplier: 0.05
9
+ loss_rgb_lambda: 5
10
+ loss_fg_lambda: 10
11
+ loss_bg_lambda: 10
12
+ log_keys:
13
+ - txt
14
+
15
+ denoiser_config:
16
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
17
+ params:
18
+ num_idx: 1000
19
+
20
+ weighting_config:
21
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
22
+ scaling_config:
23
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
24
+ discretization_config:
25
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
26
+
27
+ network_config:
28
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
29
+ params:
30
+ adm_in_channels: 2816
31
+ num_classes: sequential
32
+ use_checkpoint: False
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [4, 2]
37
+ num_res_blocks: 2
38
+ channel_mult: [1, 2, 4]
39
+ num_head_channels: 64
40
+ use_linear_in_transformer: True
41
+ transformer_depth: [1, 2, 10]
42
+ context_dim: 2048
43
+ spatial_transformer_attn_type: softmax-xformers
44
+ image_cross_blocks: [0, 2, 4, 6, 8, 10]
45
+ rgb: True
46
+ far: 2
47
+ num_samples: 24
48
+ not_add_context_in_triplane: False
49
+ rgb_predict: True
50
+ add_lora: False
51
+ average: False
52
+ use_prev_weights_imp_sample: True
53
+ stratified: True
54
+ imp_sampling_percent: 0.9
55
+
56
+ conditioner_config:
57
+ target: sgm.modules.GeneralConditioner
58
+ params:
59
+ emb_models:
60
+ # crossattn cond
61
+ - is_trainable: False
62
+ input_keys: txt,txt_ref
63
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
64
+ params:
65
+ layer: hidden
66
+ layer_idx: 11
67
+ modifier_token: <new1>
68
+ # crossattn and vector cond
69
+ - is_trainable: False
70
+ input_keys: txt,txt_ref
71
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
72
+ params:
73
+ arch: ViT-bigG-14
74
+ version: laion2b_s39b_b160k
75
+ layer: penultimate
76
+ always_return_pooled: True
77
+ legacy: False
78
+ modifier_token: <new1>
79
+ # vector cond
80
+ - is_trainable: False
81
+ input_keys: original_size_as_tuple,original_size_as_tuple_ref
82
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
83
+ params:
84
+ outdim: 256 # multiplied by two
85
+ # vector cond
86
+ - is_trainable: False
87
+ input_keys: crop_coords_top_left,crop_coords_top_left_ref
88
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
89
+ params:
90
+ outdim: 256 # multiplied by two
91
+ # vector cond
92
+ - is_trainable: False
93
+ input_keys: target_size_as_tuple,target_size_as_tuple_ref
94
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
95
+ params:
96
+ outdim: 256 # multiplied by two
97
+
98
+ first_stage_config:
99
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
100
+ params:
101
+ ckpt_path: pretrained-models/sdxl_vae.safetensors
102
+ embed_dim: 4
103
+ monitor: val/rec_loss
104
+ ddconfig:
105
+ attn_type: vanilla-xformers
106
+ double_z: true
107
+ z_channels: 4
108
+ resolution: 256
109
+ in_channels: 3
110
+ out_ch: 3
111
+ ch: 128
112
+ ch_mult: [1, 2, 4, 4]
113
+ num_res_blocks: 2
114
+ attn_resolutions: []
115
+ dropout: 0.0
116
+ lossconfig:
117
+ target: torch.nn.Identity
118
+
119
+ loss_fn_config:
120
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
121
+ params:
122
+ sigma_sampler_config:
123
+ target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
124
+ params:
125
+ num_idx: 1000
126
+ discretization_config:
127
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
128
+ sigma_sampler_config_ref:
129
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
130
+ params:
131
+ num_idx: 50
132
+
133
+ discretization_config:
134
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
135
+
136
+ sampler_config:
137
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
138
+ params:
139
+ num_steps: 50
140
+
141
+ discretization_config:
142
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
143
+
144
+ guider_config:
145
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
146
+ params:
147
+ scale: 7.5
148
+
149
+ data:
150
+ target: sgm.data.data_co3d.CustomDataDictLoader
151
+ params:
152
+ batch_size: 1
153
+ num_workers: 4
154
+ category: teddybear
155
+ img_size: 512
156
+ skip: 2
157
+ num_images: 5
158
+ mask_images: True
159
+ single_id: 0
160
+ bbox: True
161
+ addreg: True
162
+ drop_ratio: 0.25
163
+ drop_txt: 0.1
164
+ modifier_token: <new1>
165
+
166
+ lightning:
167
+ modelcheckpoint:
168
+ params:
169
+ every_n_train_steps: 1600
170
+ save_top_k: -1
171
+ save_on_train_epoch_end: False
172
+
173
+ callbacks:
174
+ metrics_over_trainsteps_checkpoint:
175
+ params:
176
+ every_n_train_steps: 25000
177
+
178
+ image_logger:
179
+ target: main.ImageLogger
180
+ params:
181
+ disabled: False
182
+ enable_autocast: False
183
+ batch_frequency: 5000
184
+ max_images: 8
185
+ increase_log_steps: False
186
+ log_first_step: False
187
+ log_images_kwargs:
188
+ use_ema_scope: False
189
+ N: 1
190
+ n_rows: 2
191
+
192
+ trainer:
193
+ devices: 0,1,2,3
194
+ benchmark: True
195
+ num_sanity_val_steps: 0
196
+ accumulate_grad_batches: 1
197
+ max_steps: 1610
198
+ # val_check_interval: 400
pretrained_models/car0/checkpoints/step=000001600.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b073c96fe525f9530dc69e5fd8e94d6527a7651c5bb4ede5953750fbe157ebd
3
+ size 777852660
pretrained_models/car0/configs/2024-04-12T21-30-20-lightning.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightning:
2
+ modelcheckpoint:
3
+ params:
4
+ every_n_train_steps: 1600
5
+ save_top_k: -1
6
+ save_on_train_epoch_end: false
7
+ callbacks:
8
+ metrics_over_trainsteps_checkpoint:
9
+ params:
10
+ every_n_train_steps: 25000
11
+ image_logger:
12
+ target: main.ImageLogger
13
+ params:
14
+ disabled: false
15
+ enable_autocast: false
16
+ batch_frequency: 5000
17
+ max_images: 8
18
+ increase_log_steps: false
19
+ log_first_step: false
20
+ log_images_kwargs:
21
+ use_ema_scope: false
22
+ 'N': 1
23
+ n_rows: 2
24
+ trainer:
25
+ devices: 0,1,2,3
26
+ benchmark: true
27
+ num_sanity_val_steps: 0
28
+ accumulate_grad_batches: 1
29
+ max_steps: 1610
30
+ val_check_interval: 400
31
+ accelerator: gpu
pretrained_models/car0/configs/2024-04-12T21-30-20-project.yaml ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: true
7
+ trainkeys: pose
8
+ multiplier: 0.05
9
+ loss_rgb_lambda: 5
10
+ loss_fg_lambda: 10
11
+ loss_bg_lambda: 10
12
+ log_keys:
13
+ - txt
14
+ denoiser_config:
15
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
16
+ params:
17
+ num_idx: 1000
18
+ weighting_config:
19
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
20
+ scaling_config:
21
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
22
+ discretization_config:
23
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
24
+ network_config:
25
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
26
+ params:
27
+ adm_in_channels: 2816
28
+ num_classes: sequential
29
+ use_checkpoint: false
30
+ in_channels: 4
31
+ out_channels: 4
32
+ model_channels: 320
33
+ attention_resolutions:
34
+ - 4
35
+ - 2
36
+ num_res_blocks: 2
37
+ channel_mult:
38
+ - 1
39
+ - 2
40
+ - 4
41
+ num_head_channels: 64
42
+ use_linear_in_transformer: true
43
+ transformer_depth:
44
+ - 1
45
+ - 2
46
+ - 10
47
+ context_dim: 2048
48
+ spatial_transformer_attn_type: softmax-xformers
49
+ image_cross_blocks:
50
+ - 0
51
+ - 2
52
+ - 4
53
+ - 6
54
+ - 8
55
+ - 10
56
+ rgb: true
57
+ far: 2
58
+ num_samples: 24
59
+ not_add_context_in_triplane: false
60
+ rgb_predict: true
61
+ add_lora: false
62
+ average: false
63
+ use_prev_weights_imp_sample: true
64
+ stratified: true
65
+ imp_sampling_percent: 0.9
66
+ use_prev_weights_imp_sample: true
67
+ conditioner_config:
68
+ target: sgm.modules.GeneralConditioner
69
+ params:
70
+ emb_models:
71
+ - is_trainable: false
72
+ input_keys: txt,txt_ref
73
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
74
+ params:
75
+ layer: hidden
76
+ layer_idx: 11
77
+ modifier_token: <new1>
78
+ - is_trainable: false
79
+ input_keys: txt,txt_ref
80
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
81
+ params:
82
+ arch: ViT-bigG-14
83
+ version: laion2b_s39b_b160k
84
+ layer: penultimate
85
+ always_return_pooled: true
86
+ legacy: false
87
+ modifier_token: <new1>
88
+ - is_trainable: false
89
+ input_keys: original_size_as_tuple,original_size_as_tuple_ref
90
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
91
+ params:
92
+ outdim: 256
93
+ - is_trainable: false
94
+ input_keys: crop_coords_top_left,crop_coords_top_left_ref
95
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
96
+ params:
97
+ outdim: 256
98
+ - is_trainable: false
99
+ input_keys: target_size_as_tuple,target_size_as_tuple_ref
100
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
101
+ params:
102
+ outdim: 256
103
+ first_stage_config:
104
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
105
+ params:
106
+ ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
107
+ embed_dim: 4
108
+ monitor: val/rec_loss
109
+ ddconfig:
110
+ attn_type: vanilla-xformers
111
+ double_z: true
112
+ z_channels: 4
113
+ resolution: 256
114
+ in_channels: 3
115
+ out_ch: 3
116
+ ch: 128
117
+ ch_mult:
118
+ - 1
119
+ - 2
120
+ - 4
121
+ - 4
122
+ num_res_blocks: 2
123
+ attn_resolutions: []
124
+ dropout: 0.0
125
+ lossconfig:
126
+ target: torch.nn.Identity
127
+ loss_fn_config:
128
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
129
+ params:
130
+ sigma_sampler_config:
131
+ target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
132
+ params:
133
+ num_idx: 1000
134
+ discretization_config:
135
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
136
+ sigma_sampler_config_ref:
137
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
138
+ params:
139
+ num_idx: 50
140
+ discretization_config:
141
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
142
+ sampler_config:
143
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
144
+ params:
145
+ num_steps: 50
146
+ discretization_config:
147
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
148
+ guider_config:
149
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
150
+ params:
151
+ scale: 7.5
152
+ data:
153
+ target: sgm.data.data_co3d.CustomDataDictLoader
154
+ params:
155
+ batch_size: 1
156
+ num_workers: 4
157
+ category: car
158
+ img_size: 512
159
+ skip: 2
160
+ num_images: 5
161
+ mask_images: true
162
+ single_id: 0
163
+ bbox: true
164
+ addreg: true
165
+ drop_ratio: 0.25
166
+ drop_txt: 0.1
167
+ modifier_token: <new1>
168
+ categoryname: null
pretrained_models/car0/configs/2024-04-13T11-31-55-lightning.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightning:
2
+ modelcheckpoint:
3
+ params:
4
+ every_n_train_steps: 1600
5
+ save_top_k: -1
6
+ save_on_train_epoch_end: false
7
+ callbacks:
8
+ metrics_over_trainsteps_checkpoint:
9
+ params:
10
+ every_n_train_steps: 25000
11
+ image_logger:
12
+ target: main.ImageLogger
13
+ params:
14
+ disabled: false
15
+ enable_autocast: false
16
+ batch_frequency: 5000
17
+ max_images: 8
18
+ increase_log_steps: false
19
+ log_first_step: false
20
+ log_images_kwargs:
21
+ use_ema_scope: false
22
+ 'N': 1
23
+ n_rows: 2
24
+ trainer:
25
+ devices: 0,1,2,3
26
+ benchmark: true
27
+ num_sanity_val_steps: 0
28
+ accumulate_grad_batches: 1
29
+ max_steps: 1610
30
+ val_check_interval: 400
31
+ accelerator: gpu
pretrained_models/car0/configs/2024-04-13T11-31-55-project.yaml ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: true
7
+ trainkeys: pose
8
+ multiplier: 0.05
9
+ loss_rgb_lambda: 5
10
+ loss_fg_lambda: 10
11
+ loss_bg_lambda: 10
12
+ log_keys:
13
+ - txt
14
+ denoiser_config:
15
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
16
+ params:
17
+ num_idx: 1000
18
+ weighting_config:
19
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
20
+ scaling_config:
21
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
22
+ discretization_config:
23
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
24
+ network_config:
25
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
26
+ params:
27
+ adm_in_channels: 2816
28
+ num_classes: sequential
29
+ use_checkpoint: false
30
+ in_channels: 4
31
+ out_channels: 4
32
+ model_channels: 320
33
+ attention_resolutions:
34
+ - 4
35
+ - 2
36
+ num_res_blocks: 2
37
+ channel_mult:
38
+ - 1
39
+ - 2
40
+ - 4
41
+ num_head_channels: 64
42
+ use_linear_in_transformer: true
43
+ transformer_depth:
44
+ - 1
45
+ - 2
46
+ - 10
47
+ context_dim: 2048
48
+ spatial_transformer_attn_type: softmax-xformers
49
+ image_cross_blocks:
50
+ - 0
51
+ - 2
52
+ - 4
53
+ - 6
54
+ - 8
55
+ - 10
56
+ rgb: true
57
+ far: 2
58
+ num_samples: 24
59
+ not_add_context_in_triplane: false
60
+ rgb_predict: true
61
+ add_lora: false
62
+ average: false
63
+ use_prev_weights_imp_sample: true
64
+ stratified: true
65
+ imp_sampling_percent: 0.9
66
+ use_prev_weights_imp_sample: true
67
+ conditioner_config:
68
+ target: sgm.modules.GeneralConditioner
69
+ params:
70
+ emb_models:
71
+ - is_trainable: false
72
+ input_keys: txt,txt_ref
73
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
74
+ params:
75
+ layer: hidden
76
+ layer_idx: 11
77
+ modifier_token: <new1>
78
+ - is_trainable: false
79
+ input_keys: txt,txt_ref
80
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
81
+ params:
82
+ arch: ViT-bigG-14
83
+ version: laion2b_s39b_b160k
84
+ layer: penultimate
85
+ always_return_pooled: true
86
+ legacy: false
87
+ modifier_token: <new1>
88
+ - is_trainable: false
89
+ input_keys: original_size_as_tuple,original_size_as_tuple_ref
90
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
91
+ params:
92
+ outdim: 256
93
+ - is_trainable: false
94
+ input_keys: crop_coords_top_left,crop_coords_top_left_ref
95
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
96
+ params:
97
+ outdim: 256
98
+ - is_trainable: false
99
+ input_keys: target_size_as_tuple,target_size_as_tuple_ref
100
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
101
+ params:
102
+ outdim: 256
103
+ first_stage_config:
104
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
105
+ params:
106
+ ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
107
+ embed_dim: 4
108
+ monitor: val/rec_loss
109
+ ddconfig:
110
+ attn_type: vanilla-xformers
111
+ double_z: true
112
+ z_channels: 4
113
+ resolution: 256
114
+ in_channels: 3
115
+ out_ch: 3
116
+ ch: 128
117
+ ch_mult:
118
+ - 1
119
+ - 2
120
+ - 4
121
+ - 4
122
+ num_res_blocks: 2
123
+ attn_resolutions: []
124
+ dropout: 0.0
125
+ lossconfig:
126
+ target: torch.nn.Identity
127
+ loss_fn_config:
128
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
129
+ params:
130
+ sigma_sampler_config:
131
+ target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
132
+ params:
133
+ num_idx: 1000
134
+ discretization_config:
135
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
136
+ sigma_sampler_config_ref:
137
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
138
+ params:
139
+ num_idx: 50
140
+ discretization_config:
141
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
142
+ sampler_config:
143
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
144
+ params:
145
+ num_steps: 50
146
+ discretization_config:
147
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
148
+ guider_config:
149
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
150
+ params:
151
+ scale: 7.5
152
+ data:
153
+ target: sgm.data.data_co3d.CustomDataDictLoader
154
+ params:
155
+ batch_size: 1
156
+ num_workers: 4
157
+ category: car
158
+ img_size: 512
159
+ skip: 2
160
+ num_images: 5
161
+ mask_images: true
162
+ single_id: 0
163
+ bbox: true
164
+ addreg: true
165
+ drop_ratio: 0.25
166
+ drop_txt: 0.1
167
+ modifier_token: <new1>
168
+ categoryname: null
169
+ --log_dir: null
170
+ check_logs: null
pretrained_models/car0/configs/2024-04-13T11-42-30-lightning.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightning:
2
+ modelcheckpoint:
3
+ params:
4
+ every_n_train_steps: 1600
5
+ save_top_k: -1
6
+ save_on_train_epoch_end: false
7
+ callbacks:
8
+ metrics_over_trainsteps_checkpoint:
9
+ params:
10
+ every_n_train_steps: 25000
11
+ image_logger:
12
+ target: main.ImageLogger
13
+ params:
14
+ disabled: false
15
+ enable_autocast: false
16
+ batch_frequency: 5000
17
+ max_images: 8
18
+ increase_log_steps: false
19
+ log_first_step: false
20
+ log_images_kwargs:
21
+ use_ema_scope: false
22
+ 'N': 1
23
+ n_rows: 2
24
+ trainer:
25
+ devices: 0,1,2,3
26
+ benchmark: true
27
+ num_sanity_val_steps: 0
28
+ accumulate_grad_batches: 1
29
+ max_steps: 1610
30
+ val_check_interval: 400
31
+ accelerator: gpu
pretrained_models/car0/configs/2024-04-13T11-42-30-project.yaml ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: true
7
+ trainkeys: pose
8
+ multiplier: 0.05
9
+ loss_rgb_lambda: 5
10
+ loss_fg_lambda: 10
11
+ loss_bg_lambda: 10
12
+ log_keys:
13
+ - txt
14
+ denoiser_config:
15
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
16
+ params:
17
+ num_idx: 1000
18
+ weighting_config:
19
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
20
+ scaling_config:
21
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
22
+ discretization_config:
23
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
24
+ network_config:
25
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
26
+ params:
27
+ adm_in_channels: 2816
28
+ num_classes: sequential
29
+ use_checkpoint: false
30
+ in_channels: 4
31
+ out_channels: 4
32
+ model_channels: 320
33
+ attention_resolutions:
34
+ - 4
35
+ - 2
36
+ num_res_blocks: 2
37
+ channel_mult:
38
+ - 1
39
+ - 2
40
+ - 4
41
+ num_head_channels: 64
42
+ use_linear_in_transformer: true
43
+ transformer_depth:
44
+ - 1
45
+ - 2
46
+ - 10
47
+ context_dim: 2048
48
+ spatial_transformer_attn_type: softmax-xformers
49
+ image_cross_blocks:
50
+ - 0
51
+ - 2
52
+ - 4
53
+ - 6
54
+ - 8
55
+ - 10
56
+ rgb: true
57
+ far: 2
58
+ num_samples: 24
59
+ not_add_context_in_triplane: false
60
+ rgb_predict: true
61
+ add_lora: false
62
+ average: false
63
+ use_prev_weights_imp_sample: true
64
+ stratified: true
65
+ imp_sampling_percent: 0.9
66
+ use_prev_weights_imp_sample: true
67
+ conditioner_config:
68
+ target: sgm.modules.GeneralConditioner
69
+ params:
70
+ emb_models:
71
+ - is_trainable: false
72
+ input_keys: txt,txt_ref
73
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
74
+ params:
75
+ layer: hidden
76
+ layer_idx: 11
77
+ modifier_token: <new1>
78
+ - is_trainable: false
79
+ input_keys: txt,txt_ref
80
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
81
+ params:
82
+ arch: ViT-bigG-14
83
+ version: laion2b_s39b_b160k
84
+ layer: penultimate
85
+ always_return_pooled: true
86
+ legacy: false
87
+ modifier_token: <new1>
88
+ - is_trainable: false
89
+ input_keys: original_size_as_tuple,original_size_as_tuple_ref
90
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
91
+ params:
92
+ outdim: 256
93
+ - is_trainable: false
94
+ input_keys: crop_coords_top_left,crop_coords_top_left_ref
95
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
96
+ params:
97
+ outdim: 256
98
+ - is_trainable: false
99
+ input_keys: target_size_as_tuple,target_size_as_tuple_ref
100
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
101
+ params:
102
+ outdim: 256
103
+ first_stage_config:
104
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
105
+ params:
106
+ ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
107
+ embed_dim: 4
108
+ monitor: val/rec_loss
109
+ ddconfig:
110
+ attn_type: vanilla-xformers
111
+ double_z: true
112
+ z_channels: 4
113
+ resolution: 256
114
+ in_channels: 3
115
+ out_ch: 3
116
+ ch: 128
117
+ ch_mult:
118
+ - 1
119
+ - 2
120
+ - 4
121
+ - 4
122
+ num_res_blocks: 2
123
+ attn_resolutions: []
124
+ dropout: 0.0
125
+ lossconfig:
126
+ target: torch.nn.Identity
127
+ loss_fn_config:
128
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
129
+ params:
130
+ sigma_sampler_config:
131
+ target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
132
+ params:
133
+ num_idx: 1000
134
+ discretization_config:
135
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
136
+ sigma_sampler_config_ref:
137
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
138
+ params:
139
+ num_idx: 50
140
+ discretization_config:
141
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
142
+ sampler_config:
143
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
144
+ params:
145
+ num_steps: 50
146
+ discretization_config:
147
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
148
+ guider_config:
149
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
150
+ params:
151
+ scale: 7.5
152
+ data:
153
+ target: sgm.data.data_co3d.CustomDataDictLoader
154
+ params:
155
+ batch_size: 1
156
+ num_workers: 4
157
+ category: car
158
+ img_size: 512
159
+ skip: 2
160
+ num_images: 5
161
+ mask_images: true
162
+ single_id: 0
163
+ bbox: true
164
+ addreg: true
165
+ drop_ratio: 0.25
166
+ drop_txt: 0.1
167
+ modifier_token: <new1>
168
+ categoryname: null
169
+ --log_dir: null
170
+ check_logs: null
pretrained_models/chair191/checkpoints/step=000001600.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d83a24db919f95ea487572b45524ae1073a1638612a751ecb9589d2060e9b991
3
+ size 777852660
pretrained_models/chair191/configs/2024-04-12T22-10-18-lightning.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightning:
2
+ modelcheckpoint:
3
+ params:
4
+ every_n_train_steps: 1600
5
+ save_top_k: -1
6
+ save_on_train_epoch_end: false
7
+ callbacks:
8
+ metrics_over_trainsteps_checkpoint:
9
+ params:
10
+ every_n_train_steps: 25000
11
+ image_logger:
12
+ target: main.ImageLogger
13
+ params:
14
+ disabled: false
15
+ enable_autocast: false
16
+ batch_frequency: 5000
17
+ max_images: 8
18
+ increase_log_steps: false
19
+ log_first_step: false
20
+ log_images_kwargs:
21
+ use_ema_scope: false
22
+ 'N': 1
23
+ n_rows: 2
24
+ trainer:
25
+ devices: 0,1,2,3
26
+ benchmark: true
27
+ num_sanity_val_steps: 0
28
+ accumulate_grad_batches: 1
29
+ max_steps: 1610
30
+ val_check_interval: 400
31
+ accelerator: gpu
pretrained_models/chair191/configs/2024-04-12T22-10-18-project.yaml ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: true
7
+ trainkeys: pose
8
+ multiplier: 0.05
9
+ loss_rgb_lambda: 5
10
+ loss_fg_lambda: 10
11
+ loss_bg_lambda: 10
12
+ log_keys:
13
+ - txt
14
+ denoiser_config:
15
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
16
+ params:
17
+ num_idx: 1000
18
+ weighting_config:
19
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
20
+ scaling_config:
21
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
22
+ discretization_config:
23
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
24
+ network_config:
25
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
26
+ params:
27
+ adm_in_channels: 2816
28
+ num_classes: sequential
29
+ use_checkpoint: false
30
+ in_channels: 4
31
+ out_channels: 4
32
+ model_channels: 320
33
+ attention_resolutions:
34
+ - 4
35
+ - 2
36
+ num_res_blocks: 2
37
+ channel_mult:
38
+ - 1
39
+ - 2
40
+ - 4
41
+ num_head_channels: 64
42
+ use_linear_in_transformer: true
43
+ transformer_depth:
44
+ - 1
45
+ - 2
46
+ - 10
47
+ context_dim: 2048
48
+ spatial_transformer_attn_type: softmax-xformers
49
+ image_cross_blocks:
50
+ - 0
51
+ - 2
52
+ - 4
53
+ - 6
54
+ - 8
55
+ - 10
56
+ rgb: true
57
+ far: 2
58
+ num_samples: 24
59
+ not_add_context_in_triplane: false
60
+ rgb_predict: true
61
+ add_lora: false
62
+ average: false
63
+ use_prev_weights_imp_sample: true
64
+ stratified: true
65
+ imp_sampling_percent: 0.9
66
+ use_prev_weights_imp_sample: true
67
+ conditioner_config:
68
+ target: sgm.modules.GeneralConditioner
69
+ params:
70
+ emb_models:
71
+ - is_trainable: false
72
+ input_keys: txt,txt_ref
73
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
74
+ params:
75
+ layer: hidden
76
+ layer_idx: 11
77
+ modifier_token: <new1>
78
+ - is_trainable: false
79
+ input_keys: txt,txt_ref
80
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
81
+ params:
82
+ arch: ViT-bigG-14
83
+ version: laion2b_s39b_b160k
84
+ layer: penultimate
85
+ always_return_pooled: true
86
+ legacy: false
87
+ modifier_token: <new1>
88
+ - is_trainable: false
89
+ input_keys: original_size_as_tuple,original_size_as_tuple_ref
90
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
91
+ params:
92
+ outdim: 256
93
+ - is_trainable: false
94
+ input_keys: crop_coords_top_left,crop_coords_top_left_ref
95
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
96
+ params:
97
+ outdim: 256
98
+ - is_trainable: false
99
+ input_keys: target_size_as_tuple,target_size_as_tuple_ref
100
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
101
+ params:
102
+ outdim: 256
103
+ first_stage_config:
104
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
105
+ params:
106
+ ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
107
+ embed_dim: 4
108
+ monitor: val/rec_loss
109
+ ddconfig:
110
+ attn_type: vanilla-xformers
111
+ double_z: true
112
+ z_channels: 4
113
+ resolution: 256
114
+ in_channels: 3
115
+ out_ch: 3
116
+ ch: 128
117
+ ch_mult:
118
+ - 1
119
+ - 2
120
+ - 4
121
+ - 4
122
+ num_res_blocks: 2
123
+ attn_resolutions: []
124
+ dropout: 0.0
125
+ lossconfig:
126
+ target: torch.nn.Identity
127
+ loss_fn_config:
128
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
129
+ params:
130
+ sigma_sampler_config:
131
+ target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
132
+ params:
133
+ num_idx: 1000
134
+ discretization_config:
135
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
136
+ sigma_sampler_config_ref:
137
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
138
+ params:
139
+ num_idx: 50
140
+ discretization_config:
141
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
142
+ sampler_config:
143
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
144
+ params:
145
+ num_steps: 50
146
+ discretization_config:
147
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
148
+ guider_config:
149
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
150
+ params:
151
+ scale: 7.5
152
+ data:
153
+ target: sgm.data.data_co3d.CustomDataDictLoader
154
+ params:
155
+ batch_size: 1
156
+ num_workers: 4
157
+ category: chair
158
+ img_size: 512
159
+ skip: 2
160
+ num_images: 5
161
+ mask_images: true
162
+ single_id: 191
163
+ bbox: true
164
+ addreg: true
165
+ drop_ratio: 0.25
166
+ drop_txt: 0.1
167
+ modifier_token: <new1>
168
+ categoryname: null
pretrained_models/motorcycle12/checkpoints/step=000001600.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11acc84e7c6fbbc9b47f7021dcdaa55032c1e48b88b6bc9a8bb8689f59521c99
3
+ size 777852660
pretrained_models/motorcycle12/configs/2024-04-12T23-30-18-project.yaml ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: true
7
+ trainkeys: pose
8
+ multiplier: 0.05
9
+ loss_rgb_lambda: 5
10
+ loss_fg_lambda: 10
11
+ loss_bg_lambda: 10
12
+ log_keys:
13
+ - txt
14
+ denoiser_config:
15
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
16
+ params:
17
+ num_idx: 1000
18
+ weighting_config:
19
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
20
+ scaling_config:
21
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
22
+ discretization_config:
23
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
24
+ network_config:
25
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
26
+ params:
27
+ adm_in_channels: 2816
28
+ num_classes: sequential
29
+ use_checkpoint: false
30
+ in_channels: 4
31
+ out_channels: 4
32
+ model_channels: 320
33
+ attention_resolutions:
34
+ - 4
35
+ - 2
36
+ num_res_blocks: 2
37
+ channel_mult:
38
+ - 1
39
+ - 2
40
+ - 4
41
+ num_head_channels: 64
42
+ use_linear_in_transformer: true
43
+ transformer_depth:
44
+ - 1
45
+ - 2
46
+ - 10
47
+ context_dim: 2048
48
+ spatial_transformer_attn_type: softmax-xformers
49
+ image_cross_blocks:
50
+ - 0
51
+ - 2
52
+ - 4
53
+ - 6
54
+ - 8
55
+ - 10
56
+ rgb: true
57
+ far: 2
58
+ num_samples: 24
59
+ not_add_context_in_triplane: false
60
+ rgb_predict: true
61
+ add_lora: false
62
+ average: false
63
+ use_prev_weights_imp_sample: true
64
+ stratified: true
65
+ imp_sampling_percent: 0.9
66
+ use_prev_weights_imp_sample: true
67
+ conditioner_config:
68
+ target: sgm.modules.GeneralConditioner
69
+ params:
70
+ emb_models:
71
+ - is_trainable: false
72
+ input_keys: txt,txt_ref
73
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
74
+ params:
75
+ layer: hidden
76
+ layer_idx: 11
77
+ modifier_token: <new1>
78
+ - is_trainable: false
79
+ input_keys: txt,txt_ref
80
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
81
+ params:
82
+ arch: ViT-bigG-14
83
+ version: laion2b_s39b_b160k
84
+ layer: penultimate
85
+ always_return_pooled: true
86
+ legacy: false
87
+ modifier_token: <new1>
88
+ - is_trainable: false
89
+ input_keys: original_size_as_tuple,original_size_as_tuple_ref
90
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
91
+ params:
92
+ outdim: 256
93
+ - is_trainable: false
94
+ input_keys: crop_coords_top_left,crop_coords_top_left_ref
95
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
96
+ params:
97
+ outdim: 256
98
+ - is_trainable: false
99
+ input_keys: target_size_as_tuple,target_size_as_tuple_ref
100
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
101
+ params:
102
+ outdim: 256
103
+ first_stage_config:
104
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
105
+ params:
106
+ ckpt_path: pretrained-models/sdxl_vae.safetensors
107
+ embed_dim: 4
108
+ monitor: val/rec_loss
109
+ ddconfig:
110
+ attn_type: vanilla-xformers
111
+ double_z: true
112
+ z_channels: 4
113
+ resolution: 256
114
+ in_channels: 3
115
+ out_ch: 3
116
+ ch: 128
117
+ ch_mult:
118
+ - 1
119
+ - 2
120
+ - 4
121
+ - 4
122
+ num_res_blocks: 2
123
+ attn_resolutions: []
124
+ dropout: 0.0
125
+ lossconfig:
126
+ target: torch.nn.Identity
127
+ loss_fn_config:
128
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
129
+ params:
130
+ sigma_sampler_config:
131
+ target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
132
+ params:
133
+ num_idx: 1000
134
+ discretization_config:
135
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
136
+ sigma_sampler_config_ref:
137
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
138
+ params:
139
+ num_idx: 50
140
+ discretization_config:
141
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
142
+ sampler_config:
143
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
144
+ params:
145
+ num_steps: 50
146
+ discretization_config:
147
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
148
+ guider_config:
149
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
150
+ params:
151
+ scale: 7.5
152
+ data:
153
+ target: sgm.data.data_co3d.CustomDataDictLoader
154
+ params:
155
+ batch_size: 1
156
+ num_workers: 4
157
+ category: motorcycle
158
+ img_size: 512
159
+ skip: 2
160
+ num_images: 5
161
+ mask_images: true
162
+ single_id: 12
163
+ bbox: true
164
+ addreg: true
165
+ drop_ratio: 0.25
166
+ drop_txt: 0.1
167
+ modifier_token: <new1>
168
+ categoryname: null
pretrained_models/teddybear31/checkpoints/step=000001600.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c474759a84423022acf8eddadb58c5dcb49a9c80077b8f08f3a137f44e5eb76
3
+ size 777852660
pretrained_models/teddybear31/configs/2024-04-12T22-50-24-lightning.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lightning:
2
+ modelcheckpoint:
3
+ params:
4
+ every_n_train_steps: 1600
5
+ save_top_k: -1
6
+ save_on_train_epoch_end: false
7
+ callbacks:
8
+ metrics_over_trainsteps_checkpoint:
9
+ params:
10
+ every_n_train_steps: 25000
11
+ image_logger:
12
+ target: main.ImageLogger
13
+ params:
14
+ disabled: false
15
+ enable_autocast: false
16
+ batch_frequency: 5000
17
+ max_images: 8
18
+ increase_log_steps: false
19
+ log_first_step: false
20
+ log_images_kwargs:
21
+ use_ema_scope: false
22
+ 'N': 1
23
+ n_rows: 2
24
+ trainer:
25
+ devices: 0,1,2,3
26
+ benchmark: true
27
+ num_sanity_val_steps: 0
28
+ accumulate_grad_batches: 1
29
+ max_steps: 1610
30
+ val_check_interval: 400
31
+ accelerator: gpu
pretrained_models/teddybear31/configs/2024-04-12T22-50-24-project.yaml ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 0.0001
3
+ target: sgm.models.diffusion.DiffusionEngine
4
+ params:
5
+ scale_factor: 0.13025
6
+ disable_first_stage_autocast: true
7
+ trainkeys: pose
8
+ multiplier: 0.05
9
+ loss_rgb_lambda: 5
10
+ loss_fg_lambda: 10
11
+ loss_bg_lambda: 10
12
+ log_keys:
13
+ - txt
14
+ denoiser_config:
15
+ target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
16
+ params:
17
+ num_idx: 1000
18
+ weighting_config:
19
+ target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
20
+ scaling_config:
21
+ target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
22
+ discretization_config:
23
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
24
+ network_config:
25
+ target: sgm.modules.diffusionmodules.openaimodel.UNetModel
26
+ params:
27
+ adm_in_channels: 2816
28
+ num_classes: sequential
29
+ use_checkpoint: false
30
+ in_channels: 4
31
+ out_channels: 4
32
+ model_channels: 320
33
+ attention_resolutions:
34
+ - 4
35
+ - 2
36
+ num_res_blocks: 2
37
+ channel_mult:
38
+ - 1
39
+ - 2
40
+ - 4
41
+ num_head_channels: 64
42
+ use_linear_in_transformer: true
43
+ transformer_depth:
44
+ - 1
45
+ - 2
46
+ - 10
47
+ context_dim: 2048
48
+ spatial_transformer_attn_type: softmax-xformers
49
+ image_cross_blocks:
50
+ - 0
51
+ - 2
52
+ - 4
53
+ - 6
54
+ - 8
55
+ - 10
56
+ rgb: true
57
+ far: 2
58
+ num_samples: 24
59
+ not_add_context_in_triplane: false
60
+ rgb_predict: true
61
+ add_lora: false
62
+ average: false
63
+ use_prev_weights_imp_sample: true
64
+ stratified: true
65
+ imp_sampling_percent: 0.9
66
+ use_prev_weights_imp_sample: true
67
+ conditioner_config:
68
+ target: sgm.modules.GeneralConditioner
69
+ params:
70
+ emb_models:
71
+ - is_trainable: false
72
+ input_keys: txt,txt_ref
73
+ target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
74
+ params:
75
+ layer: hidden
76
+ layer_idx: 11
77
+ modifier_token: <new1>
78
+ - is_trainable: false
79
+ input_keys: txt,txt_ref
80
+ target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
81
+ params:
82
+ arch: ViT-bigG-14
83
+ version: laion2b_s39b_b160k
84
+ layer: penultimate
85
+ always_return_pooled: true
86
+ legacy: false
87
+ modifier_token: <new1>
88
+ - is_trainable: false
89
+ input_keys: original_size_as_tuple,original_size_as_tuple_ref
90
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
91
+ params:
92
+ outdim: 256
93
+ - is_trainable: false
94
+ input_keys: crop_coords_top_left,crop_coords_top_left_ref
95
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
96
+ params:
97
+ outdim: 256
98
+ - is_trainable: false
99
+ input_keys: target_size_as_tuple,target_size_as_tuple_ref
100
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
101
+ params:
102
+ outdim: 256
103
+ first_stage_config:
104
+ target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
105
+ params:
106
+ ckpt_path: /sensei-fs/tenants/Sensei-AdobeResearchTeam/nupkumar1/custom-pose/pretrained-models/sdxl_vae.safetensors
107
+ embed_dim: 4
108
+ monitor: val/rec_loss
109
+ ddconfig:
110
+ attn_type: vanilla-xformers
111
+ double_z: true
112
+ z_channels: 4
113
+ resolution: 256
114
+ in_channels: 3
115
+ out_ch: 3
116
+ ch: 128
117
+ ch_mult:
118
+ - 1
119
+ - 2
120
+ - 4
121
+ - 4
122
+ num_res_blocks: 2
123
+ attn_resolutions: []
124
+ dropout: 0.0
125
+ lossconfig:
126
+ target: torch.nn.Identity
127
+ loss_fn_config:
128
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef
129
+ params:
130
+ sigma_sampler_config:
131
+ target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling
132
+ params:
133
+ num_idx: 1000
134
+ discretization_config:
135
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
136
+ sigma_sampler_config_ref:
137
+ target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
138
+ params:
139
+ num_idx: 50
140
+ discretization_config:
141
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
142
+ sampler_config:
143
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
144
+ params:
145
+ num_steps: 50
146
+ discretization_config:
147
+ target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
148
+ guider_config:
149
+ target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef
150
+ params:
151
+ scale: 7.5
152
+ data:
153
+ target: sgm.data.data_co3d.CustomDataDictLoader
154
+ params:
155
+ batch_size: 1
156
+ num_workers: 4
157
+ category: teddybear
158
+ img_size: 512
159
+ skip: 2
160
+ num_images: 5
161
+ mask_images: true
162
+ single_id: 31
163
+ bbox: true
164
+ addreg: true
165
+ drop_ratio: 0.25
166
+ drop_txt: 0.1
167
+ modifier_token: <new1>
168
+ categoryname: null
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ einops
3
+ fire
4
+ tqdm
5
+ pillow
6
+ numpy
7
+ webdataset>=0.2.33
8
+ ninja
9
+ matplotlib
10
+ torchmetrics
11
+ opencv-python==4.6.0.66
12
+ fairscale
13
+ pytorch-lightning==2.0.1
14
+ fire
15
+ fsspec
16
+ kornia==0.6.9
17
+ natsort
18
+ open-clip-torch
19
+ chardet==5.1.0
20
+ tensorboardx==2.6
21
+ pandas
22
+ pudb
23
+ pyyaml
24
+ urllib3<1.27,>=1.25.4
25
+ scipy
26
+ streamlit>=0.73.1
27
+ timm
28
+ tokenizers==0.12.1
29
+ transformers==4.19.1
30
+ triton==2.1.0
31
+ torchdata==0.7.0
32
+ wandb
33
+ invisible-watermark
34
+ xformers
35
+ -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
36
+ -e git+https://github.com/openai/CLIP.git@main#egg=clip
37
+ -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
sampling_for_demo.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import sys
4
+ import copy
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import torch
9
+ from einops import rearrange
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
+ from pytorch_lightning import seed_everything
13
+ from pytorch3d.renderer.cameras import PerspectiveCameras
14
+ from pytorch3d.renderer import look_at_view_transform
15
+ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
16
+
17
+ import json
18
+
19
+ sys.path.append('./')
20
+ from sgm.util import instantiate_from_config, load_safetensors
21
+
22
+ choices = []
23
+
24
+ def append_dims(x, target_dims):
25
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
26
+ dims_to_append = target_dims - x.ndim
27
+ if dims_to_append < 0:
28
+ raise ValueError(
29
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
30
+ )
31
+ return x[(...,) + (None,) * dims_to_append]
32
+
33
+
34
+ def load_base_model(config, ckpt=None, verbose=True):
35
+ config = OmegaConf.load(config)
36
+ # load model
37
+ config.model.params.network_config.params.far = 3
38
+ config.model.params.first_stage_config.params.ckpt_path = "pretrained-models/sdxl_vae.safetensors"
39
+ guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.ScheduledCFGImgTextRef',
40
+ 'params': {'scale': 7.5, 'scale_im': 3.5}
41
+ }
42
+ config.model.params.sampler_config.params.guider_config = guider_config
43
+
44
+ model = instantiate_from_config(config.model)
45
+
46
+ if ckpt is not None:
47
+ print(f"Loading model from {ckpt}")
48
+ if ckpt.endswith("ckpt"):
49
+ pl_sd = torch.load(ckpt, map_location="cpu")
50
+ if "global_step" in pl_sd:
51
+ print(f"Global Step: {pl_sd['global_step']}")
52
+ sd = pl_sd["state_dict"]
53
+ elif ckpt.endswith("safetensors"):
54
+ sd = load_safetensors(ckpt)
55
+ if 'modifier_token' in config.data.params:
56
+ del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight']
57
+ del sd['conditioner.embedders.1.model.token_embedding.weight']
58
+ else:
59
+ raise NotImplementedError
60
+
61
+ m, u = model.load_state_dict(sd, strict=False)
62
+
63
+ model.cuda()
64
+ model.eval()
65
+ return model
66
+
67
+
68
+ def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True):
69
+ """
70
+ model is preloaded base stable diffusion model
71
+ """
72
+
73
+ msg = None
74
+ if delta_ckpt is not None:
75
+ pl_sd_delta = torch.load(delta_ckpt, map_location="cpu")
76
+ sd_delta = pl_sd_delta["delta_state_dict"]
77
+
78
+ # TODO: add new delta loading embedding stuff?
79
+
80
+ for name, module in model.model.diffusion_model.named_modules():
81
+ if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
82
+ if hasattr(module, 'pose_emb_layers'):
83
+ module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references'])
84
+ del sd_delta[f'model.diffusion_model.{name}.references']
85
+
86
+ m, u = model.load_state_dict(sd_delta, strict=False)
87
+
88
+
89
+ if len(m) > 0 and verbose:
90
+ print("missing keys:")
91
+ if len(u) > 0 and verbose:
92
+ print("unexpected keys:")
93
+
94
+ if freeze:
95
+ for param in model.parameters():
96
+ param.requires_grad = False
97
+
98
+ model.cuda()
99
+ model.eval()
100
+ return model, msg
101
+
102
+
103
+ def get_unique_embedder_keys_from_conditioner(conditioner):
104
+ p = [x.input_keys for x in conditioner.embedders]
105
+ return list(set([item for sublist in p for item in sublist])) + ['jpg_ref']
106
+
107
+
108
+ def customforward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None, timesteps=None, drop_im=None):
109
+ # note: if no context is given, cross-attention defaults to self-attention
110
+ if not isinstance(context, list):
111
+ context = [context]
112
+ b, c, h, w = x.shape
113
+ x_in = x
114
+ fg_masks = []
115
+ alphas = []
116
+ rgbs = []
117
+
118
+ x = self.norm(x)
119
+
120
+ if not self.use_linear:
121
+ x = self.proj_in(x)
122
+
123
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
124
+ if self.use_linear:
125
+ x = self.proj_in(x)
126
+
127
+ prev_weights = None
128
+ counter = 0
129
+ for i, block in enumerate(self.transformer_blocks):
130
+ if i > 0 and len(context) == 1:
131
+ i = 0 # use same context for each block
132
+ if self.image_cross and (counter % self.poscontrol_interval == 0):
133
+ x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=x, pose=pose, mask_ref=mask_ref, prev_weights=prev_weights, drop_im=drop_im)
134
+ prev_weights = weights
135
+ fg_masks.append(fg_mask)
136
+ if alpha is not None:
137
+ alphas.append(alpha)
138
+ if rgb is not None:
139
+ rgbs.append(rgb)
140
+ else:
141
+ x, _, _, _, _ = block(x, context=context[i], drop_im=drop_im)
142
+ counter += 1
143
+ if self.use_linear:
144
+ x = self.proj_out(x)
145
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
146
+ if not self.use_linear:
147
+ x = self.proj_out(x)
148
+ if len(fg_masks) > 0:
149
+ if len(rgbs) <= 0:
150
+ rgbs = None
151
+ if len(alphas) <= 0:
152
+ alphas = None
153
+ return x + x_in, None, fg_masks, prev_weights, alphas, rgbs
154
+ else:
155
+ return x + x_in, None, None, prev_weights, None, None
156
+
157
+
158
+ def _customforward(
159
+ self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, drop_im=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
160
+ ):
161
+ if context_ref is not None:
162
+ global choices
163
+ batch_size = x.size(0)
164
+ # IP2P like sampling or default sampling
165
+ if batch_size % 3 == 0:
166
+ batch_size = batch_size // 3
167
+ context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1)
168
+ context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref, context_ref], dim=0)
169
+ else:
170
+ batch_size = batch_size // 2
171
+ context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1)
172
+ context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref], dim=0)
173
+
174
+ fg_mask = None
175
+ weights = None
176
+ alphas = None
177
+ predicted_rgb = None
178
+
179
+ x = (
180
+ self.attn1(
181
+ self.norm1(x),
182
+ context=context if self.disable_self_attn else None,
183
+ additional_tokens=additional_tokens,
184
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
185
+ if not self.disable_self_attn
186
+ else 0,
187
+ )
188
+ + x
189
+ )
190
+
191
+ x = (
192
+ self.attn2(
193
+ self.norm2(x), context=context, additional_tokens=additional_tokens,
194
+ )
195
+ + x
196
+ )
197
+
198
+ if context_ref is not None:
199
+ if self.rendered_feat is not None:
200
+ x = self.pose_emb_layers(torch.cat([x, self.rendered_feat], dim=-1))
201
+ else:
202
+ xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x,
203
+ context_ref,
204
+ context,
205
+ pose,
206
+ prev_weights,
207
+ mask_ref)
208
+ self.rendered_feat = xref
209
+ x = self.pose_emb_layers(torch.cat([x, xref], -1))
210
+
211
+ x = self.ff(self.norm3(x)) + x
212
+ return x, fg_mask, weights, alphas, predicted_rgb
213
+
214
+
215
+ def log_images(
216
+ model,
217
+ batch,
218
+ N: int = 1,
219
+ noise=None,
220
+ scale_im=3.5,
221
+ num_steps: int = 10,
222
+ ucg_keys: List[str] = None,
223
+ **kwargs,
224
+ ):
225
+
226
+ log = dict()
227
+ conditioner_input_keys = [e.input_keys for e in model.conditioner.embedders]
228
+ ucg_keys = conditioner_input_keys
229
+ pose = batch['pose']
230
+
231
+ c, uc = model.conditioner.get_unconditional_conditioning(
232
+ batch,
233
+ force_uc_zero_embeddings=ucg_keys
234
+ if len(model.conditioner.embedders) > 0
235
+ else [],
236
+ force_ref_zero_embeddings=True
237
+ )
238
+
239
+ _, n = 1, len(pose)-1
240
+ sampling_kwargs = {}
241
+
242
+ if scale_im > 0:
243
+ if uc is not None:
244
+ if isinstance(pose, list):
245
+ pose = pose[:N]*3
246
+ else:
247
+ pose = torch.cat([pose[:N]] * 3)
248
+ else:
249
+ if uc is not None:
250
+ if isinstance(pose, list):
251
+ pose = pose[:N]*2
252
+ else:
253
+ pose = torch.cat([pose[:N]] * 2)
254
+
255
+ sampling_kwargs['pose'] = pose
256
+ sampling_kwargs['drop_im'] = None
257
+ sampling_kwargs['mask_ref'] = None
258
+
259
+ for k in c:
260
+ if isinstance(c[k], torch.Tensor):
261
+ c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to('cuda'), (c, uc))
262
+
263
+ import time
264
+ st = time.time()
265
+ with model.ema_scope("Plotting"):
266
+ samples = model.sample(
267
+ c, shape=noise.shape[1:], uc=uc, batch_size=N, num_steps=num_steps, noise=noise, **sampling_kwargs
268
+ )
269
+ model.clear_rendered_feat()
270
+ samples = model.decode_first_stage(samples)
271
+ print("Time taken for sampling", time.time() - st)
272
+ log["samples"] = samples.cpu()
273
+
274
+ return log
275
+
276
+
277
+ def process_camera_json(camera_json, example_cam):
278
+ # replace all single quotes in the camera_json with quotes quotes
279
+ camera_json = camera_json.replace("'", "\"")
280
+ print("input camera json")
281
+ print(camera_json)
282
+
283
+ camera_dict = json.loads(camera_json)["scene.camera"]
284
+ eye = torch.tensor([camera_dict["eye"]["x"], camera_dict["eye"]["y"], camera_dict["eye"]["z"]], dtype=torch.float32).unsqueeze(0)
285
+ up = torch.tensor([camera_dict["up"]["x"], camera_dict["up"]["y"], camera_dict["up"]["z"]], dtype=torch.float32).unsqueeze(0)
286
+ center = torch.tensor([camera_dict["center"]["x"], camera_dict["center"]["y"], camera_dict["center"]["z"]], dtype=torch.float32).unsqueeze(0)
287
+ new_R, new_T = look_at_view_transform(eye=eye, at=center, up=up)
288
+
289
+ ## temp
290
+ # new_R = torch.tensor([[[ 0.4988, 0.2666, 0.8247],
291
+ # [-0.1917, -0.8940, 0.4049],
292
+ # [ 0.8453, -0.3601, -0.3948]]], dtype=torch.float32)
293
+ # new_T = torch.tensor([[ 0.0739, -0.0013, 0.9973]], dtype=torch.float32)
294
+
295
+
296
+ # new_R = torch.tensor([[[ 0.2530, 0.2989, 0.9201],
297
+ # [-0.2652, -0.8932, 0.3631],
298
+ # [ 0.9304, -0.3359, -0.1467],]], dtype=torch.float32)
299
+ # new_T = torch.tensor([[ 0.0081, 0.0337, 1.0452]], dtype=torch.float32)
300
+
301
+
302
+ print("focal length", example_cam.focal_length)
303
+ print("principal point", example_cam.principal_point)
304
+
305
+ newcam = PerspectiveCameras(R=new_R,
306
+ T=new_T,
307
+ focal_length=example_cam.focal_length,
308
+ principal_point=example_cam.principal_point,
309
+ image_size=512)
310
+
311
+ print("input pose")
312
+ print(newcam.get_world_to_view_transform().get_matrix())
313
+ return newcam
314
+
315
+
316
+ def load_and_return_model_and_data(config, model,
317
+ ckpt="/data/gdsu/customization3d/stable-diffusion-xl-base-1.0/sd_xl_base_1.0.safetensors",
318
+ delta_ckpt=None,
319
+ train=False,
320
+ valid=False,
321
+ far=3,
322
+ num_images=1,
323
+ num_ref=8,
324
+ max_images=20,
325
+ ):
326
+ config = OmegaConf.load(config)
327
+ # load data
328
+ data = None
329
+ # config.data.params.jitter = False
330
+ # config.data.params.addreg = False
331
+ # config.data.params.bbox = False
332
+
333
+ # data = instantiate_from_config(config.data)
334
+ # data = data.train_dataset
335
+
336
+ # single_id = data.single_id
337
+
338
+ # if hasattr(data, 'rotations'):
339
+ # total_images = len(data.rotations[data.sequence_list[single_id]])
340
+ # else:
341
+ # total_images = len(data.annotations['chair'])
342
+ # print(f"Total images in dataset: {total_images}")
343
+
344
+ model, msg = load_delta_model(model, delta_ckpt,)
345
+
346
+ # change forward methods to store rendered features and use the pre-calculated reference features
347
+ def register_recr(net_):
348
+ if net_.__class__.__name__ == 'SpatialTransformer':
349
+ print(net_.__class__.__name__, "adding control")
350
+ bound_method = customforward.__get__(net_, net_.__class__)
351
+ setattr(net_, 'forward', bound_method)
352
+ return
353
+ elif hasattr(net_, 'children'):
354
+ for net__ in net_.children():
355
+ register_recr(net__)
356
+ return
357
+
358
+ def register_recr2(net_):
359
+ if net_.__class__.__name__ == 'BasicTransformerBlock':
360
+ print(net_.__class__.__name__, "adding control")
361
+ bound_method = _customforward.__get__(net_, net_.__class__)
362
+ setattr(net_, 'forward', bound_method)
363
+ return
364
+ elif hasattr(net_, 'children'):
365
+ for net__ in net_.children():
366
+ register_recr2(net__)
367
+ return
368
+
369
+ sub_nets = model.model.diffusion_model.named_children()
370
+ for net in sub_nets:
371
+ register_recr(net[1])
372
+ register_recr2(net[1])
373
+
374
+ # start sampling
375
+ model.clear_rendered_feat()
376
+
377
+ return model, data
378
+
379
+
380
+ def sample(model, data,
381
+ num_images=1,
382
+ prompt="",
383
+ appendpath="",
384
+ camera_json=None,
385
+ train=False,
386
+ scale=7.5,
387
+ scale_im=3.5,
388
+ beta=1.0,
389
+ num_ref=8,
390
+ skipreflater=False,
391
+ num_steps=10,
392
+ valid=False,
393
+ max_images=20,
394
+ seed=42,
395
+ camera_path="pretrained-models/car0/camera.bin",
396
+ ):
397
+
398
+ """
399
+ Only works with num_images=1 (because of camera_json processing)
400
+ """
401
+
402
+ if num_images != 1:
403
+ print("forcing num_images to be 1")
404
+ num_images = 1
405
+
406
+ # set guidance scales
407
+ model.sampler.guider.scale_im = scale_im
408
+ model.sampler.guider.scale = scale
409
+
410
+ seed_everything(seed)
411
+
412
+ # load cameras
413
+ cameras_val, cameras_train = torch.load(camera_path)
414
+ global choices
415
+ num_ref = 8
416
+ max_diff = len(cameras_train)/num_ref
417
+ choices = [int(x) for x in torch.linspace(0, len(cameras_train) - max_diff, num_ref)]
418
+ cameras_train_final = [cameras_train[i] for i in choices]
419
+
420
+ # start sampling
421
+ model.clear_rendered_feat()
422
+
423
+ if prompt == "":
424
+ prompt = None
425
+
426
+ noise = torch.randn(1, 4, 64, 64).to('cuda').repeat(num_images, 1, 1, 1)
427
+
428
+ # random sample camera poses
429
+ pose_ids = np.random.choice(len(cameras_val), num_images, replace=False)
430
+ print(pose_ids)
431
+ pose_ids[0] = 21
432
+
433
+ pose = [cameras_val[i] for i in pose_ids]
434
+
435
+ print("example camera")
436
+ print(pose[0].R)
437
+ print(pose[0].T)
438
+ print(pose[0].focal_length)
439
+ print(pose[0].principal_point)
440
+
441
+ # prepare batches [if translating then call required functions on the target pose]
442
+ batches = []
443
+ for i in range(num_images):
444
+ batch = {'pose': [pose[i]] + cameras_train_final,
445
+ "original_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2),
446
+ "target_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2),
447
+ "crop_coords_top_left": torch.tensor([0, 0]).reshape(-1, 2),
448
+ "original_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2),
449
+ "target_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2),
450
+ "crop_coords_top_left_ref": torch.tensor([0, 0]).reshape(-1, 2),
451
+ }
452
+ batch_ = copy.deepcopy(batch)
453
+ batch_["pose"][0] = process_camera_json(camera_json, pose[0])
454
+ batch_["pose"] = [join_cameras_as_batch(batch_["pose"])]
455
+ # print('batched')
456
+ # print(batch_["pose"][0].get_world_to_view_transform().get_matrix())
457
+ batches.append(batch_)
458
+
459
+ print(f'len batches: {len(batches)}')
460
+
461
+ image = None
462
+
463
+ with torch.no_grad():
464
+ for batch in batches:
465
+ for key in batch.keys():
466
+ if isinstance(batch[key], torch.Tensor):
467
+ batch[key] = batch[key].to('cuda')
468
+ elif 'pose' in key:
469
+ batch[key] = [x.to('cuda') for x in batch[key]]
470
+ else:
471
+ pass
472
+
473
+ if prompt is not None:
474
+ batch["txt"] = [prompt for _ in range(1)]
475
+ batch["txt_ref"] = [prompt for _ in range(len(batch["pose"])-1)]
476
+
477
+ print(batch["txt"])
478
+ N = 1
479
+ log_ = log_images(model, batch, N=N, noise=noise.clone()[:N], num_steps=num_steps, scale_im=scale_im)
480
+ image = log_["samples"]
481
+
482
+ torch.cuda.empty_cache()
483
+ model.clear_rendered_feat()
484
+
485
+ print("generation done")
486
+ return image
487
+
scripts.js ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ async () => {
2
+
3
+ globalThis.init_camera_dict = {
4
+ "scene.camera": {
5
+ "up": {"x": -0.13227683305740356,
6
+ "y": -0.9911391735076904,
7
+ "z": -0.013464212417602539},
8
+ "center": {"x": -0.005292057991027832,
9
+ "y": 0.020704858005046844,
10
+ "z": 0.0873757004737854},
11
+ "eye": {"x": 0.8585731983184814,
12
+ "y": -0.08790968358516693,
13
+ "z": -0.40458938479423523},
14
+ },
15
+ "scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974},
16
+ "scene.aspectmode": "manual"
17
+ };
18
+
19
+ // globalThis.restrictCamera = (data) => {
20
+ // var plotlyDiv = document.getElementById("map").getElementsByClassName('js-plotly-plot')[0];
21
+ // // var curr_eye = plotlyDiv.layout.scene.camera.eye;
22
+ // // var curr_center = plotlyDiv.layout.scene.camera.center;
23
+ // // var curr_up = plotlyDiv.layout.scene.camera.up;
24
+
25
+ // var curr_eye = data["scene.camera"]["eye"];
26
+ // var curr_center = data["scene.camera"]["center"];
27
+ // var curr_up = data["scene.camera"]["up"];
28
+
29
+ // var D = Math.sqrt((curr_eye.x - curr_center.x)**2 + (curr_eye.y - curr_center.y)**2 + (curr_eye.z - curr_center.z)**2);
30
+ // console.log("D", D);
31
+
32
+ // const max_D = 1.47;
33
+ // const min_D = 0.8;
34
+
35
+ // // calculate elevation
36
+ // var elevation = Math.atan2(curr_eye.y - curr_center.y, Math.sqrt((curr_eye.x - curr_center.x)**2 + (curr_eye.z - curr_center.z)**2)) * 180 / Math.PI;
37
+ // console.log("elevation", elevation);
38
+ // const max_elev = 3.2;
39
+ // const min_elev = -30;
40
+
41
+ // const eps = 0.01;
42
+
43
+ // if (D > max_D) {
44
+ // // find new_eye such that D = max_D
45
+ // var new_dict = {
46
+ // "scene.camera": {
47
+ // "eye": {
48
+ // "x": curr_center.x + (curr_eye.x - curr_center.x) * max_D / D - eps,
49
+ // "y": curr_center.y + (curr_eye.y - curr_center.y) * max_D / D - eps,
50
+ // "z": curr_center.z + (curr_eye.z - curr_center.z) * max_D / D - eps,
51
+ // },
52
+ // "up": curr_up,
53
+ // "center": curr_center,
54
+ // }
55
+ // };
56
+
57
+ // Plotly.relayout(plotlyDiv, new_dict);
58
+
59
+ // } else if (D < min_D) {
60
+ // // find new_eye such that D = min_D
61
+ // var new_dict = {
62
+ // "scene.camera": {
63
+ // "eye": {
64
+ // "x": curr_center.x + (curr_eye.x - curr_center.x) * min_D / D - eps,
65
+ // "y": curr_center.y + (curr_eye.y - curr_center.y) * min_D / D - eps,
66
+ // "z": curr_center.z + (curr_eye.z - curr_center.z) * min_D / D - eps,
67
+ // },
68
+ // "up": curr_up,
69
+ // "center": curr_center,
70
+ // }
71
+ // };
72
+
73
+ // Plotly.relayout(plotlyDiv, new_dict);
74
+ // }
75
+
76
+ // const eta = 0.001;
77
+ // if (elevation > max_elev) {
78
+ // // find new eye such that y elevation = max_elev
79
+ // var new_dict = {
80
+ // "scene.camera": {
81
+ // "eye": {
82
+ // "x": curr_eye.x,
83
+ // "y": curr_center.y + (curr_eye.y - curr_center.y) * Math.tan((max_elev - eta) * Math.PI / 180),
84
+ // "z": curr_eye.z,
85
+ // },
86
+ // "up": curr_up,
87
+ // "center": curr_center,
88
+ // }
89
+ // };
90
+
91
+
92
+ // Plotly.relayout(plotlyDiv, new_dict);
93
+
94
+ // } else if (elevation < min_elev) {
95
+ // // find new eye such that y elevation = min_elev
96
+ // var new_dict = {
97
+ // "scene.camera": {
98
+ // "eye": {
99
+ // "x": curr_eye.x,
100
+ // "y": curr_center.y + (curr_eye.y - curr_center.y) * Math.tan((min_elev + eta) * Math.PI / 180),
101
+ // "z": curr_eye.z,
102
+ // },
103
+ // "up": curr_up,
104
+ // "center": curr_center,
105
+ // }
106
+ // };
107
+
108
+ // Plotly.relayout(plotlyDiv, new_dict);
109
+ // }
110
+
111
+ // }
112
+
113
+ globalThis.latestCam = () => {
114
+ var plotlyDiv = document.getElementById("map").getElementsByClassName('js-plotly-plot')[0];
115
+
116
+ globalThis.prev_camera_dict = {};
117
+ console.log("prev camera dict", globalThis.prev_camera_dict);
118
+
119
+ // Listen for the event and log to the console
120
+ plotlyDiv.on('plotly_relayout', function(data) {
121
+ console.log('plotly_relayout event triggered:', data);
122
+
123
+ if ("scene.camera.up" in data) {
124
+ Object.assign(globalThis.prev_camera_dict, globalThis.camera_dict);
125
+ Object.assign(globalThis.camera_dict, globalThis.init_camera_dict);
126
+ }
127
+
128
+ if ('scene.camera' in data) {
129
+ Object.assign(globalThis.prev_camera_dict, globalThis.camera_dict);
130
+ globalThis.camera_dict = data;
131
+ }
132
+
133
+ var camera_json = JSON.stringify(globalThis.camera_dict);
134
+ var input_pose = document.getElementById("input_pose").getElementsByTagName("textarea")[0];
135
+ let myEvent = new Event("input")
136
+ input_pose.value = camera_json;
137
+ input_pose.dispatchEvent(myEvent);
138
+
139
+ var update_pose_btn = document.getElementById("update_pose_button");
140
+ update_pose_btn.dispatchEvent(new Event("click"));
141
+ // globalThis.restrictCamera(data);
142
+ });
143
+ }
144
+
145
+ return latestCam(this);
146
+
147
+ }
sgm/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .models import AutoencodingEngine, DiffusionEngine
2
+ from .util import get_configs_path, instantiate_from_config
3
+
4
+ __version__ = "0.1.0"
sgm/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .dataset import StableDataModuleFromConfig
sgm/data/data_co3d.py ADDED
@@ -0,0 +1,762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code taken and modified from https://github.com/amyxlase/relpose-plus-plus/blob/b33f7d5000cf2430bfcda6466c8e89bc2dcde43f/relpose/dataset/co3d_v2.py#L346)
2
+ import os.path as osp
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import pytorch_lightning as pl
8
+
9
+ from PIL import Image, ImageFile
10
+ import json
11
+ import gzip
12
+ from torch.utils.data import DataLoader, Dataset
13
+ from torchvision import transforms
14
+ from pytorch3d.renderer.cameras import PerspectiveCameras
15
+ from pytorch3d.renderer.camera_utils import join_cameras_as_batch
16
+ from pytorch3d.implicitron.dataset.utils import adjust_camera_to_bbox_crop_, adjust_camera_to_image_scale_
17
+ from pytorch3d.transforms import Rotate, Translate
18
+
19
+
20
+ CO3D_DIR = "data/training/"
21
+
22
+ Image.MAX_IMAGE_PIXELS = None
23
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
24
+
25
+
26
+ # Added: normalize camera poses
27
+ def intersect_skew_line_groups(p, r, mask):
28
+ # p, r both of shape (B, N, n_intersected_lines, 3)
29
+ # mask of shape (B, N, n_intersected_lines)
30
+ p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
31
+ _, p_line_intersect = _point_line_distance(
32
+ p, r, p_intersect[..., None, :].expand_as(p)
33
+ )
34
+ intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
35
+ dim=-1
36
+ )
37
+ return p_intersect, p_line_intersect, intersect_dist_squared, r
38
+
39
+
40
+ def intersect_skew_lines_high_dim(p, r, mask=None):
41
+ # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
42
+ dim = p.shape[-1]
43
+ # make sure the heading vectors are l2-normed
44
+ if mask is None:
45
+ mask = torch.ones_like(p[..., 0])
46
+ r = torch.nn.functional.normalize(r, dim=-1)
47
+
48
+ eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
49
+ I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
50
+ sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
51
+ p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
52
+
53
+ if torch.any(torch.isnan(p_intersect)):
54
+ print(p_intersect)
55
+ assert False
56
+ return p_intersect, r
57
+
58
+
59
+ def _point_line_distance(p1, r1, p2):
60
+ df = p2 - p1
61
+ proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
62
+ line_pt_nearest = p2 - proj_vector
63
+ d = (proj_vector).norm(dim=-1)
64
+ return d, line_pt_nearest
65
+
66
+
67
+ def compute_optical_axis_intersection(cameras):
68
+ centers = cameras.get_camera_center()
69
+ principal_points = cameras.principal_point
70
+
71
+ one_vec = torch.ones((len(cameras), 1))
72
+ optical_axis = torch.cat((principal_points, one_vec), -1)
73
+
74
+ pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
75
+
76
+ pp2 = torch.zeros((pp.shape[0], 3))
77
+ for i in range(0, pp.shape[0]):
78
+ pp2[i] = pp[i][i]
79
+
80
+ directions = pp2 - centers
81
+ centers = centers.unsqueeze(0).unsqueeze(0)
82
+ directions = directions.unsqueeze(0).unsqueeze(0)
83
+
84
+ p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
85
+ p=centers, r=directions, mask=None
86
+ )
87
+
88
+ p_intersect = p_intersect.squeeze().unsqueeze(0)
89
+ dist = (p_intersect - centers).norm(dim=-1)
90
+
91
+ return p_intersect, dist, p_line_intersect, pp2, r
92
+
93
+
94
+ def normalize_cameras(cameras, scale=1.0):
95
+ """
96
+ Normalizes cameras such that the optical axes point to the origin and the average
97
+ distance to the origin is 1.
98
+
99
+ Args:
100
+ cameras (List[camera]).
101
+ """
102
+
103
+ # Let distance from first camera to origin be unit
104
+ new_cameras = cameras.clone()
105
+ new_transform = new_cameras.get_world_to_view_transform()
106
+
107
+ p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
108
+ cameras
109
+ )
110
+ t = Translate(p_intersect)
111
+
112
+ # scale = dist.squeeze()[0]
113
+ scale = max(dist.squeeze())
114
+
115
+ # Degenerate case
116
+ if scale == 0:
117
+ print(cameras.T)
118
+ print(new_transform.get_matrix()[:, 3, :3])
119
+ return -1
120
+ assert scale != 0
121
+
122
+ new_transform = t.compose(new_transform)
123
+ new_cameras.R = new_transform.get_matrix()[:, :3, :3]
124
+ new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale
125
+ return new_cameras, p_intersect, p_line_intersect, pp, r
126
+
127
+
128
+ def centerandalign(cameras, scale=1.0):
129
+ """
130
+ Normalizes cameras such that the optical axes point to the origin and the average
131
+ distance to the origin is 1.
132
+
133
+ Args:
134
+ cameras (List[camera]).
135
+ """
136
+
137
+ # Let distance from first camera to origin be unit
138
+ new_cameras = cameras.clone()
139
+ new_transform = new_cameras.get_world_to_view_transform()
140
+
141
+ p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
142
+ cameras
143
+ )
144
+ t = Translate(p_intersect)
145
+
146
+ centers = [cam.get_camera_center() for cam in new_cameras]
147
+ centers = torch.concat(centers, 0).cpu().numpy()
148
+ m = len(cameras)
149
+
150
+ # https://math.stackexchange.com/questions/99299/best-fitting-plane-given-a-set-of-points
151
+ A = np.hstack((centers[:m, :2], np.ones((m, 1))))
152
+ B = centers[:m, 2:]
153
+ if A.shape[0] == 2:
154
+ x = A.T @ np.linalg.inv(A @ A.T) @ B
155
+ else:
156
+ x = np.linalg.inv(A.T @ A) @ A.T @ B
157
+ a, b, c = x.flatten()
158
+ n = np.array([a, b, 1])
159
+ n /= np.linalg.norm(n)
160
+
161
+ # https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d
162
+ v = np.cross(n, [0, 1, 0])
163
+ s = np.linalg.norm(v)
164
+ c = np.dot(n, [0, 1, 0])
165
+ V = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
166
+ rot = torch.from_numpy(np.eye(3) + V + V @ V * (1 - c) / s**2).float()
167
+
168
+ scale = dist.squeeze()[0]
169
+
170
+ # Degenerate case
171
+ if scale == 0:
172
+ print(cameras.T)
173
+ print(new_transform.get_matrix()[:, 3, :3])
174
+ return -1
175
+ assert scale != 0
176
+
177
+ rot = Rotate(rot.T)
178
+
179
+ new_transform = rot.compose(t).compose(new_transform)
180
+ new_cameras.R = new_transform.get_matrix()[:, :3, :3]
181
+ new_cameras.T = new_transform.get_matrix()[:, 3, :3] / scale
182
+ return new_cameras
183
+
184
+
185
+ def square_bbox(bbox, padding=0.0, astype=None):
186
+ """
187
+ Computes a square bounding box, with optional padding parameters.
188
+
189
+ Args:
190
+ bbox: Bounding box in xyxy format (4,).
191
+
192
+ Returns:
193
+ square_bbox in xyxy format (4,).
194
+ """
195
+ if astype is None:
196
+ astype = type(bbox[0])
197
+ bbox = np.array(bbox)
198
+ center = ((bbox[:2] + bbox[2:]) / 2).round().astype(int)
199
+ extents = (bbox[2:] - bbox[:2]) / 2
200
+ s = (max(extents) * (1 + padding)).round().astype(int)
201
+ square_bbox = np.array(
202
+ [center[0] - s, center[1] - s, center[0] + s, center[1] + s],
203
+ dtype=astype,
204
+ )
205
+
206
+ return square_bbox
207
+
208
+
209
+ class Co3dDataset(Dataset):
210
+ def __init__(
211
+ self,
212
+ category,
213
+ split="train",
214
+ skip=2,
215
+ img_size=1024,
216
+ num_images=4,
217
+ mask_images=False,
218
+ single_id=0,
219
+ bbox=False,
220
+ modifier_token=None,
221
+ addreg=False,
222
+ drop_ratio=0.5,
223
+ drop_txt=0.1,
224
+ categoryname=None,
225
+ aligncameras=False,
226
+ repeat=100,
227
+ addlen=False,
228
+ onlyref=False,
229
+ ):
230
+ """
231
+ Args:
232
+ category (iterable): List of categories to use. If "all" is in the list,
233
+ all training categories are used.
234
+ num_images (int): Default number of images in each batch.
235
+ normalize_cameras (bool): If True, normalizes cameras so that the
236
+ intersection of the optical axes is placed at the origin and the norm
237
+ of the first camera translation is 1.
238
+ mask_images (bool): If True, masks out the background of the images.
239
+ """
240
+ # category = CATEGORIES
241
+ category = sorted(category.split(','))
242
+ self.category = category
243
+ self.single_id = single_id
244
+ self.addlen = addlen
245
+ self.onlyref = onlyref
246
+ self.categoryname = categoryname
247
+ self.bbox = bbox
248
+ self.modifier_token = modifier_token
249
+ self.addreg = addreg
250
+ self.drop_txt = drop_txt
251
+ self.skip = skip
252
+ if self.addreg:
253
+ with open(f'data/regularization/{category[0]}_sp_generated/caption.txt', "r") as f:
254
+ self.regcaptions = f.read().splitlines()
255
+ self.reglen = len(self.regcaptions)
256
+ self.regimpath = f'data/regularization/{category[0]}_sp_generated'
257
+
258
+ self.low_quality_translations = []
259
+ self.rotations = {}
260
+ self.category_map = {}
261
+ co3d_dir = CO3D_DIR
262
+ for c in category:
263
+ subset = 'fewview_dev'
264
+ category_dir = osp.join(co3d_dir, c)
265
+ frame_file = osp.join(category_dir, "frame_annotations.jgz")
266
+ sequence_file = osp.join(category_dir, "sequence_annotations.jgz")
267
+ subset_lists_file = osp.join(category_dir, f"set_lists/set_lists_{subset}.json")
268
+ bbox_file = osp.join(category_dir, f"{c}_bbox.jgz")
269
+
270
+ with open(subset_lists_file) as f:
271
+ subset_lists_data = json.load(f)
272
+
273
+ with gzip.open(sequence_file, "r") as fin:
274
+ sequence_data = json.loads(fin.read())
275
+
276
+ with gzip.open(bbox_file, "r") as fin:
277
+ bbox_data = json.loads(fin.read())
278
+
279
+ with gzip.open(frame_file, "r") as fin:
280
+ frame_data = json.loads(fin.read())
281
+
282
+ frame_data_processed = {}
283
+ for f_data in frame_data:
284
+ sequence_name = f_data["sequence_name"]
285
+ if sequence_name not in frame_data_processed:
286
+ frame_data_processed[sequence_name] = {}
287
+ frame_data_processed[sequence_name][f_data["frame_number"]] = f_data
288
+
289
+ good_quality_sequences = set()
290
+ for seq_data in sequence_data:
291
+ if seq_data["viewpoint_quality_score"] > 0.5:
292
+ good_quality_sequences.add(seq_data["sequence_name"])
293
+
294
+ for subset in ["train"]:
295
+ for seq_name, frame_number, filepath in subset_lists_data[subset]:
296
+ if seq_name not in good_quality_sequences:
297
+ continue
298
+
299
+ if seq_name not in self.rotations:
300
+ self.rotations[seq_name] = []
301
+ self.category_map[seq_name] = c
302
+
303
+ mask_path = filepath.replace("images", "masks").replace(".jpg", ".png")
304
+
305
+ frame_data = frame_data_processed[seq_name][frame_number]
306
+
307
+ self.rotations[seq_name].append(
308
+ {
309
+ "filepath": filepath,
310
+ "R": frame_data["viewpoint"]["R"],
311
+ "T": frame_data["viewpoint"]["T"],
312
+ "focal_length": frame_data["viewpoint"]["focal_length"],
313
+ "principal_point": frame_data["viewpoint"]["principal_point"],
314
+ "mask": mask_path,
315
+ "txt": "a car",
316
+ "bbox": bbox_data[mask_path]
317
+ }
318
+ )
319
+
320
+ for seq_name in self.rotations:
321
+ seq_data = self.rotations[seq_name]
322
+ cameras = PerspectiveCameras(
323
+ focal_length=[data["focal_length"] for data in seq_data],
324
+ principal_point=[data["principal_point"] for data in seq_data],
325
+ R=[data["R"] for data in seq_data],
326
+ T=[data["T"] for data in seq_data],
327
+ )
328
+
329
+ normalized_cameras, _, _, _, _ = normalize_cameras(cameras)
330
+ if aligncameras:
331
+ normalized_cameras = centerandalign(cameras)
332
+
333
+ if normalized_cameras == -1:
334
+ print("Error in normalizing cameras: camera scale was 0")
335
+ del self.rotations[seq_name]
336
+ continue
337
+
338
+ for i, data in enumerate(seq_data):
339
+ self.rotations[seq_name][i]["R"] = normalized_cameras.R[i]
340
+ self.rotations[seq_name][i]["T"] = normalized_cameras.T[i]
341
+ self.rotations[seq_name][i]["R_original"] = torch.from_numpy(np.array(seq_data[i]["R"]))
342
+ self.rotations[seq_name][i]["T_original"] = torch.from_numpy(np.array(seq_data[i]["T"]))
343
+
344
+ # Make sure translations are not ridiculous
345
+ if self.rotations[seq_name][i]["T"][0] + self.rotations[seq_name][i]["T"][1] + self.rotations[seq_name][i]["T"][2] > 1e5:
346
+ bad_seq = True
347
+ self.low_quality_translations.append(seq_name)
348
+ break
349
+
350
+ for seq_name in self.low_quality_translations:
351
+ if seq_name in self.rotations:
352
+ del self.rotations[seq_name]
353
+
354
+ self.sequence_list = list(self.rotations.keys())
355
+
356
+ self.transform = transforms.Compose(
357
+ [
358
+ transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
359
+ transforms.ToTensor(),
360
+ transforms.Lambda(lambda x: x * 2.0 - 1.0)
361
+ ]
362
+ )
363
+ self.transformim = transforms.Compose(
364
+ [
365
+ transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC),
366
+ transforms.CenterCrop(img_size),
367
+ transforms.ToTensor(),
368
+ transforms.Lambda(lambda x: x * 2.0 - 1.0)
369
+ ]
370
+ )
371
+ self.transformmask = transforms.Compose(
372
+ [
373
+ transforms.Resize(img_size // 8),
374
+ transforms.ToTensor(),
375
+ ]
376
+ )
377
+
378
+ self.num_images = num_images
379
+ self.image_size = img_size
380
+ self.normalize_cameras = normalize_cameras
381
+ self.mask_images = mask_images
382
+ self.drop_ratio = drop_ratio
383
+ self.kernel_tensor = torch.ones((1, 1, 7, 7))
384
+ self.repeat = repeat
385
+ print(self.sequence_list, "$$$$$$$$$$$$$$$$$$$$$")
386
+ self.valid_ids = np.arange(0, len(self.rotations[self.sequence_list[self.single_id]]), skip).tolist()
387
+ if split == 'test':
388
+ self.valid_ids = list(set(np.arange(0, len(self.rotations[self.sequence_list[self.single_id]])).tolist()).difference(self.valid_ids))
389
+
390
+ print(
391
+ f"Low quality translation sequences, not used: {self.low_quality_translations}"
392
+ )
393
+ print(f"Data size: {len(self)}")
394
+
395
+ def __len__(self):
396
+ return (len(self.valid_ids))*self.repeat + (1 if self.addlen else 0)
397
+
398
+ def _padded_bbox(self, bbox, w, h):
399
+ if w < h:
400
+ bbox = np.array([0, 0, w, h])
401
+ else:
402
+ bbox = np.array([0, 0, w, h])
403
+ return square_bbox(bbox.astype(np.float32))
404
+
405
+ def _crop_bbox(self, bbox, w, h):
406
+ bbox = square_bbox(bbox.astype(np.float32))
407
+
408
+ side_length = bbox[2] - bbox[0]
409
+ center = (bbox[:2] + bbox[2:]) / 2
410
+ extent = side_length / 2
411
+
412
+ # Final coordinates need to be integer for cropping.
413
+ ul = (center - extent).round().astype(int)
414
+ lr = ul + np.round(2 * extent).astype(int)
415
+ return np.concatenate((ul, lr))
416
+
417
+ def _crop_image(self, image, bbox, white_bg=False):
418
+ if white_bg:
419
+ # Only support PIL Images
420
+ image_crop = Image.new(
421
+ "RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255)
422
+ )
423
+ image_crop.paste(image, (-bbox[0], -bbox[1]))
424
+ else:
425
+ image_crop = transforms.functional.crop(
426
+ image,
427
+ top=bbox[1],
428
+ left=bbox[0],
429
+ height=bbox[3] - bbox[1],
430
+ width=bbox[2] - bbox[0],
431
+ )
432
+ return image_crop
433
+
434
+ def __getitem__(self, index, specific_id=None, validation=False):
435
+ sequence_name = self.sequence_list[self.single_id]
436
+
437
+ metadata = self.rotations[sequence_name]
438
+
439
+ if validation:
440
+ drop_text = False
441
+ drop_im = False
442
+ else:
443
+ drop_im = np.random.uniform(0, 1) < self.drop_ratio
444
+ if not drop_im:
445
+ drop_text = np.random.uniform(0, 1) < self.drop_txt
446
+ else:
447
+ drop_text = False
448
+
449
+ size = self.image_size
450
+
451
+ # sample reference ids
452
+ listofindices = self.valid_ids.copy()
453
+ max_diff = len(listofindices) // (self.num_images-1)
454
+ if (index*self.skip) % len(metadata) in listofindices:
455
+ listofindices.remove((index*self.skip) % len(metadata))
456
+ references = np.random.choice(np.arange(0, len(listofindices)+1, max_diff), self.num_images-1, replace=False)
457
+ rem = np.random.randint(0, max_diff)
458
+ references = [listofindices[(x + rem) % len(listofindices)] for x in references]
459
+ ids = [(index*self.skip) % len(metadata)] + references
460
+
461
+ # special case to save features corresponding to ref image as part of model buffer
462
+ if self.onlyref:
463
+ ids = references + [(index*self.skip) % len(metadata)]
464
+ if specific_id is not None: # remove this later
465
+ ids = specific_id
466
+
467
+ # get data
468
+ batch = self.get_data(index=self.single_id, ids=ids)
469
+
470
+ # text prompt
471
+ if self.modifier_token is not None:
472
+ name = self.category[0] if self.categoryname is None else self.categoryname
473
+ batch['txt'] = [f'photo of a {self.modifier_token} {name}' for _ in range(len(batch['txt']))]
474
+
475
+ # replace with regularization image if drop_im
476
+ if drop_im and self.addreg:
477
+ select_id = np.random.randint(0, self.reglen)
478
+ batch["image"] = [self.transformim(Image.open(f'{self.regimpath}/images/{select_id}.png').convert('RGB'))]
479
+ batch['txt'] = [self.regcaptions[select_id]]
480
+ batch["original_size_as_tuple"] = torch.ones_like(batch["original_size_as_tuple"])*1024
481
+
482
+ # create camera class and adjust intrinsics for crop
483
+ cameras = [PerspectiveCameras(R=batch['R'][i].unsqueeze(0),
484
+ T=batch['T'][i].unsqueeze(0),
485
+ focal_length=batch['focal_lengths'][i].unsqueeze(0),
486
+ principal_point=batch['principal_points'][i].unsqueeze(0),
487
+ image_size=self.image_size
488
+ )
489
+ for i in range(len(ids))]
490
+ for i, cam in enumerate(cameras):
491
+ adjust_camera_to_bbox_crop_(cam, batch["original_size_as_tuple"][i, :2], batch["crop_coords"][i])
492
+ adjust_camera_to_image_scale_(cam, batch["original_size_as_tuple"][i, 2:], torch.tensor([self.image_size, self.image_size]))
493
+
494
+ # create mask and dilated mask for mask based losses
495
+ batch["depth"] = batch["mask"].clone()
496
+ batch["mask"] = torch.clamp(torch.nn.functional.conv2d(batch["mask"], self.kernel_tensor, padding='same'), 0, 1)
497
+ if not self.mask_images:
498
+ batch["mask"] = [None for i in range(len(ids))]
499
+
500
+ # special case to save features corresponding to zero image
501
+ if index == self.__len__()-1 and self.addlen:
502
+ batch["image"][0] *= 0.
503
+
504
+ return {"jpg": batch["image"][0],
505
+ "txt": batch["txt"][0] if not drop_text else "",
506
+ "jpg_ref": batch["image"][1:] if not drop_im else torch.stack([2*torch.rand_like(batch["image"][0])-1. for _ in range(len(ids)-1)], dim=0),
507
+ "txt_ref": batch["txt"][1:] if not drop_im else ["" for _ in range(len(ids)-1)],
508
+ "pose": cameras,
509
+ "mask": batch["mask"][0] if not drop_im else torch.ones_like(batch["mask"][0]),
510
+ "mask_ref": batch["masks_padding"][1:],
511
+ "depth": batch["depth"][0] if len(batch["depth"]) > 0 else None,
512
+ "filepaths": batch["filepaths"],
513
+ "original_size_as_tuple": batch["original_size_as_tuple"][0][2:],
514
+ "target_size_as_tuple": torch.ones_like(batch["original_size_as_tuple"][0][2:])*size,
515
+ "crop_coords_top_left": torch.zeros_like(batch["crop_coords"][0][:2]),
516
+ "original_size_as_tuple_ref": batch["original_size_as_tuple"][1:][:, 2:],
517
+ "target_size_as_tuple_ref": torch.ones_like(batch["original_size_as_tuple"][1:][:, 2:])*size,
518
+ "crop_coords_top_left_ref": torch.zeros_like(batch["crop_coords"][1:][:, :2]),
519
+ "drop_im": torch.Tensor([1-drop_im*1.])
520
+ }
521
+
522
+ def get_data(self, index=None, sequence_name=None, ids=(0, 1)):
523
+ if sequence_name is None:
524
+ sequence_name = self.sequence_list[index]
525
+ metadata = self.rotations[sequence_name]
526
+ category = self.category_map[sequence_name]
527
+ annos = [metadata[i] for i in ids]
528
+ images = []
529
+ rotations = []
530
+ translations = []
531
+ focal_lengths = []
532
+ principal_points = []
533
+ txts = []
534
+ masks = []
535
+ filepaths = []
536
+ images_transformed = []
537
+ masks_transformed = []
538
+ original_size_as_tuple = []
539
+ crop_parameters = []
540
+ masks_padding = []
541
+ depths = []
542
+
543
+ for counter, anno in enumerate(annos):
544
+ filepath = anno["filepath"]
545
+ filepaths.append(filepath)
546
+ image = Image.open(osp.join(CO3D_DIR, filepath)).convert("RGB")
547
+
548
+ mask_name = osp.basename(filepath.replace(".jpg", ".png"))
549
+
550
+ mask_path = osp.join(
551
+ CO3D_DIR, category, sequence_name, "masks", mask_name
552
+ )
553
+ mask = Image.open(mask_path).convert("L")
554
+
555
+ if mask.size != image.size:
556
+ mask = mask.resize(image.size)
557
+
558
+ mask_padded = Image.fromarray((np.ones_like(mask) > 0))
559
+ mask = Image.fromarray((np.array(mask) > 125))
560
+ masks.append(mask)
561
+
562
+ # crop image around object
563
+ w, h = image.width, image.height
564
+ bbox = np.array(anno["bbox"])
565
+ if len(bbox) == 0:
566
+ bbox = np.array([0, 0, w, h])
567
+
568
+ if self.bbox and counter > 0:
569
+ bbox = self._crop_bbox(bbox, w, h)
570
+ else:
571
+ bbox = self._padded_bbox(None, w, h)
572
+ image = self._crop_image(image, bbox)
573
+ mask = self._crop_image(mask, bbox)
574
+ mask_padded = self._crop_image(mask_padded, bbox)
575
+ masks_padding.append(self.transformmask(mask_padded))
576
+ images_transformed.append(self.transform(image))
577
+ masks_transformed.append(self.transformmask(mask))
578
+
579
+ crop_parameters.append(torch.tensor([bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1] ]).int())
580
+ original_size_as_tuple.append(torch.tensor([w, h, bbox[2] - bbox[0], bbox[3] - bbox[1]]))
581
+ images.append(image)
582
+ rotations.append(anno["R"])
583
+ translations.append(anno["T"])
584
+ focal_lengths.append(torch.tensor(anno["focal_length"]))
585
+ principal_points.append(torch.tensor(anno["principal_point"]))
586
+ txts.append(anno["txt"])
587
+
588
+ images = images_transformed
589
+ batch = {
590
+ "model_id": sequence_name,
591
+ "category": category,
592
+ "original_size_as_tuple": torch.stack(original_size_as_tuple),
593
+ "crop_coords": torch.stack(crop_parameters),
594
+ "n": len(metadata),
595
+ "ind": torch.tensor(ids),
596
+ "txt": txts,
597
+ "filepaths": filepaths,
598
+ "masks_padding": torch.stack(masks_padding) if len(masks_padding) > 0 else [],
599
+ "depth": torch.stack(depths) if len(depths) > 0 else [],
600
+ }
601
+
602
+ batch["R"] = torch.stack(rotations)
603
+ batch["T"] = torch.stack(translations)
604
+ batch["focal_lengths"] = torch.stack(focal_lengths)
605
+ batch["principal_points"] = torch.stack(principal_points)
606
+
607
+ # Add images
608
+ if self.transform is None:
609
+ batch["image"] = images
610
+ else:
611
+ batch["image"] = torch.stack(images)
612
+ batch["mask"] = torch.stack(masks_transformed)
613
+
614
+ return batch
615
+
616
+ @staticmethod
617
+ def collate_fn(batch):
618
+ """A function to collate the data across batches. This function must be passed to pytorch's DataLoader to collate batches.
619
+ Args:
620
+ batch(list): List of objects returned by this class' __getitem__ function. This is given by pytorch's dataloader that calls __getitem__
621
+ multiple times and expects a collated batch.
622
+ Returns:
623
+ dict: The collated dictionary representing the data in the batch.
624
+ """
625
+ result = {
626
+ "jpg": [],
627
+ "txt": [],
628
+ "jpg_ref": [],
629
+ "txt_ref": [],
630
+ "pose": [],
631
+ "original_size_as_tuple": [],
632
+ "original_size_as_tuple_ref": [],
633
+ "crop_coords_top_left": [],
634
+ "crop_coords_top_left_ref": [],
635
+ "target_size_as_tuple_ref": [],
636
+ "target_size_as_tuple": [],
637
+ "drop_im": [],
638
+ "mask_ref": [],
639
+ }
640
+ if batch[0]["mask"] is not None:
641
+ result["mask"] = []
642
+ if batch[0]["depth"] is not None:
643
+ result["depth"] = []
644
+
645
+ for batch_obj in batch:
646
+ for key in result.keys():
647
+ result[key].append(batch_obj[key])
648
+ for key in result.keys():
649
+ if not (key == 'pose' or 'txt' in key or 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key):
650
+ result[key] = torch.stack(result[key], dim=0)
651
+ elif 'txt_ref' in key:
652
+ result[key] = [item for sublist in result[key] for item in sublist]
653
+ elif 'size_as_tuple_ref' in key or 'coords_top_left_ref' in key:
654
+ result[key] = torch.cat(result[key], dim=0)
655
+ elif 'pose' in key:
656
+ result[key] = [join_cameras_as_batch(cameras) for cameras in result[key]]
657
+
658
+ return result
659
+
660
+
661
+ class CustomDataDictLoader(pl.LightningDataModule):
662
+ def __init__(
663
+ self,
664
+ category,
665
+ batch_size,
666
+ mask_images=False,
667
+ skip=1,
668
+ img_size=1024,
669
+ num_images=4,
670
+ num_workers=0,
671
+ shuffle=True,
672
+ single_id=0,
673
+ modifier_token=None,
674
+ bbox=False,
675
+ addreg=False,
676
+ drop_ratio=0.5,
677
+ jitter=False,
678
+ drop_txt=0.1,
679
+ categoryname=None,
680
+ ):
681
+ super().__init__()
682
+
683
+ self.batch_size = batch_size
684
+ self.num_workers = num_workers
685
+ self.shuffle = shuffle
686
+ self.train_dataset = Co3dDataset(category,
687
+ img_size=img_size,
688
+ mask_images=mask_images,
689
+ skip=skip,
690
+ num_images=num_images,
691
+ single_id=single_id,
692
+ modifier_token=modifier_token,
693
+ bbox=bbox,
694
+ addreg=addreg,
695
+ drop_ratio=drop_ratio,
696
+ drop_txt=drop_txt,
697
+ categoryname=categoryname,
698
+ )
699
+ self.val_dataset = Co3dDataset(category,
700
+ img_size=img_size,
701
+ mask_images=mask_images,
702
+ skip=skip,
703
+ num_images=2,
704
+ single_id=single_id,
705
+ modifier_token=modifier_token,
706
+ bbox=bbox,
707
+ addreg=addreg,
708
+ drop_ratio=0.,
709
+ drop_txt=0.,
710
+ categoryname=categoryname,
711
+ repeat=1,
712
+ addlen=True,
713
+ onlyref=True,
714
+ )
715
+ self.test_dataset = Co3dDataset(category,
716
+ img_size=img_size,
717
+ mask_images=mask_images,
718
+ split="test",
719
+ skip=skip,
720
+ num_images=2,
721
+ single_id=single_id,
722
+ modifier_token=modifier_token,
723
+ bbox=False,
724
+ addreg=addreg,
725
+ drop_ratio=0.,
726
+ drop_txt=0.,
727
+ categoryname=categoryname,
728
+ repeat=1,
729
+ )
730
+ self.collate_fn = Co3dDataset.collate_fn
731
+
732
+ def prepare_data(self):
733
+ pass
734
+
735
+ def train_dataloader(self):
736
+ return DataLoader(
737
+ self.train_dataset,
738
+ batch_size=self.batch_size,
739
+ shuffle=self.shuffle,
740
+ num_workers=self.num_workers,
741
+ collate_fn=self.collate_fn,
742
+ drop_last=True,
743
+ )
744
+
745
+ def test_dataloader(self):
746
+ return DataLoader(
747
+ self.train_dataset,
748
+ batch_size=self.batch_size,
749
+ shuffle=False,
750
+ num_workers=self.num_workers,
751
+ collate_fn=self.collate_fn,
752
+ )
753
+
754
+ def val_dataloader(self):
755
+ return DataLoader(
756
+ self.val_dataset,
757
+ batch_size=self.batch_size,
758
+ shuffle=False,
759
+ num_workers=self.num_workers,
760
+ collate_fn=self.collate_fn,
761
+ drop_last=True
762
+ )
sgm/lr_scheduler.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ warm_up_steps,
12
+ lr_min,
13
+ lr_max,
14
+ lr_start,
15
+ max_decay_steps,
16
+ verbosity_interval=0,
17
+ ):
18
+ self.lr_warm_up_steps = warm_up_steps
19
+ self.lr_start = lr_start
20
+ self.lr_min = lr_min
21
+ self.lr_max = lr_max
22
+ self.lr_max_decay_steps = max_decay_steps
23
+ self.last_lr = 0.0
24
+ self.verbosity_interval = verbosity_interval
25
+
26
+ def schedule(self, n, **kwargs):
27
+ if self.verbosity_interval > 0:
28
+ if n % self.verbosity_interval == 0:
29
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
30
+ if n < self.lr_warm_up_steps:
31
+ lr = (
32
+ self.lr_max - self.lr_start
33
+ ) / self.lr_warm_up_steps * n + self.lr_start
34
+ self.last_lr = lr
35
+ return lr
36
+ else:
37
+ t = (n - self.lr_warm_up_steps) / (
38
+ self.lr_max_decay_steps - self.lr_warm_up_steps
39
+ )
40
+ t = min(t, 1.0)
41
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
42
+ 1 + np.cos(t * np.pi)
43
+ )
44
+ self.last_lr = lr
45
+ return lr
46
+
47
+ def __call__(self, n, **kwargs):
48
+ return self.schedule(n, **kwargs)
49
+
50
+
51
+ class LambdaWarmUpCosineScheduler2:
52
+ """
53
+ supports repeated iterations, configurable via lists
54
+ note: use with a base_lr of 1.0.
55
+ """
56
+
57
+ def __init__(
58
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
59
+ ):
60
+ assert (
61
+ len(warm_up_steps)
62
+ == len(f_min)
63
+ == len(f_max)
64
+ == len(f_start)
65
+ == len(cycle_lengths)
66
+ )
67
+ self.lr_warm_up_steps = warm_up_steps
68
+ self.f_start = f_start
69
+ self.f_min = f_min
70
+ self.f_max = f_max
71
+ self.cycle_lengths = cycle_lengths
72
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
73
+ self.last_f = 0.0
74
+ self.verbosity_interval = verbosity_interval
75
+
76
+ def find_in_interval(self, n):
77
+ interval = 0
78
+ for cl in self.cum_cycles[1:]:
79
+ if n <= cl:
80
+ return interval
81
+ interval += 1
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0:
88
+ print(
89
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90
+ f"current cycle {cycle}"
91
+ )
92
+ if n < self.lr_warm_up_steps[cycle]:
93
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
94
+ cycle
95
+ ] * n + self.f_start[cycle]
96
+ self.last_f = f
97
+ return f
98
+ else:
99
+ t = (n - self.lr_warm_up_steps[cycle]) / (
100
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
101
+ )
102
+ t = min(t, 1.0)
103
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
104
+ 1 + np.cos(t * np.pi)
105
+ )
106
+ self.last_f = f
107
+ return f
108
+
109
+ def __call__(self, n, **kwargs):
110
+ return self.schedule(n, **kwargs)
111
+
112
+
113
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114
+ def schedule(self, n, **kwargs):
115
+ cycle = self.find_in_interval(n)
116
+ n = n - self.cum_cycles[cycle]
117
+ if self.verbosity_interval > 0:
118
+ if n % self.verbosity_interval == 0:
119
+ print(
120
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121
+ f"current cycle {cycle}"
122
+ )
123
+
124
+ if n < self.lr_warm_up_steps[cycle]:
125
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
126
+ cycle
127
+ ] * n + self.f_start[cycle]
128
+ self.last_f = f
129
+ return f
130
+ else:
131
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
132
+ self.cycle_lengths[cycle] - n
133
+ ) / (self.cycle_lengths[cycle])
134
+ self.last_f = f
135
+ return f
sgm/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .autoencoder import AutoencodingEngine
2
+ from .diffusion import DiffusionEngine
sgm/models/autoencoder.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import abstractmethod
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Tuple, Union
5
+
6
+ import pytorch_lightning as pl
7
+ import torch
8
+ from omegaconf import ListConfig
9
+ from packaging import version
10
+ from safetensors.torch import load_file as load_safetensors
11
+
12
+ from ..modules.diffusionmodules.model import Decoder, Encoder
13
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
14
+ from ..modules.ema import LitEma
15
+ from ..util import default, get_obj_from_str, instantiate_from_config
16
+
17
+
18
+ class AbstractAutoencoder(pl.LightningModule):
19
+ """
20
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
21
+ unCLIP models, etc. Hence, it is fairly general, and specific features
22
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ ema_decay: Union[None, float] = None,
28
+ monitor: Union[None, str] = None,
29
+ input_key: str = "jpg",
30
+ ckpt_path: Union[None, str] = None,
31
+ ignore_keys: Union[Tuple, list, ListConfig] = (),
32
+ ):
33
+ super().__init__()
34
+ self.input_key = input_key
35
+ self.use_ema = ema_decay is not None
36
+ if monitor is not None:
37
+ self.monitor = monitor
38
+
39
+ if self.use_ema:
40
+ self.model_ema = LitEma(self, decay=ema_decay)
41
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
42
+
43
+ if ckpt_path is not None:
44
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
45
+
46
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
47
+ self.automatic_optimization = False
48
+
49
+ def init_from_ckpt(
50
+ self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple()
51
+ ) -> None:
52
+ if path.endswith("ckpt"):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ elif path.endswith("safetensors"):
55
+ sd = load_safetensors(path)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ keys = list(sd.keys())
60
+ for k in keys:
61
+ for ik in ignore_keys:
62
+ if re.match(ik, k):
63
+ print("Deleting key {} from state_dict.".format(k))
64
+ del sd[k]
65
+ missing, unexpected = self.load_state_dict(sd, strict=False)
66
+ print(
67
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
68
+ )
69
+ if len(missing) > 0:
70
+ print(f"Missing Keys: {missing}")
71
+ if len(unexpected) > 0:
72
+ print(f"Unexpected Keys: {unexpected}")
73
+
74
+ @abstractmethod
75
+ def get_input(self, batch) -> Any:
76
+ raise NotImplementedError()
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ # for EMA computation
80
+ if self.use_ema:
81
+ self.model_ema(self)
82
+
83
+ @contextmanager
84
+ def ema_scope(self, context=None):
85
+ if self.use_ema:
86
+ self.model_ema.store(self.parameters())
87
+ self.model_ema.copy_to(self)
88
+ if context is not None:
89
+ print(f"{context}: Switched to EMA weights")
90
+ try:
91
+ yield None
92
+ finally:
93
+ if self.use_ema:
94
+ self.model_ema.restore(self.parameters())
95
+ if context is not None:
96
+ print(f"{context}: Restored training weights")
97
+
98
+ @abstractmethod
99
+ def encode(self, *args, **kwargs) -> torch.Tensor:
100
+ raise NotImplementedError("encode()-method of abstract base class called")
101
+
102
+ @abstractmethod
103
+ def decode(self, *args, **kwargs) -> torch.Tensor:
104
+ raise NotImplementedError("decode()-method of abstract base class called")
105
+
106
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
107
+ print(f"loading >>> {cfg['target']} <<< optimizer from config")
108
+ return get_obj_from_str(cfg["target"])(
109
+ params, lr=lr, **cfg.get("params", dict())
110
+ )
111
+
112
+ def configure_optimizers(self) -> Any:
113
+ raise NotImplementedError()
114
+
115
+
116
+ class AutoencodingEngine(AbstractAutoencoder):
117
+ """
118
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
119
+ (we also restore them explicitly as special cases for legacy reasons).
120
+ Regularizations such as KL or VQ are moved to the regularizer class.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ *args,
126
+ encoder_config: Dict,
127
+ decoder_config: Dict,
128
+ loss_config: Dict,
129
+ regularizer_config: Dict,
130
+ optimizer_config: Union[Dict, None] = None,
131
+ lr_g_factor: float = 1.0,
132
+ **kwargs,
133
+ ):
134
+ super().__init__(*args, **kwargs)
135
+ # todo: add options to freeze encoder/decoder
136
+ self.encoder = instantiate_from_config(encoder_config)
137
+ self.decoder = instantiate_from_config(decoder_config)
138
+ self.loss = instantiate_from_config(loss_config)
139
+ self.regularization = instantiate_from_config(regularizer_config)
140
+ self.optimizer_config = default(
141
+ optimizer_config, {"target": "torch.optim.Adam"}
142
+ )
143
+ self.lr_g_factor = lr_g_factor
144
+
145
+ def get_input(self, batch: Dict) -> torch.Tensor:
146
+ # assuming unified data format, dataloader returns a dict.
147
+ # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc)
148
+ return batch[self.input_key]
149
+
150
+ def get_autoencoder_params(self) -> list:
151
+ params = (
152
+ list(self.encoder.parameters())
153
+ + list(self.decoder.parameters())
154
+ + list(self.regularization.get_trainable_parameters())
155
+ + list(self.loss.get_trainable_autoencoder_parameters())
156
+ )
157
+ return params
158
+
159
+ def get_discriminator_params(self) -> list:
160
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
161
+ return params
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.get_last_layer()
165
+
166
+ def encode(self, x: Any, return_reg_log: bool = False) -> Any:
167
+ z = self.encoder(x)
168
+ z, reg_log = self.regularization(z)
169
+ if return_reg_log:
170
+ return z, reg_log
171
+ return z
172
+
173
+ def decode(self, z: Any) -> torch.Tensor:
174
+ x = self.decoder(z)
175
+ return x
176
+
177
+ def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
178
+ z, reg_log = self.encode(x, return_reg_log=True)
179
+ dec = self.decode(z)
180
+ return z, dec, reg_log
181
+
182
+ def training_step(self, batch, batch_idx, optimizer_idx) -> Any:
183
+ x = self.get_input(batch)
184
+ z, xrec, regularization_log = self(x)
185
+
186
+ if optimizer_idx == 0:
187
+ # autoencode
188
+ aeloss, log_dict_ae = self.loss(
189
+ regularization_log,
190
+ x,
191
+ xrec,
192
+ optimizer_idx,
193
+ self.global_step,
194
+ last_layer=self.get_last_layer(),
195
+ split="train",
196
+ )
197
+
198
+ self.log_dict(
199
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True
200
+ )
201
+ return aeloss
202
+
203
+ if optimizer_idx == 1:
204
+ # discriminator
205
+ discloss, log_dict_disc = self.loss(
206
+ regularization_log,
207
+ x,
208
+ xrec,
209
+ optimizer_idx,
210
+ self.global_step,
211
+ last_layer=self.get_last_layer(),
212
+ split="train",
213
+ )
214
+ self.log_dict(
215
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
216
+ )
217
+ return discloss
218
+
219
+ def validation_step(self, batch, batch_idx) -> Dict:
220
+ log_dict = self._validation_step(batch, batch_idx)
221
+ with self.ema_scope():
222
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
223
+ log_dict.update(log_dict_ema)
224
+ return log_dict
225
+
226
+ def _validation_step(self, batch, batch_idx, postfix="") -> Dict:
227
+ x = self.get_input(batch)
228
+
229
+ z, xrec, regularization_log = self(x)
230
+ aeloss, log_dict_ae = self.loss(
231
+ regularization_log,
232
+ x,
233
+ xrec,
234
+ 0,
235
+ self.global_step,
236
+ last_layer=self.get_last_layer(),
237
+ split="val" + postfix,
238
+ )
239
+
240
+ discloss, log_dict_disc = self.loss(
241
+ regularization_log,
242
+ x,
243
+ xrec,
244
+ 1,
245
+ self.global_step,
246
+ last_layer=self.get_last_layer(),
247
+ split="val" + postfix,
248
+ )
249
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
250
+ log_dict_ae.update(log_dict_disc)
251
+ self.log_dict(log_dict_ae)
252
+ return log_dict_ae
253
+
254
+ def configure_optimizers(self) -> Any:
255
+ ae_params = self.get_autoencoder_params()
256
+ disc_params = self.get_discriminator_params()
257
+
258
+ opt_ae = self.instantiate_optimizer_from_config(
259
+ ae_params,
260
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
261
+ self.optimizer_config,
262
+ )
263
+ opt_disc = self.instantiate_optimizer_from_config(
264
+ disc_params, self.learning_rate, self.optimizer_config
265
+ )
266
+
267
+ return [opt_ae, opt_disc], []
268
+
269
+ @torch.no_grad()
270
+ def log_images(self, batch: Dict, **kwargs) -> Dict:
271
+ log = dict()
272
+ x = self.get_input(batch)
273
+ _, xrec, _ = self(x)
274
+ log["inputs"] = x
275
+ log["reconstructions"] = xrec
276
+ with self.ema_scope():
277
+ _, xrec_ema, _ = self(x)
278
+ log["reconstructions_ema"] = xrec_ema
279
+ return log
280
+
281
+
282
+ class AutoencoderKL(AutoencodingEngine):
283
+ def __init__(self, embed_dim: int, **kwargs):
284
+ ddconfig = kwargs.pop("ddconfig")
285
+ ckpt_path = kwargs.pop("ckpt_path", None)
286
+ ignore_keys = kwargs.pop("ignore_keys", ())
287
+ super().__init__(
288
+ encoder_config={"target": "torch.nn.Identity"},
289
+ decoder_config={"target": "torch.nn.Identity"},
290
+ regularizer_config={"target": "torch.nn.Identity"},
291
+ loss_config=kwargs.pop("lossconfig"),
292
+ **kwargs,
293
+ )
294
+ assert ddconfig["double_z"]
295
+ self.encoder = Encoder(**ddconfig)
296
+ self.decoder = Decoder(**ddconfig)
297
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
298
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
299
+ self.embed_dim = embed_dim
300
+
301
+ if ckpt_path is not None:
302
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
303
+
304
+ def encode(self, x):
305
+ assert (
306
+ not self.training
307
+ ), f"{self.__class__.__name__} only supports inference currently"
308
+ h = self.encoder(x)
309
+ moments = self.quant_conv(h)
310
+ posterior = DiagonalGaussianDistribution(moments)
311
+ return posterior
312
+
313
+ def decode(self, z, **decoder_kwargs):
314
+ z = self.post_quant_conv(z)
315
+ dec = self.decoder(z, **decoder_kwargs)
316
+ return dec
317
+
318
+
319
+ class AutoencoderKLInferenceWrapper(AutoencoderKL):
320
+ def encode(self, x):
321
+ return super().encode(x).sample()
322
+
323
+
324
+ class IdentityFirstStage(AbstractAutoencoder):
325
+ def __init__(self, *args, **kwargs):
326
+ super().__init__(*args, **kwargs)
327
+
328
+ def get_input(self, x: Any) -> Any:
329
+ return x
330
+
331
+ def encode(self, x: Any, *args, **kwargs) -> Any:
332
+ return x
333
+
334
+ def decode(self, x: Any, *args, **kwargs) -> Any:
335
+ return x
sgm/models/diffusion.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ from typing import Any, Dict, List, Tuple, Union, DefaultDict
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from omegaconf import ListConfig, OmegaConf
7
+ from safetensors.torch import load_file as load_safetensors
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ from einops import rearrange
10
+ import math
11
+ import torch.nn as nn
12
+ from ..modules import UNCONDITIONAL_CONFIG
13
+ from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
14
+ from ..modules.ema import LitEma
15
+ from ..util import (
16
+ default,
17
+ disabled_train,
18
+ get_obj_from_str,
19
+ instantiate_from_config,
20
+ log_txt_as_img,
21
+ )
22
+
23
+
24
+ import collections
25
+ from functools import partial
26
+
27
+
28
+ def save_activations(
29
+ activations: DefaultDict,
30
+ name: str,
31
+ module: nn.Module,
32
+ inp: Tuple,
33
+ out: torch.Tensor
34
+ ) -> None:
35
+ """PyTorch Forward hook to save outputs at each forward
36
+ pass. Mutates specified dict objects with each fwd pass.
37
+ """
38
+ if isinstance(out, tuple):
39
+ if out[1] is None:
40
+ activations[name].append(out[0].detach())
41
+
42
+ class DiffusionEngine(pl.LightningModule):
43
+ def __init__(
44
+ self,
45
+ network_config,
46
+ denoiser_config,
47
+ first_stage_config,
48
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
49
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
50
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
51
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
52
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
53
+ network_wrapper: Union[None, str] = None,
54
+ ckpt_path: Union[None, str] = None,
55
+ use_ema: bool = False,
56
+ ema_decay_rate: float = 0.9999,
57
+ scale_factor: float = 1.0,
58
+ disable_first_stage_autocast=False,
59
+ input_key: str = "jpg",
60
+ log_keys: Union[List, None] = None,
61
+ no_cond_log: bool = False,
62
+ compile_model: bool = False,
63
+ trainkeys='pose',
64
+ multiplier=0.05,
65
+ loss_rgb_lambda=20.,
66
+ loss_fg_lambda=10.,
67
+ loss_bg_lambda=20.,
68
+ ):
69
+ super().__init__()
70
+ self.log_keys = log_keys
71
+ self.input_key = input_key
72
+ self.trainkeys = trainkeys
73
+ self.multiplier = multiplier
74
+ self.loss_rgb_lambda = loss_rgb_lambda
75
+ self.loss_fg_lambda = loss_fg_lambda
76
+ self.loss_bg_lambda = loss_bg_lambda
77
+ self.rgb = network_config.params.rgb
78
+ self.rgb_predict = network_config.params.rgb_predict
79
+ self.add_token = ('modifier_token' in conditioner_config.params.emb_models[1].params)
80
+ self.optimizer_config = default(
81
+ optimizer_config, {"target": "torch.optim.AdamW"}
82
+ )
83
+ model = instantiate_from_config(network_config)
84
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
85
+ model, compile_model=compile_model
86
+ )
87
+
88
+ self.denoiser = instantiate_from_config(denoiser_config)
89
+ self.sampler = (
90
+ instantiate_from_config(sampler_config)
91
+ if sampler_config is not None
92
+ else None
93
+ )
94
+ self.conditioner = instantiate_from_config(
95
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
96
+ )
97
+ self.scheduler_config = scheduler_config
98
+ self._init_first_stage(first_stage_config)
99
+
100
+ self.loss_fn = (
101
+ instantiate_from_config(loss_fn_config)
102
+ if loss_fn_config is not None
103
+ else None
104
+ )
105
+
106
+ self.use_ema = use_ema
107
+ if self.use_ema:
108
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
109
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
110
+
111
+ self.scale_factor = scale_factor
112
+ self.disable_first_stage_autocast = disable_first_stage_autocast
113
+ self.no_cond_log = no_cond_log
114
+
115
+ if ckpt_path is not None:
116
+ self.init_from_ckpt(ckpt_path)
117
+
118
+ blocks = []
119
+ if self.trainkeys == 'poseattn':
120
+ for x in self.model.diffusion_model.named_parameters():
121
+ if not ('pose' in x[0] or 'transformer_blocks' in x[0]):
122
+ x[1].requires_grad = False
123
+ else:
124
+ if 'pose' in x[0]:
125
+ x[1].requires_grad = True
126
+ blocks.append(x[0].split('.pose')[0])
127
+
128
+ blocks = set(blocks)
129
+ for x in self.model.diffusion_model.named_parameters():
130
+ if 'transformer_blocks' in x[0]:
131
+ reqgrad = False
132
+ for each in blocks:
133
+ if each in x[0] and ('attn1' in x[0] or 'attn2' in x[0] or 'pose' in x[0]):
134
+ reqgrad = True
135
+ x[1].requires_grad = True
136
+ if not reqgrad:
137
+ x[1].requires_grad = False
138
+ elif self.trainkeys == 'pose':
139
+ for x in self.model.diffusion_model.named_parameters():
140
+ if not ('pose' in x[0]):
141
+ x[1].requires_grad = False
142
+ else:
143
+ x[1].requires_grad = True
144
+ elif self.trainkeys == 'all':
145
+ for x in self.model.diffusion_model.named_parameters():
146
+ x[1].requires_grad = True
147
+
148
+ self.model = self.model.to(memory_format=torch.channels_last)
149
+
150
+ def register_activation_hooks(
151
+ self,
152
+ ) -> None:
153
+ self.activations_dict = collections.defaultdict(list)
154
+ handles = []
155
+ for name, module in self.model.diffusion_model.named_modules():
156
+ if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
157
+ if hasattr(module, 'pose_emb_layers'):
158
+ handle = module.register_forward_hook(
159
+ partial(save_activations, self.activations_dict, name)
160
+ )
161
+ handles.append(handle)
162
+ self.handles = handles
163
+
164
+ def clear_rendered_feat(self,):
165
+ for name, module in self.model.diffusion_model.named_modules():
166
+ if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks':
167
+ if hasattr(module, 'pose_emb_layers'):
168
+ module.rendered_feat = None
169
+
170
+ def remove_activation_hooks(
171
+ self, handles
172
+ ) -> None:
173
+ for handle in handles:
174
+ handle.remove()
175
+
176
+ def init_from_ckpt(
177
+ self,
178
+ path: str,
179
+ ) -> None:
180
+ if path.endswith("ckpt"):
181
+ sd = torch.load(path, map_location="cpu")["state_dict"]
182
+ elif path.endswith("safetensors"):
183
+ sd = load_safetensors(path)
184
+ else:
185
+ raise NotImplementedError
186
+
187
+ missing, unexpected = self.load_state_dict(sd, strict=False)
188
+ print(
189
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
190
+ )
191
+ if len(missing) > 0:
192
+ print(f"Missing Keys: {missing}")
193
+ if len(unexpected) > 0:
194
+ print(f"Unexpected Keys: {unexpected}")
195
+
196
+ def _init_first_stage(self, config):
197
+ model = instantiate_from_config(config).eval()
198
+ model.train = disabled_train
199
+ for param in model.parameters():
200
+ param.requires_grad = False
201
+ self.first_stage_model = model
202
+
203
+ def get_input(self, batch):
204
+ return batch[self.input_key], batch[self.input_key + '_ref'] if self.input_key + '_ref' in batch else None, batch['pose'] if 'pose' in batch else None, batch['mask'] if "mask" in batch else None, batch['mask_ref'] if "mask_ref" in batch else None, batch['depth'] if "depth" in batch else None, batch['drop_im'] if "drop_im" in batch else 0.
205
+
206
+ @torch.no_grad()
207
+ def decode_first_stage(self, z):
208
+ z = 1.0 / self.scale_factor * z
209
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
210
+ out = self.first_stage_model.decode(z)
211
+ return out
212
+
213
+ @torch.no_grad()
214
+ def encode_first_stage(self, x):
215
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
216
+ z = self.first_stage_model.encode(x)
217
+ z = self.scale_factor * z
218
+ return z
219
+
220
+ def forward(self, x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch):
221
+ loss, loss_fg, loss_bg, loss_rgb = self.loss_fn(self.model, self.denoiser, self.conditioner, x, x_rgb, xr, pose, mask, mask_ref, opacity, batch)
222
+ loss_mean = loss.mean()
223
+ loss_dict = {"loss": loss_mean.item()}
224
+ if self.rgb and self.global_step > 0:
225
+ loss_fg = (loss_fg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
226
+ loss_bg = (loss_bg.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
227
+ loss_mean += self.loss_fg_lambda*loss_fg
228
+ loss_mean += self.loss_bg_lambda*loss_bg
229
+ loss_dict["loss_fg"] = loss_fg.item()
230
+ loss_dict["loss_bg"] = loss_bg.item()
231
+ if self.rgb_predict and loss_rgb.mean() > 0:
232
+ loss_rgb = (loss_rgb.mean(1)*drop_im.reshape(-1)).sum()/(drop_im.sum() + 1e-12)
233
+ loss_mean += self.loss_rgb_lambda*loss_rgb
234
+ loss_dict["loss_rgb"] = loss_rgb.item()
235
+ return loss_mean, loss_dict
236
+
237
+ def shared_step(self, batch: Dict) -> Any:
238
+ x, xr, pose, mask, mask_ref, opacity, drop_im = self.get_input(batch)
239
+ x_rgb = x.clone().detach()
240
+ x = self.encode_first_stage(x)
241
+ x = x.to(memory_format=torch.channels_last)
242
+ if xr is not None:
243
+ b, n = xr.shape[0], xr.shape[1]
244
+ xr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...")), "(b n) ... -> b n ...", b=b, n=n)
245
+ xr = drop_im.reshape(b, 1, 1, 1, 1)*xr + (1-drop_im.reshape(b, 1, 1, 1, 1))*torch.zeros_like(xr)
246
+ batch["global_step"] = self.global_step
247
+ loss, loss_dict = self(x, x_rgb, xr, pose, mask, mask_ref, opacity, drop_im, batch)
248
+ return loss, loss_dict
249
+
250
+ def training_step(self, batch, batch_idx):
251
+ loss, loss_dict = self.shared_step(batch)
252
+
253
+ self.log_dict(
254
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
255
+ )
256
+
257
+ self.log(
258
+ "global_step",
259
+ self.global_step,
260
+ prog_bar=True,
261
+ logger=True,
262
+ on_step=True,
263
+ on_epoch=False,
264
+ )
265
+
266
+ if self.scheduler_config is not None:
267
+ lr = self.optimizers().param_groups[0]["lr"]
268
+ self.log(
269
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
270
+ )
271
+ return loss
272
+
273
+ def validation_step(self, batch, batch_idx):
274
+ # print("validation data", len(self.trainer.val_dataloaders))
275
+ loss, loss_dict = self.shared_step(batch)
276
+ return loss
277
+
278
+ def on_train_start(self, *args, **kwargs):
279
+ if self.sampler is None or self.loss_fn is None:
280
+ raise ValueError("Sampler and loss function need to be set for training.")
281
+
282
+ def on_train_batch_end(self, *args, **kwargs):
283
+ if self.use_ema:
284
+ self.model_ema(self.model)
285
+
286
+ def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
287
+ optimizer.zero_grad(set_to_none=True)
288
+
289
+ @contextmanager
290
+ def ema_scope(self, context=None):
291
+ if self.use_ema:
292
+ self.model_ema.store(self.model.parameters())
293
+ self.model_ema.copy_to(self.model)
294
+ if context is not None:
295
+ print(f"{context}: Switched to EMA weights")
296
+ try:
297
+ yield None
298
+ finally:
299
+ if self.use_ema:
300
+ self.model_ema.restore(self.model.parameters())
301
+ if context is not None:
302
+ print(f"{context}: Restored training weights")
303
+
304
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
305
+ return get_obj_from_str(cfg["target"])(
306
+ params, lr=lr, **cfg.get("params", dict())
307
+ )
308
+
309
+ def configure_optimizers(self):
310
+ lr = self.learning_rate
311
+ params = []
312
+ blocks = []
313
+ lowlrparams = []
314
+ if self.trainkeys == 'poseattn':
315
+ lowlrparams = []
316
+ for x in self.model.diffusion_model.named_parameters():
317
+ if ('pose' in x[0]):
318
+ params += [x[1]]
319
+ blocks.append(x[0].split('.pose')[0])
320
+ print(x[0])
321
+ blocks = set(blocks)
322
+ for x in self.model.diffusion_model.named_parameters():
323
+ if 'transformer_blocks' in x[0]:
324
+ for each in blocks:
325
+ if each in x[0] and not ('pose' in x[0]) and ('attn1' in x[0] or 'attn2' in x[0]):
326
+ lowlrparams += [x[1]]
327
+ elif self.trainkeys == 'pose':
328
+ for x in self.model.diffusion_model.named_parameters():
329
+ if ('pose' in x[0]):
330
+ params += [x[1]]
331
+ print(x[0])
332
+ elif self.trainkeys == 'all':
333
+ lowlrparams = []
334
+ for x in self.model.diffusion_model.named_parameters():
335
+ if ('pose' in x[0]):
336
+ params += [x[1]]
337
+ print(x[0])
338
+ else:
339
+ lowlrparams += [x[1]]
340
+
341
+ for i, embedder in enumerate(self.conditioner.embedders[:2]):
342
+ if embedder.is_trainable:
343
+ params = params + list(embedder.parameters())
344
+ if self.add_token:
345
+ if i == 0:
346
+ for name, param in embedder.transformer.get_input_embeddings().named_parameters():
347
+ param.requires_grad = True
348
+ print(name, "conditional model param")
349
+ params += [param]
350
+ else:
351
+ for name, param in embedder.model.token_embedding.named_parameters():
352
+ param.requires_grad = True
353
+ print(name, "conditional model param")
354
+ params += [param]
355
+
356
+ if len(lowlrparams) > 0:
357
+ print("different optimizer groups")
358
+ opt = self.instantiate_optimizer_from_config([{'params': params}, {'params': lowlrparams, 'lr': self.multiplier*lr}], lr, self.optimizer_config)
359
+ else:
360
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
361
+ if self.scheduler_config is not None:
362
+ scheduler = instantiate_from_config(self.scheduler_config)
363
+ print("Setting up LambdaLR scheduler...")
364
+ scheduler = [
365
+ {
366
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
367
+ "interval": "step",
368
+ "frequency": 1,
369
+ }
370
+ ]
371
+ return [opt], scheduler
372
+ return opt
373
+
374
+ @torch.no_grad()
375
+ def sample(
376
+ self,
377
+ cond: Dict,
378
+ uc: Union[Dict, None] = None,
379
+ batch_size: int = 16,
380
+ num_steps=None,
381
+ randn=None,
382
+ shape: Union[None, Tuple, List] = None,
383
+ return_rgb=False,
384
+ mask=None,
385
+ init_im=None,
386
+ **kwargs,
387
+ ):
388
+ if randn is None:
389
+ randn = torch.randn(batch_size, *shape)
390
+
391
+ denoiser = lambda input, sigma, c: self.denoiser(
392
+ self.model, input, sigma, c, **kwargs
393
+ )
394
+ if mask is not None:
395
+ samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, mask=mask, init_im=init_im, num_steps=num_steps)
396
+ else:
397
+ samples, rgb_list = self.sampler(denoiser, randn.to(self.device), cond, uc=uc, num_steps=num_steps)
398
+ if return_rgb:
399
+ return samples, rgb_list
400
+ return samples
401
+
402
+ @torch.no_grad()
403
+ def samplemulti(
404
+ self,
405
+ cond,
406
+ uc=None,
407
+ batch_size: int = 16,
408
+ num_steps=None,
409
+ randn=None,
410
+ shape: Union[None, Tuple, List] = None,
411
+ return_rgb=False,
412
+ mask=None,
413
+ init_im=None,
414
+ multikwargs=None,
415
+ ):
416
+ if randn is None:
417
+ randn = torch.randn(batch_size, *shape)
418
+
419
+ samples, rgb_list = self.sampler(self.denoiser, self.model, randn.to(self.device), cond, uc=uc, num_steps=num_steps, multikwargs=multikwargs)
420
+ if return_rgb:
421
+ return samples, rgb_list
422
+ return samples
423
+
424
+ @torch.no_grad()
425
+ def log_conditionings(self, batch: Dict, n: int, refernce: bool = True) -> Dict:
426
+ """
427
+ Defines heuristics to log different conditionings.
428
+ These can be lists of strings (text-to-image), tensors, ints, ...
429
+ """
430
+ image_h, image_w = batch[self.input_key].shape[2:]
431
+ log = dict()
432
+
433
+ for embedder in self.conditioner.embedders:
434
+ if refernce:
435
+ check = (embedder.input_keys[0] in self.log_keys)
436
+ else:
437
+ check = (embedder.input_key in self.log_keys)
438
+ if (
439
+ (self.log_keys is None) or check
440
+ ) and not self.no_cond_log:
441
+ if refernce:
442
+ x = batch[embedder.input_keys[0]][:n]
443
+ else:
444
+ x = batch[embedder.input_key][:n]
445
+ if isinstance(x, torch.Tensor):
446
+ if x.dim() == 1:
447
+ # class-conditional, convert integer to string
448
+ x = [str(x[i].item()) for i in range(x.shape[0])]
449
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
450
+ elif x.dim() == 2:
451
+ # size and crop cond and the like
452
+ x = [
453
+ "x".join([str(xx) for xx in x[i].tolist()])
454
+ for i in range(x.shape[0])
455
+ ]
456
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
457
+ else:
458
+ raise NotImplementedError()
459
+ elif isinstance(x, (List, ListConfig)):
460
+ if isinstance(x[0], str):
461
+ # strings
462
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
463
+ else:
464
+ raise NotImplementedError()
465
+ else:
466
+ raise NotImplementedError()
467
+ if refernce:
468
+ log[embedder.input_keys[0]] = xc
469
+ else:
470
+ log[embedder.input_key] = xc
471
+ return log
472
+
473
+ @torch.no_grad()
474
+ def log_images(
475
+ self,
476
+ batch: Dict,
477
+ N: int = 8,
478
+ sample: bool = True,
479
+ ucg_keys: List[str] = None,
480
+ **kwargs,
481
+ ) -> Dict:
482
+ log = dict()
483
+
484
+ x, xr, pose, mask, mask_ref, depth, drop_im = self.get_input(batch)
485
+
486
+ if xr is not None:
487
+ conditioner_input_keys = [e.input_keys for e in self.conditioner.embedders]
488
+ else:
489
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
490
+
491
+ if ucg_keys:
492
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
493
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
494
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
495
+ )
496
+ else:
497
+ ucg_keys = conditioner_input_keys
498
+
499
+ c, uc = self.conditioner.get_unconditional_conditioning(
500
+ batch,
501
+ force_uc_zero_embeddings=ucg_keys
502
+ if len(self.conditioner.embedders) > 0
503
+ else [],
504
+ )
505
+
506
+ N = min(x.shape[0], N)
507
+ x = x.to(self.device)[:N]
508
+ zr = None
509
+ if xr is not None:
510
+ xr = xr.to(self.device)[:N]
511
+ b, n = xr.shape[0], xr.shape[1]
512
+ log["reference"] = rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)
513
+ zr = rearrange(self.encode_first_stage(rearrange(xr, "b n ... -> (b n) ...", b=b, n=n)), "(b n) ... -> b n ...", b=b, n=n)
514
+
515
+ log["inputs"] = x
516
+ b = x.shape[0]
517
+ if mask is not None:
518
+ log["mask"] = mask
519
+ if depth is not None:
520
+ log["depth"] = depth
521
+ z = self.encode_first_stage(x)
522
+
523
+ if uc is not None:
524
+ if xr is not None:
525
+ zr = torch.cat([torch.zeros_like(zr), zr])
526
+ drop_im = torch.cat([drop_im, drop_im])
527
+ if isinstance(pose, list):
528
+ pose = pose[:N]*2
529
+ else:
530
+ pose = torch.cat([pose[:N]] * 2)
531
+
532
+ sampling_kwargs = {'input_ref':zr}
533
+ sampling_kwargs['pose'] = pose
534
+ sampling_kwargs['mask_ref'] = None
535
+ sampling_kwargs['drop_im'] = drop_im
536
+
537
+ log["reconstructions"] = self.decode_first_stage(z)
538
+ log.update(self.log_conditionings(batch, N, refernce=True if xr is not None else False))
539
+
540
+ for k in c:
541
+ if isinstance(c[k], torch.Tensor):
542
+ if xr is not None:
543
+ c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to(self.device), (c, uc))
544
+ else:
545
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
546
+ if sample:
547
+ with self.ema_scope("Plotting"):
548
+ samples, rgb_list = self.sample(
549
+ c, shape=z.shape[1:], uc=uc, batch_size=N, return_rgb=True, **sampling_kwargs
550
+ )
551
+ samples = self.decode_first_stage(samples)
552
+ log["samples"] = samples
553
+ if len(rgb_list) > 0:
554
+ size = int(math.sqrt(rgb_list[0].size(1)))
555
+ log["predicted_rgb"] = rgb_list[0].reshape(-1, size, size, 3).permute(0, 3, 1, 2)
556
+ return log
sgm/modules/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .encoders.modules import GeneralConditioner
2
+
3
+ UNCONDITIONAL_CONFIG = {
4
+ "target": "sgm.modules.GeneralConditioner",
5
+ "params": {"emb_models": []},
6
+ }
sgm/modules/attention.py ADDED
@@ -0,0 +1,1202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import itertools
4
+ from inspect import isfunction
5
+ from typing import Any, Optional
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+ from packaging import version
11
+ from torch import nn
12
+ from .diffusionmodules.util import checkpoint
13
+ from torch.autograd import Function
14
+ from torch.cuda.amp import custom_bwd, custom_fwd
15
+
16
+ from ..modules.diffusionmodules.util import zero_module
17
+ from ..modules.nerfsd_pytorch3d import NerfSDModule, VolRender
18
+
19
+ logpy = logging.getLogger(__name__)
20
+
21
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
22
+ SDP_IS_AVAILABLE = True
23
+ from torch.backends.cuda import SDPBackend, sdp_kernel
24
+
25
+ BACKEND_MAP = {
26
+ SDPBackend.MATH: {
27
+ "enable_math": True,
28
+ "enable_flash": False,
29
+ "enable_mem_efficient": False,
30
+ },
31
+ SDPBackend.FLASH_ATTENTION: {
32
+ "enable_math": False,
33
+ "enable_flash": True,
34
+ "enable_mem_efficient": False,
35
+ },
36
+ SDPBackend.EFFICIENT_ATTENTION: {
37
+ "enable_math": False,
38
+ "enable_flash": False,
39
+ "enable_mem_efficient": True,
40
+ },
41
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
42
+ }
43
+ else:
44
+ from contextlib import nullcontext
45
+
46
+ SDP_IS_AVAILABLE = False
47
+ sdp_kernel = nullcontext
48
+ BACKEND_MAP = {}
49
+ logpy.warn(
50
+ f"No SDP backend available, likely because you are running in pytorch "
51
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
52
+ f"You might want to consider upgrading."
53
+ )
54
+
55
+ try:
56
+ import xformers
57
+ import xformers.ops
58
+
59
+ XFORMERS_IS_AVAILABLE = True
60
+ except:
61
+ XFORMERS_IS_AVAILABLE = False
62
+ logpy.warn("no module 'xformers'. Processing without...")
63
+
64
+
65
+ def exists(val):
66
+ return val is not None
67
+
68
+
69
+ def uniq(arr):
70
+ return {el: True for el in arr}.keys()
71
+
72
+
73
+ def default(val, d):
74
+ if exists(val):
75
+ return val
76
+ return d() if isfunction(d) else d
77
+
78
+
79
+ def max_neg_value(t):
80
+ return -torch.finfo(t.dtype).max
81
+
82
+
83
+ def init_(tensor):
84
+ dim = tensor.shape[-1]
85
+ std = 1 / math.sqrt(dim)
86
+ tensor.uniform_(-std, std)
87
+ return tensor
88
+
89
+
90
+ # feedforward
91
+ class GEGLU(nn.Module):
92
+ def __init__(self, dim_in, dim_out):
93
+ super().__init__()
94
+ self.proj = nn.Linear(dim_in, dim_out * 2)
95
+
96
+ def forward(self, x):
97
+ x, gate = self.proj(x).chunk(2, dim=-1)
98
+ return x * F.gelu(gate)
99
+
100
+
101
+ class FeedForward(nn.Module):
102
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
103
+ super().__init__()
104
+ inner_dim = int(dim * mult)
105
+ dim_out = default(dim_out, dim)
106
+ project_in = (
107
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
108
+ if not glu
109
+ else GEGLU(dim, inner_dim)
110
+ )
111
+
112
+ self.net = nn.Sequential(
113
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
114
+ )
115
+
116
+ def forward(self, x):
117
+ return self.net(x)
118
+
119
+
120
+ def Normalize(in_channels):
121
+ return torch.nn.GroupNorm(
122
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
123
+ )
124
+
125
+
126
+ class LinearAttention(nn.Module):
127
+ def __init__(self, dim, heads=4, dim_head=32):
128
+ super().__init__()
129
+ self.heads = heads
130
+ hidden_dim = dim_head * heads
131
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
132
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
133
+
134
+ def forward(self, x):
135
+ b, c, h, w = x.shape
136
+ qkv = self.to_qkv(x)
137
+ q, k, v = rearrange(
138
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
139
+ )
140
+ k = k.softmax(dim=-1)
141
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
142
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
143
+ out = rearrange(
144
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
145
+ )
146
+ return self.to_out(out)
147
+
148
+
149
+ class SpatialSelfAttention(nn.Module):
150
+ def __init__(self, in_channels):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels)
155
+ self.q = torch.nn.Conv2d(
156
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.k = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.v = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.proj_out = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b, c, h, w = q.shape
177
+ q = rearrange(q, "b c h w -> b (h w) c")
178
+ k = rearrange(k, "b c h w -> b c (h w)")
179
+ w_ = torch.einsum("bij,bjk->bik", q, k)
180
+
181
+ w_ = w_ * (int(c) ** (-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = rearrange(v, "b c h w -> b c (h w)")
186
+ w_ = rearrange(w_, "b i j -> b j i")
187
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
188
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
189
+ h_ = self.proj_out(h_)
190
+
191
+ return x + h_
192
+
193
+
194
+ class _TruncExp(Function): # pylint: disable=abstract-method
195
+ # Implementation from torch-ngp:
196
+ # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
197
+ @staticmethod
198
+ @custom_fwd(cast_inputs=torch.float32)
199
+ def forward(ctx, x): # pylint: disable=arguments-differ
200
+ ctx.save_for_backward(x)
201
+ return torch.exp(x)
202
+
203
+ @staticmethod
204
+ @custom_bwd
205
+ def backward(ctx, g): # pylint: disable=arguments-differ
206
+ x = ctx.saved_tensors[0]
207
+ return g * torch.exp(x.clamp(-15, 15))
208
+
209
+
210
+ trunc_exp = _TruncExp.apply
211
+ """Same as torch.exp, but with the backward pass clipped to prevent vanishing/exploding
212
+ gradients."""
213
+
214
+
215
+ class CrossAttention(nn.Module):
216
+ def __init__(
217
+ self,
218
+ query_dim,
219
+ context_dim=None,
220
+ heads=8,
221
+ dim_head=64,
222
+ dropout=0.0,
223
+ backend=None,
224
+ ):
225
+ super().__init__()
226
+ inner_dim = dim_head * heads
227
+ context_dim = default(context_dim, query_dim)
228
+
229
+ self.scale = dim_head**-0.5
230
+ self.heads = heads
231
+
232
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
233
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
234
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
235
+
236
+ self.to_out = nn.Sequential(
237
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
238
+ )
239
+ self.backend = backend
240
+
241
+ def forward(
242
+ self,
243
+ x,
244
+ context=None,
245
+ mask=None,
246
+ additional_tokens=None,
247
+ n_times_crossframe_attn_in_self=0,
248
+ ):
249
+ h = self.heads
250
+
251
+ if additional_tokens is not None:
252
+ # get the number of masked tokens at the beginning of the output sequence
253
+ n_tokens_to_mask = additional_tokens.shape[1]
254
+ # add additional token
255
+ x = torch.cat([additional_tokens, x], dim=1)
256
+
257
+ q = self.to_q(x)
258
+ context = default(context, x)
259
+ k = self.to_k(context)
260
+ v = self.to_v(context)
261
+
262
+ if n_times_crossframe_attn_in_self:
263
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
264
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
265
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
266
+ k = repeat(
267
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
268
+ )
269
+ v = repeat(
270
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
271
+ )
272
+
273
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
274
+
275
+ ## old
276
+ """
277
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
278
+ del q, k
279
+
280
+ if exists(mask):
281
+ mask = rearrange(mask, 'b ... -> b (...)')
282
+ max_neg_value = -torch.finfo(sim.dtype).max
283
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
284
+ sim.masked_fill_(~mask, max_neg_value)
285
+
286
+ # attention, what we cannot get enough of
287
+ sim = sim.softmax(dim=-1)
288
+
289
+ out = einsum('b i j, b j d -> b i d', sim, v)
290
+ """
291
+ ## new
292
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
293
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
294
+ out = F.scaled_dot_product_attention(
295
+ q, k, v, attn_mask=mask
296
+ ) # scale is dim_head ** -0.5 per default
297
+
298
+ del q, k, v
299
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
300
+
301
+ if additional_tokens is not None:
302
+ # remove additional token
303
+ out = out[:, n_tokens_to_mask:]
304
+ return self.to_out(out)
305
+
306
+
307
+ class MemoryEfficientCrossAttention(nn.Module):
308
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
309
+ def __init__(
310
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, add_lora=False, **kwargs
311
+ ):
312
+ super().__init__()
313
+ logpy.debug(
314
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
315
+ f"context_dim is {context_dim} and using {heads} heads with a "
316
+ f"dimension of {dim_head}."
317
+ )
318
+ inner_dim = dim_head * heads
319
+ context_dim = default(context_dim, query_dim)
320
+
321
+ self.heads = heads
322
+ self.dim_head = dim_head
323
+ self.add_lora = add_lora
324
+
325
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
326
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
327
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
328
+
329
+ self.to_out = nn.Sequential(
330
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
331
+ )
332
+ if add_lora:
333
+ r = 32
334
+ self.to_q_attn3_down = nn.Linear(query_dim, r, bias=False)
335
+ self.to_q_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
336
+ self.to_k_attn3_down = nn.Linear(context_dim, r, bias=False)
337
+ self.to_k_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
338
+ self.to_v_attn3_down = nn.Linear(context_dim, r, bias=False)
339
+ self.to_v_attn3_up = zero_module(nn.Linear(r, inner_dim, bias=False))
340
+ self.to_o_attn3_down = nn.Linear(inner_dim, r, bias=False)
341
+ self.to_o_attn3_up = zero_module(nn.Linear(r, query_dim, bias=False))
342
+ self.dropoutq = nn.Dropout(0.1)
343
+ self.dropoutk = nn.Dropout(0.1)
344
+ self.dropoutv = nn.Dropout(0.1)
345
+ self.dropouto = nn.Dropout(0.1)
346
+
347
+ nn.init.normal_(self.to_q_attn3_down.weight, std=1 / r)
348
+ nn.init.normal_(self.to_k_attn3_down.weight, std=1 / r)
349
+ nn.init.normal_(self.to_v_attn3_down.weight, std=1 / r)
350
+ nn.init.normal_(self.to_o_attn3_down.weight, std=1 / r)
351
+
352
+ self.attention_op: Optional[Any] = None
353
+
354
+ def forward(
355
+ self,
356
+ x,
357
+ context=None,
358
+ mask=None,
359
+ additional_tokens=None,
360
+ n_times_crossframe_attn_in_self=0,
361
+ ):
362
+ if additional_tokens is not None:
363
+ # get the number of masked tokens at the beginning of the output sequence
364
+ n_tokens_to_mask = additional_tokens.shape[1]
365
+ # add additional token
366
+ x = torch.cat([additional_tokens, x], dim=1)
367
+
368
+ context_k = context # b, n, c, h, w
369
+
370
+ q = self.to_q(x)
371
+ context = default(context, x)
372
+ context_k = default(context_k, x)
373
+ k = self.to_k(context_k)
374
+ v = self.to_v(context_k)
375
+ if self.add_lora:
376
+ q += self.dropoutq(self.to_q_attn3_up(self.to_q_attn3_down(x)))
377
+ k += self.dropoutk(self.to_k_attn3_up(self.to_k_attn3_down(context_k)))
378
+ v += self.dropoutv(self.to_v_attn3_up(self.to_v_attn3_down(context_k)))
379
+
380
+ if n_times_crossframe_attn_in_self:
381
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
382
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
383
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
384
+ k = repeat(
385
+ k[::n_times_crossframe_attn_in_self],
386
+ "b ... -> (b n) ...",
387
+ n=n_times_crossframe_attn_in_self,
388
+ )
389
+ v = repeat(
390
+ v[::n_times_crossframe_attn_in_self],
391
+ "b ... -> (b n) ...",
392
+ n=n_times_crossframe_attn_in_self,
393
+ )
394
+
395
+ b, _, _ = q.shape
396
+ q, k, v = map(
397
+ lambda t: t.unsqueeze(3)
398
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
399
+ .permute(0, 2, 1, 3)
400
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
401
+ .contiguous(),
402
+ (q, k, v),
403
+ )
404
+
405
+ attn_bias = None
406
+
407
+ # actually compute the attention, what we cannot get enough of
408
+ out = xformers.ops.memory_efficient_attention(
409
+ q, k, v, attn_bias=attn_bias, op=self.attention_op
410
+ )
411
+
412
+ # TODO: Use this directly in the attention operation, as a bias
413
+ if exists(mask):
414
+ raise NotImplementedError
415
+ out = (
416
+ out.unsqueeze(0)
417
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
418
+ .permute(0, 2, 1, 3)
419
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
420
+ )
421
+ if additional_tokens is not None:
422
+ # remove additional token
423
+ out = out[:, n_tokens_to_mask:]
424
+ final = self.to_out(out)
425
+ if self.add_lora:
426
+ final += self.dropouto(self.to_o_attn3_up(self.to_o_attn3_down(out)))
427
+ return final
428
+
429
+
430
+ class BasicTransformerBlock(nn.Module):
431
+ ATTENTION_MODES = {
432
+ "softmax": CrossAttention, # vanilla attention
433
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
434
+ }
435
+
436
+ def __init__(
437
+ self,
438
+ dim,
439
+ n_heads,
440
+ d_head,
441
+ dropout=0.0,
442
+ context_dim=None,
443
+ gated_ff=True,
444
+ checkpoint=True,
445
+ disable_self_attn=False,
446
+ attn_mode="softmax",
447
+ sdp_backend=None,
448
+ image_cross=False,
449
+ far=2,
450
+ num_samples=32,
451
+ add_lora=False,
452
+ rgb_predict=False,
453
+ mode='pixel-nerf',
454
+ average=False,
455
+ num_freqs=16,
456
+ use_prev_weights_imp_sample=False,
457
+ imp_sample_next_step=False,
458
+ stratified=False,
459
+ imp_sampling_percent=0.9,
460
+ near_plane=0.
461
+ ):
462
+
463
+ super().__init__()
464
+ assert attn_mode in self.ATTENTION_MODES
465
+ self.add_lora = add_lora
466
+ self.image_cross = image_cross
467
+ self.rgb_predict = rgb_predict
468
+ self.use_prev_weights_imp_sample = use_prev_weights_imp_sample
469
+ self.imp_sample_next_step = imp_sample_next_step
470
+ self.rendered_feat = None
471
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
472
+ logpy.warn(
473
+ f"Attention mode '{attn_mode}' is not available. Falling "
474
+ f"back to native attention. This is not a problem in "
475
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
476
+ f"version {torch.__version__}."
477
+ )
478
+ attn_mode = "softmax"
479
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
480
+ logpy.warn(
481
+ "We do not support vanilla attention anymore, as it is too "
482
+ "expensive. Sorry."
483
+ )
484
+ if not XFORMERS_IS_AVAILABLE:
485
+ assert (
486
+ False
487
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
488
+ else:
489
+ logpy.info("Falling back to xformers efficient attention.")
490
+ attn_mode = "softmax-xformers"
491
+ attn_cls = self.ATTENTION_MODES[attn_mode]
492
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
493
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
494
+ else:
495
+ assert sdp_backend is None
496
+ self.disable_self_attn = disable_self_attn
497
+ self.attn1 = attn_cls(
498
+ query_dim=dim,
499
+ heads=n_heads,
500
+ dim_head=d_head,
501
+ dropout=dropout,
502
+ add_lora=self.add_lora,
503
+ context_dim=context_dim if self.disable_self_attn else None,
504
+ backend=sdp_backend,
505
+ ) # is a self-attention if not self.disable_self_attn
506
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
507
+ self.attn2 = attn_cls(
508
+ query_dim=dim,
509
+ context_dim=context_dim,
510
+ heads=n_heads,
511
+ dim_head=d_head,
512
+ dropout=dropout,
513
+ add_lora=self.add_lora,
514
+ backend=sdp_backend,
515
+ ) # is self-attn if context is none
516
+ if image_cross:
517
+ self.pose_emb_layers = nn.Linear(2*dim, dim, bias=False)
518
+ nn.init.eye_(self.pose_emb_layers.weight)
519
+ self.pose_featurenerf = NerfSDModule(mode=mode,
520
+ out_channels=dim,
521
+ far_plane=far,
522
+ num_samples=num_samples,
523
+ rgb_predict=rgb_predict,
524
+ average=average,
525
+ num_freqs=num_freqs,
526
+ stratified=stratified,
527
+ imp_sampling_percent=imp_sampling_percent,
528
+ near_plane=near_plane,
529
+ )
530
+
531
+ self.renderer = VolRender()
532
+
533
+ self.norm1 = nn.LayerNorm(dim)
534
+ self.norm2 = nn.LayerNorm(dim)
535
+ self.norm3 = nn.LayerNorm(dim)
536
+ self.checkpoint = checkpoint
537
+ if self.checkpoint:
538
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
539
+
540
+ def forward(
541
+ self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
542
+ ):
543
+ kwargs = {"x": x}
544
+
545
+ if context is not None:
546
+ kwargs.update({"context": context})
547
+
548
+ if context_ref is not None:
549
+ kwargs.update({"context_ref": context_ref})
550
+
551
+ if pose is not None:
552
+ kwargs.update({"pose": pose})
553
+
554
+ if mask_ref is not None:
555
+ kwargs.update({"mask_ref": mask_ref})
556
+
557
+ if prev_weights is not None:
558
+ kwargs.update({"prev_weights": prev_weights})
559
+
560
+ if additional_tokens is not None:
561
+ kwargs.update({"additional_tokens": additional_tokens})
562
+
563
+ if n_times_crossframe_attn_in_self:
564
+ kwargs.update(
565
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
566
+ )
567
+
568
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
569
+ return checkpoint(
570
+ self._forward, (x, context, context_ref, pose, mask_ref, prev_weights), self.parameters(), self.checkpoint
571
+ )
572
+
573
+ def reference_attn(self, x, context_ref, context, pose, prev_weights, mask_ref):
574
+ feats, sigmas, dists, _, predicted_rgb, sigmas_uniform, dists_uniform = self.pose_featurenerf(pose,
575
+ context_ref,
576
+ mask_ref,
577
+ prev_weights=prev_weights if self.use_prev_weights_imp_sample else None,
578
+ imp_sample_next_step=self.imp_sample_next_step)
579
+
580
+ b, hw, d = feats.size()[:3]
581
+ feats = rearrange(feats, "b hw d ... -> b (hw d) ...")
582
+
583
+ feats = (
584
+ self.attn2(
585
+ self.norm2(feats), context=context,
586
+ )
587
+ + feats
588
+ )
589
+
590
+ feats = rearrange(feats, "b (hw d) ... -> b hw d ...", hw=hw, d=d)
591
+
592
+ sigmas_ = trunc_exp(sigmas)
593
+ if sigmas_uniform is not None:
594
+ sigmas_uniform = trunc_exp(sigmas_uniform)
595
+
596
+ context_ref, fg_mask, alphas, weights_uniform, predicted_rgb = self.renderer(feats, sigmas_, dists, densities_uniform=sigmas_uniform, dists_uniform=dists_uniform, return_weights_uniform=True, rgb=F.sigmoid(predicted_rgb) if predicted_rgb is not None else None)
597
+ if self.use_prev_weights_imp_sample:
598
+ prev_weights = weights_uniform
599
+
600
+ return context_ref, fg_mask, prev_weights, alphas, predicted_rgb
601
+
602
+ def _forward(
603
+ self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
604
+ ):
605
+ fg_mask = None
606
+ weights = None
607
+ alphas = None
608
+ predicted_rgb = None
609
+ xref = None
610
+
611
+ x = (
612
+ self.attn1(
613
+ self.norm1(x),
614
+ context=context if self.disable_self_attn else None,
615
+ additional_tokens=additional_tokens,
616
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
617
+ if not self.disable_self_attn
618
+ else 0,
619
+ )
620
+ + x
621
+ )
622
+ x = (
623
+ self.attn2(
624
+ self.norm2(x), context=context, additional_tokens=additional_tokens
625
+ )
626
+ + x
627
+ )
628
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float32):
629
+ if context_ref is not None:
630
+ xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x,
631
+ rearrange(context_ref, "(b n) ... -> b n ...", b=x.size(0), n=context_ref.size(0) // x.size(0)),
632
+ context,
633
+ pose,
634
+ prev_weights,
635
+ mask_ref)
636
+ x = self.pose_emb_layers(torch.cat([x, xref], -1))
637
+
638
+ x = self.ff(self.norm3(x)) + x
639
+ return x, fg_mask, weights, alphas, predicted_rgb
640
+
641
+
642
+ class BasicTransformerSingleLayerBlock(nn.Module):
643
+ ATTENTION_MODES = {
644
+ "softmax": CrossAttention, # vanilla attention
645
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
646
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
647
+ }
648
+
649
+ def __init__(
650
+ self,
651
+ dim,
652
+ n_heads,
653
+ d_head,
654
+ dropout=0.0,
655
+ context_dim=None,
656
+ gated_ff=True,
657
+ checkpoint=True,
658
+ attn_mode="softmax",
659
+ ):
660
+ super().__init__()
661
+ assert attn_mode in self.ATTENTION_MODES
662
+ attn_cls = self.ATTENTION_MODES[attn_mode]
663
+ self.attn1 = attn_cls(
664
+ query_dim=dim,
665
+ heads=n_heads,
666
+ dim_head=d_head,
667
+ dropout=dropout,
668
+ context_dim=context_dim,
669
+ )
670
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
671
+ self.norm1 = nn.LayerNorm(dim)
672
+ self.norm2 = nn.LayerNorm(dim)
673
+ self.checkpoint = checkpoint
674
+
675
+ def forward(self, x, context=None):
676
+ return checkpoint(
677
+ self._forward, (x, context), self.parameters(), self.checkpoint
678
+ )
679
+
680
+ def _forward(self, x, context=None):
681
+ x = self.attn1(self.norm1(x), context=context) + x
682
+ x = self.ff(self.norm2(x)) + x
683
+ return x
684
+
685
+
686
+ class SpatialTransformer(nn.Module):
687
+ """
688
+ Transformer block for image-like data.
689
+ First, project the input (aka embedding)
690
+ and reshape to b, t, d.
691
+ Then apply standard transformer action.
692
+ Finally, reshape to image
693
+ NEW: use_linear for more efficiency instead of the 1x1 convs
694
+ """
695
+
696
+ def __init__(
697
+ self,
698
+ in_channels,
699
+ n_heads,
700
+ d_head,
701
+ depth=1,
702
+ dropout=0.0,
703
+ context_dim=None,
704
+ disable_self_attn=False,
705
+ use_linear=False,
706
+ attn_type="softmax",
707
+ use_checkpoint=True,
708
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
709
+ sdp_backend=None,
710
+ image_cross=True,
711
+ rgb_predict=False,
712
+ far=2,
713
+ num_samples=32,
714
+ add_lora=False,
715
+ mode='feature-nerf',
716
+ average=False,
717
+ num_freqs=16,
718
+ use_prev_weights_imp_sample=False,
719
+ stratified=False,
720
+ poscontrol_interval=4,
721
+ imp_sampling_percent=0.9,
722
+ near_plane=0.
723
+ ):
724
+ super().__init__()
725
+ logpy.debug(
726
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
727
+ f"{in_channels} channels and {n_heads} heads."
728
+ )
729
+ from omegaconf import ListConfig
730
+
731
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
732
+ context_dim = [context_dim]
733
+ if exists(context_dim) and isinstance(context_dim, list):
734
+ if depth != len(context_dim):
735
+ logpy.warn(
736
+ f"{self.__class__.__name__}: Found context dims "
737
+ f"{context_dim} of depth {len(context_dim)}, which does not "
738
+ f"match the specified 'depth' of {depth}. Setting context_dim "
739
+ f"to {depth * [context_dim[0]]} now."
740
+ )
741
+ # depth does not match context dims.
742
+ assert all(
743
+ map(lambda x: x == context_dim[0], context_dim)
744
+ ), "need homogenous context_dim to match depth automatically"
745
+ context_dim = depth * [context_dim[0]]
746
+ elif context_dim is None:
747
+ context_dim = [None] * depth
748
+ self.in_channels = in_channels
749
+ inner_dim = n_heads * d_head
750
+ self.norm = Normalize(in_channels)
751
+
752
+ self.image_cross = image_cross
753
+ self.poscontrol_interval = poscontrol_interval
754
+
755
+ if not use_linear:
756
+ self.proj_in = nn.Conv2d(
757
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
758
+ )
759
+ else:
760
+ self.proj_in = nn.Linear(in_channels, inner_dim)
761
+
762
+ self.transformer_blocks = nn.ModuleList(
763
+ [
764
+ BasicTransformerBlock(
765
+ inner_dim,
766
+ n_heads,
767
+ d_head,
768
+ dropout=dropout,
769
+ context_dim=context_dim[d],
770
+ disable_self_attn=disable_self_attn,
771
+ attn_mode=attn_type,
772
+ checkpoint=use_checkpoint,
773
+ sdp_backend=sdp_backend,
774
+ image_cross=self.image_cross and (d % poscontrol_interval == 0),
775
+ far=far,
776
+ num_samples=num_samples,
777
+ add_lora=add_lora and self.image_cross and (d % poscontrol_interval == 0),
778
+ rgb_predict=rgb_predict,
779
+ mode=mode,
780
+ average=average,
781
+ num_freqs=num_freqs,
782
+ use_prev_weights_imp_sample=use_prev_weights_imp_sample,
783
+ imp_sample_next_step=(use_prev_weights_imp_sample and self.image_cross and (d % poscontrol_interval == 0) and depth >= poscontrol_interval and d < (depth // poscontrol_interval) * poscontrol_interval ),
784
+ stratified=stratified,
785
+ imp_sampling_percent=imp_sampling_percent,
786
+ near_plane=near_plane,
787
+ )
788
+ for d in range(depth)
789
+ ]
790
+ )
791
+ if not use_linear:
792
+ self.proj_out = zero_module(
793
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
794
+ )
795
+ else:
796
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
797
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
798
+ self.use_linear = use_linear
799
+
800
+ def forward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None):
801
+ # note: if no context is given, cross-attention defaults to self-attention
802
+ if xr is None:
803
+ if not isinstance(context, list):
804
+ context = [context]
805
+ b, c, h, w = x.shape
806
+ x_in = x
807
+ x = self.norm(x)
808
+ if not self.use_linear:
809
+ x = self.proj_in(x)
810
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
811
+ if self.use_linear:
812
+ x = self.proj_in(x)
813
+ for i, block in enumerate(self.transformer_blocks):
814
+ if i > 0 and len(context) == 1:
815
+ i = 0 # use same context for each block
816
+ x, _, _, _, _ = block(x, context=context[i])
817
+ if self.use_linear:
818
+ x = self.proj_out(x)
819
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
820
+ if not self.use_linear:
821
+ x = self.proj_out(x)
822
+ return x + x_in, None, None, None, None, None
823
+ else:
824
+ if not isinstance(context, list):
825
+ context = [context]
826
+ contextr = [contextr]
827
+ b, c, h, w = x.shape
828
+ b1, _, _, _ = xr.shape
829
+ x_in = x
830
+ xr_in = xr
831
+ fg_masks = []
832
+ alphas = []
833
+ rgbs = []
834
+
835
+ x = self.norm(x)
836
+ with torch.no_grad():
837
+ xr = self.norm(xr)
838
+
839
+ if not self.use_linear:
840
+ x = self.proj_in(x)
841
+ with torch.no_grad():
842
+ xr = self.proj_in(xr)
843
+
844
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
845
+ xr = rearrange(xr, "b1 c h w -> b1 (h w) c").contiguous()
846
+ if self.use_linear:
847
+ x = self.proj_in(x)
848
+ with torch.no_grad():
849
+ xr = self.proj_in(xr)
850
+
851
+ prev_weights = None
852
+ counter = 0
853
+ for i, block in enumerate(self.transformer_blocks):
854
+ if i > 0 and len(context) == 1:
855
+ i = 0 # use same context for each block
856
+ if self.image_cross and (counter % self.poscontrol_interval == 0):
857
+ with torch.no_grad():
858
+ xr, _, _, _, _ = block(xr, context=contextr[i])
859
+ x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=xr.detach(), pose=pose, mask_ref=mask_ref, prev_weights=prev_weights)
860
+ prev_weights = weights
861
+ fg_masks.append(fg_mask)
862
+ if alpha is not None:
863
+ alphas.append(alpha)
864
+ if rgb is not None:
865
+ rgbs.append(rgb)
866
+ else:
867
+ with torch.no_grad():
868
+ xr, _, _, _, _ = block(xr, context=contextr[i])
869
+ x, _, _, _, _ = block(x, context=context[i])
870
+ counter += 1
871
+ if self.use_linear:
872
+ x = self.proj_out(x)
873
+ with torch.no_grad():
874
+ xr = self.proj_out(xr)
875
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
876
+ xr = rearrange(xr, "b1 (h w) c -> b1 c h w", h=h, w=w).contiguous()
877
+ if not self.use_linear:
878
+ x = self.proj_out(x)
879
+ with torch.no_grad():
880
+ xr = self.proj_out(xr)
881
+ if len(fg_masks) > 0:
882
+ if len(rgbs) <= 0:
883
+ rgbs = None
884
+ if len(alphas) <= 0:
885
+ alphas = None
886
+ return x + x_in, (xr + xr_in).detach(), fg_masks, prev_weights, alphas, rgbs
887
+ else:
888
+ return x + x_in, (xr + xr_in).detach(), None, prev_weights, None, None
889
+
890
+
891
+ def benchmark_attn():
892
+ # Lets define a helpful benchmarking function:
893
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
894
+ device = "cuda" if torch.cuda.is_available() else "cpu"
895
+ import torch.nn.functional as F
896
+ import torch.utils.benchmark as benchmark
897
+
898
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
899
+ t0 = benchmark.Timer(
900
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
901
+ )
902
+ return t0.blocked_autorange().mean * 1e6
903
+
904
+ # Lets define the hyper-parameters of our input
905
+ batch_size = 32
906
+ max_sequence_len = 1024
907
+ num_heads = 32
908
+ embed_dimension = 32
909
+
910
+ dtype = torch.float16
911
+
912
+ query = torch.rand(
913
+ batch_size,
914
+ num_heads,
915
+ max_sequence_len,
916
+ embed_dimension,
917
+ device=device,
918
+ dtype=dtype,
919
+ )
920
+ key = torch.rand(
921
+ batch_size,
922
+ num_heads,
923
+ max_sequence_len,
924
+ embed_dimension,
925
+ device=device,
926
+ dtype=dtype,
927
+ )
928
+ value = torch.rand(
929
+ batch_size,
930
+ num_heads,
931
+ max_sequence_len,
932
+ embed_dimension,
933
+ device=device,
934
+ dtype=dtype,
935
+ )
936
+
937
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
938
+
939
+ # Lets explore the speed of each of the 3 implementations
940
+ from torch.backends.cuda import SDPBackend, sdp_kernel
941
+
942
+ # Helpful arguments mapper
943
+ backend_map = {
944
+ SDPBackend.MATH: {
945
+ "enable_math": True,
946
+ "enable_flash": False,
947
+ "enable_mem_efficient": False,
948
+ },
949
+ SDPBackend.FLASH_ATTENTION: {
950
+ "enable_math": False,
951
+ "enable_flash": True,
952
+ "enable_mem_efficient": False,
953
+ },
954
+ SDPBackend.EFFICIENT_ATTENTION: {
955
+ "enable_math": False,
956
+ "enable_flash": False,
957
+ "enable_mem_efficient": True,
958
+ },
959
+ }
960
+
961
+ from torch.profiler import ProfilerActivity, profile, record_function
962
+
963
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
964
+
965
+ print(
966
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
967
+ )
968
+ with profile(
969
+ activities=activities, record_shapes=False, profile_memory=True
970
+ ) as prof:
971
+ with record_function("Default detailed stats"):
972
+ for _ in range(25):
973
+ o = F.scaled_dot_product_attention(query, key, value)
974
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
975
+
976
+ print(
977
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
978
+ )
979
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
980
+ with profile(
981
+ activities=activities, record_shapes=False, profile_memory=True
982
+ ) as prof:
983
+ with record_function("Math implmentation stats"):
984
+ for _ in range(25):
985
+ o = F.scaled_dot_product_attention(query, key, value)
986
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
987
+
988
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
989
+ try:
990
+ print(
991
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
992
+ )
993
+ except RuntimeError:
994
+ print("FlashAttention is not supported. See warnings for reasons.")
995
+ with profile(
996
+ activities=activities, record_shapes=False, profile_memory=True
997
+ ) as prof:
998
+ with record_function("FlashAttention stats"):
999
+ for _ in range(25):
1000
+ o = F.scaled_dot_product_attention(query, key, value)
1001
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1002
+
1003
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
1004
+ try:
1005
+ print(
1006
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
1007
+ )
1008
+ except RuntimeError:
1009
+ print("EfficientAttention is not supported. See warnings for reasons.")
1010
+ with profile(
1011
+ activities=activities, record_shapes=False, profile_memory=True
1012
+ ) as prof:
1013
+ with record_function("EfficientAttention stats"):
1014
+ for _ in range(25):
1015
+ o = F.scaled_dot_product_attention(query, key, value)
1016
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1017
+
1018
+
1019
+ def run_model(model, x, context):
1020
+ return model(x, context)
1021
+
1022
+
1023
+ def benchmark_transformer_blocks():
1024
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1025
+ import torch.utils.benchmark as benchmark
1026
+
1027
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
1028
+ t0 = benchmark.Timer(
1029
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
1030
+ )
1031
+ return t0.blocked_autorange().mean * 1e6
1032
+
1033
+ checkpoint = True
1034
+ compile = False
1035
+
1036
+ batch_size = 32
1037
+ h, w = 64, 64
1038
+ context_len = 77
1039
+ embed_dimension = 1024
1040
+ context_dim = 1024
1041
+ d_head = 64
1042
+
1043
+ transformer_depth = 4
1044
+
1045
+ n_heads = embed_dimension // d_head
1046
+
1047
+ dtype = torch.float16
1048
+
1049
+ model_native = SpatialTransformer(
1050
+ embed_dimension,
1051
+ n_heads,
1052
+ d_head,
1053
+ context_dim=context_dim,
1054
+ use_linear=True,
1055
+ use_checkpoint=checkpoint,
1056
+ attn_type="softmax",
1057
+ depth=transformer_depth,
1058
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
1059
+ ).to(device)
1060
+ model_efficient_attn = SpatialTransformer(
1061
+ embed_dimension,
1062
+ n_heads,
1063
+ d_head,
1064
+ context_dim=context_dim,
1065
+ use_linear=True,
1066
+ depth=transformer_depth,
1067
+ use_checkpoint=checkpoint,
1068
+ attn_type="softmax-xformers",
1069
+ ).to(device)
1070
+ if not checkpoint and compile:
1071
+ print("compiling models")
1072
+ model_native = torch.compile(model_native)
1073
+ model_efficient_attn = torch.compile(model_efficient_attn)
1074
+
1075
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
1076
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
1077
+
1078
+ from torch.profiler import ProfilerActivity, profile, record_function
1079
+
1080
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
1081
+
1082
+ with torch.autocast("cuda"):
1083
+ print(
1084
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
1085
+ )
1086
+ print(
1087
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
1088
+ )
1089
+
1090
+ print(75 * "+")
1091
+ print("NATIVE")
1092
+ print(75 * "+")
1093
+ torch.cuda.reset_peak_memory_stats()
1094
+ with profile(
1095
+ activities=activities, record_shapes=False, profile_memory=True
1096
+ ) as prof:
1097
+ with record_function("NativeAttention stats"):
1098
+ for _ in range(25):
1099
+ model_native(x, c)
1100
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1101
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
1102
+
1103
+ print(75 * "+")
1104
+ print("Xformers")
1105
+ print(75 * "+")
1106
+ torch.cuda.reset_peak_memory_stats()
1107
+ with profile(
1108
+ activities=activities, record_shapes=False, profile_memory=True
1109
+ ) as prof:
1110
+ with record_function("xformers stats"):
1111
+ for _ in range(25):
1112
+ model_efficient_attn(x, c)
1113
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
1114
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
1115
+
1116
+
1117
+ def test01():
1118
+ # conv1x1 vs linear
1119
+ from ..util import count_params
1120
+
1121
+ conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
1122
+ print(count_params(conv))
1123
+ linear = torch.nn.Linear(3, 32).cuda()
1124
+ print(count_params(linear))
1125
+
1126
+ print(conv.weight.shape)
1127
+
1128
+ # use same initialization
1129
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
1130
+ linear.bias = torch.nn.Parameter(conv.bias)
1131
+
1132
+ print(linear.weight.shape)
1133
+
1134
+ x = torch.randn(11, 3, 64, 64).cuda()
1135
+
1136
+ xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
1137
+ print(xr.shape)
1138
+ out_linear = linear(xr)
1139
+ print(out_linear.mean(), out_linear.shape)
1140
+
1141
+ out_conv = conv(x)
1142
+ print(out_conv.mean(), out_conv.shape)
1143
+ print("done with test01.\n")
1144
+
1145
+
1146
+ def test02():
1147
+ # try cosine flash attention
1148
+ import time
1149
+
1150
+ torch.backends.cuda.matmul.allow_tf32 = True
1151
+ torch.backends.cudnn.allow_tf32 = True
1152
+ torch.backends.cudnn.benchmark = True
1153
+ print("testing cosine flash attention...")
1154
+ DIM = 1024
1155
+ SEQLEN = 4096
1156
+ BS = 16
1157
+
1158
+ print(" softmax (vanilla) first...")
1159
+ model = BasicTransformerBlock(
1160
+ dim=DIM,
1161
+ n_heads=16,
1162
+ d_head=64,
1163
+ dropout=0.0,
1164
+ context_dim=None,
1165
+ attn_mode="softmax",
1166
+ ).cuda()
1167
+ try:
1168
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
1169
+ tic = time.time()
1170
+ y = model(x)
1171
+ toc = time.time()
1172
+ print(y.shape, toc - tic)
1173
+ except RuntimeError as e:
1174
+ # likely oom
1175
+ print(str(e))
1176
+
1177
+ print("\n now flash-cosine...")
1178
+ model = BasicTransformerBlock(
1179
+ dim=DIM,
1180
+ n_heads=16,
1181
+ d_head=64,
1182
+ dropout=0.0,
1183
+ context_dim=None,
1184
+ attn_mode="flash-cosine",
1185
+ ).cuda()
1186
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
1187
+ tic = time.time()
1188
+ y = model(x)
1189
+ toc = time.time()
1190
+ print(y.shape, toc - tic)
1191
+ print("done with test02.\n")
1192
+
1193
+
1194
+ if __name__ == "__main__":
1195
+ # test01()
1196
+ # test02()
1197
+ # test03()
1198
+
1199
+ # benchmark_attn()
1200
+ benchmark_transformer_blocks()
1201
+
1202
+ print("done.")
sgm/modules/autoencoding/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/loss/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/loss/lpips.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ from collections import namedtuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import models
8
+
9
+ from ..util import get_ckpt_path
10
+
11
+
12
+ class LPIPS(nn.Module):
13
+ # Learned perceptual metric
14
+ def __init__(self, use_dropout=True):
15
+ super().__init__()
16
+ self.scaling_layer = ScalingLayer()
17
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
18
+ self.net = vgg16(pretrained=True, requires_grad=False)
19
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
20
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
21
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
22
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
23
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
24
+ self.load_from_pretrained()
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def load_from_pretrained(self, name="vgg_lpips"):
29
+ ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
30
+ self.load_state_dict(
31
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
32
+ )
33
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, name="vgg_lpips"):
37
+ if name != "vgg_lpips":
38
+ raise NotImplementedError
39
+ model = cls()
40
+ ckpt = get_ckpt_path(name)
41
+ model.load_state_dict(
42
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
43
+ )
44
+ return model
45
+
46
+ def forward(self, input, target):
47
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
48
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
49
+ feats0, feats1, diffs = {}, {}, {}
50
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
51
+ for kk in range(len(self.chns)):
52
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
53
+ outs1[kk]
54
+ )
55
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
56
+
57
+ res = [
58
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
59
+ for kk in range(len(self.chns))
60
+ ]
61
+ val = res[0]
62
+ for l in range(1, len(self.chns)):
63
+ val += res[l]
64
+ return val
65
+
66
+
67
+ class ScalingLayer(nn.Module):
68
+ def __init__(self):
69
+ super(ScalingLayer, self).__init__()
70
+ self.register_buffer(
71
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
72
+ )
73
+ self.register_buffer(
74
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
75
+ )
76
+
77
+ def forward(self, inp):
78
+ return (inp - self.shift) / self.scale
79
+
80
+
81
+ class NetLinLayer(nn.Module):
82
+ """A single linear layer which does a 1x1 conv"""
83
+
84
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
85
+ super(NetLinLayer, self).__init__()
86
+ layers = (
87
+ [
88
+ nn.Dropout(),
89
+ ]
90
+ if (use_dropout)
91
+ else []
92
+ )
93
+ layers += [
94
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
95
+ ]
96
+ self.model = nn.Sequential(*layers)
97
+
98
+
99
+ class vgg16(torch.nn.Module):
100
+ def __init__(self, requires_grad=False, pretrained=True):
101
+ super(vgg16, self).__init__()
102
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
103
+ self.slice1 = torch.nn.Sequential()
104
+ self.slice2 = torch.nn.Sequential()
105
+ self.slice3 = torch.nn.Sequential()
106
+ self.slice4 = torch.nn.Sequential()
107
+ self.slice5 = torch.nn.Sequential()
108
+ self.N_slices = 5
109
+ for x in range(4):
110
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(4, 9):
112
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(9, 16):
114
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(16, 23):
116
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
117
+ for x in range(23, 30):
118
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
119
+ if not requires_grad:
120
+ for param in self.parameters():
121
+ param.requires_grad = False
122
+
123
+ def forward(self, X):
124
+ h = self.slice1(X)
125
+ h_relu1_2 = h
126
+ h = self.slice2(h)
127
+ h_relu2_2 = h
128
+ h = self.slice3(h)
129
+ h_relu3_3 = h
130
+ h = self.slice4(h)
131
+ h_relu4_3 = h
132
+ h = self.slice5(h)
133
+ h_relu5_3 = h
134
+ vgg_outputs = namedtuple(
135
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
136
+ )
137
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
138
+ return out
139
+
140
+
141
+ def normalize_tensor(x, eps=1e-10):
142
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
143
+ return x / (norm_factor + eps)
144
+
145
+
146
+ def spatial_average(x, keepdim=True):
147
+ return x.mean([2, 3], keepdim=keepdim)
sgm/modules/autoencoding/lpips/model/LICENSE ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
18
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
20
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
21
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
22
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
23
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24
+
25
+
26
+ --------------------------- LICENSE FOR pix2pix --------------------------------
27
+ BSD License
28
+
29
+ For pix2pix software
30
+ Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
31
+ All rights reserved.
32
+
33
+ Redistribution and use in source and binary forms, with or without
34
+ modification, are permitted provided that the following conditions are met:
35
+
36
+ * Redistributions of source code must retain the above copyright notice, this
37
+ list of conditions and the following disclaimer.
38
+
39
+ * Redistributions in binary form must reproduce the above copyright notice,
40
+ this list of conditions and the following disclaimer in the documentation
41
+ and/or other materials provided with the distribution.
42
+
43
+ ----------------------------- LICENSE FOR DCGAN --------------------------------
44
+ BSD License
45
+
46
+ For dcgan.torch software
47
+
48
+ Copyright (c) 2015, Facebook, Inc. All rights reserved.
49
+
50
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
51
+
52
+ Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
53
+
54
+ Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
55
+
56
+ Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
57
+
58
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
sgm/modules/autoencoding/lpips/model/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/model/model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch.nn as nn
4
+
5
+ from ..util import ActNorm
6
+
7
+
8
+ def weights_init(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
12
+ elif classname.find("BatchNorm") != -1:
13
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
14
+ nn.init.constant_(m.bias.data, 0)
15
+
16
+
17
+ class NLayerDiscriminator(nn.Module):
18
+ """Defines a PatchGAN discriminator as in Pix2Pix
19
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20
+ """
21
+
22
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
23
+ """Construct a PatchGAN discriminator
24
+ Parameters:
25
+ input_nc (int) -- the number of channels in input images
26
+ ndf (int) -- the number of filters in the last conv layer
27
+ n_layers (int) -- the number of conv layers in the discriminator
28
+ norm_layer -- normalization layer
29
+ """
30
+ super(NLayerDiscriminator, self).__init__()
31
+ if not use_actnorm:
32
+ norm_layer = nn.BatchNorm2d
33
+ else:
34
+ norm_layer = ActNorm
35
+ if (
36
+ type(norm_layer) == functools.partial
37
+ ): # no need to use bias as BatchNorm2d has affine parameters
38
+ use_bias = norm_layer.func != nn.BatchNorm2d
39
+ else:
40
+ use_bias = norm_layer != nn.BatchNorm2d
41
+
42
+ kw = 4
43
+ padw = 1
44
+ sequence = [
45
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
46
+ nn.LeakyReLU(0.2, True),
47
+ ]
48
+ nf_mult = 1
49
+ nf_mult_prev = 1
50
+ for n in range(1, n_layers): # gradually increase the number of filters
51
+ nf_mult_prev = nf_mult
52
+ nf_mult = min(2**n, 8)
53
+ sequence += [
54
+ nn.Conv2d(
55
+ ndf * nf_mult_prev,
56
+ ndf * nf_mult,
57
+ kernel_size=kw,
58
+ stride=2,
59
+ padding=padw,
60
+ bias=use_bias,
61
+ ),
62
+ norm_layer(ndf * nf_mult),
63
+ nn.LeakyReLU(0.2, True),
64
+ ]
65
+
66
+ nf_mult_prev = nf_mult
67
+ nf_mult = min(2**n_layers, 8)
68
+ sequence += [
69
+ nn.Conv2d(
70
+ ndf * nf_mult_prev,
71
+ ndf * nf_mult,
72
+ kernel_size=kw,
73
+ stride=1,
74
+ padding=padw,
75
+ bias=use_bias,
76
+ ),
77
+ norm_layer(ndf * nf_mult),
78
+ nn.LeakyReLU(0.2, True),
79
+ ]
80
+
81
+ sequence += [
82
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
83
+ ] # output 1 channel prediction map
84
+ self.main = nn.Sequential(*sequence)
85
+
86
+ def forward(self, input):
87
+ """Standard forward."""
88
+ return self.main(input)
sgm/modules/autoencoding/lpips/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+
4
+ import requests
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+
9
+ URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10
+
11
+ CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12
+
13
+ MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14
+
15
+
16
+ def download(url, local_path, chunk_size=1024):
17
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18
+ with requests.get(url, stream=True) as r:
19
+ total_size = int(r.headers.get("content-length", 0))
20
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21
+ with open(local_path, "wb") as f:
22
+ for data in r.iter_content(chunk_size=chunk_size):
23
+ if data:
24
+ f.write(data)
25
+ pbar.update(chunk_size)
26
+
27
+
28
+ def md5_hash(path):
29
+ with open(path, "rb") as f:
30
+ content = f.read()
31
+ return hashlib.md5(content).hexdigest()
32
+
33
+
34
+ def get_ckpt_path(name, root, check=False):
35
+ assert name in URL_MAP
36
+ path = os.path.join(root, CKPT_MAP[name])
37
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39
+ download(URL_MAP[name], path)
40
+ md5 = md5_hash(path)
41
+ assert md5 == MD5_MAP[name], md5
42
+ return path
43
+
44
+
45
+ class ActNorm(nn.Module):
46
+ def __init__(
47
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
48
+ ):
49
+ assert affine
50
+ super().__init__()
51
+ self.logdet = logdet
52
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54
+ self.allow_reverse_init = allow_reverse_init
55
+
56
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57
+
58
+ def initialize(self, input):
59
+ with torch.no_grad():
60
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61
+ mean = (
62
+ flatten.mean(1)
63
+ .unsqueeze(1)
64
+ .unsqueeze(2)
65
+ .unsqueeze(3)
66
+ .permute(1, 0, 2, 3)
67
+ )
68
+ std = (
69
+ flatten.std(1)
70
+ .unsqueeze(1)
71
+ .unsqueeze(2)
72
+ .unsqueeze(3)
73
+ .permute(1, 0, 2, 3)
74
+ )
75
+
76
+ self.loc.data.copy_(-mean)
77
+ self.scale.data.copy_(1 / (std + 1e-6))
78
+
79
+ def forward(self, input, reverse=False):
80
+ if reverse:
81
+ return self.reverse(input)
82
+ if len(input.shape) == 2:
83
+ input = input[:, :, None, None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ _, _, height, width = input.shape
89
+
90
+ if self.training and self.initialized.item() == 0:
91
+ self.initialize(input)
92
+ self.initialized.fill_(1)
93
+
94
+ h = self.scale * (input + self.loc)
95
+
96
+ if squeeze:
97
+ h = h.squeeze(-1).squeeze(-1)
98
+
99
+ if self.logdet:
100
+ log_abs = torch.log(torch.abs(self.scale))
101
+ logdet = height * width * torch.sum(log_abs)
102
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
103
+ return h, logdet
104
+
105
+ return h
106
+
107
+ def reverse(self, output):
108
+ if self.training and self.initialized.item() == 0:
109
+ if not self.allow_reverse_init:
110
+ raise RuntimeError(
111
+ "Initializing ActNorm in reverse direction is "
112
+ "disabled by default. Use allow_reverse_init=True to enable."
113
+ )
114
+ else:
115
+ self.initialize(output)
116
+ self.initialized.fill_(1)
117
+
118
+ if len(output.shape) == 2:
119
+ output = output[:, :, None, None]
120
+ squeeze = True
121
+ else:
122
+ squeeze = False
123
+
124
+ h = output / self.scale - self.loc
125
+
126
+ if squeeze:
127
+ h = h.squeeze(-1).squeeze(-1)
128
+ return h
sgm/modules/autoencoding/lpips/vqperceptual.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def hinge_d_loss(logits_real, logits_fake):
6
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
7
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8
+ d_loss = 0.5 * (loss_real + loss_fake)
9
+ return d_loss
10
+
11
+
12
+ def vanilla_d_loss(logits_real, logits_fake):
13
+ d_loss = 0.5 * (
14
+ torch.mean(torch.nn.functional.softplus(-logits_real))
15
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
16
+ )
17
+ return d_loss
sgm/modules/autoencoding/regularizers/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Any, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ....modules.distributions.distributions import \
9
+ DiagonalGaussianDistribution
10
+ from .base import AbstractRegularizer
11
+
12
+
13
+ class DiagonalGaussianRegularizer(AbstractRegularizer):
14
+ def __init__(self, sample: bool = True):
15
+ super().__init__()
16
+ self.sample = sample
17
+
18
+ def get_trainable_parameters(self) -> Any:
19
+ yield from ()
20
+
21
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
22
+ log = dict()
23
+ posterior = DiagonalGaussianDistribution(z)
24
+ if self.sample:
25
+ z = posterior.sample()
26
+ else:
27
+ z = posterior.mode()
28
+ kl_loss = posterior.kl()
29
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
30
+ log["kl_loss"] = kl_loss
31
+ return z, log