This view is limited to 50 files because it contains too many changes.Β  See the raw diff here.
Files changed (50) hide show
  1. README.md +0 -1
  2. app.py +76 -185
  3. apply_net.py +0 -359
  4. checkpoints/VITONHD.ckpt +0 -3
  5. checkpoints/VITONHD_1024.ckpt +0 -3
  6. checkpoints/eternal_1024.ckpt +0 -3
  7. checkpoints/humanparsing/.gitkeep +0 -0
  8. checkpoints/humanparsing/parsing_atr.onnx +0 -3
  9. checkpoints/humanparsing/parsing_lip.onnx +0 -3
  10. checkpoints/openpose/ckpts/body_pose_model.pth +0 -3
  11. cldm/cldm.py +0 -138
  12. cldm/hack.py +0 -111
  13. cldm/model.py +0 -9
  14. cldm/plms_hacked.py +0 -251
  15. cldm/warping_cldm_network.py +0 -357
  16. configs/VITON.yaml +0 -100
  17. detectron2 +1 -1
  18. examples/garment/00055_00.jpg +0 -0
  19. examples/garment/00470_00.jpg +0 -0
  20. examples/garment/08973_00.jpg +0 -0
  21. examples/garment/12469_00.jpg +0 -0
  22. examples/model/04913_00.jpg +0 -0
  23. examples/model/05032_00.jpg +0 -0
  24. examples_eternal/garment/1.jpg +0 -0
  25. examples_eternal/garment/2.jpg +0 -0
  26. examples_eternal/garment/3.jpg +0 -0
  27. examples_eternal/garment/4.jpg +0 -0
  28. examples_eternal/garment/5.jpg +0 -0
  29. examples_eternal/garment/6.jpg +0 -0
  30. examples_eternal/model/1.jpg +0 -0
  31. examples_eternal/model/2.jpg +0 -0
  32. examples_eternal/model/3.jpg +0 -0
  33. examples_eternal/model/4.jpg +0 -0
  34. examples_eternal/model/6.jpg +0 -0
  35. ldm/data/__init__.py +0 -0
  36. ldm/data/util.py +0 -24
  37. ldm/models/autoencoder.py +0 -203
  38. ldm/models/diffusion/__init__.py +0 -0
  39. ldm/models/diffusion/ddim.py +0 -377
  40. ldm/models/diffusion/ddpm.py +0 -1875
  41. ldm/models/diffusion/dpm_solver/__init__.py +0 -1
  42. ldm/models/diffusion/dpm_solver/dpm_solver.py +0 -1154
  43. ldm/models/diffusion/dpm_solver/sampler.py +0 -87
  44. ldm/models/diffusion/plms.py +0 -244
  45. ldm/models/diffusion/sampling_util.py +0 -22
  46. ldm/modules/attention.py +0 -330
  47. ldm/modules/diffusionmodules/__init__.py +0 -0
  48. ldm/modules/diffusionmodules/model.py +0 -852
  49. ldm/modules/diffusionmodules/openaimodel.py +0 -790
  50. ldm/modules/diffusionmodules/upscaling.py +0 -81
README.md CHANGED
@@ -4,7 +4,6 @@ emoji: πŸ‘•πŸ‘”πŸ‘—
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
7
- python_version: 3.8.5
8
  sdk_version: 4.23.0
9
  app_file: app.py
10
  pinned: false
 
4
  colorFrom: blue
5
  colorTo: blue
6
  sdk: gradio
 
7
  sdk_version: 4.23.0
8
  app_file: app.py
9
  pinned: false
app.py CHANGED
@@ -1,212 +1,94 @@
1
- from preprocess.detectron2.projects.DensePose.apply_net_gradio import DensePose4Gradio
2
- from preprocess.humanparsing.run_parsing import Parsing
3
- from preprocess.openpose.run_openpose import OpenPose
4
-
5
  import os
6
  import sys
7
  import time
8
- from glob import glob
9
- from os.path import join as opj
10
  from pathlib import Path
11
 
12
  import gradio as gr
13
  import torch
14
- from omegaconf import OmegaConf
15
  from PIL import Image
16
- import spaces
17
- print(torch.cuda.is_available(), torch.cuda.device_count())
18
-
19
 
20
- from cldm.model import create_model
21
- from cldm.plms_hacked import PLMSSampler
22
- from utils_stableviton import get_mask_location, get_batch, tensor2img, center_crop
23
 
24
  PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
25
- sys.path.insert(0, str(PROJECT_ROOT))
 
 
 
 
 
 
26
 
27
- IMG_H = 1024
28
- IMG_W = 768
29
 
30
  openpose_model_hd = OpenPose(0)
31
- openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
32
  parsing_model_hd = Parsing(0)
33
  densepose_model_hd = DensePose4Gradio(
34
  cfg='preprocess/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml',
35
  model='https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl',
36
  )
 
37
 
38
  category_dict = ['upperbody', 'lowerbody', 'dress']
39
  category_dict_utils = ['upper_body', 'lower_body', 'dresses']
40
 
41
- # #### model init >>>>
42
- config = OmegaConf.load("./configs/VITON.yaml")
43
- config.model.params.img_H = IMG_H
44
- config.model.params.img_W = IMG_W
45
- params = config.model.params
46
-
47
- model = create_model(config_path=None, config=config)
48
- model.load_state_dict(torch.load("./checkpoints/eternal_1024.ckpt", map_location="cpu")["state_dict"])
49
- model = model.cuda()
50
- model.eval()
51
- sampler = PLMSSampler(model)
52
-
53
- model2 = create_model(config_path=None, config=config)
54
- model2.load_state_dict(torch.load("./checkpoints/VITONHD_1024.ckpt", map_location="cpu")["state_dict"])
55
- model2 = model.cuda()
56
- model2.eval()
57
- sampler2 = PLMSSampler(model2)
58
- # #### model init <<<<
59
 
60
- @spaces.GPU
61
- @torch.autocast("cuda")
62
- @torch.no_grad()
63
- def stable_viton_model_hd(
64
- batch,
65
- n_steps,
66
- ):
67
- z, cond = model.get_input(batch, params.first_stage_key)
68
- z = z
69
- bs = z.shape[0]
70
- c_crossattn = cond["c_crossattn"][0][:bs]
71
- if c_crossattn.ndim == 4:
72
- c_crossattn = model.get_learned_conditioning(c_crossattn)
73
- cond["c_crossattn"] = [c_crossattn]
74
- uc_cross = model.get_unconditional_conditioning(bs)
75
- uc_full = {"c_concat": cond["c_concat"], "c_crossattn": [uc_cross]}
76
- uc_full["first_stage_cond"] = cond["first_stage_cond"]
77
- for k, v in batch.items():
78
- if isinstance(v, torch.Tensor):
79
- batch[k] = v.cuda()
80
- sampler.model.batch = batch
81
 
82
- ts = torch.full((1,), 999, device=z.device, dtype=torch.long)
83
- start_code = model.q_sample(z, ts)
84
- torch.cuda.empty_cache()
85
- output, _, _ = sampler.sample(
86
- n_steps,
87
- bs,
88
- (4, IMG_H//8, IMG_W//8),
89
- cond,
90
- x_T=start_code,
91
- verbose=False,
92
- eta=0.0,
93
- unconditional_conditioning=uc_full,
94
- )
95
-
96
- output = model.decode_first_stage(output)
97
- output = tensor2img(output)
98
- pil_output = Image.fromarray(output)
99
- return pil_output
100
-
101
- @spaces.GPU
102
- @torch.autocast("cuda")
103
- @torch.no_grad()
104
- def stable_viton_model_hd2(
105
- batch,
106
- n_steps,
107
- ):
108
- z, cond = model2.get_input(batch, params.first_stage_key)
109
- z = z
110
- bs = z.shape[0]
111
- c_crossattn = cond["c_crossattn"][0][:bs]
112
- if c_crossattn.ndim == 4:
113
- c_crossattn = model2.get_learned_conditioning(c_crossattn)
114
- cond["c_crossattn"] = [c_crossattn]
115
- uc_cross = model2.get_unconditional_conditioning(bs)
116
- uc_full = {"c_concat": cond["c_concat"], "c_crossattn": [uc_cross]}
117
- uc_full["first_stage_cond"] = cond["first_stage_cond"]
118
- for k, v in batch.items():
119
- if isinstance(v, torch.Tensor):
120
- batch[k] = v.cuda()
121
- sampler2.model.batch = batch
122
 
123
- ts = torch.full((1,), 999, device=z.device, dtype=torch.long)
124
- start_code = model2.q_sample(z, ts)
125
- torch.cuda.empty_cache()
126
- output, _, _ = sampler2.sample(
127
- n_steps,
128
- bs,
129
- (4, IMG_H//8, IMG_W//8),
130
- cond,
131
- x_T=start_code,
132
- verbose=False,
133
- eta=0.0,
134
- unconditional_conditioning=uc_full,
135
- )
136
-
137
- output = model2.decode_first_stage(output)
138
- output = tensor2img(output)
139
- pil_output = Image.fromarray(output)
140
- return pil_output
141
-
142
- @spaces.GPU
143
- @torch.no_grad()
144
- def process_hd(vton_img, garm_img, n_steps, is_custom):
145
  model_type = 'hd'
146
  category = 0 # 0:upperbody; 1:lowerbody; 2:dress
147
 
148
- stt = time.time()
149
- print('load images... ', end='')
150
- # garm_img = Image.open(garm_img).resize((IMG_W, IMG_H))
151
- # vton_img = Image.open(vton_img).resize((IMG_W, IMG_H))
152
- garm_img = Image.open(garm_img)
153
- vton_img = Image.open(vton_img)
154
-
155
- vton_img = center_crop(vton_img)
156
- garm_img = garm_img.resize((IMG_W, IMG_H))
157
- vton_img = vton_img.resize((IMG_W, IMG_H))
158
-
159
- print('%.2fs' % (time.time() - stt))
160
-
161
- stt = time.time()
162
- print('get agnostic map... ', end='')
163
- keypoints = openpose_model_hd(vton_img.resize((IMG_W, IMG_H)))
164
- model_parse, _ = parsing_model_hd(vton_img.resize((IMG_W, IMG_H)))
165
- mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints, radius=5)
166
- mask = mask.resize((IMG_W, IMG_H), Image.NEAREST)
167
- mask_gray = mask_gray.resize((IMG_W, IMG_H), Image.NEAREST)
168
- masked_vton_img = Image.composite(mask_gray, vton_img, mask) # agnostic map
169
- print('%.2fs' % (time.time() - stt))
170
-
171
- stt = time.time()
172
- print('get densepose... ', end='')
173
- vton_img = vton_img.resize((IMG_W, IMG_H)) # size for densepose
174
- densepose = densepose_model_hd.execute(vton_img) # densepose
175
- print('%.2fs' % (time.time() - stt))
176
-
177
- batch = get_batch(
178
- vton_img,
179
- garm_img,
180
- densepose,
181
- masked_vton_img,
182
- mask,
183
- IMG_H,
184
- IMG_W
185
- )
186
-
187
- if is_custom:
188
- sample = stable_viton_model_hd(
189
- batch,
190
- n_steps,
191
- )
192
- else:
193
- sample = stable_viton_model_hd2(
194
- batch,
195
- n_steps,
196
- )
197
- return sample
198
-
199
-
200
- example_path = opj(os.path.dirname(__file__), 'examples_eternal')
201
- example_model_ps = sorted(glob(opj(example_path, "model/*")))
202
- example_garment_ps = sorted(glob(opj(example_path, "garment/*")))
203
 
204
  with gr.Blocks(css='style.css') as demo:
205
  gr.HTML(
206
  """
207
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
208
  <div>
209
- <h1>Rdy2Wr.AI StableVITON Demo πŸ‘•πŸ‘”πŸ‘—</h1>
210
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
211
  <a href='https://arxiv.org/abs/2312.01725'>
212
  <img src="https://img.shields.io/badge/arXiv-2312.01725-red">
@@ -232,27 +114,36 @@ with gr.Blocks(css='style.css') as demo:
232
  gr.Markdown("## Experience virtual try-on with your own images!")
233
  with gr.Row():
234
  with gr.Column():
235
- vton_img = gr.Image(label="Model", type="filepath", height=384, value=example_model_ps[0])
236
  example = gr.Examples(
237
  inputs=vton_img,
238
  examples_per_page=14,
239
- examples=example_model_ps)
 
 
 
 
240
  with gr.Column():
241
- garm_img = gr.Image(label="Garment", type="filepath", height=384, value=example_garment_ps[0])
242
  example = gr.Examples(
243
  inputs=garm_img,
244
  examples_per_page=14,
245
- examples=example_garment_ps)
 
 
 
 
246
  with gr.Column():
247
- result_gallery = gr.Image(label='Output', show_label=False, scale=1)
248
- # result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
249
  with gr.Column():
250
  run_button = gr.Button(value="Run")
251
- n_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=20, step=1)
252
- is_custom = gr.Checkbox(label="customized model")
253
- # seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
 
 
254
 
255
- ips = [vton_img, garm_img, n_steps, is_custom]
256
  run_button.click(fn=process_hd, inputs=ips, outputs=[result_gallery])
257
 
258
- demo.queue().launch()
 
 
 
 
 
1
  import os
2
  import sys
3
  import time
 
 
4
  from pathlib import Path
5
 
6
  import gradio as gr
7
  import torch
 
8
  from PIL import Image
 
 
 
9
 
10
+ from utils_stableviton import get_mask_location
 
 
11
 
12
  PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
13
+ sys.path.insert(0, str(PROJECT_ROOT))
14
+
15
+ from preprocess.detectron2.projects.DensePose.apply_net_gradio import DensePose4Gradio
16
+ from preprocess.humanparsing.run_parsing import Parsing
17
+ from preprocess.openpose.run_openpose import OpenPose
18
+
19
+ os.environ['GRADIO_TEMP_DIR'] = './tmp' # TODO: turn off when final upload
20
 
 
 
21
 
22
  openpose_model_hd = OpenPose(0)
 
23
  parsing_model_hd = Parsing(0)
24
  densepose_model_hd = DensePose4Gradio(
25
  cfg='preprocess/detectron2/projects/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml',
26
  model='https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl',
27
  )
28
+ stable_viton_model_hd = ... # TODO: write down stable viton model
29
 
30
  category_dict = ['upperbody', 'lowerbody', 'dress']
31
  category_dict_utils = ['upper_body', 'lower_body', 'dresses']
32
 
33
+ # import spaces # TODO: turn on when final upload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # @spaces.GPU # TODO: turn on when final upload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def process_hd(vton_img, garm_img, n_samples, n_steps, guidance_scale, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  model_type = 'hd'
40
  category = 0 # 0:upperbody; 1:lowerbody; 2:dress
41
 
42
+ with torch.no_grad():
43
+ openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
44
+
45
+ stt = time.time()
46
+ print('load images... ', end='')
47
+ garm_img = Image.open(garm_img).resize((768, 1024))
48
+ vton_img = Image.open(vton_img).resize((768, 1024))
49
+ print('%.2fs' % (time.time() - stt))
50
+
51
+ stt = time.time()
52
+ print('get agnostic map... ', end='')
53
+ keypoints = openpose_model_hd(vton_img.resize((384, 512)))
54
+ model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
55
+ mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
56
+ mask = mask.resize((768, 1024), Image.NEAREST)
57
+ mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
58
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask) # agnostic map
59
+ print('%.2fs' % (time.time() - stt))
60
+
61
+ stt = time.time()
62
+ print('get densepose... ', end='')
63
+ vton_img = vton_img.resize((768, 1024)) # size for densepose
64
+ densepose = densepose_model_hd.execute(vton_img) # densepose
65
+ print('%.2fs' % (time.time() - stt))
66
+
67
+ # # stable viton here
68
+ # images = stable_viton_model_hd(
69
+ # vton_img,
70
+ # garm_img,
71
+ # masked_vton_img,
72
+ # densepose,
73
+ # n_samples,
74
+ # n_steps,
75
+ # guidance_scale,
76
+ # seed
77
+ # )
78
+
79
+ # return images
80
+
81
+
82
+ example_path = os.path.join(os.path.dirname(__file__), 'examples')
83
+ model_hd = os.path.join(example_path, 'model/model_1.png')
84
+ garment_hd = os.path.join(example_path, 'garment/00055_00.jpg')
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  with gr.Blocks(css='style.css') as demo:
87
  gr.HTML(
88
  """
89
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
90
  <div>
91
+ <h1>StableVITON Demo πŸ‘•πŸ‘”πŸ‘—</h1>
92
  <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
93
  <a href='https://arxiv.org/abs/2312.01725'>
94
  <img src="https://img.shields.io/badge/arXiv-2312.01725-red">
 
114
  gr.Markdown("## Experience virtual try-on with your own images!")
115
  with gr.Row():
116
  with gr.Column():
117
+ vton_img = gr.Image(label="Model", type="filepath", height=384, value=model_hd)
118
  example = gr.Examples(
119
  inputs=vton_img,
120
  examples_per_page=14,
121
+ examples=[
122
+ os.path.join(example_path, 'model/model_1.png'), # TODO more our models
123
+ os.path.join(example_path, 'model/model_2.png'),
124
+ os.path.join(example_path, 'model/model_3.png'),
125
+ ])
126
  with gr.Column():
127
+ garm_img = gr.Image(label="Garment", type="filepath", height=384, value=garment_hd)
128
  example = gr.Examples(
129
  inputs=garm_img,
130
  examples_per_page=14,
131
+ examples=[
132
+ os.path.join(example_path, 'garment/00055_00.jpg'),
133
+ os.path.join(example_path, 'garment/00126_00.jpg'),
134
+ os.path.join(example_path, 'garment/00151_00.jpg'),
135
+ ])
136
  with gr.Column():
137
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
 
138
  with gr.Column():
139
  run_button = gr.Button(value="Run")
140
+ # TODO: change default values (important!)
141
+ n_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
142
+ n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
143
+ guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
144
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
145
 
146
+ ips = [vton_img, garm_img, n_samples, n_steps, guidance_scale, seed]
147
  run_button.click(fn=process_hd, inputs=ips, outputs=[result_gallery])
148
 
149
+ demo.launch()
apply_net.py DELETED
@@ -1,359 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright (c) Facebook, Inc. and its affiliates.
3
-
4
- import argparse
5
- import glob
6
- import logging
7
- import os
8
- import sys
9
- from typing import Any, ClassVar, Dict, List
10
- import torch
11
-
12
- from detectron2.config import CfgNode, get_cfg
13
- from detectron2.data.detection_utils import read_image
14
- from detectron2.engine.defaults import DefaultPredictor
15
- from detectron2.structures.instances import Instances
16
- from detectron2.utils.logger import setup_logger
17
-
18
- from densepose import add_densepose_config
19
- from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
20
- from densepose.utils.logger import verbosity_to_level
21
- from densepose.vis.base import CompoundVisualizer
22
- from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer
23
- from densepose.vis.densepose_outputs_vertex import (
24
- DensePoseOutputsTextureVisualizer,
25
- DensePoseOutputsVertexVisualizer,
26
- get_texture_atlases,
27
- )
28
- from densepose.vis.densepose_results import (
29
- DensePoseResultsContourVisualizer,
30
- DensePoseResultsFineSegmentationVisualizer,
31
- DensePoseResultsUVisualizer,
32
- DensePoseResultsVVisualizer,
33
- )
34
- from densepose.vis.densepose_results_textures import (
35
- DensePoseResultsVisualizerWithTexture,
36
- get_texture_atlas,
37
- )
38
- from densepose.vis.extractor import (
39
- CompoundExtractor,
40
- DensePoseOutputsExtractor,
41
- DensePoseResultExtractor,
42
- create_extractor,
43
- )
44
-
45
- DOC = """Apply Net - a tool to print / visualize DensePose results
46
- """
47
-
48
- LOGGER_NAME = "apply_net"
49
- logger = logging.getLogger(LOGGER_NAME)
50
-
51
- _ACTION_REGISTRY: Dict[str, "Action"] = {}
52
-
53
-
54
- class Action:
55
- @classmethod
56
- def add_arguments(cls: type, parser: argparse.ArgumentParser):
57
- parser.add_argument(
58
- "-v",
59
- "--verbosity",
60
- action="count",
61
- help="Verbose mode. Multiple -v options increase the verbosity.",
62
- )
63
-
64
-
65
- def register_action(cls: type):
66
- """
67
- Decorator for action classes to automate action registration
68
- """
69
- global _ACTION_REGISTRY
70
- _ACTION_REGISTRY[cls.COMMAND] = cls
71
- return cls
72
-
73
-
74
- class InferenceAction(Action):
75
- @classmethod
76
- def add_arguments(cls: type, parser: argparse.ArgumentParser):
77
- super(InferenceAction, cls).add_arguments(parser)
78
- parser.add_argument("cfg", metavar="<config>", help="Config file")
79
- parser.add_argument("model", metavar="<model>", help="Model file")
80
- parser.add_argument(
81
- "--opts",
82
- help="Modify config options using the command-line 'KEY VALUE' pairs",
83
- default=[],
84
- nargs=argparse.REMAINDER,
85
- )
86
-
87
- @classmethod
88
- def execute(cls: type, args: argparse.Namespace, human_img):
89
- logger.info(f"Loading config from {args.cfg}")
90
- opts = []
91
- cfg = cls.setup_config(args.cfg, args.model, args, opts)
92
- logger.info(f"Loading model from {args.model}")
93
- predictor = DefaultPredictor(cfg)
94
- # logger.info(f"Loading data from {args.input}")
95
- # file_list = cls._get_input_file_list(args.input)
96
- # if len(file_list) == 0:
97
- # logger.warning(f"No input images for {args.input}")
98
- # return
99
- context = cls.create_context(args, cfg)
100
- # for file_name in file_list:
101
- # img = read_image(file_name, format="BGR") # predictor expects BGR image.
102
- with torch.no_grad():
103
- outputs = predictor(human_img)["instances"]
104
- out_pose = cls.execute_on_outputs(context, {"image": human_img}, outputs)
105
- cls.postexecute(context)
106
- return out_pose
107
-
108
- @classmethod
109
- def setup_config(
110
- cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
111
- ):
112
- cfg = get_cfg()
113
- add_densepose_config(cfg)
114
- cfg.merge_from_file(config_fpath)
115
- cfg.merge_from_list(args.opts)
116
- if opts:
117
- cfg.merge_from_list(opts)
118
- cfg.MODEL.WEIGHTS = model_fpath
119
- cfg.freeze()
120
- return cfg
121
-
122
- @classmethod
123
- def _get_input_file_list(cls: type, input_spec: str):
124
- if os.path.isdir(input_spec):
125
- file_list = [
126
- os.path.join(input_spec, fname)
127
- for fname in os.listdir(input_spec)
128
- if os.path.isfile(os.path.join(input_spec, fname))
129
- ]
130
- elif os.path.isfile(input_spec):
131
- file_list = [input_spec]
132
- else:
133
- file_list = glob.glob(input_spec)
134
- return file_list
135
-
136
-
137
- @register_action
138
- class DumpAction(InferenceAction):
139
- """
140
- Dump action that outputs results to a pickle file
141
- """
142
-
143
- COMMAND: ClassVar[str] = "dump"
144
-
145
- @classmethod
146
- def add_parser(cls: type, subparsers: argparse._SubParsersAction):
147
- parser = subparsers.add_parser(cls.COMMAND, help="Dump model outputs to a file.")
148
- cls.add_arguments(parser)
149
- parser.set_defaults(func=cls.execute)
150
-
151
- @classmethod
152
- def add_arguments(cls: type, parser: argparse.ArgumentParser):
153
- super(DumpAction, cls).add_arguments(parser)
154
- parser.add_argument(
155
- "--output",
156
- metavar="<dump_file>",
157
- default="results.pkl",
158
- help="File name to save dump to",
159
- )
160
-
161
- @classmethod
162
- def execute_on_outputs(
163
- cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
164
- ):
165
- image_fpath = entry["file_name"]
166
- logger.info(f"Processing {image_fpath}")
167
- result = {"file_name": image_fpath}
168
- if outputs.has("scores"):
169
- result["scores"] = outputs.get("scores").cpu()
170
- if outputs.has("pred_boxes"):
171
- result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
172
- if outputs.has("pred_densepose"):
173
- if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
174
- extractor = DensePoseResultExtractor()
175
- elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
176
- extractor = DensePoseOutputsExtractor()
177
- result["pred_densepose"] = extractor(outputs)[0]
178
- context["results"].append(result)
179
-
180
- @classmethod
181
- def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode):
182
- context = {"results": [], "out_fname": args.output}
183
- return context
184
-
185
- @classmethod
186
- def postexecute(cls: type, context: Dict[str, Any]):
187
- out_fname = context["out_fname"]
188
- out_dir = os.path.dirname(out_fname)
189
- if len(out_dir) > 0 and not os.path.exists(out_dir):
190
- os.makedirs(out_dir)
191
- with open(out_fname, "wb") as hFile:
192
- torch.save(context["results"], hFile)
193
- logger.info(f"Output saved to {out_fname}")
194
-
195
-
196
- @register_action
197
- class ShowAction(InferenceAction):
198
- """
199
- Show action that visualizes selected entries on an image
200
- """
201
-
202
- COMMAND: ClassVar[str] = "show"
203
- VISUALIZERS: ClassVar[Dict[str, object]] = {
204
- "dp_contour": DensePoseResultsContourVisualizer,
205
- "dp_segm": DensePoseResultsFineSegmentationVisualizer,
206
- "dp_u": DensePoseResultsUVisualizer,
207
- "dp_v": DensePoseResultsVVisualizer,
208
- "dp_iuv_texture": DensePoseResultsVisualizerWithTexture,
209
- "dp_cse_texture": DensePoseOutputsTextureVisualizer,
210
- "dp_vertex": DensePoseOutputsVertexVisualizer,
211
- "bbox": ScoredBoundingBoxVisualizer,
212
- }
213
-
214
- @classmethod
215
- def add_parser(cls: type, subparsers: argparse._SubParsersAction):
216
- parser = subparsers.add_parser(cls.COMMAND, help="Visualize selected entries")
217
- cls.add_arguments(parser)
218
- parser.set_defaults(func=cls.execute)
219
-
220
- @classmethod
221
- def add_arguments(cls: type, parser: argparse.ArgumentParser):
222
- super(ShowAction, cls).add_arguments(parser)
223
- parser.add_argument(
224
- "visualizations",
225
- metavar="<visualizations>",
226
- help="Comma separated list of visualizations, possible values: "
227
- "[{}]".format(",".join(sorted(cls.VISUALIZERS.keys()))),
228
- )
229
- parser.add_argument(
230
- "--min_score",
231
- metavar="<score>",
232
- default=0.8,
233
- type=float,
234
- help="Minimum detection score to visualize",
235
- )
236
- parser.add_argument(
237
- "--nms_thresh", metavar="<threshold>", default=None, type=float, help="NMS threshold"
238
- )
239
- parser.add_argument(
240
- "--texture_atlas",
241
- metavar="<texture_atlas>",
242
- default=None,
243
- help="Texture atlas file (for IUV texture transfer)",
244
- )
245
- parser.add_argument(
246
- "--texture_atlases_map",
247
- metavar="<texture_atlases_map>",
248
- default=None,
249
- help="JSON string of a dict containing texture atlas files for each mesh",
250
- )
251
- parser.add_argument(
252
- "--output",
253
- metavar="<image_file>",
254
- default="outputres.png",
255
- help="File name to save output to",
256
- )
257
-
258
- @classmethod
259
- def setup_config(
260
- cls: type, config_fpath: str, model_fpath: str, args: argparse.Namespace, opts: List[str]
261
- ):
262
- opts.append("MODEL.ROI_HEADS.SCORE_THRESH_TEST")
263
- opts.append(str(args.min_score))
264
- if args.nms_thresh is not None:
265
- opts.append("MODEL.ROI_HEADS.NMS_THRESH_TEST")
266
- opts.append(str(args.nms_thresh))
267
- cfg = super(ShowAction, cls).setup_config(config_fpath, model_fpath, args, opts)
268
- return cfg
269
-
270
- @classmethod
271
- def execute_on_outputs(
272
- cls: type, context: Dict[str, Any], entry: Dict[str, Any], outputs: Instances
273
- ):
274
- import cv2
275
- import numpy as np
276
- visualizer = context["visualizer"]
277
- extractor = context["extractor"]
278
- # image_fpath = entry["file_name"]
279
- # logger.info(f"Processing {image_fpath}")
280
- image = cv2.cvtColor(entry["image"], cv2.COLOR_BGR2GRAY)
281
- image = np.tile(image[:, :, np.newaxis], [1, 1, 3])
282
- data = extractor(outputs)
283
- image_vis = visualizer.visualize(image, data)
284
-
285
- return image_vis
286
- entry_idx = context["entry_idx"] + 1
287
- out_fname = './image-densepose/' + image_fpath.split('/')[-1]
288
- out_dir = './image-densepose'
289
- out_dir = os.path.dirname(out_fname)
290
- if len(out_dir) > 0 and not os.path.exists(out_dir):
291
- os.makedirs(out_dir)
292
- cv2.imwrite(out_fname, image_vis)
293
- logger.info(f"Output saved to {out_fname}")
294
- context["entry_idx"] += 1
295
-
296
- @classmethod
297
- def postexecute(cls: type, context: Dict[str, Any]):
298
- pass
299
- # python ./apply_net.py show ./configs/densepose_rcnn_R_50_FPN_s1x.yaml https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl /home/alin0222/DressCode/upper_body/images dp_segm -v --opts MODEL.DEVICE cpu
300
-
301
- @classmethod
302
- def _get_out_fname(cls: type, entry_idx: int, fname_base: str):
303
- base, ext = os.path.splitext(fname_base)
304
- return base + ".{0:04d}".format(entry_idx) + ext
305
-
306
- @classmethod
307
- def create_context(cls: type, args: argparse.Namespace, cfg: CfgNode) -> Dict[str, Any]:
308
- vis_specs = args.visualizations.split(",")
309
- visualizers = []
310
- extractors = []
311
- for vis_spec in vis_specs:
312
- texture_atlas = get_texture_atlas(args.texture_atlas)
313
- texture_atlases_dict = get_texture_atlases(args.texture_atlases_map)
314
- vis = cls.VISUALIZERS[vis_spec](
315
- cfg=cfg,
316
- texture_atlas=texture_atlas,
317
- texture_atlases_dict=texture_atlases_dict,
318
- )
319
- visualizers.append(vis)
320
- extractor = create_extractor(vis)
321
- extractors.append(extractor)
322
- visualizer = CompoundVisualizer(visualizers)
323
- extractor = CompoundExtractor(extractors)
324
- context = {
325
- "extractor": extractor,
326
- "visualizer": visualizer,
327
- "out_fname": args.output,
328
- "entry_idx": 0,
329
- }
330
- return context
331
-
332
-
333
- def create_argument_parser() -> argparse.ArgumentParser:
334
- parser = argparse.ArgumentParser(
335
- description=DOC,
336
- formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=120),
337
- )
338
- parser.set_defaults(func=lambda _: parser.print_help(sys.stdout))
339
- subparsers = parser.add_subparsers(title="Actions")
340
- for _, action in _ACTION_REGISTRY.items():
341
- action.add_parser(subparsers)
342
- return parser
343
-
344
-
345
- def main():
346
- parser = create_argument_parser()
347
- args = parser.parse_args()
348
- verbosity = getattr(args, "verbosity", None)
349
- global logger
350
- logger = setup_logger(name=LOGGER_NAME)
351
- logger.setLevel(verbosity_to_level(verbosity))
352
- args.func(args)
353
-
354
-
355
- if __name__ == "__main__":
356
- main()
357
-
358
-
359
- # python ./apply_net.py show ./configs/densepose_rcnn_R_50_FPN_s1x.yaml https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl /home/alin0222/Dresscode/dresses/humanonly dp_segm -v --opts MODEL.DEVICE cuda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
checkpoints/VITONHD.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4e44bc58b68f289cd7c1660e06a0db5f6fcb5786c037c9e8217eea45a75688f
3
- size 10198120487
 
 
 
 
checkpoints/VITONHD_1024.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b6f9c91db9b5813c5e3fe9397c43f475d2725f8c795cf0d54d64ebd0fcbe8463
3
- size 10198120551
 
 
 
 
checkpoints/eternal_1024.ckpt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9bff84cd05792726e3dfbd39561ea21bd0b1e25a59756da7daab279eb192b441
3
- size 10198120423
 
 
 
 
checkpoints/humanparsing/.gitkeep DELETED
File without changes
checkpoints/humanparsing/parsing_atr.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:04c7d1d070d0e0ae943d86b18cb5aaaea9e278d97462e9cfb270cbbe4cd977f4
3
- size 266859305
 
 
 
 
checkpoints/humanparsing/parsing_lip.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8436e1dae96e2601c373d1ace29c8f0978b16357d9038c17a8ba756cca376dbc
3
- size 266863411
 
 
 
 
checkpoints/openpose/ckpts/body_pose_model.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:25a948c16078b0f08e236bda51a385d855ef4c153598947c28c0d47ed94bb746
3
- size 209267595
 
 
 
 
cldm/cldm.py DELETED
@@ -1,138 +0,0 @@
1
- import os
2
- from os.path import join as opj
3
- import omegaconf
4
-
5
- import cv2
6
- import einops
7
- import torch
8
- import torch as th
9
- import torch.nn as nn
10
- import torchvision.transforms as T
11
- import torch.nn.functional as F
12
- import numpy as np
13
-
14
- from ldm.models.diffusion.ddpm import LatentDiffusion
15
- from ldm.util import instantiate_from_config
16
-
17
- class ControlLDM(LatentDiffusion):
18
- def __init__(
19
- self,
20
- control_stage_config,
21
- validation_config,
22
- control_key,
23
- only_mid_control,
24
- use_VAEdownsample=False,
25
- config_name="",
26
- control_scales=None,
27
- use_pbe_weight=False,
28
- u_cond_percent=0.0,
29
- img_H=512,
30
- img_W=384,
31
- always_learnable_param=False,
32
- *args,
33
- **kwargs
34
- ):
35
- self.device = torch.device("cuda")
36
- self.control_stage_config = control_stage_config
37
- self.use_pbe_weight = use_pbe_weight
38
- self.u_cond_percent = u_cond_percent
39
- self.img_H = img_H
40
- self.img_W = img_W
41
- self.config_name = config_name
42
- self.always_learnable_param = always_learnable_param
43
- super().__init__(*args, **kwargs)
44
- control_stage_config.params["use_VAEdownsample"] = use_VAEdownsample
45
- self.control_model = instantiate_from_config(control_stage_config)
46
- self.control_key = control_key
47
- self.only_mid_control = only_mid_control
48
- if control_scales is None:
49
- self.control_scales = [1.0] * 13
50
- else:
51
- self.control_scales = control_scales
52
- self.first_stage_key_cond = kwargs.get("first_stage_key_cond", None)
53
- self.valid_config = validation_config
54
- self.use_VAEDownsample = use_VAEdownsample
55
- @torch.no_grad()
56
- def get_input(self, batch, k, bs=None, *args, **kwargs):
57
- x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
58
- if isinstance(self.control_key, omegaconf.listconfig.ListConfig):
59
- control_lst = []
60
- for key in self.control_key:
61
- control = batch[key]
62
- if bs is not None:
63
- control = control[:bs]
64
- control = control.to(self.device)
65
- control = einops.rearrange(control, 'b h w c -> b c h w')
66
- control = control.to(memory_format=torch.contiguous_format)
67
- control_lst.append(control)
68
- control = control_lst
69
- else:
70
- control = batch[self.control_key]
71
- if bs is not None:
72
- control = control[:bs]
73
- control = control.to(self.device)
74
- control = einops.rearrange(control, 'b h w c -> b c h w')
75
- control = control.to(memory_format=torch.contiguous_format)
76
- control = [control]
77
- cond_dict = dict(c_crossattn=[c], c_concat=control)
78
- if self.first_stage_key_cond is not None:
79
- first_stage_cond = []
80
- for key in self.first_stage_key_cond:
81
- if not "mask" in key:
82
- cond, _ = super().get_input(batch, key, *args, **kwargs)
83
- else:
84
- cond, _ = super().get_input(batch, key, no_latent=True, *args, **kwargs)
85
- first_stage_cond.append(cond)
86
- first_stage_cond = torch.cat(first_stage_cond, dim=1)
87
- cond_dict["first_stage_cond"] = first_stage_cond
88
- return x, cond_dict
89
-
90
- def apply_model(self, x_noisy, t, cond, *args, **kwargs):
91
- assert isinstance(cond, dict)
92
-
93
- diffusion_model = self.model.diffusion_model
94
- cond_txt = torch.cat(cond["c_crossattn"], 1)
95
- if self.proj_out is not None:
96
- if cond_txt.shape[-1] == 1024:
97
- cond_txt = self.proj_out(cond_txt) # [BS x 1 x 768]
98
- if self.always_learnable_param:
99
- cond_txt = self.get_unconditional_conditioning(cond_txt.shape[0])
100
-
101
- if cond['c_concat'] is None:
102
- eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
103
- else:
104
- if "first_stage_cond" in cond:
105
- x_noisy = torch.cat([x_noisy, cond["first_stage_cond"]], dim=1)
106
- if not self.use_VAEDownsample:
107
- hint = cond["c_concat"]
108
- else:
109
- hint = []
110
- for h in cond["c_concat"]:
111
- if h.shape[2] == self.img_H and h.shape[3] == self.img_W:
112
- h = self.encode_first_stage(h)
113
- h = self.get_first_stage_encoding(h).detach()
114
- hint.append(h)
115
- hint = torch.cat(hint, dim=1)
116
- control, _ = self.control_model(x=x_noisy, hint=hint, timesteps=t, context=cond_txt, only_mid_control=self.only_mid_control)
117
- if len(control) == len(self.control_scales):
118
- control = [c * scale for c, scale in zip(control, self.control_scales)]
119
-
120
- eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
121
- return eps, None
122
- @torch.no_grad()
123
- def get_unconditional_conditioning(self, N):
124
- if not self.kwargs["use_imageCLIP"]:
125
- return self.get_learned_conditioning([""] * N)
126
- else:
127
- return self.learnable_vector.repeat(N,1,1)
128
- def low_vram_shift(self, is_diffusing):
129
- if is_diffusing:
130
- self.model = self.model.cuda()
131
- self.control_model = self.control_model.cuda()
132
- self.first_stage_model = self.first_stage_model.cpu()
133
- self.cond_stage_model = self.cond_stage_model.cpu()
134
- else:
135
- self.model = self.model.cpu()
136
- self.control_model = self.control_model.cpu()
137
- self.first_stage_model = self.first_stage_model.cuda()
138
- self.cond_stage_model = self.cond_stage_model.cuda()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cldm/hack.py DELETED
@@ -1,111 +0,0 @@
1
- import torch
2
- import einops
3
-
4
- import ldm.modules.encoders.modules
5
- import ldm.modules.attention
6
-
7
- from transformers import logging
8
- from ldm.modules.attention import default
9
-
10
-
11
- def disable_verbosity():
12
- logging.set_verbosity_error()
13
- print('logging improved.')
14
- return
15
-
16
-
17
- def enable_sliced_attention():
18
- ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
- print('Enabled sliced_attention.')
20
- return
21
-
22
-
23
- def hack_everything(clip_skip=0):
24
- disable_verbosity()
25
- ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
- ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
- print('Enabled clip hacks.')
28
- return
29
-
30
-
31
- # Written by Lvmin
32
- def _hacked_clip_forward(self, text):
33
- PAD = self.tokenizer.pad_token_id
34
- EOS = self.tokenizer.eos_token_id
35
- BOS = self.tokenizer.bos_token_id
36
-
37
- def tokenize(t):
38
- return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
-
40
- def transformer_encode(t):
41
- if self.clip_skip > 1:
42
- rt = self.transformer(input_ids=t, output_hidden_states=True)
43
- return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
- else:
45
- return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
-
47
- def split(x):
48
- return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
-
50
- def pad(x, p, i):
51
- return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
-
53
- raw_tokens_list = tokenize(text)
54
- tokens_list = []
55
-
56
- for raw_tokens in raw_tokens_list:
57
- raw_tokens_123 = split(raw_tokens)
58
- raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
- raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
- tokens_list.append(raw_tokens_123)
61
-
62
- tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
-
64
- feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
- y = transformer_encode(feed)
66
- z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
-
68
- return z
69
-
70
-
71
- # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
- def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
- h = self.heads
74
-
75
- q = self.to_q(x)
76
- context = default(context, x)
77
- k = self.to_k(context)
78
- v = self.to_v(context)
79
- del context, x
80
-
81
- q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
-
83
- limit = k.shape[0]
84
- att_step = 1
85
- q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
- k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
- v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
-
89
- q_chunks.reverse()
90
- k_chunks.reverse()
91
- v_chunks.reverse()
92
- sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
- del k, q, v
94
- for i in range(0, limit, att_step):
95
- q_buffer = q_chunks.pop()
96
- k_buffer = k_chunks.pop()
97
- v_buffer = v_chunks.pop()
98
- sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
-
100
- del k_buffer, q_buffer
101
- # attention, what we cannot get enough of, by chunks
102
-
103
- sim_buffer = sim_buffer.softmax(dim=-1)
104
-
105
- sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
- del v_buffer
107
- sim[i:i + att_step, :, :] = sim_buffer
108
-
109
- del sim_buffer
110
- sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
- return self.to_out(sim)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cldm/model.py DELETED
@@ -1,9 +0,0 @@
1
- from ldm.util import instantiate_from_config
2
-
3
-
4
- def get_state_dict(d):
5
- return d.get('state_dict', d)
6
-
7
- def create_model(config, **kwargs):
8
- model = instantiate_from_config(config.model).cpu()
9
- return model
 
 
 
 
 
 
 
 
 
 
cldm/plms_hacked.py DELETED
@@ -1,251 +0,0 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
-
7
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
- from ldm.models.diffusion.sampling_util import norm_thresholding
9
-
10
-
11
- class PLMSSampler(object):
12
- def __init__(self, model, schedule="linear", **kwargs):
13
- super().__init__()
14
- self.model = model
15
- self.ddpm_num_timesteps = model.num_timesteps
16
- self.schedule = schedule
17
-
18
- def register_buffer(self, name, attr):
19
- if type(attr) == torch.Tensor:
20
- if attr.device != torch.device("cuda"):
21
- attr = attr.to(torch.device("cuda"))
22
- setattr(self, name, attr)
23
-
24
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
25
- if ddim_eta != 0:
26
- raise ValueError('ddim_eta must be 0 for PLMS')
27
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
28
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
29
- alphas_cumprod = self.model.alphas_cumprod
30
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
31
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
32
-
33
- self.register_buffer('betas', to_torch(self.model.betas))
34
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
36
-
37
- # calculations for diffusion q(x_t | x_{t-1}) and others
38
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
39
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
40
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
41
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
42
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
43
-
44
- # ddim sampling parameters
45
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
46
- ddim_timesteps=self.ddim_timesteps,
47
- eta=ddim_eta,verbose=verbose)
48
- self.register_buffer('ddim_sigmas', ddim_sigmas)
49
- self.register_buffer('ddim_alphas', ddim_alphas)
50
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
51
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
52
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
53
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
54
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
55
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
56
-
57
- @torch.no_grad()
58
- def sample(self,
59
- S,
60
- batch_size,
61
- shape,
62
- conditioning=None,
63
- callback=None,
64
- img_callback=None,
65
- quantize_x0=False,
66
- eta=0.,
67
- mask=None,
68
- x0=None,
69
- temperature=1.,
70
- noise_dropout=0.,
71
- score_corrector=None,
72
- corrector_kwargs=None,
73
- verbose=True,
74
- x_T=None,
75
- log_every_t=100,
76
- unconditional_guidance_scale=5.,
77
- unconditional_conditioning=None,
78
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79
- dynamic_threshold=None,
80
- **kwargs
81
- ):
82
- if conditioning is not None:
83
- if isinstance(conditioning, dict):
84
- ctmp = conditioning[list(conditioning.keys())[0]]
85
- while isinstance(ctmp, list): ctmp = ctmp[0]
86
- cbs = ctmp.shape[0]
87
- if cbs != batch_size:
88
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
- else:
90
- if conditioning.shape[0] != batch_size:
91
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92
-
93
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94
- # sampling
95
- C, H, W = shape
96
- size = (batch_size, C, H, W)
97
- print(f'Data shape for PLMS sampling is {size}')
98
-
99
- samples, intermediates, cond_output_dict = self.plms_sampling(conditioning, size,
100
- callback=callback,
101
- img_callback=img_callback,
102
- quantize_denoised=quantize_x0,
103
- mask=mask, x0=x0,
104
- ddim_use_original_steps=False,
105
- noise_dropout=noise_dropout,
106
- temperature=temperature,
107
- score_corrector=score_corrector,
108
- corrector_kwargs=corrector_kwargs,
109
- x_T=x_T,
110
- log_every_t=log_every_t,
111
- unconditional_guidance_scale=unconditional_guidance_scale,
112
- unconditional_conditioning=unconditional_conditioning,
113
- dynamic_threshold=dynamic_threshold,
114
- )
115
- return samples, intermediates, cond_output_dict
116
-
117
- @torch.no_grad()
118
- def plms_sampling(self, cond, shape,
119
- x_T=None, ddim_use_original_steps=False,
120
- callback=None, timesteps=None, quantize_denoised=False,
121
- mask=None, x0=None, img_callback=None, log_every_t=100,
122
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123
- unconditional_guidance_scale=1., unconditional_conditioning=None,
124
- dynamic_threshold=None):
125
- device = self.model.betas.device
126
- b = shape[0]
127
- if x_T is None:
128
- img = torch.randn(shape, device=device)
129
- else:
130
- img = x_T
131
-
132
- if timesteps is None:
133
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134
- elif timesteps is not None and not ddim_use_original_steps:
135
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136
- timesteps = self.ddim_timesteps[:subset_end]
137
-
138
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
139
- time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141
- print(f"Running PLMS Sampling with {total_steps} timesteps")
142
-
143
- iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144
- old_eps = []
145
-
146
- for i, step in enumerate(iterator):
147
- index = total_steps - i - 1
148
- ts = torch.full((b,), step, device=device, dtype=torch.long)
149
- ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150
-
151
- if mask is not None:
152
- assert x0 is not None
153
- if i < self.first_n_repaint:
154
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
155
- img = img_orig * mask + (1. - mask) * img
156
-
157
- outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
158
- quantize_denoised=quantize_denoised, temperature=temperature,
159
- noise_dropout=noise_dropout, score_corrector=score_corrector,
160
- corrector_kwargs=corrector_kwargs,
161
- unconditional_guidance_scale=unconditional_guidance_scale,
162
- unconditional_conditioning=unconditional_conditioning,
163
- old_eps=old_eps, t_next=ts_next,
164
- dynamic_threshold=dynamic_threshold)
165
- img, pred_x0, e_t = outs
166
- old_eps.append(e_t)
167
- if len(old_eps) >= 4:
168
- old_eps.pop(0)
169
- if callback: callback(i)
170
- if img_callback: img_callback(pred_x0, i)
171
-
172
- if index % log_every_t == 0 or index == total_steps - 1:
173
- intermediates['x_inter'].append(img)
174
- intermediates['pred_x0'].append(pred_x0)
175
- return img, intermediates, None
176
- def undo(self, x_t, t):
177
- beta = extract_into_tensor(self.betas, t, x_t.shape)
178
- x_t_forward = torch.sqrt(1 - beta) * x_t + torch.sqrt(beta) * torch.randn_like(x_t)
179
- return x_t_forward
180
-
181
- @torch.no_grad()
182
- def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
183
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
184
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
185
- dynamic_threshold=None):
186
- b, *_, device = *x.shape, x.device
187
-
188
- def get_model_output(x, t):
189
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
190
- e_t, _ = self.model.apply_model(x, t, c)
191
- else:
192
- model_t, _ = self.model.apply_model(x,t,c)
193
- model_uncond, _ = self.model.apply_model(x,t,unconditional_conditioning)
194
-
195
- if isinstance(model_t, tuple):
196
- model_t, _ = model_t
197
- if isinstance(model_uncond, tuple):
198
- model_uncond, _ = model_uncond
199
- e_t = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
200
-
201
- if score_corrector is not None:
202
- assert self.model.parameterization == "eps"
203
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
204
-
205
- return e_t
206
-
207
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
208
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
209
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
210
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
211
-
212
- def get_x_prev_and_pred_x0(e_t, index):
213
- # select parameters corresponding to the currently considered timestep
214
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
215
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
216
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
217
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
218
-
219
- # current prediction for x_0
220
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
221
- if quantize_denoised:
222
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
223
- if dynamic_threshold is not None:
224
- pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
225
- # direction pointing to x_t
226
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
227
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
228
- if noise_dropout > 0.:
229
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
230
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
231
- return x_prev, pred_x0
232
-
233
- e_t = get_model_output(x, t)
234
- if len(old_eps) == 0:
235
- # Pseudo Improved Euler (2nd order)
236
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
237
- e_t_next = get_model_output(x_prev, t_next)
238
- e_t_prime = (e_t + e_t_next) / 2
239
- elif len(old_eps) == 1:
240
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
241
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
242
- elif len(old_eps) == 2:
243
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
244
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
245
- elif len(old_eps) >= 3:
246
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
247
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
248
-
249
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
250
-
251
- return x_prev, pred_x0, e_t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cldm/warping_cldm_network.py DELETED
@@ -1,357 +0,0 @@
1
- import torch
2
- import torch as th
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from ldm.modules.diffusionmodules.util import (
7
- conv_nd,
8
- linear,
9
- zero_module,
10
- timestep_embedding
11
- )
12
-
13
- from einops import rearrange
14
- from ldm.modules.attention import SpatialTransformer
15
- from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
16
- from ldm.util import exists
17
-
18
- class StableVITON(UNetModel):
19
- def __init__(
20
- self,
21
- dim_head_denorm=1,
22
- *args,
23
- **kwargs,
24
- ):
25
- super().__init__(*args, **kwargs)
26
- warp_flow_blks = []
27
- warp_zero_convs = []
28
-
29
- self.encode_output_chs = [
30
- 320,
31
- 320,
32
- 640,
33
- 640,
34
- 640,
35
- 1280,
36
- 1280,
37
- 1280,
38
- 1280
39
- ]
40
-
41
- self.encode_output_chs2 = [
42
- 320,
43
- 320,
44
- 320,
45
- 320,
46
- 640,
47
- 640,
48
- 640,
49
- 1280,
50
- 1280
51
- ]
52
-
53
-
54
- for in_ch, cont_ch in zip(self.encode_output_chs, self.encode_output_chs2):
55
- dim_head = in_ch // self.num_heads
56
- dim_head = dim_head // dim_head_denorm
57
- warp_flow_blks.append(SpatialTransformer(
58
- in_channels=in_ch,
59
- n_heads=self.num_heads,
60
- d_head=dim_head,
61
- depth=self.transformer_depth,
62
- context_dim=cont_ch,
63
- use_linear=self.use_linear_in_transformer,
64
- use_checkpoint=self.use_checkpoint,
65
- ))
66
- warp_zero_convs.append(self.make_zero_conv(in_ch))
67
- self.warp_flow_blks = nn.ModuleList(reversed(warp_flow_blks))
68
- self.warp_zero_convs = nn.ModuleList(reversed(warp_zero_convs))
69
- def make_zero_conv(self, channels):
70
- return zero_module(conv_nd(2, channels, channels, 1, padding=0))
71
- def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
72
- hs = []
73
-
74
- with torch.no_grad():
75
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
76
- emb = self.time_embed(t_emb)
77
- h = x.type(self.dtype)
78
- for module in self.input_blocks:
79
- h = module(h, emb, context)
80
- hs.append(h)
81
- h = self.middle_block(h, emb, context)
82
-
83
- if control is not None:
84
- hint = control.pop()
85
-
86
- for module in self.output_blocks[:3]:
87
- control.pop()
88
- h = torch.cat([h, hs.pop()], dim=1)
89
- h = module(h, emb, context)
90
-
91
- n_warp = len(self.encode_output_chs)
92
- for i, (module, warp_blk, warp_zc) in enumerate(zip(self.output_blocks[3:n_warp+3], self.warp_flow_blks, self.warp_zero_convs)):
93
- if control is None or (h.shape[-2] == 8 and h.shape[-1] == 6):
94
- assert 0, f"shape is wrong : {h.shape}"
95
- else:
96
- hint = control.pop()
97
- h = self.warp(h, hint, warp_blk, warp_zc)
98
- h = torch.cat([h, hs.pop()], dim=1)
99
- h = module(h, emb, context)
100
- for module in self.output_blocks[n_warp+3:]:
101
- if control is None:
102
- h = torch.cat([h, hs.pop()], dim=1)
103
- else:
104
- h = torch.cat([h, hs.pop()], dim=1)
105
- h = module(h, emb, context)
106
- h = h.type(x.dtype)
107
- return self.out(h)
108
- def warp(self, x, hint, crossattn_layer, zero_conv, mask1=None, mask2=None):
109
- hint = rearrange(hint, "b c h w -> b (h w) c").contiguous()
110
- output = crossattn_layer(x, hint)
111
- output = zero_conv(output)
112
- return output + x
113
- class NoZeroConvControlNet(nn.Module):
114
- def __init__(
115
- self,
116
- image_size,
117
- in_channels,
118
- model_channels,
119
- hint_channels,
120
- num_res_blocks,
121
- attention_resolutions,
122
- dropout=0,
123
- channel_mult=(1, 2, 4, 8),
124
- conv_resample=True,
125
- dims=2,
126
- use_checkpoint=False,
127
- use_fp16=False,
128
- num_heads=-1,
129
- num_head_channels=-1,
130
- num_heads_upsample=-1,
131
- use_scale_shift_norm=False,
132
- resblock_updown=False,
133
- use_new_attention_order=False,
134
- use_spatial_transformer=False, # custom transformer support
135
- transformer_depth=1, # custom transformer support
136
- context_dim=None, # custom transformer support
137
- n_embed=None,
138
- legacy=True,
139
- disable_self_attentions=None,
140
- num_attention_blocks=None,
141
- disable_middle_self_attn=False,
142
- use_linear_in_transformer=False,
143
- use_VAEdownsample=False,
144
- cond_first_ch=8,
145
- ):
146
- super().__init__()
147
- if use_spatial_transformer:
148
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
149
-
150
- if context_dim is not None:
151
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
152
- from omegaconf.listconfig import ListConfig
153
- if type(context_dim) == ListConfig:
154
- context_dim = list(context_dim)
155
-
156
- if num_heads_upsample == -1:
157
- num_heads_upsample = num_heads
158
-
159
- if num_heads == -1:
160
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
161
-
162
- if num_head_channels == -1:
163
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
164
-
165
- self.dims = dims
166
- self.image_size = image_size
167
- self.in_channels = in_channels
168
- self.model_channels = model_channels
169
- if isinstance(num_res_blocks, int):
170
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
171
- else:
172
- if len(num_res_blocks) != len(channel_mult):
173
- raise ValueError("provide num_res_blocks either as an int (globally constant) or "
174
- "as a list/tuple (per-level) with the same length as channel_mult")
175
- self.num_res_blocks = num_res_blocks
176
- if disable_self_attentions is not None:
177
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
178
- assert len(disable_self_attentions) == len(channel_mult)
179
- if num_attention_blocks is not None:
180
- assert len(num_attention_blocks) == len(self.num_res_blocks)
181
- assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
182
- print(f"Constructor of UNetModel received um_attention_blocks={num_attention_blocks}. "
183
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
184
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
185
- f"attention will still not be set.")
186
-
187
- self.attention_resolutions = attention_resolutions
188
- self.dropout = dropout
189
- self.channel_mult = channel_mult
190
- self.conv_resample = conv_resample
191
- self.use_checkpoint = use_checkpoint
192
- self.dtype = th.float16 if use_fp16 else th.float32
193
- self.num_heads = num_heads
194
- self.num_head_channels = num_head_channels
195
- self.num_heads_upsample = num_heads_upsample
196
- self.predict_codebook_ids = n_embed is not None
197
- self.use_VAEdownsample = use_VAEdownsample
198
-
199
- time_embed_dim = model_channels * 4
200
- self.time_embed = nn.Sequential(
201
- linear(model_channels, time_embed_dim),
202
- nn.SiLU(),
203
- linear(time_embed_dim, time_embed_dim),
204
- )
205
-
206
- self.input_blocks = nn.ModuleList(
207
- [
208
- TimestepEmbedSequential(
209
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
210
- )
211
- ]
212
- )
213
-
214
- self.cond_first_block = TimestepEmbedSequential(
215
- zero_module(conv_nd(dims, cond_first_ch, model_channels, 3, padding=1))
216
- )
217
-
218
-
219
- self._feature_size = model_channels
220
- input_block_chans = [model_channels]
221
- ch = model_channels
222
- ds = 1
223
- for level, mult in enumerate(channel_mult):
224
- for nr in range(self.num_res_blocks[level]):
225
- layers = [
226
- ResBlock(
227
- ch,
228
- time_embed_dim,
229
- dropout,
230
- out_channels=mult * model_channels,
231
- dims=dims,
232
- use_checkpoint=use_checkpoint,
233
- use_scale_shift_norm=use_scale_shift_norm,
234
- )
235
- ]
236
- ch = mult * model_channels
237
- if ds in attention_resolutions:
238
- if num_head_channels == -1:
239
- dim_head = ch // num_heads
240
- else:
241
- num_heads = ch // num_head_channels
242
- dim_head = num_head_channels
243
- if legacy:
244
- # num_heads = 1
245
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
246
- if exists(disable_self_attentions):
247
- disabled_sa = disable_self_attentions[level]
248
- else:
249
- disabled_sa = False
250
-
251
- if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
252
- layers.append(
253
- AttentionBlock(
254
- ch,
255
- use_checkpoint=use_checkpoint,
256
- num_heads=num_heads,
257
- num_head_channels=dim_head,
258
- use_new_attention_order=use_new_attention_order,
259
- ) if not use_spatial_transformer else SpatialTransformer(
260
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
261
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
262
- use_checkpoint=use_checkpoint
263
- )
264
- )
265
- self.input_blocks.append(TimestepEmbedSequential(*layers))
266
- self._feature_size += ch
267
- input_block_chans.append(ch)
268
- if level != len(channel_mult) - 1:
269
- out_ch = ch
270
- self.input_blocks.append(
271
- TimestepEmbedSequential(
272
- ResBlock(
273
- ch,
274
- time_embed_dim,
275
- dropout,
276
- out_channels=out_ch,
277
- dims=dims,
278
- use_checkpoint=use_checkpoint,
279
- use_scale_shift_norm=use_scale_shift_norm,
280
- down=True,
281
- )
282
- if resblock_updown
283
- else Downsample(
284
- ch, conv_resample, dims=dims, out_channels=out_ch
285
- )
286
- )
287
- )
288
- ch = out_ch
289
- input_block_chans.append(ch)
290
- ds *= 2
291
- self._feature_size += ch
292
-
293
- if num_head_channels == -1:
294
- dim_head = ch // num_heads
295
- else:
296
- num_heads = ch // num_head_channels
297
- dim_head = num_head_channels
298
- if legacy:
299
- # num_heads = 1
300
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
301
- self.middle_block = TimestepEmbedSequential(
302
- ResBlock(
303
- ch,
304
- time_embed_dim,
305
- dropout,
306
- dims=dims,
307
- use_checkpoint=use_checkpoint,
308
- use_scale_shift_norm=use_scale_shift_norm,
309
- ),
310
- AttentionBlock(
311
- ch,
312
- use_checkpoint=use_checkpoint,
313
- num_heads=num_heads,
314
- num_head_channels=dim_head,
315
- use_new_attention_order=use_new_attention_order,
316
- ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
317
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
318
- disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
319
- use_checkpoint=use_checkpoint
320
- ),
321
- ResBlock(
322
- ch,
323
- time_embed_dim,
324
- dropout,
325
- dims=dims,
326
- use_checkpoint=use_checkpoint,
327
- use_scale_shift_norm=use_scale_shift_norm,
328
- ),
329
- )
330
- self._feature_size += ch
331
-
332
- def forward(self, x, hint, timesteps, context, only_mid_control=False, **kwargs):
333
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
334
- emb = self.time_embed(t_emb)
335
-
336
- if not self.use_VAEdownsample:
337
- guided_hint = self.input_hint_block(hint, emb, context)
338
- else:
339
- guided_hint = self.cond_first_block(hint, emb, context)
340
-
341
- outs = []
342
- hs = []
343
- h = x.type(self.dtype)
344
- for module in self.input_blocks:
345
- if guided_hint is not None:
346
- h = module(h, emb, context)
347
- h += guided_hint
348
- hs.append(h)
349
- guided_hint = None
350
- else:
351
- h = module(h, emb, context)
352
- hs.append(h)
353
- outs.append(h)
354
-
355
- h = self.middle_block(h, emb, context)
356
- outs.append(h)
357
- return outs, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/VITON.yaml DELETED
@@ -1,100 +0,0 @@
1
- model:
2
- target: cldm.cldm.ControlLDM
3
- params:
4
- linear_start: 0.00085
5
- linear_end: 0.0120
6
- num_timesteps_cond: 1
7
- log_every_t: 200
8
- timesteps: 1000
9
- first_stage_key: "image"
10
- first_stage_key_cond: ["agn", "agn_mask", "image_densepose"]
11
- cond_stage_key: "cloth"
12
- control_key: "cloth"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: False
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False
20
- only_mid_control: False
21
- use_VAEdownsample: True
22
- use_lastzc: True
23
- use_imageCLIP: True
24
- use_pbe_weight: True
25
- u_cond_percent: 0.2
26
- use_attn_mask: False
27
- mask1_key: "agn_mask"
28
- mask2_key: "cloth_mask"
29
-
30
- control_stage_config:
31
- target: cldm.warping_cldm_network.NoZeroConvControlNet
32
- params:
33
- image_size: 32
34
- in_channels: 13
35
- hint_channels: 3
36
- model_channels: 320
37
- attention_resolutions: [ 4, 2, 1 ]
38
- num_res_blocks: 2
39
- channel_mult: [ 1, 2, 4, 4 ]
40
- num_heads: 8
41
- use_spatial_transformer: True
42
- transformer_depth: 1
43
- context_dim: 768
44
- use_checkpoint: True
45
- legacy: False
46
- cond_first_ch: 4
47
-
48
- unet_config:
49
- target: cldm.warping_cldm_network.StableVITON
50
- params:
51
- image_size: 32
52
- in_channels: 13
53
- out_channels: 4
54
- model_channels: 320
55
- attention_resolutions: [ 4, 2, 1 ]
56
- num_res_blocks: 2
57
- channel_mult: [ 1, 2, 4, 4 ]
58
- num_heads: 8
59
- use_spatial_transformer: True
60
- transformer_depth: 1
61
- context_dim: 768
62
- use_checkpoint: True
63
- legacy: False
64
- dim_head_denorm: 1
65
-
66
- first_stage_config:
67
- target: ldm.models.autoencoder.AutoencoderKL
68
- params:
69
- embed_dim: 4
70
- monitor: val/rec_loss
71
- ddconfig:
72
- double_z: true
73
- z_channels: 4
74
- resolution: 256
75
- in_channels: 3
76
- out_ch: 3
77
- ch: 128
78
- ch_mult:
79
- - 1
80
- - 2
81
- - 4
82
- - 4
83
- num_res_blocks: 2
84
- attn_resolutions: []
85
- dropout: 0.0
86
- lossconfig:
87
- target: torch.nn.Identity
88
- validation_config:
89
- ddim_steps: 50
90
- eta: 0.0
91
- scale: 1.0
92
-
93
- cond_stage_config:
94
- target: ldm.modules.image_encoders.modules.FrozenCLIPImageEmbedder
95
- dataset_name: VITONHDDataset
96
- resume_path: ./pretrained_models/VITONHD_PBE_pose.ckpt
97
- default_prompt: ""
98
- log_images_kwargs:
99
- unconditional_guidance_scale: 5.0
100
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
detectron2 CHANGED
@@ -1 +1 @@
1
- preprocess/detectron2/detectron2/
 
1
+ preprocess/detectron2/detectron2
examples/garment/00055_00.jpg ADDED
examples/garment/00470_00.jpg DELETED
Binary file (252 kB)
 
examples/garment/08973_00.jpg DELETED
Binary file (206 kB)
 
examples/garment/12469_00.jpg DELETED
Binary file (54.2 kB)
 
examples/model/04913_00.jpg DELETED
Binary file (165 kB)
 
examples/model/05032_00.jpg DELETED
Binary file (115 kB)
 
examples_eternal/garment/1.jpg DELETED
Binary file (533 kB)
 
examples_eternal/garment/2.jpg DELETED
Binary file (533 kB)
 
examples_eternal/garment/3.jpg DELETED
Binary file (440 kB)
 
examples_eternal/garment/4.jpg DELETED
Binary file (91.1 kB)
 
examples_eternal/garment/5.jpg DELETED
Binary file (76.8 kB)
 
examples_eternal/garment/6.jpg DELETED
Binary file (75.3 kB)
 
examples_eternal/model/1.jpg DELETED
Binary file (80.9 kB)
 
examples_eternal/model/2.jpg DELETED
Binary file (100 kB)
 
examples_eternal/model/3.jpg DELETED
Binary file (135 kB)
 
examples_eternal/model/4.jpg DELETED
Binary file (151 kB)
 
examples_eternal/model/6.jpg DELETED
Binary file (177 kB)
 
ldm/data/__init__.py DELETED
File without changes
ldm/data/util.py DELETED
@@ -1,24 +0,0 @@
1
- import torch
2
-
3
- from ldm.modules.midas.api import load_midas_transform
4
-
5
-
6
- class AddMiDaS(object):
7
- def __init__(self, model_type):
8
- super().__init__()
9
- self.transform = load_midas_transform(model_type)
10
-
11
- def pt2np(self, x):
12
- x = ((x + 1.0) * .5).detach().cpu().numpy()
13
- return x
14
-
15
- def np2pt(self, x):
16
- x = torch.from_numpy(x) * 2 - 1.
17
- return x
18
-
19
- def __call__(self, sample):
20
- # sample['jpg'] is tensor hwc in [-1, 1] at this point
21
- x = self.pt2np(sample['jpg'])
22
- x = self.transform({"image": x})["image"]
23
- sample['midas_in'] = x
24
- return sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/autoencoder.py DELETED
@@ -1,203 +0,0 @@
1
- import torch
2
- # import pytorch_lightning as pl
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- from contextlib import contextmanager
6
-
7
- from ldm.modules.diffusionmodules.model import Encoder, Decoder
8
- from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
9
-
10
- from ldm.util import instantiate_from_config
11
- from ldm.modules.ema import LitEma
12
-
13
-
14
- class AutoencoderKL(nn.Module):
15
- def __init__(self,
16
- ddconfig,
17
- lossconfig,
18
- embed_dim,
19
- ckpt_path=None,
20
- ignore_keys=[],
21
- image_key="image",
22
- colorize_nlabels=None,
23
- monitor=None,
24
- ema_decay=None,
25
- learn_logvar=False
26
- ):
27
- super().__init__()
28
- self.lossconfig = lossconfig
29
- self.learn_logvar = learn_logvar
30
- self.image_key = image_key
31
- self.encoder = Encoder(**ddconfig)
32
- self.decoder = Decoder(**ddconfig)
33
- self.loss = torch.nn.Identity()
34
- assert ddconfig["double_z"]
35
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
36
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
37
- self.embed_dim = embed_dim
38
- if colorize_nlabels is not None:
39
- assert type(colorize_nlabels)==int
40
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
41
- if monitor is not None:
42
- self.monitor = monitor
43
-
44
- self.use_ema = ema_decay is not None
45
- if self.use_ema:
46
- self.ema_decay = ema_decay
47
- assert 0. < ema_decay < 1.
48
- self.model_ema = LitEma(self, decay=ema_decay)
49
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
50
-
51
- if ckpt_path is not None:
52
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
53
- def init_loss(self):
54
- self.loss = instantiate_from_config(self.lossconfig)
55
- def init_from_ckpt(self, path, ignore_keys=list()):
56
- sd = torch.load(path, map_location="cpu")["state_dict"]
57
- keys = list(sd.keys())
58
- for k in keys:
59
- for ik in ignore_keys:
60
- if k.startswith(ik):
61
- print("Deleting key {} from state_dict.".format(k))
62
- del sd[k]
63
- self.load_state_dict(sd, strict=False)
64
- print(f"Restored from {path}")
65
-
66
- @contextmanager
67
- def ema_scope(self, context=None):
68
- if self.use_ema:
69
- self.model_ema.store(self.parameters())
70
- self.model_ema.copy_to(self)
71
- if context is not None:
72
- print(f"{context}: Switched to EMA weights")
73
- try:
74
- yield None
75
- finally:
76
- if self.use_ema:
77
- self.model_ema.restore(self.parameters())
78
- if context is not None:
79
- print(f"{context}: Restored training weights")
80
-
81
- def on_train_batch_end(self, *args, **kwargs):
82
- if self.use_ema:
83
- self.model_ema(self)
84
-
85
- def encode(self, x):
86
- h = self.encoder(x)
87
- moments = self.quant_conv(h)
88
- posterior = DiagonalGaussianDistribution(moments)
89
- return posterior
90
-
91
- def decode(self, z):
92
- z = self.post_quant_conv(z)
93
- dec = self.decoder(z)
94
- return dec
95
-
96
- def forward(self, input, sample_posterior=True):
97
- posterior = self.encode(input)
98
- if sample_posterior:
99
- z = posterior.sample()
100
- else:
101
- z = posterior.mode()
102
- dec = self.decode(z)
103
- return dec, posterior
104
-
105
- def get_input(self, batch, k):
106
- x = batch[k]
107
- if len(x.shape) == 3:
108
- x = x[..., None]
109
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
110
- return x
111
-
112
- def training_step(self, batch, batch_idx):
113
- real_img = self.get_input(batch, self.image_key)
114
- recon, posterior = self(real_img)
115
- loss = self.loss(real_img, recon, posterior)
116
- return loss
117
-
118
- def validation_step(self, batch, batch_idx):
119
- log_dict = self._validation_step(batch, batch_idx)
120
- with self.ema_scope():
121
- log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
122
- return log_dict
123
-
124
- def _validation_step(self, batch, batch_idx, postfix=""):
125
- inputs = self.get_input(batch, self.image_key)
126
- reconstructions, posterior = self(inputs)
127
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
128
- last_layer=self.get_last_layer(), split="val"+postfix)
129
-
130
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
131
- last_layer=self.get_last_layer(), split="val"+postfix)
132
-
133
- self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
134
- self.log_dict(log_dict_ae)
135
- self.log_dict(log_dict_disc)
136
- return self.log_dict
137
- def configure_optimizers(self):
138
- lr = self.learning_rate
139
- ae_params_list = list(self.decoder.parameters())
140
- if self.learn_logvar:
141
- print(f"{self.__class__.__name__}: Learning logvar")
142
- ae_params_list.append(self.loss.logvar)
143
- opt_ae = torch.optim.Adam(ae_params_list,
144
- lr=lr, betas=(0.5, 0.9))
145
- return [opt_ae], []
146
-
147
- def get_last_layer(self):
148
- return self.decoder.conv_out.weight
149
-
150
- @torch.no_grad()
151
- def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
152
- log = dict()
153
- x = self.get_input(batch, self.image_key)
154
- x = x.to(self.device)
155
- if not only_inputs:
156
- xrec, posterior = self(x)
157
- if x.shape[1] > 3:
158
- # colorize with random projection
159
- assert xrec.shape[1] > 3
160
- x = self.to_rgb(x)
161
- xrec = self.to_rgb(xrec)
162
- log["samples"] = self.decode(torch.randn_like(posterior.sample()))
163
- log["reconstructions"] = xrec
164
- if log_ema or self.use_ema:
165
- with self.ema_scope():
166
- xrec_ema, posterior_ema = self(x)
167
- if x.shape[1] > 3:
168
- # colorize with random projection
169
- assert xrec_ema.shape[1] > 3
170
- xrec_ema = self.to_rgb(xrec_ema)
171
- log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
172
- log["reconstructions_ema"] = xrec_ema
173
- log["inputs"] = x
174
- return log
175
-
176
- def to_rgb(self, x):
177
- assert self.image_key == "segmentation"
178
- if not hasattr(self, "colorize"):
179
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
180
- x = F.conv2d(x, weight=self.colorize)
181
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
182
- return x
183
-
184
-
185
- class IdentityFirstStage(torch.nn.Module):
186
- def __init__(self, *args, vq_interface=False, **kwargs):
187
- self.vq_interface = vq_interface
188
- super().__init__()
189
-
190
- def encode(self, x, *args, **kwargs):
191
- return x
192
-
193
- def decode(self, x, *args, **kwargs):
194
- return x
195
-
196
- def quantize(self, x, *args, **kwargs):
197
- if self.vq_interface:
198
- return x, None, [None, None, None]
199
- return x
200
-
201
- def forward(self, x, *args, **kwargs):
202
- return x
203
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/__init__.py DELETED
File without changes
ldm/models/diffusion/ddim.py DELETED
@@ -1,377 +0,0 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
-
7
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
-
9
-
10
- class DDIMSampler(object):
11
- def __init__(self, model, schedule="linear", **kwargs):
12
- super().__init__()
13
- self.model = model
14
- self.ddpm_num_timesteps = model.num_timesteps
15
- self.schedule = schedule
16
-
17
- def register_buffer(self, name, attr):
18
- if type(attr) == torch.Tensor:
19
- if attr.device != torch.device("cuda"):
20
- attr = attr.to(torch.device("cuda"))
21
- setattr(self, name, attr)
22
-
23
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
- alphas_cumprod = self.model.alphas_cumprod
27
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
-
30
- self.register_buffer('betas', to_torch(self.model.betas))
31
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
-
34
- # calculations for diffusion q(x_t | x_{t-1}) and others
35
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
-
41
- # ddim sampling parameters
42
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
- ddim_timesteps=self.ddim_timesteps,
44
- eta=ddim_eta,verbose=verbose)
45
- self.register_buffer('ddim_sigmas', ddim_sigmas)
46
- self.register_buffer('ddim_alphas', ddim_alphas)
47
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
-
54
- @torch.no_grad()
55
- def sample(self,
56
- S,
57
- batch_size,
58
- shape,
59
- conditioning=None,
60
- callback=None,
61
- normals_sequence=None,
62
- img_callback=None,
63
- quantize_x0=False,
64
- eta=0.,
65
- mask=None,
66
- x0=None,
67
- temperature=1.,
68
- noise_dropout=0.,
69
- score_corrector=None,
70
- corrector_kwargs=None,
71
- verbose=True,
72
- x_T=None,
73
- log_every_t=100,
74
- unconditional_guidance_scale=1.,
75
- unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
- dynamic_threshold=None,
77
- ucg_schedule=None,
78
- **kwargs
79
- ):
80
- if conditioning is not None:
81
- if isinstance(conditioning, dict):
82
- ctmp = conditioning[list(conditioning.keys())[0]]
83
- while isinstance(ctmp, list): ctmp = ctmp[0]
84
- cbs = ctmp.shape[0]
85
- if cbs != batch_size:
86
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
-
88
- elif isinstance(conditioning, list):
89
- for ctmp in conditioning:
90
- if ctmp.shape[0] != batch_size:
91
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
-
93
- else:
94
- if conditioning.shape[0] != batch_size:
95
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
-
97
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
- # sampling
99
- C, H, W = shape
100
- size = (batch_size, C, H, W)
101
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
-
103
- samples, intermediates, cond_output_dict = self.ddim_sampling(conditioning, size,
104
- callback=callback,
105
- img_callback=img_callback,
106
- quantize_denoised=quantize_x0,
107
- mask=mask, x0=x0,
108
- ddim_use_original_steps=False,
109
- noise_dropout=noise_dropout,
110
- temperature=temperature,
111
- score_corrector=score_corrector,
112
- corrector_kwargs=corrector_kwargs,
113
- x_T=x_T,
114
- log_every_t=log_every_t,
115
- unconditional_guidance_scale=unconditional_guidance_scale,
116
- unconditional_conditioning=unconditional_conditioning,
117
- dynamic_threshold=dynamic_threshold,
118
- ucg_schedule=ucg_schedule
119
- )
120
- return samples, intermediates, cond_output_dict
121
-
122
- @torch.no_grad()
123
- def ddim_sampling(self, cond, shape,
124
- x_T=None, ddim_use_original_steps=False,
125
- callback=None, timesteps=None, quantize_denoised=False,
126
- mask=None, x0=None, img_callback=None, log_every_t=100,
127
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
- unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
- ucg_schedule=None):
130
- device = self.model.betas.device
131
- b = shape[0]
132
- if x_T is None:
133
- img = torch.randn(shape, device=device)
134
- else:
135
- img = x_T
136
-
137
- if timesteps is None:
138
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
- elif timesteps is not None and not ddim_use_original_steps:
140
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
- timesteps = self.ddim_timesteps[:subset_end]
142
-
143
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
- print(f"Running DDIM Sampling with {total_steps} timesteps")
147
-
148
- iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
-
150
- for i, step in enumerate(iterator):
151
- index = total_steps - i - 1
152
- ts = torch.full((b,), step, device=device, dtype=torch.long)
153
-
154
- if mask is not None:
155
- assert x0 is not None
156
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
- img = img_orig * mask + (1. - mask) * img
158
-
159
- if ucg_schedule is not None:
160
- assert len(ucg_schedule) == len(time_range)
161
- unconditional_guidance_scale = ucg_schedule[i]
162
-
163
- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
- quantize_denoised=quantize_denoised, temperature=temperature,
165
- noise_dropout=noise_dropout, score_corrector=score_corrector,
166
- corrector_kwargs=corrector_kwargs,
167
- unconditional_guidance_scale=unconditional_guidance_scale,
168
- unconditional_conditioning=unconditional_conditioning,
169
- dynamic_threshold=dynamic_threshold)
170
- img, pred_x0, cond_output_dict = outs
171
- if callback: callback(i)
172
- if img_callback: img_callback(pred_x0, i)
173
-
174
- if index % log_every_t == 0 or index == total_steps - 1:
175
- intermediates['x_inter'].append(img)
176
- intermediates['pred_x0'].append(pred_x0)
177
-
178
- if cond_output_dict is not None:
179
- cond_output = cond_output_dict["cond_output"]
180
- if self.model.use_noisy_cond:
181
- b = cond_output.shape[0]
182
-
183
- alphas = self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas
184
- alphas_prev = self.model.alphas_cumprod_prev if ddim_use_original_steps else self.ddim_alphas_prev
185
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if ddim_use_original_steps else self.ddim_sqrt_one_minus_alphas
186
- sigmas = self.model.ddim_sigmas_for_original_num_steps if ddim_use_original_steps else self.ddim_sigmas
187
-
188
- device = cond_output.device
189
- a_t = torch.full((b, 1, 1, 1), alphas[0], device=device)
190
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[0], device=device)
191
- sigma_t = torch.full((b, 1, 1, 1), sigmas[0], device=device)
192
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[0], device=device)
193
-
194
- c = cond_output_dict["cond_input"]
195
- e_t = cond_output
196
- pred_c0 = (c - sqrt_one_minus_at * e_t) / a_t.sqrt()
197
- dir_ct = (1. - a_prev - sigma_t**2).sqrt() * e_t
198
- noise = sigma_t * noise_like(c.shape, device, False) * temperature
199
-
200
- if noise_dropout > 0.:
201
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
202
- cond_output = a_prev.sqrt() * pred_c0 + dir_ct + noise
203
- cond_output_dict[f"cond_sample"] = cond_output
204
- return img, intermediates, cond_output_dict
205
-
206
- @torch.no_grad()
207
- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
208
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
209
- unconditional_guidance_scale=1., unconditional_conditioning=None,
210
- dynamic_threshold=None):
211
- b, *_, device = *x.shape, x.device
212
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
213
- model_output, cond_output_dict = self.model.apply_model(x, t, c)
214
- else:
215
- # x_in = torch.cat([x] * 2)
216
- # t_in = torch.cat([t] * 2)
217
- # if isinstance(c, dict):
218
- # assert isinstance(unconditional_conditioning, dict)
219
- # c_in = dict()
220
- # for k in c:
221
- # if isinstance(c[k], list):
222
- # c_in[k] = [torch.cat([
223
- # unconditional_conditioning[k][i],
224
- # c[k][i]]) for i in range(len(c[k]))]
225
- # else:
226
- # c_in[k] = torch.cat([
227
- # unconditional_conditioning[k],
228
- # c[k]])
229
- # elif isinstance(c, list):
230
- # c_in = list()
231
- # assert isinstance(unconditional_conditioning, list)
232
- # for i in range(len(c)):
233
- # c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
234
- # else:
235
- # c_in = torch.cat([unconditional_conditioning, c])
236
- x_in = x
237
- t_in = t
238
- model_t, cond_output_dict_cond = self.model.apply_model(x_in, t_in, c)
239
- model_uncond, cond_output_dict_uncond = self.model.apply_model(x_in, t_in, unconditional_conditioning)
240
- if isinstance(model_t, tuple):
241
- model_t, _ = model_t
242
- if isinstance(model_uncond, tuple):
243
- model_uncond, _ = model_uncond
244
- if cond_output_dict_cond is not None:
245
- cond_output_dict = dict()
246
- for k in cond_output_dict_cond.keys():
247
- cond_output_dict[k] = torch.cat([cond_output_dict_uncond[k], cond_output_dict_cond[k]])
248
- else:
249
- cond_output_dict = None
250
- # model_output, cond_output_dict = self.model.apply_model(x_in, t_in, c_in)
251
- # model_uncond, model_t = model_output.chunk(2)
252
- model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
253
-
254
- if self.model.parameterization == "v":
255
- e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
256
- else:
257
- e_t = model_output
258
-
259
- if score_corrector is not None:
260
- assert self.model.parameterization == "eps", 'not implemented'
261
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
262
-
263
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
264
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
265
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
266
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
267
- # select parameters corresponding to the currently considered timestep
268
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
269
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
270
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
271
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
272
-
273
- # current prediction for x_0
274
- if self.model.parameterization != "v":
275
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
276
- else:
277
- pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
278
-
279
- if quantize_denoised:
280
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
281
-
282
- if dynamic_threshold is not None:
283
- raise NotImplementedError()
284
-
285
- # direction pointing to x_t
286
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
287
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
288
- if noise_dropout > 0.:
289
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
290
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
291
-
292
- return x_prev, pred_x0, cond_output_dict
293
-
294
- @torch.no_grad()
295
- def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
296
- unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
297
- num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
298
-
299
- assert t_enc <= num_reference_steps
300
- num_steps = t_enc
301
-
302
- if use_original_steps:
303
- alphas_next = self.alphas_cumprod[:num_steps]
304
- alphas = self.alphas_cumprod_prev[:num_steps]
305
- else:
306
- alphas_next = self.ddim_alphas[:num_steps]
307
- alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
308
-
309
- x_next = x0
310
- intermediates = []
311
- inter_steps = []
312
- for i in tqdm(range(num_steps), desc='Encoding Image'):
313
- t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
314
- if unconditional_guidance_scale == 1.:
315
- noise_pred = self.model.apply_model(x_next, t, c)[0]
316
- else:
317
- assert unconditional_conditioning is not None
318
- e_t_uncond, noise_pred = torch.chunk(
319
- self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
320
- torch.cat((unconditional_conditioning, c))), 2)
321
- noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)[0]
322
-
323
- xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
324
- weighted_noise_pred = alphas_next[i].sqrt() * (
325
- (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
326
- x_next = xt_weighted + weighted_noise_pred
327
- if return_intermediates and i % (
328
- num_steps // return_intermediates) == 0 and i < num_steps - 1:
329
- intermediates.append(x_next)
330
- inter_steps.append(i)
331
- elif return_intermediates and i >= num_steps - 2:
332
- intermediates.append(x_next)
333
- inter_steps.append(i)
334
- if callback: callback(i)
335
-
336
- out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
337
- if return_intermediates:
338
- out.update({'intermediates': intermediates})
339
- return x_next, out
340
-
341
- @torch.no_grad()
342
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
343
- # fast, but does not allow for exact reconstruction
344
- # t serves as an index to gather the correct alphas
345
- if use_original_steps:
346
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
347
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
348
- else:
349
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
350
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
351
-
352
- if noise is None:
353
- noise = torch.randn_like(x0)
354
- return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
355
- extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
356
-
357
- @torch.no_grad()
358
- def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
359
- use_original_steps=False, callback=None):
360
-
361
- timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
362
- timesteps = timesteps[:t_start]
363
-
364
- time_range = np.flip(timesteps)
365
- total_steps = timesteps.shape[0]
366
- print(f"Running DDIM Sampling with {total_steps} timesteps")
367
-
368
- iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
369
- x_dec = x_latent
370
- for i, step in enumerate(iterator):
371
- index = total_steps - i - 1
372
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
373
- x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
374
- unconditional_guidance_scale=unconditional_guidance_scale,
375
- unconditional_conditioning=unconditional_conditioning)
376
- if callback: callback(i)
377
- return x_dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/ddpm.py DELETED
@@ -1,1875 +0,0 @@
1
- """
2
- wild mixture of
3
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
- https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
- https://github.com/CompVis/taming-transformers
6
- -- merci
7
- """
8
-
9
- import torch
10
- import torch.nn as nn
11
- import numpy as np
12
- # import pytorch_lightning as pl
13
- from torch.optim.lr_scheduler import LambdaLR
14
- from einops import rearrange, repeat
15
- from contextlib import contextmanager, nullcontext
16
- from functools import partial
17
- import itertools
18
- from tqdm import tqdm
19
- from torchvision.utils import make_grid
20
- # from pytorch_lightning.utilities.distributed import rank_zero_only
21
- from omegaconf import ListConfig
22
- from torchvision.transforms.functional import resize
23
- import torchvision.transforms as T
24
- import random
25
- import torch.nn.functional as F
26
- from diffusers.models.autoencoder_kl import AutoencoderKLOutput
27
- from diffusers.models.vae import DecoderOutput
28
-
29
- from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
30
- from ldm.modules.ema import LitEma
31
- from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
32
- from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
33
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like, zero_module, conv_nd
34
- from ldm.models.diffusion.ddim import DDIMSampler
35
-
36
- __conditioning_keys__ = {'concat': 'c_concat',
37
- 'crossattn': 'c_crossattn',
38
- 'adm': 'y'}
39
-
40
-
41
- def disabled_train(self, mode=True):
42
- """Overwrite model.train with this function to make sure train/eval mode
43
- does not change anymore."""
44
- return self
45
-
46
-
47
- def uniform_on_device(r1, r2, shape, device):
48
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
49
-
50
- class DDPM(nn.Module):
51
- # classic DDPM with Gaussian diffusion, in image space
52
- def __init__(self,
53
- unet_config,
54
- timesteps=1000,
55
- beta_schedule="linear",
56
- loss_type="l2",
57
- ckpt_path=None,
58
- ignore_keys=[],
59
- load_only_unet=False,
60
- monitor="val/loss",
61
- use_ema=True,
62
- first_stage_key="image",
63
- image_size=256,
64
- channels=3,
65
- log_every_t=100,
66
- clip_denoised=True,
67
- linear_start=1e-4,
68
- linear_end=2e-2,
69
- cosine_s=8e-3,
70
- given_betas=None,
71
- original_elbo_weight=0.,
72
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
73
- l_simple_weight=1.,
74
- conditioning_key=None,
75
- parameterization="eps", # all assuming fixed variance schedules
76
- scheduler_config=None,
77
- use_positional_encodings=False,
78
- learn_logvar=False,
79
- logvar_init=0.,
80
- make_it_fit=False,
81
- ucg_training=None,
82
- reset_ema=False,
83
- reset_num_ema_updates=False,
84
- l_cond_simple_weight=1.0,
85
- l_cond_recon_weight=1.0,
86
- **kwargs
87
- ):
88
- super().__init__()
89
- assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
90
- self.parameterization = parameterization
91
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
92
- self.unet_config = unet_config
93
- self.cond_stage_model = None
94
- self.clip_denoised = clip_denoised
95
- self.log_every_t = log_every_t
96
- self.first_stage_key = first_stage_key
97
- self.image_size = image_size # try conv?
98
- self.channels = channels
99
- self.use_positional_encodings = use_positional_encodings
100
- self.model = DiffusionWrapper(unet_config, conditioning_key)
101
- count_params(self.model, verbose=True)
102
- self.use_ema = use_ema
103
- if self.use_ema:
104
- self.model_ema = LitEma(self.model)
105
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
106
-
107
- self.use_scheduler = scheduler_config is not None
108
- if self.use_scheduler:
109
- self.scheduler_config = scheduler_config
110
- self.imagenet_norm = T.Normalize((0.48145466, 0.4578275, 0.40821073),
111
- (0.26862954, 0.26130258, 0.27577711))
112
-
113
- self.v_posterior = v_posterior
114
- self.original_elbo_weight = original_elbo_weight
115
- self.l_simple_weight = l_simple_weight
116
- self.l_cond_simple_weight = l_cond_simple_weight
117
- self.l_cond_recon_weight = l_cond_recon_weight
118
-
119
- if monitor is not None:
120
- self.monitor = monitor
121
- self.make_it_fit = make_it_fit
122
- if reset_ema: assert exists(ckpt_path)
123
- if ckpt_path is not None:
124
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
125
- if reset_ema:
126
- assert self.use_ema
127
- print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
128
- self.model_ema = LitEma(self.model)
129
- if reset_num_ema_updates:
130
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
131
- assert self.use_ema
132
- self.model_ema.reset_num_updates()
133
-
134
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
135
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
136
-
137
- self.loss_type = loss_type
138
-
139
- self.learn_logvar = learn_logvar
140
- logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
141
- if self.learn_logvar:
142
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
143
- else:
144
- self.register_buffer('logvar', logvar)
145
-
146
- self.ucg_training = ucg_training or dict()
147
- if self.ucg_training:
148
- self.ucg_prng = np.random.RandomState()
149
-
150
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
151
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
152
- if exists(given_betas):
153
- betas = given_betas
154
- else:
155
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
156
- cosine_s=cosine_s)
157
- alphas = 1. - betas
158
- alphas_cumprod = np.cumprod(alphas, axis=0)
159
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
160
-
161
- timesteps, = betas.shape
162
- self.num_timesteps = int(timesteps)
163
- self.linear_start = linear_start
164
- self.linear_end = linear_end
165
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
166
-
167
- to_torch = partial(torch.tensor, dtype=torch.float32)
168
-
169
- self.register_buffer('betas', to_torch(betas))
170
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
171
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
172
-
173
- # calculations for diffusion q(x_t | x_{t-1}) and others
174
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
175
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
176
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
177
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
178
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
179
-
180
- # calculations for posterior q(x_{t-1} | x_t, x_0)
181
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
182
- 1. - alphas_cumprod) + self.v_posterior * betas
183
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
184
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
185
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
186
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
187
- self.register_buffer('posterior_mean_coef1', to_torch(
188
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
189
- self.register_buffer('posterior_mean_coef2', to_torch(
190
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
191
-
192
- if self.parameterization == "eps":
193
- lvlb_weights = self.betas ** 2 / (
194
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
195
- elif self.parameterization == "x0":
196
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
197
- elif self.parameterization == "v":
198
- lvlb_weights = torch.ones_like(self.betas ** 2 / (
199
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
200
- else:
201
- raise NotImplementedError("mu not supported")
202
- lvlb_weights[0] = lvlb_weights[1]
203
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
204
- assert not torch.isnan(self.lvlb_weights).all()
205
-
206
- @contextmanager
207
- def ema_scope(self, context=None):
208
- if self.use_ema:
209
- self.model_ema.store(self.model.parameters())
210
- self.model_ema.copy_to(self.model)
211
- if context is not None:
212
- print(f"{context}: Switched to EMA weights")
213
- try:
214
- yield None
215
- finally:
216
- if self.use_ema:
217
- self.model_ema.restore(self.model.parameters())
218
- if context is not None:
219
- print(f"{context}: Restored training weights")
220
-
221
- @torch.no_grad()
222
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
223
- sd = torch.load(path, map_location="cpu")
224
- if "state_dict" in list(sd.keys()):
225
- sd = sd["state_dict"]
226
- keys = list(sd.keys())
227
- for k in keys:
228
- for ik in ignore_keys:
229
- if k.startswith(ik):
230
- print("Deleting key {} from state_dict.".format(k))
231
- del sd[k]
232
- if self.make_it_fit:
233
- n_params = len([name for name, _ in
234
- itertools.chain(self.named_parameters(),
235
- self.named_buffers())])
236
- for name, param in tqdm(
237
- itertools.chain(self.named_parameters(),
238
- self.named_buffers()),
239
- desc="Fitting old weights to new weights",
240
- total=n_params
241
- ):
242
- if not name in sd:
243
- continue
244
- old_shape = sd[name].shape
245
- new_shape = param.shape
246
- assert len(old_shape) == len(new_shape)
247
- if len(new_shape) > 2:
248
- # we only modify first two axes
249
- assert new_shape[2:] == old_shape[2:]
250
- # assumes first axis corresponds to output dim
251
- if not new_shape == old_shape:
252
- new_param = param.clone()
253
- old_param = sd[name]
254
- if len(new_shape) == 1:
255
- for i in range(new_param.shape[0]):
256
- new_param[i] = old_param[i % old_shape[0]]
257
- elif len(new_shape) >= 2:
258
- for i in range(new_param.shape[0]):
259
- for j in range(new_param.shape[1]):
260
- new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
261
-
262
- n_used_old = torch.ones(old_shape[1])
263
- for j in range(new_param.shape[1]):
264
- n_used_old[j % old_shape[1]] += 1
265
- n_used_new = torch.zeros(new_shape[1])
266
- for j in range(new_param.shape[1]):
267
- n_used_new[j] = n_used_old[j % old_shape[1]]
268
-
269
- n_used_new = n_used_new[None, :]
270
- while len(n_used_new.shape) < len(new_shape):
271
- n_used_new = n_used_new.unsqueeze(-1)
272
- new_param /= n_used_new
273
-
274
- sd[name] = new_param
275
-
276
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
277
- sd, strict=False)
278
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
279
- if len(missing) > 0:
280
- print(f"Missing Keys:\n {missing}")
281
- if len(unexpected) > 0:
282
- print(f"\nUnexpected Keys:\n {unexpected}")
283
-
284
- def q_mean_variance(self, x_start, t):
285
- """
286
- Get the distribution q(x_t | x_0).
287
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
288
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
289
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
290
- """
291
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
292
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
293
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
294
- return mean, variance, log_variance
295
-
296
- def predict_start_from_noise(self, x_t, t, noise):
297
- return (
298
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
299
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
300
- )
301
-
302
- def predict_start_from_z_and_v(self, x_t, t, v):
303
- # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
304
- # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
305
- return (
306
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
307
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
308
- )
309
-
310
- def predict_eps_from_z_and_v(self, x_t, t, v):
311
- return (
312
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
313
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
314
- )
315
-
316
- def q_posterior(self, x_start, x_t, t):
317
- posterior_mean = (
318
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
319
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
320
- )
321
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
322
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
323
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
324
-
325
- def p_mean_variance(self, x, t, clip_denoised: bool):
326
- model_out = self.model(x, t)
327
- if self.parameterization == "eps":
328
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
329
- elif self.parameterization == "x0":
330
- x_recon = model_out
331
- if clip_denoised:
332
- x_recon.clamp_(-1., 1.)
333
-
334
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
335
- return model_mean, posterior_variance, posterior_log_variance
336
-
337
- @torch.no_grad()
338
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
339
- b, *_, device = *x.shape, x.device
340
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
341
- noise = noise_like(x.shape, device, repeat_noise)
342
- # no noise when t == 0
343
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
344
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
345
-
346
- @torch.no_grad()
347
- def p_sample_loop(self, shape, return_intermediates=False):
348
- device = self.betas.device
349
- b = shape[0]
350
- img = torch.randn(shape, device=device)
351
- intermediates = [img]
352
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
353
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
354
- clip_denoised=self.clip_denoised)
355
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
356
- intermediates.append(img)
357
- if return_intermediates:
358
- return img, intermediates
359
- return img
360
-
361
- @torch.no_grad()
362
- def sample(self, batch_size=16, return_intermediates=False):
363
- image_size = self.image_size
364
- channels = self.channels
365
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
366
- return_intermediates=return_intermediates)
367
-
368
- def q_sample(self, x_start, t, noise=None):
369
- noise = default(noise, lambda: torch.randn_like(x_start))
370
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
371
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
372
-
373
- def get_v(self, x, noise, t):
374
- return (
375
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
376
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
377
- )
378
-
379
- def get_loss(self, pred, target, mean=True):
380
- if self.loss_type == 'l1':
381
- loss = (target - pred).abs()
382
- if mean:
383
- loss = loss.mean()
384
- elif self.loss_type == 'l2':
385
- if mean:
386
- loss = torch.nn.functional.mse_loss(target, pred)
387
- else:
388
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
389
- else:
390
- raise NotImplementedError("unknown loss type '{loss_type}'")
391
-
392
- return loss
393
-
394
- def p_losses(self, x_start, t, noise=None):
395
- noise = default(noise, lambda: torch.randn_like(x_start))
396
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
397
- model_out = self.model(x_noisy, t)
398
-
399
- loss_dict = {}
400
- if self.parameterization == "eps":
401
- target = noise
402
- elif self.parameterization == "x0":
403
- target = x_start
404
- elif self.parameterization == "v":
405
- target = self.get_v(x_start, noise, t)
406
- else:
407
- raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
408
-
409
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
410
-
411
- log_prefix = 'train' if self.training else 'val'
412
-
413
- loss_dict.update({f'{log_prefix}_loss_simple': loss.mean()})
414
- loss_simple = loss.mean() * self.l_simple_weight
415
-
416
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
417
- loss_dict.update({f'{log_prefix}_loss_vlb': loss_vlb})
418
-
419
- loss = loss_simple + self.original_elbo_weight * loss_vlb
420
-
421
- loss_dict.update({f'{log_prefix}_loss': loss})
422
-
423
- return loss, loss_dict
424
-
425
- def forward(self, x, *args, **kwargs):
426
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
427
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
428
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
429
- return self.p_losses(x, t, *args, **kwargs)
430
-
431
- def get_input(self, batch, k):
432
- x = batch[k]
433
- if len(x.shape) == 3:
434
- x = x[..., None]
435
- x = rearrange(x, 'b h w c -> b c h w')
436
- x = x.to(memory_format=torch.contiguous_format).float()
437
- return x
438
-
439
- def shared_step(self, batch):
440
- x = self.get_input(batch, self.first_stage_key)
441
- loss, loss_dict = self(x)
442
- return loss, loss_dict
443
-
444
- def training_step(self, batch, batch_idx):
445
- self.batch = batch
446
- for k in self.ucg_training:
447
- p = self.ucg_training[k]["p"]
448
- val = self.ucg_training[k]["val"]
449
- if val is None:
450
- val = ""
451
- for i in range(len(batch[k])):
452
- if self.ucg_prng.choice(2, p=[1 - p, p]):
453
- batch[k][i] = val
454
- loss, loss_dict = self.shared_step(batch)
455
-
456
- self.log_dict(loss_dict, prog_bar=True,
457
- logger=True, on_step=True, on_epoch=True)
458
-
459
- self.log("global_step", self.global_step,
460
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
461
-
462
- if self.use_scheduler:
463
- lr = self.optimizers().param_groups[0]['lr']
464
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
465
-
466
- return loss
467
-
468
- @torch.no_grad()
469
- def validation_step(self, batch, batch_idx):
470
- _, loss_dict_no_ema = self.shared_step(batch)
471
- with self.ema_scope():
472
- _, loss_dict_ema = self.shared_step(batch)
473
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
474
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
475
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
476
-
477
- def on_train_batch_end(self, *args, **kwargs):
478
- if self.use_ema:
479
- self.model_ema(self.model)
480
-
481
- def _get_rows_from_list(self, samples):
482
- n_imgs_per_row = len(samples)
483
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
484
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
485
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
486
- return denoise_grid
487
-
488
- @torch.no_grad()
489
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
490
- log = dict()
491
- x = self.get_input(batch, self.first_stage_key)
492
- N = min(x.shape[0], N)
493
- n_row = min(x.shape[0], n_row)
494
- x = x.to(self.device)[:N]
495
- log["inputs"] = x
496
-
497
- # get diffusion row
498
- diffusion_row = list()
499
- x_start = x[:n_row]
500
-
501
- for t in range(self.num_timesteps):
502
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
503
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
504
- t = t.to(self.device).long()
505
- noise = torch.randn_like(x_start)
506
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
507
- diffusion_row.append(x_noisy)
508
-
509
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
510
-
511
- if sample:
512
- # get denoise row
513
- with self.ema_scope("Plotting"):
514
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
515
-
516
- log["samples"] = samples
517
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
518
-
519
- if return_keys:
520
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
521
- return log
522
- else:
523
- return {key: log[key] for key in return_keys}
524
- return log
525
-
526
- def configure_optimizers(self):
527
- lr = self.learning_rate
528
- params = list(self.model.parameters())
529
- if self.learn_logvar:
530
- params = params + [self.logvar]
531
- opt = torch.optim.AdamW(params, lr=lr)
532
- return opt
533
-
534
-
535
- class LatentDiffusion(DDPM):
536
- """main class"""
537
-
538
- def __init__(self,
539
- first_stage_config,
540
- cond_stage_config,
541
- num_timesteps_cond=None,
542
- cond_stage_key="image",
543
- cond_stage_trainable=False,
544
- concat_mode=True,
545
- cond_stage_forward=None,
546
- conditioning_key=None,
547
- scale_factor=1.0,
548
- scale_by_std=False,
549
- force_null_conditioning=False,
550
- *args, **kwargs):
551
- self.kwargs = kwargs
552
- self.force_null_conditioning = force_null_conditioning
553
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
554
- self.scale_by_std = scale_by_std
555
- self.cond_stage_trainable = cond_stage_trainable
556
- assert self.num_timesteps_cond <= kwargs['timesteps']
557
- if conditioning_key is None:
558
- conditioning_key = 'concat' if concat_mode else 'crossattn'
559
- if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
560
- conditioning_key = None
561
- ckpt_path = kwargs.pop("ckpt_path", None)
562
- reset_ema = kwargs.pop("reset_ema", False)
563
- reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
564
- ignore_keys = kwargs.pop("ignore_keys", [])
565
- super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
566
- self.concat_mode = concat_mode
567
- self.cond_stage_key = cond_stage_key
568
- try:
569
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
570
- except:
571
- self.num_downs = 0
572
- if not scale_by_std:
573
- self.scale_factor = scale_factor
574
- else:
575
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
576
-
577
- self.instantiate_first_stage(first_stage_config)
578
- self.instantiate_cond_stage(cond_stage_config)
579
- self.cond_stage_forward = cond_stage_forward
580
- self.clip_denoised = False
581
- self.bbox_tokenizer = None
582
-
583
- if self.kwargs["use_imageCLIP"]:
584
- self.proj_out = nn.Linear(1024, 768)
585
- else:
586
- self.proj_out = None
587
- if self.use_pbe_weight:
588
- print("learnable vector gene")
589
- self.learnable_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True)
590
- else:
591
- self.learnable_vector = None
592
-
593
- if self.kwargs["use_lastzc"]: # deprecated
594
- self.lastzc = zero_module(conv_nd(2, 4, 4, 1, 1, 0))
595
- else:
596
- self.lastzc = None
597
-
598
- self.restarted_from_ckpt = False
599
- if ckpt_path is not None:
600
- self.init_from_ckpt(ckpt_path, ignore_keys)
601
- self.restarted_from_ckpt = True
602
- if reset_ema:
603
- assert self.use_ema
604
- print(
605
- f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
606
- self.model_ema = LitEma(self.model)
607
- if reset_num_ema_updates:
608
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
609
- assert self.use_ema
610
- self.model_ema.reset_num_updates()
611
-
612
- def make_cond_schedule(self, ):
613
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
614
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
615
- self.cond_ids[:self.num_timesteps_cond] = ids
616
-
617
- # @rank_zero_only
618
- @torch.no_grad()
619
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
620
- # only for very first batch
621
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
622
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
623
- # set rescale weight to 1./std of encodings
624
- print("### USING STD-RESCALING ###")
625
- x = super().get_input(batch, self.first_stage_key)
626
- x = x.to(self.device)
627
- encoder_posterior = self.encode_first_stage(x)
628
- z = self.get_first_stage_encoding(encoder_posterior).detach()
629
- del self.scale_factor
630
- self.register_buffer('scale_factor', 1. / z.flatten().std())
631
- print(f"setting self.scale_factor to {self.scale_factor}")
632
- print("### USING STD-RESCALING ###")
633
-
634
- def register_schedule(self,
635
- given_betas=None, beta_schedule="linear", timesteps=1000,
636
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
637
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
638
-
639
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
640
- if self.shorten_cond_schedule:
641
- self.make_cond_schedule()
642
-
643
- def instantiate_first_stage(self, config):
644
- model = instantiate_from_config(config)
645
- self.first_stage_model = model.eval()
646
- self.first_stage_model.train = disabled_train
647
- for param in self.first_stage_model.parameters():
648
- param.requires_grad = False
649
-
650
- def instantiate_cond_stage(self, config):
651
- if not self.cond_stage_trainable:
652
- if config == "__is_first_stage__":
653
- print("Using first stage also as cond stage.")
654
- self.cond_stage_model = self.first_stage_model
655
- elif config == "__is_unconditional__":
656
- print(f"Training {self.__class__.__name__} as an unconditional model.")
657
- self.cond_stage_model = None
658
- else:
659
- model = instantiate_from_config(config)
660
- self.cond_stage_model = model
661
- else:
662
- assert config != '__is_first_stage__'
663
- assert config != '__is_unconditional__'
664
- model = instantiate_from_config(config)
665
- self.cond_stage_model = model
666
-
667
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
668
- denoise_row = []
669
- for zd in tqdm(samples, desc=desc):
670
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
671
- force_not_quantize=force_no_decoder_quantization))
672
- n_imgs_per_row = len(denoise_row)
673
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
674
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
675
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
676
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
677
- return denoise_grid
678
-
679
- def get_first_stage_encoding(self, encoder_posterior):
680
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
681
- z = encoder_posterior.sample()
682
- elif isinstance(encoder_posterior, torch.Tensor):
683
- z = encoder_posterior
684
- elif isinstance(encoder_posterior, AutoencoderKLOutput):
685
- z = encoder_posterior.latent_dist.sample()
686
- else:
687
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
688
- return self.scale_factor * z
689
-
690
- def get_learned_conditioning(self, c):
691
- if self.cond_stage_forward is None:
692
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
693
- c = self.cond_stage_model.encode(c)
694
- if isinstance(c, DiagonalGaussianDistribution):
695
- c = c.mode()
696
- else:
697
- c = self.cond_stage_model(c)
698
- else:
699
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
700
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
701
- return c
702
-
703
- def meshgrid(self, h, w):
704
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
705
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
706
-
707
- arr = torch.cat([y, x], dim=-1)
708
- return arr
709
-
710
- def delta_border(self, h, w):
711
- """
712
- :param h: height
713
- :param w: width
714
- :return: normalized distance to image border,
715
- wtith min distance = 0 at border and max dist = 0.5 at image center
716
- """
717
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
718
- arr = self.meshgrid(h, w) / lower_right_corner
719
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
720
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
721
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
722
- return edge_dist
723
-
724
- def get_weighting(self, h, w, Ly, Lx, device):
725
- weighting = self.delta_border(h, w)
726
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
727
- self.split_input_params["clip_max_weight"], )
728
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
729
-
730
- if self.split_input_params["tie_braker"]:
731
- L_weighting = self.delta_border(Ly, Lx)
732
- L_weighting = torch.clip(L_weighting,
733
- self.split_input_params["clip_min_tie_weight"],
734
- self.split_input_params["clip_max_tie_weight"])
735
-
736
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
737
- weighting = weighting * L_weighting
738
- return weighting
739
-
740
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
741
- """
742
- :param x: img of size (bs, c, h, w)
743
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
744
- """
745
- bs, nc, h, w = x.shape
746
-
747
- # number of crops in image
748
- Ly = (h - kernel_size[0]) // stride[0] + 1
749
- Lx = (w - kernel_size[1]) // stride[1] + 1
750
-
751
- if uf == 1 and df == 1:
752
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
753
- unfold = torch.nn.Unfold(**fold_params)
754
-
755
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
756
-
757
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
758
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
759
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
760
-
761
- elif uf > 1 and df == 1:
762
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
763
- unfold = torch.nn.Unfold(**fold_params)
764
-
765
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
766
- dilation=1, padding=0,
767
- stride=(stride[0] * uf, stride[1] * uf))
768
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
769
-
770
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
771
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
772
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
773
-
774
- elif df > 1 and uf == 1:
775
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
776
- unfold = torch.nn.Unfold(**fold_params)
777
-
778
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
779
- dilation=1, padding=0,
780
- stride=(stride[0] // df, stride[1] // df))
781
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
782
-
783
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
784
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
785
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
786
-
787
- else:
788
- raise NotImplementedError
789
-
790
- return fold, unfold, normalization, weighting
791
-
792
- @torch.no_grad()
793
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
794
- cond_key=None, return_original_cond=False, bs=None, return_x=False, no_latent=False, is_controlnet=False):
795
- x = super().get_input(batch, k)
796
- if bs is not None:
797
- x = x[:bs]
798
- x = x.to(self.device)
799
- if no_latent:
800
- _,_,h,w = x.shape
801
- x = resize(x, (h//8, w//8))
802
- return [x, None]
803
- encoder_posterior = self.encode_first_stage(x)
804
- z = self.get_first_stage_encoding(encoder_posterior).detach()
805
- if is_controlnet and self.lastzc is not None:
806
- z = self.lastzc(z)
807
-
808
- if self.model.conditioning_key is not None and not self.force_null_conditioning:
809
- if cond_key is None:
810
- cond_key = self.cond_stage_key
811
- if cond_key != self.first_stage_key:
812
- if cond_key in ['caption', 'coordinates_bbox', "txt"]:
813
- xc = batch[cond_key]
814
- elif cond_key in ['class_label', 'cls']:
815
- xc = batch
816
- else:
817
- xc = super().get_input(batch, cond_key).to(self.device)
818
- else:
819
- xc = x
820
- if not self.cond_stage_trainable or force_c_encode:
821
- if self.kwargs["use_imageCLIP"]:
822
- xc = resize(xc, (224,224))
823
- xc = self.imagenet_norm((xc+1)/2)
824
- c = xc
825
- else:
826
- if isinstance(xc, dict) or isinstance(xc, list):
827
- c = self.get_learned_conditioning(xc)
828
- else:
829
- c = self.get_learned_conditioning(xc.to(self.device))
830
- c = c.float()
831
- else:
832
- if self.kwargs["use_imageCLIP"]:
833
- xc = resize(xc, (224,224))
834
- xc = self.imagenet_norm((xc+1)/2)
835
- c = xc
836
- if bs is not None:
837
- c = c[:bs]
838
-
839
- if self.use_positional_encodings:
840
- pos_x, pos_y = self.compute_latent_shifts(batch)
841
- ckey = __conditioning_keys__[self.model.conditioning_key]
842
- c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
843
-
844
- else:
845
- c = None
846
- xc = None
847
- if self.use_positional_encodings:
848
- pos_x, pos_y = self.compute_latent_shifts(batch)
849
- c = {'pos_x': pos_x, 'pos_y': pos_y}
850
-
851
- out = [z, c]
852
- if return_first_stage_outputs:
853
- xrec = self.decode_first_stage(z)
854
- out.extend([x, xrec])
855
- if return_x:
856
- out.extend([x])
857
- if return_original_cond:
858
- out.append(xc)
859
- return out
860
-
861
- @torch.no_grad()
862
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
863
- if predict_cids:
864
- if z.dim() == 4:
865
- z = torch.argmax(z.exp(), dim=1).long()
866
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
867
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
868
-
869
- z = 1. / self.scale_factor * z
870
- output = self.first_stage_model.decode(z)
871
- if not isinstance(output, DecoderOutput):
872
- return output
873
- else:
874
- return output.sample
875
- def decode_first_stage_train(self, z, predict_cids=False, force_not_quantize=False):
876
- if predict_cids:
877
- if z.dim() == 4:
878
- z = torch.argmax(z.exp(), dim=1).long()
879
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
880
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
881
-
882
- z = 1. / self.scale_factor * z
883
- return self.first_stage_model.decode(z)
884
-
885
- @torch.no_grad()
886
- def encode_first_stage(self, x):
887
- return self.first_stage_model.encode(x)
888
-
889
- def shared_step(self, batch, **kwargs):
890
- x, c = self.get_input(batch, self.first_stage_key)
891
- loss = self(x, c)
892
- return loss
893
-
894
- def forward(self, x, c, *args, **kwargs):
895
- if not self.use_pbe_weight:
896
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
897
- if self.model.conditioning_key is not None:
898
- assert c is not None
899
- if self.cond_stage_trainable:
900
- c = self.get_learned_conditioning(c)
901
- if self.shorten_cond_schedule: # TODO: drop this option
902
- tc = self.cond_ids[t].to(self.device)
903
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
904
- return self.p_losses(x, c, t, *args, **kwargs)
905
- # pbe negative condition
906
- else:
907
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
908
- self.u_cond_prop=random.uniform(0, 1)
909
- c["c_crossattn"] = [self.get_learned_conditioning(c["c_crossattn"])]
910
- if self.u_cond_prop < self.u_cond_percent:
911
- c["c_crossattn"] = [self.learnable_vector.repeat(x.shape[0],1,1)]
912
- return self.p_losses(x, c, t, *args, **kwargs)
913
-
914
-
915
- def apply_model(self, x_noisy, t, cond, return_ids=False):
916
- if isinstance(cond, dict):
917
- # hybrid case, cond is expected to be a dict
918
- pass
919
- else:
920
- if not isinstance(cond, list):
921
- cond = [cond]
922
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
923
- cond = {key: cond}
924
-
925
- x_recon = self.model(x_noisy, t, **cond)
926
-
927
- if isinstance(x_recon, tuple) and not return_ids:
928
- return x_recon[0]
929
- else:
930
- return x_recon
931
-
932
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
933
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
934
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
935
-
936
- def _prior_bpd(self, x_start):
937
- """
938
- Get the prior KL term for the variational lower-bound, measured in
939
- bits-per-dim.
940
- This term can't be optimized, as it only depends on the encoder.
941
- :param x_start: the [N x C x ...] tensor of inputs.
942
- :return: a batch of [N] KL values (in bits), one per batch element.
943
- """
944
- batch_size = x_start.shape[0]
945
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
946
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
947
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
948
- return mean_flat(kl_prior) / np.log(2.0)
949
- def p_losses(self, x_start, cond, t, noise=None):
950
- loss_dict = {}
951
- noise = default(noise, lambda: torch.randn_like(x_start))
952
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
953
- model_output, cond_output_dict = self.apply_model(x_noisy, t, cond)
954
-
955
- prefix = 'train' if self.training else 'val'
956
-
957
- if self.parameterization == "x0":
958
- target = x_start
959
- elif self.parameterization == "eps":
960
- target = noise
961
- elif self.parameterization == "v":
962
- target = self.get_v(x_start, noise, t)
963
- else:
964
- raise NotImplementedError()
965
- model_loss = None
966
- if isinstance(model_output, tuple):
967
- model_output, model_loss = model_output
968
-
969
- if self.only_agn_simple_loss:
970
- _, _, l_h, l_w = model_output.shape
971
- m_agn = F.interpolate(super().get_input(self.batch, "agn_mask"), (l_h, l_w))
972
- loss_simple = self.get_loss(model_output * (1-m_agn), target * (1-m_agn), mean=False).mean([1, 2, 3])
973
- else:
974
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
975
- loss_dict.update({f'simple': loss_simple.mean()})
976
-
977
- logvar_t = self.logvar[t].to(self.device)
978
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
979
- # loss = loss_simple / torch.exp(self.logvar) + self.logvar
980
- if self.learn_logvar:
981
- loss_dict.update({f'gamma': loss.mean()})
982
- loss_dict.update({'logvar': self.logvar.data.mean()})
983
- loss = self.l_simple_weight * loss.mean()
984
-
985
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
986
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
987
- if self.original_elbo_weight != 0:
988
- loss_dict.update({f'loss_vlb': loss_vlb})
989
- loss += (self.original_elbo_weight * loss_vlb)
990
-
991
- if model_loss is not None:
992
- loss += model_loss
993
- loss_dict.update({f"model loss" : model_loss})
994
- loss_dict.update({f'{prefix}_loss': loss})
995
-
996
- return loss, loss_dict
997
-
998
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
999
- return_x0=False, score_corrector=None, corrector_kwargs=None):
1000
- t_in = t
1001
- model_out, cond_output_dict = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1002
- if isinstance(model_out, tuple):
1003
- model_out, _ = model_out
1004
-
1005
- if score_corrector is not None:
1006
- assert self.parameterization == "eps"
1007
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1008
-
1009
- if return_codebook_ids:
1010
- model_out, logits = model_out
1011
-
1012
- if self.parameterization == "eps":
1013
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1014
- elif self.parameterization == "x0":
1015
- x_recon = model_out
1016
- else:
1017
- raise NotImplementedError()
1018
-
1019
- if clip_denoised:
1020
- x_recon.clamp_(-1., 1.)
1021
- if quantize_denoised:
1022
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1023
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1024
- if return_codebook_ids:
1025
- return model_mean, posterior_variance, posterior_log_variance, logits
1026
- elif return_x0:
1027
- return model_mean, posterior_variance, posterior_log_variance, x_recon
1028
- else:
1029
- return model_mean, posterior_variance, posterior_log_variance
1030
-
1031
- @torch.no_grad()
1032
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1033
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1034
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1035
- b, *_, device = *x.shape, x.device
1036
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1037
- return_codebook_ids=return_codebook_ids,
1038
- quantize_denoised=quantize_denoised,
1039
- return_x0=return_x0,
1040
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1041
- if return_codebook_ids:
1042
- raise DeprecationWarning("Support dropped.")
1043
- model_mean, _, model_log_variance, logits = outputs
1044
- elif return_x0:
1045
- model_mean, _, model_log_variance, x0 = outputs
1046
- else:
1047
- model_mean, _, model_log_variance = outputs
1048
-
1049
- noise = noise_like(x.shape, device, repeat_noise) * temperature
1050
- if noise_dropout > 0.:
1051
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1052
- # no noise when t == 0
1053
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1054
-
1055
- if return_codebook_ids:
1056
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1057
- if return_x0:
1058
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1059
- else:
1060
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1061
-
1062
- @torch.no_grad()
1063
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1064
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1065
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1066
- log_every_t=None):
1067
- if not log_every_t:
1068
- log_every_t = self.log_every_t
1069
- timesteps = self.num_timesteps
1070
- if batch_size is not None:
1071
- b = batch_size if batch_size is not None else shape[0]
1072
- shape = [batch_size] + list(shape)
1073
- else:
1074
- b = batch_size = shape[0]
1075
- if x_T is None:
1076
- img = torch.randn(shape, device=self.device)
1077
- else:
1078
- img = x_T
1079
- intermediates = []
1080
- if cond is not None:
1081
- if isinstance(cond, dict):
1082
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1083
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1084
- else:
1085
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1086
-
1087
- if start_T is not None:
1088
- timesteps = min(timesteps, start_T)
1089
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1090
- total=timesteps) if verbose else reversed(
1091
- range(0, timesteps))
1092
- if type(temperature) == float:
1093
- temperature = [temperature] * timesteps
1094
-
1095
- for i in iterator:
1096
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1097
- if self.shorten_cond_schedule:
1098
- assert self.model.conditioning_key != 'hybrid'
1099
- tc = self.cond_ids[ts].to(cond.device)
1100
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1101
-
1102
- img, x0_partial = self.p_sample(img, cond, ts,
1103
- clip_denoised=self.clip_denoised,
1104
- quantize_denoised=quantize_denoised, return_x0=True,
1105
- temperature=temperature[i], noise_dropout=noise_dropout,
1106
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1107
- if mask is not None:
1108
- assert x0 is not None
1109
- img_orig = self.q_sample(x0, ts)
1110
- img = img_orig * mask + (1. - mask) * img
1111
-
1112
- if i % log_every_t == 0 or i == timesteps - 1:
1113
- intermediates.append(x0_partial)
1114
- if callback: callback(i)
1115
- if img_callback: img_callback(img, i)
1116
- return img, intermediates
1117
-
1118
- @torch.no_grad()
1119
- def p_sample_loop(self, cond, shape, return_intermediates=False,
1120
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1121
- mask=None, x0=None, img_callback=None, start_T=None,
1122
- log_every_t=None):
1123
-
1124
- if not log_every_t:
1125
- log_every_t = self.log_every_t
1126
- device = self.betas.device
1127
- b = shape[0]
1128
- if x_T is None:
1129
- img = torch.randn(shape, device=device)
1130
- else:
1131
- img = x_T
1132
-
1133
- intermediates = [img]
1134
- if timesteps is None:
1135
- timesteps = self.num_timesteps
1136
-
1137
- if start_T is not None:
1138
- timesteps = min(timesteps, start_T)
1139
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1140
- range(0, timesteps))
1141
-
1142
- if mask is not None:
1143
- assert x0 is not None
1144
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1145
-
1146
- for i in iterator:
1147
- ts = torch.full((b,), i, device=device, dtype=torch.long)
1148
- if self.shorten_cond_schedule:
1149
- assert self.model.conditioning_key != 'hybrid'
1150
- tc = self.cond_ids[ts].to(cond.device)
1151
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1152
-
1153
- img = self.p_sample(img, cond, ts,
1154
- clip_denoised=self.clip_denoised,
1155
- quantize_denoised=quantize_denoised)
1156
- if mask is not None:
1157
- img_orig = self.q_sample(x0, ts)
1158
- img = img_orig * mask + (1. - mask) * img
1159
-
1160
- if i % log_every_t == 0 or i == timesteps - 1:
1161
- intermediates.append(img)
1162
- if callback: callback(i)
1163
- if img_callback: img_callback(img, i)
1164
-
1165
- if return_intermediates:
1166
- return img, intermediates
1167
- return img
1168
-
1169
- @torch.no_grad()
1170
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1171
- verbose=True, timesteps=None, quantize_denoised=False,
1172
- mask=None, x0=None, shape=None, **kwargs):
1173
- if shape is None:
1174
- shape = (batch_size, self.channels, self.image_size, self.image_size)
1175
- if cond is not None:
1176
- if isinstance(cond, dict):
1177
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1178
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1179
- else:
1180
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1181
- return self.p_sample_loop(cond,
1182
- shape,
1183
- return_intermediates=return_intermediates, x_T=x_T,
1184
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1185
- mask=mask, x0=x0)
1186
-
1187
- @torch.no_grad()
1188
- def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1189
- if ddim:
1190
- ddim_sampler = DDIMSampler(self)
1191
- shape = (self.channels, self.image_size, self.image_size)
1192
- samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
1193
- shape, cond, verbose=False, **kwargs)
1194
-
1195
- else:
1196
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1197
- return_intermediates=True, **kwargs)
1198
-
1199
- return samples, intermediates
1200
-
1201
- @torch.no_grad()
1202
- def get_unconditional_conditioning(self, batch_size, null_label=None):
1203
- if null_label is not None:
1204
- xc = null_label
1205
- if isinstance(xc, ListConfig):
1206
- xc = list(xc)
1207
- if isinstance(xc, dict) or isinstance(xc, list):
1208
- c = self.get_learned_conditioning(xc)
1209
- else:
1210
- if hasattr(xc, "to"):
1211
- xc = xc.to(self.device)
1212
- c = self.get_learned_conditioning(xc)
1213
- else:
1214
- if self.cond_stage_key in ["class_label", "cls"]:
1215
- xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
1216
- return self.get_learned_conditioning(xc)
1217
- else:
1218
- raise NotImplementedError("todo")
1219
- if isinstance(c, list): # in case the encoder gives us a list
1220
- for i in range(len(c)):
1221
- c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
1222
- else:
1223
- c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
1224
- return c
1225
-
1226
- @torch.no_grad()
1227
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
1228
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1229
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1230
- use_ema_scope=True,
1231
- **kwargs):
1232
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1233
- use_ddim = ddim_steps is not None
1234
-
1235
- log = dict()
1236
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1237
- return_first_stage_outputs=True,
1238
- force_c_encode=True,
1239
- return_original_cond=True,
1240
- bs=N)
1241
- N = min(x.shape[0], N)
1242
- n_row = min(x.shape[0], n_row)
1243
- log["inputs"] = x
1244
- log["reconstruction"] = xrec
1245
- if self.model.conditioning_key is not None:
1246
- if hasattr(self.cond_stage_model, "decode"):
1247
- xc = self.cond_stage_model.decode(c)
1248
- log["conditioning"] = xc
1249
- elif self.cond_stage_key in ["caption", "txt"]:
1250
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1251
- log["conditioning"] = xc
1252
- elif self.cond_stage_key in ['class_label', "cls"]:
1253
- try:
1254
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1255
- log['conditioning'] = xc
1256
- except KeyError:
1257
- # probably no "human_label" in batch
1258
- pass
1259
- elif isimage(xc):
1260
- log["conditioning"] = xc
1261
- if ismap(xc):
1262
- log["original_conditioning"] = self.to_rgb(xc)
1263
-
1264
- if plot_diffusion_rows:
1265
- # get diffusion row
1266
- diffusion_row = list()
1267
- z_start = z[:n_row]
1268
- for t in range(self.num_timesteps):
1269
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1270
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1271
- t = t.to(self.device).long()
1272
- noise = torch.randn_like(z_start)
1273
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1274
- diffusion_row.append(self.decode_first_stage(z_noisy))
1275
-
1276
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1277
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1278
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1279
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1280
- log["diffusion_row"] = diffusion_grid
1281
-
1282
- if sample:
1283
- # get denoise row
1284
- with ema_scope("Sampling"):
1285
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1286
- ddim_steps=ddim_steps, eta=ddim_eta)
1287
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1288
- x_samples = self.decode_first_stage(samples)
1289
- log["samples"] = x_samples
1290
- if plot_denoise_rows:
1291
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1292
- log["denoise_row"] = denoise_grid
1293
-
1294
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1295
- self.first_stage_model, IdentityFirstStage):
1296
- # also display when quantizing x0 while sampling
1297
- with ema_scope("Plotting Quantized Denoised"):
1298
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1299
- ddim_steps=ddim_steps, eta=ddim_eta,
1300
- quantize_denoised=True)
1301
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1302
- # quantize_denoised=True)
1303
- x_samples = self.decode_first_stage(samples.to(self.device))
1304
- log["samples_x0_quantized"] = x_samples
1305
-
1306
- if unconditional_guidance_scale > 1.0:
1307
- uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1308
- if self.model.conditioning_key == "crossattn-adm":
1309
- uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1310
- with ema_scope("Sampling with classifier-free guidance"):
1311
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1312
- ddim_steps=ddim_steps, eta=ddim_eta,
1313
- unconditional_guidance_scale=unconditional_guidance_scale,
1314
- unconditional_conditioning=uc,
1315
- )
1316
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1317
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1318
-
1319
- if inpaint:
1320
- # make a simple center square
1321
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
1322
- mask = torch.ones(N, h, w).to(self.device)
1323
- # zeros will be filled in
1324
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1325
- mask = mask[:, None, ...]
1326
- with ema_scope("Plotting Inpaint"):
1327
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1328
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1329
- x_samples = self.decode_first_stage(samples.to(self.device))
1330
- log["samples_inpainting"] = x_samples
1331
- log["mask"] = mask
1332
-
1333
- # outpaint
1334
- mask = 1. - mask
1335
- with ema_scope("Plotting Outpaint"):
1336
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1337
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1338
- x_samples = self.decode_first_stage(samples.to(self.device))
1339
- log["samples_outpainting"] = x_samples
1340
-
1341
- if plot_progressive_rows:
1342
- with ema_scope("Plotting Progressives"):
1343
- img, progressives = self.progressive_denoising(c,
1344
- shape=(self.channels, self.image_size, self.image_size),
1345
- batch_size=N)
1346
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1347
- log["progressive_row"] = prog_row
1348
-
1349
- if return_keys:
1350
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1351
- return log
1352
- else:
1353
- return {key: log[key] for key in return_keys}
1354
- return log
1355
-
1356
- def configure_optimizers(self):
1357
- lr = self.learning_rate
1358
- params = list(self.model.parameters())
1359
- if self.cond_stage_trainable:
1360
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1361
- params = params + list(self.cond_stage_model.parameters())
1362
- if self.learn_logvar:
1363
- print('Diffusion model optimizing logvar')
1364
- params.append(self.logvar)
1365
- opt = torch.optim.AdamW(params, lr=lr)
1366
- if self.use_scheduler:
1367
- assert 'target' in self.scheduler_config
1368
- scheduler = instantiate_from_config(self.scheduler_config)
1369
-
1370
- print("Setting up LambdaLR scheduler...")
1371
- scheduler = [
1372
- {
1373
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1374
- 'interval': 'step',
1375
- 'frequency': 1
1376
- }]
1377
- return [opt], scheduler
1378
- return opt
1379
-
1380
- @torch.no_grad()
1381
- def to_rgb(self, x):
1382
- x = x.float()
1383
- if not hasattr(self, "colorize"):
1384
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1385
- x = nn.functional.conv2d(x, weight=self.colorize)
1386
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1387
- return x
1388
-
1389
-
1390
- class DiffusionWrapper(nn.Module):
1391
- def __init__(self, diff_model_config, conditioning_key):
1392
- super().__init__()
1393
- self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
1394
- self.diffusion_model = instantiate_from_config(diff_model_config)
1395
- self.conditioning_key = conditioning_key
1396
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
1397
-
1398
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
1399
- if self.conditioning_key is None:
1400
- out = self.diffusion_model(x, t)
1401
- elif self.conditioning_key == 'concat':
1402
- xc = torch.cat([x] + c_concat, dim=1)
1403
- out = self.diffusion_model(xc, t)
1404
- elif self.conditioning_key == 'crossattn':
1405
- if not self.sequential_cross_attn:
1406
- cc = torch.cat(c_crossattn, 1)
1407
- else:
1408
- cc = c_crossattn
1409
- out = self.diffusion_model(x, t, context=cc)
1410
- elif self.conditioning_key == 'hybrid':
1411
- xc = torch.cat([x] + c_concat, dim=1)
1412
- cc = torch.cat(c_crossattn, 1)
1413
- out = self.diffusion_model(xc, t, context=cc)
1414
- elif self.conditioning_key == 'hybrid-adm':
1415
- assert c_adm is not None
1416
- xc = torch.cat([x] + c_concat, dim=1)
1417
- cc = torch.cat(c_crossattn, 1)
1418
- out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1419
- elif self.conditioning_key == 'crossattn-adm':
1420
- assert c_adm is not None
1421
- cc = torch.cat(c_crossattn, 1)
1422
- out = self.diffusion_model(x, t, context=cc, y=c_adm)
1423
- elif self.conditioning_key == 'adm':
1424
- cc = c_crossattn[0]
1425
- out = self.diffusion_model(x, t, y=cc)
1426
- else:
1427
- raise NotImplementedError()
1428
-
1429
- return out
1430
-
1431
-
1432
- class LatentUpscaleDiffusion(LatentDiffusion):
1433
- def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
1434
- super().__init__(*args, **kwargs)
1435
- # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1436
- assert not self.cond_stage_trainable
1437
- self.instantiate_low_stage(low_scale_config)
1438
- self.low_scale_key = low_scale_key
1439
- self.noise_level_key = noise_level_key
1440
-
1441
- def instantiate_low_stage(self, config):
1442
- model = instantiate_from_config(config)
1443
- self.low_scale_model = model.eval()
1444
- self.low_scale_model.train = disabled_train
1445
- for param in self.low_scale_model.parameters():
1446
- param.requires_grad = False
1447
-
1448
- @torch.no_grad()
1449
- def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1450
- if not log_mode:
1451
- z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1452
- else:
1453
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1454
- force_c_encode=True, return_original_cond=True, bs=bs)
1455
- x_low = batch[self.low_scale_key][:bs]
1456
- x_low = rearrange(x_low, 'b h w c -> b c h w')
1457
- x_low = x_low.to(memory_format=torch.contiguous_format).float()
1458
- zx, noise_level = self.low_scale_model(x_low)
1459
- if self.noise_level_key is not None:
1460
- # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
1461
- raise NotImplementedError('TODO')
1462
-
1463
- all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1464
- if log_mode:
1465
- # TODO: maybe disable if too expensive
1466
- x_low_rec = self.low_scale_model.decode(zx)
1467
- return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1468
- return z, all_conds
1469
-
1470
- @torch.no_grad()
1471
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1472
- plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
1473
- unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
1474
- **kwargs):
1475
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1476
- use_ddim = ddim_steps is not None
1477
-
1478
- log = dict()
1479
- z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
1480
- log_mode=True)
1481
- N = min(x.shape[0], N)
1482
- n_row = min(x.shape[0], n_row)
1483
- log["inputs"] = x
1484
- log["reconstruction"] = xrec
1485
- log["x_lr"] = x_low
1486
- log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
1487
- if self.model.conditioning_key is not None:
1488
- if hasattr(self.cond_stage_model, "decode"):
1489
- xc = self.cond_stage_model.decode(c)
1490
- log["conditioning"] = xc
1491
- elif self.cond_stage_key in ["caption", "txt"]:
1492
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1493
- log["conditioning"] = xc
1494
- elif self.cond_stage_key in ['class_label', 'cls']:
1495
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1496
- log['conditioning'] = xc
1497
- elif isimage(xc):
1498
- log["conditioning"] = xc
1499
- if ismap(xc):
1500
- log["original_conditioning"] = self.to_rgb(xc)
1501
-
1502
- if plot_diffusion_rows:
1503
- # get diffusion row
1504
- diffusion_row = list()
1505
- z_start = z[:n_row]
1506
- for t in range(self.num_timesteps):
1507
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1508
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1509
- t = t.to(self.device).long()
1510
- noise = torch.randn_like(z_start)
1511
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1512
- diffusion_row.append(self.decode_first_stage(z_noisy))
1513
-
1514
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1515
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1516
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1517
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1518
- log["diffusion_row"] = diffusion_grid
1519
-
1520
- if sample:
1521
- # get denoise row
1522
- with ema_scope("Sampling"):
1523
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1524
- ddim_steps=ddim_steps, eta=ddim_eta)
1525
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1526
- x_samples = self.decode_first_stage(samples)
1527
- log["samples"] = x_samples
1528
- if plot_denoise_rows:
1529
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1530
- log["denoise_row"] = denoise_grid
1531
-
1532
- if unconditional_guidance_scale > 1.0:
1533
- uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1534
- # TODO explore better "unconditional" choices for the other keys
1535
- # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1536
- uc = dict()
1537
- for k in c:
1538
- if k == "c_crossattn":
1539
- assert isinstance(c[k], list) and len(c[k]) == 1
1540
- uc[k] = [uc_tmp]
1541
- elif k == "c_adm": # todo: only run with text-based guidance?
1542
- assert isinstance(c[k], torch.Tensor)
1543
- #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1544
- uc[k] = c[k]
1545
- elif isinstance(c[k], list):
1546
- uc[k] = [c[k][i] for i in range(len(c[k]))]
1547
- else:
1548
- uc[k] = c[k]
1549
-
1550
- with ema_scope("Sampling with classifier-free guidance"):
1551
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1552
- ddim_steps=ddim_steps, eta=ddim_eta,
1553
- unconditional_guidance_scale=unconditional_guidance_scale,
1554
- unconditional_conditioning=uc,
1555
- )
1556
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1557
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1558
-
1559
- if plot_progressive_rows:
1560
- with ema_scope("Plotting Progressives"):
1561
- img, progressives = self.progressive_denoising(c,
1562
- shape=(self.channels, self.image_size, self.image_size),
1563
- batch_size=N)
1564
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1565
- log["progressive_row"] = prog_row
1566
-
1567
- return log
1568
-
1569
-
1570
- class LatentFinetuneDiffusion(LatentDiffusion):
1571
- """
1572
- Basis for different finetunas, such as inpainting or depth2image
1573
- To disable finetuning mode, set finetune_keys to None
1574
- """
1575
-
1576
- def __init__(self,
1577
- concat_keys: tuple,
1578
- finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
1579
- "model_ema.diffusion_modelinput_blocks00weight"
1580
- ),
1581
- keep_finetune_dims=4,
1582
- # if model was trained without concat mode before and we would like to keep these channels
1583
- c_concat_log_start=None, # to log reconstruction of c_concat codes
1584
- c_concat_log_end=None,
1585
- *args, **kwargs
1586
- ):
1587
- ckpt_path = kwargs.pop("ckpt_path", None)
1588
- ignore_keys = kwargs.pop("ignore_keys", list())
1589
- super().__init__(*args, **kwargs)
1590
- self.finetune_keys = finetune_keys
1591
- self.concat_keys = concat_keys
1592
- self.keep_dims = keep_finetune_dims
1593
- self.c_concat_log_start = c_concat_log_start
1594
- self.c_concat_log_end = c_concat_log_end
1595
- if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
1596
- if exists(ckpt_path):
1597
- self.init_from_ckpt(ckpt_path, ignore_keys)
1598
-
1599
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
1600
- sd = torch.load(path, map_location="cpu")
1601
- if "state_dict" in list(sd.keys()):
1602
- sd = sd["state_dict"]
1603
- keys = list(sd.keys())
1604
- for k in keys:
1605
- for ik in ignore_keys:
1606
- if k.startswith(ik):
1607
- print("Deleting key {} from state_dict.".format(k))
1608
- del sd[k]
1609
-
1610
- # make it explicit, finetune by including extra input channels
1611
- if exists(self.finetune_keys) and k in self.finetune_keys:
1612
- new_entry = None
1613
- for name, param in self.named_parameters():
1614
- if name in self.finetune_keys:
1615
- print(
1616
- f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
1617
- new_entry = torch.zeros_like(param) # zero init
1618
- assert exists(new_entry), 'did not find matching parameter to modify'
1619
- new_entry[:, :self.keep_dims, ...] = sd[k]
1620
- sd[k] = new_entry
1621
-
1622
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
1623
- sd, strict=False)
1624
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
1625
- if len(missing) > 0:
1626
- print(f"Missing Keys: {missing}")
1627
- if len(unexpected) > 0:
1628
- print(f"Unexpected Keys: {unexpected}")
1629
-
1630
- @torch.no_grad()
1631
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1632
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1633
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1634
- use_ema_scope=True,
1635
- **kwargs):
1636
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1637
- use_ddim = ddim_steps is not None
1638
-
1639
- log = dict()
1640
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
1641
- c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
1642
- N = min(x.shape[0], N)
1643
- n_row = min(x.shape[0], n_row)
1644
- log["inputs"] = x
1645
- log["reconstruction"] = xrec
1646
- if self.model.conditioning_key is not None:
1647
- if hasattr(self.cond_stage_model, "decode"):
1648
- xc = self.cond_stage_model.decode(c)
1649
- log["conditioning"] = xc
1650
- elif self.cond_stage_key in ["caption", "txt"]:
1651
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1652
- log["conditioning"] = xc
1653
- elif self.cond_stage_key in ['class_label', 'cls']:
1654
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1655
- log['conditioning'] = xc
1656
- elif isimage(xc):
1657
- log["conditioning"] = xc
1658
- if ismap(xc):
1659
- log["original_conditioning"] = self.to_rgb(xc)
1660
-
1661
- if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
1662
- log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
1663
-
1664
- if plot_diffusion_rows:
1665
- # get diffusion row
1666
- diffusion_row = list()
1667
- z_start = z[:n_row]
1668
- for t in range(self.num_timesteps):
1669
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1670
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1671
- t = t.to(self.device).long()
1672
- noise = torch.randn_like(z_start)
1673
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1674
- diffusion_row.append(self.decode_first_stage(z_noisy))
1675
-
1676
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1677
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1678
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1679
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1680
- log["diffusion_row"] = diffusion_grid
1681
-
1682
- if sample:
1683
- # get denoise row
1684
- with ema_scope("Sampling"):
1685
- samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1686
- batch_size=N, ddim=use_ddim,
1687
- ddim_steps=ddim_steps, eta=ddim_eta)
1688
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1689
- x_samples = self.decode_first_stage(samples)
1690
- log["samples"] = x_samples
1691
- if plot_denoise_rows:
1692
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1693
- log["denoise_row"] = denoise_grid
1694
-
1695
- if unconditional_guidance_scale > 1.0:
1696
- uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1697
- uc_cat = c_cat
1698
- uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
1699
- with ema_scope("Sampling with classifier-free guidance"):
1700
- samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1701
- batch_size=N, ddim=use_ddim,
1702
- ddim_steps=ddim_steps, eta=ddim_eta,
1703
- unconditional_guidance_scale=unconditional_guidance_scale,
1704
- unconditional_conditioning=uc_full,
1705
- )
1706
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1707
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1708
-
1709
- return log
1710
-
1711
-
1712
- class LatentInpaintDiffusion(LatentFinetuneDiffusion):
1713
- """
1714
- can either run as pure inpainting model (only concat mode) or with mixed conditionings,
1715
- e.g. mask as concat and text via cross-attn.
1716
- To disable finetuning mode, set finetune_keys to None
1717
- """
1718
-
1719
- def __init__(self,
1720
- concat_keys=("mask", "masked_image"),
1721
- masked_image_key="masked_image",
1722
- *args, **kwargs
1723
- ):
1724
- super().__init__(concat_keys, *args, **kwargs)
1725
- self.masked_image_key = masked_image_key
1726
- assert self.masked_image_key in concat_keys
1727
-
1728
- @torch.no_grad()
1729
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1730
- # note: restricted to non-trainable encoders currently
1731
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1732
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1733
- force_c_encode=True, return_original_cond=True, bs=bs)
1734
-
1735
- assert exists(self.concat_keys)
1736
- c_cat = list()
1737
- for ck in self.concat_keys:
1738
- cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1739
- if bs is not None:
1740
- cc = cc[:bs]
1741
- cc = cc.to(self.device)
1742
- bchw = z.shape
1743
- if ck != self.masked_image_key:
1744
- cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1745
- else:
1746
- cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1747
- c_cat.append(cc)
1748
- c_cat = torch.cat(c_cat, dim=1)
1749
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1750
- if return_first_stage_outputs:
1751
- return z, all_conds, x, xrec, xc
1752
- return z, all_conds
1753
-
1754
- @torch.no_grad()
1755
- def log_images(self, *args, **kwargs):
1756
- log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
1757
- log["masked_image"] = rearrange(args[0]["masked_image"],
1758
- 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1759
- return log
1760
-
1761
-
1762
- class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
1763
- """
1764
- condition on monocular depth estimation
1765
- """
1766
-
1767
- def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
1768
- super().__init__(concat_keys=concat_keys, *args, **kwargs)
1769
- self.depth_model = instantiate_from_config(depth_stage_config)
1770
- self.depth_stage_key = concat_keys[0]
1771
-
1772
- @torch.no_grad()
1773
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1774
- # note: restricted to non-trainable encoders currently
1775
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
1776
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1777
- force_c_encode=True, return_original_cond=True, bs=bs)
1778
-
1779
- assert exists(self.concat_keys)
1780
- assert len(self.concat_keys) == 1
1781
- c_cat = list()
1782
- for ck in self.concat_keys:
1783
- cc = batch[ck]
1784
- if bs is not None:
1785
- cc = cc[:bs]
1786
- cc = cc.to(self.device)
1787
- cc = self.depth_model(cc)
1788
- cc = torch.nn.functional.interpolate(
1789
- cc,
1790
- size=z.shape[2:],
1791
- mode="bicubic",
1792
- align_corners=False,
1793
- )
1794
-
1795
- depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
1796
- keepdim=True)
1797
- cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
1798
- c_cat.append(cc)
1799
- c_cat = torch.cat(c_cat, dim=1)
1800
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1801
- if return_first_stage_outputs:
1802
- return z, all_conds, x, xrec, xc
1803
- return z, all_conds
1804
-
1805
- @torch.no_grad()
1806
- def log_images(self, *args, **kwargs):
1807
- log = super().log_images(*args, **kwargs)
1808
- depth = self.depth_model(args[0][self.depth_stage_key])
1809
- depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
1810
- torch.amax(depth, dim=[1, 2, 3], keepdim=True)
1811
- log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
1812
- return log
1813
-
1814
-
1815
- class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1816
- """
1817
- condition on low-res image (and optionally on some spatial noise augmentation)
1818
- """
1819
- def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
1820
- low_scale_config=None, low_scale_key=None, *args, **kwargs):
1821
- super().__init__(concat_keys=concat_keys, *args, **kwargs)
1822
- self.reshuffle_patch_size = reshuffle_patch_size
1823
- self.low_scale_model = None
1824
- if low_scale_config is not None:
1825
- print("Initializing a low-scale model")
1826
- assert exists(low_scale_key)
1827
- self.instantiate_low_stage(low_scale_config)
1828
- self.low_scale_key = low_scale_key
1829
-
1830
- def instantiate_low_stage(self, config):
1831
- model = instantiate_from_config(config)
1832
- self.low_scale_model = model.eval()
1833
- self.low_scale_model.train = disabled_train
1834
- for param in self.low_scale_model.parameters():
1835
- param.requires_grad = False
1836
-
1837
- @torch.no_grad()
1838
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1839
- # note: restricted to non-trainable encoders currently
1840
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
1841
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1842
- force_c_encode=True, return_original_cond=True, bs=bs)
1843
-
1844
- assert exists(self.concat_keys)
1845
- assert len(self.concat_keys) == 1
1846
- # optionally make spatial noise_level here
1847
- c_cat = list()
1848
- noise_level = None
1849
- for ck in self.concat_keys:
1850
- cc = batch[ck]
1851
- cc = rearrange(cc, 'b h w c -> b c h w')
1852
- if exists(self.reshuffle_patch_size):
1853
- assert isinstance(self.reshuffle_patch_size, int)
1854
- cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
1855
- p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
1856
- if bs is not None:
1857
- cc = cc[:bs]
1858
- cc = cc.to(self.device)
1859
- if exists(self.low_scale_model) and ck == self.low_scale_key:
1860
- cc, noise_level = self.low_scale_model(cc)
1861
- c_cat.append(cc)
1862
- c_cat = torch.cat(c_cat, dim=1)
1863
- if exists(noise_level):
1864
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
1865
- else:
1866
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1867
- if return_first_stage_outputs:
1868
- return z, all_conds, x, xrec, xc
1869
- return z, all_conds
1870
-
1871
- @torch.no_grad()
1872
- def log_images(self, *args, **kwargs):
1873
- log = super().log_images(*args, **kwargs)
1874
- log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1875
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/dpm_solver/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .sampler import DPMSolverSampler
 
 
ldm/models/diffusion/dpm_solver/dpm_solver.py DELETED
@@ -1,1154 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import math
4
- from tqdm import tqdm
5
-
6
-
7
- class NoiseScheduleVP:
8
- def __init__(
9
- self,
10
- schedule='discrete',
11
- betas=None,
12
- alphas_cumprod=None,
13
- continuous_beta_0=0.1,
14
- continuous_beta_1=20.,
15
- ):
16
- """Create a wrapper class for the forward SDE (VP type).
17
- ***
18
- Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
- We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
- ***
21
- The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
- We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
- Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
- log_alpha_t = self.marginal_log_mean_coeff(t)
25
- sigma_t = self.marginal_std(t)
26
- lambda_t = self.marginal_lambda(t)
27
- Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
- t = self.inverse_lambda(lambda_t)
29
- ===============================================================
30
- We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
- 1. For discrete-time DPMs:
32
- For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
- t_i = (i + 1) / N
34
- e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
- We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
- Args:
37
- betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
- alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
- Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
- **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
- The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
- q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
- Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
- alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
- and
46
- log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
- 2. For continuous-time DPMs:
48
- We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
- schedule are the default settings in DDPM and improved-DDPM:
50
- Args:
51
- beta_min: A `float` number. The smallest beta for the linear schedule.
52
- beta_max: A `float` number. The largest beta for the linear schedule.
53
- cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
- cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
- T: A `float` number. The ending time of the forward process.
56
- ===============================================================
57
- Args:
58
- schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
- 'linear' or 'cosine' for continuous-time DPMs.
60
- Returns:
61
- A wrapper object of the forward SDE (VP type).
62
-
63
- ===============================================================
64
- Example:
65
- # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
- >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
- # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
- >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
- # For continuous-time DPMs (VPSDE), linear schedule:
70
- >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
- """
72
-
73
- if schedule not in ['discrete', 'linear', 'cosine']:
74
- raise ValueError(
75
- "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
76
- schedule))
77
-
78
- self.schedule = schedule
79
- if schedule == 'discrete':
80
- if betas is not None:
81
- log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
82
- else:
83
- assert alphas_cumprod is not None
84
- log_alphas = 0.5 * torch.log(alphas_cumprod)
85
- self.total_N = len(log_alphas)
86
- self.T = 1.
87
- self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
88
- self.log_alpha_array = log_alphas.reshape((1, -1,))
89
- else:
90
- self.total_N = 1000
91
- self.beta_0 = continuous_beta_0
92
- self.beta_1 = continuous_beta_1
93
- self.cosine_s = 0.008
94
- self.cosine_beta_max = 999.
95
- self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
96
- 1. + self.cosine_s) / math.pi - self.cosine_s
97
- self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
98
- self.schedule = schedule
99
- if schedule == 'cosine':
100
- # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
101
- # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
102
- self.T = 0.9946
103
- else:
104
- self.T = 1.
105
-
106
- def marginal_log_mean_coeff(self, t):
107
- """
108
- Compute log(alpha_t) of a given continuous-time label t in [0, T].
109
- """
110
- if self.schedule == 'discrete':
111
- return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
112
- self.log_alpha_array.to(t.device)).reshape((-1))
113
- elif self.schedule == 'linear':
114
- return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
115
- elif self.schedule == 'cosine':
116
- log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
117
- log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
118
- return log_alpha_t
119
-
120
- def marginal_alpha(self, t):
121
- """
122
- Compute alpha_t of a given continuous-time label t in [0, T].
123
- """
124
- return torch.exp(self.marginal_log_mean_coeff(t))
125
-
126
- def marginal_std(self, t):
127
- """
128
- Compute sigma_t of a given continuous-time label t in [0, T].
129
- """
130
- return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
131
-
132
- def marginal_lambda(self, t):
133
- """
134
- Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
135
- """
136
- log_mean_coeff = self.marginal_log_mean_coeff(t)
137
- log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
138
- return log_mean_coeff - log_std
139
-
140
- def inverse_lambda(self, lamb):
141
- """
142
- Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
143
- """
144
- if self.schedule == 'linear':
145
- tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
146
- Delta = self.beta_0 ** 2 + tmp
147
- return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
148
- elif self.schedule == 'discrete':
149
- log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
150
- t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
151
- torch.flip(self.t_array.to(lamb.device), [1]))
152
- return t.reshape((-1,))
153
- else:
154
- log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
155
- t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
156
- 1. + self.cosine_s) / math.pi - self.cosine_s
157
- t = t_fn(log_alpha)
158
- return t
159
-
160
-
161
- def model_wrapper(
162
- model,
163
- noise_schedule,
164
- model_type="noise",
165
- model_kwargs={},
166
- guidance_type="uncond",
167
- condition=None,
168
- unconditional_condition=None,
169
- guidance_scale=1.,
170
- classifier_fn=None,
171
- classifier_kwargs={},
172
- ):
173
- """Create a wrapper function for the noise prediction model.
174
- DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
- firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
176
- We support four types of the diffusion model by setting `model_type`:
177
- 1. "noise": noise prediction model. (Trained by predicting noise).
178
- 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
179
- 3. "v": velocity prediction model. (Trained by predicting the velocity).
180
- The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
181
- [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
182
- arXiv preprint arXiv:2202.00512 (2022).
183
- [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
184
- arXiv preprint arXiv:2210.02303 (2022).
185
-
186
- 4. "score": marginal score function. (Trained by denoising score matching).
187
- Note that the score function and the noise prediction model follows a simple relationship:
188
- ```
189
- noise(x_t, t) = -sigma_t * score(x_t, t)
190
- ```
191
- We support three types of guided sampling by DPMs by setting `guidance_type`:
192
- 1. "uncond": unconditional sampling by DPMs.
193
- The input `model` has the following format:
194
- ``
195
- model(x, t_input, **model_kwargs) -> noise | x_start | v | score
196
- ``
197
- 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
198
- The input `model` has the following format:
199
- ``
200
- model(x, t_input, **model_kwargs) -> noise | x_start | v | score
201
- ``
202
- The input `classifier_fn` has the following format:
203
- ``
204
- classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
205
- ``
206
- [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
207
- in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
208
- 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
209
- The input `model` has the following format:
210
- ``
211
- model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
212
- ``
213
- And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
214
- [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
215
- arXiv preprint arXiv:2207.12598 (2022).
216
-
217
- The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
218
- or continuous-time labels (i.e. epsilon to T).
219
- We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
220
- ``
221
- def model_fn(x, t_continuous) -> noise:
222
- t_input = get_model_input_time(t_continuous)
223
- return noise_pred(model, x, t_input, **model_kwargs)
224
- ``
225
- where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
226
- ===============================================================
227
- Args:
228
- model: A diffusion model with the corresponding format described above.
229
- noise_schedule: A noise schedule object, such as NoiseScheduleVP.
230
- model_type: A `str`. The parameterization type of the diffusion model.
231
- "noise" or "x_start" or "v" or "score".
232
- model_kwargs: A `dict`. A dict for the other inputs of the model function.
233
- guidance_type: A `str`. The type of the guidance for sampling.
234
- "uncond" or "classifier" or "classifier-free".
235
- condition: A pytorch tensor. The condition for the guided sampling.
236
- Only used for "classifier" or "classifier-free" guidance type.
237
- unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
238
- Only used for "classifier-free" guidance type.
239
- guidance_scale: A `float`. The scale for the guided sampling.
240
- classifier_fn: A classifier function. Only used for the classifier guidance.
241
- classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
242
- Returns:
243
- A noise prediction model that accepts the noised data and the continuous time as the inputs.
244
- """
245
-
246
- def get_model_input_time(t_continuous):
247
- """
248
- Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
249
- For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
250
- For continuous-time DPMs, we just use `t_continuous`.
251
- """
252
- if noise_schedule.schedule == 'discrete':
253
- return (t_continuous - 1. / noise_schedule.total_N) * 1000.
254
- else:
255
- return t_continuous
256
-
257
- def noise_pred_fn(x, t_continuous, cond=None):
258
- if t_continuous.reshape((-1,)).shape[0] == 1:
259
- t_continuous = t_continuous.expand((x.shape[0]))
260
- t_input = get_model_input_time(t_continuous)
261
- if cond is None:
262
- output = model(x, t_input, **model_kwargs)
263
- else:
264
- output = model(x, t_input, cond, **model_kwargs)
265
- if model_type == "noise":
266
- return output
267
- elif model_type == "x_start":
268
- alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
269
- dims = x.dim()
270
- return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
271
- elif model_type == "v":
272
- alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
273
- dims = x.dim()
274
- return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
275
- elif model_type == "score":
276
- sigma_t = noise_schedule.marginal_std(t_continuous)
277
- dims = x.dim()
278
- return -expand_dims(sigma_t, dims) * output
279
-
280
- def cond_grad_fn(x, t_input):
281
- """
282
- Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
283
- """
284
- with torch.enable_grad():
285
- x_in = x.detach().requires_grad_(True)
286
- log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
287
- return torch.autograd.grad(log_prob.sum(), x_in)[0]
288
-
289
- def model_fn(x, t_continuous):
290
- """
291
- The noise predicition model function that is used for DPM-Solver.
292
- """
293
- if t_continuous.reshape((-1,)).shape[0] == 1:
294
- t_continuous = t_continuous.expand((x.shape[0]))
295
- if guidance_type == "uncond":
296
- return noise_pred_fn(x, t_continuous)
297
- elif guidance_type == "classifier":
298
- assert classifier_fn is not None
299
- t_input = get_model_input_time(t_continuous)
300
- cond_grad = cond_grad_fn(x, t_input)
301
- sigma_t = noise_schedule.marginal_std(t_continuous)
302
- noise = noise_pred_fn(x, t_continuous)
303
- return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
304
- elif guidance_type == "classifier-free":
305
- if guidance_scale == 1. or unconditional_condition is None:
306
- return noise_pred_fn(x, t_continuous, cond=condition)
307
- else:
308
- x_in = torch.cat([x] * 2)
309
- t_in = torch.cat([t_continuous] * 2)
310
- c_in = torch.cat([unconditional_condition, condition])
311
- noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
312
- return noise_uncond + guidance_scale * (noise - noise_uncond)
313
-
314
- assert model_type in ["noise", "x_start", "v"]
315
- assert guidance_type in ["uncond", "classifier", "classifier-free"]
316
- return model_fn
317
-
318
-
319
- class DPM_Solver:
320
- def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
321
- """Construct a DPM-Solver.
322
- We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
323
- If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
324
- If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
325
- In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
326
- The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
327
- Args:
328
- model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
329
- ``
330
- def model_fn(x, t_continuous):
331
- return noise
332
- ``
333
- noise_schedule: A noise schedule object, such as NoiseScheduleVP.
334
- predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
335
- thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
336
- max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
337
-
338
- [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
339
- """
340
- self.model = model_fn
341
- self.noise_schedule = noise_schedule
342
- self.predict_x0 = predict_x0
343
- self.thresholding = thresholding
344
- self.max_val = max_val
345
-
346
- def noise_prediction_fn(self, x, t):
347
- """
348
- Return the noise prediction model.
349
- """
350
- return self.model(x, t)
351
-
352
- def data_prediction_fn(self, x, t):
353
- """
354
- Return the data prediction model (with thresholding).
355
- """
356
- noise = self.noise_prediction_fn(x, t)
357
- dims = x.dim()
358
- alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
359
- x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
360
- if self.thresholding:
361
- p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
362
- s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
363
- s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
364
- x0 = torch.clamp(x0, -s, s) / s
365
- return x0
366
-
367
- def model_fn(self, x, t):
368
- """
369
- Convert the model to the noise prediction model or the data prediction model.
370
- """
371
- if self.predict_x0:
372
- return self.data_prediction_fn(x, t)
373
- else:
374
- return self.noise_prediction_fn(x, t)
375
-
376
- def get_time_steps(self, skip_type, t_T, t_0, N, device):
377
- """Compute the intermediate time steps for sampling.
378
- Args:
379
- skip_type: A `str`. The type for the spacing of the time steps. We support three types:
380
- - 'logSNR': uniform logSNR for the time steps.
381
- - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
382
- - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
383
- t_T: A `float`. The starting time of the sampling (default is T).
384
- t_0: A `float`. The ending time of the sampling (default is epsilon).
385
- N: A `int`. The total number of the spacing of the time steps.
386
- device: A torch device.
387
- Returns:
388
- A pytorch tensor of the time steps, with the shape (N + 1,).
389
- """
390
- if skip_type == 'logSNR':
391
- lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
392
- lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
393
- logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
394
- return self.noise_schedule.inverse_lambda(logSNR_steps)
395
- elif skip_type == 'time_uniform':
396
- return torch.linspace(t_T, t_0, N + 1).to(device)
397
- elif skip_type == 'time_quadratic':
398
- t_order = 2
399
- t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
400
- return t
401
- else:
402
- raise ValueError(
403
- "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
404
-
405
- def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
406
- """
407
- Get the order of each step for sampling by the singlestep DPM-Solver.
408
- We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
409
- Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
410
- - If order == 1:
411
- We take `steps` of DPM-Solver-1 (i.e. DDIM).
412
- - If order == 2:
413
- - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
414
- - If steps % 2 == 0, we use K steps of DPM-Solver-2.
415
- - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
416
- - If order == 3:
417
- - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
418
- - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
419
- - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
420
- - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
421
- ============================================
422
- Args:
423
- order: A `int`. The max order for the solver (2 or 3).
424
- steps: A `int`. The total number of function evaluations (NFE).
425
- skip_type: A `str`. The type for the spacing of the time steps. We support three types:
426
- - 'logSNR': uniform logSNR for the time steps.
427
- - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
428
- - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
429
- t_T: A `float`. The starting time of the sampling (default is T).
430
- t_0: A `float`. The ending time of the sampling (default is epsilon).
431
- device: A torch device.
432
- Returns:
433
- orders: A list of the solver order of each step.
434
- """
435
- if order == 3:
436
- K = steps // 3 + 1
437
- if steps % 3 == 0:
438
- orders = [3, ] * (K - 2) + [2, 1]
439
- elif steps % 3 == 1:
440
- orders = [3, ] * (K - 1) + [1]
441
- else:
442
- orders = [3, ] * (K - 1) + [2]
443
- elif order == 2:
444
- if steps % 2 == 0:
445
- K = steps // 2
446
- orders = [2, ] * K
447
- else:
448
- K = steps // 2 + 1
449
- orders = [2, ] * (K - 1) + [1]
450
- elif order == 1:
451
- K = 1
452
- orders = [1, ] * steps
453
- else:
454
- raise ValueError("'order' must be '1' or '2' or '3'.")
455
- if skip_type == 'logSNR':
456
- # To reproduce the results in DPM-Solver paper
457
- timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
- else:
459
- timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
460
- torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
461
- return timesteps_outer, orders
462
-
463
- def denoise_to_zero_fn(self, x, s):
464
- """
465
- Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
466
- """
467
- return self.data_prediction_fn(x, s)
468
-
469
- def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
470
- """
471
- DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
472
- Args:
473
- x: A pytorch tensor. The initial value at time `s`.
474
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
475
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
476
- model_s: A pytorch tensor. The model function evaluated at time `s`.
477
- If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
478
- return_intermediate: A `bool`. If true, also return the model value at time `s`.
479
- Returns:
480
- x_t: A pytorch tensor. The approximated solution at time `t`.
481
- """
482
- ns = self.noise_schedule
483
- dims = x.dim()
484
- lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
485
- h = lambda_t - lambda_s
486
- log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
487
- sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
488
- alpha_t = torch.exp(log_alpha_t)
489
-
490
- if self.predict_x0:
491
- phi_1 = torch.expm1(-h)
492
- if model_s is None:
493
- model_s = self.model_fn(x, s)
494
- x_t = (
495
- expand_dims(sigma_t / sigma_s, dims) * x
496
- - expand_dims(alpha_t * phi_1, dims) * model_s
497
- )
498
- if return_intermediate:
499
- return x_t, {'model_s': model_s}
500
- else:
501
- return x_t
502
- else:
503
- phi_1 = torch.expm1(h)
504
- if model_s is None:
505
- model_s = self.model_fn(x, s)
506
- x_t = (
507
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
508
- - expand_dims(sigma_t * phi_1, dims) * model_s
509
- )
510
- if return_intermediate:
511
- return x_t, {'model_s': model_s}
512
- else:
513
- return x_t
514
-
515
- def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
516
- solver_type='dpm_solver'):
517
- """
518
- Singlestep solver DPM-Solver-2 from time `s` to time `t`.
519
- Args:
520
- x: A pytorch tensor. The initial value at time `s`.
521
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
522
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
523
- r1: A `float`. The hyperparameter of the second-order solver.
524
- model_s: A pytorch tensor. The model function evaluated at time `s`.
525
- If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
526
- return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
527
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
528
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
529
- Returns:
530
- x_t: A pytorch tensor. The approximated solution at time `t`.
531
- """
532
- if solver_type not in ['dpm_solver', 'taylor']:
533
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
534
- if r1 is None:
535
- r1 = 0.5
536
- ns = self.noise_schedule
537
- dims = x.dim()
538
- lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
539
- h = lambda_t - lambda_s
540
- lambda_s1 = lambda_s + r1 * h
541
- s1 = ns.inverse_lambda(lambda_s1)
542
- log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
543
- s1), ns.marginal_log_mean_coeff(t)
544
- sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
545
- alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
546
-
547
- if self.predict_x0:
548
- phi_11 = torch.expm1(-r1 * h)
549
- phi_1 = torch.expm1(-h)
550
-
551
- if model_s is None:
552
- model_s = self.model_fn(x, s)
553
- x_s1 = (
554
- expand_dims(sigma_s1 / sigma_s, dims) * x
555
- - expand_dims(alpha_s1 * phi_11, dims) * model_s
556
- )
557
- model_s1 = self.model_fn(x_s1, s1)
558
- if solver_type == 'dpm_solver':
559
- x_t = (
560
- expand_dims(sigma_t / sigma_s, dims) * x
561
- - expand_dims(alpha_t * phi_1, dims) * model_s
562
- - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
563
- )
564
- elif solver_type == 'taylor':
565
- x_t = (
566
- expand_dims(sigma_t / sigma_s, dims) * x
567
- - expand_dims(alpha_t * phi_1, dims) * model_s
568
- + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
569
- model_s1 - model_s)
570
- )
571
- else:
572
- phi_11 = torch.expm1(r1 * h)
573
- phi_1 = torch.expm1(h)
574
-
575
- if model_s is None:
576
- model_s = self.model_fn(x, s)
577
- x_s1 = (
578
- expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
579
- - expand_dims(sigma_s1 * phi_11, dims) * model_s
580
- )
581
- model_s1 = self.model_fn(x_s1, s1)
582
- if solver_type == 'dpm_solver':
583
- x_t = (
584
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
585
- - expand_dims(sigma_t * phi_1, dims) * model_s
586
- - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
587
- )
588
- elif solver_type == 'taylor':
589
- x_t = (
590
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
591
- - expand_dims(sigma_t * phi_1, dims) * model_s
592
- - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
593
- )
594
- if return_intermediate:
595
- return x_t, {'model_s': model_s, 'model_s1': model_s1}
596
- else:
597
- return x_t
598
-
599
- def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
600
- return_intermediate=False, solver_type='dpm_solver'):
601
- """
602
- Singlestep solver DPM-Solver-3 from time `s` to time `t`.
603
- Args:
604
- x: A pytorch tensor. The initial value at time `s`.
605
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
606
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
607
- r1: A `float`. The hyperparameter of the third-order solver.
608
- r2: A `float`. The hyperparameter of the third-order solver.
609
- model_s: A pytorch tensor. The model function evaluated at time `s`.
610
- If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
611
- model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
612
- If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
613
- return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
614
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
615
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
616
- Returns:
617
- x_t: A pytorch tensor. The approximated solution at time `t`.
618
- """
619
- if solver_type not in ['dpm_solver', 'taylor']:
620
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
621
- if r1 is None:
622
- r1 = 1. / 3.
623
- if r2 is None:
624
- r2 = 2. / 3.
625
- ns = self.noise_schedule
626
- dims = x.dim()
627
- lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
628
- h = lambda_t - lambda_s
629
- lambda_s1 = lambda_s + r1 * h
630
- lambda_s2 = lambda_s + r2 * h
631
- s1 = ns.inverse_lambda(lambda_s1)
632
- s2 = ns.inverse_lambda(lambda_s2)
633
- log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
634
- s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
635
- sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
636
- s2), ns.marginal_std(t)
637
- alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
638
-
639
- if self.predict_x0:
640
- phi_11 = torch.expm1(-r1 * h)
641
- phi_12 = torch.expm1(-r2 * h)
642
- phi_1 = torch.expm1(-h)
643
- phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
644
- phi_2 = phi_1 / h + 1.
645
- phi_3 = phi_2 / h - 0.5
646
-
647
- if model_s is None:
648
- model_s = self.model_fn(x, s)
649
- if model_s1 is None:
650
- x_s1 = (
651
- expand_dims(sigma_s1 / sigma_s, dims) * x
652
- - expand_dims(alpha_s1 * phi_11, dims) * model_s
653
- )
654
- model_s1 = self.model_fn(x_s1, s1)
655
- x_s2 = (
656
- expand_dims(sigma_s2 / sigma_s, dims) * x
657
- - expand_dims(alpha_s2 * phi_12, dims) * model_s
658
- + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
659
- )
660
- model_s2 = self.model_fn(x_s2, s2)
661
- if solver_type == 'dpm_solver':
662
- x_t = (
663
- expand_dims(sigma_t / sigma_s, dims) * x
664
- - expand_dims(alpha_t * phi_1, dims) * model_s
665
- + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
666
- )
667
- elif solver_type == 'taylor':
668
- D1_0 = (1. / r1) * (model_s1 - model_s)
669
- D1_1 = (1. / r2) * (model_s2 - model_s)
670
- D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
671
- D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
672
- x_t = (
673
- expand_dims(sigma_t / sigma_s, dims) * x
674
- - expand_dims(alpha_t * phi_1, dims) * model_s
675
- + expand_dims(alpha_t * phi_2, dims) * D1
676
- - expand_dims(alpha_t * phi_3, dims) * D2
677
- )
678
- else:
679
- phi_11 = torch.expm1(r1 * h)
680
- phi_12 = torch.expm1(r2 * h)
681
- phi_1 = torch.expm1(h)
682
- phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
683
- phi_2 = phi_1 / h - 1.
684
- phi_3 = phi_2 / h - 0.5
685
-
686
- if model_s is None:
687
- model_s = self.model_fn(x, s)
688
- if model_s1 is None:
689
- x_s1 = (
690
- expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
691
- - expand_dims(sigma_s1 * phi_11, dims) * model_s
692
- )
693
- model_s1 = self.model_fn(x_s1, s1)
694
- x_s2 = (
695
- expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
696
- - expand_dims(sigma_s2 * phi_12, dims) * model_s
697
- - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
698
- )
699
- model_s2 = self.model_fn(x_s2, s2)
700
- if solver_type == 'dpm_solver':
701
- x_t = (
702
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
703
- - expand_dims(sigma_t * phi_1, dims) * model_s
704
- - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
705
- )
706
- elif solver_type == 'taylor':
707
- D1_0 = (1. / r1) * (model_s1 - model_s)
708
- D1_1 = (1. / r2) * (model_s2 - model_s)
709
- D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
710
- D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
711
- x_t = (
712
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
713
- - expand_dims(sigma_t * phi_1, dims) * model_s
714
- - expand_dims(sigma_t * phi_2, dims) * D1
715
- - expand_dims(sigma_t * phi_3, dims) * D2
716
- )
717
-
718
- if return_intermediate:
719
- return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
720
- else:
721
- return x_t
722
-
723
- def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
724
- """
725
- Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
726
- Args:
727
- x: A pytorch tensor. The initial value at time `s`.
728
- model_prev_list: A list of pytorch tensor. The previous computed model values.
729
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
730
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
731
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
732
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
733
- Returns:
734
- x_t: A pytorch tensor. The approximated solution at time `t`.
735
- """
736
- if solver_type not in ['dpm_solver', 'taylor']:
737
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
738
- ns = self.noise_schedule
739
- dims = x.dim()
740
- model_prev_1, model_prev_0 = model_prev_list
741
- t_prev_1, t_prev_0 = t_prev_list
742
- lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
743
- t_prev_0), ns.marginal_lambda(t)
744
- log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
745
- sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
746
- alpha_t = torch.exp(log_alpha_t)
747
-
748
- h_0 = lambda_prev_0 - lambda_prev_1
749
- h = lambda_t - lambda_prev_0
750
- r0 = h_0 / h
751
- D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
752
- if self.predict_x0:
753
- if solver_type == 'dpm_solver':
754
- x_t = (
755
- expand_dims(sigma_t / sigma_prev_0, dims) * x
756
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
757
- - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
758
- )
759
- elif solver_type == 'taylor':
760
- x_t = (
761
- expand_dims(sigma_t / sigma_prev_0, dims) * x
762
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
763
- + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
764
- )
765
- else:
766
- if solver_type == 'dpm_solver':
767
- x_t = (
768
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
769
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
770
- - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
771
- )
772
- elif solver_type == 'taylor':
773
- x_t = (
774
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
775
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
776
- - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
777
- )
778
- return x_t
779
-
780
- def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
781
- """
782
- Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
783
- Args:
784
- x: A pytorch tensor. The initial value at time `s`.
785
- model_prev_list: A list of pytorch tensor. The previous computed model values.
786
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
787
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
788
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
789
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
790
- Returns:
791
- x_t: A pytorch tensor. The approximated solution at time `t`.
792
- """
793
- ns = self.noise_schedule
794
- dims = x.dim()
795
- model_prev_2, model_prev_1, model_prev_0 = model_prev_list
796
- t_prev_2, t_prev_1, t_prev_0 = t_prev_list
797
- lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
798
- t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
799
- log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
800
- sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
801
- alpha_t = torch.exp(log_alpha_t)
802
-
803
- h_1 = lambda_prev_1 - lambda_prev_2
804
- h_0 = lambda_prev_0 - lambda_prev_1
805
- h = lambda_t - lambda_prev_0
806
- r0, r1 = h_0 / h, h_1 / h
807
- D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
808
- D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
809
- D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
810
- D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
811
- if self.predict_x0:
812
- x_t = (
813
- expand_dims(sigma_t / sigma_prev_0, dims) * x
814
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
815
- + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
816
- - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
817
- )
818
- else:
819
- x_t = (
820
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
821
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
822
- - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
823
- - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
824
- )
825
- return x_t
826
-
827
- def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
828
- r2=None):
829
- """
830
- Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
831
- Args:
832
- x: A pytorch tensor. The initial value at time `s`.
833
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
834
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
835
- order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
836
- return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
837
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
838
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
839
- r1: A `float`. The hyperparameter of the second-order or third-order solver.
840
- r2: A `float`. The hyperparameter of the third-order solver.
841
- Returns:
842
- x_t: A pytorch tensor. The approximated solution at time `t`.
843
- """
844
- if order == 1:
845
- return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
846
- elif order == 2:
847
- return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
848
- solver_type=solver_type, r1=r1)
849
- elif order == 3:
850
- return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
851
- solver_type=solver_type, r1=r1, r2=r2)
852
- else:
853
- raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
854
-
855
- def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
856
- """
857
- Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
858
- Args:
859
- x: A pytorch tensor. The initial value at time `s`.
860
- model_prev_list: A list of pytorch tensor. The previous computed model values.
861
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
862
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
863
- order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
864
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
865
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
866
- Returns:
867
- x_t: A pytorch tensor. The approximated solution at time `t`.
868
- """
869
- if order == 1:
870
- return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
871
- elif order == 2:
872
- return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
873
- elif order == 3:
874
- return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
875
- else:
876
- raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
877
-
878
- def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
879
- solver_type='dpm_solver'):
880
- """
881
- The adaptive step size solver based on singlestep DPM-Solver.
882
- Args:
883
- x: A pytorch tensor. The initial value at time `t_T`.
884
- order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
885
- t_T: A `float`. The starting time of the sampling (default is T).
886
- t_0: A `float`. The ending time of the sampling (default is epsilon).
887
- h_init: A `float`. The initial step size (for logSNR).
888
- atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
889
- rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
890
- theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
891
- t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
892
- current time and `t_0` is less than `t_err`. The default setting is 1e-5.
893
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
894
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
895
- Returns:
896
- x_0: A pytorch tensor. The approximated solution at time `t_0`.
897
- [1] A. Jolicoeur-Martineau, K. Li, R. PichΓ©-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
898
- """
899
- ns = self.noise_schedule
900
- s = t_T * torch.ones((x.shape[0],)).to(x)
901
- lambda_s = ns.marginal_lambda(s)
902
- lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
903
- h = h_init * torch.ones_like(s).to(x)
904
- x_prev = x
905
- nfe = 0
906
- if order == 2:
907
- r1 = 0.5
908
- lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
909
- higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
910
- solver_type=solver_type,
911
- **kwargs)
912
- elif order == 3:
913
- r1, r2 = 1. / 3., 2. / 3.
914
- lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
915
- return_intermediate=True,
916
- solver_type=solver_type)
917
- higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
918
- solver_type=solver_type,
919
- **kwargs)
920
- else:
921
- raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
922
- while torch.abs((s - t_0)).mean() > t_err:
923
- t = ns.inverse_lambda(lambda_s + h)
924
- x_lower, lower_noise_kwargs = lower_update(x, s, t)
925
- x_higher = higher_update(x, s, t, **lower_noise_kwargs)
926
- delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
927
- norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
928
- E = norm_fn((x_higher - x_lower) / delta).max()
929
- if torch.all(E <= 1.):
930
- x = x_higher
931
- s = t
932
- x_prev = x_lower
933
- lambda_s = ns.marginal_lambda(s)
934
- h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
935
- nfe += order
936
- print('adaptive solver nfe', nfe)
937
- return x
938
-
939
- def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
940
- method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
941
- atol=0.0078, rtol=0.05,
942
- ):
943
- """
944
- Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
945
- =====================================================
946
- We support the following algorithms for both noise prediction model and data prediction model:
947
- - 'singlestep':
948
- Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
949
- We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
950
- The total number of function evaluations (NFE) == `steps`.
951
- Given a fixed NFE == `steps`, the sampling procedure is:
952
- - If `order` == 1:
953
- - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
954
- - If `order` == 2:
955
- - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
956
- - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
957
- - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
958
- - If `order` == 3:
959
- - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
960
- - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
961
- - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
962
- - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
963
- - 'multistep':
964
- Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
965
- We initialize the first `order` values by lower order multistep solvers.
966
- Given a fixed NFE == `steps`, the sampling procedure is:
967
- Denote K = steps.
968
- - If `order` == 1:
969
- - We use K steps of DPM-Solver-1 (i.e. DDIM).
970
- - If `order` == 2:
971
- - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
972
- - If `order` == 3:
973
- - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
974
- - 'singlestep_fixed':
975
- Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
976
- We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
977
- - 'adaptive':
978
- Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
979
- We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
980
- You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
981
- (NFE) and the sample quality.
982
- - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
983
- - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
984
- =====================================================
985
- Some advices for choosing the algorithm:
986
- - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
987
- Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
988
- e.g.
989
- >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
990
- >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
991
- skip_type='time_uniform', method='singlestep')
992
- - For **guided sampling with large guidance scale** by DPMs:
993
- Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
994
- e.g.
995
- >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
996
- >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
997
- skip_type='time_uniform', method='multistep')
998
- We support three types of `skip_type`:
999
- - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1000
- - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1001
- - 'time_quadratic': quadratic time for the time steps.
1002
- =====================================================
1003
- Args:
1004
- x: A pytorch tensor. The initial value at time `t_start`
1005
- e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1006
- steps: A `int`. The total number of function evaluations (NFE).
1007
- t_start: A `float`. The starting time of the sampling.
1008
- If `T` is None, we use self.noise_schedule.T (default is 1.0).
1009
- t_end: A `float`. The ending time of the sampling.
1010
- If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1011
- e.g. if total_N == 1000, we have `t_end` == 1e-3.
1012
- For discrete-time DPMs:
1013
- - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1014
- For continuous-time DPMs:
1015
- - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1016
- order: A `int`. The order of DPM-Solver.
1017
- skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1018
- method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1019
- denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1020
- Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1021
- This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1022
- score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1023
- for diffusion models sampling by diffusion SDEs for low-resolutional images
1024
- (such as CIFAR-10). However, we observed that such trick does not matter for
1025
- high-resolutional images. As it needs an additional NFE, we do not recommend
1026
- it for high-resolutional images.
1027
- lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1028
- Only valid for `method=multistep` and `steps < 15`. We empirically find that
1029
- this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1030
- (especially for steps <= 10). So we recommend to set it to be `True`.
1031
- solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1032
- atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1033
- rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1034
- Returns:
1035
- x_end: A pytorch tensor. The approximated solution at time `t_end`.
1036
- """
1037
- t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1038
- t_T = self.noise_schedule.T if t_start is None else t_start
1039
- device = x.device
1040
- if method == 'adaptive':
1041
- with torch.no_grad():
1042
- x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1043
- solver_type=solver_type)
1044
- elif method == 'multistep':
1045
- assert steps >= order
1046
- timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1047
- assert timesteps.shape[0] - 1 == steps
1048
- with torch.no_grad():
1049
- vec_t = timesteps[0].expand((x.shape[0]))
1050
- model_prev_list = [self.model_fn(x, vec_t)]
1051
- t_prev_list = [vec_t]
1052
- # Init the first `order` values by lower order multistep DPM-Solver.
1053
- for init_order in tqdm(range(1, order), desc="DPM init order"):
1054
- vec_t = timesteps[init_order].expand(x.shape[0])
1055
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1056
- solver_type=solver_type)
1057
- model_prev_list.append(self.model_fn(x, vec_t))
1058
- t_prev_list.append(vec_t)
1059
- # Compute the remaining values by `order`-th order multistep DPM-Solver.
1060
- for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1061
- vec_t = timesteps[step].expand(x.shape[0])
1062
- if lower_order_final and steps < 15:
1063
- step_order = min(order, steps + 1 - step)
1064
- else:
1065
- step_order = order
1066
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1067
- solver_type=solver_type)
1068
- for i in range(order - 1):
1069
- t_prev_list[i] = t_prev_list[i + 1]
1070
- model_prev_list[i] = model_prev_list[i + 1]
1071
- t_prev_list[-1] = vec_t
1072
- # We do not need to evaluate the final model value.
1073
- if step < steps:
1074
- model_prev_list[-1] = self.model_fn(x, vec_t)
1075
- elif method in ['singlestep', 'singlestep_fixed']:
1076
- if method == 'singlestep':
1077
- timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1078
- skip_type=skip_type,
1079
- t_T=t_T, t_0=t_0,
1080
- device=device)
1081
- elif method == 'singlestep_fixed':
1082
- K = steps // order
1083
- orders = [order, ] * K
1084
- timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1085
- for i, order in enumerate(orders):
1086
- t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1087
- timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1088
- N=order, device=device)
1089
- lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1090
- vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1091
- h = lambda_inner[-1] - lambda_inner[0]
1092
- r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1093
- r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1094
- x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1095
- if denoise_to_zero:
1096
- x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1097
- return x
1098
-
1099
-
1100
- #############################################################
1101
- # other utility functions
1102
- #############################################################
1103
-
1104
- def interpolate_fn(x, xp, yp):
1105
- """
1106
- A piecewise linear function y = f(x), using xp and yp as keypoints.
1107
- We implement f(x) in a differentiable way (i.e. applicable for autograd).
1108
- The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1109
- Args:
1110
- x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1111
- xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1112
- yp: PyTorch tensor with shape [C, K].
1113
- Returns:
1114
- The function values f(x), with shape [N, C].
1115
- """
1116
- N, K = x.shape[0], xp.shape[1]
1117
- all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1118
- sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1119
- x_idx = torch.argmin(x_indices, dim=2)
1120
- cand_start_idx = x_idx - 1
1121
- start_idx = torch.where(
1122
- torch.eq(x_idx, 0),
1123
- torch.tensor(1, device=x.device),
1124
- torch.where(
1125
- torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1126
- ),
1127
- )
1128
- end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1129
- start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1130
- end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1131
- start_idx2 = torch.where(
1132
- torch.eq(x_idx, 0),
1133
- torch.tensor(0, device=x.device),
1134
- torch.where(
1135
- torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1136
- ),
1137
- )
1138
- y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1139
- start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1140
- end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1141
- cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1142
- return cand
1143
-
1144
-
1145
- def expand_dims(v, dims):
1146
- """
1147
- Expand the tensor `v` to the dim `dims`.
1148
- Args:
1149
- `v`: a PyTorch tensor with shape [N].
1150
- `dim`: a `int`.
1151
- Returns:
1152
- a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1153
- """
1154
- return v[(...,) + (None,) * (dims - 1)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/dpm_solver/sampler.py DELETED
@@ -1,87 +0,0 @@
1
- """SAMPLING ONLY."""
2
- import torch
3
-
4
- from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5
-
6
-
7
- MODEL_TYPES = {
8
- "eps": "noise",
9
- "v": "v"
10
- }
11
-
12
-
13
- class DPMSolverSampler(object):
14
- def __init__(self, model, **kwargs):
15
- super().__init__()
16
- self.model = model
17
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18
- self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19
-
20
- def register_buffer(self, name, attr):
21
- if type(attr) == torch.Tensor:
22
- if attr.device != torch.device("cuda"):
23
- attr = attr.to(torch.device("cuda"))
24
- setattr(self, name, attr)
25
-
26
- @torch.no_grad()
27
- def sample(self,
28
- S,
29
- batch_size,
30
- shape,
31
- conditioning=None,
32
- callback=None,
33
- normals_sequence=None,
34
- img_callback=None,
35
- quantize_x0=False,
36
- eta=0.,
37
- mask=None,
38
- x0=None,
39
- temperature=1.,
40
- noise_dropout=0.,
41
- score_corrector=None,
42
- corrector_kwargs=None,
43
- verbose=True,
44
- x_T=None,
45
- log_every_t=100,
46
- unconditional_guidance_scale=1.,
47
- unconditional_conditioning=None,
48
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49
- **kwargs
50
- ):
51
- if conditioning is not None:
52
- if isinstance(conditioning, dict):
53
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54
- if cbs != batch_size:
55
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56
- else:
57
- if conditioning.shape[0] != batch_size:
58
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59
-
60
- # sampling
61
- C, H, W = shape
62
- size = (batch_size, C, H, W)
63
-
64
- print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65
-
66
- device = self.model.betas.device
67
- if x_T is None:
68
- img = torch.randn(size, device=device)
69
- else:
70
- img = x_T
71
-
72
- ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73
-
74
- model_fn = model_wrapper(
75
- lambda x, t, c: self.model.apply_model(x, t, c),
76
- ns,
77
- model_type=MODEL_TYPES[self.model.parameterization],
78
- guidance_type="classifier-free",
79
- condition=conditioning,
80
- unconditional_condition=unconditional_conditioning,
81
- guidance_scale=unconditional_guidance_scale,
82
- )
83
-
84
- dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85
- x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86
-
87
- return x.to(device), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/plms.py DELETED
@@ -1,244 +0,0 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
- from functools import partial
7
-
8
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
- from ldm.models.diffusion.sampling_util import norm_thresholding
10
-
11
-
12
- class PLMSSampler(object):
13
- def __init__(self, model, schedule="linear", **kwargs):
14
- super().__init__()
15
- self.model = model
16
- self.ddpm_num_timesteps = model.num_timesteps
17
- self.schedule = schedule
18
-
19
- def register_buffer(self, name, attr):
20
- if type(attr) == torch.Tensor:
21
- if attr.device != torch.device("cuda"):
22
- attr = attr.to(torch.device("cuda"))
23
- setattr(self, name, attr)
24
-
25
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
- if ddim_eta != 0:
27
- raise ValueError('ddim_eta must be 0 for PLMS')
28
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30
- alphas_cumprod = self.model.alphas_cumprod
31
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33
-
34
- self.register_buffer('betas', to_torch(self.model.betas))
35
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37
-
38
- # calculations for diffusion q(x_t | x_{t-1}) and others
39
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44
-
45
- # ddim sampling parameters
46
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47
- ddim_timesteps=self.ddim_timesteps,
48
- eta=ddim_eta,verbose=verbose)
49
- self.register_buffer('ddim_sigmas', ddim_sigmas)
50
- self.register_buffer('ddim_alphas', ddim_alphas)
51
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57
-
58
- @torch.no_grad()
59
- def sample(self,
60
- S,
61
- batch_size,
62
- shape,
63
- conditioning=None,
64
- callback=None,
65
- normals_sequence=None,
66
- img_callback=None,
67
- quantize_x0=False,
68
- eta=0.,
69
- mask=None,
70
- x0=None,
71
- temperature=1.,
72
- noise_dropout=0.,
73
- score_corrector=None,
74
- corrector_kwargs=None,
75
- verbose=True,
76
- x_T=None,
77
- log_every_t=100,
78
- unconditional_guidance_scale=1.,
79
- unconditional_conditioning=None,
80
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
- dynamic_threshold=None,
82
- **kwargs
83
- ):
84
- if conditioning is not None:
85
- if isinstance(conditioning, dict):
86
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87
- if cbs != batch_size:
88
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
- else:
90
- if conditioning.shape[0] != batch_size:
91
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92
-
93
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94
- # sampling
95
- C, H, W = shape
96
- size = (batch_size, C, H, W)
97
- print(f'Data shape for PLMS sampling is {size}')
98
-
99
- samples, intermediates = self.plms_sampling(conditioning, size,
100
- callback=callback,
101
- img_callback=img_callback,
102
- quantize_denoised=quantize_x0,
103
- mask=mask, x0=x0,
104
- ddim_use_original_steps=False,
105
- noise_dropout=noise_dropout,
106
- temperature=temperature,
107
- score_corrector=score_corrector,
108
- corrector_kwargs=corrector_kwargs,
109
- x_T=x_T,
110
- log_every_t=log_every_t,
111
- unconditional_guidance_scale=unconditional_guidance_scale,
112
- unconditional_conditioning=unconditional_conditioning,
113
- dynamic_threshold=dynamic_threshold,
114
- )
115
- return samples, intermediates
116
-
117
- @torch.no_grad()
118
- def plms_sampling(self, cond, shape,
119
- x_T=None, ddim_use_original_steps=False,
120
- callback=None, timesteps=None, quantize_denoised=False,
121
- mask=None, x0=None, img_callback=None, log_every_t=100,
122
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123
- unconditional_guidance_scale=1., unconditional_conditioning=None,
124
- dynamic_threshold=None):
125
- device = self.model.betas.device
126
- b = shape[0]
127
- if x_T is None:
128
- img = torch.randn(shape, device=device)
129
- else:
130
- img = x_T
131
-
132
- if timesteps is None:
133
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134
- elif timesteps is not None and not ddim_use_original_steps:
135
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136
- timesteps = self.ddim_timesteps[:subset_end]
137
-
138
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
139
- time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141
- print(f"Running PLMS Sampling with {total_steps} timesteps")
142
-
143
- iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144
- old_eps = []
145
-
146
- for i, step in enumerate(iterator):
147
- index = total_steps - i - 1
148
- ts = torch.full((b,), step, device=device, dtype=torch.long)
149
- ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150
-
151
- if mask is not None:
152
- assert x0 is not None
153
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154
- img = img_orig * mask + (1. - mask) * img
155
-
156
- outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157
- quantize_denoised=quantize_denoised, temperature=temperature,
158
- noise_dropout=noise_dropout, score_corrector=score_corrector,
159
- corrector_kwargs=corrector_kwargs,
160
- unconditional_guidance_scale=unconditional_guidance_scale,
161
- unconditional_conditioning=unconditional_conditioning,
162
- old_eps=old_eps, t_next=ts_next,
163
- dynamic_threshold=dynamic_threshold)
164
- img, pred_x0, e_t = outs
165
- old_eps.append(e_t)
166
- if len(old_eps) >= 4:
167
- old_eps.pop(0)
168
- if callback: callback(i)
169
- if img_callback: img_callback(pred_x0, i)
170
-
171
- if index % log_every_t == 0 or index == total_steps - 1:
172
- intermediates['x_inter'].append(img)
173
- intermediates['pred_x0'].append(pred_x0)
174
-
175
- return img, intermediates
176
-
177
- @torch.no_grad()
178
- def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
181
- dynamic_threshold=None):
182
- b, *_, device = *x.shape, x.device
183
-
184
- def get_model_output(x, t):
185
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186
- e_t = self.model.apply_model(x, t, c)
187
- else:
188
- x_in = torch.cat([x] * 2)
189
- t_in = torch.cat([t] * 2)
190
- c_in = torch.cat([unconditional_conditioning, c])
191
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
192
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
193
-
194
- if score_corrector is not None:
195
- assert self.model.parameterization == "eps"
196
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
197
-
198
- return e_t
199
-
200
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
201
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
202
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
203
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
204
-
205
- def get_x_prev_and_pred_x0(e_t, index):
206
- # select parameters corresponding to the currently considered timestep
207
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
208
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
209
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
210
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
211
-
212
- # current prediction for x_0
213
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
214
- if quantize_denoised:
215
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
216
- if dynamic_threshold is not None:
217
- pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
218
- # direction pointing to x_t
219
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
221
- if noise_dropout > 0.:
222
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224
- return x_prev, pred_x0
225
-
226
- e_t = get_model_output(x, t)
227
- if len(old_eps) == 0:
228
- # Pseudo Improved Euler (2nd order)
229
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
230
- e_t_next = get_model_output(x_prev, t_next)
231
- e_t_prime = (e_t + e_t_next) / 2
232
- elif len(old_eps) == 1:
233
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
234
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
235
- elif len(old_eps) == 2:
236
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
237
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
238
- elif len(old_eps) >= 3:
239
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
240
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
241
-
242
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
243
-
244
- return x_prev, pred_x0, e_t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/sampling_util.py DELETED
@@ -1,22 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- def append_dims(x, target_dims):
6
- """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7
- From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8
- dims_to_append = target_dims - x.ndim
9
- if dims_to_append < 0:
10
- raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11
- return x[(...,) + (None,) * dims_to_append]
12
-
13
-
14
- def norm_thresholding(x0, value):
15
- s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16
- return x0 * (value / s)
17
-
18
-
19
- def spatial_norm_thresholding(x0, value):
20
- # b c h w
21
- s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22
- return x0 * (value / s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/attention.py DELETED
@@ -1,330 +0,0 @@
1
- from inspect import isfunction
2
- import math
3
- import torch
4
- import torch.nn.functional as F
5
- from torch import nn, einsum
6
- from einops import rearrange, repeat
7
- from typing import Optional, Any
8
- import os
9
-
10
- from ldm.modules.diffusionmodules.util import checkpoint
11
-
12
- try:
13
- import xformers
14
- import xformers.ops
15
- XFORMERS_IS_AVAILBLE = True
16
- except:
17
- XFORMERS_IS_AVAILBLE = False
18
-
19
- # CrossAttn precision handling
20
- import os
21
- _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
-
23
- def exists(val):
24
- return val is not None
25
-
26
-
27
- def uniq(arr):
28
- return{el: True for el in arr}.keys()
29
-
30
-
31
- def default(val, d):
32
- if exists(val):
33
- return val
34
- return d() if isfunction(d) else d
35
-
36
- class GEGLU(nn.Module):
37
- def __init__(self, dim_in, dim_out):
38
- super().__init__()
39
- self.proj = nn.Linear(dim_in, dim_out * 2)
40
-
41
- def forward(self, x):
42
- x, gate = self.proj(x).chunk(2, dim=-1)
43
- return x * F.gelu(gate)
44
-
45
-
46
- class FeedForward(nn.Module):
47
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
48
- super().__init__()
49
- inner_dim = int(dim * mult)
50
- dim_out = default(dim_out, dim)
51
- project_in = nn.Sequential(
52
- nn.Linear(dim, inner_dim),
53
- nn.GELU()
54
- ) if not glu else GEGLU(dim, inner_dim)
55
-
56
- self.net = nn.Sequential(
57
- project_in,
58
- nn.Dropout(dropout),
59
- nn.Linear(inner_dim, dim_out)
60
- )
61
-
62
- def forward(self, x):
63
- return self.net(x)
64
-
65
-
66
- def zero_module(module):
67
- """
68
- Zero out the parameters of a module and return it.
69
- """
70
- for p in module.parameters():
71
- p.detach().zero_()
72
- return module
73
-
74
-
75
- def Normalize(in_channels):
76
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
77
-
78
-
79
- class SpatialSelfAttention(nn.Module):
80
- def __init__(self, in_channels):
81
- super().__init__()
82
- self.in_channels = in_channels
83
-
84
- self.norm = Normalize(in_channels)
85
- self.q = torch.nn.Conv2d(in_channels,
86
- in_channels,
87
- kernel_size=1,
88
- stride=1,
89
- padding=0)
90
- self.k = torch.nn.Conv2d(in_channels,
91
- in_channels,
92
- kernel_size=1,
93
- stride=1,
94
- padding=0)
95
- self.v = torch.nn.Conv2d(in_channels,
96
- in_channels,
97
- kernel_size=1,
98
- stride=1,
99
- padding=0)
100
- self.proj_out = torch.nn.Conv2d(in_channels,
101
- in_channels,
102
- kernel_size=1,
103
- stride=1,
104
- padding=0)
105
-
106
- def forward(self, x):
107
- h_ = x
108
- h_ = self.norm(h_)
109
- q = self.q(h_)
110
- k = self.k(h_)
111
- v = self.v(h_)
112
-
113
- b,c,h,w = q.shape
114
- q = rearrange(q, 'b c h w -> b (h w) c')
115
- k = rearrange(k, 'b c h w -> b c (h w)')
116
- w_ = torch.einsum('bij,bjk->bik', q, k)
117
-
118
- w_ = w_ * (int(c)**(-0.5))
119
- w_ = torch.nn.functional.softmax(w_, dim=2)
120
-
121
- v = rearrange(v, 'b c h w -> b c (h w)')
122
- w_ = rearrange(w_, 'b i j -> b j i')
123
- h_ = torch.einsum('bij,bjk->bik', v, w_)
124
- h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
125
- h_ = self.proj_out(h_)
126
-
127
- return x+h_
128
-
129
- class CrossAttention(nn.Module):
130
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kwargs):
131
- super().__init__()
132
- inner_dim = dim_head * heads
133
- context_dim = default(context_dim, query_dim)
134
-
135
- self.scale = dim_head ** -0.5
136
- self.heads = heads
137
-
138
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
139
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
140
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
141
-
142
- self.to_out = nn.Sequential(
143
- nn.Linear(inner_dim, query_dim),
144
- nn.Dropout(dropout)
145
- )
146
-
147
-
148
- def forward(self, x, context=None, mask=None):
149
- h = self.heads
150
- q = self.to_q(x)
151
- context = default(context, x)
152
- k = self.to_k(context)
153
- v = self.to_v(context)
154
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
155
-
156
- if _ATTN_PRECISION =="fp32":
157
- with torch.autocast(enabled=False, device_type = 'cuda'):
158
- q, k = q.float(), k.float()
159
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
160
- else:
161
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
162
-
163
- del q, k
164
- if exists(mask):
165
- mask = rearrange(mask, 'b ... -> b (...)')
166
- max_neg_value = -torch.finfo(sim.dtype).max
167
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
168
- sim.masked_fill_(~mask, max_neg_value)
169
-
170
- sim = sim.softmax(dim=-1)
171
-
172
- out = einsum('b i j, b j d -> b i d', sim, v)
173
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
174
- return self.to_out(out)
175
-
176
- class MemoryEfficientCrossAttention(nn.Module):
177
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
178
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, zero_init=False, **kwargs):
179
- super().__init__()
180
- print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
181
- f"{heads} heads.")
182
- inner_dim = dim_head * heads
183
- context_dim = default(context_dim, query_dim)
184
-
185
- self.heads = heads
186
- self.dim_head = dim_head
187
- if not zero_init:
188
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
189
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
190
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
191
- else:
192
- self.to_q = zero_module(nn.Linear(query_dim, inner_dim, bias=False))
193
- self.to_k = zero_module(nn.Linear(context_dim, inner_dim, bias=False))
194
- self.to_v = zero_module(nn.Linear(context_dim, inner_dim, bias=False))
195
-
196
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
197
- self.attention_op: Optional[Any] = None
198
-
199
-
200
- def forward(self, x, context=None, mask=None, **kwargs):
201
- q = self.to_q(x)
202
- context = default(context, x)
203
- k = self.to_k(context)
204
- v = self.to_v(context)
205
- b, _, _ = q.shape
206
- q, k, v = map(
207
- lambda t: t.unsqueeze(3)
208
- .reshape(b, t.shape[1], self.heads, self.dim_head)
209
- .permute(0, 2, 1, 3)
210
- .reshape(b * self.heads, t.shape[1], self.dim_head)
211
- .contiguous(),
212
- (q, k, v),
213
- )
214
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
215
- if exists(mask):
216
- raise NotImplementedError
217
- out = (
218
- out.unsqueeze(0)
219
- .reshape(b, self.heads, out.shape[1], self.dim_head)
220
- .permute(0, 2, 1, 3)
221
- .reshape(b, out.shape[1], self.heads * self.dim_head)
222
- )
223
- return self.to_out(out)
224
-
225
- class BasicTransformerBlock(nn.Module):
226
- ATTENTION_MODES = {
227
- "softmax": CrossAttention, # vanilla attention
228
- "softmax-xformers": MemoryEfficientCrossAttention
229
- }
230
- def __init__(
231
- self,
232
- dim,
233
- n_heads,
234
- d_head,
235
- dropout=0.,
236
- context_dim=None,
237
- gated_ff=True,
238
- checkpoint=True,
239
- disable_self_attn=False
240
- ):
241
- super().__init__()
242
- attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
243
- assert attn_mode in self.ATTENTION_MODES
244
- attn_cls = self.ATTENTION_MODES[attn_mode]
245
- self.disable_self_attn = disable_self_attn
246
- self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
247
- context_dim=context_dim if self.disable_self_attn else None)
248
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
249
- self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
250
- heads=n_heads, dim_head=d_head, dropout=dropout)
251
- self.norm1 = nn.LayerNorm(dim)
252
- self.norm2 = nn.LayerNorm(dim)
253
- self.norm3 = nn.LayerNorm(dim)
254
- self.checkpoint = checkpoint
255
-
256
- def forward(self, x, context=None,hint=None):
257
- if hint is None:
258
- return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
259
- else:
260
- return checkpoint(self._forward, (x, context, hint), self.parameters(), self.checkpoint)
261
-
262
- def _forward(self, x, context=None,hint=None):
263
- x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None,hint=hint) + x
264
- x = self.attn2(self.norm2(x), context=context) + x
265
- x = self.ff(self.norm3(x)) + x
266
- return x
267
-
268
- class SpatialTransformer(nn.Module):
269
- """
270
- Transformer block for image-like data.
271
- First, project the input (aka embedding)
272
- and reshape to b, t, d.
273
- Then apply standard transformer action.
274
- Finally, reshape to image
275
- NEW: use_linear for more efficiency instead of the 1x1 convs
276
- """
277
- def __init__(self, in_channels, n_heads, d_head,
278
- depth=1, dropout=0., context_dim=None,
279
- disable_self_attn=False, use_linear=False,
280
- use_checkpoint=True):
281
- super().__init__()
282
- if exists(context_dim) and not isinstance(context_dim, list):
283
- context_dim = [context_dim]
284
- self.in_channels = in_channels
285
- inner_dim = n_heads * d_head
286
- self.norm = Normalize(in_channels)
287
- if not use_linear:
288
- self.proj_in = nn.Conv2d(in_channels,
289
- inner_dim,
290
- kernel_size=1,
291
- stride=1,
292
- padding=0)
293
- else:
294
- self.proj_in = nn.Linear(in_channels, inner_dim)
295
-
296
- self.transformer_blocks = nn.ModuleList(
297
- [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
298
- disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
299
- for d in range(depth)]
300
- )
301
- if not use_linear:
302
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
303
- in_channels,
304
- kernel_size=1,
305
- stride=1,
306
- padding=0))
307
- else:
308
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
309
- self.use_linear = use_linear
310
-
311
- def forward(self, x, context=None,hint=None):
312
- # note: if no context is given, cross-attention defaults to self-attention
313
- if not isinstance(context, list):
314
- context = [context]
315
- b, c, h, w = x.shape
316
- x_in = x
317
- x = self.norm(x)
318
- if not self.use_linear:
319
- x = self.proj_in(x)
320
- x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
321
- if self.use_linear:
322
- x = self.proj_in(x)
323
- for i, block in enumerate(self.transformer_blocks):
324
- x = block(x, context=context[i],hint=hint)
325
- if self.use_linear:
326
- x = self.proj_out(x)
327
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
328
- if not self.use_linear:
329
- x = self.proj_out(x)
330
- return x + x_in
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/diffusionmodules/__init__.py DELETED
File without changes
ldm/modules/diffusionmodules/model.py DELETED
@@ -1,852 +0,0 @@
1
- # pytorch_diffusion + derived encoder decoder
2
- import math
3
- import torch
4
- import torch.nn as nn
5
- import numpy as np
6
- from einops import rearrange
7
- from typing import Optional, Any
8
-
9
- from ldm.modules.attention import MemoryEfficientCrossAttention
10
-
11
- try:
12
- import xformers
13
- import xformers.ops
14
- XFORMERS_IS_AVAILBLE = True
15
- except:
16
- XFORMERS_IS_AVAILBLE = False
17
- print("No module 'xformers'. Proceeding without it.")
18
-
19
-
20
- def get_timestep_embedding(timesteps, embedding_dim):
21
- """
22
- This matches the implementation in Denoising Diffusion Probabilistic Models:
23
- From Fairseq.
24
- Build sinusoidal embeddings.
25
- This matches the implementation in tensor2tensor, but differs slightly
26
- from the description in Section 3.5 of "Attention Is All You Need".
27
- """
28
- assert len(timesteps.shape) == 1
29
-
30
- half_dim = embedding_dim // 2
31
- emb = math.log(10000) / (half_dim - 1)
32
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
33
- emb = emb.to(device=timesteps.device)
34
- emb = timesteps.float()[:, None] * emb[None, :]
35
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
36
- if embedding_dim % 2 == 1: # zero pad
37
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
38
- return emb
39
-
40
-
41
- def nonlinearity(x):
42
- # swish
43
- return x*torch.sigmoid(x)
44
-
45
-
46
- def Normalize(in_channels, num_groups=32):
47
- return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
48
-
49
-
50
- class Upsample(nn.Module):
51
- def __init__(self, in_channels, with_conv):
52
- super().__init__()
53
- self.with_conv = with_conv
54
- if self.with_conv:
55
- self.conv = torch.nn.Conv2d(in_channels,
56
- in_channels,
57
- kernel_size=3,
58
- stride=1,
59
- padding=1)
60
-
61
- def forward(self, x):
62
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
63
- if self.with_conv:
64
- x = self.conv(x)
65
- return x
66
-
67
-
68
- class Downsample(nn.Module):
69
- def __init__(self, in_channels, with_conv):
70
- super().__init__()
71
- self.with_conv = with_conv
72
- if self.with_conv:
73
- # no asymmetric padding in torch conv, must do it ourselves
74
- self.conv = torch.nn.Conv2d(in_channels,
75
- in_channels,
76
- kernel_size=3,
77
- stride=2,
78
- padding=0)
79
-
80
- def forward(self, x):
81
- if self.with_conv:
82
- pad = (0,1,0,1)
83
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
84
- x = self.conv(x)
85
- else:
86
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
87
- return x
88
-
89
-
90
- class ResnetBlock(nn.Module):
91
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
92
- dropout, temb_channels=512):
93
- super().__init__()
94
- self.in_channels = in_channels
95
- out_channels = in_channels if out_channels is None else out_channels
96
- self.out_channels = out_channels
97
- self.use_conv_shortcut = conv_shortcut
98
-
99
- self.norm1 = Normalize(in_channels)
100
- self.conv1 = torch.nn.Conv2d(in_channels,
101
- out_channels,
102
- kernel_size=3,
103
- stride=1,
104
- padding=1)
105
- if temb_channels > 0:
106
- self.temb_proj = torch.nn.Linear(temb_channels,
107
- out_channels)
108
- self.norm2 = Normalize(out_channels)
109
- self.dropout = torch.nn.Dropout(dropout)
110
- self.conv2 = torch.nn.Conv2d(out_channels,
111
- out_channels,
112
- kernel_size=3,
113
- stride=1,
114
- padding=1)
115
- if self.in_channels != self.out_channels:
116
- if self.use_conv_shortcut:
117
- self.conv_shortcut = torch.nn.Conv2d(in_channels,
118
- out_channels,
119
- kernel_size=3,
120
- stride=1,
121
- padding=1)
122
- else:
123
- self.nin_shortcut = torch.nn.Conv2d(in_channels,
124
- out_channels,
125
- kernel_size=1,
126
- stride=1,
127
- padding=0)
128
-
129
- def forward(self, x, temb):
130
- h = x
131
- h = self.norm1(h)
132
- h = nonlinearity(h)
133
- h = self.conv1(h)
134
-
135
- if temb is not None:
136
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
137
-
138
- h = self.norm2(h)
139
- h = nonlinearity(h)
140
- h = self.dropout(h)
141
- h = self.conv2(h)
142
-
143
- if self.in_channels != self.out_channels:
144
- if self.use_conv_shortcut:
145
- x = self.conv_shortcut(x)
146
- else:
147
- x = self.nin_shortcut(x)
148
-
149
- return x+h
150
-
151
-
152
- class AttnBlock(nn.Module):
153
- def __init__(self, in_channels):
154
- super().__init__()
155
- self.in_channels = in_channels
156
-
157
- self.norm = Normalize(in_channels)
158
- self.q = torch.nn.Conv2d(in_channels,
159
- in_channels,
160
- kernel_size=1,
161
- stride=1,
162
- padding=0)
163
- self.k = torch.nn.Conv2d(in_channels,
164
- in_channels,
165
- kernel_size=1,
166
- stride=1,
167
- padding=0)
168
- self.v = torch.nn.Conv2d(in_channels,
169
- in_channels,
170
- kernel_size=1,
171
- stride=1,
172
- padding=0)
173
- self.proj_out = torch.nn.Conv2d(in_channels,
174
- in_channels,
175
- kernel_size=1,
176
- stride=1,
177
- padding=0)
178
-
179
- def forward(self, x):
180
- h_ = x
181
- h_ = self.norm(h_)
182
- q = self.q(h_)
183
- k = self.k(h_)
184
- v = self.v(h_)
185
-
186
- # compute attention
187
- b,c,h,w = q.shape
188
- q = q.reshape(b,c,h*w)
189
- q = q.permute(0,2,1) # b,hw,c
190
- k = k.reshape(b,c,h*w) # b,c,hw
191
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
192
- w_ = w_ * (int(c)**(-0.5))
193
- w_ = torch.nn.functional.softmax(w_, dim=2)
194
-
195
- # attend to values
196
- v = v.reshape(b,c,h*w)
197
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
198
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
199
- h_ = h_.reshape(b,c,h,w)
200
-
201
- h_ = self.proj_out(h_)
202
-
203
- return x+h_
204
-
205
- class MemoryEfficientAttnBlock(nn.Module):
206
- """
207
- Uses xformers efficient implementation,
208
- see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
209
- Note: this is a single-head self-attention operation
210
- """
211
- #
212
- def __init__(self, in_channels):
213
- super().__init__()
214
- self.in_channels = in_channels
215
-
216
- self.norm = Normalize(in_channels)
217
- self.q = torch.nn.Conv2d(in_channels,
218
- in_channels,
219
- kernel_size=1,
220
- stride=1,
221
- padding=0)
222
- self.k = torch.nn.Conv2d(in_channels,
223
- in_channels,
224
- kernel_size=1,
225
- stride=1,
226
- padding=0)
227
- self.v = torch.nn.Conv2d(in_channels,
228
- in_channels,
229
- kernel_size=1,
230
- stride=1,
231
- padding=0)
232
- self.proj_out = torch.nn.Conv2d(in_channels,
233
- in_channels,
234
- kernel_size=1,
235
- stride=1,
236
- padding=0)
237
- self.attention_op: Optional[Any] = None
238
-
239
- def forward(self, x):
240
- h_ = x
241
- h_ = self.norm(h_)
242
- q = self.q(h_)
243
- k = self.k(h_)
244
- v = self.v(h_)
245
-
246
- # compute attention
247
- B, C, H, W = q.shape
248
- q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
249
-
250
- q, k, v = map(
251
- lambda t: t.unsqueeze(3)
252
- .reshape(B, t.shape[1], 1, C)
253
- .permute(0, 2, 1, 3)
254
- .reshape(B * 1, t.shape[1], C)
255
- .contiguous(),
256
- (q, k, v),
257
- )
258
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
259
-
260
- out = (
261
- out.unsqueeze(0)
262
- .reshape(B, 1, out.shape[1], C)
263
- .permute(0, 2, 1, 3)
264
- .reshape(B, out.shape[1], C)
265
- )
266
- out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
267
- out = self.proj_out(out)
268
- return x+out
269
-
270
-
271
- class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
272
- def forward(self, x, context=None, mask=None):
273
- b, c, h, w = x.shape
274
- x = rearrange(x, 'b c h w -> b (h w) c')
275
- out = super().forward(x, context=context, mask=mask)
276
- out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
277
- return x + out
278
-
279
-
280
- def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
281
- assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
282
- if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
283
- attn_type = "vanilla-xformers"
284
- print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
285
- if attn_type == "vanilla":
286
- assert attn_kwargs is None
287
- return AttnBlock(in_channels)
288
- elif attn_type == "vanilla-xformers":
289
- print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
290
- return MemoryEfficientAttnBlock(in_channels)
291
- elif type == "memory-efficient-cross-attn":
292
- attn_kwargs["query_dim"] = in_channels
293
- return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
294
- elif attn_type == "none":
295
- return nn.Identity(in_channels)
296
- else:
297
- raise NotImplementedError()
298
-
299
-
300
- class Model(nn.Module):
301
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
302
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
303
- resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
304
- super().__init__()
305
- if use_linear_attn: attn_type = "linear"
306
- self.ch = ch
307
- self.temb_ch = self.ch*4
308
- self.num_resolutions = len(ch_mult)
309
- self.num_res_blocks = num_res_blocks
310
- self.resolution = resolution
311
- self.in_channels = in_channels
312
-
313
- self.use_timestep = use_timestep
314
- if self.use_timestep:
315
- # timestep embedding
316
- self.temb = nn.Module()
317
- self.temb.dense = nn.ModuleList([
318
- torch.nn.Linear(self.ch,
319
- self.temb_ch),
320
- torch.nn.Linear(self.temb_ch,
321
- self.temb_ch),
322
- ])
323
-
324
- # downsampling
325
- self.conv_in = torch.nn.Conv2d(in_channels,
326
- self.ch,
327
- kernel_size=3,
328
- stride=1,
329
- padding=1)
330
-
331
- curr_res = resolution
332
- in_ch_mult = (1,)+tuple(ch_mult)
333
- self.down = nn.ModuleList()
334
- for i_level in range(self.num_resolutions):
335
- block = nn.ModuleList()
336
- attn = nn.ModuleList()
337
- block_in = ch*in_ch_mult[i_level]
338
- block_out = ch*ch_mult[i_level]
339
- for i_block in range(self.num_res_blocks):
340
- block.append(ResnetBlock(in_channels=block_in,
341
- out_channels=block_out,
342
- temb_channels=self.temb_ch,
343
- dropout=dropout))
344
- block_in = block_out
345
- if curr_res in attn_resolutions:
346
- attn.append(make_attn(block_in, attn_type=attn_type))
347
- down = nn.Module()
348
- down.block = block
349
- down.attn = attn
350
- if i_level != self.num_resolutions-1:
351
- down.downsample = Downsample(block_in, resamp_with_conv)
352
- curr_res = curr_res // 2
353
- self.down.append(down)
354
-
355
- # middle
356
- self.mid = nn.Module()
357
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
358
- out_channels=block_in,
359
- temb_channels=self.temb_ch,
360
- dropout=dropout)
361
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
362
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
363
- out_channels=block_in,
364
- temb_channels=self.temb_ch,
365
- dropout=dropout)
366
-
367
- # upsampling
368
- self.up = nn.ModuleList()
369
- for i_level in reversed(range(self.num_resolutions)):
370
- block = nn.ModuleList()
371
- attn = nn.ModuleList()
372
- block_out = ch*ch_mult[i_level]
373
- skip_in = ch*ch_mult[i_level]
374
- for i_block in range(self.num_res_blocks+1):
375
- if i_block == self.num_res_blocks:
376
- skip_in = ch*in_ch_mult[i_level]
377
- block.append(ResnetBlock(in_channels=block_in+skip_in,
378
- out_channels=block_out,
379
- temb_channels=self.temb_ch,
380
- dropout=dropout))
381
- block_in = block_out
382
- if curr_res in attn_resolutions:
383
- attn.append(make_attn(block_in, attn_type=attn_type))
384
- up = nn.Module()
385
- up.block = block
386
- up.attn = attn
387
- if i_level != 0:
388
- up.upsample = Upsample(block_in, resamp_with_conv)
389
- curr_res = curr_res * 2
390
- self.up.insert(0, up) # prepend to get consistent order
391
-
392
- # end
393
- self.norm_out = Normalize(block_in)
394
- self.conv_out = torch.nn.Conv2d(block_in,
395
- out_ch,
396
- kernel_size=3,
397
- stride=1,
398
- padding=1)
399
-
400
- def forward(self, x, t=None, context=None):
401
- #assert x.shape[2] == x.shape[3] == self.resolution
402
- if context is not None:
403
- # assume aligned context, cat along channel axis
404
- x = torch.cat((x, context), dim=1)
405
- if self.use_timestep:
406
- # timestep embedding
407
- assert t is not None
408
- temb = get_timestep_embedding(t, self.ch)
409
- temb = self.temb.dense[0](temb)
410
- temb = nonlinearity(temb)
411
- temb = self.temb.dense[1](temb)
412
- else:
413
- temb = None
414
-
415
- # downsampling
416
- hs = [self.conv_in(x)]
417
- for i_level in range(self.num_resolutions):
418
- for i_block in range(self.num_res_blocks):
419
- h = self.down[i_level].block[i_block](hs[-1], temb)
420
- if len(self.down[i_level].attn) > 0:
421
- h = self.down[i_level].attn[i_block](h)
422
- hs.append(h)
423
- if i_level != self.num_resolutions-1:
424
- hs.append(self.down[i_level].downsample(hs[-1]))
425
-
426
- # middle
427
- h = hs[-1]
428
- h = self.mid.block_1(h, temb)
429
- h = self.mid.attn_1(h)
430
- h = self.mid.block_2(h, temb)
431
-
432
- # upsampling
433
- for i_level in reversed(range(self.num_resolutions)):
434
- for i_block in range(self.num_res_blocks+1):
435
- h = self.up[i_level].block[i_block](
436
- torch.cat([h, hs.pop()], dim=1), temb)
437
- if len(self.up[i_level].attn) > 0:
438
- h = self.up[i_level].attn[i_block](h)
439
- if i_level != 0:
440
- h = self.up[i_level].upsample(h)
441
-
442
- # end
443
- h = self.norm_out(h)
444
- h = nonlinearity(h)
445
- h = self.conv_out(h)
446
- return h
447
-
448
- def get_last_layer(self):
449
- return self.conv_out.weight
450
-
451
-
452
- class Encoder(nn.Module):
453
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
454
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
455
- resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
456
- **ignore_kwargs):
457
- super().__init__()
458
- if use_linear_attn: attn_type = "linear"
459
- self.ch = ch
460
- self.temb_ch = 0
461
- self.num_resolutions = len(ch_mult)
462
- self.num_res_blocks = num_res_blocks
463
- self.resolution = resolution
464
- self.in_channels = in_channels
465
-
466
- # downsampling
467
- self.conv_in = torch.nn.Conv2d(in_channels,
468
- self.ch,
469
- kernel_size=3,
470
- stride=1,
471
- padding=1)
472
-
473
- curr_res = resolution
474
- in_ch_mult = (1,)+tuple(ch_mult)
475
- self.in_ch_mult = in_ch_mult
476
- self.down = nn.ModuleList()
477
- for i_level in range(self.num_resolutions):
478
- block = nn.ModuleList()
479
- attn = nn.ModuleList()
480
- block_in = ch*in_ch_mult[i_level]
481
- block_out = ch*ch_mult[i_level]
482
- for i_block in range(self.num_res_blocks):
483
- block.append(ResnetBlock(in_channels=block_in,
484
- out_channels=block_out,
485
- temb_channels=self.temb_ch,
486
- dropout=dropout))
487
- block_in = block_out
488
- if curr_res in attn_resolutions:
489
- attn.append(make_attn(block_in, attn_type=attn_type))
490
- down = nn.Module()
491
- down.block = block
492
- down.attn = attn
493
- if i_level != self.num_resolutions-1:
494
- down.downsample = Downsample(block_in, resamp_with_conv)
495
- curr_res = curr_res // 2
496
- self.down.append(down)
497
-
498
- # middle
499
- self.mid = nn.Module()
500
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
501
- out_channels=block_in,
502
- temb_channels=self.temb_ch,
503
- dropout=dropout)
504
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
505
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
506
- out_channels=block_in,
507
- temb_channels=self.temb_ch,
508
- dropout=dropout)
509
-
510
- # end
511
- self.norm_out = Normalize(block_in)
512
- self.conv_out = torch.nn.Conv2d(block_in,
513
- 2*z_channels if double_z else z_channels,
514
- kernel_size=3,
515
- stride=1,
516
- padding=1)
517
-
518
- def forward(self, x):
519
- # timestep embedding
520
- temb = None
521
-
522
- # downsampling
523
- hs = [self.conv_in(x)]
524
- for i_level in range(self.num_resolutions):
525
- for i_block in range(self.num_res_blocks):
526
- h = self.down[i_level].block[i_block](hs[-1], temb)
527
- if len(self.down[i_level].attn) > 0:
528
- h = self.down[i_level].attn[i_block](h)
529
- hs.append(h)
530
- if i_level != self.num_resolutions-1:
531
- hs.append(self.down[i_level].downsample(hs[-1]))
532
-
533
- # middle
534
- h = hs[-1]
535
- h = self.mid.block_1(h, temb)
536
- h = self.mid.attn_1(h)
537
- h = self.mid.block_2(h, temb)
538
-
539
- # end
540
- h = self.norm_out(h)
541
- h = nonlinearity(h)
542
- h = self.conv_out(h)
543
- return h
544
-
545
-
546
- class Decoder(nn.Module):
547
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
548
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
549
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
550
- attn_type="vanilla", **ignorekwargs):
551
- super().__init__()
552
- if use_linear_attn: attn_type = "linear"
553
- self.ch = ch
554
- self.temb_ch = 0
555
- self.num_resolutions = len(ch_mult)
556
- self.num_res_blocks = num_res_blocks
557
- self.resolution = resolution
558
- self.in_channels = in_channels
559
- self.give_pre_end = give_pre_end
560
- self.tanh_out = tanh_out
561
-
562
- # compute in_ch_mult, block_in and curr_res at lowest res
563
- in_ch_mult = (1,)+tuple(ch_mult)
564
- block_in = ch*ch_mult[self.num_resolutions-1]
565
- curr_res = resolution // 2**(self.num_resolutions-1)
566
- self.z_shape = (1,z_channels,curr_res,curr_res)
567
- print("Working with z of shape {} = {} dimensions.".format(
568
- self.z_shape, np.prod(self.z_shape)))
569
-
570
- # z to block_in
571
- self.conv_in = torch.nn.Conv2d(z_channels,
572
- block_in,
573
- kernel_size=3,
574
- stride=1,
575
- padding=1)
576
-
577
- # middle
578
- self.mid = nn.Module()
579
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
580
- out_channels=block_in,
581
- temb_channels=self.temb_ch,
582
- dropout=dropout)
583
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
584
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
585
- out_channels=block_in,
586
- temb_channels=self.temb_ch,
587
- dropout=dropout)
588
-
589
- # upsampling
590
- self.up = nn.ModuleList()
591
- for i_level in reversed(range(self.num_resolutions)):
592
- block = nn.ModuleList()
593
- attn = nn.ModuleList()
594
- block_out = ch*ch_mult[i_level]
595
- for i_block in range(self.num_res_blocks+1):
596
- block.append(ResnetBlock(in_channels=block_in,
597
- out_channels=block_out,
598
- temb_channels=self.temb_ch,
599
- dropout=dropout))
600
- block_in = block_out
601
- if curr_res in attn_resolutions:
602
- attn.append(make_attn(block_in, attn_type=attn_type))
603
- up = nn.Module()
604
- up.block = block
605
- up.attn = attn
606
- if i_level != 0:
607
- up.upsample = Upsample(block_in, resamp_with_conv)
608
- curr_res = curr_res * 2
609
- self.up.insert(0, up) # prepend to get consistent order
610
-
611
- # end
612
- self.norm_out = Normalize(block_in)
613
- self.conv_out = torch.nn.Conv2d(block_in,
614
- out_ch,
615
- kernel_size=3,
616
- stride=1,
617
- padding=1)
618
-
619
- def forward(self, z):
620
- #assert z.shape[1:] == self.z_shape[1:]
621
- self.last_z_shape = z.shape
622
-
623
- # timestep embedding
624
- temb = None
625
-
626
- # z to block_in
627
- h = self.conv_in(z)
628
-
629
- # middle
630
- h = self.mid.block_1(h, temb)
631
- h = self.mid.attn_1(h)
632
- h = self.mid.block_2(h, temb)
633
-
634
- # upsampling
635
- for i_level in reversed(range(self.num_resolutions)):
636
- for i_block in range(self.num_res_blocks+1):
637
- h = self.up[i_level].block[i_block](h, temb)
638
- if len(self.up[i_level].attn) > 0:
639
- h = self.up[i_level].attn[i_block](h)
640
- if i_level != 0:
641
- h = self.up[i_level].upsample(h)
642
-
643
- # end
644
- if self.give_pre_end:
645
- return h
646
-
647
- h = self.norm_out(h)
648
- h = nonlinearity(h)
649
- h = self.conv_out(h)
650
- if self.tanh_out:
651
- h = torch.tanh(h)
652
- return h
653
-
654
-
655
- class SimpleDecoder(nn.Module):
656
- def __init__(self, in_channels, out_channels, *args, **kwargs):
657
- super().__init__()
658
- self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
659
- ResnetBlock(in_channels=in_channels,
660
- out_channels=2 * in_channels,
661
- temb_channels=0, dropout=0.0),
662
- ResnetBlock(in_channels=2 * in_channels,
663
- out_channels=4 * in_channels,
664
- temb_channels=0, dropout=0.0),
665
- ResnetBlock(in_channels=4 * in_channels,
666
- out_channels=2 * in_channels,
667
- temb_channels=0, dropout=0.0),
668
- nn.Conv2d(2*in_channels, in_channels, 1),
669
- Upsample(in_channels, with_conv=True)])
670
- # end
671
- self.norm_out = Normalize(in_channels)
672
- self.conv_out = torch.nn.Conv2d(in_channels,
673
- out_channels,
674
- kernel_size=3,
675
- stride=1,
676
- padding=1)
677
-
678
- def forward(self, x):
679
- for i, layer in enumerate(self.model):
680
- if i in [1,2,3]:
681
- x = layer(x, None)
682
- else:
683
- x = layer(x)
684
-
685
- h = self.norm_out(x)
686
- h = nonlinearity(h)
687
- x = self.conv_out(h)
688
- return x
689
-
690
-
691
- class UpsampleDecoder(nn.Module):
692
- def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
693
- ch_mult=(2,2), dropout=0.0):
694
- super().__init__()
695
- # upsampling
696
- self.temb_ch = 0
697
- self.num_resolutions = len(ch_mult)
698
- self.num_res_blocks = num_res_blocks
699
- block_in = in_channels
700
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
701
- self.res_blocks = nn.ModuleList()
702
- self.upsample_blocks = nn.ModuleList()
703
- for i_level in range(self.num_resolutions):
704
- res_block = []
705
- block_out = ch * ch_mult[i_level]
706
- for i_block in range(self.num_res_blocks + 1):
707
- res_block.append(ResnetBlock(in_channels=block_in,
708
- out_channels=block_out,
709
- temb_channels=self.temb_ch,
710
- dropout=dropout))
711
- block_in = block_out
712
- self.res_blocks.append(nn.ModuleList(res_block))
713
- if i_level != self.num_resolutions - 1:
714
- self.upsample_blocks.append(Upsample(block_in, True))
715
- curr_res = curr_res * 2
716
-
717
- # end
718
- self.norm_out = Normalize(block_in)
719
- self.conv_out = torch.nn.Conv2d(block_in,
720
- out_channels,
721
- kernel_size=3,
722
- stride=1,
723
- padding=1)
724
-
725
- def forward(self, x):
726
- # upsampling
727
- h = x
728
- for k, i_level in enumerate(range(self.num_resolutions)):
729
- for i_block in range(self.num_res_blocks + 1):
730
- h = self.res_blocks[i_level][i_block](h, None)
731
- if i_level != self.num_resolutions - 1:
732
- h = self.upsample_blocks[k](h)
733
- h = self.norm_out(h)
734
- h = nonlinearity(h)
735
- h = self.conv_out(h)
736
- return h
737
-
738
-
739
- class LatentRescaler(nn.Module):
740
- def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
741
- super().__init__()
742
- # residual block, interpolate, residual block
743
- self.factor = factor
744
- self.conv_in = nn.Conv2d(in_channels,
745
- mid_channels,
746
- kernel_size=3,
747
- stride=1,
748
- padding=1)
749
- self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
750
- out_channels=mid_channels,
751
- temb_channels=0,
752
- dropout=0.0) for _ in range(depth)])
753
- self.attn = AttnBlock(mid_channels)
754
- self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
755
- out_channels=mid_channels,
756
- temb_channels=0,
757
- dropout=0.0) for _ in range(depth)])
758
-
759
- self.conv_out = nn.Conv2d(mid_channels,
760
- out_channels,
761
- kernel_size=1,
762
- )
763
-
764
- def forward(self, x):
765
- x = self.conv_in(x)
766
- for block in self.res_block1:
767
- x = block(x, None)
768
- x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
769
- x = self.attn(x)
770
- for block in self.res_block2:
771
- x = block(x, None)
772
- x = self.conv_out(x)
773
- return x
774
-
775
-
776
- class MergedRescaleEncoder(nn.Module):
777
- def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
778
- attn_resolutions, dropout=0.0, resamp_with_conv=True,
779
- ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
780
- super().__init__()
781
- intermediate_chn = ch * ch_mult[-1]
782
- self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
783
- z_channels=intermediate_chn, double_z=False, resolution=resolution,
784
- attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
785
- out_ch=None)
786
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
787
- mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
788
-
789
- def forward(self, x):
790
- x = self.encoder(x)
791
- x = self.rescaler(x)
792
- return x
793
-
794
-
795
- class MergedRescaleDecoder(nn.Module):
796
- def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
797
- dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
798
- super().__init__()
799
- tmp_chn = z_channels*ch_mult[-1]
800
- self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
801
- resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
802
- ch_mult=ch_mult, resolution=resolution, ch=ch)
803
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
804
- out_channels=tmp_chn, depth=rescale_module_depth)
805
-
806
- def forward(self, x):
807
- x = self.rescaler(x)
808
- x = self.decoder(x)
809
- return x
810
-
811
-
812
- class Upsampler(nn.Module):
813
- def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
814
- super().__init__()
815
- assert out_size >= in_size
816
- num_blocks = int(np.log2(out_size//in_size))+1
817
- factor_up = 1.+ (out_size % in_size)
818
- print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
819
- self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
820
- out_channels=in_channels)
821
- self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
822
- attn_resolutions=[], in_channels=None, ch=in_channels,
823
- ch_mult=[ch_mult for _ in range(num_blocks)])
824
-
825
- def forward(self, x):
826
- x = self.rescaler(x)
827
- x = self.decoder(x)
828
- return x
829
-
830
-
831
- class Resize(nn.Module):
832
- def __init__(self, in_channels=None, learned=False, mode="bilinear"):
833
- super().__init__()
834
- self.with_conv = learned
835
- self.mode = mode
836
- if self.with_conv:
837
- print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
838
- raise NotImplementedError()
839
- assert in_channels is not None
840
- # no asymmetric padding in torch conv, must do it ourselves
841
- self.conv = torch.nn.Conv2d(in_channels,
842
- in_channels,
843
- kernel_size=4,
844
- stride=2,
845
- padding=1)
846
-
847
- def forward(self, x, scale_factor=1.0):
848
- if scale_factor==1.0:
849
- return x
850
- else:
851
- x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
852
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/diffusionmodules/openaimodel.py DELETED
@@ -1,790 +0,0 @@
1
- from abc import abstractmethod
2
- import math
3
-
4
- import numpy as np
5
- import torch as th
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- from ldm.modules.diffusionmodules.util import (
10
- checkpoint,
11
- conv_nd,
12
- linear,
13
- avg_pool_nd,
14
- zero_module,
15
- normalization,
16
- timestep_embedding,
17
- )
18
- from ldm.modules.attention import SpatialTransformer
19
- from ldm.util import exists
20
-
21
-
22
- # dummy replace
23
- def convert_module_to_f16(x):
24
- pass
25
-
26
- def convert_module_to_f32(x):
27
- pass
28
-
29
-
30
- ## go
31
- class AttentionPool2d(nn.Module):
32
- """
33
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
34
- """
35
-
36
- def __init__(
37
- self,
38
- spacial_dim: int,
39
- embed_dim: int,
40
- num_heads_channels: int,
41
- output_dim: int = None,
42
- ):
43
- super().__init__()
44
- self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
45
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
46
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
47
- self.num_heads = embed_dim // num_heads_channels
48
- self.attention = QKVAttention(self.num_heads)
49
-
50
- def forward(self, x):
51
- b, c, *_spatial = x.shape
52
- x = x.reshape(b, c, -1) # NC(HW)
53
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
54
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
55
- x = self.qkv_proj(x)
56
- x = self.attention(x)
57
- x = self.c_proj(x)
58
- return x[:, :, 0]
59
-
60
-
61
- class TimestepBlock(nn.Module):
62
- """
63
- Any module where forward() takes timestep embeddings as a second argument.
64
- """
65
-
66
- @abstractmethod
67
- def forward(self, x, emb):
68
- """
69
- Apply the module to `x` given `emb` timestep embeddings.
70
- """
71
-
72
-
73
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
74
- """
75
- A sequential module that passes timestep embeddings to the children that
76
- support it as an extra input.
77
- """
78
-
79
- def forward(self, x, emb, context=None, hint=None):
80
- for layer in self:
81
- if isinstance(layer, TimestepBlock):
82
- x = layer(x, emb)
83
- elif isinstance(layer, SpatialTransformer):
84
- x = layer(x, context, hint)
85
- else:
86
- x = layer(x)
87
- return x
88
-
89
-
90
- class Upsample(nn.Module):
91
- """
92
- An upsampling layer with an optional convolution.
93
- :param channels: channels in the inputs and outputs.
94
- :param use_conv: a bool determining if a convolution is applied.
95
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
96
- upsampling occurs in the inner-two dimensions.
97
- """
98
-
99
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
100
- super().__init__()
101
- self.channels = channels
102
- self.out_channels = out_channels or channels
103
- self.use_conv = use_conv
104
- self.dims = dims
105
- if use_conv:
106
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
107
-
108
- def forward(self, x):
109
- assert x.shape[1] == self.channels
110
- if self.dims == 3:
111
- x = F.interpolate(
112
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
113
- )
114
- else:
115
- x = F.interpolate(x, scale_factor=2, mode="nearest")
116
- if self.use_conv:
117
- x = self.conv(x)
118
- return x
119
-
120
- class TransposedUpsample(nn.Module):
121
- 'Learned 2x upsampling without padding'
122
- def __init__(self, channels, out_channels=None, ks=5):
123
- super().__init__()
124
- self.channels = channels
125
- self.out_channels = out_channels or channels
126
-
127
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
128
-
129
- def forward(self,x):
130
- return self.up(x)
131
-
132
-
133
- class Downsample(nn.Module):
134
- """
135
- A downsampling layer with an optional convolution.
136
- :param channels: channels in the inputs and outputs.
137
- :param use_conv: a bool determining if a convolution is applied.
138
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
139
- downsampling occurs in the inner-two dimensions.
140
- """
141
-
142
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
143
- super().__init__()
144
- self.channels = channels
145
- self.out_channels = out_channels or channels
146
- self.use_conv = use_conv
147
- self.dims = dims
148
- stride = 2 if dims != 3 else (1, 2, 2)
149
- if use_conv:
150
- self.op = conv_nd(
151
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
152
- )
153
- else:
154
- assert self.channels == self.out_channels
155
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
156
-
157
- def forward(self, x):
158
- assert x.shape[1] == self.channels
159
- return self.op(x)
160
-
161
-
162
- class ResBlock(TimestepBlock):
163
- """
164
- A residual block that can optionally change the number of channels.
165
- :param channels: the number of input channels.
166
- :param emb_channels: the number of timestep embedding channels.
167
- :param dropout: the rate of dropout.
168
- :param out_channels: if specified, the number of out channels.
169
- :param use_conv: if True and out_channels is specified, use a spatial
170
- convolution instead of a smaller 1x1 convolution to change the
171
- channels in the skip connection.
172
- :param dims: determines if the signal is 1D, 2D, or 3D.
173
- :param use_checkpoint: if True, use gradient checkpointing on this module.
174
- :param up: if True, use this block for upsampling.
175
- :param down: if True, use this block for downsampling.
176
- """
177
-
178
- def __init__(
179
- self,
180
- channels,
181
- emb_channels,
182
- dropout,
183
- out_channels=None,
184
- use_conv=False,
185
- use_scale_shift_norm=False,
186
- dims=2,
187
- use_checkpoint=False,
188
- up=False,
189
- down=False,
190
- ):
191
- super().__init__()
192
- self.channels = channels
193
- self.emb_channels = emb_channels
194
- self.dropout = dropout
195
- self.out_channels = out_channels or channels
196
- self.use_conv = use_conv
197
- self.use_checkpoint = use_checkpoint
198
- self.use_scale_shift_norm = use_scale_shift_norm
199
-
200
- self.in_layers = nn.Sequential(
201
- normalization(channels),
202
- nn.SiLU(),
203
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
204
- )
205
-
206
- self.updown = up or down
207
-
208
- if up:
209
- self.h_upd = Upsample(channels, False, dims)
210
- self.x_upd = Upsample(channels, False, dims)
211
- elif down:
212
- self.h_upd = Downsample(channels, False, dims)
213
- self.x_upd = Downsample(channels, False, dims)
214
- else:
215
- self.h_upd = self.x_upd = nn.Identity()
216
-
217
- self.emb_layers = nn.Sequential(
218
- nn.SiLU(),
219
- linear(
220
- emb_channels,
221
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
222
- ),
223
- )
224
- self.out_layers = nn.Sequential(
225
- normalization(self.out_channels),
226
- nn.SiLU(),
227
- nn.Dropout(p=dropout),
228
- zero_module(
229
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
230
- ),
231
- )
232
-
233
- if self.out_channels == channels:
234
- self.skip_connection = nn.Identity()
235
- elif use_conv:
236
- self.skip_connection = conv_nd(
237
- dims, channels, self.out_channels, 3, padding=1
238
- )
239
- else:
240
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
241
-
242
- def forward(self, x, emb):
243
- """
244
- Apply the block to a Tensor, conditioned on a timestep embedding.
245
- :param x: an [N x C x ...] Tensor of features.
246
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
247
- :return: an [N x C x ...] Tensor of outputs.
248
- """
249
- return checkpoint(
250
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
251
- )
252
-
253
-
254
- def _forward(self, x, emb):
255
- if self.updown:
256
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
257
- h = in_rest(x)
258
- h = self.h_upd(h)
259
- x = self.x_upd(x)
260
- h = in_conv(h)
261
- else:
262
- h = self.in_layers(x)
263
- emb_out = self.emb_layers(emb).type(h.dtype)
264
- while len(emb_out.shape) < len(h.shape):
265
- emb_out = emb_out[..., None]
266
- if self.use_scale_shift_norm:
267
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
268
- scale, shift = th.chunk(emb_out, 2, dim=1)
269
- h = out_norm(h) * (1 + scale) + shift
270
- h = out_rest(h)
271
- else:
272
- h = h + emb_out
273
- h = self.out_layers(h)
274
- return self.skip_connection(x) + h
275
-
276
-
277
- class AttentionBlock(nn.Module):
278
- """
279
- An attention block that allows spatial positions to attend to each other.
280
- Originally ported from here, but adapted to the N-d case.
281
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
282
- """
283
-
284
- def __init__(
285
- self,
286
- channels,
287
- num_heads=1,
288
- num_head_channels=-1,
289
- use_checkpoint=False,
290
- use_new_attention_order=False,
291
- ):
292
- super().__init__()
293
- self.channels = channels
294
- if num_head_channels == -1:
295
- self.num_heads = num_heads
296
- else:
297
- assert (
298
- channels % num_head_channels == 0
299
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
300
- self.num_heads = channels // num_head_channels
301
- self.use_checkpoint = use_checkpoint
302
- self.norm = normalization(channels)
303
- self.qkv = conv_nd(1, channels, channels * 3, 1)
304
- if use_new_attention_order:
305
- # split qkv before split heads
306
- self.attention = QKVAttention(self.num_heads)
307
- else:
308
- # split heads before split qkv
309
- self.attention = QKVAttentionLegacy(self.num_heads)
310
-
311
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
312
-
313
- def forward(self, x):
314
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
315
- #return pt_checkpoint(self._forward, x) # pytorch
316
-
317
- def _forward(self, x):
318
- b, c, *spatial = x.shape
319
- x = x.reshape(b, c, -1)
320
- qkv = self.qkv(self.norm(x))
321
- h = self.attention(qkv)
322
- h = self.proj_out(h)
323
- return (x + h).reshape(b, c, *spatial)
324
-
325
-
326
- def count_flops_attn(model, _x, y):
327
- """
328
- A counter for the `thop` package to count the operations in an
329
- attention operation.
330
- Meant to be used like:
331
- macs, params = thop.profile(
332
- model,
333
- inputs=(inputs, timestamps),
334
- custom_ops={QKVAttention: QKVAttention.count_flops},
335
- )
336
- """
337
- b, c, *spatial = y[0].shape
338
- num_spatial = int(np.prod(spatial))
339
- # We perform two matmuls with the same number of ops.
340
- # The first computes the weight matrix, the second computes
341
- # the combination of the value vectors.
342
- matmul_ops = 2 * b * (num_spatial ** 2) * c
343
- model.total_ops += th.DoubleTensor([matmul_ops])
344
-
345
-
346
- class QKVAttentionLegacy(nn.Module):
347
- """
348
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
349
- """
350
-
351
- def __init__(self, n_heads):
352
- super().__init__()
353
- self.n_heads = n_heads
354
-
355
- def forward(self, qkv):
356
- """
357
- Apply QKV attention.
358
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
359
- :return: an [N x (H * C) x T] tensor after attention.
360
- """
361
- bs, width, length = qkv.shape
362
- assert width % (3 * self.n_heads) == 0
363
- ch = width // (3 * self.n_heads)
364
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
365
- scale = 1 / math.sqrt(math.sqrt(ch))
366
- weight = th.einsum(
367
- "bct,bcs->bts", q * scale, k * scale
368
- ) # More stable with f16 than dividing afterwards
369
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
370
- a = th.einsum("bts,bcs->bct", weight, v)
371
- return a.reshape(bs, -1, length)
372
-
373
- @staticmethod
374
- def count_flops(model, _x, y):
375
- return count_flops_attn(model, _x, y)
376
-
377
-
378
- class QKVAttention(nn.Module):
379
- """
380
- A module which performs QKV attention and splits in a different order.
381
- """
382
-
383
- def __init__(self, n_heads):
384
- super().__init__()
385
- self.n_heads = n_heads
386
-
387
- def forward(self, qkv):
388
- """
389
- Apply QKV attention.
390
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
391
- :return: an [N x (H * C) x T] tensor after attention.
392
- """
393
- bs, width, length = qkv.shape
394
- assert width % (3 * self.n_heads) == 0
395
- ch = width // (3 * self.n_heads)
396
- q, k, v = qkv.chunk(3, dim=1)
397
- scale = 1 / math.sqrt(math.sqrt(ch))
398
- weight = th.einsum(
399
- "bct,bcs->bts",
400
- (q * scale).view(bs * self.n_heads, ch, length),
401
- (k * scale).view(bs * self.n_heads, ch, length),
402
- ) # More stable with f16 than dividing afterwards
403
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
404
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
405
- return a.reshape(bs, -1, length)
406
-
407
- @staticmethod
408
- def count_flops(model, _x, y):
409
- return count_flops_attn(model, _x, y)
410
-
411
-
412
- class UNetModel(nn.Module):
413
- """
414
- The full UNet model with attention and timestep embedding.
415
- :param in_channels: channels in the input Tensor.
416
- :param model_channels: base channel count for the model.
417
- :param out_channels: channels in the output Tensor.
418
- :param num_res_blocks: number of residual blocks per downsample.
419
- :param attention_resolutions: a collection of downsample rates at which
420
- attention will take place. May be a set, list, or tuple.
421
- For example, if this contains 4, then at 4x downsampling, attention
422
- will be used.
423
- :param dropout: the dropout probability.
424
- :param channel_mult: channel multiplier for each level of the UNet.
425
- :param conv_resample: if True, use learned convolutions for upsampling and
426
- downsampling.
427
- :param dims: determines if the signal is 1D, 2D, or 3D.
428
- :param num_classes: if specified (as an int), then this model will be
429
- class-conditional with `num_classes` classes.
430
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
431
- :param num_heads: the number of attention heads in each attention layer.
432
- :param num_heads_channels: if specified, ignore num_heads and instead use
433
- a fixed channel width per attention head.
434
- :param num_heads_upsample: works with num_heads to set a different number
435
- of heads for upsampling. Deprecated.
436
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
437
- :param resblock_updown: use residual blocks for up/downsampling.
438
- :param use_new_attention_order: use a different attention pattern for potentially
439
- increased efficiency.
440
- """
441
-
442
- def __init__(
443
- self,
444
- image_size,
445
- in_channels,
446
- model_channels,
447
- out_channels,
448
- num_res_blocks,
449
- attention_resolutions,
450
- dropout=0,
451
- channel_mult=(1, 2, 4, 8),
452
- conv_resample=True,
453
- dims=2,
454
- num_classes=None,
455
- use_checkpoint=False,
456
- use_fp16=False,
457
- num_heads=-1,
458
- num_head_channels=-1,
459
- num_heads_upsample=-1,
460
- use_scale_shift_norm=False,
461
- resblock_updown=False,
462
- use_new_attention_order=False,
463
- use_spatial_transformer=False, # custom transformer support
464
- transformer_depth=1, # custom transformer support
465
- context_dim=None, # custom transformer support
466
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
467
- legacy=True,
468
- disable_self_attentions=None,
469
- num_attention_blocks=None,
470
- disable_middle_self_attn=False,
471
- use_linear_in_transformer=False,
472
- no_control=False,
473
- ):
474
- super().__init__()
475
- self.no_control = no_control
476
- if use_spatial_transformer:
477
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
478
-
479
- if context_dim is not None:
480
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
481
- from omegaconf.listconfig import ListConfig
482
- if type(context_dim) == ListConfig:
483
- context_dim = list(context_dim)
484
-
485
- if num_heads_upsample == -1:
486
- num_heads_upsample = num_heads
487
-
488
- if num_heads == -1:
489
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
490
-
491
- if num_head_channels == -1:
492
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
493
-
494
- self.image_size = image_size
495
- self.in_channels = in_channels
496
- self.model_channels = model_channels
497
- self.out_channels = out_channels
498
- if isinstance(num_res_blocks, int):
499
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
500
- else:
501
- if len(num_res_blocks) != len(channel_mult):
502
- raise ValueError("provide num_res_blocks either as an int (globally constant) or "
503
- "as a list/tuple (per-level) with the same length as channel_mult")
504
- self.num_res_blocks = num_res_blocks
505
- if disable_self_attentions is not None:
506
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
507
- assert len(disable_self_attentions) == len(channel_mult)
508
- if num_attention_blocks is not None:
509
- assert len(num_attention_blocks) == len(self.num_res_blocks)
510
- assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
511
- print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
512
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
513
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
514
- f"attention will still not be set.")
515
-
516
- self.attention_resolutions = attention_resolutions
517
- self.dropout = dropout
518
- self.channel_mult = channel_mult
519
- self.conv_resample = conv_resample
520
- self.num_classes = num_classes
521
- self.use_checkpoint = use_checkpoint
522
- self.dtype = th.float16 if use_fp16 else th.float32
523
- self.num_heads = num_heads
524
- self.num_head_channels = num_head_channels
525
- self.num_heads_upsample = num_heads_upsample
526
- self.predict_codebook_ids = n_embed is not None
527
- self.transformer_depth = transformer_depth
528
- self.context_dim = context_dim
529
- self.use_linear_in_transformer = use_linear_in_transformer
530
-
531
- time_embed_dim = model_channels * 4
532
- self.time_embed = nn.Sequential(
533
- linear(model_channels, time_embed_dim),
534
- nn.SiLU(),
535
- linear(time_embed_dim, time_embed_dim),
536
- )
537
-
538
- if self.num_classes is not None:
539
- if isinstance(self.num_classes, int):
540
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
541
- elif self.num_classes == "continuous":
542
- print("setting up linear c_adm embedding layer")
543
- self.label_emb = nn.Linear(1, time_embed_dim)
544
- else:
545
- raise ValueError()
546
-
547
- self.input_blocks = nn.ModuleList(
548
- [
549
- TimestepEmbedSequential(
550
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
551
- )
552
- ]
553
- )
554
- self._feature_size = model_channels
555
- input_block_chans = [model_channels]
556
- ch = model_channels
557
- ds = 1
558
- for level, mult in enumerate(channel_mult):
559
- for nr in range(self.num_res_blocks[level]):
560
- layers = [
561
- ResBlock(
562
- ch,
563
- time_embed_dim,
564
- dropout,
565
- out_channels=mult * model_channels,
566
- dims=dims,
567
- use_checkpoint=use_checkpoint,
568
- use_scale_shift_norm=use_scale_shift_norm,
569
- )
570
- ]
571
- ch = mult * model_channels
572
- if ds in attention_resolutions:
573
- if num_head_channels == -1:
574
- dim_head = ch // num_heads
575
- else:
576
- num_heads = ch // num_head_channels
577
- dim_head = num_head_channels
578
- if legacy:
579
- #num_heads = 1
580
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
581
- if exists(disable_self_attentions):
582
- disabled_sa = disable_self_attentions[level]
583
- else:
584
- disabled_sa = False
585
-
586
- if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
587
- layers.append(
588
- AttentionBlock(
589
- ch,
590
- use_checkpoint=use_checkpoint,
591
- num_heads=num_heads,
592
- num_head_channels=dim_head,
593
- use_new_attention_order=use_new_attention_order,
594
- ) if not use_spatial_transformer else SpatialTransformer(
595
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
596
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
597
- use_checkpoint=use_checkpoint
598
- )
599
- )
600
- self.input_blocks.append(TimestepEmbedSequential(*layers))
601
- self._feature_size += ch
602
- input_block_chans.append(ch)
603
- if level != len(channel_mult) - 1:
604
- out_ch = ch
605
- self.input_blocks.append(
606
- TimestepEmbedSequential(
607
- ResBlock(
608
- ch,
609
- time_embed_dim,
610
- dropout,
611
- out_channels=out_ch,
612
- dims=dims,
613
- use_checkpoint=use_checkpoint,
614
- use_scale_shift_norm=use_scale_shift_norm,
615
- down=True,
616
- )
617
- if resblock_updown
618
- else Downsample(
619
- ch, conv_resample, dims=dims, out_channels=out_ch
620
- )
621
- )
622
- )
623
- ch = out_ch
624
- input_block_chans.append(ch)
625
- ds *= 2
626
- self._feature_size += ch
627
-
628
- if num_head_channels == -1:
629
- dim_head = ch // num_heads
630
- else:
631
- num_heads = ch // num_head_channels
632
- dim_head = num_head_channels
633
- if legacy:
634
- #num_heads = 1
635
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
636
- self.middle_block = TimestepEmbedSequential(
637
- ResBlock(
638
- ch,
639
- time_embed_dim,
640
- dropout,
641
- dims=dims,
642
- use_checkpoint=use_checkpoint,
643
- use_scale_shift_norm=use_scale_shift_norm,
644
- ),
645
- AttentionBlock(
646
- ch,
647
- use_checkpoint=use_checkpoint,
648
- num_heads=num_heads,
649
- num_head_channels=dim_head,
650
- use_new_attention_order=use_new_attention_order,
651
- ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
652
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
653
- disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
654
- use_checkpoint=use_checkpoint
655
- ),
656
- ResBlock(
657
- ch,
658
- time_embed_dim,
659
- dropout,
660
- dims=dims,
661
- use_checkpoint=use_checkpoint,
662
- use_scale_shift_norm=use_scale_shift_norm,
663
- ),
664
- )
665
- self._feature_size += ch
666
-
667
- self.output_blocks = nn.ModuleList([])
668
- for level, mult in list(enumerate(channel_mult))[::-1]:
669
- for i in range(self.num_res_blocks[level] + 1):
670
- ich = input_block_chans.pop()
671
- layers = [
672
- ResBlock(
673
- ch + ich,
674
- time_embed_dim,
675
- dropout,
676
- out_channels=model_channels * mult,
677
- dims=dims,
678
- use_checkpoint=use_checkpoint,
679
- use_scale_shift_norm=use_scale_shift_norm,
680
- )
681
- ]
682
- ch = model_channels * mult
683
- if ds in attention_resolutions:
684
- if num_head_channels == -1:
685
- dim_head = ch // num_heads
686
- else:
687
- num_heads = ch // num_head_channels
688
- dim_head = num_head_channels
689
- if legacy:
690
- #num_heads = 1
691
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
692
- if exists(disable_self_attentions):
693
- disabled_sa = disable_self_attentions[level]
694
- else:
695
- disabled_sa = False
696
-
697
- if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
698
- layers.append(
699
- AttentionBlock(
700
- ch,
701
- use_checkpoint=use_checkpoint,
702
- num_heads=num_heads_upsample,
703
- num_head_channels=dim_head,
704
- use_new_attention_order=use_new_attention_order,
705
- ) if not use_spatial_transformer else SpatialTransformer(
706
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
707
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
708
- use_checkpoint=use_checkpoint
709
- )
710
- )
711
- if level and i == self.num_res_blocks[level]:
712
- out_ch = ch
713
- layers.append(
714
- ResBlock(
715
- ch,
716
- time_embed_dim,
717
- dropout,
718
- out_channels=out_ch,
719
- dims=dims,
720
- use_checkpoint=use_checkpoint,
721
- use_scale_shift_norm=use_scale_shift_norm,
722
- up=True,
723
- )
724
- if resblock_updown
725
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
726
- )
727
- ds //= 2
728
- self.output_blocks.append(TimestepEmbedSequential(*layers))
729
- self._feature_size += ch
730
-
731
- self.out = nn.Sequential(
732
- normalization(ch),
733
- nn.SiLU(),
734
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
735
- )
736
- if self.predict_codebook_ids:
737
- self.id_predictor = nn.Sequential(
738
- normalization(ch),
739
- conv_nd(dims, model_channels, n_embed, 1),
740
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
741
- )
742
-
743
- def convert_to_fp16(self):
744
- """
745
- Convert the torso of the model to float16.
746
- """
747
- self.input_blocks.apply(convert_module_to_f16)
748
- self.middle_block.apply(convert_module_to_f16)
749
- self.output_blocks.apply(convert_module_to_f16)
750
-
751
- def convert_to_fp32(self):
752
- """
753
- Convert the torso of the model to float32.
754
- """
755
- self.input_blocks.apply(convert_module_to_f32)
756
- self.middle_block.apply(convert_module_to_f32)
757
- self.output_blocks.apply(convert_module_to_f32)
758
-
759
- def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
760
- """
761
- Apply the model to an input batch.
762
- :param x: an [N x C x ...] Tensor of inputs.
763
- :param timesteps: a 1-D batch of timesteps.
764
- :param context: conditioning plugged in via crossattn
765
- :param y: an [N] Tensor of labels, if class-conditional.
766
- :return: an [N x C x ...] Tensor of outputs.
767
- """
768
- assert (y is not None) == (
769
- self.num_classes is not None
770
- ), "must specify y if and only if the model is class-conditional"
771
- hs = []
772
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
773
- emb = self.time_embed(t_emb)
774
- if self.num_classes is not None:
775
- assert y.shape[0] == x.shape[0]
776
- emb = emb + self.label_emb(y)
777
-
778
- h = x.type(self.dtype)
779
- for module in self.input_blocks:
780
- h = module(h, emb, context)
781
- hs.append(h)
782
- h = self.middle_block(h, emb, context)
783
- for module in self.output_blocks:
784
- h = th.cat([h, hs.pop()], dim=1)
785
- h = module(h, emb, context)
786
- h = h.type(x.dtype)
787
- if self.predict_codebook_ids:
788
- return self.id_predictor(h)
789
- else:
790
- return self.out(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/modules/diffusionmodules/upscaling.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- from functools import partial
5
-
6
- from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7
- from ldm.util import default
8
-
9
-
10
- class AbstractLowScaleModel(nn.Module):
11
- # for concatenating a downsampled image to the latent representation
12
- def __init__(self, noise_schedule_config=None):
13
- super(AbstractLowScaleModel, self).__init__()
14
- if noise_schedule_config is not None:
15
- self.register_schedule(**noise_schedule_config)
16
-
17
- def register_schedule(self, beta_schedule="linear", timesteps=1000,
18
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20
- cosine_s=cosine_s)
21
- alphas = 1. - betas
22
- alphas_cumprod = np.cumprod(alphas, axis=0)
23
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24
-
25
- timesteps, = betas.shape
26
- self.num_timesteps = int(timesteps)
27
- self.linear_start = linear_start
28
- self.linear_end = linear_end
29
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30
-
31
- to_torch = partial(torch.tensor, dtype=torch.float32)
32
-
33
- self.register_buffer('betas', to_torch(betas))
34
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36
-
37
- # calculations for diffusion q(x_t | x_{t-1}) and others
38
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43
-
44
- def q_sample(self, x_start, t, noise=None):
45
- noise = default(noise, lambda: torch.randn_like(x_start))
46
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48
-
49
- def forward(self, x):
50
- return x, None
51
-
52
- def decode(self, x):
53
- return x
54
-
55
-
56
- class SimpleImageConcat(AbstractLowScaleModel):
57
- # no noise level conditioning
58
- def __init__(self):
59
- super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60
- self.max_noise_level = 0
61
-
62
- def forward(self, x):
63
- # fix to constant noise level
64
- return x, torch.zeros(x.shape[0], device=x.device).long()
65
-
66
-
67
- class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68
- def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69
- super().__init__(noise_schedule_config=noise_schedule_config)
70
- self.max_noise_level = max_noise_level
71
-
72
- def forward(self, x, noise_level=None):
73
- if noise_level is None:
74
- noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75
- else:
76
- assert isinstance(noise_level, torch.Tensor)
77
- z = self.q_sample(x, noise_level)
78
- return z, noise_level
79
-
80
-
81
-