Warvito commited on
Commit
c9cd3be
0 Parent(s):

commit message

Browse files
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ .idea/
132
+ outputs/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Testing Gradio
3
+ emoji: 🏢
4
+ colorFrom: yellow
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.1.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import shutil
3
+ import uuid
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import mediapy
9
+ import mlflow.pytorch
10
+ import numpy as np
11
+ import torch
12
+ from skimage import img_as_ubyte
13
+
14
+ from models.ddim import DDIMSampler
15
+
16
+ import nibabel as nib
17
+
18
+ ffmpeg_path = shutil.which("ffmpeg")
19
+ mediapy.set_ffmpeg(ffmpeg_path)
20
+
21
+ # Loading model
22
+ vqvae = mlflow.pytorch.load_model(
23
+ "./trained_models/vae/final_model"
24
+ )
25
+ vqvae.eval()
26
+
27
+ diffusion = mlflow.pytorch.load_model(
28
+ "./trained_models/ddpm/final_model"
29
+ )
30
+ diffusion.eval()
31
+
32
+ device = torch.device("cpu")
33
+ diffusion = diffusion.to(device)
34
+ vqvae = vqvae.to(device)
35
+
36
+
37
+ def sample_fn(
38
+ gender_radio,
39
+ age_slider,
40
+ ventricular_slider,
41
+ brain_slider,
42
+ ):
43
+ print("Sampling brain!")
44
+ print(f"Gender: {gender_radio}")
45
+ print(f"Age: {age_slider}")
46
+ print(f"Ventricular volume: {ventricular_slider}")
47
+ print(f"Brain volume: {brain_slider}")
48
+
49
+ age_slider = (age_slider - 44) / (82 - 44)
50
+
51
+ cond = torch.Tensor([[gender_radio, age_slider, ventricular_slider, brain_slider]])
52
+ latent_shape = [1, 3, 20, 28, 20]
53
+ cond_crossatten = cond.unsqueeze(1)
54
+ cond_concat = cond.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
55
+ cond_concat = cond_concat.expand(list(cond.shape[0:2]) + list(latent_shape[2:]))
56
+ conditioning = {
57
+ "c_concat": [cond_concat.float().to(device)],
58
+ "c_crossattn": [cond_crossatten.float().to(device)],
59
+ }
60
+
61
+ ddim = DDIMSampler(diffusion)
62
+ num_timesteps = 50
63
+ latent_vectors, _ = ddim.sample(
64
+ num_timesteps,
65
+ conditioning=conditioning,
66
+ batch_size=1,
67
+ shape=list(latent_shape[1:]),
68
+ eta=1.0,
69
+ )
70
+
71
+ with torch.no_grad():
72
+ x_hat = vqvae.reconstruct_ldm_outputs(latent_vectors).cpu()
73
+
74
+ return x_hat.numpy()
75
+
76
+
77
+ def create_videos_and_file(
78
+ gender_radio,
79
+ age_slider,
80
+ ventricular_slider,
81
+ brain_slider,
82
+ ):
83
+ output_dir = Path(
84
+ f"/media/walter/Storage/Projects/gradio_medical_ldm/outputs/{str(uuid.uuid4())}"
85
+ )
86
+ output_dir.mkdir(exist_ok=True)
87
+
88
+ image_data = sample_fn(
89
+ gender_radio,
90
+ age_slider,
91
+ ventricular_slider,
92
+ brain_slider,
93
+ )
94
+ image_data = image_data[0, 0, 5:-5, 5:-5, :-15]
95
+ image_data = (image_data - image_data.min()) / (image_data.max() - image_data.min())
96
+ image_data = (image_data * 255).astype(np.uint8)
97
+
98
+ # Write frames to video
99
+ with mediapy.VideoWriter(
100
+ f"{str(output_dir)}/brain_axial.mp4", shape=(150, 214), fps=12, crf=18
101
+ ) as w:
102
+ for idx in range(image_data.shape[2]):
103
+ img = image_data[:, :, idx]
104
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
105
+ frame = img_as_ubyte(img)
106
+ w.add_image(frame)
107
+
108
+ with mediapy.VideoWriter(
109
+ f"{str(output_dir)}/brain_sagittal.mp4", shape=(145, 214), fps=12, crf=18
110
+ ) as w:
111
+ for idx in range(image_data.shape[0]):
112
+ img = np.rot90(image_data[idx, :, :])
113
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
114
+ frame = img_as_ubyte(img)
115
+ w.add_image(frame)
116
+
117
+ with mediapy.VideoWriter(
118
+ f"{str(output_dir)}/brain_coronal.mp4", shape=(145, 150), fps=12, crf=18
119
+ ) as w:
120
+ for idx in range(image_data.shape[1]):
121
+ img = np.rot90(np.flip(image_data, axis=1)[:, idx, :])
122
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
123
+ frame = img_as_ubyte(img)
124
+ w.add_image(frame)
125
+
126
+ # Create file
127
+ affine = np.array(
128
+ [
129
+ [-1.0, 0.0, 0.0, 96.48149872],
130
+ [0.0, 1.0, 0.0, -141.47715759],
131
+ [0.0, 0.0, 1.0, -156.55375671],
132
+ [0.0, 0.0, 0.0, 1.0],
133
+ ]
134
+ )
135
+ empty_header = nib.Nifti1Header()
136
+ sample_nii = nib.Nifti1Image(image_data, affine, empty_header)
137
+ nib.save(sample_nii, f"{str(output_dir)}/my_brain.nii.gz")
138
+
139
+ # time.sleep(2)
140
+
141
+ return (
142
+ f"{str(output_dir)}/brain_axial.mp4",
143
+ f"{str(output_dir)}/brain_sagittal.mp4",
144
+ f"{str(output_dir)}/brain_coronal.mp4",
145
+ f"{str(output_dir)}/my_brain.nii.gz",
146
+ )
147
+
148
+
149
+ def randomise():
150
+ random_age = round(random.uniform(44.0, 82.0), 2)
151
+ return (
152
+ random.choice(["Female", "Male"]),
153
+ random_age,
154
+ round(random.uniform(0, 1.0), 2),
155
+ round(random.uniform(0, 1.0), 2),
156
+ )
157
+
158
+
159
+ def unrest_randomise():
160
+ random_age = round(random.uniform(18.0, 100.0), 2)
161
+ return (
162
+ random.choice([1, 0]),
163
+ random_age,
164
+ round(random.uniform(-1.0, 2.0), 2),
165
+ round(random.uniform(-1.0, 2.0), 2),
166
+ )
167
+
168
+
169
+ # TEXT
170
+ title = "Generating Brain Imaging with Diffusion Models"
171
+ description = """
172
+ <center><b>WORK IN PROGRESS. DO NOT SHARE.</b></center>
173
+ <center><a href="https://arxiv.org/">[PAPER]</a> <a href="https://academictorrents.com/details/63aeb864bbe2115ded0aa0d7d36334c026f0660b">[DATASET]</a></center>
174
+
175
+ <details>
176
+ <summary>Instructions</summary>
177
+
178
+ With this app, you can generate synthetic brain images with one click!<br />You have two ways to set how your generated brain will look like:<br />- Using the "Inputs" tab that creates well-behaved brains using the same value ranges that our models learned as described in paper linked above<br />- Or using the "Unrestricted Inputs" tab to generate the wildest brains!<br />After customisation, just hit "Generate" and wait a few seconds.<br />Note: if are having problems with the videos, try our app using chrome. <b>Enjoy!<b>
179
+ </details>
180
+
181
+ """
182
+
183
+ article = """
184
+ Checkout our dataset with [100K synthetic brain](https://academictorrents.com/details/63aeb864bbe2115ded0aa0d7d36334c026f0660b)! 🧠🧠🧠
185
+
186
+ App made by [Walter Hugo Lopez Pinaya](https://twitter.com/warvito) from [AMIGO](https://amigos.ai/)
187
+ <center><img src="https://amigos.ai/assets/images/logo_dark_rect.png" alt="amigos.ai" width=300px></center>
188
+ """
189
+
190
+ demo = gr.Blocks()
191
+
192
+ with demo:
193
+ gr.Markdown(
194
+ "<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>"
195
+ )
196
+ gr.Markdown(description)
197
+ with gr.Row():
198
+ with gr.Column():
199
+ with gr.Box():
200
+ with gr.Tabs():
201
+ with gr.TabItem("Inputs"):
202
+ with gr.Row():
203
+ gender_radio = gr.Radio(
204
+ choices=["Female", "Male"],
205
+ value="Female",
206
+ type="index",
207
+ label="Gender",
208
+ interactive=True,
209
+ )
210
+ age_slider = gr.Slider(
211
+ minimum=44,
212
+ maximum=82,
213
+ value=63,
214
+ label="Age [years]",
215
+ interactive=True,
216
+ )
217
+ with gr.Row():
218
+ ventricular_slider = gr.Slider(
219
+ minimum=0,
220
+ maximum=1,
221
+ value=0.5,
222
+ label="Volume of ventricular cerebrospinal fluid",
223
+ interactive=True,
224
+ )
225
+ brain_slider = gr.Slider(
226
+ minimum=0,
227
+ maximum=1,
228
+ value=0.5,
229
+ label="Volume of brain",
230
+ interactive=True,
231
+ )
232
+ with gr.Row():
233
+ submit_btn = gr.Button("Generate", variant="primary")
234
+ randomize_btn = gr.Button("I'm Feeling Lucky")
235
+
236
+ with gr.TabItem("Unrestricted Inputs"):
237
+ with gr.Row():
238
+ unrest_gender_number = gr.Number(
239
+ value=1.0,
240
+ precision=1,
241
+ label="Gender [Female=0, Male=1]",
242
+ interactive=True,
243
+ )
244
+ unrest_age_number = gr.Number(
245
+ value=63,
246
+ precision=1,
247
+ label="Age [years]",
248
+ interactive=True,
249
+ )
250
+ with gr.Row():
251
+ unrest_ventricular_number = gr.Number(
252
+ value=0.5,
253
+ precision=2,
254
+ label="Volume of ventricular cerebrospinal fluid",
255
+ interactive=True,
256
+ )
257
+ unrest_brain_number = gr.Number(
258
+ value=0.5,
259
+ precision=2,
260
+ label="Volume of brain",
261
+ interactive=True,
262
+ )
263
+ with gr.Row():
264
+ unrest_submit_btn = gr.Button("Generate", variant="primary")
265
+ unrest_randomize_btn = gr.Button("I'm Feeling Lucky")
266
+
267
+ gr.Examples(
268
+ examples=[
269
+ [1, 63, 1.3, 0.5],
270
+ [0, 63, 1.9, 0.5],
271
+ [1, 63, -0.5, 0.5],
272
+ [0, 63, 0.5, -0.3],
273
+ ],
274
+ inputs=[
275
+ unrest_gender_number,
276
+ unrest_age_number,
277
+ unrest_ventricular_number,
278
+ unrest_brain_number,
279
+ ],
280
+ )
281
+
282
+ with gr.Column():
283
+ with gr.Box():
284
+ with gr.Tabs():
285
+ with gr.TabItem("Axial View"):
286
+ axial_sample_plot = gr.Video(show_label=False)
287
+ with gr.TabItem("Sagittal View"):
288
+ sagittal_sample_plot = gr.Video(show_label=False)
289
+ with gr.TabItem("Coronal View"):
290
+ coronal_sample_plot = gr.Video(show_label=False)
291
+ sample_file = gr.File(label="My Brain")
292
+
293
+ gr.Markdown(article)
294
+
295
+ submit_btn.click(
296
+ create_videos_and_file,
297
+ [
298
+ gender_radio,
299
+ age_slider,
300
+ ventricular_slider,
301
+ brain_slider,
302
+ ],
303
+ [axial_sample_plot, sagittal_sample_plot, coronal_sample_plot, sample_file],
304
+ # [axial_sample_plot, sagittal_sample_plot, coronal_sample_plot],
305
+ )
306
+ unrest_submit_btn.click(
307
+ create_videos_and_file,
308
+ [
309
+ unrest_gender_number,
310
+ unrest_age_number,
311
+ unrest_ventricular_number,
312
+ unrest_brain_number,
313
+ ],
314
+ [axial_sample_plot, sagittal_sample_plot, coronal_sample_plot, sample_file],
315
+ # [axial_sample_plot, sagittal_sample_plot, coronal_sample_plot],
316
+ )
317
+
318
+ randomize_btn.click(
319
+ fn=randomise,
320
+ inputs=[],
321
+ queue=False,
322
+ outputs=[gender_radio, age_slider, ventricular_slider, brain_slider],
323
+ )
324
+
325
+ unrest_randomize_btn.click(
326
+ fn=unrest_randomise,
327
+ inputs=[],
328
+ queue=False,
329
+ outputs=[
330
+ unrest_gender_number,
331
+ unrest_age_number,
332
+ unrest_ventricular_number,
333
+ unrest_brain_number,
334
+ ],
335
+ )
336
+
337
+ # demo.launch(share=True, enable_queue=True)
338
+ demo.launch(enable_queue=True)
models/aekl_no_attention.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AUTOENCODER WITH ARCHTECTURE FROM VERSION 2
3
+ """
4
+ from typing import Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ @torch.jit.script
12
+ def swish(x):
13
+ return x * torch.sigmoid(x)
14
+
15
+
16
+ def Normalize(in_channels):
17
+ return nn.GroupNorm(
18
+ num_groups=32,
19
+ num_channels=in_channels,
20
+ eps=1e-6,
21
+ affine=True
22
+ )
23
+
24
+
25
+ class Upsample(nn.Module):
26
+ def __init__(self, in_channels):
27
+ super().__init__()
28
+ self.conv = nn.Conv3d(
29
+ in_channels,
30
+ in_channels,
31
+ kernel_size=3,
32
+ stride=1,
33
+ padding=1
34
+ )
35
+
36
+ def forward(self, x):
37
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
38
+ x = self.conv(x)
39
+ return x
40
+
41
+
42
+ class Downsample(nn.Module):
43
+ def __init__(self, in_channels):
44
+ super().__init__()
45
+ self.conv = nn.Conv3d(
46
+ in_channels,
47
+ in_channels,
48
+ kernel_size=3,
49
+ stride=2,
50
+ padding=0
51
+ )
52
+
53
+ def forward(self, x):
54
+ pad = (0, 1, 0, 1, 0, 1)
55
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class ResBlock(nn.Module):
61
+ def __init__(self, in_channels, out_channels=None):
62
+ super().__init__()
63
+ self.in_channels = in_channels
64
+ self.out_channels = in_channels if out_channels is None else out_channels
65
+ self.norm1 = Normalize(in_channels)
66
+ self.conv1 = nn.Conv3d(
67
+ in_channels,
68
+ out_channels,
69
+ kernel_size=3,
70
+ stride=1,
71
+ padding=1
72
+ )
73
+ self.norm2 = Normalize(out_channels)
74
+ self.conv2 = nn.Conv3d(
75
+ out_channels,
76
+ out_channels,
77
+ kernel_size=3,
78
+ stride=1,
79
+ padding=1
80
+ )
81
+
82
+ if self.in_channels != self.out_channels:
83
+ self.nin_shortcut = nn.Conv3d(
84
+ in_channels,
85
+ out_channels,
86
+ kernel_size=1,
87
+ stride=1,
88
+ padding=0
89
+ )
90
+
91
+ def forward(self, x):
92
+ h = x
93
+ h = self.norm1(h)
94
+ h = F.silu(h)
95
+ h = self.conv1(h)
96
+
97
+ h = self.norm2(h)
98
+ h = F.silu(h)
99
+ h = self.conv2(h)
100
+
101
+ if self.in_channels != self.out_channels:
102
+ x = self.nin_shortcut(x)
103
+
104
+ return x + h
105
+
106
+
107
+ class Encoder(nn.Module):
108
+ def __init__(
109
+ self,
110
+ in_channels: int,
111
+ n_channels: int,
112
+ z_channels: int,
113
+ ch_mult: Tuple[int],
114
+ num_res_blocks: int,
115
+ resolution: Tuple[int],
116
+ attn_resolutions: Tuple[int],
117
+ **ignorekwargs,
118
+ ) -> None:
119
+ super().__init__()
120
+ self.in_channels = in_channels
121
+ self.n_channels = n_channels
122
+ self.num_resolutions = len(ch_mult)
123
+ self.num_res_blocks = num_res_blocks
124
+ self.resolution = resolution
125
+ self.attn_resolutions = attn_resolutions
126
+
127
+ curr_res = resolution
128
+ in_ch_mult = (1,) + tuple(ch_mult)
129
+
130
+ blocks = []
131
+ # initial convolution
132
+ blocks.append(
133
+ nn.Conv3d(
134
+ in_channels,
135
+ n_channels,
136
+ kernel_size=3,
137
+ stride=1,
138
+ padding=1
139
+ )
140
+ )
141
+
142
+ # residual and downsampling blocks, with attention on smaller res (16x16)
143
+ for i in range(self.num_resolutions):
144
+ block_in_ch = n_channels * in_ch_mult[i]
145
+ block_out_ch = n_channels * ch_mult[i]
146
+ for _ in range(self.num_res_blocks):
147
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
148
+ block_in_ch = block_out_ch
149
+
150
+ if i != self.num_resolutions - 1:
151
+ blocks.append(Downsample(block_in_ch))
152
+ curr_res = tuple(ti // 2 for ti in curr_res)
153
+
154
+ # normalise and convert to latent size
155
+ blocks.append(Normalize(block_in_ch))
156
+ blocks.append(
157
+ nn.Conv3d(
158
+ block_in_ch,
159
+ z_channels,
160
+ kernel_size=3,
161
+ stride=1,
162
+ padding=1
163
+ )
164
+ )
165
+
166
+ self.blocks = nn.ModuleList(blocks)
167
+
168
+ def forward(self, x):
169
+ for block in self.blocks:
170
+ x = block(x)
171
+ return x
172
+
173
+
174
+ class Decoder(nn.Module):
175
+ def __init__(
176
+ self,
177
+ n_channels: int,
178
+ z_channels: int,
179
+ out_channels: int,
180
+ ch_mult: Tuple[int],
181
+ num_res_blocks: int,
182
+ resolution: Tuple[int],
183
+ attn_resolutions: Tuple[int],
184
+ **ignorekwargs,
185
+ ) -> None:
186
+ super().__init__()
187
+ self.n_channels = n_channels
188
+ self.z_channels = z_channels
189
+ self.out_channels = out_channels
190
+ self.ch_mult = ch_mult
191
+ self.num_resolutions = len(ch_mult)
192
+ self.num_res_blocks = num_res_blocks
193
+ self.resolution = resolution
194
+ self.attn_resolutions = attn_resolutions
195
+
196
+ block_in_ch = n_channels * self.ch_mult[-1]
197
+ curr_res = tuple(ti // 2 ** (self.num_resolutions - 1) for ti in resolution)
198
+
199
+ blocks = []
200
+ # initial conv
201
+ blocks.append(
202
+ nn.Conv3d(
203
+ z_channels,
204
+ block_in_ch,
205
+ kernel_size=3,
206
+ stride=1,
207
+ padding=1
208
+ )
209
+ )
210
+
211
+ for i in reversed(range(self.num_resolutions)):
212
+ block_out_ch = n_channels * self.ch_mult[i]
213
+
214
+ for _ in range(self.num_res_blocks):
215
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
216
+ block_in_ch = block_out_ch
217
+
218
+ if i != 0:
219
+ blocks.append(Upsample(block_in_ch))
220
+ curr_res = tuple(ti * 2 for ti in curr_res)
221
+
222
+ blocks.append(Normalize(block_in_ch))
223
+ blocks.append(
224
+ nn.Conv3d(
225
+ block_in_ch,
226
+ out_channels,
227
+ kernel_size=3,
228
+ stride=1,
229
+ padding=1
230
+ )
231
+ )
232
+
233
+ self.blocks = nn.ModuleList(blocks)
234
+
235
+ def forward(self, x):
236
+ for block in self.blocks:
237
+ x = block(x)
238
+ return x
239
+
240
+
241
+ class AutoencoderKL(nn.Module):
242
+ def __init__(self, embed_dim: int, hparams) -> None:
243
+ super().__init__()
244
+ self.encoder = Encoder(**hparams)
245
+ self.decoder = Decoder(**hparams)
246
+ self.quant_conv_mu = torch.nn.Conv3d(hparams["z_channels"], embed_dim, 1)
247
+ self.quant_conv_log_sigma = torch.nn.Conv3d(hparams["z_channels"], embed_dim, 1)
248
+ self.post_quant_conv = torch.nn.Conv3d(embed_dim, hparams["z_channels"], 1)
249
+ self.embed_dim = embed_dim
250
+
251
+ def decode(self, z):
252
+ z = self.post_quant_conv(z)
253
+ dec = self.decoder(z)
254
+ return dec
255
+
256
+ def reconstruct_ldm_outputs(self, z):
257
+ x_hat = self.decode(z)
258
+ return x_hat
models/attention.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ from torch import nn, einsum
7
+
8
+
9
+ def exists(val):
10
+ return val is not None
11
+
12
+
13
+ def default(val, d):
14
+ if exists(val):
15
+ return val
16
+ return d() if isfunction(d) else d
17
+
18
+
19
+ # feedforward
20
+ class GEGLU(nn.Module):
21
+ def __init__(self, dim_in, dim_out):
22
+ super().__init__()
23
+ self.proj = nn.Linear(dim_in, dim_out * 2)
24
+
25
+ def forward(self, x):
26
+ x, gate = self.proj(x).chunk(2, dim=-1)
27
+ return x * F.gelu(gate)
28
+
29
+
30
+ class FeedForward(nn.Module):
31
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
32
+ super().__init__()
33
+ inner_dim = int(dim * mult)
34
+ dim_out = default(dim_out, dim)
35
+ project_in = nn.Sequential(
36
+ nn.Linear(dim, inner_dim),
37
+ nn.GELU()
38
+ ) if not glu else GEGLU(dim, inner_dim)
39
+
40
+ self.net = nn.Sequential(
41
+ project_in,
42
+ nn.Dropout(dropout),
43
+ nn.Linear(inner_dim, dim_out)
44
+ )
45
+
46
+ def forward(self, x):
47
+ return self.net(x)
48
+
49
+
50
+ def zero_module(module):
51
+ """
52
+ Zero out the parameters of a module and return it.
53
+ """
54
+ for p in module.parameters():
55
+ p.detach().zero_()
56
+ return module
57
+
58
+
59
+ def Normalize(in_channels):
60
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
61
+
62
+
63
+ class CrossAttention(nn.Module):
64
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
65
+ super().__init__()
66
+ inner_dim = dim_head * heads
67
+ context_dim = default(context_dim, query_dim)
68
+
69
+ self.scale = dim_head ** -0.5
70
+ self.heads = heads
71
+
72
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
73
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
74
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
75
+
76
+ self.to_out = nn.Sequential(
77
+ nn.Linear(inner_dim, query_dim),
78
+ nn.Dropout(dropout)
79
+ )
80
+
81
+ def forward(self, x, context=None, mask=None):
82
+ h = self.heads
83
+
84
+ q = self.to_q(x)
85
+ context = default(context, x)
86
+ k = self.to_k(context)
87
+ v = self.to_v(context)
88
+
89
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
90
+
91
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
92
+
93
+ if exists(mask):
94
+ mask = rearrange(mask, 'b ... -> b (...)')
95
+ max_neg_value = -torch.finfo(sim.dtype).max
96
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
97
+ sim.masked_fill_(~mask, max_neg_value)
98
+
99
+ # attention, what we cannot get enough of
100
+ attn = sim.softmax(dim=-1)
101
+
102
+ out = einsum('b i j, b j d -> b i d', attn, v)
103
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
104
+ return self.to_out(out)
105
+
106
+
107
+ class BasicTransformerBlock(nn.Module):
108
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):
109
+ super().__init__()
110
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
111
+ dropout=dropout) # is a self-attention
112
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
113
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
114
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
115
+ self.norm1 = nn.LayerNorm(dim)
116
+ self.norm2 = nn.LayerNorm(dim)
117
+ self.norm3 = nn.LayerNorm(dim)
118
+
119
+ def forward(self, x, context=None):
120
+ x = self.attn1(self.norm1(x)) + x
121
+ x = self.attn2(self.norm2(x), context=context) + x
122
+ x = self.ff(self.norm3(x)) + x
123
+ return x
124
+
125
+
126
+ class SpatialTransformer(nn.Module):
127
+ """
128
+ Transformer block for image-like data.
129
+ First, project the input (aka embedding)
130
+ and reshape to b, t, d.
131
+ Then apply standard transformer action.
132
+ Finally, reshape to image
133
+ """
134
+
135
+ def __init__(self, in_channels, n_heads, d_head,
136
+ depth=1, dropout=0., context_dim=None):
137
+ super().__init__()
138
+ self.in_channels = in_channels
139
+ inner_dim = n_heads * d_head
140
+ self.norm = Normalize(in_channels)
141
+
142
+ self.proj_in = nn.Conv3d(in_channels,
143
+ inner_dim,
144
+ kernel_size=1,
145
+ stride=1,
146
+ padding=0)
147
+
148
+ self.transformer_blocks = nn.ModuleList(
149
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
150
+ for d in range(depth)]
151
+ )
152
+
153
+ self.proj_out = zero_module(nn.Conv3d(inner_dim,
154
+ in_channels,
155
+ kernel_size=1,
156
+ stride=1,
157
+ padding=0))
158
+
159
+ def forward(self, x, context=None):
160
+ # note: if no context is given, cross-attention defaults to self-attention
161
+ b, c, h, w, d = x.shape
162
+ x_in = x
163
+ x = self.norm(x)
164
+ x = self.proj_in(x)
165
+ x = rearrange(x, 'b c h w d -> b (h w d) c')
166
+ for block in self.transformer_blocks:
167
+ x = block(x, context=context)
168
+ x = rearrange(x, 'b (h w d) c -> b c h w d', h=h, w=w, d=d)
169
+ x = self.proj_out(x)
170
+ return x + x_in
models/ddim.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from tqdm import tqdm
7
+
8
+
9
+ def exists(x):
10
+ return x is not None
11
+
12
+
13
+ def default(val, d):
14
+ if exists(val):
15
+ return val
16
+ return d() if isfunction(d) else d
17
+
18
+
19
+ def noise_like(shape, device, repeat=False):
20
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
21
+ noise = lambda: torch.randn(shape, device=device)
22
+ return repeat_noise() if repeat else noise()
23
+
24
+
25
+ def extract(a, t, x_shape):
26
+ b, *_ = t.shape
27
+ out = a.gather(-1, t)
28
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
29
+
30
+
31
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
32
+ # select alphas for computing the variance schedule
33
+ alphas = alphacums[ddim_timesteps]
34
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
35
+
36
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
37
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
38
+ if verbose:
39
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
40
+ print(f'For the chosen value of eta, which is {eta}, '
41
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
42
+ return sigmas, alphas, alphas_prev
43
+
44
+
45
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
46
+ if ddim_discr_method == 'uniform':
47
+ c = num_ddpm_timesteps // num_ddim_timesteps
48
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
49
+ elif ddim_discr_method == 'quad':
50
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
51
+ else:
52
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
53
+
54
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
55
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
56
+ steps_out = ddim_timesteps + 1
57
+ if verbose:
58
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
59
+ return steps_out
60
+
61
+
62
+ class DDIMSampler(object):
63
+ def __init__(self, model, schedule="linear", **kwargs):
64
+ super().__init__()
65
+ self.model = model
66
+ self.ddpm_num_timesteps = model.num_timesteps
67
+ self.schedule = schedule
68
+
69
+ def register_buffer(self, name, attr):
70
+ if type(attr) == torch.Tensor:
71
+ if attr.device != torch.device("cuda"):
72
+ attr = attr.to(torch.device("cuda"))
73
+ setattr(self, name, attr)
74
+
75
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
76
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
77
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
78
+ alphas_cumprod = self.model.alphas_cumprod
79
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
80
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(torch.device("cuda"))
81
+
82
+ self.register_buffer('betas', to_torch(self.model.betas))
83
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
84
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
85
+
86
+ # calculations for diffusion q(x_t | x_{t-1}) and others
87
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
88
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
89
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
90
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
91
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
92
+
93
+ # ddim sampling parameters
94
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
95
+ ddim_timesteps=self.ddim_timesteps,
96
+ eta=ddim_eta, verbose=verbose)
97
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
98
+ self.register_buffer('ddim_alphas', ddim_alphas)
99
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
100
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
101
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
102
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
103
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
104
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
105
+
106
+ @torch.no_grad()
107
+ def sample(self,
108
+ S,
109
+ batch_size,
110
+ shape,
111
+ conditioning=None,
112
+ callback=None,
113
+ img_callback=None,
114
+ quantize_x0=False,
115
+ eta=0.,
116
+ mask=None,
117
+ x0=None,
118
+ temperature=1.,
119
+ noise_dropout=0.,
120
+ score_corrector=None,
121
+ corrector_kwargs=None,
122
+ verbose=True,
123
+ x_T=None,
124
+ log_every_t=100,
125
+ **kwargs
126
+ ):
127
+
128
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
129
+ # sampling
130
+ C, H, W, D = shape
131
+ size = (batch_size, C, H, W, D)
132
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
133
+
134
+ samples, intermediates = self.ddim_sampling(
135
+ conditioning, size,
136
+ callback=callback,
137
+ img_callback=img_callback,
138
+ quantize_denoised=quantize_x0,
139
+ mask=mask, x0=x0,
140
+ ddim_use_original_steps=False,
141
+ noise_dropout=noise_dropout,
142
+ temperature=temperature,
143
+ score_corrector=score_corrector,
144
+ corrector_kwargs=corrector_kwargs,
145
+ x_T=x_T,
146
+ log_every_t=log_every_t
147
+ )
148
+ return samples, intermediates
149
+
150
+ @torch.no_grad()
151
+ def ddim_sampling(self, cond, shape,
152
+ x_T=None, ddim_use_original_steps=False,
153
+ callback=None, timesteps=None, quantize_denoised=False,
154
+ mask=None, x0=None, img_callback=None, log_every_t=100,
155
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
156
+ device = self.model.betas.device
157
+ b = shape[0]
158
+ if x_T is None:
159
+ img = torch.randn(shape, device=device)
160
+ else:
161
+ img = x_T
162
+
163
+ if timesteps is None:
164
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
165
+ elif timesteps is not None and not ddim_use_original_steps:
166
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
167
+ timesteps = self.ddim_timesteps[:subset_end]
168
+
169
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
170
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
171
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
172
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
173
+
174
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
175
+
176
+ for i, step in enumerate(iterator):
177
+ index = total_steps - i - 1
178
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
179
+
180
+ if mask is not None:
181
+ assert x0 is not None
182
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
183
+ img = img_orig * mask + (1. - mask) * img
184
+
185
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
186
+ quantize_denoised=quantize_denoised, temperature=temperature,
187
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
188
+ corrector_kwargs=corrector_kwargs)
189
+ img, pred_x0 = outs
190
+ if callback: callback(i)
191
+ if img_callback: img_callback(pred_x0, i)
192
+
193
+ if index % log_every_t == 0 or index == total_steps - 1:
194
+ intermediates['x_inter'].append(img)
195
+ intermediates['pred_x0'].append(pred_x0)
196
+
197
+ return img, intermediates
198
+
199
+ @torch.no_grad()
200
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
201
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
202
+ b, *_, device = *x.shape, x.device
203
+ e_t = self.model.apply_model(x, t, c)
204
+ if score_corrector is not None:
205
+ assert self.model.parameterization == "eps"
206
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
207
+
208
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
209
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
210
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
211
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
212
+ # select parameters corresponding to the currently considered timestep
213
+ a_t = torch.full((b, 1, 1, 1, 1), alphas[index], device=device)
214
+ a_prev = torch.full((b, 1, 1, 1, 1), alphas_prev[index], device=device)
215
+ sigma_t = torch.full((b, 1, 1, 1, 1), sigmas[index], device=device)
216
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
217
+
218
+ # current prediction for x_0
219
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
220
+ if quantize_denoised:
221
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
222
+ # direction pointing to x_t
223
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
224
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
225
+ if noise_dropout > 0.:
226
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
227
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
228
+ return x_prev, pred_x0
models/ddpm_v2_conditioned.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from inspect import isfunction
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from tqdm import tqdm
9
+
10
+ from models.unet_v2_conditioned import UNetModel
11
+
12
+
13
+ def exists(x):
14
+ return x is not None
15
+
16
+
17
+ def default(val, d):
18
+ if exists(val):
19
+ return val
20
+ return d() if isfunction(d) else d
21
+
22
+
23
+ def noise_like(shape, device, repeat=False):
24
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
25
+ noise = lambda: torch.randn(shape, device=device)
26
+ return repeat_noise() if repeat else noise()
27
+
28
+
29
+ def extract(a, t, x_shape):
30
+ b, *_ = t.shape
31
+ out = a.gather(-1, t)
32
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
33
+
34
+
35
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
36
+ if schedule == "linear":
37
+ betas = (
38
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
39
+ )
40
+
41
+ elif schedule == "cosine":
42
+ timesteps = (
43
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
44
+ )
45
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
46
+ alphas = torch.cos(alphas).pow(2)
47
+ alphas = alphas / alphas[0]
48
+ betas = 1 - alphas[1:] / alphas[:-1]
49
+ betas = np.clip(betas, a_min=0, a_max=0.999)
50
+
51
+ elif schedule == "sqrt_linear":
52
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
53
+ elif schedule == "sqrt":
54
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
55
+ else:
56
+ raise ValueError(f"schedule '{schedule}' unknown.")
57
+ return betas.numpy()
58
+
59
+
60
+ class DDPM(nn.Module):
61
+ def __init__(
62
+ self,
63
+ unet_config,
64
+ timesteps: int = 1000,
65
+ beta_schedule="linear",
66
+ loss_type="l2",
67
+ log_every_t=100,
68
+ clip_denoised=False,
69
+ linear_start=1e-4,
70
+ linear_end=2e-2,
71
+ cosine_s=8e-3,
72
+ original_elbo_weight=0.,
73
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
74
+ l_simple_weight=1.,
75
+ parameterization="eps", # all assuming fixed variance schedules
76
+ learn_logvar=False,
77
+ logvar_init=0.,
78
+ conditioning_key=None,
79
+ ):
80
+ super().__init__()
81
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
82
+ self.parameterization = parameterization
83
+
84
+ if conditioning_key == "unconditioned":
85
+ conditioning_key = None
86
+ self.conditioning_key = conditioning_key
87
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
88
+
89
+ self.clip_denoised = clip_denoised
90
+ self.log_every_t = log_every_t
91
+
92
+ self.v_posterior = v_posterior
93
+ self.original_elbo_weight = original_elbo_weight
94
+ self.l_simple_weight = l_simple_weight
95
+
96
+ self.loss_type = loss_type
97
+
98
+ self.register_schedule(
99
+ beta_schedule=beta_schedule,
100
+ timesteps=timesteps,
101
+ linear_start=linear_start,
102
+ linear_end=linear_end,
103
+ cosine_s=cosine_s,
104
+ )
105
+
106
+ self.learn_logvar = learn_logvar
107
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
108
+ if self.learn_logvar:
109
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
110
+
111
+ def register_schedule(
112
+ self,
113
+ beta_schedule="linear",
114
+ timesteps=1000,
115
+ linear_start=1e-4,
116
+ linear_end=2e-2,
117
+ cosine_s=8e-3
118
+ ):
119
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
120
+ cosine_s=cosine_s)
121
+ alphas = 1. - betas
122
+ alphas_cumprod = np.cumprod(alphas, axis=0)
123
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
124
+
125
+ timesteps, = betas.shape
126
+ self.num_timesteps = int(timesteps)
127
+ self.linear_start = linear_start
128
+ self.linear_end = linear_end
129
+
130
+ to_torch = partial(torch.tensor, dtype=torch.float32)
131
+
132
+ self.register_buffer('betas', to_torch(betas))
133
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
134
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
135
+
136
+ # calculations for diffusion q(x_t | x_{t-1}) and others
137
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
138
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
139
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
140
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
141
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
142
+
143
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
144
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
145
+ 1. - alphas_cumprod) + self.v_posterior * betas
146
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
147
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
148
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
149
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
150
+ self.register_buffer('posterior_mean_coef1', to_torch(
151
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
152
+ self.register_buffer('posterior_mean_coef2', to_torch(
153
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
154
+
155
+ if self.parameterization == "eps":
156
+ lvlb_weights = self.betas ** 2 / (
157
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
158
+ elif self.parameterization == "x0":
159
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
160
+ else:
161
+ raise NotImplementedError("mu not supported")
162
+ # TODO how to choose this term
163
+ lvlb_weights[0] = lvlb_weights[1]
164
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
165
+ assert not torch.isnan(self.lvlb_weights).all()
166
+
167
+ def q_mean_variance(self, x_start, t):
168
+ """
169
+ Get the distribution q(x_t | x_0).
170
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
171
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
172
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
173
+ """
174
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
175
+ variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
176
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
177
+ return mean, variance, log_variance
178
+
179
+ def predict_start_from_noise(self, x_t, t, noise):
180
+ return (
181
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
182
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
183
+ )
184
+
185
+ def q_posterior(self, x_start, x_t, t):
186
+ """
187
+ Compute the mean and variance of the diffusion posterior:
188
+ q(x_{t-1} | x_t, x_0)
189
+ """
190
+ posterior_mean = (
191
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
192
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
193
+ )
194
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
195
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
196
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
197
+
198
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_x0=False):
199
+ """
200
+ Apply the model to get p(x_{t-1} | x_t)
201
+ :param model: the model, which takes a signal and a batch of timesteps
202
+ as input.
203
+ :param x: the [N x C x ...] tensor at time t.
204
+ :param t: a 1-D Tensor of timesteps.
205
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
206
+
207
+ """
208
+ t_in = t
209
+ model_out = self.apply_model(x, t_in, c)
210
+ if self.parameterization == "eps":
211
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
212
+ elif self.parameterization == "x0":
213
+ x_recon = model_out
214
+
215
+ if clip_denoised:
216
+ x_recon.clamp_(-1., 1.)
217
+
218
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
219
+ if return_x0:
220
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
221
+ else:
222
+ return model_mean, posterior_variance, posterior_log_variance
223
+
224
+ @torch.no_grad()
225
+ def p_sample(
226
+ self,
227
+ x,
228
+ c,
229
+ t,
230
+ clip_denoised=True,
231
+ repeat_noise=False,
232
+ return_x0=False,
233
+ temperature=1.,
234
+ noise_dropout=0.,
235
+ ):
236
+ """
237
+ Sample x_{t-1} from the model at the given timestep.
238
+ :param x: the current tensor at x_{t-1}.
239
+ :param t: the value of t, starting at 0 for the first diffusion step.
240
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
241
+ """
242
+
243
+ b, *_, device = *x.shape, x.device
244
+ outputs = self.p_mean_variance(
245
+ x=x,
246
+ c=c,
247
+ t=t,
248
+ clip_denoised=clip_denoised,
249
+ return_x0=return_x0,
250
+ )
251
+ if return_x0:
252
+ model_mean, _, model_log_variance, x0 = outputs
253
+ else:
254
+ model_mean, _, model_log_variance = outputs
255
+
256
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
257
+ if noise_dropout > 0.:
258
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
259
+ # no noise when t == 0
260
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
261
+ if return_x0:
262
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
263
+ else:
264
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
265
+
266
+ @torch.no_grad()
267
+ def p_sample_loop(self, cond, shape, return_intermediates=False):
268
+ device = self.betas.device
269
+
270
+ b = shape[0]
271
+ img = torch.randn(shape, device=device)
272
+ intermediates = [img]
273
+
274
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
275
+ img = self.p_sample(img, cond, torch.full((b,), i, device=device, dtype=torch.long),
276
+ clip_denoised=self.clip_denoised)
277
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
278
+ intermediates.append(img)
279
+ if return_intermediates:
280
+ return img, intermediates
281
+ return img
282
+
283
+ @torch.no_grad()
284
+ def sample(self, batch_size=16, return_intermediates=False):
285
+ image_size = self.image_size
286
+ channels = self.channels
287
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
288
+ return_intermediates=return_intermediates)
289
+
290
+ def q_sample(self, x_start, t, noise=None):
291
+ """
292
+ Diffuse the data for a given number of diffusion steps.
293
+ In other words, sample from q(x_t | x_0).
294
+ :param x_start: the initial data batch.
295
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
296
+ :param noise: if specified, the split-out normal noise.
297
+ :return: A noisy version of x_start.
298
+ """
299
+ noise = default(noise, lambda: torch.randn_like(x_start))
300
+
301
+ return (
302
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
303
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
304
+ )
305
+
306
+ def get_loss(self, pred, target, mean=True):
307
+ if self.loss_type == 'l1':
308
+ loss = (target - pred).abs()
309
+ if mean:
310
+ loss = loss.mean()
311
+ elif self.loss_type == 'l2':
312
+ if mean:
313
+ loss = torch.nn.functional.mse_loss(target, pred)
314
+ else:
315
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
316
+ else:
317
+ raise NotImplementedError("unknown loss type '{loss_type}'")
318
+
319
+ return loss
320
+
321
+ def p_losses(self, x_start, cond, t, noise=None):
322
+ noise = default(noise, lambda: torch.randn_like(x_start))
323
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
324
+ model_output = self.apply_model(x_noisy, t, cond)
325
+
326
+ loss_dict = {}
327
+ if self.parameterization == "eps":
328
+ target = noise
329
+ elif self.parameterization == "x0":
330
+ target = x_start
331
+ else:
332
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
333
+
334
+ loss_simple = self.get_loss(model_output, target, mean=False).mean(dim=[1, 2, 3, 4])
335
+ loss_dict.update({f'loss_simple': loss_simple.mean()})
336
+
337
+ logvar_t = self.logvar[t].to(x_start.device)
338
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
339
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
340
+ if self.learn_logvar:
341
+ loss_dict.update({f'loss_gamma': loss.mean()})
342
+ loss_dict.update({'logvar': self.logvar.data.mean()})
343
+
344
+ loss = self.l_simple_weight * loss.mean()
345
+
346
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3, 4))
347
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
348
+ loss_dict.update({f'loss_vlb': loss_vlb})
349
+ loss += (self.original_elbo_weight * loss_vlb)
350
+ loss_dict.update({f'loss': loss})
351
+
352
+ return loss, loss_dict
353
+
354
+ def forward(self, x, c, *args, **kwargs):
355
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long()
356
+ return self.p_losses(x, c, t, *args, **kwargs)
357
+
358
+ def configure_optimizers(self):
359
+ lr = self.learning_rate
360
+ params = list(self.model.parameters())
361
+ if self.learn_logvar:
362
+ print('Diffusion model optimizing logvar')
363
+ params.append(self.logvar)
364
+ opt = torch.optim.AdamW(params, lr=lr)
365
+ return opt
366
+
367
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
368
+
369
+ if isinstance(cond, dict):
370
+ # hybrid case, cond is exptected to be a dict
371
+ pass
372
+ else:
373
+ if not isinstance(cond, list):
374
+ cond = [cond]
375
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
376
+ cond = {key: cond}
377
+
378
+ x_recon = self.model(x_noisy, t, **cond)
379
+
380
+ if isinstance(x_recon, tuple) and not return_ids:
381
+ return x_recon[0]
382
+ else:
383
+ return x_recon
384
+
385
+
386
+
387
+ class DiffusionWrapper(nn.Module):
388
+ def __init__(self, unet_config, conditioning_key):
389
+ super().__init__()
390
+ self.diffusion_model = UNetModel(
391
+ **unet_config.get("params", dict())
392
+ )
393
+ self.conditioning_key = conditioning_key
394
+
395
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
396
+ xc = torch.cat([x] + c_concat, dim=1)
397
+ cc = torch.cat(c_crossattn, 1)
398
+ out = self.diffusion_model(xc, t, context=cc)
399
+
400
+
401
+ return out
models/unet_v2_conditioned.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import abstractmethod
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import repeat
10
+
11
+ from models.attention import SpatialTransformer
12
+
13
+
14
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
15
+ """
16
+ Create sinusoidal timestep embeddings.
17
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
18
+ These may be fractional.
19
+ :param dim: the dimension of the output.
20
+ :param max_period: controls the minimum frequency of the embeddings.
21
+ :return: an [N x dim] Tensor of positional embeddings.
22
+ """
23
+ if not repeat_only:
24
+ half = dim // 2
25
+ freqs = torch.exp(
26
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
27
+ ).to(device=timesteps.device)
28
+ args = timesteps[:, None].float() * freqs[None]
29
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
30
+ if dim % 2:
31
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
32
+ else:
33
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
34
+ return embedding
35
+
36
+
37
+ def zero_module(module):
38
+ """
39
+ Zero out the parameters of a module and return it.
40
+ """
41
+ for p in module.parameters():
42
+ p.detach().zero_()
43
+ return module
44
+
45
+
46
+ class TimestepBlock(nn.Module):
47
+ @abstractmethod
48
+ def forward(self, x, emb):
49
+ """
50
+ Apply the module to `x` given `emb` timestep embeddings.
51
+ """
52
+
53
+
54
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
55
+ """
56
+ A sequential module that passes timestep embeddings to the children that
57
+ support it as an extra input.
58
+ """
59
+
60
+ def forward(self, x, emb, context=None):
61
+ for layer in self:
62
+ if isinstance(layer, TimestepBlock):
63
+ x = layer(x, emb)
64
+ elif isinstance(layer, SpatialTransformer):
65
+ x = layer(x, context)
66
+ else:
67
+ x = layer(x)
68
+ return x
69
+
70
+
71
+ def Normalize(in_channels):
72
+ return nn.GroupNorm(
73
+ num_groups=32,
74
+ num_channels=in_channels,
75
+ eps=1e-6,
76
+ affine=True
77
+ )
78
+
79
+
80
+ def count_flops_attn(model, _x, y):
81
+ """
82
+ A counter for the `thop` package to count the operations in an
83
+ attention operation.
84
+ Meant to be used like:
85
+ macs, params = thop.profile(
86
+ model,
87
+ inputs=(inputs, timestamps),
88
+ custom_ops={QKVAttention: QKVAttention.count_flops},
89
+ )
90
+ """
91
+ b, c, *spatial = y[0].shape
92
+ num_spatial = int(np.prod(spatial))
93
+ # We perform two matmuls with the same number of ops.
94
+ # The first computes the weight matrix, the second computes
95
+ # the combination of the value vectors.
96
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
97
+ model.total_ops += th.DoubleTensor([matmul_ops])
98
+
99
+
100
+ class QKVAttentionLegacy(nn.Module):
101
+ """
102
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
103
+ """
104
+
105
+ def __init__(self, n_heads):
106
+ super().__init__()
107
+ self.n_heads = n_heads
108
+
109
+ def forward(self, qkv):
110
+ """
111
+ Apply QKV attention.
112
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
113
+ :return: an [N x (H * C) x T] tensor after attention.
114
+ """
115
+ bs, width, length = qkv.shape
116
+ assert width % (3 * self.n_heads) == 0
117
+ ch = width // (3 * self.n_heads)
118
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
119
+ scale = 1 / math.sqrt(math.sqrt(ch))
120
+ weight = th.einsum(
121
+ "bct,bcs->bts", q * scale, k * scale
122
+ ) # More stable with f16 than dividing afterwards
123
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
124
+ a = th.einsum("bts,bcs->bct", weight, v)
125
+ return a.reshape(bs, -1, length)
126
+
127
+ @staticmethod
128
+ def count_flops(model, _x, y):
129
+ return count_flops_attn(model, _x, y)
130
+
131
+
132
+ class AttentionBlock(nn.Module):
133
+ """
134
+ An attention block that allows spatial positions to attend to each other.
135
+ Originally ported from here, but adapted to the N-d case.
136
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ channels,
142
+ num_heads=1,
143
+ num_head_channels=-1,
144
+ use_checkpoint=False,
145
+ ):
146
+ super().__init__()
147
+ self.channels = channels
148
+ if num_head_channels == -1:
149
+ self.num_heads = num_heads
150
+ else:
151
+ assert (
152
+ channels % num_head_channels == 0
153
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
154
+ self.num_heads = channels // num_head_channels
155
+ self.use_checkpoint = use_checkpoint
156
+ self.norm = Normalize(channels)
157
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
158
+ self.attention = QKVAttentionLegacy(self.num_heads)
159
+
160
+ self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
161
+
162
+ def forward(self, x):
163
+ return self._forward(x, )
164
+
165
+ def _forward(self, x):
166
+ b, c, *spatial = x.shape
167
+ x = x.reshape(b, c, -1)
168
+ qkv = self.qkv(self.norm(x))
169
+ h = self.attention(qkv)
170
+ h = self.proj_out(h)
171
+ return (x + h).reshape(b, c, *spatial)
172
+
173
+
174
+ class Downsample(nn.Module):
175
+ """
176
+ A downsampling layer with an optional convolution.
177
+
178
+ :param channels: channels in the inputs and outputs.
179
+ :param use_conv: a bool determining if a convolution is applied.
180
+ """
181
+
182
+ def __init__(self, channels, use_conv, out_channels=None, padding=1):
183
+ super().__init__()
184
+ self.channels = channels
185
+ self.out_channels = out_channels or channels
186
+ self.use_conv = use_conv
187
+ if use_conv:
188
+ self.op = nn.Conv3d(
189
+ self.channels, self.out_channels, 3, stride=2, padding=padding
190
+ )
191
+ else:
192
+ assert self.channels == self.out_channels
193
+ self.op = nn.AvgPool3d(kernel_size=2, stride=2)
194
+
195
+ def forward(self, x):
196
+ assert x.shape[1] == self.channels
197
+ return self.op(x)
198
+
199
+
200
+ class Upsample(nn.Module):
201
+ """
202
+ An upsampling layer with an optional convolution.
203
+ :param channels: channels in the inputs and outputs.
204
+ :param use_conv: a bool determining if a convolution is applied.
205
+ """
206
+
207
+ def __init__(self, channels, use_conv, out_channels=None, padding=1):
208
+ super().__init__()
209
+ self.channels = channels
210
+ self.out_channels = out_channels or channels
211
+ self.use_conv = use_conv
212
+ if use_conv:
213
+ self.conv = nn.Conv3d(self.channels, self.out_channels, 3, padding=padding)
214
+
215
+ def forward(self, x):
216
+ assert x.shape[1] == self.channels
217
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
218
+ if self.use_conv:
219
+ x = self.conv(x)
220
+ return x
221
+
222
+
223
+ class ResBlock(TimestepBlock):
224
+ """
225
+ A residual block that can optionally change the number of channels.
226
+ :param channels: the number of input channels.
227
+ :param emb_channels: the number of timestep embedding channels.
228
+ :param dropout: the rate of dropout.
229
+ :param out_channels: if specified, the number of out channels.
230
+ :param use_conv: if True and out_channels is specified, use a spatial
231
+ convolution instead of a smaller 1x1 convolution to change the
232
+ channels in the skip connection.
233
+ :param up: if True, use this block for upsampling.
234
+ :param down: if True, use this block for downsampling.
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ channels,
240
+ emb_channels,
241
+ dropout,
242
+ out_channels=None,
243
+ use_conv=False,
244
+ use_scale_shift_norm=False,
245
+ up=False,
246
+ down=False,
247
+ ):
248
+ super().__init__()
249
+ self.channels = channels
250
+ self.emb_channels = emb_channels
251
+ self.dropout = dropout
252
+ self.out_channels = out_channels or channels
253
+ self.use_conv = use_conv
254
+ self.use_scale_shift_norm = use_scale_shift_norm
255
+
256
+ self.in_layers = nn.Sequential(
257
+ Normalize(channels),
258
+ nn.SiLU(),
259
+ nn.Conv3d(channels, self.out_channels, 3, padding=1),
260
+ )
261
+
262
+ self.updown = up or down
263
+
264
+ if up:
265
+ self.h_upd = Upsample(channels, False)
266
+ self.x_upd = Upsample(channels, False)
267
+ elif down:
268
+ self.h_upd = Downsample(channels, False)
269
+ self.x_upd = Downsample(channels, False)
270
+ else:
271
+ self.h_upd = self.x_upd = nn.Identity()
272
+
273
+ self.emb_layers = nn.Sequential(
274
+ nn.SiLU(),
275
+ nn.Linear(
276
+ emb_channels,
277
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
278
+ ),
279
+ )
280
+ self.out_layers = nn.Sequential(
281
+ Normalize(self.out_channels),
282
+ nn.SiLU(),
283
+ nn.Dropout(p=dropout),
284
+ zero_module(
285
+ nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1)
286
+ ),
287
+ )
288
+
289
+ if self.out_channels == channels:
290
+ self.skip_connection = nn.Identity()
291
+ elif use_conv:
292
+ self.skip_connection = nn.Conv3d(
293
+ channels, self.out_channels, 3, padding=1
294
+ )
295
+ else:
296
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1)
297
+
298
+ def forward(self, x, emb):
299
+ return self._forward(x, emb)
300
+
301
+ def _forward(self, x, emb):
302
+ if self.updown:
303
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
304
+ h = in_rest(x)
305
+ h = self.h_upd(h)
306
+ x = self.x_upd(x)
307
+ h = in_conv(h)
308
+ else:
309
+ h = self.in_layers(x)
310
+ emb_out = self.emb_layers(emb).type(h.dtype)
311
+ while len(emb_out.shape) < len(h.shape):
312
+ emb_out = emb_out[..., None]
313
+ if self.use_scale_shift_norm:
314
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
315
+ scale, shift = th.chunk(emb_out, 2, dim=1)
316
+ h = out_norm(h) * (1 + scale) + shift
317
+ h = out_rest(h)
318
+ else:
319
+ h = h + emb_out
320
+ h = self.out_layers(h)
321
+ return self.skip_connection(x) + h
322
+
323
+
324
+ class UNetModel(nn.Module):
325
+ def __init__(
326
+ self,
327
+ image_size,
328
+ in_channels,
329
+ model_channels,
330
+ out_channels,
331
+ num_res_blocks,
332
+ attention_resolutions,
333
+ dropout=0,
334
+ channel_mult=(1, 2, 4, 8),
335
+ conv_resample=True,
336
+ num_classes=None,
337
+ num_heads=1,
338
+ num_head_channels=-1,
339
+ num_heads_upsample=-1,
340
+ use_scale_shift_norm=False,
341
+ resblock_updown=False,
342
+ use_spatial_transformer=False, # custom transformer support
343
+ transformer_depth=1, # custom transformer support
344
+ context_dim=None, # custom transformer support
345
+ n_embed=None # custom support for prediction of discrete ids into codebook of first stage vq model
346
+ ):
347
+ super().__init__()
348
+
349
+ if use_spatial_transformer:
350
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
351
+
352
+ if context_dim is not None:
353
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
354
+
355
+
356
+
357
+ if num_heads_upsample == -1:
358
+ num_heads_upsample = num_heads
359
+
360
+ self.image_size = image_size
361
+ self.in_channels = in_channels
362
+ self.model_channels = model_channels
363
+ self.out_channels = out_channels
364
+ self.num_res_blocks = num_res_blocks
365
+ self.attention_resolutions = attention_resolutions
366
+ self.dropout = dropout
367
+ self.channel_mult = channel_mult
368
+ self.conv_resample = conv_resample
369
+ self.num_classes = num_classes
370
+ self.num_heads = num_heads
371
+ self.num_head_channels = num_head_channels
372
+ self.num_heads_upsample = num_heads_upsample
373
+ self.predict_codebook_ids = n_embed is not None
374
+
375
+ time_embed_dim = model_channels * 4
376
+ self.time_embed = nn.Sequential(
377
+ nn.Linear(model_channels, time_embed_dim),
378
+ nn.SiLU(),
379
+ nn.Linear(time_embed_dim, time_embed_dim),
380
+ )
381
+
382
+ if self.num_classes is not None:
383
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
384
+
385
+ self.input_blocks = nn.ModuleList(
386
+ [
387
+ TimestepEmbedSequential(
388
+ nn.Conv3d(in_channels, model_channels, 3, padding=1)
389
+ )
390
+ ]
391
+ )
392
+ self._feature_size = model_channels
393
+ input_block_chans = [model_channels]
394
+ ch = model_channels
395
+ ds = 1
396
+ for level, mult in enumerate(channel_mult):
397
+ for _ in range(num_res_blocks):
398
+ layers = [
399
+ ResBlock(
400
+ ch,
401
+ time_embed_dim,
402
+ dropout,
403
+ out_channels=mult * model_channels,
404
+ use_scale_shift_norm=use_scale_shift_norm,
405
+ )
406
+ ]
407
+ ch = mult * model_channels
408
+ if ds in attention_resolutions:
409
+ dim_head = ch // num_heads
410
+ layers.append(
411
+ AttentionBlock(
412
+ ch,
413
+ num_heads=num_heads,
414
+ num_head_channels=num_head_channels,
415
+ ) if not use_spatial_transformer else SpatialTransformer(
416
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
417
+ )
418
+ )
419
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
420
+ self._feature_size += ch
421
+ input_block_chans.append(ch)
422
+ if level != len(channel_mult) - 1:
423
+ out_ch = ch
424
+ self.input_blocks.append(
425
+ TimestepEmbedSequential(
426
+ ResBlock(
427
+ ch,
428
+ time_embed_dim,
429
+ dropout,
430
+ out_channels=out_ch,
431
+ use_scale_shift_norm=use_scale_shift_norm,
432
+ down=True,
433
+ )
434
+ if resblock_updown
435
+ else Downsample(
436
+ ch, conv_resample, out_channels=out_ch
437
+ )
438
+ )
439
+ )
440
+ ch = out_ch
441
+ input_block_chans.append(ch)
442
+ ds *= 2
443
+ self._feature_size += ch
444
+
445
+ dim_head = ch // num_heads
446
+ self.middle_block = TimestepEmbedSequential(
447
+ ResBlock(
448
+ ch,
449
+ time_embed_dim,
450
+ dropout,
451
+ use_scale_shift_norm=use_scale_shift_norm,
452
+ ),
453
+ AttentionBlock(
454
+ ch,
455
+ num_heads=num_heads,
456
+ num_head_channels=num_head_channels,
457
+ ) if not use_spatial_transformer else SpatialTransformer(
458
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
459
+ ),
460
+ ResBlock(
461
+ ch,
462
+ time_embed_dim,
463
+ dropout,
464
+ use_scale_shift_norm=use_scale_shift_norm,
465
+ ),
466
+ )
467
+ self._feature_size += ch
468
+
469
+ self.output_blocks = nn.ModuleList([])
470
+ for level, mult in list(enumerate(channel_mult))[::-1]:
471
+ for i in range(num_res_blocks + 1):
472
+ ich = input_block_chans.pop()
473
+ layers = [
474
+ ResBlock(
475
+ ch + ich,
476
+ time_embed_dim,
477
+ dropout,
478
+ out_channels=model_channels * mult,
479
+ use_scale_shift_norm=use_scale_shift_norm,
480
+ )
481
+ ]
482
+ ch = model_channels * mult
483
+ if ds in attention_resolutions:
484
+ dim_head = ch // num_heads
485
+ layers.append(
486
+ AttentionBlock(
487
+ ch,
488
+ num_heads=num_heads_upsample,
489
+ num_head_channels=num_head_channels,
490
+ ) if not use_spatial_transformer else SpatialTransformer(
491
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
492
+ )
493
+ )
494
+ if level and i == num_res_blocks:
495
+ out_ch = ch
496
+ layers.append(
497
+ ResBlock(
498
+ ch,
499
+ time_embed_dim,
500
+ dropout,
501
+ out_channels=out_ch,
502
+ use_scale_shift_norm=use_scale_shift_norm,
503
+ up=True,
504
+ )
505
+ if resblock_updown
506
+ else Upsample(ch, conv_resample, out_channels=out_ch)
507
+ )
508
+ ds //= 2
509
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
510
+ self._feature_size += ch
511
+
512
+ self.out = nn.Sequential(
513
+ Normalize(ch),
514
+ nn.SiLU(),
515
+ zero_module(nn.Conv3d(model_channels, out_channels, 3, padding=1)),
516
+ )
517
+ if self.predict_codebook_ids:
518
+ self.id_predictor = nn.Sequential(
519
+ Normalize(ch),
520
+ nn.Conv3d(model_channels, n_embed, 1),
521
+ )
522
+
523
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
524
+ """
525
+ Apply the model to an input batch.
526
+ :param x: an [N x C x ...] Tensor of inputs.
527
+ :param timesteps: a 1-D batch of timesteps.
528
+ :param context: conditioning plugged in via crossattn
529
+ :param y: an [N] Tensor of labels, if class-conditional.
530
+ :return: an [N x C x ...] Tensor of outputs.
531
+ """
532
+ assert (y is not None) == (
533
+ self.num_classes is not None
534
+ ), "must specify y if and only if the model is class-conditional"
535
+ assert timesteps is not None, 'need to implement no-timestep usage'
536
+ hs = []
537
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
538
+ emb = self.time_embed(t_emb)
539
+
540
+ if self.num_classes is not None:
541
+ assert y.shape == (x.shape[0],)
542
+ emb = emb + self.label_emb(y)
543
+
544
+ h = x
545
+ for module in self.input_blocks:
546
+ h = module(h, emb, context)
547
+ hs.append(h)
548
+ h = self.middle_block(h, emb, context)
549
+ for module in self.output_blocks:
550
+ h = th.cat([h, hs.pop()], dim=1)
551
+ h = module(h, emb, context)
552
+
553
+ if self.predict_codebook_ids:
554
+ # return self.out(h), self.id_predictor(h)
555
+ return self.id_predictor(h)
556
+ else:
557
+ return self.out(h)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops>=0.3.2
2
+ gradio==3.1.1
3
+ mediapy==1.0.3
4
+ mlflow
5
+ nibabel
6
+ omegaconf==2.1.1
7
+ opencv-python==4.6.0.66
8
+ plotly==5.9.0
9
+ scikit-image==0.19.3
10
+ tqdm
11
+
12
+ -f https://download.pytorch.org/whl/torch_stable.html
13
+ torch==1.11.0+cpu
trained_models/ddpm/.gitkeep ADDED
File without changes
trained_models/ddpm/MLmodel ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ artifact_path: final_model
2
+ flavors:
3
+ python_function:
4
+ data: data
5
+ env: conda.yaml
6
+ loader_module: mlflow.pytorch
7
+ pickle_module_name: mlflow.pytorch.pickle_module
8
+ python_version: 3.8.12
9
+ pytorch:
10
+ model_data: data
11
+ pytorch_version: 1.11.0a0+bfe5ad2
12
+ model_uuid: 6cf6d11600204707bfb1373170c6c137
13
+ run_id: c7b62c88595843d3a404368c87df5607
14
+ utc_time_created: '2022-04-19 14:50:01.769881'
trained_models/ddpm/conda.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ channels:
2
+ - conda-forge
3
+ dependencies:
4
+ - python=3.8.12
5
+ - pip
6
+ - pip:
7
+ - mlflow
8
+ - attrs==21.4.0
9
+ - cloudpickle==2.0.0
10
+ - einops==0.4.0
11
+ - ipython==7.31.0
12
+ - omegaconf==2.1.1
13
+ - torch==1.11.0a0
14
+ - tqdm==4.62.3
15
+ name: mlflow-env
trained_models/ddpm/data/pickle_module_info.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ mlflow.pytorch.pickle_module
trained_models/ddpm/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ mlflow
2
+ attrs==21.4.0
3
+ cloudpickle==2.0.0
4
+ einops==0.4.0
5
+ ipython==7.31.0
6
+ omegaconf==2.1.1
7
+ torch==1.11.0a0
8
+ tqdm==4.62.3
trained_models/vae/.gitkeep ADDED
File without changes
trained_models/vae/MLmodel ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ artifact_path: final_model
2
+ flavors:
3
+ python_function:
4
+ data: data
5
+ env: conda.yaml
6
+ loader_module: mlflow.pytorch
7
+ pickle_module_name: mlflow.pytorch.pickle_module
8
+ python_version: 3.8.12
9
+ pytorch:
10
+ model_data: data
11
+ pytorch_version: 1.11.0a0+bfe5ad2
12
+ model_uuid: b09405e06c9f42d5902b2467888ec060
13
+ run_id: 2f37b3b604a44b189b020028aa53f991
14
+ utc_time_created: '2022-03-29 20:38:58.307349'
trained_models/vae/conda.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ channels:
2
+ - conda-forge
3
+ dependencies:
4
+ - python=3.8.12
5
+ - pip
6
+ - pip:
7
+ - mlflow
8
+ - attrs==21.4.0
9
+ - cloudpickle==2.0.0
10
+ - ipython==7.31.0
11
+ - omegaconf==2.1.1
12
+ - torch==1.11.0a0
13
+ - tqdm==4.62.3
14
+ name: mlflow-env
trained_models/vae/data/pickle_module_info.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ mlflow.pytorch.pickle_module
trained_models/vae/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ mlflow
2
+ attrs==21.4.0
3
+ cloudpickle==2.0.0
4
+ ipython==7.31.0
5
+ omegaconf==2.1.1
6
+ torch==1.11.0a0
7
+ tqdm==4.62.3