shgao commited on
Commit
0c7479d
1 Parent(s): d8212b4

update new demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +0 -34
  2. README.md +0 -15
  3. annotator/util.py +18 -0
  4. app.py +17 -11
  5. cldm/cldm.py +435 -0
  6. cldm/ddim_hacked.py +316 -0
  7. cldm/hack.py +111 -0
  8. cldm/logger.py +76 -0
  9. cldm/model.py +28 -0
  10. config.py +1 -0
  11. dataset_build.py +45 -0
  12. sam2edit.py → editany.py +19 -11
  13. editany_beauty.py +66 -0
  14. editany_demo.py +394 -0
  15. sam2edit_handsome.py → editany_handsome.py +28 -15
  16. sam2edit_lora.py → editany_lora.py +465 -175
  17. editany_test.py +73 -0
  18. environment.yaml +38 -0
  19. font/DejaVuSans.ttf +0 -0
  20. ldm/data/__init__.py +0 -0
  21. ldm/data/util.py +24 -0
  22. ldm/models/autoencoder.py +219 -0
  23. ldm/models/diffusion/__init__.py +0 -0
  24. ldm/models/diffusion/ddim.py +336 -0
  25. ldm/models/diffusion/ddpm.py +1797 -0
  26. ldm/models/diffusion/dpm_solver/__init__.py +1 -0
  27. ldm/models/diffusion/dpm_solver/dpm_solver.py +1154 -0
  28. ldm/models/diffusion/dpm_solver/sampler.py +87 -0
  29. ldm/models/diffusion/plms.py +244 -0
  30. ldm/models/diffusion/sampling_util.py +22 -0
  31. ldm/modules/attention.py +341 -0
  32. ldm/modules/diffusionmodules/__init__.py +0 -0
  33. ldm/modules/diffusionmodules/model.py +852 -0
  34. ldm/modules/diffusionmodules/openaimodel.py +786 -0
  35. ldm/modules/diffusionmodules/upscaling.py +81 -0
  36. ldm/modules/diffusionmodules/util.py +270 -0
  37. ldm/modules/distributions/__init__.py +0 -0
  38. ldm/modules/distributions/distributions.py +92 -0
  39. ldm/modules/ema.py +80 -0
  40. ldm/modules/encoders/__init__.py +0 -0
  41. ldm/modules/encoders/modules.py +213 -0
  42. ldm/modules/image_degradation/__init__.py +2 -0
  43. ldm/modules/image_degradation/bsrgan.py +730 -0
  44. ldm/modules/image_degradation/bsrgan_light.py +651 -0
  45. ldm/modules/image_degradation/utils/test.png +0 -0
  46. ldm/modules/image_degradation/utils_image.py +916 -0
  47. ldm/modules/midas/__init__.py +0 -0
  48. ldm/modules/midas/api.py +170 -0
  49. ldm/modules/midas/midas/__init__.py +0 -0
  50. ldm/modules/midas/midas/base_model.py +16 -0
.gitattributes DELETED
@@ -1,34 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tflite filter=lfs diff=lfs merge=lfs -text
29
- *.tgz filter=lfs diff=lfs merge=lfs -text
30
- *.wasm filter=lfs diff=lfs merge=lfs -text
31
- *.xz filter=lfs diff=lfs merge=lfs -text
32
- *.zip filter=lfs diff=lfs merge=lfs -text
33
- *.zst filter=lfs diff=lfs merge=lfs -text
34
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,15 +0,0 @@
1
- ---
2
- title: EditAnything
3
- emoji: 🦀
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.24.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
-
14
- # Edit Anything by Segment-Anything
15
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
annotator/util.py CHANGED
@@ -53,3 +53,21 @@ def resize_points(clicked_points, original_shape, resolution):
53
  resized_points.append(resized_point)
54
 
55
  return resized_points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  resized_points.append(resized_point)
54
 
55
  return resized_points
56
+
57
+ def get_bounding_box(mask):
58
+ # Convert PIL Image to numpy array
59
+ mask = np.array(mask).astype(np.uint8)
60
+
61
+ # Take the first channel (R) of the mask
62
+ mask = mask[:,:,0]
63
+
64
+ # Get the indices of elements that are not zero
65
+ rows = np.any(mask, axis=0)
66
+ cols = np.any(mask, axis=1)
67
+
68
+ # Get the minimum and maximum indices where the elements are not zero
69
+ rmin, rmax = np.where(rows)[0][[0, -1]]
70
+ cmin, cmax = np.where(cols)[0][[0, -1]]
71
+
72
+ # Return as [xmin, ymin, xmax, ymax]
73
+ return [rmin, cmin, rmax, cmax]
app.py CHANGED
@@ -1,15 +1,11 @@
1
- import subprocess
2
- result = subprocess.run(
3
- ['pip', 'install', 'gradio==3.28.3'], check=True)
4
  import gradio as gr
5
-
6
  import os
7
 
8
- from sam2edit import create_demo as create_demo_edit_anything
9
  from sam2image import create_demo as create_demo_generate_anything
10
- from sam2edit_beauty import create_demo as create_demo_beauty
11
- from sam2edit_handsome import create_demo as create_demo_handsome
12
- from sam2edit_lora import EditAnythingLoraModel, init_sam_model, init_blip_processor, init_blip_model
13
  from huggingface_hub import hf_hub_download, snapshot_download
14
 
15
  DESCRIPTION = f'''# [Edit Anything](https://github.com/sail-sg/EditAnything)
@@ -29,9 +25,9 @@ with gr.Blocks() as demo:
29
  gr.Markdown(DESCRIPTION)
30
  with gr.Tabs():
31
  with gr.TabItem('🖌Edit Anything'):
32
- model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2-inpainting",
33
- controlmodel_name='LAION Pretrained(v0-4)-SD21',
34
- lora_model_path=None, use_blip=True, extra_inpaint=False,
35
  sam_generator=sam_generator,
36
  mask_predictor=mask_predictor,
37
  blip_processor=blip_processor,
@@ -57,6 +53,16 @@ with gr.Blocks() as demo:
57
  # blip_processor=blip_processor,
58
  # blip_model=blip_model)
59
  # create_demo_handsome(model.process, model.process_image_click)
 
 
 
 
 
 
 
 
 
 
60
  # with gr.TabItem('Generate Anything'):
61
  # create_demo_generate_anything()
62
  with gr.Tabs():
 
 
 
 
1
  import gradio as gr
 
2
  import os
3
 
4
+ from editany import create_demo as create_demo_edit_anything
5
  from sam2image import create_demo as create_demo_generate_anything
6
+ from editany_beauty import create_demo as create_demo_beauty
7
+ from editany_handsome import create_demo as create_demo_handsome
8
+ from editany_lora import EditAnythingLoraModel, init_sam_model, init_blip_processor, init_blip_model
9
  from huggingface_hub import hf_hub_download, snapshot_download
10
 
11
  DESCRIPTION = f'''# [Edit Anything](https://github.com/sail-sg/EditAnything)
 
25
  gr.Markdown(DESCRIPTION)
26
  with gr.Tabs():
27
  with gr.TabItem('🖌Edit Anything'):
28
+ model = EditAnythingLoraModel(base_model_path="runwayml/stable-diffusion-v1-5",
29
+ controlmodel_name='LAION Pretrained(v0-4)-SD15',
30
+ lora_model_path=None, use_blip=True, extra_inpaint=True,
31
  sam_generator=sam_generator,
32
  mask_predictor=mask_predictor,
33
  blip_processor=blip_processor,
 
53
  # blip_processor=blip_processor,
54
  # blip_model=blip_model)
55
  # create_demo_handsome(model.process, model.process_image_click)
56
+ # with gr.TabItem('Edit More'):
57
+ # model = EditAnythingLoraModel(base_model_path="andite/anything-v4.0",
58
+ # lora_model_path=None, use_blip=True, extra_inpaint=True,
59
+ # sam_generator=sam_generator,
60
+ # mask_predictor=mask_predictor,
61
+ # blip_processor=blip_processor,
62
+ # blip_model=blip_model,
63
+ # lora_weight=0.5,
64
+ # )
65
+ create_demo_beauty(model.process, model.process_image_click)
66
  # with gr.TabItem('Generate Anything'):
67
  # create_demo_generate_anything()
68
  with gr.Tabs():
cldm/cldm.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ from ldm.modules.diffusionmodules.util import (
7
+ conv_nd,
8
+ linear,
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from einops import rearrange, repeat
14
+ from torchvision.utils import make_grid
15
+ from ldm.modules.attention import SpatialTransformer
16
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
17
+ from ldm.models.diffusion.ddpm import LatentDiffusion
18
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
19
+ from ldm.models.diffusion.ddim import DDIMSampler
20
+
21
+
22
+ class ControlledUnetModel(UNetModel):
23
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
24
+ hs = []
25
+ with torch.no_grad():
26
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
27
+ emb = self.time_embed(t_emb)
28
+ h = x.type(self.dtype)
29
+ for module in self.input_blocks:
30
+ h = module(h, emb, context)
31
+ hs.append(h)
32
+ h = self.middle_block(h, emb, context)
33
+
34
+ if control is not None:
35
+ h += control.pop()
36
+
37
+ for i, module in enumerate(self.output_blocks):
38
+ if only_mid_control or control is None:
39
+ h = torch.cat([h, hs.pop()], dim=1)
40
+ else:
41
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
42
+ h = module(h, emb, context)
43
+
44
+ h = h.type(x.dtype)
45
+ return self.out(h)
46
+
47
+
48
+ class ControlNet(nn.Module):
49
+ def __init__(
50
+ self,
51
+ image_size,
52
+ in_channels,
53
+ model_channels,
54
+ hint_channels,
55
+ num_res_blocks,
56
+ attention_resolutions,
57
+ dropout=0,
58
+ channel_mult=(1, 2, 4, 8),
59
+ conv_resample=True,
60
+ dims=2,
61
+ use_checkpoint=False,
62
+ use_fp16=False,
63
+ num_heads=-1,
64
+ num_head_channels=-1,
65
+ num_heads_upsample=-1,
66
+ use_scale_shift_norm=False,
67
+ resblock_updown=False,
68
+ use_new_attention_order=False,
69
+ use_spatial_transformer=False, # custom transformer support
70
+ transformer_depth=1, # custom transformer support
71
+ context_dim=None, # custom transformer support
72
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
73
+ legacy=True,
74
+ disable_self_attentions=None,
75
+ num_attention_blocks=None,
76
+ disable_middle_self_attn=False,
77
+ use_linear_in_transformer=False,
78
+ ):
79
+ super().__init__()
80
+ if use_spatial_transformer:
81
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
82
+
83
+ if context_dim is not None:
84
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
85
+ from omegaconf.listconfig import ListConfig
86
+ if type(context_dim) == ListConfig:
87
+ context_dim = list(context_dim)
88
+
89
+ if num_heads_upsample == -1:
90
+ num_heads_upsample = num_heads
91
+
92
+ if num_heads == -1:
93
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
94
+
95
+ if num_head_channels == -1:
96
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
97
+
98
+ self.dims = dims
99
+ self.image_size = image_size
100
+ self.in_channels = in_channels
101
+ self.model_channels = model_channels
102
+ if isinstance(num_res_blocks, int):
103
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
104
+ else:
105
+ if len(num_res_blocks) != len(channel_mult):
106
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
107
+ "as a list/tuple (per-level) with the same length as channel_mult")
108
+ self.num_res_blocks = num_res_blocks
109
+ if disable_self_attentions is not None:
110
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
111
+ assert len(disable_self_attentions) == len(channel_mult)
112
+ if num_attention_blocks is not None:
113
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
114
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
115
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
116
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
117
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
118
+ f"attention will still not be set.")
119
+
120
+ self.attention_resolutions = attention_resolutions
121
+ self.dropout = dropout
122
+ self.channel_mult = channel_mult
123
+ self.conv_resample = conv_resample
124
+ self.use_checkpoint = use_checkpoint
125
+ self.dtype = th.float16 if use_fp16 else th.float32
126
+ self.num_heads = num_heads
127
+ self.num_head_channels = num_head_channels
128
+ self.num_heads_upsample = num_heads_upsample
129
+ self.predict_codebook_ids = n_embed is not None
130
+
131
+ time_embed_dim = model_channels * 4
132
+ self.time_embed = nn.Sequential(
133
+ linear(model_channels, time_embed_dim),
134
+ nn.SiLU(),
135
+ linear(time_embed_dim, time_embed_dim),
136
+ )
137
+
138
+ self.input_blocks = nn.ModuleList(
139
+ [
140
+ TimestepEmbedSequential(
141
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
142
+ )
143
+ ]
144
+ )
145
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
146
+
147
+ self.input_hint_block = TimestepEmbedSequential(
148
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
149
+ nn.SiLU(),
150
+ conv_nd(dims, 16, 16, 3, padding=1),
151
+ nn.SiLU(),
152
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
153
+ nn.SiLU(),
154
+ conv_nd(dims, 32, 32, 3, padding=1),
155
+ nn.SiLU(),
156
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
157
+ nn.SiLU(),
158
+ conv_nd(dims, 96, 96, 3, padding=1),
159
+ nn.SiLU(),
160
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
161
+ nn.SiLU(),
162
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
163
+ )
164
+
165
+ self._feature_size = model_channels
166
+ input_block_chans = [model_channels]
167
+ ch = model_channels
168
+ ds = 1
169
+ for level, mult in enumerate(channel_mult):
170
+ for nr in range(self.num_res_blocks[level]):
171
+ layers = [
172
+ ResBlock(
173
+ ch,
174
+ time_embed_dim,
175
+ dropout,
176
+ out_channels=mult * model_channels,
177
+ dims=dims,
178
+ use_checkpoint=use_checkpoint,
179
+ use_scale_shift_norm=use_scale_shift_norm,
180
+ )
181
+ ]
182
+ ch = mult * model_channels
183
+ if ds in attention_resolutions:
184
+ if num_head_channels == -1:
185
+ dim_head = ch // num_heads
186
+ else:
187
+ num_heads = ch // num_head_channels
188
+ dim_head = num_head_channels
189
+ if legacy:
190
+ # num_heads = 1
191
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
192
+ if exists(disable_self_attentions):
193
+ disabled_sa = disable_self_attentions[level]
194
+ else:
195
+ disabled_sa = False
196
+
197
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
198
+ layers.append(
199
+ AttentionBlock(
200
+ ch,
201
+ use_checkpoint=use_checkpoint,
202
+ num_heads=num_heads,
203
+ num_head_channels=dim_head,
204
+ use_new_attention_order=use_new_attention_order,
205
+ ) if not use_spatial_transformer else SpatialTransformer(
206
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
207
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
208
+ use_checkpoint=use_checkpoint
209
+ )
210
+ )
211
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
212
+ self.zero_convs.append(self.make_zero_conv(ch))
213
+ self._feature_size += ch
214
+ input_block_chans.append(ch)
215
+ if level != len(channel_mult) - 1:
216
+ out_ch = ch
217
+ self.input_blocks.append(
218
+ TimestepEmbedSequential(
219
+ ResBlock(
220
+ ch,
221
+ time_embed_dim,
222
+ dropout,
223
+ out_channels=out_ch,
224
+ dims=dims,
225
+ use_checkpoint=use_checkpoint,
226
+ use_scale_shift_norm=use_scale_shift_norm,
227
+ down=True,
228
+ )
229
+ if resblock_updown
230
+ else Downsample(
231
+ ch, conv_resample, dims=dims, out_channels=out_ch
232
+ )
233
+ )
234
+ )
235
+ ch = out_ch
236
+ input_block_chans.append(ch)
237
+ self.zero_convs.append(self.make_zero_conv(ch))
238
+ ds *= 2
239
+ self._feature_size += ch
240
+
241
+ if num_head_channels == -1:
242
+ dim_head = ch // num_heads
243
+ else:
244
+ num_heads = ch // num_head_channels
245
+ dim_head = num_head_channels
246
+ if legacy:
247
+ # num_heads = 1
248
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
249
+ self.middle_block = TimestepEmbedSequential(
250
+ ResBlock(
251
+ ch,
252
+ time_embed_dim,
253
+ dropout,
254
+ dims=dims,
255
+ use_checkpoint=use_checkpoint,
256
+ use_scale_shift_norm=use_scale_shift_norm,
257
+ ),
258
+ AttentionBlock(
259
+ ch,
260
+ use_checkpoint=use_checkpoint,
261
+ num_heads=num_heads,
262
+ num_head_channels=dim_head,
263
+ use_new_attention_order=use_new_attention_order,
264
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
265
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
266
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
267
+ use_checkpoint=use_checkpoint
268
+ ),
269
+ ResBlock(
270
+ ch,
271
+ time_embed_dim,
272
+ dropout,
273
+ dims=dims,
274
+ use_checkpoint=use_checkpoint,
275
+ use_scale_shift_norm=use_scale_shift_norm,
276
+ ),
277
+ )
278
+ self.middle_block_out = self.make_zero_conv(ch)
279
+ self._feature_size += ch
280
+
281
+ def make_zero_conv(self, channels):
282
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
283
+
284
+ def forward(self, x, hint, timesteps, context, **kwargs):
285
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
286
+ emb = self.time_embed(t_emb)
287
+
288
+ guided_hint = self.input_hint_block(hint, emb, context)
289
+
290
+ outs = []
291
+
292
+ h = x.type(self.dtype)
293
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
294
+ if guided_hint is not None:
295
+ h = module(h, emb, context)
296
+ h += guided_hint
297
+ guided_hint = None
298
+ else:
299
+ h = module(h, emb, context)
300
+ outs.append(zero_conv(h, emb, context))
301
+
302
+ h = self.middle_block(h, emb, context)
303
+ outs.append(self.middle_block_out(h, emb, context))
304
+
305
+ return outs
306
+
307
+
308
+ class ControlLDM(LatentDiffusion):
309
+
310
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
311
+ super().__init__(*args, **kwargs)
312
+ self.control_model = instantiate_from_config(control_stage_config)
313
+ self.control_key = control_key
314
+ self.only_mid_control = only_mid_control
315
+ self.control_scales = [1.0] * 13
316
+
317
+ @torch.no_grad()
318
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
319
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
320
+ control = batch[self.control_key]
321
+ if bs is not None:
322
+ control = control[:bs]
323
+ control = control.to(self.device)
324
+ control = einops.rearrange(control, 'b h w c -> b c h w')
325
+ control = control.to(memory_format=torch.contiguous_format).float()
326
+ return x, dict(c_crossattn=[c], c_concat=[control])
327
+
328
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
329
+ assert isinstance(cond, dict)
330
+ diffusion_model = self.model.diffusion_model
331
+
332
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
333
+
334
+ if cond['c_concat'] is None:
335
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
336
+ else:
337
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
338
+ control = [c * scale for c, scale in zip(control, self.control_scales)]
339
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
340
+
341
+ return eps
342
+
343
+ @torch.no_grad()
344
+ def get_unconditional_conditioning(self, N):
345
+ return self.get_learned_conditioning([""] * N)
346
+
347
+ @torch.no_grad()
348
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
349
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
350
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
351
+ use_ema_scope=True,
352
+ **kwargs):
353
+ use_ddim = ddim_steps is not None
354
+
355
+ log = dict()
356
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
357
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
358
+ N = min(z.shape[0], N)
359
+ n_row = min(z.shape[0], n_row)
360
+ log["reconstruction"] = self.decode_first_stage(z)
361
+ log["control"] = c_cat * 2.0 - 1.0
362
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
363
+
364
+ if plot_diffusion_rows:
365
+ # get diffusion row
366
+ diffusion_row = list()
367
+ z_start = z[:n_row]
368
+ for t in range(self.num_timesteps):
369
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
370
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
371
+ t = t.to(self.device).long()
372
+ noise = torch.randn_like(z_start)
373
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
374
+ diffusion_row.append(self.decode_first_stage(z_noisy))
375
+
376
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
377
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
378
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
379
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
380
+ log["diffusion_row"] = diffusion_grid
381
+
382
+ if sample:
383
+ # get denoise row
384
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
385
+ batch_size=N, ddim=use_ddim,
386
+ ddim_steps=ddim_steps, eta=ddim_eta)
387
+ x_samples = self.decode_first_stage(samples)
388
+ log["samples"] = x_samples
389
+ if plot_denoise_rows:
390
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
391
+ log["denoise_row"] = denoise_grid
392
+
393
+ if unconditional_guidance_scale > 1.0:
394
+ uc_cross = self.get_unconditional_conditioning(N)
395
+ uc_cat = c_cat # torch.zeros_like(c_cat)
396
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
397
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
398
+ batch_size=N, ddim=use_ddim,
399
+ ddim_steps=ddim_steps, eta=ddim_eta,
400
+ unconditional_guidance_scale=unconditional_guidance_scale,
401
+ unconditional_conditioning=uc_full,
402
+ )
403
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
404
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
405
+
406
+ return log
407
+
408
+ @torch.no_grad()
409
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
410
+ ddim_sampler = DDIMSampler(self)
411
+ b, c, h, w = cond["c_concat"][0].shape
412
+ shape = (self.channels, h // 8, w // 8)
413
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
414
+ return samples, intermediates
415
+
416
+ def configure_optimizers(self):
417
+ lr = self.learning_rate
418
+ params = list(self.control_model.parameters())
419
+ if not self.sd_locked:
420
+ params += list(self.model.diffusion_model.output_blocks.parameters())
421
+ params += list(self.model.diffusion_model.out.parameters())
422
+ opt = torch.optim.AdamW(params, lr=lr)
423
+ return opt
424
+
425
+ def low_vram_shift(self, is_diffusing):
426
+ if is_diffusing:
427
+ self.model = self.model.cuda()
428
+ self.control_model = self.control_model.cuda()
429
+ self.first_stage_model = self.first_stage_model.cpu()
430
+ self.cond_stage_model = self.cond_stage_model.cpu()
431
+ else:
432
+ self.model = self.model.cpu()
433
+ self.control_model = self.control_model.cpu()
434
+ self.first_stage_model = self.first_stage_model.cuda()
435
+ self.cond_stage_model = self.cond_stage_model.cuda()
cldm/ddim_hacked.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+
88
+ elif isinstance(conditioning, list):
89
+ for ctmp in conditioning:
90
+ if ctmp.shape[0] != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+
93
+ else:
94
+ if conditioning.shape[0] != batch_size:
95
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ C, H, W = shape
100
+ size = (batch_size, C, H, W)
101
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
+
103
+ samples, intermediates = self.ddim_sampling(conditioning, size,
104
+ callback=callback,
105
+ img_callback=img_callback,
106
+ quantize_denoised=quantize_x0,
107
+ mask=mask, x0=x0,
108
+ ddim_use_original_steps=False,
109
+ noise_dropout=noise_dropout,
110
+ temperature=temperature,
111
+ score_corrector=score_corrector,
112
+ corrector_kwargs=corrector_kwargs,
113
+ x_T=x_T,
114
+ log_every_t=log_every_t,
115
+ unconditional_guidance_scale=unconditional_guidance_scale,
116
+ unconditional_conditioning=unconditional_conditioning,
117
+ dynamic_threshold=dynamic_threshold,
118
+ ucg_schedule=ucg_schedule
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def ddim_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
+ ucg_schedule=None):
130
+ device = self.model.betas.device
131
+ b = shape[0]
132
+ if x_T is None:
133
+ img = torch.randn(shape, device=device)
134
+ else:
135
+ img = x_T
136
+
137
+ if timesteps is None:
138
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
+ elif timesteps is not None and not ddim_use_original_steps:
140
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
+ timesteps = self.ddim_timesteps[:subset_end]
142
+
143
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
147
+
148
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
+
150
+ for i, step in enumerate(iterator):
151
+ index = total_steps - i - 1
152
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
153
+
154
+ if mask is not None:
155
+ assert x0 is not None
156
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
+ img = img_orig * mask + (1. - mask) * img
158
+
159
+ if ucg_schedule is not None:
160
+ assert len(ucg_schedule) == len(time_range)
161
+ unconditional_guidance_scale = ucg_schedule[i]
162
+
163
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
+ quantize_denoised=quantize_denoised, temperature=temperature,
165
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
166
+ corrector_kwargs=corrector_kwargs,
167
+ unconditional_guidance_scale=unconditional_guidance_scale,
168
+ unconditional_conditioning=unconditional_conditioning,
169
+ dynamic_threshold=dynamic_threshold)
170
+ img, pred_x0 = outs
171
+ if callback: callback(i)
172
+ if img_callback: img_callback(pred_x0, i)
173
+
174
+ if index % log_every_t == 0 or index == total_steps - 1:
175
+ intermediates['x_inter'].append(img)
176
+ intermediates['pred_x0'].append(pred_x0)
177
+
178
+ return img, intermediates
179
+
180
+ @torch.no_grad()
181
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
184
+ dynamic_threshold=None):
185
+ b, *_, device = *x.shape, x.device
186
+
187
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
+ model_output = self.model.apply_model(x, t, c)
189
+ else:
190
+ model_t = self.model.apply_model(x, t, c)
191
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
192
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
193
+
194
+ if self.model.parameterization == "v":
195
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
196
+ else:
197
+ e_t = model_output
198
+
199
+ if score_corrector is not None:
200
+ assert self.model.parameterization == "eps", 'not implemented'
201
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
202
+
203
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
204
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
205
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
206
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
207
+ # select parameters corresponding to the currently considered timestep
208
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
209
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
210
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
211
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
212
+
213
+ # current prediction for x_0
214
+ if self.model.parameterization != "v":
215
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
216
+ else:
217
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
218
+
219
+ if quantize_denoised:
220
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
221
+
222
+ if dynamic_threshold is not None:
223
+ raise NotImplementedError()
224
+
225
+ # direction pointing to x_t
226
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
227
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
228
+ if noise_dropout > 0.:
229
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
230
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
231
+ return x_prev, pred_x0
232
+
233
+ @torch.no_grad()
234
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
235
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
236
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
237
+
238
+ assert t_enc <= num_reference_steps
239
+ num_steps = t_enc
240
+
241
+ if use_original_steps:
242
+ alphas_next = self.alphas_cumprod[:num_steps]
243
+ alphas = self.alphas_cumprod_prev[:num_steps]
244
+ else:
245
+ alphas_next = self.ddim_alphas[:num_steps]
246
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
247
+
248
+ x_next = x0
249
+ intermediates = []
250
+ inter_steps = []
251
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
252
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
253
+ if unconditional_guidance_scale == 1.:
254
+ noise_pred = self.model.apply_model(x_next, t, c)
255
+ else:
256
+ assert unconditional_conditioning is not None
257
+ e_t_uncond, noise_pred = torch.chunk(
258
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
259
+ torch.cat((unconditional_conditioning, c))), 2)
260
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
261
+
262
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
263
+ weighted_noise_pred = alphas_next[i].sqrt() * (
264
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
265
+ x_next = xt_weighted + weighted_noise_pred
266
+ if return_intermediates and i % (
267
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
268
+ intermediates.append(x_next)
269
+ inter_steps.append(i)
270
+ elif return_intermediates and i >= num_steps - 2:
271
+ intermediates.append(x_next)
272
+ inter_steps.append(i)
273
+ if callback: callback(i)
274
+
275
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
276
+ if return_intermediates:
277
+ out.update({'intermediates': intermediates})
278
+ return x_next, out
279
+
280
+ @torch.no_grad()
281
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
282
+ # fast, but does not allow for exact reconstruction
283
+ # t serves as an index to gather the correct alphas
284
+ if use_original_steps:
285
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
286
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
287
+ else:
288
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
289
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
290
+
291
+ if noise is None:
292
+ noise = torch.randn_like(x0)
293
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
294
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
295
+
296
+ @torch.no_grad()
297
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
298
+ use_original_steps=False, callback=None):
299
+
300
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
301
+ timesteps = timesteps[:t_start]
302
+
303
+ time_range = np.flip(timesteps)
304
+ total_steps = timesteps.shape[0]
305
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
306
+
307
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
308
+ x_dec = x_latent
309
+ for i, step in enumerate(iterator):
310
+ index = total_steps - i - 1
311
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
312
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
313
+ unconditional_guidance_scale=unconditional_guidance_scale,
314
+ unconditional_conditioning=unconditional_conditioning)
315
+ if callback: callback(i)
316
+ return x_dec
cldm/hack.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import ldm.modules.encoders.modules
5
+ import ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from ldm.modules.attention import default
9
+
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ print('logging improved.')
14
+ return
15
+
16
+
17
+ def enable_sliced_attention():
18
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
+ print('Enabled sliced_attention.')
20
+ return
21
+
22
+
23
+ def hack_everything(clip_skip=0):
24
+ disable_verbosity()
25
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
+ print('Enabled clip hacks.')
28
+ return
29
+
30
+
31
+ # Written by Lvmin
32
+ def _hacked_clip_forward(self, text):
33
+ PAD = self.tokenizer.pad_token_id
34
+ EOS = self.tokenizer.eos_token_id
35
+ BOS = self.tokenizer.bos_token_id
36
+
37
+ def tokenize(t):
38
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
+
40
+ def transformer_encode(t):
41
+ if self.clip_skip > 1:
42
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
43
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
+ else:
45
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
+
47
+ def split(x):
48
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
+
50
+ def pad(x, p, i):
51
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
+
53
+ raw_tokens_list = tokenize(text)
54
+ tokens_list = []
55
+
56
+ for raw_tokens in raw_tokens_list:
57
+ raw_tokens_123 = split(raw_tokens)
58
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
+ tokens_list.append(raw_tokens_123)
61
+
62
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
+
64
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
+ y = transformer_encode(feed)
66
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
+
68
+ return z
69
+
70
+
71
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
+ h = self.heads
74
+
75
+ q = self.to_q(x)
76
+ context = default(context, x)
77
+ k = self.to_k(context)
78
+ v = self.to_v(context)
79
+ del context, x
80
+
81
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
+
83
+ limit = k.shape[0]
84
+ att_step = 1
85
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
+
89
+ q_chunks.reverse()
90
+ k_chunks.reverse()
91
+ v_chunks.reverse()
92
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
+ del k, q, v
94
+ for i in range(0, limit, att_step):
95
+ q_buffer = q_chunks.pop()
96
+ k_buffer = k_chunks.pop()
97
+ v_buffer = v_chunks.pop()
98
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
+
100
+ del k_buffer, q_buffer
101
+ # attention, what we cannot get enough of, by chunks
102
+
103
+ sim_buffer = sim_buffer.softmax(dim=-1)
104
+
105
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
+ del v_buffer
107
+ sim[i:i + att_step, :, :] = sim_buffer
108
+
109
+ del sim_buffer
110
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
+ return self.to_out(sim)
cldm/logger.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+ from PIL import Image
7
+ from pytorch_lightning.callbacks import Callback
8
+ from pytorch_lightning.utilities.distributed import rank_zero_only
9
+
10
+
11
+ class ImageLogger(Callback):
12
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
13
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
14
+ log_images_kwargs=None):
15
+ super().__init__()
16
+ self.rescale = rescale
17
+ self.batch_freq = batch_frequency
18
+ self.max_images = max_images
19
+ if not increase_log_steps:
20
+ self.log_steps = [self.batch_freq]
21
+ self.clamp = clamp
22
+ self.disabled = disabled
23
+ self.log_on_batch_idx = log_on_batch_idx
24
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
25
+ self.log_first_step = log_first_step
26
+
27
+ @rank_zero_only
28
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
29
+ root = os.path.join(save_dir, "image_log", split)
30
+ for k in images:
31
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
32
+ if self.rescale:
33
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
34
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
35
+ grid = grid.numpy()
36
+ grid = (grid * 255).astype(np.uint8)
37
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
38
+ path = os.path.join(root, filename)
39
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
40
+ Image.fromarray(grid).save(path)
41
+
42
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
43
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
44
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
45
+ hasattr(pl_module, "log_images") and
46
+ callable(pl_module.log_images) and
47
+ self.max_images > 0):
48
+ logger = type(pl_module.logger)
49
+
50
+ is_train = pl_module.training
51
+ if is_train:
52
+ pl_module.eval()
53
+
54
+ with torch.no_grad():
55
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
56
+
57
+ for k in images:
58
+ N = min(images[k].shape[0], self.max_images)
59
+ images[k] = images[k][:N]
60
+ if isinstance(images[k], torch.Tensor):
61
+ images[k] = images[k].detach().cpu()
62
+ if self.clamp:
63
+ images[k] = torch.clamp(images[k], -1., 1.)
64
+
65
+ self.log_local(pl_module.logger.save_dir, split, images,
66
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
67
+
68
+ if is_train:
69
+ pl_module.train()
70
+
71
+ def check_frequency(self, check_idx):
72
+ return check_idx % self.batch_freq == 0
73
+
74
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
75
+ if not self.disabled:
76
+ self.log_img(pl_module, batch, batch_idx, split="train")
cldm/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path):
25
+ config = OmegaConf.load(config_path)
26
+ model = instantiate_from_config(config.model).cpu()
27
+ print(f'Loaded model config from [{config_path}]')
28
+ return model
config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ save_memory = False
dataset_build.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import json
3
+
4
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
5
+ import torch
6
+ import os
7
+
8
+ processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
9
+ model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16)
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ model.to(device)
13
+
14
+ def get_blip2_text(image):
15
+ inputs = processor(image, return_tensors="pt").to(device, torch.float16)
16
+ generated_ids = model.generate(**inputs, max_new_tokens=50)
17
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
18
+ return generated_text
19
+
20
+
21
+ data_path = "files"
22
+ save_path = ""
23
+
24
+ image_names = os.listdir(data_path)
25
+ image_names = sorted(image_names)
26
+
27
+ text_data = {}
28
+ f = open("data.txt","w")
29
+ for each in image_names:
30
+ if '.jpg' in each:
31
+ this_data = {}
32
+ this_data['target'] = each
33
+ this_data['source'] = each[:-4]+'.json'
34
+ this_image = Image.open(os.path.join(data_path, each))
35
+ print(each)
36
+ generated_text = get_blip2_text(this_image)
37
+ this_data['prompt'] = generated_text
38
+ print(this_data)
39
+ f.write(str(this_data)+"\n")
40
+ f.close()
41
+
42
+
43
+
44
+
45
+
sam2edit.py → editany.py RENAMED
@@ -2,27 +2,35 @@
2
  import os
3
  import gradio as gr
4
  from diffusers.utils import load_image
5
- from sam2edit_lora import EditAnythingLoraModel, config_dict
6
- from sam2edit_demo import create_demo_template
7
  from huggingface_hub import hf_hub_download, snapshot_download
8
 
9
 
10
  def create_demo(process, process_image_click=None):
11
 
12
  examples = None
13
- INFO = f'''
14
  ## EditAnything https://github.com/sail-sg/EditAnything
15
- '''
16
  WARNING_INFO = None
17
 
18
- demo = create_demo_template(process, process_image_click, examples=examples,
19
- INFO=INFO, WARNING_INFO=WARNING_INFO, enable_auto_prompt_default=True)
 
 
 
 
 
 
20
  return demo
21
 
22
 
23
- if __name__ == '__main__':
24
- model = EditAnythingLoraModel(base_model_path="stabilityai/stable-diffusion-2",
25
- controlmodel_name='LAION Pretrained(v0-4)-SD21', extra_inpaint=False,
26
- lora_model_path=None, use_blip=True)
 
 
27
  demo = create_demo(model.process, model.process_image_click)
28
- demo.queue().launch(server_name='0.0.0.0')
 
2
  import os
3
  import gradio as gr
4
  from diffusers.utils import load_image
5
+ from editany_lora import EditAnythingLoraModel, config_dict
6
+ from editany_demo import create_demo_template
7
  from huggingface_hub import hf_hub_download, snapshot_download
8
 
9
 
10
  def create_demo(process, process_image_click=None):
11
 
12
  examples = None
13
+ INFO = f"""
14
  ## EditAnything https://github.com/sail-sg/EditAnything
15
+ """
16
  WARNING_INFO = None
17
 
18
+ demo = create_demo_template(
19
+ process,
20
+ process_image_click,
21
+ examples=examples,
22
+ INFO=INFO,
23
+ WARNING_INFO=WARNING_INFO,
24
+ enable_auto_prompt_default=True,
25
+ )
26
  return demo
27
 
28
 
29
+ if __name__ == "__main__":
30
+ model = EditAnythingLoraModel(
31
+ base_model_path="runwayml/stable-diffusion-v1-5",
32
+ controlmodel_name='LAION Pretrained(v0-4)-SD15',
33
+ lora_model_path=None, use_blip=True, extra_inpaint=True,
34
+ )
35
  demo = create_demo(model.process, model.process_image_click)
36
+ demo.queue().launch(server_name="0.0.0.0")
editany_beauty.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import os
3
+ import gradio as gr
4
+ from diffusers.utils import load_image
5
+ from editany_lora import EditAnythingLoraModel, config_dict
6
+ from editany_demo import create_demo_template
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+
9
+
10
+ def create_demo(process, process_image_click=None):
11
+
12
+ examples = [
13
+ [
14
+ "dudou,1girl, beautiful face, solo, candle, brown hair, long hair, <lora:flowergirl:0.9>,ulzzang-6500-v1.1,(raw photo:1.2),((photorealistic:1.4))best quality ,masterpiece, illustration, an extremely delicate and beautiful, extremely detailed ,CG ,unity ,8k wallpaper, Amazing, finely detail, masterpiece,best quality,official art,extremely detailed CG unity 8k wallpaper,absurdres, incredibly absurdres, huge filesize, ultra-detailed, highres, extremely detailed,beautiful detailed girl, extremely detailed eyes and face, beautiful detailed eyes,cinematic lighting,1girl,see-through,looking at viewer,full body,full-body shot,outdoors,arms behind back,(chinese clothes) <lora:cuteGirlMix4_v10:1>",
15
+ "(((mole))),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, bad anatomy,(long hair:1.4),DeepNegative,(fat:1.2),facing away, looking away,tilted head, lowres,bad anatomy,bad hands, text, error, missing fingers,extra digit, fewer digits, cropped, worstquality, low quality, normal quality,jpegartifacts,signature, watermark, username,blurry,bad feet,cropped,poorly drawn hands,poorly drawn face,mutation,deformed,worst quality,low quality,normal quality,jpeg artifacts,signature,watermark,extra fingers,fewer digits,extra limbs,extra arms,extra legs,malformed limbs,fused fingers,too many fingers,long neck,cross-eyed,mutated hands,polar lowres,bad body,bad proportions,gross proportions,text,error,missing fingers,missing arms,missing legs,extra digit, extra arms, extra leg, extra foot,(freckles),(mole:2)",
16
+ 5,
17
+ ],
18
+ [
19
+ "best quality, ultra high res, (photorealistic:1.4), (detailed beautiful girl:1.4), (medium breasts:0.8), looking_at_viewer, Detailed facial details, beautiful detailed eyes, (multicolored|blue|pink hair: 1.2), green eyes, slender, haunting smile, (makeup:0.3), red lips, <lora:cuteGirlMix4_v10:0.7>, highly detailed clothes, (ulzzang-6500-v1.1:0.3)",
20
+ "EasyNegative, paintings, sketches, ugly, 3d, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, manboobs, backlight,(ugly:1.3), (duplicate:1.3), (morbid:1.2), (mutilated:1.2), (tranny:1.3), mutated hands, (poorly drawn hands:1.3), blurry, (bad anatomy:1.2), (bad proportions:1.3), extra limbs, (disfigured:1.3), (more than 2 nipples:1.3), (more than 1 navel:1.3), (missing arms:1.3), (extra legs:1.3), (fused fingers:1.6), (too many fingers:1.6), (unclear eyes:1.3), bad hands, missing fingers, extra digit, (futa:1.1), bad body, double navel, mutad arms, hused arms, (puffy nipples, dark areolae, dark nipples, rei no himo, inverted nipples, long nipples), NG_DeepNegative_V1_75t, pubic hair, fat rolls, obese, bad-picture-chill-75v",
21
+ 8,
22
+ ],
23
+ [
24
+ "best quality, ultra high res, (photorealistic:1.4), (detailed beautiful girl:1.4), (medium breasts:0.8), looking_at_viewer, Detailed facial details, beautiful detailed eyes, (blue|pink hair), green eyes, slender, smile, (makeup:0.4), red lips, (full body, sitting, beach), <lora:cuteGirlMix4_v10:0.7>, highly detailed clothes, (ulzzang-6500-v1.1:0.3)",
25
+ "asyNegative, paintings, sketches, ugly, 3d, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, manboobs, backlight,(ugly:1.3), (duplicate:1.3), (morbid:1.2), (mutilated:1.2), (tranny:1.3), mutated hands, (poorly drawn hands:1.3), blurry, (bad anatomy:1.2), (bad proportions:1.3), extra limbs, (disfigured:1.3), (more than 2 nipples:1.3), (more than 1 navel:1.3), (missing arms:1.3), (extra legs:1.3), (fused fingers:1.6), (too many fingers:1.6), (unclear eyes:1.3), bad hands, missing fingers, extra digit, (futa:1.1), bad body, double navel, mutad arms, hused arms, (puffy nipples, dark areolae, dark nipples, rei no himo, inverted nipples, long nipples), NG_DeepNegative_V1_75t, pubic hair, fat rolls, obese, bad-picture-chill-75v",
26
+ 7,
27
+ ],
28
+ [
29
+ "mix4, whole body shot, ((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional close-up photo, shy expression, cute, beautiful detailed girl, detailed fingers, extremely detailed eyes and face, beautiful detailed nose, beautiful detailed eyes, long eyelashes, light on face, looking at viewer, (closed mouth:1.2), 1girl, cute, young, mature face, (full body:1.3), ((small breasts)), realistic face, realistic body, beautiful detailed thigh,s, same eyes color, (realistic, photo realism:1. 37), (highest quality), (best shadow), (best illustration), ultra high resolution, physics-based rendering, cinematic lighting), solo, 1girl, highly detailed, in office, detailed office, open cardigan, ponytail contorted, beautiful eyes ,sitting in office,dating, business suit, cross-laced clothes, collared shirt, beautiful breast, small breast, Chinese dress, white pantyhose, natural breasts, pink and white hair, <lora:cuteGirlMix4_v10:1>",
30
+ "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), cloth, underwear, bra, low-res, normal quality, ((monochrome)), ((grayscale)), skin spots, acne, skin blemishes, age spots, glans, bad nipples, long nipples, bad vagina, extra fingers,fewer fingers,strange fingers,bad hand, ng_deepnegative_v1_75t, bad-picture-chill-75v",
31
+ 7,
32
+ ],
33
+ ]
34
+ INFO = f"""
35
+ ## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything
36
+ This model is good at generating beautiful female.
37
+ """
38
+ WARNING_INFO = f"""### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
39
+ We are not responsible for possible risks using this model.
40
+ Lora model from https://civitai.com/models/14171/cutegirlmix4 Thanks!
41
+ """
42
+ demo = create_demo_template(
43
+ process,
44
+ process_image_click,
45
+ examples=examples,
46
+ INFO=INFO,
47
+ WARNING_INFO=WARNING_INFO,
48
+ )
49
+ return demo
50
+
51
+
52
+ if __name__ == "__main__":
53
+ sd_models_path = snapshot_download("shgao/sdmodels")
54
+ lora_model_path = hf_hub_download(
55
+ "mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors"
56
+ )
57
+ model = EditAnythingLoraModel(
58
+ base_model_path=os.path.join(
59
+ sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
60
+ lora_model_path=lora_model_path,
61
+ use_blip=True,
62
+ extra_inpaint=True,
63
+ lora_weight=0.5,
64
+ )
65
+ demo = create_demo(model.process, model.process_image_click)
66
+ demo.queue().launch(server_name="0.0.0.0")
editany_demo.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import gradio as gr
3
+
4
+
5
+ def create_demo_template(
6
+ process,
7
+ process_image_click=None,
8
+ examples=None,
9
+ INFO="EditAnything https://github.com/sail-sg/EditAnything",
10
+ WARNING_INFO=None,
11
+ enable_auto_prompt_default=False,
12
+ ):
13
+
14
+ print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
15
+ block = gr.Blocks()
16
+ with block as demo:
17
+ clicked_points = gr.State([])
18
+ origin_image = gr.State(None)
19
+ click_mask = gr.State(None)
20
+ ref_clicked_points = gr.State([])
21
+ ref_origin_image = gr.State(None)
22
+ ref_click_mask = gr.State(None)
23
+ with gr.Row():
24
+ gr.Markdown(INFO)
25
+ with gr.Row().style(equal_height=False):
26
+ with gr.Column():
27
+ with gr.Tab("Click🖱"):
28
+ source_image_click = gr.Image(
29
+ type="pil",
30
+ interactive=True,
31
+ label="Image: Upload an image and click the region you want to edit.",
32
+ )
33
+ with gr.Column():
34
+ with gr.Row():
35
+ point_prompt = gr.Radio(
36
+ choices=["Foreground Point",
37
+ "Background Point"],
38
+ value="Foreground Point",
39
+ label="Point Label",
40
+ interactive=True,
41
+ show_label=False,
42
+ )
43
+ clear_button_click = gr.Button(
44
+ value="Clear Click Points", interactive=True
45
+ )
46
+ clear_button_image = gr.Button(
47
+ value="Clear Image", interactive=True
48
+ )
49
+ with gr.Row():
50
+ run_button_click = gr.Button(
51
+ label="Run EditAnying", interactive=True
52
+ )
53
+ with gr.Tab("Brush🖌️"):
54
+ source_image_brush = gr.Image(
55
+ source="upload",
56
+ label="Image: Upload an image and cover the region you want to edit with sketch",
57
+ type="numpy",
58
+ tool="sketch",
59
+ )
60
+ run_button = gr.Button(
61
+ label="Run EditAnying", interactive=True)
62
+ with gr.Column():
63
+ enable_all_generate = gr.Checkbox(
64
+ label="Auto generation on all region.", value=False
65
+ )
66
+ control_scale = gr.Slider(
67
+ label="Mask Align strength",
68
+ info="Large value -> strict alignment with SAM mask",
69
+ minimum=0,
70
+ maximum=1,
71
+ value=0.5,
72
+ step=0.1,
73
+ )
74
+ with gr.Column():
75
+ enable_auto_prompt = gr.Checkbox(
76
+ label="Auto generate text prompt from input image with BLIP2",
77
+ info="Warning: Enable this may makes your prompt not working.",
78
+ value=enable_auto_prompt_default,
79
+ )
80
+ a_prompt = gr.Textbox(
81
+ label="Positive Prompt",
82
+ info="Text in the expected things of edited region",
83
+ value="best quality, extremely detailed,",
84
+ )
85
+ n_prompt = gr.Textbox(
86
+ label="Negative Prompt",
87
+ value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, NSFW",
88
+ )
89
+ with gr.Row():
90
+ num_samples = gr.Slider(
91
+ label="Images", minimum=1, maximum=12, value=2, step=1
92
+ )
93
+ seed = gr.Slider(
94
+ label="Seed",
95
+ minimum=-1,
96
+ maximum=2147483647,
97
+ step=1,
98
+ randomize=True,
99
+ )
100
+ with gr.Row():
101
+ enable_tile = gr.Checkbox(
102
+ label="Tile refinement for high resolution generation",
103
+ info="Slow inference",
104
+ value=True,
105
+ )
106
+ refine_alignment_ratio = gr.Slider(
107
+ label="Alignment Strength",
108
+ info="Large value -> strict alignment with input image. Small value -> strong global consistency",
109
+ minimum=0.0,
110
+ maximum=1.0,
111
+ value=0.95,
112
+ step=0.05,
113
+ )
114
+
115
+ with gr.Accordion("Reference options", open=False):
116
+ # ref_image = gr.Image(
117
+ # source='upload', label="Upload a reference image", type="pil", value=None)
118
+ ref_image = gr.Image(
119
+ source="upload",
120
+ label="Upload a reference image and cover the region you want to use with sketch",
121
+ type="pil",
122
+ tool="sketch",
123
+ )
124
+ with gr.Column():
125
+ ref_auto_prompt = gr.Checkbox(
126
+ label="Ref. Auto Prompt", value=True
127
+ )
128
+ ref_prompt = gr.Textbox(
129
+ label="Prompt",
130
+ info="Text in the prompt of edited region",
131
+ value="best quality, extremely detailed, ",
132
+ )
133
+ # ref_image = gr.Image(
134
+ # type="pil", interactive=True,
135
+ # label="Image: Upload an image and click the region you want to use as reference.",
136
+ # )
137
+ # with gr.Column():
138
+ # with gr.Row():
139
+ # ref_point_prompt = gr.Radio(
140
+ # choices=["Foreground Point", "Background Point"],
141
+ # value="Foreground Point",
142
+ # label="Point Label",
143
+ # interactive=True, show_label=False)
144
+ # ref_clear_button_click = gr.Button(
145
+ # value="Clear Click Points", interactive=True)
146
+ # ref_clear_button_image = gr.Button(
147
+ # value="Clear Image", interactive=True)
148
+ with gr.Row():
149
+ reference_attn = gr.Checkbox(
150
+ label="reference_attn", value=True)
151
+ attention_auto_machine_weight = gr.Slider(
152
+ label="attention_weight",
153
+ minimum=0,
154
+ maximum=1.0,
155
+ value=0.8,
156
+ step=0.01,
157
+ )
158
+ with gr.Row():
159
+ reference_adain = gr.Checkbox(
160
+ label="reference_adain", value=False
161
+ )
162
+ gn_auto_machine_weight = gr.Slider(
163
+ label="gn_weight",
164
+ minimum=0,
165
+ maximum=1.0,
166
+ value=0.1,
167
+ step=0.01,
168
+ )
169
+ style_fidelity = gr.Slider(
170
+ label="Style fidelity",
171
+ minimum=0,
172
+ maximum=1.0,
173
+ value=0.5,
174
+ step=0.01,
175
+ )
176
+ ref_sam_scale = gr.Slider(
177
+ label="SAM Control Scale",
178
+ minimum=0,
179
+ maximum=1.0,
180
+ value=0.3,
181
+ step=0.1,
182
+ )
183
+ ref_inpaint_scale = gr.Slider(
184
+ label="Inpaint Control Scale",
185
+ minimum=0,
186
+ maximum=1.0,
187
+ value=0.2,
188
+ step=0.1,
189
+ )
190
+ with gr.Row():
191
+ ref_textinv = gr.Checkbox(
192
+ label="Use textual inversion token", value=False
193
+ )
194
+ ref_textinv_path = gr.Textbox(
195
+ label="textual inversion token path",
196
+ info="Text in the inversion token path",
197
+ value=None,
198
+ )
199
+
200
+ with gr.Accordion("Advanced options", open=False):
201
+ mask_image = gr.Image(
202
+ source="upload",
203
+ label="Upload a predefined mask of edit region: Switch to Brush mode when using this!",
204
+ type="numpy",
205
+ value=None,
206
+ )
207
+ image_resolution = gr.Slider(
208
+ label="Image Resolution",
209
+ minimum=256,
210
+ maximum=768,
211
+ value=512,
212
+ step=64,
213
+ )
214
+ refine_image_resolution = gr.Slider(
215
+ label="Image Resolution",
216
+ minimum=256,
217
+ maximum=8192,
218
+ value=1024,
219
+ step=64,
220
+ )
221
+ guess_mode = gr.Checkbox(label="Guess Mode", value=False)
222
+ detect_resolution = gr.Slider(
223
+ label="SAM Resolution",
224
+ minimum=128,
225
+ maximum=2048,
226
+ value=1024,
227
+ step=1,
228
+ )
229
+ ddim_steps = gr.Slider(
230
+ label="Steps", minimum=1, maximum=100, value=30, step=1
231
+ )
232
+ scale = gr.Slider(
233
+ label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
234
+ alpha_weight = gr.Slider(
235
+ label="Alpha weight", info="Alpha mixing with original image", minimum=0,
236
+ maximum=1, value=0.0, step=0.1)
237
+ use_scale_map = gr.Checkbox(
238
+ label='Use scale map', value=False)
239
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
240
+ condition_model = gr.Textbox(
241
+ label="Condition model path",
242
+ info="Text in the Controlnet model path in hugglingface",
243
+ value="EditAnything",
244
+ )
245
+ with gr.Column():
246
+ result_gallery_refine = gr.Gallery(
247
+ label="Output High quality", show_label=True, elem_id="gallery"
248
+ ).style(grid=2, preview=False)
249
+ result_gallery_init = gr.Gallery(
250
+ label="Output Low quality", show_label=True, elem_id="gallery"
251
+ ).style(grid=2, height="auto")
252
+ result_gallery_ref = gr.Gallery(
253
+ label="Output Ref", show_label=False, elem_id="gallery"
254
+ ).style(grid=2, height="auto")
255
+ result_text = gr.Text(label="BLIP2+Human Prompt Text")
256
+
257
+ ips = [
258
+ source_image_brush,
259
+ enable_all_generate,
260
+ mask_image,
261
+ control_scale,
262
+ enable_auto_prompt,
263
+ a_prompt,
264
+ n_prompt,
265
+ num_samples,
266
+ image_resolution,
267
+ detect_resolution,
268
+ ddim_steps,
269
+ guess_mode,
270
+ scale,
271
+ seed,
272
+ eta,
273
+ enable_tile,
274
+ refine_alignment_ratio,
275
+ refine_image_resolution,
276
+ alpha_weight,
277
+ use_scale_map,
278
+ condition_model,
279
+ ref_image,
280
+ attention_auto_machine_weight,
281
+ gn_auto_machine_weight,
282
+ style_fidelity,
283
+ reference_attn,
284
+ reference_adain,
285
+ ref_prompt,
286
+ ref_sam_scale,
287
+ ref_inpaint_scale,
288
+ ref_auto_prompt,
289
+ ref_textinv,
290
+ ref_textinv_path,
291
+ ]
292
+ run_button.click(
293
+ fn=process,
294
+ inputs=ips,
295
+ outputs=[
296
+ result_gallery_refine,
297
+ result_gallery_init,
298
+ result_gallery_ref,
299
+ result_text,
300
+ ],
301
+ )
302
+
303
+ ip_click = [
304
+ origin_image,
305
+ enable_all_generate,
306
+ click_mask,
307
+ control_scale,
308
+ enable_auto_prompt,
309
+ a_prompt,
310
+ n_prompt,
311
+ num_samples,
312
+ image_resolution,
313
+ detect_resolution,
314
+ ddim_steps,
315
+ guess_mode,
316
+ scale,
317
+ seed,
318
+ eta,
319
+ enable_tile,
320
+ refine_alignment_ratio,
321
+ refine_image_resolution,
322
+ alpha_weight,
323
+ use_scale_map,
324
+ condition_model,
325
+ ref_image,
326
+ attention_auto_machine_weight,
327
+ gn_auto_machine_weight,
328
+ style_fidelity,
329
+ reference_attn,
330
+ reference_adain,
331
+ ref_prompt,
332
+ ref_sam_scale,
333
+ ref_inpaint_scale,
334
+ ref_auto_prompt,
335
+ ref_textinv,
336
+ ref_textinv_path,
337
+ ]
338
+
339
+ run_button_click.click(
340
+ fn=process,
341
+ inputs=ip_click,
342
+ outputs=[
343
+ result_gallery_refine,
344
+ result_gallery_init,
345
+ result_gallery_ref,
346
+ result_text,
347
+ ],
348
+ )
349
+
350
+ source_image_click.upload(
351
+ lambda image: image.copy() if image is not None else None,
352
+ inputs=[source_image_click],
353
+ outputs=[origin_image],
354
+ )
355
+ source_image_click.select(
356
+ process_image_click,
357
+ inputs=[origin_image, point_prompt,
358
+ clicked_points, image_resolution],
359
+ outputs=[source_image_click, clicked_points, click_mask],
360
+ show_progress=True,
361
+ queue=True,
362
+ )
363
+ clear_button_click.click(
364
+ fn=lambda original_image: (original_image.copy(), [], None)
365
+ if original_image is not None
366
+ else (None, [], None),
367
+ inputs=[origin_image],
368
+ outputs=[source_image_click, clicked_points, click_mask],
369
+ )
370
+ clear_button_image.click(
371
+ fn=lambda: (None, [], None, None, None),
372
+ inputs=[],
373
+ outputs=[
374
+ source_image_click,
375
+ clicked_points,
376
+ click_mask,
377
+ result_gallery_init,
378
+ result_text,
379
+ ],
380
+ )
381
+
382
+ if examples is not None:
383
+ with gr.Row():
384
+ ex = gr.Examples(
385
+ examples=examples,
386
+ fn=process,
387
+ inputs=[a_prompt, n_prompt, scale],
388
+ outputs=[result_gallery_init],
389
+ cache_examples=False,
390
+ )
391
+ if WARNING_INFO is not None:
392
+ with gr.Row():
393
+ gr.Markdown(WARNING_INFO)
394
+ return demo
sam2edit_handsome.py → editany_handsome.py RENAMED
@@ -2,36 +2,49 @@
2
  import os
3
  import gradio as gr
4
  from diffusers.utils import load_image
5
- from sam2edit_lora import EditAnythingLoraModel, config_dict
6
- from sam2edit_demo import create_demo_template
7
  from huggingface_hub import hf_hub_download, snapshot_download
8
 
9
 
10
  def create_demo(process, process_image_click=None):
11
 
12
  examples = [
13
- ["1man, muscle,full body, vest, short straight hair, glasses, Gym, barbells, dumbbells, treadmills, boxing rings, squat racks, plates, dumbbell racks soft lighting, masterpiece, best quality, 8k uhd, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6>",
14
- "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
15
- ["1man, 25 years- old, full body, wearing long-sleeve white shirt and tie, muscular rand black suit, soft lighting, masterpiece, best quality, 8k uhd, dslr, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6> <lora:uncutPenisLora_v10:0.6>",
16
- "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", 6],
 
 
 
 
 
 
17
  ]
18
 
19
  print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
20
 
21
- INFO = f'''
22
  ## Generate Your Handsome powered by EditAnything https://github.com/sail-sg/EditAnything
23
  This model is good at generating handsome male.
24
- '''
25
- WARNING_INFO = f'''### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
26
  We are not responsible for possible risks using this model.
27
  Base model from https://huggingface.co/SG161222/Realistic_Vision_V2.0 Thanks!
28
- '''
29
- demo = create_demo_template(process, process_image_click, examples=examples, INFO=INFO, WARNING_INFO=WARNING_INFO)
 
 
 
 
 
 
30
  return demo
31
 
32
 
33
- if __name__ == '__main__':
34
- model = EditAnythingLoraModel(base_model_path='Realistic_Vision_V2.0',
35
- lora_model_path=None, use_blip=True)
 
36
  demo = create_demo(model.process, model.process_image_click)
37
- demo.queue().launch(server_name='0.0.0.0')
 
2
  import os
3
  import gradio as gr
4
  from diffusers.utils import load_image
5
+ from editany_lora import EditAnythingLoraModel, config_dict
6
+ from editany_demo import create_demo_template
7
  from huggingface_hub import hf_hub_download, snapshot_download
8
 
9
 
10
  def create_demo(process, process_image_click=None):
11
 
12
  examples = [
13
+ [
14
+ "1man, muscle,full body, vest, short straight hair, glasses, Gym, barbells, dumbbells, treadmills, boxing rings, squat racks, plates, dumbbell racks soft lighting, masterpiece, best quality, 8k uhd, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6>",
15
+ "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
16
+ 6,
17
+ ],
18
+ [
19
+ "1man, 25 years- old, full body, wearing long-sleeve white shirt and tie, muscular rand black suit, soft lighting, masterpiece, best quality, 8k uhd, dslr, film grain, Fujifilm XT3 photorealistic painting art by midjourney and greg rutkowski <lora:asianmale_v10:0.6> <lora:uncutPenisLora_v10:0.6>",
20
+ "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
21
+ 6,
22
+ ],
23
  ]
24
 
25
  print("The GUI is not fully tested yet. Please open an issue if you find bugs.")
26
 
27
+ INFO = f"""
28
  ## Generate Your Handsome powered by EditAnything https://github.com/sail-sg/EditAnything
29
  This model is good at generating handsome male.
30
+ """
31
+ WARNING_INFO = f"""### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
32
  We are not responsible for possible risks using this model.
33
  Base model from https://huggingface.co/SG161222/Realistic_Vision_V2.0 Thanks!
34
+ """
35
+ demo = create_demo_template(
36
+ process,
37
+ process_image_click,
38
+ examples=examples,
39
+ INFO=INFO,
40
+ WARNING_INFO=WARNING_INFO,
41
+ )
42
  return demo
43
 
44
 
45
+ if __name__ == "__main__":
46
+ model = EditAnythingLoraModel(
47
+ base_model_path="Realistic_Vision_V2.0", lora_model_path=None, use_blip=True
48
+ )
49
  demo = create_demo(model.process, model.process_image_click)
50
+ demo.queue().launch(server_name="0.0.0.0")
sam2edit_lora.py → editany_lora.py RENAMED
@@ -14,41 +14,70 @@ import random
14
  import os
15
  import requests
16
  from io import BytesIO
17
- from annotator.util import resize_image, HWC3, resize_points
18
 
19
  import torch
20
  from safetensors.torch import load_file
21
  from collections import defaultdict
22
  from diffusers import StableDiffusionControlNetPipeline
23
  from diffusers import ControlNetModel, UniPCMultistepScheduler
24
- from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
 
 
 
25
  # need the latest transformers
26
  # pip install git+https://github.com/huggingface/transformers.git
27
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
28
  from diffusers import ControlNetModel, DiffusionPipeline
 
29
  import PIL.Image
30
 
31
  # Segment-Anything init.
32
  # pip install git+https://github.com/facebookresearch/segment-anything.git
33
  try:
34
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
 
 
 
 
35
  except ImportError:
36
- print('segment_anything not installed')
37
  result = subprocess.run(
38
- ['pip', 'install', 'git+https://github.com/facebookresearch/segment-anything.git'], check=True)
39
- print(f'Install segment_anything {result}')
40
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
41
- if not os.path.exists('./models/sam_vit_h_4b8939.pth'):
 
 
 
 
 
 
 
 
 
 
42
  result = subprocess.run(
43
- ['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth', '-P', 'models'], check=True)
44
- print(f'Download sam_vit_h_4b8939.pth {result}')
 
 
 
 
 
 
 
45
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
 
48
- config_dict = OrderedDict([
49
- ('LAION Pretrained(v0-4)-SD15', 'shgao/edit-anything-v0-4-sd15'),
50
- ('LAION Pretrained(v0-4)-SD21', 'shgao/edit-anything-v0-4-sd21'),
51
- ])
 
 
 
 
52
 
53
 
54
  def init_sam_model(sam_generator=None, mask_predictor=None):
@@ -58,8 +87,10 @@ def init_sam_model(sam_generator=None, mask_predictor=None):
58
  model_type = "default"
59
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
60
  sam.to(device=device)
61
- sam_generator = SamAutomaticMaskGenerator(
62
- sam) if sam_generator is None else sam_generator
 
 
63
  mask_predictor = SamPredictor(
64
  sam) if mask_predictor is None else mask_predictor
65
  return sam_generator, mask_predictor
@@ -72,13 +103,14 @@ def init_blip_processor():
72
 
73
  def init_blip_model():
74
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
75
- "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
 
76
  return blip_model
77
 
78
 
79
  def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
80
  # https://github.com/huggingface/diffusers/issues/2136
81
- """ Get pipeline embeds for prompts bigger than the maxlength of the pipe
82
  :param pipeline:
83
  :param prompt:
84
  :param negative_prompt:
@@ -88,30 +120,40 @@ def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
88
  max_length = pipeline.tokenizer.model_max_length
89
 
90
  # simple way to determine length of tokens
91
- count_prompt = len(re.split(r', ', prompt))
92
- count_negative_prompt = len(re.split(r', ', negative_prompt))
93
 
94
  # create the tensor based on which prompt is longer
95
  if count_prompt >= count_negative_prompt:
96
  input_ids = pipeline.tokenizer(
97
- prompt, return_tensors="pt", truncation=False).input_ids.to(device)
 
98
  shape_max_length = input_ids.shape[-1]
99
- negative_ids = pipeline.tokenizer(negative_prompt, truncation=False, padding="max_length",
100
- max_length=shape_max_length, return_tensors="pt").input_ids.to(device)
 
 
 
 
 
101
  else:
102
  negative_ids = pipeline.tokenizer(
103
- negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device)
 
104
  shape_max_length = negative_ids.shape[-1]
105
- input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length",
106
- max_length=shape_max_length).input_ids.to(device)
 
 
 
 
 
107
 
108
  concat_embeds = []
109
  neg_embeds = []
110
  for i in range(0, shape_max_length, max_length):
111
- concat_embeds.append(pipeline.text_encoder(
112
- input_ids[:, i: i + max_length])[0])
113
- neg_embeds.append(pipeline.text_encoder(
114
- negative_ids[:, i: i + max_length])[0])
115
 
116
  return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
117
 
@@ -120,8 +162,8 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
120
  LORA_PREFIX_UNET = "lora_unet"
121
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
122
  # load LoRA weight from .safetensors
 
123
  if isinstance(checkpoint_path, str):
124
-
125
  state_dict = load_file(checkpoint_path, device=device)
126
 
127
  updates = defaultdict(dict)
@@ -129,19 +171,17 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
129
  # it is suggested to print out the key, it usually will be something like below
130
  # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
131
 
132
- layer, elem = key.split('.', 1)
133
  updates[layer][elem] = value
134
 
135
  # directly update weight in diffusers model
136
  for layer, elems in updates.items():
137
 
138
  if "text" in layer:
139
- layer_infos = layer.split(
140
- LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
141
  curr_layer = pipeline.text_encoder
142
  else:
143
- layer_infos = layer.split(
144
- LORA_PREFIX_UNET + "_")[-1].split("_")
145
  curr_layer = pipeline.unet
146
 
147
  # find the target layer
@@ -160,9 +200,9 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
160
  temp_name = layer_infos.pop(0)
161
 
162
  # get elements for this layer
163
- weight_up = elems['lora_up.weight'].to(dtype)
164
- weight_down = elems['lora_down.weight'].to(dtype)
165
- alpha = elems['alpha']
166
  if alpha:
167
  alpha = alpha.item() / weight_up.shape[1]
168
  else:
@@ -170,11 +210,20 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
170
 
171
  # update weight
172
  if len(weight_up.shape) == 4:
173
- curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(
174
- 3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
 
 
 
 
 
 
 
 
175
  else:
176
- curr_layer.weight.data += multiplier * \
177
- alpha * torch.mm(weight_up, weight_down)
 
178
  else:
179
  for ckptpath in checkpoint_path:
180
  state_dict = load_file(ckptpath, device=device)
@@ -184,18 +233,18 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
184
  # it is suggested to print out the key, it usually will be something like below
185
  # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
186
 
187
- layer, elem = key.split('.', 1)
188
  updates[layer][elem] = value
189
 
190
  # directly update weight in diffusers model
191
  for layer, elems in updates.items():
192
  if "text" in layer:
193
- layer_infos = layer.split(
194
- LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
 
195
  curr_layer = pipeline.text_encoder
196
  else:
197
- layer_infos = layer.split(
198
- LORA_PREFIX_UNET + "_")[-1].split("_")
199
  curr_layer = pipeline.unet
200
 
201
  # find the target layer
@@ -214,9 +263,9 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
214
  temp_name = layer_infos.pop(0)
215
 
216
  # get elements for this layer
217
- weight_up = elems['lora_up.weight'].to(dtype)
218
- weight_down = elems['lora_down.weight'].to(dtype)
219
- alpha = elems['alpha']
220
  if alpha:
221
  alpha = alpha.item() / weight_up.shape[1]
222
  else:
@@ -224,52 +273,75 @@ def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
224
 
225
  # update weight
226
  if len(weight_up.shape) == 4:
227
- curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(
228
- 3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
 
 
 
 
 
 
 
 
229
  else:
230
- curr_layer.weight.data += multiplier * \
231
- alpha * torch.mm(weight_up, weight_down)
 
232
  return pipeline
233
 
234
 
235
  def make_inpaint_condition(image, image_mask):
236
- # image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
237
  image = image / 255.0
238
- print("img", image.max(), image.min(), image_mask.max(), image_mask.min())
239
- # image_mask = np.array(image_mask.convert("L"))
240
- assert image.shape[0:1] == image_mask.shape[0:
241
- 1], "image and image_mask must have the same image size"
242
  image[image_mask > 128] = -1.0 # set as masked pixel
243
  image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
244
  image = torch.from_numpy(image)
245
  return image
246
 
247
 
248
- def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, generation_only=False, extra_inpaint=True, lora_weight=1.0):
 
 
 
 
 
 
 
249
  controlnet = []
250
- controlnet.append(ControlNetModel.from_pretrained(
251
- controlnet_path, torch_dtype=torch.float16)) # sam control
 
 
252
  if (not generation_only) and extra_inpaint: # inpainting control
253
  print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
254
  controlnet.append(
255
  ControlNetModel.from_pretrained(
256
- 'lllyasviel/control_v11p_sd15_inpaint', torch_dtype=torch.float16) # inpainting controlnet
 
257
  )
258
 
259
  if generation_only and extra_inpaint:
260
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
261
- base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
 
 
 
262
  )
263
  else:
264
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
265
- base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
 
 
 
266
  )
267
  if lora_model_path is not None:
268
  pipe = load_lora_weights(
269
- pipe, [lora_model_path], lora_weight, 'cpu', torch.float32)
 
270
  # speed up diffusion process with faster scheduler and memory optimization
271
- pipe.scheduler = UniPCMultistepScheduler.from_config(
272
- pipe.scheduler.config)
273
  # remove following line if xformers is not installed
274
  pipe.enable_xformers_memory_efficient_attention()
275
 
@@ -278,23 +350,33 @@ def obtain_generation_model(base_model_path, lora_model_path, controlnet_path, g
278
 
279
 
280
  def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
281
- controlnet = ControlNetModel.from_pretrained(
282
- 'lllyasviel/control_v11f1e_sd15_tile', torch_dtype=torch.float16) # tile controlnet
283
- if base_model_path == 'runwayml/stable-diffusion-v1-5' or base_model_path == 'stabilityai/stable-diffusion-2-inpainting':
 
 
 
 
284
  print("base_model_path", base_model_path)
285
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
286
- "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
 
 
 
287
  )
288
  else:
289
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
290
- base_model_path, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
 
 
 
291
  )
292
  if lora_model_path is not None:
293
  pipe = load_lora_weights(
294
- pipe, [lora_model_path], lora_weight, 'cpu', torch.float32)
 
295
  # speed up diffusion process with faster scheduler and memory optimization
296
- pipe.scheduler = UniPCMultistepScheduler.from_config(
297
- pipe.scheduler.config)
298
  # remove following line if xformers is not installed
299
  pipe.enable_xformers_memory_efficient_attention()
300
 
@@ -305,20 +387,20 @@ def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
305
  def show_anns(anns):
306
  if len(anns) == 0:
307
  return
308
- sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
309
  full_img = None
310
 
311
  # for ann in sorted_anns:
312
  for i in range(len(sorted_anns)):
313
  ann = anns[i]
314
- m = ann['segmentation']
315
  if full_img is None:
316
  full_img = np.zeros((m.shape[0], m.shape[1], 3))
317
  map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
318
  map[m != 0] = i + 1
319
  color_mask = np.random.random((1, 3)).tolist()[0]
320
  full_img[m != 0] = color_mask
321
- full_img = full_img*255
322
  # anno encoding from https://github.com/LUSSeg/ImageNet-S
323
  res = np.zeros((map.shape[0], map.shape[1], 3))
324
  res[:, :, 0] = map % 256
@@ -329,19 +411,22 @@ def show_anns(anns):
329
 
330
 
331
  class EditAnythingLoraModel:
332
- def __init__(self,
333
- base_model_path='../chilloutmix_NiPrunedFp32Fix',
334
- lora_model_path='../40806/mix4', use_blip=True,
335
- blip_processor=None,
336
- blip_model=None,
337
- sam_generator=None,
338
- controlmodel_name='LAION Pretrained(v0-4)-SD15',
339
- # used when the base model is not an inpainting model.
340
- extra_inpaint=True,
341
- tile_model=None,
342
- lora_weight=1.0,
343
- mask_predictor=None
344
- ):
 
 
 
345
  self.device = device
346
  self.use_blip = use_blip
347
 
@@ -351,12 +436,20 @@ class EditAnythingLoraModel:
351
  self.lora_model_path = lora_model_path
352
  self.defalut_enable_all_generate = False
353
  self.extra_inpaint = extra_inpaint
 
354
  self.pipe = obtain_generation_model(
355
- base_model_path, lora_model_path, self.default_controlnet_path, generation_only=False, extra_inpaint=extra_inpaint, lora_weight=lora_weight)
356
-
 
 
 
 
 
 
357
  # Segment-Anything init.
358
  self.sam_generator, self.mask_predictor = init_sam_model(
359
- sam_generator, mask_predictor)
 
360
  # BLIP2 init.
361
  if use_blip:
362
  if blip_processor is not None:
@@ -374,14 +467,17 @@ class EditAnythingLoraModel:
374
  self.tile_pipe = tile_model
375
  else:
376
  self.tile_pipe = obtain_tile_model(
377
- base_model_path, lora_model_path, lora_weight=lora_weight)
 
378
 
379
  def get_blip2_text(self, image):
380
  inputs = self.blip_processor(image, return_tensors="pt").to(
381
- self.device, torch.float16)
 
382
  generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50)
383
  generated_text = self.blip_processor.batch_decode(
384
- generated_ids, skip_special_tokens=True)[0].strip()
 
385
  return generated_text
386
 
387
  def get_sam_control(self, image):
@@ -408,11 +504,14 @@ class EditAnythingLoraModel:
408
  return masks
409
 
410
  @torch.inference_mode()
411
- def process_image_click(self, original_image: gr.Image,
412
- point_prompt: gr.Radio,
413
- clicked_points: gr.State,
414
- image_resolution,
415
- evt: gr.SelectData):
 
 
 
416
  # Get the clicked coordinates
417
  clicked_coords = evt.index
418
  x, y = clicked_coords
@@ -426,17 +525,16 @@ class EditAnythingLoraModel:
426
  img = resize_image(input_image, image_resolution)
427
 
428
  # Update the clicked_points
429
- resized_points = resize_points(clicked_points,
430
- input_image.shape,
431
- image_resolution)
432
  mask_click_np = self.get_click_mask(img, resized_points)
433
 
434
  # Convert mask_click_np to HWC format
435
  mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
436
 
437
  mask_image = HWC3(mask_click_np.astype(np.uint8))
438
- mask_image = cv2.resize(
439
- mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
440
  # mask_image = Image.fromarray(mask_image_tmp)
441
 
442
  # Draw circles for all clicked points
@@ -454,50 +552,161 @@ class EditAnythingLoraModel:
454
 
455
  # Combine the edited_image and the mask_image using cv2.addWeighted()
456
  overlay_image = cv2.addWeighted(
457
- edited_image, opacity_edited,
458
- (mask_image * np.array([0/255, 255/255, 0/255])).astype(np.uint8),
459
- opacity_mask, 0
 
 
 
460
  )
461
 
462
- return Image.fromarray(overlay_image), clicked_points, Image.fromarray(mask_image)
 
 
 
 
463
 
464
  @torch.inference_mode()
465
- def process(self, source_image, enable_all_generate, mask_image,
466
- control_scale,
467
- enable_auto_prompt, a_prompt, n_prompt,
468
- num_samples, image_resolution, detect_resolution,
469
- ddim_steps, guess_mode, scale, seed, eta,
470
- enable_tile=True, refine_alignment_ratio=None, refine_image_resolution=None, condition_model=None):
471
-
472
- if condition_model is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  this_controlnet_path = self.default_controlnet_path
474
  else:
475
- this_controlnet_path = config_dict[condition_model]
476
- input_image = source_image["image"] if isinstance(
477
- source_image, dict) else np.array(source_image, dtype=np.uint8)
 
 
 
478
  if mask_image is None:
479
  if enable_all_generate != self.defalut_enable_all_generate:
480
  self.pipe = obtain_generation_model(
481
- self.base_model_path, self.lora_model_path, this_controlnet_path, enable_all_generate, self.extra_inpaint)
482
-
 
 
 
 
483
  self.defalut_enable_all_generate = enable_all_generate
484
  if enable_all_generate:
485
- print("source_image",
486
- source_image["mask"].shape, input_image.shape,)
487
- mask_image = np.ones(
488
- (input_image.shape[0], input_image.shape[1], 3))*255
 
 
 
 
 
489
  else:
490
  mask_image = source_image["mask"]
491
  else:
492
  mask_image = np.array(mask_image, dtype=np.uint8)
493
  if self.default_controlnet_path != this_controlnet_path:
494
- print("To Use:", this_controlnet_path,
495
- "Current:", self.default_controlnet_path)
 
 
 
 
496
  print("Change condition model to:", this_controlnet_path)
497
  self.pipe = obtain_generation_model(
498
- self.base_model_path, self.lora_model_path, this_controlnet_path, enable_all_generate, self.extra_inpaint)
 
 
 
 
 
499
  self.default_controlnet_path = this_controlnet_path
500
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  with torch.no_grad():
503
  if self.use_blip and enable_auto_prompt:
@@ -505,7 +714,7 @@ class EditAnythingLoraModel:
505
  blip2_prompt = self.get_blip2_text(input_image)
506
  print("Generated text:", blip2_prompt)
507
  if len(a_prompt) > 0:
508
- a_prompt = blip2_prompt + ',' + a_prompt
509
  else:
510
  a_prompt = blip2_prompt
511
 
@@ -517,20 +726,22 @@ class EditAnythingLoraModel:
517
  print("Generating SAM seg:")
518
  # the default SAM model is trained with 1024 size.
519
  full_segmask, detected_map = self.get_sam_control(
520
- resize_image(input_image, detect_resolution))
 
521
 
522
  detected_map = HWC3(detected_map.astype(np.uint8))
523
  detected_map = cv2.resize(
524
- detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
 
525
 
526
- control = torch.from_numpy(
527
- detected_map.copy()).float().cuda()
528
  control = torch.stack([control for _ in range(num_samples)], dim=0)
529
- control = einops.rearrange(control, 'b h w c -> b c h w').clone()
530
 
531
  mask_imag_ori = HWC3(mask_image.astype(np.uint8))
532
  mask_image_tmp = cv2.resize(
533
- mask_imag_ori, (W, H), interpolation=cv2.INTER_LINEAR)
 
534
  mask_image = Image.fromarray(mask_image_tmp)
535
 
536
  if seed == -1:
@@ -540,15 +751,22 @@ class EditAnythingLoraModel:
540
  postive_prompt = a_prompt
541
  negative_prompt = n_prompt
542
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
543
- self.pipe, postive_prompt, negative_prompt, "cuda")
 
544
  prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
545
  negative_prompt_embeds = torch.cat(
546
- [negative_prompt_embeds] * num_samples, dim=0)
 
 
547
  if enable_all_generate and self.extra_inpaint:
548
  self.pipe.safety_checker = lambda images, clip_input: (
549
  images, False)
 
 
 
550
  x_samples = self.pipe(
551
- prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
 
552
  num_images_per_prompt=num_samples,
553
  num_inference_steps=ddim_steps,
554
  generator=generator,
@@ -557,58 +775,130 @@ class EditAnythingLoraModel:
557
  image=[control.type(torch.float16)],
558
  controlnet_conditioning_scale=[float(control_scale)],
559
  guidance_scale=scale,
 
560
  ).images
561
  else:
562
  multi_condition_image = []
563
  multi_condition_scale = []
564
  multi_condition_image.append(control.type(torch.float16))
565
  multi_condition_scale.append(float(control_scale))
 
 
 
566
  if self.extra_inpaint:
567
  inpaint_image = make_inpaint_condition(img, mask_image_tmp)
568
- print(inpaint_image.shape)
569
  multi_condition_image.append(
570
  inpaint_image.type(torch.float16))
571
  multi_condition_scale.append(1.0)
572
- x_samples = self.pipe(
573
- image=img,
574
- mask_image=mask_image,
575
- prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
576
- num_images_per_prompt=num_samples,
577
- num_inference_steps=ddim_steps,
578
- generator=generator,
579
- controlnet_conditioning_image=multi_condition_image,
580
- height=H,
581
- width=W,
582
- controlnet_conditioning_scale=multi_condition_scale,
583
- guidance_scale=scale,
584
- ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  results = [x_samples[i] for i in range(num_samples)]
586
 
587
  results_tile = []
588
  if enable_tile:
589
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
590
- self.tile_pipe, postive_prompt, negative_prompt, "cuda")
 
591
  for i in range(num_samples):
592
- img_tile = PIL.Image.fromarray(resize_image(
593
- np.array(x_samples[i]), refine_image_resolution))
 
 
594
  if i == 0:
595
  mask_image_tile = cv2.resize(
596
- mask_imag_ori, (img_tile.size[0], img_tile.size[1]), interpolation=cv2.INTER_LINEAR)
 
 
 
597
  mask_image_tile = Image.fromarray(mask_image_tile)
598
- x_samples_tile = self.tile_pipe(
599
- image=img_tile,
600
- mask_image=mask_image_tile,
601
- prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
602
- num_images_per_prompt=1,
603
- num_inference_steps=ddim_steps,
604
- generator=generator,
605
- controlnet_conditioning_image=img_tile,
606
- height=img_tile.size[1],
607
- width=img_tile.size[0],
608
- controlnet_conditioning_scale=1.0,
609
- alignment_ratio=refine_alignment_ratio,
610
- guidance_scale=scale,
611
- ).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  results_tile += x_samples_tile
613
 
614
  return results_tile, results, [full_segmask, mask_image], postive_prompt
 
14
  import os
15
  import requests
16
  from io import BytesIO
17
+ from annotator.util import resize_image, HWC3, resize_points, get_bounding_box
18
 
19
  import torch
20
  from safetensors.torch import load_file
21
  from collections import defaultdict
22
  from diffusers import StableDiffusionControlNetPipeline
23
  from diffusers import ControlNetModel, UniPCMultistepScheduler
24
+
25
+ from utils.stable_diffusion_controlnet import ControlNetModel2
26
+ from utils.stable_diffusion_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline, \
27
+ StableDiffusionControlNetInpaintMixingPipeline, prepare_mask_image
28
  # need the latest transformers
29
  # pip install git+https://github.com/huggingface/transformers.git
30
  from transformers import AutoProcessor, Blip2ForConditionalGeneration
31
  from diffusers import ControlNetModel, DiffusionPipeline
32
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
  import PIL.Image
34
 
35
  # Segment-Anything init.
36
  # pip install git+https://github.com/facebookresearch/segment-anything.git
37
  try:
38
+ from segment_anything import (
39
+ sam_model_registry,
40
+ SamAutomaticMaskGenerator,
41
+ SamPredictor,
42
+ )
43
  except ImportError:
44
+ print("segment_anything not installed")
45
  result = subprocess.run(
46
+ [
47
+ "pip",
48
+ "install",
49
+ "git+https://github.com/facebookresearch/segment-anything.git",
50
+ ],
51
+ check=True,
52
+ )
53
+ print(f"Install segment_anything {result}")
54
+ from segment_anything import (
55
+ sam_model_registry,
56
+ SamAutomaticMaskGenerator,
57
+ SamPredictor,
58
+ )
59
+ if not os.path.exists("./models/sam_vit_h_4b8939.pth"):
60
  result = subprocess.run(
61
+ [
62
+ "wget",
63
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
64
+ "-P",
65
+ "models",
66
+ ],
67
+ check=True,
68
+ )
69
+ print(f"Download sam_vit_h_4b8939.pth {result}")
70
 
71
  device = "cuda" if torch.cuda.is_available() else "cpu"
72
 
73
+ config_dict = OrderedDict(
74
+ [
75
+ ("LAION Pretrained(v0-4)-SD15", "shgao/edit-anything-v0-4-sd15"),
76
+ ("LAION Pretrained(v0-4)-SD21", "shgao/edit-anything-v0-4-sd21"),
77
+ ("LAION Pretrained(v0-3)-SD21", "shgao/edit-anything-v0-3"),
78
+ ("SAM Pretrained(v0-1)-SD21", "shgao/edit-anything-v0-1-1"),
79
+ ]
80
+ )
81
 
82
 
83
  def init_sam_model(sam_generator=None, mask_predictor=None):
 
87
  model_type = "default"
88
  sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
89
  sam.to(device=device)
90
+ sam_generator = (
91
+ SamAutomaticMaskGenerator(
92
+ sam) if sam_generator is None else sam_generator
93
+ )
94
  mask_predictor = SamPredictor(
95
  sam) if mask_predictor is None else mask_predictor
96
  return sam_generator, mask_predictor
 
103
 
104
  def init_blip_model():
105
  blip_model = Blip2ForConditionalGeneration.from_pretrained(
106
+ "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto"
107
+ )
108
  return blip_model
109
 
110
 
111
  def get_pipeline_embeds(pipeline, prompt, negative_prompt, device):
112
  # https://github.com/huggingface/diffusers/issues/2136
113
+ """Get pipeline embeds for prompts bigger than the maxlength of the pipe
114
  :param pipeline:
115
  :param prompt:
116
  :param negative_prompt:
 
120
  max_length = pipeline.tokenizer.model_max_length
121
 
122
  # simple way to determine length of tokens
123
+ count_prompt = len(re.split(r", ", prompt))
124
+ count_negative_prompt = len(re.split(r", ", negative_prompt))
125
 
126
  # create the tensor based on which prompt is longer
127
  if count_prompt >= count_negative_prompt:
128
  input_ids = pipeline.tokenizer(
129
+ prompt, return_tensors="pt", truncation=False
130
+ ).input_ids.to(device)
131
  shape_max_length = input_ids.shape[-1]
132
+ negative_ids = pipeline.tokenizer(
133
+ negative_prompt,
134
+ truncation=False,
135
+ padding="max_length",
136
+ max_length=shape_max_length,
137
+ return_tensors="pt",
138
+ ).input_ids.to(device)
139
  else:
140
  negative_ids = pipeline.tokenizer(
141
+ negative_prompt, return_tensors="pt", truncation=False
142
+ ).input_ids.to(device)
143
  shape_max_length = negative_ids.shape[-1]
144
+ input_ids = pipeline.tokenizer(
145
+ prompt,
146
+ return_tensors="pt",
147
+ truncation=False,
148
+ padding="max_length",
149
+ max_length=shape_max_length,
150
+ ).input_ids.to(device)
151
 
152
  concat_embeds = []
153
  neg_embeds = []
154
  for i in range(0, shape_max_length, max_length):
155
+ concat_embeds.append(pipeline.text_encoder(input_ids[:, i : i + max_length])[0])
156
+ neg_embeds.append(pipeline.text_encoder(negative_ids[:, i : i + max_length])[0])
 
 
157
 
158
  return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)
159
 
 
162
  LORA_PREFIX_UNET = "lora_unet"
163
  LORA_PREFIX_TEXT_ENCODER = "lora_te"
164
  # load LoRA weight from .safetensors
165
+ print('device: {}'.format(device))
166
  if isinstance(checkpoint_path, str):
 
167
  state_dict = load_file(checkpoint_path, device=device)
168
 
169
  updates = defaultdict(dict)
 
171
  # it is suggested to print out the key, it usually will be something like below
172
  # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
173
 
174
+ layer, elem = key.split(".", 1)
175
  updates[layer][elem] = value
176
 
177
  # directly update weight in diffusers model
178
  for layer, elems in updates.items():
179
 
180
  if "text" in layer:
181
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
 
182
  curr_layer = pipeline.text_encoder
183
  else:
184
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
 
185
  curr_layer = pipeline.unet
186
 
187
  # find the target layer
 
200
  temp_name = layer_infos.pop(0)
201
 
202
  # get elements for this layer
203
+ weight_up = elems["lora_up.weight"].to(dtype)
204
+ weight_down = elems["lora_down.weight"].to(dtype)
205
+ alpha = elems["alpha"]
206
  if alpha:
207
  alpha = alpha.item() / weight_up.shape[1]
208
  else:
 
210
 
211
  # update weight
212
  if len(weight_up.shape) == 4:
213
+ curr_layer.weight.data += (
214
+ multiplier
215
+ * alpha
216
+ * torch.mm(
217
+ weight_up.squeeze(3).squeeze(2),
218
+ weight_down.squeeze(3).squeeze(2),
219
+ )
220
+ .unsqueeze(2)
221
+ .unsqueeze(3)
222
+ )
223
  else:
224
+ curr_layer.weight.data += (
225
+ multiplier * alpha * torch.mm(weight_up, weight_down)
226
+ )
227
  else:
228
  for ckptpath in checkpoint_path:
229
  state_dict = load_file(ckptpath, device=device)
 
233
  # it is suggested to print out the key, it usually will be something like below
234
  # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
235
 
236
+ layer, elem = key.split(".", 1)
237
  updates[layer][elem] = value
238
 
239
  # directly update weight in diffusers model
240
  for layer, elems in updates.items():
241
  if "text" in layer:
242
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split(
243
+ "_"
244
+ )
245
  curr_layer = pipeline.text_encoder
246
  else:
247
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
 
248
  curr_layer = pipeline.unet
249
 
250
  # find the target layer
 
263
  temp_name = layer_infos.pop(0)
264
 
265
  # get elements for this layer
266
+ weight_up = elems["lora_up.weight"].to(dtype)
267
+ weight_down = elems["lora_down.weight"].to(dtype)
268
+ alpha = elems["alpha"]
269
  if alpha:
270
  alpha = alpha.item() / weight_up.shape[1]
271
  else:
 
273
 
274
  # update weight
275
  if len(weight_up.shape) == 4:
276
+ curr_layer.weight.data += (
277
+ multiplier
278
+ * alpha
279
+ * torch.mm(
280
+ weight_up.squeeze(3).squeeze(2),
281
+ weight_down.squeeze(3).squeeze(2),
282
+ )
283
+ .unsqueeze(2)
284
+ .unsqueeze(3)
285
+ )
286
  else:
287
+ curr_layer.weight.data += (
288
+ multiplier * alpha * torch.mm(weight_up, weight_down)
289
+ )
290
  return pipeline
291
 
292
 
293
  def make_inpaint_condition(image, image_mask):
 
294
  image = image / 255.0
295
+ assert (
296
+ image.shape[0:1] == image_mask.shape[0:1]
297
+ ), "image and image_mask must have the same image size"
 
298
  image[image_mask > 128] = -1.0 # set as masked pixel
299
  image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
300
  image = torch.from_numpy(image)
301
  return image
302
 
303
 
304
+ def obtain_generation_model(
305
+ base_model_path,
306
+ lora_model_path,
307
+ controlnet_path,
308
+ generation_only=False,
309
+ extra_inpaint=True,
310
+ lora_weight=1.0,
311
+ ):
312
  controlnet = []
313
+ controlnet.append(
314
+ ControlNetModel2.from_pretrained(
315
+ controlnet_path, torch_dtype=torch.float16)
316
+ ) # sam control
317
  if (not generation_only) and extra_inpaint: # inpainting control
318
  print("Warning: ControlNet based inpainting model only support SD1.5 for now.")
319
  controlnet.append(
320
  ControlNetModel.from_pretrained(
321
+ "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
322
+ ) # inpainting controlnet
323
  )
324
 
325
  if generation_only and extra_inpaint:
326
  pipe = StableDiffusionControlNetPipeline.from_pretrained(
327
+ base_model_path,
328
+ controlnet=controlnet,
329
+ torch_dtype=torch.float16,
330
+ safety_checker=None,
331
  )
332
  else:
333
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
334
+ base_model_path,
335
+ controlnet=controlnet,
336
+ torch_dtype=torch.float16,
337
+ safety_checker=None,
338
  )
339
  if lora_model_path is not None:
340
  pipe = load_lora_weights(
341
+ pipe, [lora_model_path], lora_weight, "cpu", torch.float32
342
+ )
343
  # speed up diffusion process with faster scheduler and memory optimization
344
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
 
345
  # remove following line if xformers is not installed
346
  pipe.enable_xformers_memory_efficient_attention()
347
 
 
350
 
351
 
352
  def obtain_tile_model(base_model_path, lora_model_path, lora_weight=1.0):
353
+ controlnet = ControlNetModel2.from_pretrained(
354
+ "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16
355
+ ) # tile controlnet
356
+ if (
357
+ base_model_path == "runwayml/stable-diffusion-v1-5"
358
+ or base_model_path == "stabilityai/stable-diffusion-2-inpainting"
359
+ ):
360
  print("base_model_path", base_model_path)
361
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
362
+ "runwayml/stable-diffusion-v1-5",
363
+ controlnet=controlnet,
364
+ torch_dtype=torch.float16,
365
+ safety_checker=None,
366
  )
367
  else:
368
  pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
369
+ base_model_path,
370
+ controlnet=controlnet,
371
+ torch_dtype=torch.float16,
372
+ safety_checker=None,
373
  )
374
  if lora_model_path is not None:
375
  pipe = load_lora_weights(
376
+ pipe, [lora_model_path], lora_weight, "cpu", torch.float32
377
+ )
378
  # speed up diffusion process with faster scheduler and memory optimization
379
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
 
380
  # remove following line if xformers is not installed
381
  pipe.enable_xformers_memory_efficient_attention()
382
 
 
387
  def show_anns(anns):
388
  if len(anns) == 0:
389
  return
390
+ sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
391
  full_img = None
392
 
393
  # for ann in sorted_anns:
394
  for i in range(len(sorted_anns)):
395
  ann = anns[i]
396
+ m = ann["segmentation"]
397
  if full_img is None:
398
  full_img = np.zeros((m.shape[0], m.shape[1], 3))
399
  map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
400
  map[m != 0] = i + 1
401
  color_mask = np.random.random((1, 3)).tolist()[0]
402
  full_img[m != 0] = color_mask
403
+ full_img = full_img * 255
404
  # anno encoding from https://github.com/LUSSeg/ImageNet-S
405
  res = np.zeros((map.shape[0], map.shape[1], 3))
406
  res[:, :, 0] = map % 256
 
411
 
412
 
413
  class EditAnythingLoraModel:
414
+ def __init__(
415
+ self,
416
+ base_model_path="../chilloutmix_NiPrunedFp32Fix",
417
+ lora_model_path="../40806/mix4",
418
+ use_blip=True,
419
+ blip_processor=None,
420
+ blip_model=None,
421
+ sam_generator=None,
422
+ controlmodel_name="LAION Pretrained(v0-4)-SD15",
423
+ # used when the base model is not an inpainting model.
424
+ extra_inpaint=True,
425
+ tile_model=None,
426
+ lora_weight=1.0,
427
+ alpha_mixing=None,
428
+ mask_predictor=None,
429
+ ):
430
  self.device = device
431
  self.use_blip = use_blip
432
 
 
436
  self.lora_model_path = lora_model_path
437
  self.defalut_enable_all_generate = False
438
  self.extra_inpaint = extra_inpaint
439
+ self.last_ref_infer = False
440
  self.pipe = obtain_generation_model(
441
+ base_model_path,
442
+ lora_model_path,
443
+ self.default_controlnet_path,
444
+ generation_only=False,
445
+ extra_inpaint=extra_inpaint,
446
+ lora_weight=lora_weight,
447
+ )
448
+ # self.pipe.load_textual_inversion("textual_inversion_cat/learned_embeds.bin")
449
  # Segment-Anything init.
450
  self.sam_generator, self.mask_predictor = init_sam_model(
451
+ sam_generator, mask_predictor
452
+ )
453
  # BLIP2 init.
454
  if use_blip:
455
  if blip_processor is not None:
 
467
  self.tile_pipe = tile_model
468
  else:
469
  self.tile_pipe = obtain_tile_model(
470
+ base_model_path, lora_model_path, lora_weight=lora_weight
471
+ )
472
 
473
  def get_blip2_text(self, image):
474
  inputs = self.blip_processor(image, return_tensors="pt").to(
475
+ self.device, torch.float16
476
+ )
477
  generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50)
478
  generated_text = self.blip_processor.batch_decode(
479
+ generated_ids, skip_special_tokens=True
480
+ )[0].strip()
481
  return generated_text
482
 
483
  def get_sam_control(self, image):
 
504
  return masks
505
 
506
  @torch.inference_mode()
507
+ def process_image_click(
508
+ self,
509
+ original_image: gr.Image,
510
+ point_prompt: gr.Radio,
511
+ clicked_points: gr.State,
512
+ image_resolution,
513
+ evt: gr.SelectData,
514
+ ):
515
  # Get the clicked coordinates
516
  clicked_coords = evt.index
517
  x, y = clicked_coords
 
525
  img = resize_image(input_image, image_resolution)
526
 
527
  # Update the clicked_points
528
+ resized_points = resize_points(
529
+ clicked_points, input_image.shape, image_resolution
530
+ )
531
  mask_click_np = self.get_click_mask(img, resized_points)
532
 
533
  # Convert mask_click_np to HWC format
534
  mask_click_np = np.transpose(mask_click_np, (1, 2, 0)) * 255.0
535
 
536
  mask_image = HWC3(mask_click_np.astype(np.uint8))
537
+ mask_image = cv2.resize(mask_image, (W, H), interpolation=cv2.INTER_LINEAR)
 
538
  # mask_image = Image.fromarray(mask_image_tmp)
539
 
540
  # Draw circles for all clicked points
 
552
 
553
  # Combine the edited_image and the mask_image using cv2.addWeighted()
554
  overlay_image = cv2.addWeighted(
555
+ edited_image,
556
+ opacity_edited,
557
+ (mask_image *
558
+ np.array([0 / 255, 255 / 255, 0 / 255])).astype(np.uint8),
559
+ opacity_mask,
560
+ 0,
561
  )
562
 
563
+ return (
564
+ Image.fromarray(overlay_image),
565
+ clicked_points,
566
+ Image.fromarray(mask_image),
567
+ )
568
 
569
  @torch.inference_mode()
570
+ def process(
571
+ self,
572
+ source_image,
573
+ enable_all_generate,
574
+ mask_image,
575
+ control_scale,
576
+ enable_auto_prompt,
577
+ a_prompt,
578
+ n_prompt,
579
+ num_samples,
580
+ image_resolution,
581
+ detect_resolution,
582
+ ddim_steps,
583
+ guess_mode,
584
+ scale,
585
+ seed,
586
+ eta,
587
+ enable_tile=True,
588
+ refine_alignment_ratio=None,
589
+ refine_image_resolution=None,
590
+ alpha_weight=0.5,
591
+ use_scale_map=False,
592
+ condition_model=None,
593
+ ref_image=None,
594
+ attention_auto_machine_weight=1.0,
595
+ gn_auto_machine_weight=1.0,
596
+ style_fidelity=0.5,
597
+ reference_attn=True,
598
+ reference_adain=True,
599
+ ref_prompt=None,
600
+ ref_sam_scale=None,
601
+ ref_inpaint_scale=None,
602
+ ref_auto_prompt=False,
603
+ ref_textinv=True,
604
+ ref_textinv_path=None,
605
+ ):
606
+
607
+ if condition_model is None or condition_model == "EditAnything":
608
  this_controlnet_path = self.default_controlnet_path
609
  else:
610
+ this_controlnet_path = condition_model
611
+ input_image = (
612
+ source_image["image"]
613
+ if isinstance(source_image, dict)
614
+ else np.array(source_image, dtype=np.uint8)
615
+ )
616
  if mask_image is None:
617
  if enable_all_generate != self.defalut_enable_all_generate:
618
  self.pipe = obtain_generation_model(
619
+ self.base_model_path,
620
+ self.lora_model_path,
621
+ this_controlnet_path,
622
+ enable_all_generate,
623
+ self.extra_inpaint,
624
+ )
625
  self.defalut_enable_all_generate = enable_all_generate
626
  if enable_all_generate:
627
+ print(
628
+ "source_image",
629
+ source_image["mask"].shape,
630
+ input_image.shape,
631
+ )
632
+ mask_image = (
633
+ np.ones((input_image.shape[0],
634
+ input_image.shape[1], 3)) * 255
635
+ )
636
  else:
637
  mask_image = source_image["mask"]
638
  else:
639
  mask_image = np.array(mask_image, dtype=np.uint8)
640
  if self.default_controlnet_path != this_controlnet_path:
641
+ print(
642
+ "To Use:",
643
+ this_controlnet_path,
644
+ "Current:",
645
+ self.default_controlnet_path,
646
+ )
647
  print("Change condition model to:", this_controlnet_path)
648
  self.pipe = obtain_generation_model(
649
+ self.base_model_path,
650
+ self.lora_model_path,
651
+ this_controlnet_path,
652
+ enable_all_generate,
653
+ self.extra_inpaint,
654
+ )
655
  self.default_controlnet_path = this_controlnet_path
656
  torch.cuda.empty_cache()
657
+ if self.last_ref_infer:
658
+ print("Redefine the model to overwrite the ref mode")
659
+ self.pipe = obtain_generation_model(
660
+ self.base_model_path,
661
+ self.lora_model_path,
662
+ this_controlnet_path,
663
+ enable_all_generate,
664
+ self.extra_inpaint,
665
+ )
666
+ self.last_ref_infer = False
667
+
668
+ if ref_image is not None:
669
+ ref_mask = ref_image["mask"]
670
+ ref_image = ref_image["image"]
671
+ if ref_auto_prompt or ref_textinv:
672
+ bbox = get_bounding_box(
673
+ np.array(ref_mask) / 255
674
+ ) # reverse the mask to make 1 the choosen region
675
+ cropped_ref_mask = ref_mask.crop(
676
+ (bbox[0], bbox[1], bbox[2], bbox[3]))
677
+ cropped_ref_image = ref_image.crop(
678
+ (bbox[0], bbox[1], bbox[2], bbox[3]))
679
+ # cropped_ref_image.save("debug.jpg")
680
+ cropped_ref_image = np.array(cropped_ref_image) * (
681
+ np.array(cropped_ref_mask)[:, :, :3] / 255.0
682
+ )
683
+ cropped_ref_image = Image.fromarray(
684
+ cropped_ref_image.astype("uint8"))
685
+
686
+ if ref_auto_prompt:
687
+ generated_prompt = self.get_blip2_text(cropped_ref_image)
688
+ ref_prompt += generated_prompt
689
+ a_prompt += generated_prompt
690
+ print("Generated ref text:", ref_prompt)
691
+ print("Generated input text:", a_prompt)
692
+ self.last_ref_infer = True
693
+ # ref_image = cropped_ref_image
694
+ # ref_mask = cropped_ref_mask
695
+ if ref_textinv:
696
+ try:
697
+ self.pipe.load_textual_inversion(ref_textinv_path)
698
+ print("Load textinv embedding from:", ref_textinv_path)
699
+ except:
700
+ print("No textinvert embeddings found.")
701
+ ref_data_path = "./utils/tmp/textinv/img"
702
+ if not os.path.exists(ref_data_path):
703
+ os.makedirs(ref_data_path)
704
+ cropped_ref_image.save(os.path.join(ref_data_path, 'ref.png'))
705
+ print("Ref image region is save to:", ref_data_path)
706
+ print("Plese finetune with run_texutal_inversion.sh in utils folder to get the textinvert embeddings.")
707
+
708
+ else:
709
+ ref_mask = None
710
 
711
  with torch.no_grad():
712
  if self.use_blip and enable_auto_prompt:
 
714
  blip2_prompt = self.get_blip2_text(input_image)
715
  print("Generated text:", blip2_prompt)
716
  if len(a_prompt) > 0:
717
+ a_prompt = blip2_prompt + "," + a_prompt
718
  else:
719
  a_prompt = blip2_prompt
720
 
 
726
  print("Generating SAM seg:")
727
  # the default SAM model is trained with 1024 size.
728
  full_segmask, detected_map = self.get_sam_control(
729
+ resize_image(input_image, detect_resolution)
730
+ )
731
 
732
  detected_map = HWC3(detected_map.astype(np.uint8))
733
  detected_map = cv2.resize(
734
+ detected_map, (W, H), interpolation=cv2.INTER_LINEAR
735
+ )
736
 
737
+ control = torch.from_numpy(detected_map.copy()).float().cuda()
 
738
  control = torch.stack([control for _ in range(num_samples)], dim=0)
739
+ control = einops.rearrange(control, "b h w c -> b c h w").clone()
740
 
741
  mask_imag_ori = HWC3(mask_image.astype(np.uint8))
742
  mask_image_tmp = cv2.resize(
743
+ mask_imag_ori, (W, H), interpolation=cv2.INTER_LINEAR
744
+ )
745
  mask_image = Image.fromarray(mask_image_tmp)
746
 
747
  if seed == -1:
 
751
  postive_prompt = a_prompt
752
  negative_prompt = n_prompt
753
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
754
+ self.pipe, postive_prompt, negative_prompt, "cuda"
755
+ )
756
  prompt_embeds = torch.cat([prompt_embeds] * num_samples, dim=0)
757
  negative_prompt_embeds = torch.cat(
758
+ [negative_prompt_embeds] * num_samples, dim=0
759
+ )
760
+
761
  if enable_all_generate and self.extra_inpaint:
762
  self.pipe.safety_checker = lambda images, clip_input: (
763
  images, False)
764
+ if ref_image is not None:
765
+ print("Not support yet.")
766
+ return
767
  x_samples = self.pipe(
768
+ prompt_embeds=prompt_embeds,
769
+ negative_prompt_embeds=negative_prompt_embeds,
770
  num_images_per_prompt=num_samples,
771
  num_inference_steps=ddim_steps,
772
  generator=generator,
 
775
  image=[control.type(torch.float16)],
776
  controlnet_conditioning_scale=[float(control_scale)],
777
  guidance_scale=scale,
778
+ guess_mode=guess_mode,
779
  ).images
780
  else:
781
  multi_condition_image = []
782
  multi_condition_scale = []
783
  multi_condition_image.append(control.type(torch.float16))
784
  multi_condition_scale.append(float(control_scale))
785
+ ref_multi_condition_scale = []
786
+ if ref_image is not None:
787
+ ref_multi_condition_scale.append(float(ref_sam_scale))
788
  if self.extra_inpaint:
789
  inpaint_image = make_inpaint_condition(img, mask_image_tmp)
 
790
  multi_condition_image.append(
791
  inpaint_image.type(torch.float16))
792
  multi_condition_scale.append(1.0)
793
+ if ref_image is not None:
794
+ ref_multi_condition_scale.append(
795
+ float(ref_inpaint_scale))
796
+ if use_scale_map:
797
+ scale_map_tmp = source_image["mask"]
798
+ tmp = HWC3(scale_map_tmp.astype(np.uint8))
799
+ scale_map_tmp = cv2.resize(
800
+ tmp, (W, H), interpolation=cv2.INTER_LINEAR)
801
+ scale_map_tmp = Image.fromarray(scale_map_tmp)
802
+ controlnet_conditioning_scale_map = 1.0 - \
803
+ prepare_mask_image(scale_map_tmp).float()
804
+ print('scale map:', controlnet_conditioning_scale_map.size())
805
+ else:
806
+ controlnet_conditioning_scale_map = None
807
+
808
+ if isinstance(self.pipe, StableDiffusionControlNetInpaintMixingPipeline):
809
+ x_samples = self.pipe(
810
+ image=img,
811
+ mask_image=mask_image,
812
+ prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
813
+ num_images_per_prompt=num_samples,
814
+ num_inference_steps=ddim_steps,
815
+ generator=generator,
816
+ controlnet_conditioning_image=multi_condition_image,
817
+ height=H,
818
+ width=W,
819
+ controlnet_conditioning_scale=multi_condition_scale,
820
+ guidance_scale=scale,
821
+ alpha_weight=alpha_weight,
822
+ controlnet_conditioning_scale_map=controlnet_conditioning_scale_map
823
+ ).images
824
+ else:
825
+ x_samples = self.pipe(
826
+ image=img,
827
+ mask_image=mask_image,
828
+ prompt_embeds=prompt_embeds,
829
+ negative_prompt_embeds=negative_prompt_embeds,
830
+ num_images_per_prompt=num_samples,
831
+ num_inference_steps=ddim_steps,
832
+ generator=generator,
833
+ controlnet_conditioning_image=multi_condition_image,
834
+ height=H,
835
+ width=W,
836
+ controlnet_conditioning_scale=multi_condition_scale,
837
+ guidance_scale=scale,
838
+ ref_image=ref_image,
839
+ ref_mask=ref_mask,
840
+ ref_prompt=ref_prompt,
841
+ attention_auto_machine_weight=attention_auto_machine_weight,
842
+ gn_auto_machine_weight=gn_auto_machine_weight,
843
+ style_fidelity=style_fidelity,
844
+ reference_attn=reference_attn,
845
+ reference_adain=reference_adain,
846
+ ref_controlnet_conditioning_scale=ref_multi_condition_scale,
847
+ guess_mode=guess_mode,
848
+ ).images
849
  results = [x_samples[i] for i in range(num_samples)]
850
 
851
  results_tile = []
852
  if enable_tile:
853
  prompt_embeds, negative_prompt_embeds = get_pipeline_embeds(
854
+ self.tile_pipe, postive_prompt, negative_prompt, "cuda"
855
+ )
856
  for i in range(num_samples):
857
+ img_tile = PIL.Image.fromarray(
858
+ resize_image(
859
+ np.array(x_samples[i]), refine_image_resolution)
860
+ )
861
  if i == 0:
862
  mask_image_tile = cv2.resize(
863
+ mask_imag_ori,
864
+ (img_tile.size[0], img_tile.size[1]),
865
+ interpolation=cv2.INTER_LINEAR,
866
+ )
867
  mask_image_tile = Image.fromarray(mask_image_tile)
868
+ if isinstance(self.pipe, StableDiffusionControlNetInpaintMixingPipeline):
869
+ x_samples_tile = self.tile_pipe(
870
+ image=img_tile,
871
+ mask_image=mask_image_tile,
872
+ prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
873
+ num_images_per_prompt=1,
874
+ num_inference_steps=ddim_steps,
875
+ generator=generator,
876
+ controlnet_conditioning_image=img_tile,
877
+ height=img_tile.size[1],
878
+ width=img_tile.size[0],
879
+ controlnet_conditioning_scale=1.0,
880
+ alignment_ratio=refine_alignment_ratio,
881
+ guidance_scale=scale,
882
+ alpha_weight=alpha_weight,
883
+ controlnet_conditioning_scale_map=controlnet_conditioning_scale_map
884
+ ).images
885
+ else:
886
+ x_samples_tile = self.tile_pipe(
887
+ image=img_tile,
888
+ mask_image=mask_image_tile,
889
+ prompt_embeds=prompt_embeds,
890
+ negative_prompt_embeds=negative_prompt_embeds,
891
+ num_images_per_prompt=1,
892
+ num_inference_steps=ddim_steps,
893
+ generator=generator,
894
+ controlnet_conditioning_image=img_tile,
895
+ height=img_tile.size[1],
896
+ width=img_tile.size[0],
897
+ controlnet_conditioning_scale=1.0,
898
+ alignment_ratio=refine_alignment_ratio,
899
+ guidance_scale=scale,
900
+ guess_mode=guess_mode,
901
+ ).images
902
  results_tile += x_samples_tile
903
 
904
  return results_tile, results, [full_segmask, mask_image], postive_prompt
editany_test.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Edit Anything trained with Stable Diffusion + ControlNet + SAM + BLIP2
2
+ import os
3
+ import gradio as gr
4
+ from diffusers.utils import load_image
5
+ from editany_lora import EditAnythingLoraModel, config_dict
6
+ from editany_demo import create_demo_template
7
+ from huggingface_hub import hf_hub_download, snapshot_download
8
+
9
+
10
+ def create_demo(process, process_image_click=None):
11
+
12
+ examples = [
13
+ [
14
+ "dudou,1girl, beautiful face, solo, candle, brown hair, long hair, <lora:flowergirl:0.9>,ulzzang-6500-v1.1,(raw photo:1.2),((photorealistic:1.4))best quality ,masterpiece, illustration, an extremely delicate and beautiful, extremely detailed ,CG ,unity ,8k wallpaper, Amazing, finely detail, masterpiece,best quality,official art,extremely detailed CG unity 8k wallpaper,absurdres, incredibly absurdres, huge filesize, ultra-detailed, highres, extremely detailed,beautiful detailed girl, extremely detailed eyes and face, beautiful detailed eyes,cinematic lighting,1girl,see-through,looking at viewer,full body,full-body shot,outdoors,arms behind back,(chinese clothes) <lora:cuteGirlMix4_v10:1>",
15
+ "(((mole))),sketches, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, bad anatomy,(long hair:1.4),DeepNegative,(fat:1.2),facing away, looking away,tilted head, lowres,bad anatomy,bad hands, text, error, missing fingers,extra digit, fewer digits, cropped, worstquality, low quality, normal quality,jpegartifacts,signature, watermark, username,blurry,bad feet,cropped,poorly drawn hands,poorly drawn face,mutation,deformed,worst quality,low quality,normal quality,jpeg artifacts,signature,watermark,extra fingers,fewer digits,extra limbs,extra arms,extra legs,malformed limbs,fused fingers,too many fingers,long neck,cross-eyed,mutated hands,polar lowres,bad body,bad proportions,gross proportions,text,error,missing fingers,missing arms,missing legs,extra digit, extra arms, extra leg, extra foot,(freckles),(mole:2)",
16
+ 5,
17
+ ],
18
+ [
19
+ "best quality, ultra high res, (photorealistic:1.4), (detailed beautiful girl:1.4), (medium breasts:0.8), looking_at_viewer, Detailed facial details, beautiful detailed eyes, (multicolored|blue|pink hair: 1.2), green eyes, slender, haunting smile, (makeup:0.3), red lips, <lora:cuteGirlMix4_v10:0.7>, highly detailed clothes, (ulzzang-6500-v1.1:0.3)",
20
+ "EasyNegative, paintings, sketches, ugly, 3d, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, manboobs, backlight,(ugly:1.3), (duplicate:1.3), (morbid:1.2), (mutilated:1.2), (tranny:1.3), mutated hands, (poorly drawn hands:1.3), blurry, (bad anatomy:1.2), (bad proportions:1.3), extra limbs, (disfigured:1.3), (more than 2 nipples:1.3), (more than 1 navel:1.3), (missing arms:1.3), (extra legs:1.3), (fused fingers:1.6), (too many fingers:1.6), (unclear eyes:1.3), bad hands, missing fingers, extra digit, (futa:1.1), bad body, double navel, mutad arms, hused arms, (puffy nipples, dark areolae, dark nipples, rei no himo, inverted nipples, long nipples), NG_DeepNegative_V1_75t, pubic hair, fat rolls, obese, bad-picture-chill-75v",
21
+ 8,
22
+ ],
23
+ [
24
+ "best quality, ultra high res, (photorealistic:1.4), (detailed beautiful girl:1.4), (medium breasts:0.8), looking_at_viewer, Detailed facial details, beautiful detailed eyes, (blue|pink hair), green eyes, slender, smile, (makeup:0.4), red lips, (full body, sitting, beach), <lora:cuteGirlMix4_v10:0.7>, highly detailed clothes, (ulzzang-6500-v1.1:0.3)",
25
+ "asyNegative, paintings, sketches, ugly, 3d, (worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, ((monochrome)), ((grayscale)), skin spots, acnes, skin blemishes, age spot, manboobs, backlight,(ugly:1.3), (duplicate:1.3), (morbid:1.2), (mutilated:1.2), (tranny:1.3), mutated hands, (poorly drawn hands:1.3), blurry, (bad anatomy:1.2), (bad proportions:1.3), extra limbs, (disfigured:1.3), (more than 2 nipples:1.3), (more than 1 navel:1.3), (missing arms:1.3), (extra legs:1.3), (fused fingers:1.6), (too many fingers:1.6), (unclear eyes:1.3), bad hands, missing fingers, extra digit, (futa:1.1), bad body, double navel, mutad arms, hused arms, (puffy nipples, dark areolae, dark nipples, rei no himo, inverted nipples, long nipples), NG_DeepNegative_V1_75t, pubic hair, fat rolls, obese, bad-picture-chill-75v",
26
+ 7,
27
+ ],
28
+ [
29
+ "mix4, whole body shot, ((8k, RAW photo, highest quality, masterpiece), High detail RAW color photo professional close-up photo, shy expression, cute, beautiful detailed girl, detailed fingers, extremely detailed eyes and face, beautiful detailed nose, beautiful detailed eyes, long eyelashes, light on face, looking at viewer, (closed mouth:1.2), 1girl, cute, young, mature face, (full body:1.3), ((small breasts)), realistic face, realistic body, beautiful detailed thigh,s, same eyes color, (realistic, photo realism:1. 37), (highest quality), (best shadow), (best illustration), ultra high resolution, physics-based rendering, cinematic lighting), solo, 1girl, highly detailed, in office, detailed office, open cardigan, ponytail contorted, beautiful eyes ,sitting in office,dating, business suit, cross-laced clothes, collared shirt, beautiful breast, small breast, Chinese dress, white pantyhose, natural breasts, pink and white hair, <lora:cuteGirlMix4_v10:1>",
30
+ "paintings, sketches, (worst quality:2), (low quality:2), (normal quality:2), cloth, underwear, bra, low-res, normal quality, ((monochrome)), ((grayscale)), skin spots, acne, skin blemishes, age spots, glans, bad nipples, long nipples, bad vagina, extra fingers,fewer fingers,strange fingers,bad hand, ng_deepnegative_v1_75t, bad-picture-chill-75v",
31
+ 7,
32
+ ],
33
+ ]
34
+ INFO = f"""
35
+ ## Generate Your Beauty powered by EditAnything https://github.com/sail-sg/EditAnything
36
+ This model is good at generating beautiful female.
37
+ """
38
+ WARNING_INFO = f"""### [NOTE] the model is collected from the Internet for demo only, please do not use it for commercial purposes.
39
+ We are not responsible for possible risks using this model.
40
+ Lora model from https://civitai.com/models/14171/cutegirlmix4 Thanks!
41
+ """
42
+ demo = create_demo_template(
43
+ process,
44
+ process_image_click,
45
+ examples=examples,
46
+ INFO=INFO,
47
+ WARNING_INFO=WARNING_INFO,
48
+ )
49
+ return demo
50
+
51
+
52
+ if __name__ == "__main__":
53
+ # sd_models_path = snapshot_download("shgao/sdmodels")
54
+ # lora_model_path = hf_hub_download(
55
+ # "mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors")
56
+ # model = EditAnythingLoraModel(base_model_path="andite/anything-v4.0",
57
+ # lora_model_path=None, use_blip=True, extra_inpaint=True,
58
+ # lora_weight=0.5,
59
+ # )
60
+ sd_models_path = snapshot_download("shgao/sdmodels")
61
+ lora_model_path = hf_hub_download(
62
+ "mlida/Cute_girl_mix4", "cuteGirlMix4_v10.safetensors"
63
+ )
64
+ model = EditAnythingLoraModel(
65
+ base_model_path=os.path.join(
66
+ sd_models_path, "chilloutmix_NiPrunedFp32Fix"),
67
+ lora_model_path=lora_model_path,
68
+ use_blip=True,
69
+ extra_inpaint=True,
70
+ lora_weight=0.5,
71
+ )
72
+ demo = create_demo(model.process, model.process_image_click)
73
+ demo.queue().launch(server_name="0.0.0.0")
environment.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: control
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.13.1
10
+ - torchvision=0.14.1
11
+ - numpy=1.23.1
12
+ - pip:
13
+ - gradio==3.35.2
14
+ - albumentations==1.3.0
15
+ - opencv-contrib-python==4.3.0.36
16
+ - imageio==2.9.0
17
+ - imageio-ffmpeg==0.4.2
18
+ - pytorch-lightning==1.5.0
19
+ - omegaconf==2.1.1
20
+ - test-tube>=0.7.5
21
+ - streamlit==1.12.1
22
+ - einops==0.3.0
23
+ - webdataset==0.2.5
24
+ - kornia==0.6
25
+ - open_clip_torch==2.0.2
26
+ - invisible-watermark>=0.1.5
27
+ - streamlit-drawable-canvas==0.8.0
28
+ - torchmetrics==0.6.0
29
+ - timm==0.6.12
30
+ - addict==2.4.0
31
+ - yapf==0.32.0
32
+ - prettytable==3.6.0
33
+ - safetensors==0.2.7
34
+ - basicsr==1.4.2
35
+ - diffusers==0.17.1
36
+ - accelerate==0.17.0
37
+ - transformers==4.30.2
38
+ - xformers
font/DejaVuSans.ttf ADDED
Binary file (757 kB). View file
 
ldm/data/__init__.py ADDED
File without changes
ldm/data/util.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ldm.modules.midas.api import load_midas_transform
4
+
5
+
6
+ class AddMiDaS(object):
7
+ def __init__(self, model_type):
8
+ super().__init__()
9
+ self.transform = load_midas_transform(model_type)
10
+
11
+ def pt2np(self, x):
12
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
13
+ return x
14
+
15
+ def np2pt(self, x):
16
+ x = torch.from_numpy(x) * 2 - 1.
17
+ return x
18
+
19
+ def __call__(self, sample):
20
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
21
+ x = self.pt2np(sample['jpg'])
22
+ x = self.transform({"image": x})["image"]
23
+ sample['midas_in'] = x
24
+ return sample
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
7
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from ldm.util import instantiate_from_config
10
+ from ldm.modules.ema import LitEma
11
+
12
+
13
+ class AutoencoderKL(pl.LightningModule):
14
+ def __init__(self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ ema_decay=None,
24
+ learn_logvar=False
25
+ ):
26
+ super().__init__()
27
+ self.learn_logvar = learn_logvar
28
+ self.image_key = image_key
29
+ self.encoder = Encoder(**ddconfig)
30
+ self.decoder = Decoder(**ddconfig)
31
+ self.loss = instantiate_from_config(lossconfig)
32
+ assert ddconfig["double_z"]
33
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ self.embed_dim = embed_dim
36
+ if colorize_nlabels is not None:
37
+ assert type(colorize_nlabels)==int
38
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ self.use_ema = ema_decay is not None
43
+ if self.use_ema:
44
+ self.ema_decay = ema_decay
45
+ assert 0. < ema_decay < 1.
46
+ self.model_ema = LitEma(self, decay=ema_decay)
47
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
+
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+
52
+ def init_from_ckpt(self, path, ignore_keys=list()):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ keys = list(sd.keys())
55
+ for k in keys:
56
+ for ik in ignore_keys:
57
+ if k.startswith(ik):
58
+ print("Deleting key {} from state_dict.".format(k))
59
+ del sd[k]
60
+ self.load_state_dict(sd, strict=False)
61
+ print(f"Restored from {path}")
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ if self.use_ema:
80
+ self.model_ema(self)
81
+
82
+ def encode(self, x):
83
+ h = self.encoder(x)
84
+ moments = self.quant_conv(h)
85
+ posterior = DiagonalGaussianDistribution(moments)
86
+ return posterior
87
+
88
+ def decode(self, z):
89
+ z = self.post_quant_conv(z)
90
+ dec = self.decoder(z)
91
+ return dec
92
+
93
+ def forward(self, input, sample_posterior=True):
94
+ posterior = self.encode(input)
95
+ if sample_posterior:
96
+ z = posterior.sample()
97
+ else:
98
+ z = posterior.mode()
99
+ dec = self.decode(z)
100
+ return dec, posterior
101
+
102
+ def get_input(self, batch, k):
103
+ x = batch[k]
104
+ if len(x.shape) == 3:
105
+ x = x[..., None]
106
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
+ return x
108
+
109
+ def training_step(self, batch, batch_idx, optimizer_idx):
110
+ inputs = self.get_input(batch, self.image_key)
111
+ reconstructions, posterior = self(inputs)
112
+
113
+ if optimizer_idx == 0:
114
+ # train encoder+decoder+logvar
115
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116
+ last_layer=self.get_last_layer(), split="train")
117
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119
+ return aeloss
120
+
121
+ if optimizer_idx == 1:
122
+ # train the discriminator
123
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124
+ last_layer=self.get_last_layer(), split="train")
125
+
126
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128
+ return discloss
129
+
130
+ def validation_step(self, batch, batch_idx):
131
+ log_dict = self._validation_step(batch, batch_idx)
132
+ with self.ema_scope():
133
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134
+ return log_dict
135
+
136
+ def _validation_step(self, batch, batch_idx, postfix=""):
137
+ inputs = self.get_input(batch, self.image_key)
138
+ reconstructions, posterior = self(inputs)
139
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140
+ last_layer=self.get_last_layer(), split="val"+postfix)
141
+
142
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143
+ last_layer=self.get_last_layer(), split="val"+postfix)
144
+
145
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146
+ self.log_dict(log_dict_ae)
147
+ self.log_dict(log_dict_disc)
148
+ return self.log_dict
149
+
150
+ def configure_optimizers(self):
151
+ lr = self.learning_rate
152
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154
+ if self.learn_logvar:
155
+ print(f"{self.__class__.__name__}: Learning logvar")
156
+ ae_params_list.append(self.loss.logvar)
157
+ opt_ae = torch.optim.Adam(ae_params_list,
158
+ lr=lr, betas=(0.5, 0.9))
159
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160
+ lr=lr, betas=(0.5, 0.9))
161
+ return [opt_ae, opt_disc], []
162
+
163
+ def get_last_layer(self):
164
+ return self.decoder.conv_out.weight
165
+
166
+ @torch.no_grad()
167
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168
+ log = dict()
169
+ x = self.get_input(batch, self.image_key)
170
+ x = x.to(self.device)
171
+ if not only_inputs:
172
+ xrec, posterior = self(x)
173
+ if x.shape[1] > 3:
174
+ # colorize with random projection
175
+ assert xrec.shape[1] > 3
176
+ x = self.to_rgb(x)
177
+ xrec = self.to_rgb(xrec)
178
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179
+ log["reconstructions"] = xrec
180
+ if log_ema or self.use_ema:
181
+ with self.ema_scope():
182
+ xrec_ema, posterior_ema = self(x)
183
+ if x.shape[1] > 3:
184
+ # colorize with random projection
185
+ assert xrec_ema.shape[1] > 3
186
+ xrec_ema = self.to_rgb(xrec_ema)
187
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188
+ log["reconstructions_ema"] = xrec_ema
189
+ log["inputs"] = x
190
+ return log
191
+
192
+ def to_rgb(self, x):
193
+ assert self.image_key == "segmentation"
194
+ if not hasattr(self, "colorize"):
195
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196
+ x = F.conv2d(x, weight=self.colorize)
197
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198
+ return x
199
+
200
+
201
+ class IdentityFirstStage(torch.nn.Module):
202
+ def __init__(self, *args, vq_interface=False, **kwargs):
203
+ self.vq_interface = vq_interface
204
+ super().__init__()
205
+
206
+ def encode(self, x, *args, **kwargs):
207
+ return x
208
+
209
+ def decode(self, x, *args, **kwargs):
210
+ return x
211
+
212
+ def quantize(self, x, *args, **kwargs):
213
+ if self.vq_interface:
214
+ return x, None, [None, None, None]
215
+ return x
216
+
217
+ def forward(self, x, *args, **kwargs):
218
+ return x
219
+
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+
88
+ elif isinstance(conditioning, list):
89
+ for ctmp in conditioning:
90
+ if ctmp.shape[0] != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+
93
+ else:
94
+ if conditioning.shape[0] != batch_size:
95
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ C, H, W = shape
100
+ size = (batch_size, C, H, W)
101
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
+
103
+ samples, intermediates = self.ddim_sampling(conditioning, size,
104
+ callback=callback,
105
+ img_callback=img_callback,
106
+ quantize_denoised=quantize_x0,
107
+ mask=mask, x0=x0,
108
+ ddim_use_original_steps=False,
109
+ noise_dropout=noise_dropout,
110
+ temperature=temperature,
111
+ score_corrector=score_corrector,
112
+ corrector_kwargs=corrector_kwargs,
113
+ x_T=x_T,
114
+ log_every_t=log_every_t,
115
+ unconditional_guidance_scale=unconditional_guidance_scale,
116
+ unconditional_conditioning=unconditional_conditioning,
117
+ dynamic_threshold=dynamic_threshold,
118
+ ucg_schedule=ucg_schedule
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def ddim_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
+ ucg_schedule=None):
130
+ device = self.model.betas.device
131
+ b = shape[0]
132
+ if x_T is None:
133
+ img = torch.randn(shape, device=device)
134
+ else:
135
+ img = x_T
136
+
137
+ if timesteps is None:
138
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
+ elif timesteps is not None and not ddim_use_original_steps:
140
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
+ timesteps = self.ddim_timesteps[:subset_end]
142
+
143
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
147
+
148
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
+
150
+ for i, step in enumerate(iterator):
151
+ index = total_steps - i - 1
152
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
153
+
154
+ if mask is not None:
155
+ assert x0 is not None
156
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
+ img = img_orig * mask + (1. - mask) * img
158
+
159
+ if ucg_schedule is not None:
160
+ assert len(ucg_schedule) == len(time_range)
161
+ unconditional_guidance_scale = ucg_schedule[i]
162
+
163
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
+ quantize_denoised=quantize_denoised, temperature=temperature,
165
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
166
+ corrector_kwargs=corrector_kwargs,
167
+ unconditional_guidance_scale=unconditional_guidance_scale,
168
+ unconditional_conditioning=unconditional_conditioning,
169
+ dynamic_threshold=dynamic_threshold)
170
+ img, pred_x0 = outs
171
+ if callback: callback(i)
172
+ if img_callback: img_callback(pred_x0, i)
173
+
174
+ if index % log_every_t == 0 or index == total_steps - 1:
175
+ intermediates['x_inter'].append(img)
176
+ intermediates['pred_x0'].append(pred_x0)
177
+
178
+ return img, intermediates
179
+
180
+ @torch.no_grad()
181
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
184
+ dynamic_threshold=None):
185
+ b, *_, device = *x.shape, x.device
186
+
187
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
+ model_output = self.model.apply_model(x, t, c)
189
+ else:
190
+ x_in = torch.cat([x] * 2)
191
+ t_in = torch.cat([t] * 2)
192
+ if isinstance(c, dict):
193
+ assert isinstance(unconditional_conditioning, dict)
194
+ c_in = dict()
195
+ for k in c:
196
+ if isinstance(c[k], list):
197
+ c_in[k] = [torch.cat([
198
+ unconditional_conditioning[k][i],
199
+ c[k][i]]) for i in range(len(c[k]))]
200
+ else:
201
+ c_in[k] = torch.cat([
202
+ unconditional_conditioning[k],
203
+ c[k]])
204
+ elif isinstance(c, list):
205
+ c_in = list()
206
+ assert isinstance(unconditional_conditioning, list)
207
+ for i in range(len(c)):
208
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
209
+ else:
210
+ c_in = torch.cat([unconditional_conditioning, c])
211
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
212
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
213
+
214
+ if self.model.parameterization == "v":
215
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
216
+ else:
217
+ e_t = model_output
218
+
219
+ if score_corrector is not None:
220
+ assert self.model.parameterization == "eps", 'not implemented'
221
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
222
+
223
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
224
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
225
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
226
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
227
+ # select parameters corresponding to the currently considered timestep
228
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
229
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
230
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
231
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
232
+
233
+ # current prediction for x_0
234
+ if self.model.parameterization != "v":
235
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
236
+ else:
237
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
238
+
239
+ if quantize_denoised:
240
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
241
+
242
+ if dynamic_threshold is not None:
243
+ raise NotImplementedError()
244
+
245
+ # direction pointing to x_t
246
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
247
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
248
+ if noise_dropout > 0.:
249
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
250
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
251
+ return x_prev, pred_x0
252
+
253
+ @torch.no_grad()
254
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
255
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
256
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
257
+
258
+ assert t_enc <= num_reference_steps
259
+ num_steps = t_enc
260
+
261
+ if use_original_steps:
262
+ alphas_next = self.alphas_cumprod[:num_steps]
263
+ alphas = self.alphas_cumprod_prev[:num_steps]
264
+ else:
265
+ alphas_next = self.ddim_alphas[:num_steps]
266
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
267
+
268
+ x_next = x0
269
+ intermediates = []
270
+ inter_steps = []
271
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
272
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
273
+ if unconditional_guidance_scale == 1.:
274
+ noise_pred = self.model.apply_model(x_next, t, c)
275
+ else:
276
+ assert unconditional_conditioning is not None
277
+ e_t_uncond, noise_pred = torch.chunk(
278
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
279
+ torch.cat((unconditional_conditioning, c))), 2)
280
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
281
+
282
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
283
+ weighted_noise_pred = alphas_next[i].sqrt() * (
284
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
285
+ x_next = xt_weighted + weighted_noise_pred
286
+ if return_intermediates and i % (
287
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
288
+ intermediates.append(x_next)
289
+ inter_steps.append(i)
290
+ elif return_intermediates and i >= num_steps - 2:
291
+ intermediates.append(x_next)
292
+ inter_steps.append(i)
293
+ if callback: callback(i)
294
+
295
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
296
+ if return_intermediates:
297
+ out.update({'intermediates': intermediates})
298
+ return x_next, out
299
+
300
+ @torch.no_grad()
301
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
302
+ # fast, but does not allow for exact reconstruction
303
+ # t serves as an index to gather the correct alphas
304
+ if use_original_steps:
305
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
306
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
307
+ else:
308
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
309
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
310
+
311
+ if noise is None:
312
+ noise = torch.randn_like(x0)
313
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
314
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
315
+
316
+ @torch.no_grad()
317
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
318
+ use_original_steps=False, callback=None):
319
+
320
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
321
+ timesteps = timesteps[:t_start]
322
+
323
+ time_range = np.flip(timesteps)
324
+ total_steps = timesteps.shape[0]
325
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
326
+
327
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
328
+ x_dec = x_latent
329
+ for i, step in enumerate(iterator):
330
+ index = total_steps - i - 1
331
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
332
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
333
+ unconditional_guidance_scale=unconditional_guidance_scale,
334
+ unconditional_conditioning=unconditional_conditioning)
335
+ if callback: callback(i)
336
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager, nullcontext
16
+ from functools import partial
17
+ import itertools
18
+ from tqdm import tqdm
19
+ from torchvision.utils import make_grid
20
+ from pytorch_lightning.utilities.distributed import rank_zero_only
21
+ from omegaconf import ListConfig
22
+
23
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
24
+ from ldm.modules.ema import LitEma
25
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
26
+ from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
27
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
28
+ from ldm.models.diffusion.ddim import DDIMSampler
29
+
30
+
31
+ __conditioning_keys__ = {'concat': 'c_concat',
32
+ 'crossattn': 'c_crossattn',
33
+ 'adm': 'y'}
34
+
35
+
36
+ def disabled_train(self, mode=True):
37
+ """Overwrite model.train with this function to make sure train/eval mode
38
+ does not change anymore."""
39
+ return self
40
+
41
+
42
+ def uniform_on_device(r1, r2, shape, device):
43
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
44
+
45
+
46
+ class DDPM(pl.LightningModule):
47
+ # classic DDPM with Gaussian diffusion, in image space
48
+ def __init__(self,
49
+ unet_config,
50
+ timesteps=1000,
51
+ beta_schedule="linear",
52
+ loss_type="l2",
53
+ ckpt_path=None,
54
+ ignore_keys=[],
55
+ load_only_unet=False,
56
+ monitor="val/loss",
57
+ use_ema=True,
58
+ first_stage_key="image",
59
+ image_size=256,
60
+ channels=3,
61
+ log_every_t=100,
62
+ clip_denoised=True,
63
+ linear_start=1e-4,
64
+ linear_end=2e-2,
65
+ cosine_s=8e-3,
66
+ given_betas=None,
67
+ original_elbo_weight=0.,
68
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
69
+ l_simple_weight=1.,
70
+ conditioning_key=None,
71
+ parameterization="eps", # all assuming fixed variance schedules
72
+ scheduler_config=None,
73
+ use_positional_encodings=False,
74
+ learn_logvar=False,
75
+ logvar_init=0.,
76
+ make_it_fit=False,
77
+ ucg_training=None,
78
+ reset_ema=False,
79
+ reset_num_ema_updates=False,
80
+ ):
81
+ super().__init__()
82
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
83
+ self.parameterization = parameterization
84
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
85
+ self.cond_stage_model = None
86
+ self.clip_denoised = clip_denoised
87
+ self.log_every_t = log_every_t
88
+ self.first_stage_key = first_stage_key
89
+ self.image_size = image_size # try conv?
90
+ self.channels = channels
91
+ self.use_positional_encodings = use_positional_encodings
92
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
93
+ count_params(self.model, verbose=True)
94
+ self.use_ema = use_ema
95
+ if self.use_ema:
96
+ self.model_ema = LitEma(self.model)
97
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
98
+
99
+ self.use_scheduler = scheduler_config is not None
100
+ if self.use_scheduler:
101
+ self.scheduler_config = scheduler_config
102
+
103
+ self.v_posterior = v_posterior
104
+ self.original_elbo_weight = original_elbo_weight
105
+ self.l_simple_weight = l_simple_weight
106
+
107
+ if monitor is not None:
108
+ self.monitor = monitor
109
+ self.make_it_fit = make_it_fit
110
+ if reset_ema: assert exists(ckpt_path)
111
+ if ckpt_path is not None:
112
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
113
+ if reset_ema:
114
+ assert self.use_ema
115
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
116
+ self.model_ema = LitEma(self.model)
117
+ if reset_num_ema_updates:
118
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
119
+ assert self.use_ema
120
+ self.model_ema.reset_num_updates()
121
+
122
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
123
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
124
+
125
+ self.loss_type = loss_type
126
+
127
+ self.learn_logvar = learn_logvar
128
+ logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
129
+ if self.learn_logvar:
130
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
131
+ else:
132
+ self.register_buffer('logvar', logvar)
133
+
134
+ self.ucg_training = ucg_training or dict()
135
+ if self.ucg_training:
136
+ self.ucg_prng = np.random.RandomState()
137
+
138
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
139
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
140
+ if exists(given_betas):
141
+ betas = given_betas
142
+ else:
143
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
144
+ cosine_s=cosine_s)
145
+ alphas = 1. - betas
146
+ alphas_cumprod = np.cumprod(alphas, axis=0)
147
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
148
+
149
+ timesteps, = betas.shape
150
+ self.num_timesteps = int(timesteps)
151
+ self.linear_start = linear_start
152
+ self.linear_end = linear_end
153
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
154
+
155
+ to_torch = partial(torch.tensor, dtype=torch.float32)
156
+
157
+ self.register_buffer('betas', to_torch(betas))
158
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
159
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
160
+
161
+ # calculations for diffusion q(x_t | x_{t-1}) and others
162
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
163
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
164
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
165
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
166
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
167
+
168
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
169
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
170
+ 1. - alphas_cumprod) + self.v_posterior * betas
171
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
172
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
173
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
174
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
175
+ self.register_buffer('posterior_mean_coef1', to_torch(
176
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
177
+ self.register_buffer('posterior_mean_coef2', to_torch(
178
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
179
+
180
+ if self.parameterization == "eps":
181
+ lvlb_weights = self.betas ** 2 / (
182
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
183
+ elif self.parameterization == "x0":
184
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
185
+ elif self.parameterization == "v":
186
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
187
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
188
+ else:
189
+ raise NotImplementedError("mu not supported")
190
+ lvlb_weights[0] = lvlb_weights[1]
191
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
192
+ assert not torch.isnan(self.lvlb_weights).all()
193
+
194
+ @contextmanager
195
+ def ema_scope(self, context=None):
196
+ if self.use_ema:
197
+ self.model_ema.store(self.model.parameters())
198
+ self.model_ema.copy_to(self.model)
199
+ if context is not None:
200
+ print(f"{context}: Switched to EMA weights")
201
+ try:
202
+ yield None
203
+ finally:
204
+ if self.use_ema:
205
+ self.model_ema.restore(self.model.parameters())
206
+ if context is not None:
207
+ print(f"{context}: Restored training weights")
208
+
209
+ @torch.no_grad()
210
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
211
+ sd = torch.load(path, map_location="cpu")
212
+ if "state_dict" in list(sd.keys()):
213
+ sd = sd["state_dict"]
214
+ keys = list(sd.keys())
215
+ for k in keys:
216
+ for ik in ignore_keys:
217
+ if k.startswith(ik):
218
+ print("Deleting key {} from state_dict.".format(k))
219
+ del sd[k]
220
+ if self.make_it_fit:
221
+ n_params = len([name for name, _ in
222
+ itertools.chain(self.named_parameters(),
223
+ self.named_buffers())])
224
+ for name, param in tqdm(
225
+ itertools.chain(self.named_parameters(),
226
+ self.named_buffers()),
227
+ desc="Fitting old weights to new weights",
228
+ total=n_params
229
+ ):
230
+ if not name in sd:
231
+ continue
232
+ old_shape = sd[name].shape
233
+ new_shape = param.shape
234
+ assert len(old_shape) == len(new_shape)
235
+ if len(new_shape) > 2:
236
+ # we only modify first two axes
237
+ assert new_shape[2:] == old_shape[2:]
238
+ # assumes first axis corresponds to output dim
239
+ if not new_shape == old_shape:
240
+ new_param = param.clone()
241
+ old_param = sd[name]
242
+ if len(new_shape) == 1:
243
+ for i in range(new_param.shape[0]):
244
+ new_param[i] = old_param[i % old_shape[0]]
245
+ elif len(new_shape) >= 2:
246
+ for i in range(new_param.shape[0]):
247
+ for j in range(new_param.shape[1]):
248
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
249
+
250
+ n_used_old = torch.ones(old_shape[1])
251
+ for j in range(new_param.shape[1]):
252
+ n_used_old[j % old_shape[1]] += 1
253
+ n_used_new = torch.zeros(new_shape[1])
254
+ for j in range(new_param.shape[1]):
255
+ n_used_new[j] = n_used_old[j % old_shape[1]]
256
+
257
+ n_used_new = n_used_new[None, :]
258
+ while len(n_used_new.shape) < len(new_shape):
259
+ n_used_new = n_used_new.unsqueeze(-1)
260
+ new_param /= n_used_new
261
+
262
+ sd[name] = new_param
263
+
264
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
265
+ sd, strict=False)
266
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
267
+ if len(missing) > 0:
268
+ print(f"Missing Keys:\n {missing}")
269
+ if len(unexpected) > 0:
270
+ print(f"\nUnexpected Keys:\n {unexpected}")
271
+
272
+ def q_mean_variance(self, x_start, t):
273
+ """
274
+ Get the distribution q(x_t | x_0).
275
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
276
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
277
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
278
+ """
279
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
280
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
281
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
282
+ return mean, variance, log_variance
283
+
284
+ def predict_start_from_noise(self, x_t, t, noise):
285
+ return (
286
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
287
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
288
+ )
289
+
290
+ def predict_start_from_z_and_v(self, x_t, t, v):
291
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
292
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
293
+ return (
294
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
295
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
296
+ )
297
+
298
+ def predict_eps_from_z_and_v(self, x_t, t, v):
299
+ return (
300
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
301
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
302
+ )
303
+
304
+ def q_posterior(self, x_start, x_t, t):
305
+ posterior_mean = (
306
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
307
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
308
+ )
309
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
310
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
311
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
312
+
313
+ def p_mean_variance(self, x, t, clip_denoised: bool):
314
+ model_out = self.model(x, t)
315
+ if self.parameterization == "eps":
316
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
317
+ elif self.parameterization == "x0":
318
+ x_recon = model_out
319
+ if clip_denoised:
320
+ x_recon.clamp_(-1., 1.)
321
+
322
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
323
+ return model_mean, posterior_variance, posterior_log_variance
324
+
325
+ @torch.no_grad()
326
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
327
+ b, *_, device = *x.shape, x.device
328
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
329
+ noise = noise_like(x.shape, device, repeat_noise)
330
+ # no noise when t == 0
331
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
332
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
333
+
334
+ @torch.no_grad()
335
+ def p_sample_loop(self, shape, return_intermediates=False):
336
+ device = self.betas.device
337
+ b = shape[0]
338
+ img = torch.randn(shape, device=device)
339
+ intermediates = [img]
340
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
341
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
342
+ clip_denoised=self.clip_denoised)
343
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
344
+ intermediates.append(img)
345
+ if return_intermediates:
346
+ return img, intermediates
347
+ return img
348
+
349
+ @torch.no_grad()
350
+ def sample(self, batch_size=16, return_intermediates=False):
351
+ image_size = self.image_size
352
+ channels = self.channels
353
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
354
+ return_intermediates=return_intermediates)
355
+
356
+ def q_sample(self, x_start, t, noise=None):
357
+ noise = default(noise, lambda: torch.randn_like(x_start))
358
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
359
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
360
+
361
+ def get_v(self, x, noise, t):
362
+ return (
363
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
364
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
365
+ )
366
+
367
+ def get_loss(self, pred, target, mean=True):
368
+ if self.loss_type == 'l1':
369
+ loss = (target - pred).abs()
370
+ if mean:
371
+ loss = loss.mean()
372
+ elif self.loss_type == 'l2':
373
+ if mean:
374
+ loss = torch.nn.functional.mse_loss(target, pred)
375
+ else:
376
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
377
+ else:
378
+ raise NotImplementedError("unknown loss type '{loss_type}'")
379
+
380
+ return loss
381
+
382
+ def p_losses(self, x_start, t, noise=None):
383
+ noise = default(noise, lambda: torch.randn_like(x_start))
384
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
385
+ model_out = self.model(x_noisy, t)
386
+
387
+ loss_dict = {}
388
+ if self.parameterization == "eps":
389
+ target = noise
390
+ elif self.parameterization == "x0":
391
+ target = x_start
392
+ elif self.parameterization == "v":
393
+ target = self.get_v(x_start, noise, t)
394
+ else:
395
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
396
+
397
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
398
+
399
+ log_prefix = 'train' if self.training else 'val'
400
+
401
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
402
+ loss_simple = loss.mean() * self.l_simple_weight
403
+
404
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
405
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
406
+
407
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
408
+
409
+ loss_dict.update({f'{log_prefix}/loss': loss})
410
+
411
+ return loss, loss_dict
412
+
413
+ def forward(self, x, *args, **kwargs):
414
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
415
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
416
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
417
+ return self.p_losses(x, t, *args, **kwargs)
418
+
419
+ def get_input(self, batch, k):
420
+ x = batch[k]
421
+ if len(x.shape) == 3:
422
+ x = x[..., None]
423
+ x = rearrange(x, 'b h w c -> b c h w')
424
+ x = x.to(memory_format=torch.contiguous_format).float()
425
+ return x
426
+
427
+ def shared_step(self, batch):
428
+ x = self.get_input(batch, self.first_stage_key)
429
+ loss, loss_dict = self(x)
430
+ return loss, loss_dict
431
+
432
+ def training_step(self, batch, batch_idx):
433
+ for k in self.ucg_training:
434
+ p = self.ucg_training[k]["p"]
435
+ val = self.ucg_training[k]["val"]
436
+ if val is None:
437
+ val = ""
438
+ for i in range(len(batch[k])):
439
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
440
+ batch[k][i] = val
441
+
442
+ loss, loss_dict = self.shared_step(batch)
443
+
444
+ self.log_dict(loss_dict, prog_bar=True,
445
+ logger=True, on_step=True, on_epoch=True)
446
+
447
+ self.log("global_step", self.global_step,
448
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
449
+
450
+ if self.use_scheduler:
451
+ lr = self.optimizers().param_groups[0]['lr']
452
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
453
+
454
+ return loss
455
+
456
+ @torch.no_grad()
457
+ def validation_step(self, batch, batch_idx):
458
+ _, loss_dict_no_ema = self.shared_step(batch)
459
+ with self.ema_scope():
460
+ _, loss_dict_ema = self.shared_step(batch)
461
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
462
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
463
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
464
+
465
+ def on_train_batch_end(self, *args, **kwargs):
466
+ if self.use_ema:
467
+ self.model_ema(self.model)
468
+
469
+ def _get_rows_from_list(self, samples):
470
+ n_imgs_per_row = len(samples)
471
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
472
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
473
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
474
+ return denoise_grid
475
+
476
+ @torch.no_grad()
477
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
478
+ log = dict()
479
+ x = self.get_input(batch, self.first_stage_key)
480
+ N = min(x.shape[0], N)
481
+ n_row = min(x.shape[0], n_row)
482
+ x = x.to(self.device)[:N]
483
+ log["inputs"] = x
484
+
485
+ # get diffusion row
486
+ diffusion_row = list()
487
+ x_start = x[:n_row]
488
+
489
+ for t in range(self.num_timesteps):
490
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
491
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
492
+ t = t.to(self.device).long()
493
+ noise = torch.randn_like(x_start)
494
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
495
+ diffusion_row.append(x_noisy)
496
+
497
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
498
+
499
+ if sample:
500
+ # get denoise row
501
+ with self.ema_scope("Plotting"):
502
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
503
+
504
+ log["samples"] = samples
505
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
506
+
507
+ if return_keys:
508
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
509
+ return log
510
+ else:
511
+ return {key: log[key] for key in return_keys}
512
+ return log
513
+
514
+ def configure_optimizers(self):
515
+ lr = self.learning_rate
516
+ params = list(self.model.parameters())
517
+ if self.learn_logvar:
518
+ params = params + [self.logvar]
519
+ opt = torch.optim.AdamW(params, lr=lr)
520
+ return opt
521
+
522
+
523
+ class LatentDiffusion(DDPM):
524
+ """main class"""
525
+
526
+ def __init__(self,
527
+ first_stage_config,
528
+ cond_stage_config,
529
+ num_timesteps_cond=None,
530
+ cond_stage_key="image",
531
+ cond_stage_trainable=False,
532
+ concat_mode=True,
533
+ cond_stage_forward=None,
534
+ conditioning_key=None,
535
+ scale_factor=1.0,
536
+ scale_by_std=False,
537
+ force_null_conditioning=False,
538
+ *args, **kwargs):
539
+ self.force_null_conditioning = force_null_conditioning
540
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
541
+ self.scale_by_std = scale_by_std
542
+ assert self.num_timesteps_cond <= kwargs['timesteps']
543
+ # for backwards compatibility after implementation of DiffusionWrapper
544
+ if conditioning_key is None:
545
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
546
+ if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
547
+ conditioning_key = None
548
+ ckpt_path = kwargs.pop("ckpt_path", None)
549
+ reset_ema = kwargs.pop("reset_ema", False)
550
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
551
+ ignore_keys = kwargs.pop("ignore_keys", [])
552
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
553
+ self.concat_mode = concat_mode
554
+ self.cond_stage_trainable = cond_stage_trainable
555
+ self.cond_stage_key = cond_stage_key
556
+ try:
557
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
558
+ except:
559
+ self.num_downs = 0
560
+ if not scale_by_std:
561
+ self.scale_factor = scale_factor
562
+ else:
563
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
564
+ self.instantiate_first_stage(first_stage_config)
565
+ self.instantiate_cond_stage(cond_stage_config)
566
+ self.cond_stage_forward = cond_stage_forward
567
+ self.clip_denoised = False
568
+ self.bbox_tokenizer = None
569
+
570
+ self.restarted_from_ckpt = False
571
+ if ckpt_path is not None:
572
+ self.init_from_ckpt(ckpt_path, ignore_keys)
573
+ self.restarted_from_ckpt = True
574
+ if reset_ema:
575
+ assert self.use_ema
576
+ print(
577
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
578
+ self.model_ema = LitEma(self.model)
579
+ if reset_num_ema_updates:
580
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
581
+ assert self.use_ema
582
+ self.model_ema.reset_num_updates()
583
+
584
+ def make_cond_schedule(self, ):
585
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
586
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
587
+ self.cond_ids[:self.num_timesteps_cond] = ids
588
+
589
+ @rank_zero_only
590
+ @torch.no_grad()
591
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
592
+ # only for very first batch
593
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
594
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
595
+ # set rescale weight to 1./std of encodings
596
+ print("### USING STD-RESCALING ###")
597
+ x = super().get_input(batch, self.first_stage_key)
598
+ x = x.to(self.device)
599
+ encoder_posterior = self.encode_first_stage(x)
600
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
601
+ del self.scale_factor
602
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
603
+ print(f"setting self.scale_factor to {self.scale_factor}")
604
+ print("### USING STD-RESCALING ###")
605
+
606
+ def register_schedule(self,
607
+ given_betas=None, beta_schedule="linear", timesteps=1000,
608
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
609
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
610
+
611
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
612
+ if self.shorten_cond_schedule:
613
+ self.make_cond_schedule()
614
+
615
+ def instantiate_first_stage(self, config):
616
+ model = instantiate_from_config(config)
617
+ self.first_stage_model = model.eval()
618
+ self.first_stage_model.train = disabled_train
619
+ for param in self.first_stage_model.parameters():
620
+ param.requires_grad = False
621
+
622
+ def instantiate_cond_stage(self, config):
623
+ if not self.cond_stage_trainable:
624
+ if config == "__is_first_stage__":
625
+ print("Using first stage also as cond stage.")
626
+ self.cond_stage_model = self.first_stage_model
627
+ elif config == "__is_unconditional__":
628
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
629
+ self.cond_stage_model = None
630
+ # self.be_unconditional = True
631
+ else:
632
+ model = instantiate_from_config(config)
633
+ self.cond_stage_model = model.eval()
634
+ self.cond_stage_model.train = disabled_train
635
+ for param in self.cond_stage_model.parameters():
636
+ param.requires_grad = False
637
+ else:
638
+ assert config != '__is_first_stage__'
639
+ assert config != '__is_unconditional__'
640
+ model = instantiate_from_config(config)
641
+ self.cond_stage_model = model
642
+
643
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
644
+ denoise_row = []
645
+ for zd in tqdm(samples, desc=desc):
646
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
647
+ force_not_quantize=force_no_decoder_quantization))
648
+ n_imgs_per_row = len(denoise_row)
649
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
650
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
651
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
652
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
653
+ return denoise_grid
654
+
655
+ def get_first_stage_encoding(self, encoder_posterior):
656
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
657
+ z = encoder_posterior.sample()
658
+ elif isinstance(encoder_posterior, torch.Tensor):
659
+ z = encoder_posterior
660
+ else:
661
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
662
+ return self.scale_factor * z
663
+
664
+ def get_learned_conditioning(self, c):
665
+ if self.cond_stage_forward is None:
666
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
667
+ c = self.cond_stage_model.encode(c)
668
+ if isinstance(c, DiagonalGaussianDistribution):
669
+ c = c.mode()
670
+ else:
671
+ c = self.cond_stage_model(c)
672
+ else:
673
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
674
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
675
+ return c
676
+
677
+ def meshgrid(self, h, w):
678
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
679
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
680
+
681
+ arr = torch.cat([y, x], dim=-1)
682
+ return arr
683
+
684
+ def delta_border(self, h, w):
685
+ """
686
+ :param h: height
687
+ :param w: width
688
+ :return: normalized distance to image border,
689
+ wtith min distance = 0 at border and max dist = 0.5 at image center
690
+ """
691
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
692
+ arr = self.meshgrid(h, w) / lower_right_corner
693
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
694
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
695
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
696
+ return edge_dist
697
+
698
+ def get_weighting(self, h, w, Ly, Lx, device):
699
+ weighting = self.delta_border(h, w)
700
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
701
+ self.split_input_params["clip_max_weight"], )
702
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
703
+
704
+ if self.split_input_params["tie_braker"]:
705
+ L_weighting = self.delta_border(Ly, Lx)
706
+ L_weighting = torch.clip(L_weighting,
707
+ self.split_input_params["clip_min_tie_weight"],
708
+ self.split_input_params["clip_max_tie_weight"])
709
+
710
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
711
+ weighting = weighting * L_weighting
712
+ return weighting
713
+
714
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
715
+ """
716
+ :param x: img of size (bs, c, h, w)
717
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
718
+ """
719
+ bs, nc, h, w = x.shape
720
+
721
+ # number of crops in image
722
+ Ly = (h - kernel_size[0]) // stride[0] + 1
723
+ Lx = (w - kernel_size[1]) // stride[1] + 1
724
+
725
+ if uf == 1 and df == 1:
726
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
727
+ unfold = torch.nn.Unfold(**fold_params)
728
+
729
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
730
+
731
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
732
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
733
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
734
+
735
+ elif uf > 1 and df == 1:
736
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
737
+ unfold = torch.nn.Unfold(**fold_params)
738
+
739
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
740
+ dilation=1, padding=0,
741
+ stride=(stride[0] * uf, stride[1] * uf))
742
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
743
+
744
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
745
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
746
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
747
+
748
+ elif df > 1 and uf == 1:
749
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
750
+ unfold = torch.nn.Unfold(**fold_params)
751
+
752
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
753
+ dilation=1, padding=0,
754
+ stride=(stride[0] // df, stride[1] // df))
755
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
756
+
757
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
758
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
759
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
760
+
761
+ else:
762
+ raise NotImplementedError
763
+
764
+ return fold, unfold, normalization, weighting
765
+
766
+ @torch.no_grad()
767
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
768
+ cond_key=None, return_original_cond=False, bs=None, return_x=False):
769
+ x = super().get_input(batch, k)
770
+ if bs is not None:
771
+ x = x[:bs]
772
+ x = x.to(self.device)
773
+ encoder_posterior = self.encode_first_stage(x)
774
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
775
+
776
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
777
+ if cond_key is None:
778
+ cond_key = self.cond_stage_key
779
+ if cond_key != self.first_stage_key:
780
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
781
+ xc = batch[cond_key]
782
+ elif cond_key in ['class_label', 'cls']:
783
+ xc = batch
784
+ else:
785
+ xc = super().get_input(batch, cond_key).to(self.device)
786
+ else:
787
+ xc = x
788
+ if not self.cond_stage_trainable or force_c_encode:
789
+ if isinstance(xc, dict) or isinstance(xc, list):
790
+ c = self.get_learned_conditioning(xc)
791
+ else:
792
+ c = self.get_learned_conditioning(xc.to(self.device))
793
+ else:
794
+ c = xc
795
+ if bs is not None:
796
+ c = c[:bs]
797
+
798
+ if self.use_positional_encodings:
799
+ pos_x, pos_y = self.compute_latent_shifts(batch)
800
+ ckey = __conditioning_keys__[self.model.conditioning_key]
801
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
802
+
803
+ else:
804
+ c = None
805
+ xc = None
806
+ if self.use_positional_encodings:
807
+ pos_x, pos_y = self.compute_latent_shifts(batch)
808
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
809
+ out = [z, c]
810
+ if return_first_stage_outputs:
811
+ xrec = self.decode_first_stage(z)
812
+ out.extend([x, xrec])
813
+ if return_x:
814
+ out.extend([x])
815
+ if return_original_cond:
816
+ out.append(xc)
817
+ return out
818
+
819
+ @torch.no_grad()
820
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
821
+ if predict_cids:
822
+ if z.dim() == 4:
823
+ z = torch.argmax(z.exp(), dim=1).long()
824
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
825
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
826
+
827
+ z = 1. / self.scale_factor * z
828
+ return self.first_stage_model.decode(z)
829
+
830
+ @torch.no_grad()
831
+ def encode_first_stage(self, x):
832
+ return self.first_stage_model.encode(x)
833
+
834
+ def shared_step(self, batch, **kwargs):
835
+ x, c = self.get_input(batch, self.first_stage_key)
836
+ loss = self(x, c)
837
+ return loss
838
+
839
+ def forward(self, x, c, *args, **kwargs):
840
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
841
+ if self.model.conditioning_key is not None:
842
+ assert c is not None
843
+ if self.cond_stage_trainable:
844
+ c = self.get_learned_conditioning(c)
845
+ if self.shorten_cond_schedule: # TODO: drop this option
846
+ tc = self.cond_ids[t].to(self.device)
847
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
848
+ return self.p_losses(x, c, t, *args, **kwargs)
849
+
850
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
851
+ if isinstance(cond, dict):
852
+ # hybrid case, cond is expected to be a dict
853
+ pass
854
+ else:
855
+ if not isinstance(cond, list):
856
+ cond = [cond]
857
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
858
+ cond = {key: cond}
859
+
860
+ x_recon = self.model(x_noisy, t, **cond)
861
+
862
+ if isinstance(x_recon, tuple) and not return_ids:
863
+ return x_recon[0]
864
+ else:
865
+ return x_recon
866
+
867
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
868
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
869
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
870
+
871
+ def _prior_bpd(self, x_start):
872
+ """
873
+ Get the prior KL term for the variational lower-bound, measured in
874
+ bits-per-dim.
875
+ This term can't be optimized, as it only depends on the encoder.
876
+ :param x_start: the [N x C x ...] tensor of inputs.
877
+ :return: a batch of [N] KL values (in bits), one per batch element.
878
+ """
879
+ batch_size = x_start.shape[0]
880
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
881
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
882
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
883
+ return mean_flat(kl_prior) / np.log(2.0)
884
+
885
+ def p_losses(self, x_start, cond, t, noise=None):
886
+ noise = default(noise, lambda: torch.randn_like(x_start))
887
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
888
+ model_output = self.apply_model(x_noisy, t, cond)
889
+
890
+ loss_dict = {}
891
+ prefix = 'train' if self.training else 'val'
892
+
893
+ if self.parameterization == "x0":
894
+ target = x_start
895
+ elif self.parameterization == "eps":
896
+ target = noise
897
+ elif self.parameterization == "v":
898
+ target = self.get_v(x_start, noise, t)
899
+ else:
900
+ raise NotImplementedError()
901
+
902
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
903
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
904
+
905
+ logvar_t = self.logvar[t].to(self.device)
906
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
907
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
908
+ if self.learn_logvar:
909
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
910
+ loss_dict.update({'logvar': self.logvar.data.mean()})
911
+
912
+ loss = self.l_simple_weight * loss.mean()
913
+
914
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
915
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
916
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
917
+ loss += (self.original_elbo_weight * loss_vlb)
918
+ loss_dict.update({f'{prefix}/loss': loss})
919
+
920
+ return loss, loss_dict
921
+
922
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
923
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
924
+ t_in = t
925
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
926
+
927
+ if score_corrector is not None:
928
+ assert self.parameterization == "eps"
929
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
930
+
931
+ if return_codebook_ids:
932
+ model_out, logits = model_out
933
+
934
+ if self.parameterization == "eps":
935
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
936
+ elif self.parameterization == "x0":
937
+ x_recon = model_out
938
+ else:
939
+ raise NotImplementedError()
940
+
941
+ if clip_denoised:
942
+ x_recon.clamp_(-1., 1.)
943
+ if quantize_denoised:
944
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
945
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
946
+ if return_codebook_ids:
947
+ return model_mean, posterior_variance, posterior_log_variance, logits
948
+ elif return_x0:
949
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
950
+ else:
951
+ return model_mean, posterior_variance, posterior_log_variance
952
+
953
+ @torch.no_grad()
954
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
955
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
956
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
957
+ b, *_, device = *x.shape, x.device
958
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
959
+ return_codebook_ids=return_codebook_ids,
960
+ quantize_denoised=quantize_denoised,
961
+ return_x0=return_x0,
962
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
963
+ if return_codebook_ids:
964
+ raise DeprecationWarning("Support dropped.")
965
+ model_mean, _, model_log_variance, logits = outputs
966
+ elif return_x0:
967
+ model_mean, _, model_log_variance, x0 = outputs
968
+ else:
969
+ model_mean, _, model_log_variance = outputs
970
+
971
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
972
+ if noise_dropout > 0.:
973
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
974
+ # no noise when t == 0
975
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
976
+
977
+ if return_codebook_ids:
978
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
979
+ if return_x0:
980
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
981
+ else:
982
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
983
+
984
+ @torch.no_grad()
985
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
986
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
987
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
988
+ log_every_t=None):
989
+ if not log_every_t:
990
+ log_every_t = self.log_every_t
991
+ timesteps = self.num_timesteps
992
+ if batch_size is not None:
993
+ b = batch_size if batch_size is not None else shape[0]
994
+ shape = [batch_size] + list(shape)
995
+ else:
996
+ b = batch_size = shape[0]
997
+ if x_T is None:
998
+ img = torch.randn(shape, device=self.device)
999
+ else:
1000
+ img = x_T
1001
+ intermediates = []
1002
+ if cond is not None:
1003
+ if isinstance(cond, dict):
1004
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1005
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1006
+ else:
1007
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1008
+
1009
+ if start_T is not None:
1010
+ timesteps = min(timesteps, start_T)
1011
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1012
+ total=timesteps) if verbose else reversed(
1013
+ range(0, timesteps))
1014
+ if type(temperature) == float:
1015
+ temperature = [temperature] * timesteps
1016
+
1017
+ for i in iterator:
1018
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1019
+ if self.shorten_cond_schedule:
1020
+ assert self.model.conditioning_key != 'hybrid'
1021
+ tc = self.cond_ids[ts].to(cond.device)
1022
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1023
+
1024
+ img, x0_partial = self.p_sample(img, cond, ts,
1025
+ clip_denoised=self.clip_denoised,
1026
+ quantize_denoised=quantize_denoised, return_x0=True,
1027
+ temperature=temperature[i], noise_dropout=noise_dropout,
1028
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1029
+ if mask is not None:
1030
+ assert x0 is not None
1031
+ img_orig = self.q_sample(x0, ts)
1032
+ img = img_orig * mask + (1. - mask) * img
1033
+
1034
+ if i % log_every_t == 0 or i == timesteps - 1:
1035
+ intermediates.append(x0_partial)
1036
+ if callback: callback(i)
1037
+ if img_callback: img_callback(img, i)
1038
+ return img, intermediates
1039
+
1040
+ @torch.no_grad()
1041
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1042
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1043
+ mask=None, x0=None, img_callback=None, start_T=None,
1044
+ log_every_t=None):
1045
+
1046
+ if not log_every_t:
1047
+ log_every_t = self.log_every_t
1048
+ device = self.betas.device
1049
+ b = shape[0]
1050
+ if x_T is None:
1051
+ img = torch.randn(shape, device=device)
1052
+ else:
1053
+ img = x_T
1054
+
1055
+ intermediates = [img]
1056
+ if timesteps is None:
1057
+ timesteps = self.num_timesteps
1058
+
1059
+ if start_T is not None:
1060
+ timesteps = min(timesteps, start_T)
1061
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1062
+ range(0, timesteps))
1063
+
1064
+ if mask is not None:
1065
+ assert x0 is not None
1066
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1067
+
1068
+ for i in iterator:
1069
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1070
+ if self.shorten_cond_schedule:
1071
+ assert self.model.conditioning_key != 'hybrid'
1072
+ tc = self.cond_ids[ts].to(cond.device)
1073
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1074
+
1075
+ img = self.p_sample(img, cond, ts,
1076
+ clip_denoised=self.clip_denoised,
1077
+ quantize_denoised=quantize_denoised)
1078
+ if mask is not None:
1079
+ img_orig = self.q_sample(x0, ts)
1080
+ img = img_orig * mask + (1. - mask) * img
1081
+
1082
+ if i % log_every_t == 0 or i == timesteps - 1:
1083
+ intermediates.append(img)
1084
+ if callback: callback(i)
1085
+ if img_callback: img_callback(img, i)
1086
+
1087
+ if return_intermediates:
1088
+ return img, intermediates
1089
+ return img
1090
+
1091
+ @torch.no_grad()
1092
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1093
+ verbose=True, timesteps=None, quantize_denoised=False,
1094
+ mask=None, x0=None, shape=None, **kwargs):
1095
+ if shape is None:
1096
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1097
+ if cond is not None:
1098
+ if isinstance(cond, dict):
1099
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1100
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1101
+ else:
1102
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1103
+ return self.p_sample_loop(cond,
1104
+ shape,
1105
+ return_intermediates=return_intermediates, x_T=x_T,
1106
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1107
+ mask=mask, x0=x0)
1108
+
1109
+ @torch.no_grad()
1110
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1111
+ if ddim:
1112
+ ddim_sampler = DDIMSampler(self)
1113
+ shape = (self.channels, self.image_size, self.image_size)
1114
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
1115
+ shape, cond, verbose=False, **kwargs)
1116
+
1117
+ else:
1118
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1119
+ return_intermediates=True, **kwargs)
1120
+
1121
+ return samples, intermediates
1122
+
1123
+ @torch.no_grad()
1124
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
1125
+ if null_label is not None:
1126
+ xc = null_label
1127
+ if isinstance(xc, ListConfig):
1128
+ xc = list(xc)
1129
+ if isinstance(xc, dict) or isinstance(xc, list):
1130
+ c = self.get_learned_conditioning(xc)
1131
+ else:
1132
+ if hasattr(xc, "to"):
1133
+ xc = xc.to(self.device)
1134
+ c = self.get_learned_conditioning(xc)
1135
+ else:
1136
+ if self.cond_stage_key in ["class_label", "cls"]:
1137
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
1138
+ return self.get_learned_conditioning(xc)
1139
+ else:
1140
+ raise NotImplementedError("todo")
1141
+ if isinstance(c, list): # in case the encoder gives us a list
1142
+ for i in range(len(c)):
1143
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
1144
+ else:
1145
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
1146
+ return c
1147
+
1148
+ @torch.no_grad()
1149
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
1150
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1151
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1152
+ use_ema_scope=True,
1153
+ **kwargs):
1154
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1155
+ use_ddim = ddim_steps is not None
1156
+
1157
+ log = dict()
1158
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1159
+ return_first_stage_outputs=True,
1160
+ force_c_encode=True,
1161
+ return_original_cond=True,
1162
+ bs=N)
1163
+ N = min(x.shape[0], N)
1164
+ n_row = min(x.shape[0], n_row)
1165
+ log["inputs"] = x
1166
+ log["reconstruction"] = xrec
1167
+ if self.model.conditioning_key is not None:
1168
+ if hasattr(self.cond_stage_model, "decode"):
1169
+ xc = self.cond_stage_model.decode(c)
1170
+ log["conditioning"] = xc
1171
+ elif self.cond_stage_key in ["caption", "txt"]:
1172
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1173
+ log["conditioning"] = xc
1174
+ elif self.cond_stage_key in ['class_label', "cls"]:
1175
+ try:
1176
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1177
+ log['conditioning'] = xc
1178
+ except KeyError:
1179
+ # probably no "human_label" in batch
1180
+ pass
1181
+ elif isimage(xc):
1182
+ log["conditioning"] = xc
1183
+ if ismap(xc):
1184
+ log["original_conditioning"] = self.to_rgb(xc)
1185
+
1186
+ if plot_diffusion_rows:
1187
+ # get diffusion row
1188
+ diffusion_row = list()
1189
+ z_start = z[:n_row]
1190
+ for t in range(self.num_timesteps):
1191
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1192
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1193
+ t = t.to(self.device).long()
1194
+ noise = torch.randn_like(z_start)
1195
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1196
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1197
+
1198
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1199
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1200
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1201
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1202
+ log["diffusion_row"] = diffusion_grid
1203
+
1204
+ if sample:
1205
+ # get denoise row
1206
+ with ema_scope("Sampling"):
1207
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1208
+ ddim_steps=ddim_steps, eta=ddim_eta)
1209
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1210
+ x_samples = self.decode_first_stage(samples)
1211
+ log["samples"] = x_samples
1212
+ if plot_denoise_rows:
1213
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1214
+ log["denoise_row"] = denoise_grid
1215
+
1216
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1217
+ self.first_stage_model, IdentityFirstStage):
1218
+ # also display when quantizing x0 while sampling
1219
+ with ema_scope("Plotting Quantized Denoised"):
1220
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1221
+ ddim_steps=ddim_steps, eta=ddim_eta,
1222
+ quantize_denoised=True)
1223
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1224
+ # quantize_denoised=True)
1225
+ x_samples = self.decode_first_stage(samples.to(self.device))
1226
+ log["samples_x0_quantized"] = x_samples
1227
+
1228
+ if unconditional_guidance_scale > 1.0:
1229
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1230
+ if self.model.conditioning_key == "crossattn-adm":
1231
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1232
+ with ema_scope("Sampling with classifier-free guidance"):
1233
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1234
+ ddim_steps=ddim_steps, eta=ddim_eta,
1235
+ unconditional_guidance_scale=unconditional_guidance_scale,
1236
+ unconditional_conditioning=uc,
1237
+ )
1238
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1239
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1240
+
1241
+ if inpaint:
1242
+ # make a simple center square
1243
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1244
+ mask = torch.ones(N, h, w).to(self.device)
1245
+ # zeros will be filled in
1246
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1247
+ mask = mask[:, None, ...]
1248
+ with ema_scope("Plotting Inpaint"):
1249
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1250
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1251
+ x_samples = self.decode_first_stage(samples.to(self.device))
1252
+ log["samples_inpainting"] = x_samples
1253
+ log["mask"] = mask
1254
+
1255
+ # outpaint
1256
+ mask = 1. - mask
1257
+ with ema_scope("Plotting Outpaint"):
1258
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1259
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1260
+ x_samples = self.decode_first_stage(samples.to(self.device))
1261
+ log["samples_outpainting"] = x_samples
1262
+
1263
+ if plot_progressive_rows:
1264
+ with ema_scope("Plotting Progressives"):
1265
+ img, progressives = self.progressive_denoising(c,
1266
+ shape=(self.channels, self.image_size, self.image_size),
1267
+ batch_size=N)
1268
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1269
+ log["progressive_row"] = prog_row
1270
+
1271
+ if return_keys:
1272
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1273
+ return log
1274
+ else:
1275
+ return {key: log[key] for key in return_keys}
1276
+ return log
1277
+
1278
+ def configure_optimizers(self):
1279
+ lr = self.learning_rate
1280
+ params = list(self.model.parameters())
1281
+ if self.cond_stage_trainable:
1282
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1283
+ params = params + list(self.cond_stage_model.parameters())
1284
+ if self.learn_logvar:
1285
+ print('Diffusion model optimizing logvar')
1286
+ params.append(self.logvar)
1287
+ opt = torch.optim.AdamW(params, lr=lr)
1288
+ if self.use_scheduler:
1289
+ assert 'target' in self.scheduler_config
1290
+ scheduler = instantiate_from_config(self.scheduler_config)
1291
+
1292
+ print("Setting up LambdaLR scheduler...")
1293
+ scheduler = [
1294
+ {
1295
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1296
+ 'interval': 'step',
1297
+ 'frequency': 1
1298
+ }]
1299
+ return [opt], scheduler
1300
+ return opt
1301
+
1302
+ @torch.no_grad()
1303
+ def to_rgb(self, x):
1304
+ x = x.float()
1305
+ if not hasattr(self, "colorize"):
1306
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1307
+ x = nn.functional.conv2d(x, weight=self.colorize)
1308
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1309
+ return x
1310
+
1311
+
1312
+ class DiffusionWrapper(pl.LightningModule):
1313
+ def __init__(self, diff_model_config, conditioning_key):
1314
+ super().__init__()
1315
+ self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
1316
+ self.diffusion_model = instantiate_from_config(diff_model_config)
1317
+ self.conditioning_key = conditioning_key
1318
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
1319
+
1320
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
1321
+ if self.conditioning_key is None:
1322
+ out = self.diffusion_model(x, t)
1323
+ elif self.conditioning_key == 'concat':
1324
+ xc = torch.cat([x] + c_concat, dim=1)
1325
+ out = self.diffusion_model(xc, t)
1326
+ elif self.conditioning_key == 'crossattn':
1327
+ if not self.sequential_cross_attn:
1328
+ cc = torch.cat(c_crossattn, 1)
1329
+ else:
1330
+ cc = c_crossattn
1331
+ out = self.diffusion_model(x, t, context=cc)
1332
+ elif self.conditioning_key == 'hybrid':
1333
+ xc = torch.cat([x] + c_concat, dim=1)
1334
+ cc = torch.cat(c_crossattn, 1)
1335
+ out = self.diffusion_model(xc, t, context=cc)
1336
+ elif self.conditioning_key == 'hybrid-adm':
1337
+ assert c_adm is not None
1338
+ xc = torch.cat([x] + c_concat, dim=1)
1339
+ cc = torch.cat(c_crossattn, 1)
1340
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1341
+ elif self.conditioning_key == 'crossattn-adm':
1342
+ assert c_adm is not None
1343
+ cc = torch.cat(c_crossattn, 1)
1344
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
1345
+ elif self.conditioning_key == 'adm':
1346
+ cc = c_crossattn[0]
1347
+ out = self.diffusion_model(x, t, y=cc)
1348
+ else:
1349
+ raise NotImplementedError()
1350
+
1351
+ return out
1352
+
1353
+
1354
+ class LatentUpscaleDiffusion(LatentDiffusion):
1355
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
1356
+ super().__init__(*args, **kwargs)
1357
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1358
+ assert not self.cond_stage_trainable
1359
+ self.instantiate_low_stage(low_scale_config)
1360
+ self.low_scale_key = low_scale_key
1361
+ self.noise_level_key = noise_level_key
1362
+
1363
+ def instantiate_low_stage(self, config):
1364
+ model = instantiate_from_config(config)
1365
+ self.low_scale_model = model.eval()
1366
+ self.low_scale_model.train = disabled_train
1367
+ for param in self.low_scale_model.parameters():
1368
+ param.requires_grad = False
1369
+
1370
+ @torch.no_grad()
1371
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1372
+ if not log_mode:
1373
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1374
+ else:
1375
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1376
+ force_c_encode=True, return_original_cond=True, bs=bs)
1377
+ x_low = batch[self.low_scale_key][:bs]
1378
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
1379
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
1380
+ zx, noise_level = self.low_scale_model(x_low)
1381
+ if self.noise_level_key is not None:
1382
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
1383
+ raise NotImplementedError('TODO')
1384
+
1385
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1386
+ if log_mode:
1387
+ # TODO: maybe disable if too expensive
1388
+ x_low_rec = self.low_scale_model.decode(zx)
1389
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1390
+ return z, all_conds
1391
+
1392
+ @torch.no_grad()
1393
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1394
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
1395
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
1396
+ **kwargs):
1397
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1398
+ use_ddim = ddim_steps is not None
1399
+
1400
+ log = dict()
1401
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
1402
+ log_mode=True)
1403
+ N = min(x.shape[0], N)
1404
+ n_row = min(x.shape[0], n_row)
1405
+ log["inputs"] = x
1406
+ log["reconstruction"] = xrec
1407
+ log["x_lr"] = x_low
1408
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
1409
+ if self.model.conditioning_key is not None:
1410
+ if hasattr(self.cond_stage_model, "decode"):
1411
+ xc = self.cond_stage_model.decode(c)
1412
+ log["conditioning"] = xc
1413
+ elif self.cond_stage_key in ["caption", "txt"]:
1414
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1415
+ log["conditioning"] = xc
1416
+ elif self.cond_stage_key in ['class_label', 'cls']:
1417
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1418
+ log['conditioning'] = xc
1419
+ elif isimage(xc):
1420
+ log["conditioning"] = xc
1421
+ if ismap(xc):
1422
+ log["original_conditioning"] = self.to_rgb(xc)
1423
+
1424
+ if plot_diffusion_rows:
1425
+ # get diffusion row
1426
+ diffusion_row = list()
1427
+ z_start = z[:n_row]
1428
+ for t in range(self.num_timesteps):
1429
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1430
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1431
+ t = t.to(self.device).long()
1432
+ noise = torch.randn_like(z_start)
1433
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1434
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1435
+
1436
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1437
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1438
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1439
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1440
+ log["diffusion_row"] = diffusion_grid
1441
+
1442
+ if sample:
1443
+ # get denoise row
1444
+ with ema_scope("Sampling"):
1445
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1446
+ ddim_steps=ddim_steps, eta=ddim_eta)
1447
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1448
+ x_samples = self.decode_first_stage(samples)
1449
+ log["samples"] = x_samples
1450
+ if plot_denoise_rows:
1451
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1452
+ log["denoise_row"] = denoise_grid
1453
+
1454
+ if unconditional_guidance_scale > 1.0:
1455
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1456
+ # TODO explore better "unconditional" choices for the other keys
1457
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1458
+ uc = dict()
1459
+ for k in c:
1460
+ if k == "c_crossattn":
1461
+ assert isinstance(c[k], list) and len(c[k]) == 1
1462
+ uc[k] = [uc_tmp]
1463
+ elif k == "c_adm": # todo: only run with text-based guidance?
1464
+ assert isinstance(c[k], torch.Tensor)
1465
+ #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1466
+ uc[k] = c[k]
1467
+ elif isinstance(c[k], list):
1468
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
1469
+ else:
1470
+ uc[k] = c[k]
1471
+
1472
+ with ema_scope("Sampling with classifier-free guidance"):
1473
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1474
+ ddim_steps=ddim_steps, eta=ddim_eta,
1475
+ unconditional_guidance_scale=unconditional_guidance_scale,
1476
+ unconditional_conditioning=uc,
1477
+ )
1478
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1479
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1480
+
1481
+ if plot_progressive_rows:
1482
+ with ema_scope("Plotting Progressives"):
1483
+ img, progressives = self.progressive_denoising(c,
1484
+ shape=(self.channels, self.image_size, self.image_size),
1485
+ batch_size=N)
1486
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1487
+ log["progressive_row"] = prog_row
1488
+
1489
+ return log
1490
+
1491
+
1492
+ class LatentFinetuneDiffusion(LatentDiffusion):
1493
+ """
1494
+ Basis for different finetunas, such as inpainting or depth2image
1495
+ To disable finetuning mode, set finetune_keys to None
1496
+ """
1497
+
1498
+ def __init__(self,
1499
+ concat_keys: tuple,
1500
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
1501
+ "model_ema.diffusion_modelinput_blocks00weight"
1502
+ ),
1503
+ keep_finetune_dims=4,
1504
+ # if model was trained without concat mode before and we would like to keep these channels
1505
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
1506
+ c_concat_log_end=None,
1507
+ *args, **kwargs
1508
+ ):
1509
+ ckpt_path = kwargs.pop("ckpt_path", None)
1510
+ ignore_keys = kwargs.pop("ignore_keys", list())
1511
+ super().__init__(*args, **kwargs)
1512
+ self.finetune_keys = finetune_keys
1513
+ self.concat_keys = concat_keys
1514
+ self.keep_dims = keep_finetune_dims
1515
+ self.c_concat_log_start = c_concat_log_start
1516
+ self.c_concat_log_end = c_concat_log_end
1517
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
1518
+ if exists(ckpt_path):
1519
+ self.init_from_ckpt(ckpt_path, ignore_keys)
1520
+
1521
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
1522
+ sd = torch.load(path, map_location="cpu")
1523
+ if "state_dict" in list(sd.keys()):
1524
+ sd = sd["state_dict"]
1525
+ keys = list(sd.keys())
1526
+ for k in keys:
1527
+ for ik in ignore_keys:
1528
+ if k.startswith(ik):
1529
+ print("Deleting key {} from state_dict.".format(k))
1530
+ del sd[k]
1531
+
1532
+ # make it explicit, finetune by including extra input channels
1533
+ if exists(self.finetune_keys) and k in self.finetune_keys:
1534
+ new_entry = None
1535
+ for name, param in self.named_parameters():
1536
+ if name in self.finetune_keys:
1537
+ print(
1538
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
1539
+ new_entry = torch.zeros_like(param) # zero init
1540
+ assert exists(new_entry), 'did not find matching parameter to modify'
1541
+ new_entry[:, :self.keep_dims, ...] = sd[k]
1542
+ sd[k] = new_entry
1543
+
1544
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
1545
+ sd, strict=False)
1546
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
1547
+ if len(missing) > 0:
1548
+ print(f"Missing Keys: {missing}")
1549
+ if len(unexpected) > 0:
1550
+ print(f"Unexpected Keys: {unexpected}")
1551
+
1552
+ @torch.no_grad()
1553
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1554
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1555
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1556
+ use_ema_scope=True,
1557
+ **kwargs):
1558
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
1559
+ use_ddim = ddim_steps is not None
1560
+
1561
+ log = dict()
1562
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
1563
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
1564
+ N = min(x.shape[0], N)
1565
+ n_row = min(x.shape[0], n_row)
1566
+ log["inputs"] = x
1567
+ log["reconstruction"] = xrec
1568
+ if self.model.conditioning_key is not None:
1569
+ if hasattr(self.cond_stage_model, "decode"):
1570
+ xc = self.cond_stage_model.decode(c)
1571
+ log["conditioning"] = xc
1572
+ elif self.cond_stage_key in ["caption", "txt"]:
1573
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1574
+ log["conditioning"] = xc
1575
+ elif self.cond_stage_key in ['class_label', 'cls']:
1576
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1577
+ log['conditioning'] = xc
1578
+ elif isimage(xc):
1579
+ log["conditioning"] = xc
1580
+ if ismap(xc):
1581
+ log["original_conditioning"] = self.to_rgb(xc)
1582
+
1583
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
1584
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
1585
+
1586
+ if plot_diffusion_rows:
1587
+ # get diffusion row
1588
+ diffusion_row = list()
1589
+ z_start = z[:n_row]
1590
+ for t in range(self.num_timesteps):
1591
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1592
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1593
+ t = t.to(self.device).long()
1594
+ noise = torch.randn_like(z_start)
1595
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1596
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1597
+
1598
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1599
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1600
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1601
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1602
+ log["diffusion_row"] = diffusion_grid
1603
+
1604
+ if sample:
1605
+ # get denoise row
1606
+ with ema_scope("Sampling"):
1607
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1608
+ batch_size=N, ddim=use_ddim,
1609
+ ddim_steps=ddim_steps, eta=ddim_eta)
1610
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1611
+ x_samples = self.decode_first_stage(samples)
1612
+ log["samples"] = x_samples
1613
+ if plot_denoise_rows:
1614
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1615
+ log["denoise_row"] = denoise_grid
1616
+
1617
+ if unconditional_guidance_scale > 1.0:
1618
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1619
+ uc_cat = c_cat
1620
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
1621
+ with ema_scope("Sampling with classifier-free guidance"):
1622
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1623
+ batch_size=N, ddim=use_ddim,
1624
+ ddim_steps=ddim_steps, eta=ddim_eta,
1625
+ unconditional_guidance_scale=unconditional_guidance_scale,
1626
+ unconditional_conditioning=uc_full,
1627
+ )
1628
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
1629
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1630
+
1631
+ return log
1632
+
1633
+
1634
+ class LatentInpaintDiffusion(LatentFinetuneDiffusion):
1635
+ """
1636
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
1637
+ e.g. mask as concat and text via cross-attn.
1638
+ To disable finetuning mode, set finetune_keys to None
1639
+ """
1640
+
1641
+ def __init__(self,
1642
+ concat_keys=("mask", "masked_image"),
1643
+ masked_image_key="masked_image",
1644
+ *args, **kwargs
1645
+ ):
1646
+ super().__init__(concat_keys, *args, **kwargs)
1647
+ self.masked_image_key = masked_image_key
1648
+ assert self.masked_image_key in concat_keys
1649
+
1650
+ @torch.no_grad()
1651
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1652
+ # note: restricted to non-trainable encoders currently
1653
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1654
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1655
+ force_c_encode=True, return_original_cond=True, bs=bs)
1656
+
1657
+ assert exists(self.concat_keys)
1658
+ c_cat = list()
1659
+ for ck in self.concat_keys:
1660
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1661
+ if bs is not None:
1662
+ cc = cc[:bs]
1663
+ cc = cc.to(self.device)
1664
+ bchw = z.shape
1665
+ if ck != self.masked_image_key:
1666
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1667
+ else:
1668
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1669
+ c_cat.append(cc)
1670
+ c_cat = torch.cat(c_cat, dim=1)
1671
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1672
+ if return_first_stage_outputs:
1673
+ return z, all_conds, x, xrec, xc
1674
+ return z, all_conds
1675
+
1676
+ @torch.no_grad()
1677
+ def log_images(self, *args, **kwargs):
1678
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
1679
+ log["masked_image"] = rearrange(args[0]["masked_image"],
1680
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1681
+ return log
1682
+
1683
+
1684
+ class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
1685
+ """
1686
+ condition on monocular depth estimation
1687
+ """
1688
+
1689
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
1690
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
1691
+ self.depth_model = instantiate_from_config(depth_stage_config)
1692
+ self.depth_stage_key = concat_keys[0]
1693
+
1694
+ @torch.no_grad()
1695
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1696
+ # note: restricted to non-trainable encoders currently
1697
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
1698
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1699
+ force_c_encode=True, return_original_cond=True, bs=bs)
1700
+
1701
+ assert exists(self.concat_keys)
1702
+ assert len(self.concat_keys) == 1
1703
+ c_cat = list()
1704
+ for ck in self.concat_keys:
1705
+ cc = batch[ck]
1706
+ if bs is not None:
1707
+ cc = cc[:bs]
1708
+ cc = cc.to(self.device)
1709
+ cc = self.depth_model(cc)
1710
+ cc = torch.nn.functional.interpolate(
1711
+ cc,
1712
+ size=z.shape[2:],
1713
+ mode="bicubic",
1714
+ align_corners=False,
1715
+ )
1716
+
1717
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
1718
+ keepdim=True)
1719
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
1720
+ c_cat.append(cc)
1721
+ c_cat = torch.cat(c_cat, dim=1)
1722
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1723
+ if return_first_stage_outputs:
1724
+ return z, all_conds, x, xrec, xc
1725
+ return z, all_conds
1726
+
1727
+ @torch.no_grad()
1728
+ def log_images(self, *args, **kwargs):
1729
+ log = super().log_images(*args, **kwargs)
1730
+ depth = self.depth_model(args[0][self.depth_stage_key])
1731
+ depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
1732
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
1733
+ log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
1734
+ return log
1735
+
1736
+
1737
+ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1738
+ """
1739
+ condition on low-res image (and optionally on some spatial noise augmentation)
1740
+ """
1741
+ def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
1742
+ low_scale_config=None, low_scale_key=None, *args, **kwargs):
1743
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
1744
+ self.reshuffle_patch_size = reshuffle_patch_size
1745
+ self.low_scale_model = None
1746
+ if low_scale_config is not None:
1747
+ print("Initializing a low-scale model")
1748
+ assert exists(low_scale_key)
1749
+ self.instantiate_low_stage(low_scale_config)
1750
+ self.low_scale_key = low_scale_key
1751
+
1752
+ def instantiate_low_stage(self, config):
1753
+ model = instantiate_from_config(config)
1754
+ self.low_scale_model = model.eval()
1755
+ self.low_scale_model.train = disabled_train
1756
+ for param in self.low_scale_model.parameters():
1757
+ param.requires_grad = False
1758
+
1759
+ @torch.no_grad()
1760
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1761
+ # note: restricted to non-trainable encoders currently
1762
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
1763
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1764
+ force_c_encode=True, return_original_cond=True, bs=bs)
1765
+
1766
+ assert exists(self.concat_keys)
1767
+ assert len(self.concat_keys) == 1
1768
+ # optionally make spatial noise_level here
1769
+ c_cat = list()
1770
+ noise_level = None
1771
+ for ck in self.concat_keys:
1772
+ cc = batch[ck]
1773
+ cc = rearrange(cc, 'b h w c -> b c h w')
1774
+ if exists(self.reshuffle_patch_size):
1775
+ assert isinstance(self.reshuffle_patch_size, int)
1776
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
1777
+ p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
1778
+ if bs is not None:
1779
+ cc = cc[:bs]
1780
+ cc = cc.to(self.device)
1781
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
1782
+ cc, noise_level = self.low_scale_model(cc)
1783
+ c_cat.append(cc)
1784
+ c_cat = torch.cat(c_cat, dim=1)
1785
+ if exists(noise_level):
1786
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
1787
+ else:
1788
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1789
+ if return_first_stage_outputs:
1790
+ return z, all_conds, x, xrec, xc
1791
+ return z, all_conds
1792
+
1793
+ @torch.no_grad()
1794
+ def log_images(self, *args, **kwargs):
1795
+ log = super().log_images(*args, **kwargs)
1796
+ log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1797
+ return log
ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampler import DPMSolverSampler
ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from tqdm import tqdm
5
+
6
+
7
+ class NoiseScheduleVP:
8
+ def __init__(
9
+ self,
10
+ schedule='discrete',
11
+ betas=None,
12
+ alphas_cumprod=None,
13
+ continuous_beta_0=0.1,
14
+ continuous_beta_1=20.,
15
+ ):
16
+ """Create a wrapper class for the forward SDE (VP type).
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
+ log_alpha_t = self.marginal_log_mean_coeff(t)
25
+ sigma_t = self.marginal_std(t)
26
+ lambda_t = self.marginal_lambda(t)
27
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
+ t = self.inverse_lambda(lambda_t)
29
+ ===============================================================
30
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
+ 1. For discrete-time DPMs:
32
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
+ t_i = (i + 1) / N
34
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
+ Args:
37
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
+ and
46
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
+ 2. For continuous-time DPMs:
48
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
+ schedule are the default settings in DDPM and improved-DDPM:
50
+ Args:
51
+ beta_min: A `float` number. The smallest beta for the linear schedule.
52
+ beta_max: A `float` number. The largest beta for the linear schedule.
53
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
+ T: A `float` number. The ending time of the forward process.
56
+ ===============================================================
57
+ Args:
58
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
+ 'linear' or 'cosine' for continuous-time DPMs.
60
+ Returns:
61
+ A wrapper object of the forward SDE (VP type).
62
+
63
+ ===============================================================
64
+ Example:
65
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
+ # For continuous-time DPMs (VPSDE), linear schedule:
70
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
+ """
72
+
73
+ if schedule not in ['discrete', 'linear', 'cosine']:
74
+ raise ValueError(
75
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
76
+ schedule))
77
+
78
+ self.schedule = schedule
79
+ if schedule == 'discrete':
80
+ if betas is not None:
81
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
82
+ else:
83
+ assert alphas_cumprod is not None
84
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
85
+ self.total_N = len(log_alphas)
86
+ self.T = 1.
87
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
88
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
89
+ else:
90
+ self.total_N = 1000
91
+ self.beta_0 = continuous_beta_0
92
+ self.beta_1 = continuous_beta_1
93
+ self.cosine_s = 0.008
94
+ self.cosine_beta_max = 999.
95
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
96
+ 1. + self.cosine_s) / math.pi - self.cosine_s
97
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
98
+ self.schedule = schedule
99
+ if schedule == 'cosine':
100
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
101
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
102
+ self.T = 0.9946
103
+ else:
104
+ self.T = 1.
105
+
106
+ def marginal_log_mean_coeff(self, t):
107
+ """
108
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
109
+ """
110
+ if self.schedule == 'discrete':
111
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
112
+ self.log_alpha_array.to(t.device)).reshape((-1))
113
+ elif self.schedule == 'linear':
114
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
115
+ elif self.schedule == 'cosine':
116
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
117
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
118
+ return log_alpha_t
119
+
120
+ def marginal_alpha(self, t):
121
+ """
122
+ Compute alpha_t of a given continuous-time label t in [0, T].
123
+ """
124
+ return torch.exp(self.marginal_log_mean_coeff(t))
125
+
126
+ def marginal_std(self, t):
127
+ """
128
+ Compute sigma_t of a given continuous-time label t in [0, T].
129
+ """
130
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
131
+
132
+ def marginal_lambda(self, t):
133
+ """
134
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
135
+ """
136
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
137
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
138
+ return log_mean_coeff - log_std
139
+
140
+ def inverse_lambda(self, lamb):
141
+ """
142
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
143
+ """
144
+ if self.schedule == 'linear':
145
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
146
+ Delta = self.beta_0 ** 2 + tmp
147
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
148
+ elif self.schedule == 'discrete':
149
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
150
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
151
+ torch.flip(self.t_array.to(lamb.device), [1]))
152
+ return t.reshape((-1,))
153
+ else:
154
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
155
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
156
+ 1. + self.cosine_s) / math.pi - self.cosine_s
157
+ t = t_fn(log_alpha)
158
+ return t
159
+
160
+
161
+ def model_wrapper(
162
+ model,
163
+ noise_schedule,
164
+ model_type="noise",
165
+ model_kwargs={},
166
+ guidance_type="uncond",
167
+ condition=None,
168
+ unconditional_condition=None,
169
+ guidance_scale=1.,
170
+ classifier_fn=None,
171
+ classifier_kwargs={},
172
+ ):
173
+ """Create a wrapper function for the noise prediction model.
174
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
176
+ We support four types of the diffusion model by setting `model_type`:
177
+ 1. "noise": noise prediction model. (Trained by predicting noise).
178
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
179
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
180
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
181
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
182
+ arXiv preprint arXiv:2202.00512 (2022).
183
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
184
+ arXiv preprint arXiv:2210.02303 (2022).
185
+
186
+ 4. "score": marginal score function. (Trained by denoising score matching).
187
+ Note that the score function and the noise prediction model follows a simple relationship:
188
+ ```
189
+ noise(x_t, t) = -sigma_t * score(x_t, t)
190
+ ```
191
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
192
+ 1. "uncond": unconditional sampling by DPMs.
193
+ The input `model` has the following format:
194
+ ``
195
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
196
+ ``
197
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
198
+ The input `model` has the following format:
199
+ ``
200
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
201
+ ``
202
+ The input `classifier_fn` has the following format:
203
+ ``
204
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
205
+ ``
206
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
207
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
208
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
209
+ The input `model` has the following format:
210
+ ``
211
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
212
+ ``
213
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
214
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
215
+ arXiv preprint arXiv:2207.12598 (2022).
216
+
217
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
218
+ or continuous-time labels (i.e. epsilon to T).
219
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
220
+ ``
221
+ def model_fn(x, t_continuous) -> noise:
222
+ t_input = get_model_input_time(t_continuous)
223
+ return noise_pred(model, x, t_input, **model_kwargs)
224
+ ``
225
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
226
+ ===============================================================
227
+ Args:
228
+ model: A diffusion model with the corresponding format described above.
229
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
230
+ model_type: A `str`. The parameterization type of the diffusion model.
231
+ "noise" or "x_start" or "v" or "score".
232
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
233
+ guidance_type: A `str`. The type of the guidance for sampling.
234
+ "uncond" or "classifier" or "classifier-free".
235
+ condition: A pytorch tensor. The condition for the guided sampling.
236
+ Only used for "classifier" or "classifier-free" guidance type.
237
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
238
+ Only used for "classifier-free" guidance type.
239
+ guidance_scale: A `float`. The scale for the guided sampling.
240
+ classifier_fn: A classifier function. Only used for the classifier guidance.
241
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
242
+ Returns:
243
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
244
+ """
245
+
246
+ def get_model_input_time(t_continuous):
247
+ """
248
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
249
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
250
+ For continuous-time DPMs, we just use `t_continuous`.
251
+ """
252
+ if noise_schedule.schedule == 'discrete':
253
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
254
+ else:
255
+ return t_continuous
256
+
257
+ def noise_pred_fn(x, t_continuous, cond=None):
258
+ if t_continuous.reshape((-1,)).shape[0] == 1:
259
+ t_continuous = t_continuous.expand((x.shape[0]))
260
+ t_input = get_model_input_time(t_continuous)
261
+ if cond is None:
262
+ output = model(x, t_input, **model_kwargs)
263
+ else:
264
+ output = model(x, t_input, cond, **model_kwargs)
265
+ if model_type == "noise":
266
+ return output
267
+ elif model_type == "x_start":
268
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
269
+ dims = x.dim()
270
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
271
+ elif model_type == "v":
272
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
273
+ dims = x.dim()
274
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
275
+ elif model_type == "score":
276
+ sigma_t = noise_schedule.marginal_std(t_continuous)
277
+ dims = x.dim()
278
+ return -expand_dims(sigma_t, dims) * output
279
+
280
+ def cond_grad_fn(x, t_input):
281
+ """
282
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
283
+ """
284
+ with torch.enable_grad():
285
+ x_in = x.detach().requires_grad_(True)
286
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
287
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
288
+
289
+ def model_fn(x, t_continuous):
290
+ """
291
+ The noise predicition model function that is used for DPM-Solver.
292
+ """
293
+ if t_continuous.reshape((-1,)).shape[0] == 1:
294
+ t_continuous = t_continuous.expand((x.shape[0]))
295
+ if guidance_type == "uncond":
296
+ return noise_pred_fn(x, t_continuous)
297
+ elif guidance_type == "classifier":
298
+ assert classifier_fn is not None
299
+ t_input = get_model_input_time(t_continuous)
300
+ cond_grad = cond_grad_fn(x, t_input)
301
+ sigma_t = noise_schedule.marginal_std(t_continuous)
302
+ noise = noise_pred_fn(x, t_continuous)
303
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
304
+ elif guidance_type == "classifier-free":
305
+ if guidance_scale == 1. or unconditional_condition is None:
306
+ return noise_pred_fn(x, t_continuous, cond=condition)
307
+ else:
308
+ x_in = torch.cat([x] * 2)
309
+ t_in = torch.cat([t_continuous] * 2)
310
+ c_in = torch.cat([unconditional_condition, condition])
311
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
312
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
313
+
314
+ assert model_type in ["noise", "x_start", "v"]
315
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
316
+ return model_fn
317
+
318
+
319
+ class DPM_Solver:
320
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
321
+ """Construct a DPM-Solver.
322
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
323
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
324
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
325
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
326
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
327
+ Args:
328
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
329
+ ``
330
+ def model_fn(x, t_continuous):
331
+ return noise
332
+ ``
333
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
334
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
335
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
336
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
337
+
338
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
339
+ """
340
+ self.model = model_fn
341
+ self.noise_schedule = noise_schedule
342
+ self.predict_x0 = predict_x0
343
+ self.thresholding = thresholding
344
+ self.max_val = max_val
345
+
346
+ def noise_prediction_fn(self, x, t):
347
+ """
348
+ Return the noise prediction model.
349
+ """
350
+ return self.model(x, t)
351
+
352
+ def data_prediction_fn(self, x, t):
353
+ """
354
+ Return the data prediction model (with thresholding).
355
+ """
356
+ noise = self.noise_prediction_fn(x, t)
357
+ dims = x.dim()
358
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
359
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
360
+ if self.thresholding:
361
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
362
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
363
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
364
+ x0 = torch.clamp(x0, -s, s) / s
365
+ return x0
366
+
367
+ def model_fn(self, x, t):
368
+ """
369
+ Convert the model to the noise prediction model or the data prediction model.
370
+ """
371
+ if self.predict_x0:
372
+ return self.data_prediction_fn(x, t)
373
+ else:
374
+ return self.noise_prediction_fn(x, t)
375
+
376
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
377
+ """Compute the intermediate time steps for sampling.
378
+ Args:
379
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
380
+ - 'logSNR': uniform logSNR for the time steps.
381
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
382
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
383
+ t_T: A `float`. The starting time of the sampling (default is T).
384
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
385
+ N: A `int`. The total number of the spacing of the time steps.
386
+ device: A torch device.
387
+ Returns:
388
+ A pytorch tensor of the time steps, with the shape (N + 1,).
389
+ """
390
+ if skip_type == 'logSNR':
391
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
392
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
393
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
394
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
395
+ elif skip_type == 'time_uniform':
396
+ return torch.linspace(t_T, t_0, N + 1).to(device)
397
+ elif skip_type == 'time_quadratic':
398
+ t_order = 2
399
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
400
+ return t
401
+ else:
402
+ raise ValueError(
403
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
404
+
405
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
406
+ """
407
+ Get the order of each step for sampling by the singlestep DPM-Solver.
408
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
409
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
410
+ - If order == 1:
411
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
412
+ - If order == 2:
413
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
414
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
415
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
416
+ - If order == 3:
417
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
418
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
419
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
420
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
421
+ ============================================
422
+ Args:
423
+ order: A `int`. The max order for the solver (2 or 3).
424
+ steps: A `int`. The total number of function evaluations (NFE).
425
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
426
+ - 'logSNR': uniform logSNR for the time steps.
427
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
428
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
429
+ t_T: A `float`. The starting time of the sampling (default is T).
430
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
431
+ device: A torch device.
432
+ Returns:
433
+ orders: A list of the solver order of each step.
434
+ """
435
+ if order == 3:
436
+ K = steps // 3 + 1
437
+ if steps % 3 == 0:
438
+ orders = [3, ] * (K - 2) + [2, 1]
439
+ elif steps % 3 == 1:
440
+ orders = [3, ] * (K - 1) + [1]
441
+ else:
442
+ orders = [3, ] * (K - 1) + [2]
443
+ elif order == 2:
444
+ if steps % 2 == 0:
445
+ K = steps // 2
446
+ orders = [2, ] * K
447
+ else:
448
+ K = steps // 2 + 1
449
+ orders = [2, ] * (K - 1) + [1]
450
+ elif order == 1:
451
+ K = 1
452
+ orders = [1, ] * steps
453
+ else:
454
+ raise ValueError("'order' must be '1' or '2' or '3'.")
455
+ if skip_type == 'logSNR':
456
+ # To reproduce the results in DPM-Solver paper
457
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
+ else:
459
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
460
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
461
+ return timesteps_outer, orders
462
+
463
+ def denoise_to_zero_fn(self, x, s):
464
+ """
465
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
466
+ """
467
+ return self.data_prediction_fn(x, s)
468
+
469
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
470
+ """
471
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
472
+ Args:
473
+ x: A pytorch tensor. The initial value at time `s`.
474
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
475
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
476
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
477
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
478
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
479
+ Returns:
480
+ x_t: A pytorch tensor. The approximated solution at time `t`.
481
+ """
482
+ ns = self.noise_schedule
483
+ dims = x.dim()
484
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
485
+ h = lambda_t - lambda_s
486
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
487
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
488
+ alpha_t = torch.exp(log_alpha_t)
489
+
490
+ if self.predict_x0:
491
+ phi_1 = torch.expm1(-h)
492
+ if model_s is None:
493
+ model_s = self.model_fn(x, s)
494
+ x_t = (
495
+ expand_dims(sigma_t / sigma_s, dims) * x
496
+ - expand_dims(alpha_t * phi_1, dims) * model_s
497
+ )
498
+ if return_intermediate:
499
+ return x_t, {'model_s': model_s}
500
+ else:
501
+ return x_t
502
+ else:
503
+ phi_1 = torch.expm1(h)
504
+ if model_s is None:
505
+ model_s = self.model_fn(x, s)
506
+ x_t = (
507
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
508
+ - expand_dims(sigma_t * phi_1, dims) * model_s
509
+ )
510
+ if return_intermediate:
511
+ return x_t, {'model_s': model_s}
512
+ else:
513
+ return x_t
514
+
515
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
516
+ solver_type='dpm_solver'):
517
+ """
518
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
519
+ Args:
520
+ x: A pytorch tensor. The initial value at time `s`.
521
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
522
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
523
+ r1: A `float`. The hyperparameter of the second-order solver.
524
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
525
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
526
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
527
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
528
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
529
+ Returns:
530
+ x_t: A pytorch tensor. The approximated solution at time `t`.
531
+ """
532
+ if solver_type not in ['dpm_solver', 'taylor']:
533
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
534
+ if r1 is None:
535
+ r1 = 0.5
536
+ ns = self.noise_schedule
537
+ dims = x.dim()
538
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
539
+ h = lambda_t - lambda_s
540
+ lambda_s1 = lambda_s + r1 * h
541
+ s1 = ns.inverse_lambda(lambda_s1)
542
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
543
+ s1), ns.marginal_log_mean_coeff(t)
544
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
545
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
546
+
547
+ if self.predict_x0:
548
+ phi_11 = torch.expm1(-r1 * h)
549
+ phi_1 = torch.expm1(-h)
550
+
551
+ if model_s is None:
552
+ model_s = self.model_fn(x, s)
553
+ x_s1 = (
554
+ expand_dims(sigma_s1 / sigma_s, dims) * x
555
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
556
+ )
557
+ model_s1 = self.model_fn(x_s1, s1)
558
+ if solver_type == 'dpm_solver':
559
+ x_t = (
560
+ expand_dims(sigma_t / sigma_s, dims) * x
561
+ - expand_dims(alpha_t * phi_1, dims) * model_s
562
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
563
+ )
564
+ elif solver_type == 'taylor':
565
+ x_t = (
566
+ expand_dims(sigma_t / sigma_s, dims) * x
567
+ - expand_dims(alpha_t * phi_1, dims) * model_s
568
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
569
+ model_s1 - model_s)
570
+ )
571
+ else:
572
+ phi_11 = torch.expm1(r1 * h)
573
+ phi_1 = torch.expm1(h)
574
+
575
+ if model_s is None:
576
+ model_s = self.model_fn(x, s)
577
+ x_s1 = (
578
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
579
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
580
+ )
581
+ model_s1 = self.model_fn(x_s1, s1)
582
+ if solver_type == 'dpm_solver':
583
+ x_t = (
584
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
585
+ - expand_dims(sigma_t * phi_1, dims) * model_s
586
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
587
+ )
588
+ elif solver_type == 'taylor':
589
+ x_t = (
590
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
591
+ - expand_dims(sigma_t * phi_1, dims) * model_s
592
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
593
+ )
594
+ if return_intermediate:
595
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
596
+ else:
597
+ return x_t
598
+
599
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
600
+ return_intermediate=False, solver_type='dpm_solver'):
601
+ """
602
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
603
+ Args:
604
+ x: A pytorch tensor. The initial value at time `s`.
605
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
606
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
607
+ r1: A `float`. The hyperparameter of the third-order solver.
608
+ r2: A `float`. The hyperparameter of the third-order solver.
609
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
610
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
611
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
612
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
613
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
614
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
615
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
616
+ Returns:
617
+ x_t: A pytorch tensor. The approximated solution at time `t`.
618
+ """
619
+ if solver_type not in ['dpm_solver', 'taylor']:
620
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
621
+ if r1 is None:
622
+ r1 = 1. / 3.
623
+ if r2 is None:
624
+ r2 = 2. / 3.
625
+ ns = self.noise_schedule
626
+ dims = x.dim()
627
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
628
+ h = lambda_t - lambda_s
629
+ lambda_s1 = lambda_s + r1 * h
630
+ lambda_s2 = lambda_s + r2 * h
631
+ s1 = ns.inverse_lambda(lambda_s1)
632
+ s2 = ns.inverse_lambda(lambda_s2)
633
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
634
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
635
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
636
+ s2), ns.marginal_std(t)
637
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
638
+
639
+ if self.predict_x0:
640
+ phi_11 = torch.expm1(-r1 * h)
641
+ phi_12 = torch.expm1(-r2 * h)
642
+ phi_1 = torch.expm1(-h)
643
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
644
+ phi_2 = phi_1 / h + 1.
645
+ phi_3 = phi_2 / h - 0.5
646
+
647
+ if model_s is None:
648
+ model_s = self.model_fn(x, s)
649
+ if model_s1 is None:
650
+ x_s1 = (
651
+ expand_dims(sigma_s1 / sigma_s, dims) * x
652
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
653
+ )
654
+ model_s1 = self.model_fn(x_s1, s1)
655
+ x_s2 = (
656
+ expand_dims(sigma_s2 / sigma_s, dims) * x
657
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
658
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
659
+ )
660
+ model_s2 = self.model_fn(x_s2, s2)
661
+ if solver_type == 'dpm_solver':
662
+ x_t = (
663
+ expand_dims(sigma_t / sigma_s, dims) * x
664
+ - expand_dims(alpha_t * phi_1, dims) * model_s
665
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
666
+ )
667
+ elif solver_type == 'taylor':
668
+ D1_0 = (1. / r1) * (model_s1 - model_s)
669
+ D1_1 = (1. / r2) * (model_s2 - model_s)
670
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
671
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
672
+ x_t = (
673
+ expand_dims(sigma_t / sigma_s, dims) * x
674
+ - expand_dims(alpha_t * phi_1, dims) * model_s
675
+ + expand_dims(alpha_t * phi_2, dims) * D1
676
+ - expand_dims(alpha_t * phi_3, dims) * D2
677
+ )
678
+ else:
679
+ phi_11 = torch.expm1(r1 * h)
680
+ phi_12 = torch.expm1(r2 * h)
681
+ phi_1 = torch.expm1(h)
682
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
683
+ phi_2 = phi_1 / h - 1.
684
+ phi_3 = phi_2 / h - 0.5
685
+
686
+ if model_s is None:
687
+ model_s = self.model_fn(x, s)
688
+ if model_s1 is None:
689
+ x_s1 = (
690
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
691
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
692
+ )
693
+ model_s1 = self.model_fn(x_s1, s1)
694
+ x_s2 = (
695
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
696
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
697
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
698
+ )
699
+ model_s2 = self.model_fn(x_s2, s2)
700
+ if solver_type == 'dpm_solver':
701
+ x_t = (
702
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
703
+ - expand_dims(sigma_t * phi_1, dims) * model_s
704
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
705
+ )
706
+ elif solver_type == 'taylor':
707
+ D1_0 = (1. / r1) * (model_s1 - model_s)
708
+ D1_1 = (1. / r2) * (model_s2 - model_s)
709
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
710
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
711
+ x_t = (
712
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
713
+ - expand_dims(sigma_t * phi_1, dims) * model_s
714
+ - expand_dims(sigma_t * phi_2, dims) * D1
715
+ - expand_dims(sigma_t * phi_3, dims) * D2
716
+ )
717
+
718
+ if return_intermediate:
719
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
720
+ else:
721
+ return x_t
722
+
723
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
724
+ """
725
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
726
+ Args:
727
+ x: A pytorch tensor. The initial value at time `s`.
728
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
729
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
730
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
731
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
732
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
733
+ Returns:
734
+ x_t: A pytorch tensor. The approximated solution at time `t`.
735
+ """
736
+ if solver_type not in ['dpm_solver', 'taylor']:
737
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
738
+ ns = self.noise_schedule
739
+ dims = x.dim()
740
+ model_prev_1, model_prev_0 = model_prev_list
741
+ t_prev_1, t_prev_0 = t_prev_list
742
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
743
+ t_prev_0), ns.marginal_lambda(t)
744
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
745
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
746
+ alpha_t = torch.exp(log_alpha_t)
747
+
748
+ h_0 = lambda_prev_0 - lambda_prev_1
749
+ h = lambda_t - lambda_prev_0
750
+ r0 = h_0 / h
751
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
752
+ if self.predict_x0:
753
+ if solver_type == 'dpm_solver':
754
+ x_t = (
755
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
756
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
757
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
758
+ )
759
+ elif solver_type == 'taylor':
760
+ x_t = (
761
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
762
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
763
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
764
+ )
765
+ else:
766
+ if solver_type == 'dpm_solver':
767
+ x_t = (
768
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
769
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
770
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
771
+ )
772
+ elif solver_type == 'taylor':
773
+ x_t = (
774
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
775
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
776
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
777
+ )
778
+ return x_t
779
+
780
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
781
+ """
782
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
783
+ Args:
784
+ x: A pytorch tensor. The initial value at time `s`.
785
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
786
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
787
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
788
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
789
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
790
+ Returns:
791
+ x_t: A pytorch tensor. The approximated solution at time `t`.
792
+ """
793
+ ns = self.noise_schedule
794
+ dims = x.dim()
795
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
796
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
797
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
798
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
799
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
800
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
801
+ alpha_t = torch.exp(log_alpha_t)
802
+
803
+ h_1 = lambda_prev_1 - lambda_prev_2
804
+ h_0 = lambda_prev_0 - lambda_prev_1
805
+ h = lambda_t - lambda_prev_0
806
+ r0, r1 = h_0 / h, h_1 / h
807
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
808
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
809
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
810
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
811
+ if self.predict_x0:
812
+ x_t = (
813
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
814
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
815
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
816
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
817
+ )
818
+ else:
819
+ x_t = (
820
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
821
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
822
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
823
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
824
+ )
825
+ return x_t
826
+
827
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
828
+ r2=None):
829
+ """
830
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
831
+ Args:
832
+ x: A pytorch tensor. The initial value at time `s`.
833
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
834
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
835
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
836
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
837
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
838
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
839
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
840
+ r2: A `float`. The hyperparameter of the third-order solver.
841
+ Returns:
842
+ x_t: A pytorch tensor. The approximated solution at time `t`.
843
+ """
844
+ if order == 1:
845
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
846
+ elif order == 2:
847
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
848
+ solver_type=solver_type, r1=r1)
849
+ elif order == 3:
850
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
851
+ solver_type=solver_type, r1=r1, r2=r2)
852
+ else:
853
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
854
+
855
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
856
+ """
857
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
858
+ Args:
859
+ x: A pytorch tensor. The initial value at time `s`.
860
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
861
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
862
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
863
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
864
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
865
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
866
+ Returns:
867
+ x_t: A pytorch tensor. The approximated solution at time `t`.
868
+ """
869
+ if order == 1:
870
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
871
+ elif order == 2:
872
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
873
+ elif order == 3:
874
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
875
+ else:
876
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
877
+
878
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
879
+ solver_type='dpm_solver'):
880
+ """
881
+ The adaptive step size solver based on singlestep DPM-Solver.
882
+ Args:
883
+ x: A pytorch tensor. The initial value at time `t_T`.
884
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
885
+ t_T: A `float`. The starting time of the sampling (default is T).
886
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
887
+ h_init: A `float`. The initial step size (for logSNR).
888
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
889
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
890
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
891
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
892
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
893
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
894
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
895
+ Returns:
896
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
897
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
898
+ """
899
+ ns = self.noise_schedule
900
+ s = t_T * torch.ones((x.shape[0],)).to(x)
901
+ lambda_s = ns.marginal_lambda(s)
902
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
903
+ h = h_init * torch.ones_like(s).to(x)
904
+ x_prev = x
905
+ nfe = 0
906
+ if order == 2:
907
+ r1 = 0.5
908
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
909
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
910
+ solver_type=solver_type,
911
+ **kwargs)
912
+ elif order == 3:
913
+ r1, r2 = 1. / 3., 2. / 3.
914
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
915
+ return_intermediate=True,
916
+ solver_type=solver_type)
917
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
918
+ solver_type=solver_type,
919
+ **kwargs)
920
+ else:
921
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
922
+ while torch.abs((s - t_0)).mean() > t_err:
923
+ t = ns.inverse_lambda(lambda_s + h)
924
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
925
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
926
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
927
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
928
+ E = norm_fn((x_higher - x_lower) / delta).max()
929
+ if torch.all(E <= 1.):
930
+ x = x_higher
931
+ s = t
932
+ x_prev = x_lower
933
+ lambda_s = ns.marginal_lambda(s)
934
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
935
+ nfe += order
936
+ print('adaptive solver nfe', nfe)
937
+ return x
938
+
939
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
940
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
941
+ atol=0.0078, rtol=0.05,
942
+ ):
943
+ """
944
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
945
+ =====================================================
946
+ We support the following algorithms for both noise prediction model and data prediction model:
947
+ - 'singlestep':
948
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
949
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
950
+ The total number of function evaluations (NFE) == `steps`.
951
+ Given a fixed NFE == `steps`, the sampling procedure is:
952
+ - If `order` == 1:
953
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
954
+ - If `order` == 2:
955
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
956
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
957
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
958
+ - If `order` == 3:
959
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
960
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
961
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
962
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
963
+ - 'multistep':
964
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
965
+ We initialize the first `order` values by lower order multistep solvers.
966
+ Given a fixed NFE == `steps`, the sampling procedure is:
967
+ Denote K = steps.
968
+ - If `order` == 1:
969
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
970
+ - If `order` == 2:
971
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
972
+ - If `order` == 3:
973
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
974
+ - 'singlestep_fixed':
975
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
976
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
977
+ - 'adaptive':
978
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
979
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
980
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
981
+ (NFE) and the sample quality.
982
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
983
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
984
+ =====================================================
985
+ Some advices for choosing the algorithm:
986
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
987
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
988
+ e.g.
989
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
990
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
991
+ skip_type='time_uniform', method='singlestep')
992
+ - For **guided sampling with large guidance scale** by DPMs:
993
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
994
+ e.g.
995
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
996
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
997
+ skip_type='time_uniform', method='multistep')
998
+ We support three types of `skip_type`:
999
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1000
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1001
+ - 'time_quadratic': quadratic time for the time steps.
1002
+ =====================================================
1003
+ Args:
1004
+ x: A pytorch tensor. The initial value at time `t_start`
1005
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1006
+ steps: A `int`. The total number of function evaluations (NFE).
1007
+ t_start: A `float`. The starting time of the sampling.
1008
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1009
+ t_end: A `float`. The ending time of the sampling.
1010
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1011
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1012
+ For discrete-time DPMs:
1013
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1014
+ For continuous-time DPMs:
1015
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1016
+ order: A `int`. The order of DPM-Solver.
1017
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1018
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1019
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1020
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1021
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1022
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1023
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1024
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1025
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1026
+ it for high-resolutional images.
1027
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1028
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1029
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1030
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1031
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1032
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1033
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1034
+ Returns:
1035
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1036
+ """
1037
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1038
+ t_T = self.noise_schedule.T if t_start is None else t_start
1039
+ device = x.device
1040
+ if method == 'adaptive':
1041
+ with torch.no_grad():
1042
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1043
+ solver_type=solver_type)
1044
+ elif method == 'multistep':
1045
+ assert steps >= order
1046
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1047
+ assert timesteps.shape[0] - 1 == steps
1048
+ with torch.no_grad():
1049
+ vec_t = timesteps[0].expand((x.shape[0]))
1050
+ model_prev_list = [self.model_fn(x, vec_t)]
1051
+ t_prev_list = [vec_t]
1052
+ # Init the first `order` values by lower order multistep DPM-Solver.
1053
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
1054
+ vec_t = timesteps[init_order].expand(x.shape[0])
1055
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1056
+ solver_type=solver_type)
1057
+ model_prev_list.append(self.model_fn(x, vec_t))
1058
+ t_prev_list.append(vec_t)
1059
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1060
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1061
+ vec_t = timesteps[step].expand(x.shape[0])
1062
+ if lower_order_final and steps < 15:
1063
+ step_order = min(order, steps + 1 - step)
1064
+ else:
1065
+ step_order = order
1066
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1067
+ solver_type=solver_type)
1068
+ for i in range(order - 1):
1069
+ t_prev_list[i] = t_prev_list[i + 1]
1070
+ model_prev_list[i] = model_prev_list[i + 1]
1071
+ t_prev_list[-1] = vec_t
1072
+ # We do not need to evaluate the final model value.
1073
+ if step < steps:
1074
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1075
+ elif method in ['singlestep', 'singlestep_fixed']:
1076
+ if method == 'singlestep':
1077
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1078
+ skip_type=skip_type,
1079
+ t_T=t_T, t_0=t_0,
1080
+ device=device)
1081
+ elif method == 'singlestep_fixed':
1082
+ K = steps // order
1083
+ orders = [order, ] * K
1084
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1085
+ for i, order in enumerate(orders):
1086
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1087
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1088
+ N=order, device=device)
1089
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1090
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1091
+ h = lambda_inner[-1] - lambda_inner[0]
1092
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1093
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1094
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1095
+ if denoise_to_zero:
1096
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1097
+ return x
1098
+
1099
+
1100
+ #############################################################
1101
+ # other utility functions
1102
+ #############################################################
1103
+
1104
+ def interpolate_fn(x, xp, yp):
1105
+ """
1106
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1107
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1108
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1109
+ Args:
1110
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1111
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1112
+ yp: PyTorch tensor with shape [C, K].
1113
+ Returns:
1114
+ The function values f(x), with shape [N, C].
1115
+ """
1116
+ N, K = x.shape[0], xp.shape[1]
1117
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1118
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1119
+ x_idx = torch.argmin(x_indices, dim=2)
1120
+ cand_start_idx = x_idx - 1
1121
+ start_idx = torch.where(
1122
+ torch.eq(x_idx, 0),
1123
+ torch.tensor(1, device=x.device),
1124
+ torch.where(
1125
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1126
+ ),
1127
+ )
1128
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1129
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1130
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1131
+ start_idx2 = torch.where(
1132
+ torch.eq(x_idx, 0),
1133
+ torch.tensor(0, device=x.device),
1134
+ torch.where(
1135
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1136
+ ),
1137
+ )
1138
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1139
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1140
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1141
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1142
+ return cand
1143
+
1144
+
1145
+ def expand_dims(v, dims):
1146
+ """
1147
+ Expand the tensor `v` to the dim `dims`.
1148
+ Args:
1149
+ `v`: a PyTorch tensor with shape [N].
1150
+ `dim`: a `int`.
1151
+ Returns:
1152
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1153
+ """
1154
+ return v[(...,) + (None,) * (dims - 1)]
ldm/models/diffusion/dpm_solver/sampler.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+ import torch
3
+
4
+ from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
5
+
6
+
7
+ MODEL_TYPES = {
8
+ "eps": "noise",
9
+ "v": "v"
10
+ }
11
+
12
+
13
+ class DPMSolverSampler(object):
14
+ def __init__(self, model, **kwargs):
15
+ super().__init__()
16
+ self.model = model
17
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
18
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
19
+
20
+ def register_buffer(self, name, attr):
21
+ if type(attr) == torch.Tensor:
22
+ if attr.device != torch.device("cuda"):
23
+ attr = attr.to(torch.device("cuda"))
24
+ setattr(self, name, attr)
25
+
26
+ @torch.no_grad()
27
+ def sample(self,
28
+ S,
29
+ batch_size,
30
+ shape,
31
+ conditioning=None,
32
+ callback=None,
33
+ normals_sequence=None,
34
+ img_callback=None,
35
+ quantize_x0=False,
36
+ eta=0.,
37
+ mask=None,
38
+ x0=None,
39
+ temperature=1.,
40
+ noise_dropout=0.,
41
+ score_corrector=None,
42
+ corrector_kwargs=None,
43
+ verbose=True,
44
+ x_T=None,
45
+ log_every_t=100,
46
+ unconditional_guidance_scale=1.,
47
+ unconditional_conditioning=None,
48
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
49
+ **kwargs
50
+ ):
51
+ if conditioning is not None:
52
+ if isinstance(conditioning, dict):
53
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
54
+ if cbs != batch_size:
55
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
56
+ else:
57
+ if conditioning.shape[0] != batch_size:
58
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
59
+
60
+ # sampling
61
+ C, H, W = shape
62
+ size = (batch_size, C, H, W)
63
+
64
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
65
+
66
+ device = self.model.betas.device
67
+ if x_T is None:
68
+ img = torch.randn(size, device=device)
69
+ else:
70
+ img = x_T
71
+
72
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
73
+
74
+ model_fn = model_wrapper(
75
+ lambda x, t, c: self.model.apply_model(x, t, c),
76
+ ns,
77
+ model_type=MODEL_TYPES[self.model.parameterization],
78
+ guidance_type="classifier-free",
79
+ condition=conditioning,
80
+ unconditional_condition=unconditional_conditioning,
81
+ guidance_scale=unconditional_guidance_scale,
82
+ )
83
+
84
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
85
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
86
+
87
+ return x.to(device), None
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+ from ldm.models.diffusion.sampling_util import norm_thresholding
10
+
11
+
12
+ class PLMSSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ def register_buffer(self, name, attr):
20
+ if type(attr) == torch.Tensor:
21
+ if attr.device != torch.device("cuda"):
22
+ attr = attr.to(torch.device("cuda"))
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
+ if ddim_eta != 0:
27
+ raise ValueError('ddim_eta must be 0 for PLMS')
28
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
29
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
30
+ alphas_cumprod = self.model.alphas_cumprod
31
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
32
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
33
+
34
+ self.register_buffer('betas', to_torch(self.model.betas))
35
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
36
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
37
+
38
+ # calculations for diffusion q(x_t | x_{t-1}) and others
39
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
41
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
42
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
44
+
45
+ # ddim sampling parameters
46
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
47
+ ddim_timesteps=self.ddim_timesteps,
48
+ eta=ddim_eta,verbose=verbose)
49
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
50
+ self.register_buffer('ddim_alphas', ddim_alphas)
51
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
52
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
53
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
54
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
55
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
56
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
57
+
58
+ @torch.no_grad()
59
+ def sample(self,
60
+ S,
61
+ batch_size,
62
+ shape,
63
+ conditioning=None,
64
+ callback=None,
65
+ normals_sequence=None,
66
+ img_callback=None,
67
+ quantize_x0=False,
68
+ eta=0.,
69
+ mask=None,
70
+ x0=None,
71
+ temperature=1.,
72
+ noise_dropout=0.,
73
+ score_corrector=None,
74
+ corrector_kwargs=None,
75
+ verbose=True,
76
+ x_T=None,
77
+ log_every_t=100,
78
+ unconditional_guidance_scale=1.,
79
+ unconditional_conditioning=None,
80
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
81
+ dynamic_threshold=None,
82
+ **kwargs
83
+ ):
84
+ if conditioning is not None:
85
+ if isinstance(conditioning, dict):
86
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
87
+ if cbs != batch_size:
88
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
89
+ else:
90
+ if conditioning.shape[0] != batch_size:
91
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
92
+
93
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
94
+ # sampling
95
+ C, H, W = shape
96
+ size = (batch_size, C, H, W)
97
+ print(f'Data shape for PLMS sampling is {size}')
98
+
99
+ samples, intermediates = self.plms_sampling(conditioning, size,
100
+ callback=callback,
101
+ img_callback=img_callback,
102
+ quantize_denoised=quantize_x0,
103
+ mask=mask, x0=x0,
104
+ ddim_use_original_steps=False,
105
+ noise_dropout=noise_dropout,
106
+ temperature=temperature,
107
+ score_corrector=score_corrector,
108
+ corrector_kwargs=corrector_kwargs,
109
+ x_T=x_T,
110
+ log_every_t=log_every_t,
111
+ unconditional_guidance_scale=unconditional_guidance_scale,
112
+ unconditional_conditioning=unconditional_conditioning,
113
+ dynamic_threshold=dynamic_threshold,
114
+ )
115
+ return samples, intermediates
116
+
117
+ @torch.no_grad()
118
+ def plms_sampling(self, cond, shape,
119
+ x_T=None, ddim_use_original_steps=False,
120
+ callback=None, timesteps=None, quantize_denoised=False,
121
+ mask=None, x0=None, img_callback=None, log_every_t=100,
122
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
123
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
124
+ dynamic_threshold=None):
125
+ device = self.model.betas.device
126
+ b = shape[0]
127
+ if x_T is None:
128
+ img = torch.randn(shape, device=device)
129
+ else:
130
+ img = x_T
131
+
132
+ if timesteps is None:
133
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
134
+ elif timesteps is not None and not ddim_use_original_steps:
135
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
136
+ timesteps = self.ddim_timesteps[:subset_end]
137
+
138
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
139
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
140
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
141
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
142
+
143
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
144
+ old_eps = []
145
+
146
+ for i, step in enumerate(iterator):
147
+ index = total_steps - i - 1
148
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
149
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
150
+
151
+ if mask is not None:
152
+ assert x0 is not None
153
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
154
+ img = img_orig * mask + (1. - mask) * img
155
+
156
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
157
+ quantize_denoised=quantize_denoised, temperature=temperature,
158
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
159
+ corrector_kwargs=corrector_kwargs,
160
+ unconditional_guidance_scale=unconditional_guidance_scale,
161
+ unconditional_conditioning=unconditional_conditioning,
162
+ old_eps=old_eps, t_next=ts_next,
163
+ dynamic_threshold=dynamic_threshold)
164
+ img, pred_x0, e_t = outs
165
+ old_eps.append(e_t)
166
+ if len(old_eps) >= 4:
167
+ old_eps.pop(0)
168
+ if callback: callback(i)
169
+ if img_callback: img_callback(pred_x0, i)
170
+
171
+ if index % log_every_t == 0 or index == total_steps - 1:
172
+ intermediates['x_inter'].append(img)
173
+ intermediates['pred_x0'].append(pred_x0)
174
+
175
+ return img, intermediates
176
+
177
+ @torch.no_grad()
178
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
179
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
180
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
181
+ dynamic_threshold=None):
182
+ b, *_, device = *x.shape, x.device
183
+
184
+ def get_model_output(x, t):
185
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
186
+ e_t = self.model.apply_model(x, t, c)
187
+ else:
188
+ x_in = torch.cat([x] * 2)
189
+ t_in = torch.cat([t] * 2)
190
+ c_in = torch.cat([unconditional_conditioning, c])
191
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
192
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
193
+
194
+ if score_corrector is not None:
195
+ assert self.model.parameterization == "eps"
196
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
197
+
198
+ return e_t
199
+
200
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
201
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
202
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
203
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
204
+
205
+ def get_x_prev_and_pred_x0(e_t, index):
206
+ # select parameters corresponding to the currently considered timestep
207
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
208
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
209
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
210
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
211
+
212
+ # current prediction for x_0
213
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
214
+ if quantize_denoised:
215
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
216
+ if dynamic_threshold is not None:
217
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
218
+ # direction pointing to x_t
219
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
221
+ if noise_dropout > 0.:
222
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224
+ return x_prev, pred_x0
225
+
226
+ e_t = get_model_output(x, t)
227
+ if len(old_eps) == 0:
228
+ # Pseudo Improved Euler (2nd order)
229
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
230
+ e_t_next = get_model_output(x_prev, t_next)
231
+ e_t_prime = (e_t + e_t_next) / 2
232
+ elif len(old_eps) == 1:
233
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
234
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
235
+ elif len(old_eps) == 2:
236
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
237
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
238
+ elif len(old_eps) >= 3:
239
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
240
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
241
+
242
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
243
+
244
+ return x_prev, pred_x0, e_t
ldm/models/diffusion/sampling_util.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def append_dims(x, target_dims):
6
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
7
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
8
+ dims_to_append = target_dims - x.ndim
9
+ if dims_to_append < 0:
10
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
11
+ return x[(...,) + (None,) * dims_to_append]
12
+
13
+
14
+ def norm_thresholding(x0, value):
15
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
16
+ return x0 * (value / s)
17
+
18
+
19
+ def spatial_norm_thresholding(x0, value):
20
+ # b c h w
21
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
22
+ return x0 * (value / s)
ldm/modules/attention.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from ldm.modules.diffusionmodules.util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ XFORMERS_IS_AVAILBLE = True
16
+ except:
17
+ XFORMERS_IS_AVAILBLE = False
18
+
19
+ # CrossAttn precision handling
20
+ import os
21
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
+
23
+ def exists(val):
24
+ return val is not None
25
+
26
+
27
+ def uniq(arr):
28
+ return{el: True for el in arr}.keys()
29
+
30
+
31
+ def default(val, d):
32
+ if exists(val):
33
+ return val
34
+ return d() if isfunction(d) else d
35
+
36
+
37
+ def max_neg_value(t):
38
+ return -torch.finfo(t.dtype).max
39
+
40
+
41
+ def init_(tensor):
42
+ dim = tensor.shape[-1]
43
+ std = 1 / math.sqrt(dim)
44
+ tensor.uniform_(-std, std)
45
+ return tensor
46
+
47
+
48
+ # feedforward
49
+ class GEGLU(nn.Module):
50
+ def __init__(self, dim_in, dim_out):
51
+ super().__init__()
52
+ self.proj = nn.Linear(dim_in, dim_out * 2)
53
+
54
+ def forward(self, x):
55
+ x, gate = self.proj(x).chunk(2, dim=-1)
56
+ return x * F.gelu(gate)
57
+
58
+
59
+ class FeedForward(nn.Module):
60
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
61
+ super().__init__()
62
+ inner_dim = int(dim * mult)
63
+ dim_out = default(dim_out, dim)
64
+ project_in = nn.Sequential(
65
+ nn.Linear(dim, inner_dim),
66
+ nn.GELU()
67
+ ) if not glu else GEGLU(dim, inner_dim)
68
+
69
+ self.net = nn.Sequential(
70
+ project_in,
71
+ nn.Dropout(dropout),
72
+ nn.Linear(inner_dim, dim_out)
73
+ )
74
+
75
+ def forward(self, x):
76
+ return self.net(x)
77
+
78
+
79
+ def zero_module(module):
80
+ """
81
+ Zero out the parameters of a module and return it.
82
+ """
83
+ for p in module.parameters():
84
+ p.detach().zero_()
85
+ return module
86
+
87
+
88
+ def Normalize(in_channels):
89
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
90
+
91
+
92
+ class SpatialSelfAttention(nn.Module):
93
+ def __init__(self, in_channels):
94
+ super().__init__()
95
+ self.in_channels = in_channels
96
+
97
+ self.norm = Normalize(in_channels)
98
+ self.q = torch.nn.Conv2d(in_channels,
99
+ in_channels,
100
+ kernel_size=1,
101
+ stride=1,
102
+ padding=0)
103
+ self.k = torch.nn.Conv2d(in_channels,
104
+ in_channels,
105
+ kernel_size=1,
106
+ stride=1,
107
+ padding=0)
108
+ self.v = torch.nn.Conv2d(in_channels,
109
+ in_channels,
110
+ kernel_size=1,
111
+ stride=1,
112
+ padding=0)
113
+ self.proj_out = torch.nn.Conv2d(in_channels,
114
+ in_channels,
115
+ kernel_size=1,
116
+ stride=1,
117
+ padding=0)
118
+
119
+ def forward(self, x):
120
+ h_ = x
121
+ h_ = self.norm(h_)
122
+ q = self.q(h_)
123
+ k = self.k(h_)
124
+ v = self.v(h_)
125
+
126
+ # compute attention
127
+ b,c,h,w = q.shape
128
+ q = rearrange(q, 'b c h w -> b (h w) c')
129
+ k = rearrange(k, 'b c h w -> b c (h w)')
130
+ w_ = torch.einsum('bij,bjk->bik', q, k)
131
+
132
+ w_ = w_ * (int(c)**(-0.5))
133
+ w_ = torch.nn.functional.softmax(w_, dim=2)
134
+
135
+ # attend to values
136
+ v = rearrange(v, 'b c h w -> b c (h w)')
137
+ w_ = rearrange(w_, 'b i j -> b j i')
138
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
139
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
140
+ h_ = self.proj_out(h_)
141
+
142
+ return x+h_
143
+
144
+
145
+ class CrossAttention(nn.Module):
146
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
147
+ super().__init__()
148
+ inner_dim = dim_head * heads
149
+ context_dim = default(context_dim, query_dim)
150
+
151
+ self.scale = dim_head ** -0.5
152
+ self.heads = heads
153
+
154
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
155
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
156
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
157
+
158
+ self.to_out = nn.Sequential(
159
+ nn.Linear(inner_dim, query_dim),
160
+ nn.Dropout(dropout)
161
+ )
162
+
163
+ def forward(self, x, context=None, mask=None):
164
+ h = self.heads
165
+
166
+ q = self.to_q(x)
167
+ context = default(context, x)
168
+ k = self.to_k(context)
169
+ v = self.to_v(context)
170
+
171
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
172
+
173
+ # force cast to fp32 to avoid overflowing
174
+ if _ATTN_PRECISION =="fp32":
175
+ with torch.autocast(enabled=False, device_type = 'cuda'):
176
+ q, k = q.float(), k.float()
177
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
178
+ else:
179
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
180
+
181
+ del q, k
182
+
183
+ if exists(mask):
184
+ mask = rearrange(mask, 'b ... -> b (...)')
185
+ max_neg_value = -torch.finfo(sim.dtype).max
186
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
187
+ sim.masked_fill_(~mask, max_neg_value)
188
+
189
+ # attention, what we cannot get enough of
190
+ sim = sim.softmax(dim=-1)
191
+
192
+ out = einsum('b i j, b j d -> b i d', sim, v)
193
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
194
+ return self.to_out(out)
195
+
196
+
197
+ class MemoryEfficientCrossAttention(nn.Module):
198
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
199
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
200
+ super().__init__()
201
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
202
+ f"{heads} heads.")
203
+ inner_dim = dim_head * heads
204
+ context_dim = default(context_dim, query_dim)
205
+
206
+ self.heads = heads
207
+ self.dim_head = dim_head
208
+
209
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
210
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
211
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
212
+
213
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
214
+ self.attention_op: Optional[Any] = None
215
+
216
+ def forward(self, x, context=None, mask=None):
217
+ q = self.to_q(x)
218
+ context = default(context, x)
219
+ k = self.to_k(context)
220
+ v = self.to_v(context)
221
+
222
+ b, _, _ = q.shape
223
+ q, k, v = map(
224
+ lambda t: t.unsqueeze(3)
225
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
226
+ .permute(0, 2, 1, 3)
227
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
228
+ .contiguous(),
229
+ (q, k, v),
230
+ )
231
+
232
+ # actually compute the attention, what we cannot get enough of
233
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
234
+
235
+ if exists(mask):
236
+ raise NotImplementedError
237
+ out = (
238
+ out.unsqueeze(0)
239
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
240
+ .permute(0, 2, 1, 3)
241
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
242
+ )
243
+ return self.to_out(out)
244
+
245
+
246
+ class BasicTransformerBlock(nn.Module):
247
+ ATTENTION_MODES = {
248
+ "softmax": CrossAttention, # vanilla attention
249
+ "softmax-xformers": MemoryEfficientCrossAttention
250
+ }
251
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
252
+ disable_self_attn=False):
253
+ super().__init__()
254
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
255
+ assert attn_mode in self.ATTENTION_MODES
256
+ attn_cls = self.ATTENTION_MODES[attn_mode]
257
+ self.disable_self_attn = disable_self_attn
258
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
259
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
260
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
261
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
262
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+ self.checkpoint = checkpoint
267
+
268
+ def forward(self, x, context=None):
269
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
270
+
271
+ def _forward(self, x, context=None):
272
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
273
+ x = self.attn2(self.norm2(x), context=context) + x
274
+ x = self.ff(self.norm3(x)) + x
275
+ return x
276
+
277
+
278
+ class SpatialTransformer(nn.Module):
279
+ """
280
+ Transformer block for image-like data.
281
+ First, project the input (aka embedding)
282
+ and reshape to b, t, d.
283
+ Then apply standard transformer action.
284
+ Finally, reshape to image
285
+ NEW: use_linear for more efficiency instead of the 1x1 convs
286
+ """
287
+ def __init__(self, in_channels, n_heads, d_head,
288
+ depth=1, dropout=0., context_dim=None,
289
+ disable_self_attn=False, use_linear=False,
290
+ use_checkpoint=True):
291
+ super().__init__()
292
+ if exists(context_dim) and not isinstance(context_dim, list):
293
+ context_dim = [context_dim]
294
+ self.in_channels = in_channels
295
+ inner_dim = n_heads * d_head
296
+ self.norm = Normalize(in_channels)
297
+ if not use_linear:
298
+ self.proj_in = nn.Conv2d(in_channels,
299
+ inner_dim,
300
+ kernel_size=1,
301
+ stride=1,
302
+ padding=0)
303
+ else:
304
+ self.proj_in = nn.Linear(in_channels, inner_dim)
305
+
306
+ self.transformer_blocks = nn.ModuleList(
307
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
308
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
309
+ for d in range(depth)]
310
+ )
311
+ if not use_linear:
312
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
313
+ in_channels,
314
+ kernel_size=1,
315
+ stride=1,
316
+ padding=0))
317
+ else:
318
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
319
+ self.use_linear = use_linear
320
+
321
+ def forward(self, x, context=None):
322
+ # note: if no context is given, cross-attention defaults to self-attention
323
+ if not isinstance(context, list):
324
+ context = [context]
325
+ b, c, h, w = x.shape
326
+ x_in = x
327
+ x = self.norm(x)
328
+ if not self.use_linear:
329
+ x = self.proj_in(x)
330
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
331
+ if self.use_linear:
332
+ x = self.proj_in(x)
333
+ for i, block in enumerate(self.transformer_blocks):
334
+ x = block(x, context=context[i])
335
+ if self.use_linear:
336
+ x = self.proj_out(x)
337
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
338
+ if not self.use_linear:
339
+ x = self.proj_out(x)
340
+ return x + x_in
341
+
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,852 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from ldm.modules.attention import MemoryEfficientCrossAttention
10
+
11
+ try:
12
+ import xformers
13
+ import xformers.ops
14
+ XFORMERS_IS_AVAILBLE = True
15
+ except:
16
+ XFORMERS_IS_AVAILBLE = False
17
+ print("No module 'xformers'. Proceeding without it.")
18
+
19
+
20
+ def get_timestep_embedding(timesteps, embedding_dim):
21
+ """
22
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
23
+ From Fairseq.
24
+ Build sinusoidal embeddings.
25
+ This matches the implementation in tensor2tensor, but differs slightly
26
+ from the description in Section 3.5 of "Attention Is All You Need".
27
+ """
28
+ assert len(timesteps.shape) == 1
29
+
30
+ half_dim = embedding_dim // 2
31
+ emb = math.log(10000) / (half_dim - 1)
32
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
33
+ emb = emb.to(device=timesteps.device)
34
+ emb = timesteps.float()[:, None] * emb[None, :]
35
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
36
+ if embedding_dim % 2 == 1: # zero pad
37
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
38
+ return emb
39
+
40
+
41
+ def nonlinearity(x):
42
+ # swish
43
+ return x*torch.sigmoid(x)
44
+
45
+
46
+ def Normalize(in_channels, num_groups=32):
47
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
48
+
49
+
50
+ class Upsample(nn.Module):
51
+ def __init__(self, in_channels, with_conv):
52
+ super().__init__()
53
+ self.with_conv = with_conv
54
+ if self.with_conv:
55
+ self.conv = torch.nn.Conv2d(in_channels,
56
+ in_channels,
57
+ kernel_size=3,
58
+ stride=1,
59
+ padding=1)
60
+
61
+ def forward(self, x):
62
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
63
+ if self.with_conv:
64
+ x = self.conv(x)
65
+ return x
66
+
67
+
68
+ class Downsample(nn.Module):
69
+ def __init__(self, in_channels, with_conv):
70
+ super().__init__()
71
+ self.with_conv = with_conv
72
+ if self.with_conv:
73
+ # no asymmetric padding in torch conv, must do it ourselves
74
+ self.conv = torch.nn.Conv2d(in_channels,
75
+ in_channels,
76
+ kernel_size=3,
77
+ stride=2,
78
+ padding=0)
79
+
80
+ def forward(self, x):
81
+ if self.with_conv:
82
+ pad = (0,1,0,1)
83
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
84
+ x = self.conv(x)
85
+ else:
86
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
87
+ return x
88
+
89
+
90
+ class ResnetBlock(nn.Module):
91
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
92
+ dropout, temb_channels=512):
93
+ super().__init__()
94
+ self.in_channels = in_channels
95
+ out_channels = in_channels if out_channels is None else out_channels
96
+ self.out_channels = out_channels
97
+ self.use_conv_shortcut = conv_shortcut
98
+
99
+ self.norm1 = Normalize(in_channels)
100
+ self.conv1 = torch.nn.Conv2d(in_channels,
101
+ out_channels,
102
+ kernel_size=3,
103
+ stride=1,
104
+ padding=1)
105
+ if temb_channels > 0:
106
+ self.temb_proj = torch.nn.Linear(temb_channels,
107
+ out_channels)
108
+ self.norm2 = Normalize(out_channels)
109
+ self.dropout = torch.nn.Dropout(dropout)
110
+ self.conv2 = torch.nn.Conv2d(out_channels,
111
+ out_channels,
112
+ kernel_size=3,
113
+ stride=1,
114
+ padding=1)
115
+ if self.in_channels != self.out_channels:
116
+ if self.use_conv_shortcut:
117
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
118
+ out_channels,
119
+ kernel_size=3,
120
+ stride=1,
121
+ padding=1)
122
+ else:
123
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
124
+ out_channels,
125
+ kernel_size=1,
126
+ stride=1,
127
+ padding=0)
128
+
129
+ def forward(self, x, temb):
130
+ h = x
131
+ h = self.norm1(h)
132
+ h = nonlinearity(h)
133
+ h = self.conv1(h)
134
+
135
+ if temb is not None:
136
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
137
+
138
+ h = self.norm2(h)
139
+ h = nonlinearity(h)
140
+ h = self.dropout(h)
141
+ h = self.conv2(h)
142
+
143
+ if self.in_channels != self.out_channels:
144
+ if self.use_conv_shortcut:
145
+ x = self.conv_shortcut(x)
146
+ else:
147
+ x = self.nin_shortcut(x)
148
+
149
+ return x+h
150
+
151
+
152
+ class AttnBlock(nn.Module):
153
+ def __init__(self, in_channels):
154
+ super().__init__()
155
+ self.in_channels = in_channels
156
+
157
+ self.norm = Normalize(in_channels)
158
+ self.q = torch.nn.Conv2d(in_channels,
159
+ in_channels,
160
+ kernel_size=1,
161
+ stride=1,
162
+ padding=0)
163
+ self.k = torch.nn.Conv2d(in_channels,
164
+ in_channels,
165
+ kernel_size=1,
166
+ stride=1,
167
+ padding=0)
168
+ self.v = torch.nn.Conv2d(in_channels,
169
+ in_channels,
170
+ kernel_size=1,
171
+ stride=1,
172
+ padding=0)
173
+ self.proj_out = torch.nn.Conv2d(in_channels,
174
+ in_channels,
175
+ kernel_size=1,
176
+ stride=1,
177
+ padding=0)
178
+
179
+ def forward(self, x):
180
+ h_ = x
181
+ h_ = self.norm(h_)
182
+ q = self.q(h_)
183
+ k = self.k(h_)
184
+ v = self.v(h_)
185
+
186
+ # compute attention
187
+ b,c,h,w = q.shape
188
+ q = q.reshape(b,c,h*w)
189
+ q = q.permute(0,2,1) # b,hw,c
190
+ k = k.reshape(b,c,h*w) # b,c,hw
191
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
192
+ w_ = w_ * (int(c)**(-0.5))
193
+ w_ = torch.nn.functional.softmax(w_, dim=2)
194
+
195
+ # attend to values
196
+ v = v.reshape(b,c,h*w)
197
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
198
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
199
+ h_ = h_.reshape(b,c,h,w)
200
+
201
+ h_ = self.proj_out(h_)
202
+
203
+ return x+h_
204
+
205
+ class MemoryEfficientAttnBlock(nn.Module):
206
+ """
207
+ Uses xformers efficient implementation,
208
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
209
+ Note: this is a single-head self-attention operation
210
+ """
211
+ #
212
+ def __init__(self, in_channels):
213
+ super().__init__()
214
+ self.in_channels = in_channels
215
+
216
+ self.norm = Normalize(in_channels)
217
+ self.q = torch.nn.Conv2d(in_channels,
218
+ in_channels,
219
+ kernel_size=1,
220
+ stride=1,
221
+ padding=0)
222
+ self.k = torch.nn.Conv2d(in_channels,
223
+ in_channels,
224
+ kernel_size=1,
225
+ stride=1,
226
+ padding=0)
227
+ self.v = torch.nn.Conv2d(in_channels,
228
+ in_channels,
229
+ kernel_size=1,
230
+ stride=1,
231
+ padding=0)
232
+ self.proj_out = torch.nn.Conv2d(in_channels,
233
+ in_channels,
234
+ kernel_size=1,
235
+ stride=1,
236
+ padding=0)
237
+ self.attention_op: Optional[Any] = None
238
+
239
+ def forward(self, x):
240
+ h_ = x
241
+ h_ = self.norm(h_)
242
+ q = self.q(h_)
243
+ k = self.k(h_)
244
+ v = self.v(h_)
245
+
246
+ # compute attention
247
+ B, C, H, W = q.shape
248
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
249
+
250
+ q, k, v = map(
251
+ lambda t: t.unsqueeze(3)
252
+ .reshape(B, t.shape[1], 1, C)
253
+ .permute(0, 2, 1, 3)
254
+ .reshape(B * 1, t.shape[1], C)
255
+ .contiguous(),
256
+ (q, k, v),
257
+ )
258
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
259
+
260
+ out = (
261
+ out.unsqueeze(0)
262
+ .reshape(B, 1, out.shape[1], C)
263
+ .permute(0, 2, 1, 3)
264
+ .reshape(B, out.shape[1], C)
265
+ )
266
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
267
+ out = self.proj_out(out)
268
+ return x+out
269
+
270
+
271
+ class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
272
+ def forward(self, x, context=None, mask=None):
273
+ b, c, h, w = x.shape
274
+ x = rearrange(x, 'b c h w -> b (h w) c')
275
+ out = super().forward(x, context=context, mask=mask)
276
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
277
+ return x + out
278
+
279
+
280
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
281
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
282
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
283
+ attn_type = "vanilla-xformers"
284
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
285
+ if attn_type == "vanilla":
286
+ assert attn_kwargs is None
287
+ return AttnBlock(in_channels)
288
+ elif attn_type == "vanilla-xformers":
289
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
290
+ return MemoryEfficientAttnBlock(in_channels)
291
+ elif type == "memory-efficient-cross-attn":
292
+ attn_kwargs["query_dim"] = in_channels
293
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
294
+ elif attn_type == "none":
295
+ return nn.Identity(in_channels)
296
+ else:
297
+ raise NotImplementedError()
298
+
299
+
300
+ class Model(nn.Module):
301
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
302
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
303
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
304
+ super().__init__()
305
+ if use_linear_attn: attn_type = "linear"
306
+ self.ch = ch
307
+ self.temb_ch = self.ch*4
308
+ self.num_resolutions = len(ch_mult)
309
+ self.num_res_blocks = num_res_blocks
310
+ self.resolution = resolution
311
+ self.in_channels = in_channels
312
+
313
+ self.use_timestep = use_timestep
314
+ if self.use_timestep:
315
+ # timestep embedding
316
+ self.temb = nn.Module()
317
+ self.temb.dense = nn.ModuleList([
318
+ torch.nn.Linear(self.ch,
319
+ self.temb_ch),
320
+ torch.nn.Linear(self.temb_ch,
321
+ self.temb_ch),
322
+ ])
323
+
324
+ # downsampling
325
+ self.conv_in = torch.nn.Conv2d(in_channels,
326
+ self.ch,
327
+ kernel_size=3,
328
+ stride=1,
329
+ padding=1)
330
+
331
+ curr_res = resolution
332
+ in_ch_mult = (1,)+tuple(ch_mult)
333
+ self.down = nn.ModuleList()
334
+ for i_level in range(self.num_resolutions):
335
+ block = nn.ModuleList()
336
+ attn = nn.ModuleList()
337
+ block_in = ch*in_ch_mult[i_level]
338
+ block_out = ch*ch_mult[i_level]
339
+ for i_block in range(self.num_res_blocks):
340
+ block.append(ResnetBlock(in_channels=block_in,
341
+ out_channels=block_out,
342
+ temb_channels=self.temb_ch,
343
+ dropout=dropout))
344
+ block_in = block_out
345
+ if curr_res in attn_resolutions:
346
+ attn.append(make_attn(block_in, attn_type=attn_type))
347
+ down = nn.Module()
348
+ down.block = block
349
+ down.attn = attn
350
+ if i_level != self.num_resolutions-1:
351
+ down.downsample = Downsample(block_in, resamp_with_conv)
352
+ curr_res = curr_res // 2
353
+ self.down.append(down)
354
+
355
+ # middle
356
+ self.mid = nn.Module()
357
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
358
+ out_channels=block_in,
359
+ temb_channels=self.temb_ch,
360
+ dropout=dropout)
361
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
362
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
363
+ out_channels=block_in,
364
+ temb_channels=self.temb_ch,
365
+ dropout=dropout)
366
+
367
+ # upsampling
368
+ self.up = nn.ModuleList()
369
+ for i_level in reversed(range(self.num_resolutions)):
370
+ block = nn.ModuleList()
371
+ attn = nn.ModuleList()
372
+ block_out = ch*ch_mult[i_level]
373
+ skip_in = ch*ch_mult[i_level]
374
+ for i_block in range(self.num_res_blocks+1):
375
+ if i_block == self.num_res_blocks:
376
+ skip_in = ch*in_ch_mult[i_level]
377
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
378
+ out_channels=block_out,
379
+ temb_channels=self.temb_ch,
380
+ dropout=dropout))
381
+ block_in = block_out
382
+ if curr_res in attn_resolutions:
383
+ attn.append(make_attn(block_in, attn_type=attn_type))
384
+ up = nn.Module()
385
+ up.block = block
386
+ up.attn = attn
387
+ if i_level != 0:
388
+ up.upsample = Upsample(block_in, resamp_with_conv)
389
+ curr_res = curr_res * 2
390
+ self.up.insert(0, up) # prepend to get consistent order
391
+
392
+ # end
393
+ self.norm_out = Normalize(block_in)
394
+ self.conv_out = torch.nn.Conv2d(block_in,
395
+ out_ch,
396
+ kernel_size=3,
397
+ stride=1,
398
+ padding=1)
399
+
400
+ def forward(self, x, t=None, context=None):
401
+ #assert x.shape[2] == x.shape[3] == self.resolution
402
+ if context is not None:
403
+ # assume aligned context, cat along channel axis
404
+ x = torch.cat((x, context), dim=1)
405
+ if self.use_timestep:
406
+ # timestep embedding
407
+ assert t is not None
408
+ temb = get_timestep_embedding(t, self.ch)
409
+ temb = self.temb.dense[0](temb)
410
+ temb = nonlinearity(temb)
411
+ temb = self.temb.dense[1](temb)
412
+ else:
413
+ temb = None
414
+
415
+ # downsampling
416
+ hs = [self.conv_in(x)]
417
+ for i_level in range(self.num_resolutions):
418
+ for i_block in range(self.num_res_blocks):
419
+ h = self.down[i_level].block[i_block](hs[-1], temb)
420
+ if len(self.down[i_level].attn) > 0:
421
+ h = self.down[i_level].attn[i_block](h)
422
+ hs.append(h)
423
+ if i_level != self.num_resolutions-1:
424
+ hs.append(self.down[i_level].downsample(hs[-1]))
425
+
426
+ # middle
427
+ h = hs[-1]
428
+ h = self.mid.block_1(h, temb)
429
+ h = self.mid.attn_1(h)
430
+ h = self.mid.block_2(h, temb)
431
+
432
+ # upsampling
433
+ for i_level in reversed(range(self.num_resolutions)):
434
+ for i_block in range(self.num_res_blocks+1):
435
+ h = self.up[i_level].block[i_block](
436
+ torch.cat([h, hs.pop()], dim=1), temb)
437
+ if len(self.up[i_level].attn) > 0:
438
+ h = self.up[i_level].attn[i_block](h)
439
+ if i_level != 0:
440
+ h = self.up[i_level].upsample(h)
441
+
442
+ # end
443
+ h = self.norm_out(h)
444
+ h = nonlinearity(h)
445
+ h = self.conv_out(h)
446
+ return h
447
+
448
+ def get_last_layer(self):
449
+ return self.conv_out.weight
450
+
451
+
452
+ class Encoder(nn.Module):
453
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
454
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
455
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
456
+ **ignore_kwargs):
457
+ super().__init__()
458
+ if use_linear_attn: attn_type = "linear"
459
+ self.ch = ch
460
+ self.temb_ch = 0
461
+ self.num_resolutions = len(ch_mult)
462
+ self.num_res_blocks = num_res_blocks
463
+ self.resolution = resolution
464
+ self.in_channels = in_channels
465
+
466
+ # downsampling
467
+ self.conv_in = torch.nn.Conv2d(in_channels,
468
+ self.ch,
469
+ kernel_size=3,
470
+ stride=1,
471
+ padding=1)
472
+
473
+ curr_res = resolution
474
+ in_ch_mult = (1,)+tuple(ch_mult)
475
+ self.in_ch_mult = in_ch_mult
476
+ self.down = nn.ModuleList()
477
+ for i_level in range(self.num_resolutions):
478
+ block = nn.ModuleList()
479
+ attn = nn.ModuleList()
480
+ block_in = ch*in_ch_mult[i_level]
481
+ block_out = ch*ch_mult[i_level]
482
+ for i_block in range(self.num_res_blocks):
483
+ block.append(ResnetBlock(in_channels=block_in,
484
+ out_channels=block_out,
485
+ temb_channels=self.temb_ch,
486
+ dropout=dropout))
487
+ block_in = block_out
488
+ if curr_res in attn_resolutions:
489
+ attn.append(make_attn(block_in, attn_type=attn_type))
490
+ down = nn.Module()
491
+ down.block = block
492
+ down.attn = attn
493
+ if i_level != self.num_resolutions-1:
494
+ down.downsample = Downsample(block_in, resamp_with_conv)
495
+ curr_res = curr_res // 2
496
+ self.down.append(down)
497
+
498
+ # middle
499
+ self.mid = nn.Module()
500
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
501
+ out_channels=block_in,
502
+ temb_channels=self.temb_ch,
503
+ dropout=dropout)
504
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
505
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
506
+ out_channels=block_in,
507
+ temb_channels=self.temb_ch,
508
+ dropout=dropout)
509
+
510
+ # end
511
+ self.norm_out = Normalize(block_in)
512
+ self.conv_out = torch.nn.Conv2d(block_in,
513
+ 2*z_channels if double_z else z_channels,
514
+ kernel_size=3,
515
+ stride=1,
516
+ padding=1)
517
+
518
+ def forward(self, x):
519
+ # timestep embedding
520
+ temb = None
521
+
522
+ # downsampling
523
+ hs = [self.conv_in(x)]
524
+ for i_level in range(self.num_resolutions):
525
+ for i_block in range(self.num_res_blocks):
526
+ h = self.down[i_level].block[i_block](hs[-1], temb)
527
+ if len(self.down[i_level].attn) > 0:
528
+ h = self.down[i_level].attn[i_block](h)
529
+ hs.append(h)
530
+ if i_level != self.num_resolutions-1:
531
+ hs.append(self.down[i_level].downsample(hs[-1]))
532
+
533
+ # middle
534
+ h = hs[-1]
535
+ h = self.mid.block_1(h, temb)
536
+ h = self.mid.attn_1(h)
537
+ h = self.mid.block_2(h, temb)
538
+
539
+ # end
540
+ h = self.norm_out(h)
541
+ h = nonlinearity(h)
542
+ h = self.conv_out(h)
543
+ return h
544
+
545
+
546
+ class Decoder(nn.Module):
547
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
548
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
549
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
550
+ attn_type="vanilla", **ignorekwargs):
551
+ super().__init__()
552
+ if use_linear_attn: attn_type = "linear"
553
+ self.ch = ch
554
+ self.temb_ch = 0
555
+ self.num_resolutions = len(ch_mult)
556
+ self.num_res_blocks = num_res_blocks
557
+ self.resolution = resolution
558
+ self.in_channels = in_channels
559
+ self.give_pre_end = give_pre_end
560
+ self.tanh_out = tanh_out
561
+
562
+ # compute in_ch_mult, block_in and curr_res at lowest res
563
+ in_ch_mult = (1,)+tuple(ch_mult)
564
+ block_in = ch*ch_mult[self.num_resolutions-1]
565
+ curr_res = resolution // 2**(self.num_resolutions-1)
566
+ self.z_shape = (1,z_channels,curr_res,curr_res)
567
+ print("Working with z of shape {} = {} dimensions.".format(
568
+ self.z_shape, np.prod(self.z_shape)))
569
+
570
+ # z to block_in
571
+ self.conv_in = torch.nn.Conv2d(z_channels,
572
+ block_in,
573
+ kernel_size=3,
574
+ stride=1,
575
+ padding=1)
576
+
577
+ # middle
578
+ self.mid = nn.Module()
579
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
580
+ out_channels=block_in,
581
+ temb_channels=self.temb_ch,
582
+ dropout=dropout)
583
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
584
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
585
+ out_channels=block_in,
586
+ temb_channels=self.temb_ch,
587
+ dropout=dropout)
588
+
589
+ # upsampling
590
+ self.up = nn.ModuleList()
591
+ for i_level in reversed(range(self.num_resolutions)):
592
+ block = nn.ModuleList()
593
+ attn = nn.ModuleList()
594
+ block_out = ch*ch_mult[i_level]
595
+ for i_block in range(self.num_res_blocks+1):
596
+ block.append(ResnetBlock(in_channels=block_in,
597
+ out_channels=block_out,
598
+ temb_channels=self.temb_ch,
599
+ dropout=dropout))
600
+ block_in = block_out
601
+ if curr_res in attn_resolutions:
602
+ attn.append(make_attn(block_in, attn_type=attn_type))
603
+ up = nn.Module()
604
+ up.block = block
605
+ up.attn = attn
606
+ if i_level != 0:
607
+ up.upsample = Upsample(block_in, resamp_with_conv)
608
+ curr_res = curr_res * 2
609
+ self.up.insert(0, up) # prepend to get consistent order
610
+
611
+ # end
612
+ self.norm_out = Normalize(block_in)
613
+ self.conv_out = torch.nn.Conv2d(block_in,
614
+ out_ch,
615
+ kernel_size=3,
616
+ stride=1,
617
+ padding=1)
618
+
619
+ def forward(self, z):
620
+ #assert z.shape[1:] == self.z_shape[1:]
621
+ self.last_z_shape = z.shape
622
+
623
+ # timestep embedding
624
+ temb = None
625
+
626
+ # z to block_in
627
+ h = self.conv_in(z)
628
+
629
+ # middle
630
+ h = self.mid.block_1(h, temb)
631
+ h = self.mid.attn_1(h)
632
+ h = self.mid.block_2(h, temb)
633
+
634
+ # upsampling
635
+ for i_level in reversed(range(self.num_resolutions)):
636
+ for i_block in range(self.num_res_blocks+1):
637
+ h = self.up[i_level].block[i_block](h, temb)
638
+ if len(self.up[i_level].attn) > 0:
639
+ h = self.up[i_level].attn[i_block](h)
640
+ if i_level != 0:
641
+ h = self.up[i_level].upsample(h)
642
+
643
+ # end
644
+ if self.give_pre_end:
645
+ return h
646
+
647
+ h = self.norm_out(h)
648
+ h = nonlinearity(h)
649
+ h = self.conv_out(h)
650
+ if self.tanh_out:
651
+ h = torch.tanh(h)
652
+ return h
653
+
654
+
655
+ class SimpleDecoder(nn.Module):
656
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
657
+ super().__init__()
658
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
659
+ ResnetBlock(in_channels=in_channels,
660
+ out_channels=2 * in_channels,
661
+ temb_channels=0, dropout=0.0),
662
+ ResnetBlock(in_channels=2 * in_channels,
663
+ out_channels=4 * in_channels,
664
+ temb_channels=0, dropout=0.0),
665
+ ResnetBlock(in_channels=4 * in_channels,
666
+ out_channels=2 * in_channels,
667
+ temb_channels=0, dropout=0.0),
668
+ nn.Conv2d(2*in_channels, in_channels, 1),
669
+ Upsample(in_channels, with_conv=True)])
670
+ # end
671
+ self.norm_out = Normalize(in_channels)
672
+ self.conv_out = torch.nn.Conv2d(in_channels,
673
+ out_channels,
674
+ kernel_size=3,
675
+ stride=1,
676
+ padding=1)
677
+
678
+ def forward(self, x):
679
+ for i, layer in enumerate(self.model):
680
+ if i in [1,2,3]:
681
+ x = layer(x, None)
682
+ else:
683
+ x = layer(x)
684
+
685
+ h = self.norm_out(x)
686
+ h = nonlinearity(h)
687
+ x = self.conv_out(h)
688
+ return x
689
+
690
+
691
+ class UpsampleDecoder(nn.Module):
692
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
693
+ ch_mult=(2,2), dropout=0.0):
694
+ super().__init__()
695
+ # upsampling
696
+ self.temb_ch = 0
697
+ self.num_resolutions = len(ch_mult)
698
+ self.num_res_blocks = num_res_blocks
699
+ block_in = in_channels
700
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
701
+ self.res_blocks = nn.ModuleList()
702
+ self.upsample_blocks = nn.ModuleList()
703
+ for i_level in range(self.num_resolutions):
704
+ res_block = []
705
+ block_out = ch * ch_mult[i_level]
706
+ for i_block in range(self.num_res_blocks + 1):
707
+ res_block.append(ResnetBlock(in_channels=block_in,
708
+ out_channels=block_out,
709
+ temb_channels=self.temb_ch,
710
+ dropout=dropout))
711
+ block_in = block_out
712
+ self.res_blocks.append(nn.ModuleList(res_block))
713
+ if i_level != self.num_resolutions - 1:
714
+ self.upsample_blocks.append(Upsample(block_in, True))
715
+ curr_res = curr_res * 2
716
+
717
+ # end
718
+ self.norm_out = Normalize(block_in)
719
+ self.conv_out = torch.nn.Conv2d(block_in,
720
+ out_channels,
721
+ kernel_size=3,
722
+ stride=1,
723
+ padding=1)
724
+
725
+ def forward(self, x):
726
+ # upsampling
727
+ h = x
728
+ for k, i_level in enumerate(range(self.num_resolutions)):
729
+ for i_block in range(self.num_res_blocks + 1):
730
+ h = self.res_blocks[i_level][i_block](h, None)
731
+ if i_level != self.num_resolutions - 1:
732
+ h = self.upsample_blocks[k](h)
733
+ h = self.norm_out(h)
734
+ h = nonlinearity(h)
735
+ h = self.conv_out(h)
736
+ return h
737
+
738
+
739
+ class LatentRescaler(nn.Module):
740
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
741
+ super().__init__()
742
+ # residual block, interpolate, residual block
743
+ self.factor = factor
744
+ self.conv_in = nn.Conv2d(in_channels,
745
+ mid_channels,
746
+ kernel_size=3,
747
+ stride=1,
748
+ padding=1)
749
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
750
+ out_channels=mid_channels,
751
+ temb_channels=0,
752
+ dropout=0.0) for _ in range(depth)])
753
+ self.attn = AttnBlock(mid_channels)
754
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
755
+ out_channels=mid_channels,
756
+ temb_channels=0,
757
+ dropout=0.0) for _ in range(depth)])
758
+
759
+ self.conv_out = nn.Conv2d(mid_channels,
760
+ out_channels,
761
+ kernel_size=1,
762
+ )
763
+
764
+ def forward(self, x):
765
+ x = self.conv_in(x)
766
+ for block in self.res_block1:
767
+ x = block(x, None)
768
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
769
+ x = self.attn(x)
770
+ for block in self.res_block2:
771
+ x = block(x, None)
772
+ x = self.conv_out(x)
773
+ return x
774
+
775
+
776
+ class MergedRescaleEncoder(nn.Module):
777
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
778
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
779
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
780
+ super().__init__()
781
+ intermediate_chn = ch * ch_mult[-1]
782
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
783
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
784
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
785
+ out_ch=None)
786
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
787
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
788
+
789
+ def forward(self, x):
790
+ x = self.encoder(x)
791
+ x = self.rescaler(x)
792
+ return x
793
+
794
+
795
+ class MergedRescaleDecoder(nn.Module):
796
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
797
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
798
+ super().__init__()
799
+ tmp_chn = z_channels*ch_mult[-1]
800
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
801
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
802
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
803
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
804
+ out_channels=tmp_chn, depth=rescale_module_depth)
805
+
806
+ def forward(self, x):
807
+ x = self.rescaler(x)
808
+ x = self.decoder(x)
809
+ return x
810
+
811
+
812
+ class Upsampler(nn.Module):
813
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
814
+ super().__init__()
815
+ assert out_size >= in_size
816
+ num_blocks = int(np.log2(out_size//in_size))+1
817
+ factor_up = 1.+ (out_size % in_size)
818
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
819
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
820
+ out_channels=in_channels)
821
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
822
+ attn_resolutions=[], in_channels=None, ch=in_channels,
823
+ ch_mult=[ch_mult for _ in range(num_blocks)])
824
+
825
+ def forward(self, x):
826
+ x = self.rescaler(x)
827
+ x = self.decoder(x)
828
+ return x
829
+
830
+
831
+ class Resize(nn.Module):
832
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
833
+ super().__init__()
834
+ self.with_conv = learned
835
+ self.mode = mode
836
+ if self.with_conv:
837
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
838
+ raise NotImplementedError()
839
+ assert in_channels is not None
840
+ # no asymmetric padding in torch conv, must do it ourselves
841
+ self.conv = torch.nn.Conv2d(in_channels,
842
+ in_channels,
843
+ kernel_size=4,
844
+ stride=2,
845
+ padding=1)
846
+
847
+ def forward(self, x, scale_factor=1.0):
848
+ if scale_factor==1.0:
849
+ return x
850
+ else:
851
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
852
+ return x
ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from ldm.modules.diffusionmodules.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ )
18
+ from ldm.modules.attention import SpatialTransformer
19
+ from ldm.util import exists
20
+
21
+
22
+ # dummy replace
23
+ def convert_module_to_f16(x):
24
+ pass
25
+
26
+ def convert_module_to_f32(x):
27
+ pass
28
+
29
+
30
+ ## go
31
+ class AttentionPool2d(nn.Module):
32
+ """
33
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ spacial_dim: int,
39
+ embed_dim: int,
40
+ num_heads_channels: int,
41
+ output_dim: int = None,
42
+ ):
43
+ super().__init__()
44
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
45
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
46
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
47
+ self.num_heads = embed_dim // num_heads_channels
48
+ self.attention = QKVAttention(self.num_heads)
49
+
50
+ def forward(self, x):
51
+ b, c, *_spatial = x.shape
52
+ x = x.reshape(b, c, -1) # NC(HW)
53
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
54
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
55
+ x = self.qkv_proj(x)
56
+ x = self.attention(x)
57
+ x = self.c_proj(x)
58
+ return x[:, :, 0]
59
+
60
+
61
+ class TimestepBlock(nn.Module):
62
+ """
63
+ Any module where forward() takes timestep embeddings as a second argument.
64
+ """
65
+
66
+ @abstractmethod
67
+ def forward(self, x, emb):
68
+ """
69
+ Apply the module to `x` given `emb` timestep embeddings.
70
+ """
71
+
72
+
73
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
74
+ """
75
+ A sequential module that passes timestep embeddings to the children that
76
+ support it as an extra input.
77
+ """
78
+
79
+ def forward(self, x, emb, context=None):
80
+ for layer in self:
81
+ if isinstance(layer, TimestepBlock):
82
+ x = layer(x, emb)
83
+ elif isinstance(layer, SpatialTransformer):
84
+ x = layer(x, context)
85
+ else:
86
+ x = layer(x)
87
+ return x
88
+
89
+
90
+ class Upsample(nn.Module):
91
+ """
92
+ An upsampling layer with an optional convolution.
93
+ :param channels: channels in the inputs and outputs.
94
+ :param use_conv: a bool determining if a convolution is applied.
95
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
96
+ upsampling occurs in the inner-two dimensions.
97
+ """
98
+
99
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
100
+ super().__init__()
101
+ self.channels = channels
102
+ self.out_channels = out_channels or channels
103
+ self.use_conv = use_conv
104
+ self.dims = dims
105
+ if use_conv:
106
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
107
+
108
+ def forward(self, x):
109
+ assert x.shape[1] == self.channels
110
+ if self.dims == 3:
111
+ x = F.interpolate(
112
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
113
+ )
114
+ else:
115
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
116
+ if self.use_conv:
117
+ x = self.conv(x)
118
+ return x
119
+
120
+ class TransposedUpsample(nn.Module):
121
+ 'Learned 2x upsampling without padding'
122
+ def __init__(self, channels, out_channels=None, ks=5):
123
+ super().__init__()
124
+ self.channels = channels
125
+ self.out_channels = out_channels or channels
126
+
127
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
128
+
129
+ def forward(self,x):
130
+ return self.up(x)
131
+
132
+
133
+ class Downsample(nn.Module):
134
+ """
135
+ A downsampling layer with an optional convolution.
136
+ :param channels: channels in the inputs and outputs.
137
+ :param use_conv: a bool determining if a convolution is applied.
138
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
139
+ downsampling occurs in the inner-two dimensions.
140
+ """
141
+
142
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
143
+ super().__init__()
144
+ self.channels = channels
145
+ self.out_channels = out_channels or channels
146
+ self.use_conv = use_conv
147
+ self.dims = dims
148
+ stride = 2 if dims != 3 else (1, 2, 2)
149
+ if use_conv:
150
+ self.op = conv_nd(
151
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
152
+ )
153
+ else:
154
+ assert self.channels == self.out_channels
155
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
156
+
157
+ def forward(self, x):
158
+ assert x.shape[1] == self.channels
159
+ return self.op(x)
160
+
161
+
162
+ class ResBlock(TimestepBlock):
163
+ """
164
+ A residual block that can optionally change the number of channels.
165
+ :param channels: the number of input channels.
166
+ :param emb_channels: the number of timestep embedding channels.
167
+ :param dropout: the rate of dropout.
168
+ :param out_channels: if specified, the number of out channels.
169
+ :param use_conv: if True and out_channels is specified, use a spatial
170
+ convolution instead of a smaller 1x1 convolution to change the
171
+ channels in the skip connection.
172
+ :param dims: determines if the signal is 1D, 2D, or 3D.
173
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
174
+ :param up: if True, use this block for upsampling.
175
+ :param down: if True, use this block for downsampling.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ channels,
181
+ emb_channels,
182
+ dropout,
183
+ out_channels=None,
184
+ use_conv=False,
185
+ use_scale_shift_norm=False,
186
+ dims=2,
187
+ use_checkpoint=False,
188
+ up=False,
189
+ down=False,
190
+ ):
191
+ super().__init__()
192
+ self.channels = channels
193
+ self.emb_channels = emb_channels
194
+ self.dropout = dropout
195
+ self.out_channels = out_channels or channels
196
+ self.use_conv = use_conv
197
+ self.use_checkpoint = use_checkpoint
198
+ self.use_scale_shift_norm = use_scale_shift_norm
199
+
200
+ self.in_layers = nn.Sequential(
201
+ normalization(channels),
202
+ nn.SiLU(),
203
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
204
+ )
205
+
206
+ self.updown = up or down
207
+
208
+ if up:
209
+ self.h_upd = Upsample(channels, False, dims)
210
+ self.x_upd = Upsample(channels, False, dims)
211
+ elif down:
212
+ self.h_upd = Downsample(channels, False, dims)
213
+ self.x_upd = Downsample(channels, False, dims)
214
+ else:
215
+ self.h_upd = self.x_upd = nn.Identity()
216
+
217
+ self.emb_layers = nn.Sequential(
218
+ nn.SiLU(),
219
+ linear(
220
+ emb_channels,
221
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
222
+ ),
223
+ )
224
+ self.out_layers = nn.Sequential(
225
+ normalization(self.out_channels),
226
+ nn.SiLU(),
227
+ nn.Dropout(p=dropout),
228
+ zero_module(
229
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
230
+ ),
231
+ )
232
+
233
+ if self.out_channels == channels:
234
+ self.skip_connection = nn.Identity()
235
+ elif use_conv:
236
+ self.skip_connection = conv_nd(
237
+ dims, channels, self.out_channels, 3, padding=1
238
+ )
239
+ else:
240
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
241
+
242
+ def forward(self, x, emb):
243
+ """
244
+ Apply the block to a Tensor, conditioned on a timestep embedding.
245
+ :param x: an [N x C x ...] Tensor of features.
246
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
247
+ :return: an [N x C x ...] Tensor of outputs.
248
+ """
249
+ return checkpoint(
250
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
251
+ )
252
+
253
+
254
+ def _forward(self, x, emb):
255
+ if self.updown:
256
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
257
+ h = in_rest(x)
258
+ h = self.h_upd(h)
259
+ x = self.x_upd(x)
260
+ h = in_conv(h)
261
+ else:
262
+ h = self.in_layers(x)
263
+ emb_out = self.emb_layers(emb).type(h.dtype)
264
+ while len(emb_out.shape) < len(h.shape):
265
+ emb_out = emb_out[..., None]
266
+ if self.use_scale_shift_norm:
267
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
268
+ scale, shift = th.chunk(emb_out, 2, dim=1)
269
+ h = out_norm(h) * (1 + scale) + shift
270
+ h = out_rest(h)
271
+ else:
272
+ h = h + emb_out
273
+ h = self.out_layers(h)
274
+ return self.skip_connection(x) + h
275
+
276
+
277
+ class AttentionBlock(nn.Module):
278
+ """
279
+ An attention block that allows spatial positions to attend to each other.
280
+ Originally ported from here, but adapted to the N-d case.
281
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ channels,
287
+ num_heads=1,
288
+ num_head_channels=-1,
289
+ use_checkpoint=False,
290
+ use_new_attention_order=False,
291
+ ):
292
+ super().__init__()
293
+ self.channels = channels
294
+ if num_head_channels == -1:
295
+ self.num_heads = num_heads
296
+ else:
297
+ assert (
298
+ channels % num_head_channels == 0
299
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
300
+ self.num_heads = channels // num_head_channels
301
+ self.use_checkpoint = use_checkpoint
302
+ self.norm = normalization(channels)
303
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
304
+ if use_new_attention_order:
305
+ # split qkv before split heads
306
+ self.attention = QKVAttention(self.num_heads)
307
+ else:
308
+ # split heads before split qkv
309
+ self.attention = QKVAttentionLegacy(self.num_heads)
310
+
311
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
312
+
313
+ def forward(self, x):
314
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
315
+ #return pt_checkpoint(self._forward, x) # pytorch
316
+
317
+ def _forward(self, x):
318
+ b, c, *spatial = x.shape
319
+ x = x.reshape(b, c, -1)
320
+ qkv = self.qkv(self.norm(x))
321
+ h = self.attention(qkv)
322
+ h = self.proj_out(h)
323
+ return (x + h).reshape(b, c, *spatial)
324
+
325
+
326
+ def count_flops_attn(model, _x, y):
327
+ """
328
+ A counter for the `thop` package to count the operations in an
329
+ attention operation.
330
+ Meant to be used like:
331
+ macs, params = thop.profile(
332
+ model,
333
+ inputs=(inputs, timestamps),
334
+ custom_ops={QKVAttention: QKVAttention.count_flops},
335
+ )
336
+ """
337
+ b, c, *spatial = y[0].shape
338
+ num_spatial = int(np.prod(spatial))
339
+ # We perform two matmuls with the same number of ops.
340
+ # The first computes the weight matrix, the second computes
341
+ # the combination of the value vectors.
342
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
343
+ model.total_ops += th.DoubleTensor([matmul_ops])
344
+
345
+
346
+ class QKVAttentionLegacy(nn.Module):
347
+ """
348
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
349
+ """
350
+
351
+ def __init__(self, n_heads):
352
+ super().__init__()
353
+ self.n_heads = n_heads
354
+
355
+ def forward(self, qkv):
356
+ """
357
+ Apply QKV attention.
358
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
359
+ :return: an [N x (H * C) x T] tensor after attention.
360
+ """
361
+ bs, width, length = qkv.shape
362
+ assert width % (3 * self.n_heads) == 0
363
+ ch = width // (3 * self.n_heads)
364
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
365
+ scale = 1 / math.sqrt(math.sqrt(ch))
366
+ weight = th.einsum(
367
+ "bct,bcs->bts", q * scale, k * scale
368
+ ) # More stable with f16 than dividing afterwards
369
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
370
+ a = th.einsum("bts,bcs->bct", weight, v)
371
+ return a.reshape(bs, -1, length)
372
+
373
+ @staticmethod
374
+ def count_flops(model, _x, y):
375
+ return count_flops_attn(model, _x, y)
376
+
377
+
378
+ class QKVAttention(nn.Module):
379
+ """
380
+ A module which performs QKV attention and splits in a different order.
381
+ """
382
+
383
+ def __init__(self, n_heads):
384
+ super().__init__()
385
+ self.n_heads = n_heads
386
+
387
+ def forward(self, qkv):
388
+ """
389
+ Apply QKV attention.
390
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
391
+ :return: an [N x (H * C) x T] tensor after attention.
392
+ """
393
+ bs, width, length = qkv.shape
394
+ assert width % (3 * self.n_heads) == 0
395
+ ch = width // (3 * self.n_heads)
396
+ q, k, v = qkv.chunk(3, dim=1)
397
+ scale = 1 / math.sqrt(math.sqrt(ch))
398
+ weight = th.einsum(
399
+ "bct,bcs->bts",
400
+ (q * scale).view(bs * self.n_heads, ch, length),
401
+ (k * scale).view(bs * self.n_heads, ch, length),
402
+ ) # More stable with f16 than dividing afterwards
403
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
404
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
405
+ return a.reshape(bs, -1, length)
406
+
407
+ @staticmethod
408
+ def count_flops(model, _x, y):
409
+ return count_flops_attn(model, _x, y)
410
+
411
+
412
+ class UNetModel(nn.Module):
413
+ """
414
+ The full UNet model with attention and timestep embedding.
415
+ :param in_channels: channels in the input Tensor.
416
+ :param model_channels: base channel count for the model.
417
+ :param out_channels: channels in the output Tensor.
418
+ :param num_res_blocks: number of residual blocks per downsample.
419
+ :param attention_resolutions: a collection of downsample rates at which
420
+ attention will take place. May be a set, list, or tuple.
421
+ For example, if this contains 4, then at 4x downsampling, attention
422
+ will be used.
423
+ :param dropout: the dropout probability.
424
+ :param channel_mult: channel multiplier for each level of the UNet.
425
+ :param conv_resample: if True, use learned convolutions for upsampling and
426
+ downsampling.
427
+ :param dims: determines if the signal is 1D, 2D, or 3D.
428
+ :param num_classes: if specified (as an int), then this model will be
429
+ class-conditional with `num_classes` classes.
430
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
431
+ :param num_heads: the number of attention heads in each attention layer.
432
+ :param num_heads_channels: if specified, ignore num_heads and instead use
433
+ a fixed channel width per attention head.
434
+ :param num_heads_upsample: works with num_heads to set a different number
435
+ of heads for upsampling. Deprecated.
436
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
437
+ :param resblock_updown: use residual blocks for up/downsampling.
438
+ :param use_new_attention_order: use a different attention pattern for potentially
439
+ increased efficiency.
440
+ """
441
+
442
+ def __init__(
443
+ self,
444
+ image_size,
445
+ in_channels,
446
+ model_channels,
447
+ out_channels,
448
+ num_res_blocks,
449
+ attention_resolutions,
450
+ dropout=0,
451
+ channel_mult=(1, 2, 4, 8),
452
+ conv_resample=True,
453
+ dims=2,
454
+ num_classes=None,
455
+ use_checkpoint=False,
456
+ use_fp16=False,
457
+ num_heads=-1,
458
+ num_head_channels=-1,
459
+ num_heads_upsample=-1,
460
+ use_scale_shift_norm=False,
461
+ resblock_updown=False,
462
+ use_new_attention_order=False,
463
+ use_spatial_transformer=False, # custom transformer support
464
+ transformer_depth=1, # custom transformer support
465
+ context_dim=None, # custom transformer support
466
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
467
+ legacy=True,
468
+ disable_self_attentions=None,
469
+ num_attention_blocks=None,
470
+ disable_middle_self_attn=False,
471
+ use_linear_in_transformer=False,
472
+ ):
473
+ super().__init__()
474
+ if use_spatial_transformer:
475
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
476
+
477
+ if context_dim is not None:
478
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
479
+ from omegaconf.listconfig import ListConfig
480
+ if type(context_dim) == ListConfig:
481
+ context_dim = list(context_dim)
482
+
483
+ if num_heads_upsample == -1:
484
+ num_heads_upsample = num_heads
485
+
486
+ if num_heads == -1:
487
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
488
+
489
+ if num_head_channels == -1:
490
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
491
+
492
+ self.image_size = image_size
493
+ self.in_channels = in_channels
494
+ self.model_channels = model_channels
495
+ self.out_channels = out_channels
496
+ if isinstance(num_res_blocks, int):
497
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
498
+ else:
499
+ if len(num_res_blocks) != len(channel_mult):
500
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
501
+ "as a list/tuple (per-level) with the same length as channel_mult")
502
+ self.num_res_blocks = num_res_blocks
503
+ if disable_self_attentions is not None:
504
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
505
+ assert len(disable_self_attentions) == len(channel_mult)
506
+ if num_attention_blocks is not None:
507
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
508
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
509
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
510
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
511
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
512
+ f"attention will still not be set.")
513
+
514
+ self.attention_resolutions = attention_resolutions
515
+ self.dropout = dropout
516
+ self.channel_mult = channel_mult
517
+ self.conv_resample = conv_resample
518
+ self.num_classes = num_classes
519
+ self.use_checkpoint = use_checkpoint
520
+ self.dtype = th.float16 if use_fp16 else th.float32
521
+ self.num_heads = num_heads
522
+ self.num_head_channels = num_head_channels
523
+ self.num_heads_upsample = num_heads_upsample
524
+ self.predict_codebook_ids = n_embed is not None
525
+
526
+ time_embed_dim = model_channels * 4
527
+ self.time_embed = nn.Sequential(
528
+ linear(model_channels, time_embed_dim),
529
+ nn.SiLU(),
530
+ linear(time_embed_dim, time_embed_dim),
531
+ )
532
+
533
+ if self.num_classes is not None:
534
+ if isinstance(self.num_classes, int):
535
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
536
+ elif self.num_classes == "continuous":
537
+ print("setting up linear c_adm embedding layer")
538
+ self.label_emb = nn.Linear(1, time_embed_dim)
539
+ else:
540
+ raise ValueError()
541
+
542
+ self.input_blocks = nn.ModuleList(
543
+ [
544
+ TimestepEmbedSequential(
545
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
546
+ )
547
+ ]
548
+ )
549
+ self._feature_size = model_channels
550
+ input_block_chans = [model_channels]
551
+ ch = model_channels
552
+ ds = 1
553
+ for level, mult in enumerate(channel_mult):
554
+ for nr in range(self.num_res_blocks[level]):
555
+ layers = [
556
+ ResBlock(
557
+ ch,
558
+ time_embed_dim,
559
+ dropout,
560
+ out_channels=mult * model_channels,
561
+ dims=dims,
562
+ use_checkpoint=use_checkpoint,
563
+ use_scale_shift_norm=use_scale_shift_norm,
564
+ )
565
+ ]
566
+ ch = mult * model_channels
567
+ if ds in attention_resolutions:
568
+ if num_head_channels == -1:
569
+ dim_head = ch // num_heads
570
+ else:
571
+ num_heads = ch // num_head_channels
572
+ dim_head = num_head_channels
573
+ if legacy:
574
+ #num_heads = 1
575
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
576
+ if exists(disable_self_attentions):
577
+ disabled_sa = disable_self_attentions[level]
578
+ else:
579
+ disabled_sa = False
580
+
581
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
582
+ layers.append(
583
+ AttentionBlock(
584
+ ch,
585
+ use_checkpoint=use_checkpoint,
586
+ num_heads=num_heads,
587
+ num_head_channels=dim_head,
588
+ use_new_attention_order=use_new_attention_order,
589
+ ) if not use_spatial_transformer else SpatialTransformer(
590
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
591
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
592
+ use_checkpoint=use_checkpoint
593
+ )
594
+ )
595
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
596
+ self._feature_size += ch
597
+ input_block_chans.append(ch)
598
+ if level != len(channel_mult) - 1:
599
+ out_ch = ch
600
+ self.input_blocks.append(
601
+ TimestepEmbedSequential(
602
+ ResBlock(
603
+ ch,
604
+ time_embed_dim,
605
+ dropout,
606
+ out_channels=out_ch,
607
+ dims=dims,
608
+ use_checkpoint=use_checkpoint,
609
+ use_scale_shift_norm=use_scale_shift_norm,
610
+ down=True,
611
+ )
612
+ if resblock_updown
613
+ else Downsample(
614
+ ch, conv_resample, dims=dims, out_channels=out_ch
615
+ )
616
+ )
617
+ )
618
+ ch = out_ch
619
+ input_block_chans.append(ch)
620
+ ds *= 2
621
+ self._feature_size += ch
622
+
623
+ if num_head_channels == -1:
624
+ dim_head = ch // num_heads
625
+ else:
626
+ num_heads = ch // num_head_channels
627
+ dim_head = num_head_channels
628
+ if legacy:
629
+ #num_heads = 1
630
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
631
+ self.middle_block = TimestepEmbedSequential(
632
+ ResBlock(
633
+ ch,
634
+ time_embed_dim,
635
+ dropout,
636
+ dims=dims,
637
+ use_checkpoint=use_checkpoint,
638
+ use_scale_shift_norm=use_scale_shift_norm,
639
+ ),
640
+ AttentionBlock(
641
+ ch,
642
+ use_checkpoint=use_checkpoint,
643
+ num_heads=num_heads,
644
+ num_head_channels=dim_head,
645
+ use_new_attention_order=use_new_attention_order,
646
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
647
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
648
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
649
+ use_checkpoint=use_checkpoint
650
+ ),
651
+ ResBlock(
652
+ ch,
653
+ time_embed_dim,
654
+ dropout,
655
+ dims=dims,
656
+ use_checkpoint=use_checkpoint,
657
+ use_scale_shift_norm=use_scale_shift_norm,
658
+ ),
659
+ )
660
+ self._feature_size += ch
661
+
662
+ self.output_blocks = nn.ModuleList([])
663
+ for level, mult in list(enumerate(channel_mult))[::-1]:
664
+ for i in range(self.num_res_blocks[level] + 1):
665
+ ich = input_block_chans.pop()
666
+ layers = [
667
+ ResBlock(
668
+ ch + ich,
669
+ time_embed_dim,
670
+ dropout,
671
+ out_channels=model_channels * mult,
672
+ dims=dims,
673
+ use_checkpoint=use_checkpoint,
674
+ use_scale_shift_norm=use_scale_shift_norm,
675
+ )
676
+ ]
677
+ ch = model_channels * mult
678
+ if ds in attention_resolutions:
679
+ if num_head_channels == -1:
680
+ dim_head = ch // num_heads
681
+ else:
682
+ num_heads = ch // num_head_channels
683
+ dim_head = num_head_channels
684
+ if legacy:
685
+ #num_heads = 1
686
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
687
+ if exists(disable_self_attentions):
688
+ disabled_sa = disable_self_attentions[level]
689
+ else:
690
+ disabled_sa = False
691
+
692
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
693
+ layers.append(
694
+ AttentionBlock(
695
+ ch,
696
+ use_checkpoint=use_checkpoint,
697
+ num_heads=num_heads_upsample,
698
+ num_head_channels=dim_head,
699
+ use_new_attention_order=use_new_attention_order,
700
+ ) if not use_spatial_transformer else SpatialTransformer(
701
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
702
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
703
+ use_checkpoint=use_checkpoint
704
+ )
705
+ )
706
+ if level and i == self.num_res_blocks[level]:
707
+ out_ch = ch
708
+ layers.append(
709
+ ResBlock(
710
+ ch,
711
+ time_embed_dim,
712
+ dropout,
713
+ out_channels=out_ch,
714
+ dims=dims,
715
+ use_checkpoint=use_checkpoint,
716
+ use_scale_shift_norm=use_scale_shift_norm,
717
+ up=True,
718
+ )
719
+ if resblock_updown
720
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
721
+ )
722
+ ds //= 2
723
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
724
+ self._feature_size += ch
725
+
726
+ self.out = nn.Sequential(
727
+ normalization(ch),
728
+ nn.SiLU(),
729
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
730
+ )
731
+ if self.predict_codebook_ids:
732
+ self.id_predictor = nn.Sequential(
733
+ normalization(ch),
734
+ conv_nd(dims, model_channels, n_embed, 1),
735
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
736
+ )
737
+
738
+ def convert_to_fp16(self):
739
+ """
740
+ Convert the torso of the model to float16.
741
+ """
742
+ self.input_blocks.apply(convert_module_to_f16)
743
+ self.middle_block.apply(convert_module_to_f16)
744
+ self.output_blocks.apply(convert_module_to_f16)
745
+
746
+ def convert_to_fp32(self):
747
+ """
748
+ Convert the torso of the model to float32.
749
+ """
750
+ self.input_blocks.apply(convert_module_to_f32)
751
+ self.middle_block.apply(convert_module_to_f32)
752
+ self.output_blocks.apply(convert_module_to_f32)
753
+
754
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
755
+ """
756
+ Apply the model to an input batch.
757
+ :param x: an [N x C x ...] Tensor of inputs.
758
+ :param timesteps: a 1-D batch of timesteps.
759
+ :param context: conditioning plugged in via crossattn
760
+ :param y: an [N] Tensor of labels, if class-conditional.
761
+ :return: an [N x C x ...] Tensor of outputs.
762
+ """
763
+ assert (y is not None) == (
764
+ self.num_classes is not None
765
+ ), "must specify y if and only if the model is class-conditional"
766
+ hs = []
767
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
768
+ emb = self.time_embed(t_emb)
769
+
770
+ if self.num_classes is not None:
771
+ assert y.shape[0] == x.shape[0]
772
+ emb = emb + self.label_emb(y)
773
+
774
+ h = x.type(self.dtype)
775
+ for module in self.input_blocks:
776
+ h = module(h, emb, context)
777
+ hs.append(h)
778
+ h = self.middle_block(h, emb, context)
779
+ for module in self.output_blocks:
780
+ h = th.cat([h, hs.pop()], dim=1)
781
+ h = module(h, emb, context)
782
+ h = h.type(x.dtype)
783
+ if self.predict_codebook_ids:
784
+ return self.id_predictor(h)
785
+ else:
786
+ return self.out(h)
ldm/modules/diffusionmodules/upscaling.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
7
+ from ldm.util import default
8
+
9
+
10
+ class AbstractLowScaleModel(nn.Module):
11
+ # for concatenating a downsampled image to the latent representation
12
+ def __init__(self, noise_schedule_config=None):
13
+ super(AbstractLowScaleModel, self).__init__()
14
+ if noise_schedule_config is not None:
15
+ self.register_schedule(**noise_schedule_config)
16
+
17
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
18
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20
+ cosine_s=cosine_s)
21
+ alphas = 1. - betas
22
+ alphas_cumprod = np.cumprod(alphas, axis=0)
23
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24
+
25
+ timesteps, = betas.shape
26
+ self.num_timesteps = int(timesteps)
27
+ self.linear_start = linear_start
28
+ self.linear_end = linear_end
29
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30
+
31
+ to_torch = partial(torch.tensor, dtype=torch.float32)
32
+
33
+ self.register_buffer('betas', to_torch(betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43
+
44
+ def q_sample(self, x_start, t, noise=None):
45
+ noise = default(noise, lambda: torch.randn_like(x_start))
46
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
47
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
48
+
49
+ def forward(self, x):
50
+ return x, None
51
+
52
+ def decode(self, x):
53
+ return x
54
+
55
+
56
+ class SimpleImageConcat(AbstractLowScaleModel):
57
+ # no noise level conditioning
58
+ def __init__(self):
59
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
60
+ self.max_noise_level = 0
61
+
62
+ def forward(self, x):
63
+ # fix to constant noise level
64
+ return x, torch.zeros(x.shape[0], device=x.device).long()
65
+
66
+
67
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
68
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
69
+ super().__init__(noise_schedule_config=noise_schedule_config)
70
+ self.max_noise_level = max_noise_level
71
+
72
+ def forward(self, x, noise_level=None):
73
+ if noise_level is None:
74
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
75
+ else:
76
+ assert isinstance(noise_level, torch.Tensor)
77
+ z = self.q_sample(x, noise_level)
78
+ return z, noise_level
79
+
80
+
81
+
ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from ldm.util import instantiate_from_config
19
+
20
+
21
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
25
+ )
26
+
27
+ elif schedule == "cosine":
28
+ timesteps = (
29
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
30
+ )
31
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
32
+ alphas = torch.cos(alphas).pow(2)
33
+ alphas = alphas / alphas[0]
34
+ betas = 1 - alphas[1:] / alphas[:-1]
35
+ betas = np.clip(betas, a_min=0, a_max=0.999)
36
+
37
+ elif schedule == "sqrt_linear":
38
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
39
+ elif schedule == "sqrt":
40
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
41
+ else:
42
+ raise ValueError(f"schedule '{schedule}' unknown.")
43
+ return betas.numpy()
44
+
45
+
46
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
47
+ if ddim_discr_method == 'uniform':
48
+ c = num_ddpm_timesteps // num_ddim_timesteps
49
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
50
+ elif ddim_discr_method == 'quad':
51
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
52
+ else:
53
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
54
+
55
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
56
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
57
+ steps_out = ddim_timesteps + 1
58
+ if verbose:
59
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
60
+ return steps_out
61
+
62
+
63
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
64
+ # select alphas for computing the variance schedule
65
+ alphas = alphacums[ddim_timesteps]
66
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
67
+
68
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
69
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
70
+ if verbose:
71
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
72
+ print(f'For the chosen value of eta, which is {eta}, '
73
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
74
+ return sigmas, alphas, alphas_prev
75
+
76
+
77
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
78
+ """
79
+ Create a beta schedule that discretizes the given alpha_t_bar function,
80
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
81
+ :param num_diffusion_timesteps: the number of betas to produce.
82
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
83
+ produces the cumulative product of (1-beta) up to that
84
+ part of the diffusion process.
85
+ :param max_beta: the maximum beta to use; use values lower than 1 to
86
+ prevent singularities.
87
+ """
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
93
+ return np.array(betas)
94
+
95
+
96
+ def extract_into_tensor(a, t, x_shape):
97
+ b, *_ = t.shape
98
+ out = a.gather(-1, t)
99
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
100
+
101
+
102
+ def checkpoint(func, inputs, params, flag):
103
+ """
104
+ Evaluate a function without caching intermediate activations, allowing for
105
+ reduced memory at the expense of extra compute in the backward pass.
106
+ :param func: the function to evaluate.
107
+ :param inputs: the argument sequence to pass to `func`.
108
+ :param params: a sequence of parameters `func` depends on but does not
109
+ explicitly take as arguments.
110
+ :param flag: if False, disable gradient checkpointing.
111
+ """
112
+ if flag:
113
+ args = tuple(inputs) + tuple(params)
114
+ return CheckpointFunction.apply(func, len(inputs), *args)
115
+ else:
116
+ return func(*inputs)
117
+
118
+
119
+ class CheckpointFunction(torch.autograd.Function):
120
+ @staticmethod
121
+ def forward(ctx, run_function, length, *args):
122
+ ctx.run_function = run_function
123
+ ctx.input_tensors = list(args[:length])
124
+ ctx.input_params = list(args[length:])
125
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
126
+ "dtype": torch.get_autocast_gpu_dtype(),
127
+ "cache_enabled": torch.is_autocast_cache_enabled()}
128
+ with torch.no_grad():
129
+ output_tensors = ctx.run_function(*ctx.input_tensors)
130
+ return output_tensors
131
+
132
+ @staticmethod
133
+ def backward(ctx, *output_grads):
134
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
135
+ with torch.enable_grad(), \
136
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
137
+ # Fixes a bug where the first op in run_function modifies the
138
+ # Tensor storage in place, which is not allowed for detach()'d
139
+ # Tensors.
140
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
141
+ output_tensors = ctx.run_function(*shallow_copies)
142
+ input_grads = torch.autograd.grad(
143
+ output_tensors,
144
+ ctx.input_tensors + ctx.input_params,
145
+ output_grads,
146
+ allow_unused=True,
147
+ )
148
+ del ctx.input_tensors
149
+ del ctx.input_params
150
+ del output_tensors
151
+ return (None, None) + input_grads
152
+
153
+
154
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
155
+ """
156
+ Create sinusoidal timestep embeddings.
157
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
158
+ These may be fractional.
159
+ :param dim: the dimension of the output.
160
+ :param max_period: controls the minimum frequency of the embeddings.
161
+ :return: an [N x dim] Tensor of positional embeddings.
162
+ """
163
+ if not repeat_only:
164
+ half = dim // 2
165
+ freqs = torch.exp(
166
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
167
+ ).to(device=timesteps.device)
168
+ args = timesteps[:, None].float() * freqs[None]
169
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
170
+ if dim % 2:
171
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
172
+ else:
173
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
174
+ return embedding
175
+
176
+
177
+ def zero_module(module):
178
+ """
179
+ Zero out the parameters of a module and return it.
180
+ """
181
+ for p in module.parameters():
182
+ p.detach().zero_()
183
+ return module
184
+
185
+
186
+ def scale_module(module, scale):
187
+ """
188
+ Scale the parameters of a module and return it.
189
+ """
190
+ for p in module.parameters():
191
+ p.detach().mul_(scale)
192
+ return module
193
+
194
+
195
+ def mean_flat(tensor):
196
+ """
197
+ Take the mean over all non-batch dimensions.
198
+ """
199
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
200
+
201
+
202
+ def normalization(channels):
203
+ """
204
+ Make a standard normalization layer.
205
+ :param channels: number of input channels.
206
+ :return: an nn.Module for normalization.
207
+ """
208
+ return GroupNorm32(32, channels)
209
+
210
+
211
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
212
+ class SiLU(nn.Module):
213
+ def forward(self, x):
214
+ return x * torch.sigmoid(x)
215
+
216
+
217
+ class GroupNorm32(nn.GroupNorm):
218
+ def forward(self, x):
219
+ return super().forward(x.float()).type(x.dtype)
220
+
221
+ def conv_nd(dims, *args, **kwargs):
222
+ """
223
+ Create a 1D, 2D, or 3D convolution module.
224
+ """
225
+ if dims == 1:
226
+ return nn.Conv1d(*args, **kwargs)
227
+ elif dims == 2:
228
+ return nn.Conv2d(*args, **kwargs)
229
+ elif dims == 3:
230
+ return nn.Conv3d(*args, **kwargs)
231
+ raise ValueError(f"unsupported dimensions: {dims}")
232
+
233
+
234
+ def linear(*args, **kwargs):
235
+ """
236
+ Create a linear module.
237
+ """
238
+ return nn.Linear(*args, **kwargs)
239
+
240
+
241
+ def avg_pool_nd(dims, *args, **kwargs):
242
+ """
243
+ Create a 1D, 2D, or 3D average pooling module.
244
+ """
245
+ if dims == 1:
246
+ return nn.AvgPool1d(*args, **kwargs)
247
+ elif dims == 2:
248
+ return nn.AvgPool2d(*args, **kwargs)
249
+ elif dims == 3:
250
+ return nn.AvgPool3d(*args, **kwargs)
251
+ raise ValueError(f"unsupported dimensions: {dims}")
252
+
253
+
254
+ class HybridConditioner(nn.Module):
255
+
256
+ def __init__(self, c_concat_config, c_crossattn_config):
257
+ super().__init__()
258
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
259
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
260
+
261
+ def forward(self, c_concat, c_crossattn):
262
+ c_concat = self.concat_conditioner(c_concat)
263
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
264
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
265
+
266
+
267
+ def noise_like(shape, device, repeat=False):
268
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
269
+ noise = lambda: torch.randn(shape, device=device)
270
+ return repeat_noise() if repeat else noise()
ldm/modules/distributions/__init__.py ADDED
File without changes
ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
ldm/modules/ema.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1, dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ # remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.', '')
20
+ self.m_name2s_name.update({name: s_name})
21
+ self.register_buffer(s_name, p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def reset_num_updates(self):
26
+ del self.num_updates
27
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28
+
29
+ def forward(self, model):
30
+ decay = self.decay
31
+
32
+ if self.num_updates >= 0:
33
+ self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
+
36
+ one_minus_decay = 1.0 - decay
37
+
38
+ with torch.no_grad():
39
+ m_param = dict(model.named_parameters())
40
+ shadow_params = dict(self.named_buffers())
41
+
42
+ for key in m_param:
43
+ if m_param[key].requires_grad:
44
+ sname = self.m_name2s_name[key]
45
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47
+ else:
48
+ assert not key in self.m_name2s_name
49
+
50
+ def copy_to(self, model):
51
+ m_param = dict(model.named_parameters())
52
+ shadow_params = dict(self.named_buffers())
53
+ for key in m_param:
54
+ if m_param[key].requires_grad:
55
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def store(self, parameters):
60
+ """
61
+ Save the current parameters for restoring later.
62
+ Args:
63
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64
+ temporarily stored.
65
+ """
66
+ self.collected_params = [param.clone() for param in parameters]
67
+
68
+ def restore(self, parameters):
69
+ """
70
+ Restore the parameters stored with the `store` method.
71
+ Useful to validate the model with EMA parameters without affecting the
72
+ original optimization process. Store the parameters before the
73
+ `copy_to` method. After validation (or model saving), use this to
74
+ restore the former parameters.
75
+ Args:
76
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77
+ updated with the stored parameters.
78
+ """
79
+ for c_param, param in zip(self.collected_params, parameters):
80
+ param.data.copy_(c_param.data)
ldm/modules/encoders/__init__.py ADDED
File without changes
ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6
+
7
+ import open_clip
8
+ from ldm.util import default, count_params
9
+
10
+
11
+ class AbstractEncoder(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def encode(self, *args, **kwargs):
16
+ raise NotImplementedError
17
+
18
+
19
+ class IdentityEncoder(AbstractEncoder):
20
+
21
+ def encode(self, x):
22
+ return x
23
+
24
+
25
+ class ClassEmbedder(nn.Module):
26
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
27
+ super().__init__()
28
+ self.key = key
29
+ self.embedding = nn.Embedding(n_classes, embed_dim)
30
+ self.n_classes = n_classes
31
+ self.ucg_rate = ucg_rate
32
+
33
+ def forward(self, batch, key=None, disable_dropout=False):
34
+ if key is None:
35
+ key = self.key
36
+ # this is for use in crossattn
37
+ c = batch[key][:, None]
38
+ if self.ucg_rate > 0. and not disable_dropout:
39
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
40
+ c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
41
+ c = c.long()
42
+ c = self.embedding(c)
43
+ return c
44
+
45
+ def get_unconditional_conditioning(self, bs, device="cuda"):
46
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
47
+ uc = torch.ones((bs,), device=device) * uc_class
48
+ uc = {self.key: uc}
49
+ return uc
50
+
51
+
52
+ def disabled_train(self, mode=True):
53
+ """Overwrite model.train with this function to make sure train/eval mode
54
+ does not change anymore."""
55
+ return self
56
+
57
+
58
+ class FrozenT5Embedder(AbstractEncoder):
59
+ """Uses the T5 transformer encoder for text"""
60
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
61
+ super().__init__()
62
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
63
+ self.transformer = T5EncoderModel.from_pretrained(version)
64
+ self.device = device
65
+ self.max_length = max_length # TODO: typical value?
66
+ if freeze:
67
+ self.freeze()
68
+
69
+ def freeze(self):
70
+ self.transformer = self.transformer.eval()
71
+ #self.train = disabled_train
72
+ for param in self.parameters():
73
+ param.requires_grad = False
74
+
75
+ def forward(self, text):
76
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
77
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
78
+ tokens = batch_encoding["input_ids"].to(self.device)
79
+ outputs = self.transformer(input_ids=tokens)
80
+
81
+ z = outputs.last_hidden_state
82
+ return z
83
+
84
+ def encode(self, text):
85
+ return self(text)
86
+
87
+
88
+ class FrozenCLIPEmbedder(AbstractEncoder):
89
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
90
+ LAYERS = [
91
+ "last",
92
+ "pooled",
93
+ "hidden"
94
+ ]
95
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
96
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
97
+ super().__init__()
98
+ assert layer in self.LAYERS
99
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
100
+ self.transformer = CLIPTextModel.from_pretrained(version)
101
+ self.device = device
102
+ self.max_length = max_length
103
+ if freeze:
104
+ self.freeze()
105
+ self.layer = layer
106
+ self.layer_idx = layer_idx
107
+ if layer == "hidden":
108
+ assert layer_idx is not None
109
+ assert 0 <= abs(layer_idx) <= 12
110
+
111
+ def freeze(self):
112
+ self.transformer = self.transformer.eval()
113
+ #self.train = disabled_train
114
+ for param in self.parameters():
115
+ param.requires_grad = False
116
+
117
+ def forward(self, text):
118
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
119
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
120
+ tokens = batch_encoding["input_ids"].to(self.device)
121
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
122
+ if self.layer == "last":
123
+ z = outputs.last_hidden_state
124
+ elif self.layer == "pooled":
125
+ z = outputs.pooler_output[:, None, :]
126
+ else:
127
+ z = outputs.hidden_states[self.layer_idx]
128
+ return z
129
+
130
+ def encode(self, text):
131
+ return self(text)
132
+
133
+
134
+ class FrozenOpenCLIPEmbedder(AbstractEncoder):
135
+ """
136
+ Uses the OpenCLIP transformer encoder for text
137
+ """
138
+ LAYERS = [
139
+ #"pooled",
140
+ "last",
141
+ "penultimate"
142
+ ]
143
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
144
+ freeze=True, layer="last"):
145
+ super().__init__()
146
+ assert layer in self.LAYERS
147
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
148
+ del model.visual
149
+ self.model = model
150
+
151
+ self.device = device
152
+ self.max_length = max_length
153
+ if freeze:
154
+ self.freeze()
155
+ self.layer = layer
156
+ if self.layer == "last":
157
+ self.layer_idx = 0
158
+ elif self.layer == "penultimate":
159
+ self.layer_idx = 1
160
+ else:
161
+ raise NotImplementedError()
162
+
163
+ def freeze(self):
164
+ self.model = self.model.eval()
165
+ for param in self.parameters():
166
+ param.requires_grad = False
167
+
168
+ def forward(self, text):
169
+ tokens = open_clip.tokenize(text)
170
+ z = self.encode_with_transformer(tokens.to(self.device))
171
+ return z
172
+
173
+ def encode_with_transformer(self, text):
174
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
175
+ x = x + self.model.positional_embedding
176
+ x = x.permute(1, 0, 2) # NLD -> LND
177
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
178
+ x = x.permute(1, 0, 2) # LND -> NLD
179
+ x = self.model.ln_final(x)
180
+ return x
181
+
182
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
183
+ for i, r in enumerate(self.model.transformer.resblocks):
184
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
185
+ break
186
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
187
+ x = checkpoint(r, x, attn_mask)
188
+ else:
189
+ x = r(x, attn_mask=attn_mask)
190
+ return x
191
+
192
+ def encode(self, text):
193
+ return self(text)
194
+
195
+
196
+ class FrozenCLIPT5Encoder(AbstractEncoder):
197
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
198
+ clip_max_length=77, t5_max_length=77):
199
+ super().__init__()
200
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
201
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
202
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
203
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
204
+
205
+ def encode(self, text):
206
+ return self(text)
207
+
208
+ def forward(self, text):
209
+ clip_z = self.clip_encoder.encode(text)
210
+ t5_z = self.t5_encoder.encode(text)
211
+ return [clip_z, t5_z]
212
+
213
+
ldm/modules/image_degradation/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2
+ from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
ldm/modules/image_degradation/bsrgan.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ # --------------------------------------------
4
+ # Super-Resolution
5
+ # --------------------------------------------
6
+ #
7
+ # Kai Zhang (cskaizhang@gmail.com)
8
+ # https://github.com/cszn
9
+ # From 2019/03--2021/08
10
+ # --------------------------------------------
11
+ """
12
+
13
+ import numpy as np
14
+ import cv2
15
+ import torch
16
+
17
+ from functools import partial
18
+ import random
19
+ from scipy import ndimage
20
+ import scipy
21
+ import scipy.stats as ss
22
+ from scipy.interpolate import interp2d
23
+ from scipy.linalg import orth
24
+ import albumentations
25
+
26
+ import ldm.modules.image_degradation.utils_image as util
27
+
28
+
29
+ def modcrop_np(img, sf):
30
+ '''
31
+ Args:
32
+ img: numpy image, WxH or WxHxC
33
+ sf: scale factor
34
+ Return:
35
+ cropped image
36
+ '''
37
+ w, h = img.shape[:2]
38
+ im = np.copy(img)
39
+ return im[:w - w % sf, :h - h % sf, ...]
40
+
41
+
42
+ """
43
+ # --------------------------------------------
44
+ # anisotropic Gaussian kernels
45
+ # --------------------------------------------
46
+ """
47
+
48
+
49
+ def analytic_kernel(k):
50
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
51
+ k_size = k.shape[0]
52
+ # Calculate the big kernels size
53
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
54
+ # Loop over the small kernel to fill the big one
55
+ for r in range(k_size):
56
+ for c in range(k_size):
57
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
58
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
59
+ crop = k_size // 2
60
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
61
+ # Normalize to 1
62
+ return cropped_big_k / cropped_big_k.sum()
63
+
64
+
65
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
66
+ """ generate an anisotropic Gaussian kernel
67
+ Args:
68
+ ksize : e.g., 15, kernel size
69
+ theta : [0, pi], rotation angle range
70
+ l1 : [0.1,50], scaling of eigenvalues
71
+ l2 : [0.1,l1], scaling of eigenvalues
72
+ If l1 = l2, will get an isotropic Gaussian kernel.
73
+ Returns:
74
+ k : kernel
75
+ """
76
+
77
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
78
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
79
+ D = np.array([[l1, 0], [0, l2]])
80
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
81
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
82
+
83
+ return k
84
+
85
+
86
+ def gm_blur_kernel(mean, cov, size=15):
87
+ center = size / 2.0 + 0.5
88
+ k = np.zeros([size, size])
89
+ for y in range(size):
90
+ for x in range(size):
91
+ cy = y - center + 1
92
+ cx = x - center + 1
93
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
94
+
95
+ k = k / np.sum(k)
96
+ return k
97
+
98
+
99
+ def shift_pixel(x, sf, upper_left=True):
100
+ """shift pixel for super-resolution with different scale factors
101
+ Args:
102
+ x: WxHxC or WxH
103
+ sf: scale factor
104
+ upper_left: shift direction
105
+ """
106
+ h, w = x.shape[:2]
107
+ shift = (sf - 1) * 0.5
108
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
109
+ if upper_left:
110
+ x1 = xv + shift
111
+ y1 = yv + shift
112
+ else:
113
+ x1 = xv - shift
114
+ y1 = yv - shift
115
+
116
+ x1 = np.clip(x1, 0, w - 1)
117
+ y1 = np.clip(y1, 0, h - 1)
118
+
119
+ if x.ndim == 2:
120
+ x = interp2d(xv, yv, x)(x1, y1)
121
+ if x.ndim == 3:
122
+ for i in range(x.shape[-1]):
123
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
124
+
125
+ return x
126
+
127
+
128
+ def blur(x, k):
129
+ '''
130
+ x: image, NxcxHxW
131
+ k: kernel, Nx1xhxw
132
+ '''
133
+ n, c = x.shape[:2]
134
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
135
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
136
+ k = k.repeat(1, c, 1, 1)
137
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
138
+ x = x.view(1, -1, x.shape[2], x.shape[3])
139
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
140
+ x = x.view(n, c, x.shape[2], x.shape[3])
141
+
142
+ return x
143
+
144
+
145
+ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
146
+ """"
147
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
148
+ # Kai Zhang
149
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
150
+ # max_var = 2.5 * sf
151
+ """
152
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
153
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
154
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
155
+ theta = np.random.rand() * np.pi # random theta
156
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
157
+
158
+ # Set COV matrix using Lambdas and Theta
159
+ LAMBDA = np.diag([lambda_1, lambda_2])
160
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
161
+ [np.sin(theta), np.cos(theta)]])
162
+ SIGMA = Q @ LAMBDA @ Q.T
163
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
164
+
165
+ # Set expectation position (shifting kernel for aligned image)
166
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
167
+ MU = MU[None, None, :, None]
168
+
169
+ # Create meshgrid for Gaussian
170
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
171
+ Z = np.stack([X, Y], 2)[:, :, :, None]
172
+
173
+ # Calcualte Gaussian for every pixel of the kernel
174
+ ZZ = Z - MU
175
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
176
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
177
+
178
+ # shift the kernel so it will be centered
179
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
180
+
181
+ # Normalize the kernel and return
182
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
183
+ kernel = raw_kernel / np.sum(raw_kernel)
184
+ return kernel
185
+
186
+
187
+ def fspecial_gaussian(hsize, sigma):
188
+ hsize = [hsize, hsize]
189
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
190
+ std = sigma
191
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
192
+ arg = -(x * x + y * y) / (2 * std * std)
193
+ h = np.exp(arg)
194
+ h[h < scipy.finfo(float).eps * h.max()] = 0
195
+ sumh = h.sum()
196
+ if sumh != 0:
197
+ h = h / sumh
198
+ return h
199
+
200
+
201
+ def fspecial_laplacian(alpha):
202
+ alpha = max([0, min([alpha, 1])])
203
+ h1 = alpha / (alpha + 1)
204
+ h2 = (1 - alpha) / (alpha + 1)
205
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
206
+ h = np.array(h)
207
+ return h
208
+
209
+
210
+ def fspecial(filter_type, *args, **kwargs):
211
+ '''
212
+ python code from:
213
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
214
+ '''
215
+ if filter_type == 'gaussian':
216
+ return fspecial_gaussian(*args, **kwargs)
217
+ if filter_type == 'laplacian':
218
+ return fspecial_laplacian(*args, **kwargs)
219
+
220
+
221
+ """
222
+ # --------------------------------------------
223
+ # degradation models
224
+ # --------------------------------------------
225
+ """
226
+
227
+
228
+ def bicubic_degradation(x, sf=3):
229
+ '''
230
+ Args:
231
+ x: HxWxC image, [0, 1]
232
+ sf: down-scale factor
233
+ Return:
234
+ bicubicly downsampled LR image
235
+ '''
236
+ x = util.imresize_np(x, scale=1 / sf)
237
+ return x
238
+
239
+
240
+ def srmd_degradation(x, k, sf=3):
241
+ ''' blur + bicubic downsampling
242
+ Args:
243
+ x: HxWxC image, [0, 1]
244
+ k: hxw, double
245
+ sf: down-scale factor
246
+ Return:
247
+ downsampled LR image
248
+ Reference:
249
+ @inproceedings{zhang2018learning,
250
+ title={Learning a single convolutional super-resolution network for multiple degradations},
251
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
252
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
253
+ pages={3262--3271},
254
+ year={2018}
255
+ }
256
+ '''
257
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
258
+ x = bicubic_degradation(x, sf=sf)
259
+ return x
260
+
261
+
262
+ def dpsr_degradation(x, k, sf=3):
263
+ ''' bicubic downsampling + blur
264
+ Args:
265
+ x: HxWxC image, [0, 1]
266
+ k: hxw, double
267
+ sf: down-scale factor
268
+ Return:
269
+ downsampled LR image
270
+ Reference:
271
+ @inproceedings{zhang2019deep,
272
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
273
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
274
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
275
+ pages={1671--1681},
276
+ year={2019}
277
+ }
278
+ '''
279
+ x = bicubic_degradation(x, sf=sf)
280
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
281
+ return x
282
+
283
+
284
+ def classical_degradation(x, k, sf=3):
285
+ ''' blur + downsampling
286
+ Args:
287
+ x: HxWxC image, [0, 1]/[0, 255]
288
+ k: hxw, double
289
+ sf: down-scale factor
290
+ Return:
291
+ downsampled LR image
292
+ '''
293
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
294
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
295
+ st = 0
296
+ return x[st::sf, st::sf, ...]
297
+
298
+
299
+ def add_sharpening(img, weight=0.5, radius=50, threshold=10):
300
+ """USM sharpening. borrowed from real-ESRGAN
301
+ Input image: I; Blurry image: B.
302
+ 1. K = I + weight * (I - B)
303
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
304
+ 3. Blur mask:
305
+ 4. Out = Mask * K + (1 - Mask) * I
306
+ Args:
307
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
308
+ weight (float): Sharp weight. Default: 1.
309
+ radius (float): Kernel size of Gaussian blur. Default: 50.
310
+ threshold (int):
311
+ """
312
+ if radius % 2 == 0:
313
+ radius += 1
314
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
315
+ residual = img - blur
316
+ mask = np.abs(residual) * 255 > threshold
317
+ mask = mask.astype('float32')
318
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
319
+
320
+ K = img + weight * residual
321
+ K = np.clip(K, 0, 1)
322
+ return soft_mask * K + (1 - soft_mask) * img
323
+
324
+
325
+ def add_blur(img, sf=4):
326
+ wd2 = 4.0 + sf
327
+ wd = 2.0 + 0.2 * sf
328
+ if random.random() < 0.5:
329
+ l1 = wd2 * random.random()
330
+ l2 = wd2 * random.random()
331
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
332
+ else:
333
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
334
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
335
+
336
+ return img
337
+
338
+
339
+ def add_resize(img, sf=4):
340
+ rnum = np.random.rand()
341
+ if rnum > 0.8: # up
342
+ sf1 = random.uniform(1, 2)
343
+ elif rnum < 0.7: # down
344
+ sf1 = random.uniform(0.5 / sf, 1)
345
+ else:
346
+ sf1 = 1.0
347
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
348
+ img = np.clip(img, 0.0, 1.0)
349
+
350
+ return img
351
+
352
+
353
+ # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
354
+ # noise_level = random.randint(noise_level1, noise_level2)
355
+ # rnum = np.random.rand()
356
+ # if rnum > 0.6: # add color Gaussian noise
357
+ # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
358
+ # elif rnum < 0.4: # add grayscale Gaussian noise
359
+ # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
360
+ # else: # add noise
361
+ # L = noise_level2 / 255.
362
+ # D = np.diag(np.random.rand(3))
363
+ # U = orth(np.random.rand(3, 3))
364
+ # conv = np.dot(np.dot(np.transpose(U), D), U)
365
+ # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
366
+ # img = np.clip(img, 0.0, 1.0)
367
+ # return img
368
+
369
+ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
370
+ noise_level = random.randint(noise_level1, noise_level2)
371
+ rnum = np.random.rand()
372
+ if rnum > 0.6: # add color Gaussian noise
373
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
374
+ elif rnum < 0.4: # add grayscale Gaussian noise
375
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
376
+ else: # add noise
377
+ L = noise_level2 / 255.
378
+ D = np.diag(np.random.rand(3))
379
+ U = orth(np.random.rand(3, 3))
380
+ conv = np.dot(np.dot(np.transpose(U), D), U)
381
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
382
+ img = np.clip(img, 0.0, 1.0)
383
+ return img
384
+
385
+
386
+ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
387
+ noise_level = random.randint(noise_level1, noise_level2)
388
+ img = np.clip(img, 0.0, 1.0)
389
+ rnum = random.random()
390
+ if rnum > 0.6:
391
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
392
+ elif rnum < 0.4:
393
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
394
+ else:
395
+ L = noise_level2 / 255.
396
+ D = np.diag(np.random.rand(3))
397
+ U = orth(np.random.rand(3, 3))
398
+ conv = np.dot(np.dot(np.transpose(U), D), U)
399
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
400
+ img = np.clip(img, 0.0, 1.0)
401
+ return img
402
+
403
+
404
+ def add_Poisson_noise(img):
405
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
406
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
407
+ if random.random() < 0.5:
408
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
409
+ else:
410
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
411
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
412
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
413
+ img += noise_gray[:, :, np.newaxis]
414
+ img = np.clip(img, 0.0, 1.0)
415
+ return img
416
+
417
+
418
+ def add_JPEG_noise(img):
419
+ quality_factor = random.randint(30, 95)
420
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
421
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
422
+ img = cv2.imdecode(encimg, 1)
423
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
424
+ return img
425
+
426
+
427
+ def random_crop(lq, hq, sf=4, lq_patchsize=64):
428
+ h, w = lq.shape[:2]
429
+ rnd_h = random.randint(0, h - lq_patchsize)
430
+ rnd_w = random.randint(0, w - lq_patchsize)
431
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
432
+
433
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
434
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
435
+ return lq, hq
436
+
437
+
438
+ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
439
+ """
440
+ This is the degradation model of BSRGAN from the paper
441
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
442
+ ----------
443
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
444
+ sf: scale factor
445
+ isp_model: camera ISP model
446
+ Returns
447
+ -------
448
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
449
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
450
+ """
451
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
452
+ sf_ori = sf
453
+
454
+ h1, w1 = img.shape[:2]
455
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
456
+ h, w = img.shape[:2]
457
+
458
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
459
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
460
+
461
+ hq = img.copy()
462
+
463
+ if sf == 4 and random.random() < scale2_prob: # downsample1
464
+ if np.random.rand() < 0.5:
465
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
466
+ interpolation=random.choice([1, 2, 3]))
467
+ else:
468
+ img = util.imresize_np(img, 1 / 2, True)
469
+ img = np.clip(img, 0.0, 1.0)
470
+ sf = 2
471
+
472
+ shuffle_order = random.sample(range(7), 7)
473
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
474
+ if idx1 > idx2: # keep downsample3 last
475
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
476
+
477
+ for i in shuffle_order:
478
+
479
+ if i == 0:
480
+ img = add_blur(img, sf=sf)
481
+
482
+ elif i == 1:
483
+ img = add_blur(img, sf=sf)
484
+
485
+ elif i == 2:
486
+ a, b = img.shape[1], img.shape[0]
487
+ # downsample2
488
+ if random.random() < 0.75:
489
+ sf1 = random.uniform(1, 2 * sf)
490
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
491
+ interpolation=random.choice([1, 2, 3]))
492
+ else:
493
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
494
+ k_shifted = shift_pixel(k, sf)
495
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
496
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
497
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
498
+ img = np.clip(img, 0.0, 1.0)
499
+
500
+ elif i == 3:
501
+ # downsample3
502
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
503
+ img = np.clip(img, 0.0, 1.0)
504
+
505
+ elif i == 4:
506
+ # add Gaussian noise
507
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
508
+
509
+ elif i == 5:
510
+ # add JPEG noise
511
+ if random.random() < jpeg_prob:
512
+ img = add_JPEG_noise(img)
513
+
514
+ elif i == 6:
515
+ # add processed camera sensor noise
516
+ if random.random() < isp_prob and isp_model is not None:
517
+ with torch.no_grad():
518
+ img, hq = isp_model.forward(img.copy(), hq)
519
+
520
+ # add final JPEG compression noise
521
+ img = add_JPEG_noise(img)
522
+
523
+ # random crop
524
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
525
+
526
+ return img, hq
527
+
528
+
529
+ # todo no isp_model?
530
+ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
531
+ """
532
+ This is the degradation model of BSRGAN from the paper
533
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
534
+ ----------
535
+ sf: scale factor
536
+ isp_model: camera ISP model
537
+ Returns
538
+ -------
539
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
540
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
541
+ """
542
+ image = util.uint2single(image)
543
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
544
+ sf_ori = sf
545
+
546
+ h1, w1 = image.shape[:2]
547
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
548
+ h, w = image.shape[:2]
549
+
550
+ hq = image.copy()
551
+
552
+ if sf == 4 and random.random() < scale2_prob: # downsample1
553
+ if np.random.rand() < 0.5:
554
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
555
+ interpolation=random.choice([1, 2, 3]))
556
+ else:
557
+ image = util.imresize_np(image, 1 / 2, True)
558
+ image = np.clip(image, 0.0, 1.0)
559
+ sf = 2
560
+
561
+ shuffle_order = random.sample(range(7), 7)
562
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
563
+ if idx1 > idx2: # keep downsample3 last
564
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
565
+
566
+ for i in shuffle_order:
567
+
568
+ if i == 0:
569
+ image = add_blur(image, sf=sf)
570
+
571
+ elif i == 1:
572
+ image = add_blur(image, sf=sf)
573
+
574
+ elif i == 2:
575
+ a, b = image.shape[1], image.shape[0]
576
+ # downsample2
577
+ if random.random() < 0.75:
578
+ sf1 = random.uniform(1, 2 * sf)
579
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
580
+ interpolation=random.choice([1, 2, 3]))
581
+ else:
582
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
583
+ k_shifted = shift_pixel(k, sf)
584
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
585
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
586
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
587
+ image = np.clip(image, 0.0, 1.0)
588
+
589
+ elif i == 3:
590
+ # downsample3
591
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
592
+ image = np.clip(image, 0.0, 1.0)
593
+
594
+ elif i == 4:
595
+ # add Gaussian noise
596
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
597
+
598
+ elif i == 5:
599
+ # add JPEG noise
600
+ if random.random() < jpeg_prob:
601
+ image = add_JPEG_noise(image)
602
+
603
+ # elif i == 6:
604
+ # # add processed camera sensor noise
605
+ # if random.random() < isp_prob and isp_model is not None:
606
+ # with torch.no_grad():
607
+ # img, hq = isp_model.forward(img.copy(), hq)
608
+
609
+ # add final JPEG compression noise
610
+ image = add_JPEG_noise(image)
611
+ image = util.single2uint(image)
612
+ example = {"image":image}
613
+ return example
614
+
615
+
616
+ # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
617
+ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
618
+ """
619
+ This is an extended degradation model by combining
620
+ the degradation models of BSRGAN and Real-ESRGAN
621
+ ----------
622
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
623
+ sf: scale factor
624
+ use_shuffle: the degradation shuffle
625
+ use_sharp: sharpening the img
626
+ Returns
627
+ -------
628
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
629
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
630
+ """
631
+
632
+ h1, w1 = img.shape[:2]
633
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
634
+ h, w = img.shape[:2]
635
+
636
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
637
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
638
+
639
+ if use_sharp:
640
+ img = add_sharpening(img)
641
+ hq = img.copy()
642
+
643
+ if random.random() < shuffle_prob:
644
+ shuffle_order = random.sample(range(13), 13)
645
+ else:
646
+ shuffle_order = list(range(13))
647
+ # local shuffle for noise, JPEG is always the last one
648
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
649
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
650
+
651
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
652
+
653
+ for i in shuffle_order:
654
+ if i == 0:
655
+ img = add_blur(img, sf=sf)
656
+ elif i == 1:
657
+ img = add_resize(img, sf=sf)
658
+ elif i == 2:
659
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
660
+ elif i == 3:
661
+ if random.random() < poisson_prob:
662
+ img = add_Poisson_noise(img)
663
+ elif i == 4:
664
+ if random.random() < speckle_prob:
665
+ img = add_speckle_noise(img)
666
+ elif i == 5:
667
+ if random.random() < isp_prob and isp_model is not None:
668
+ with torch.no_grad():
669
+ img, hq = isp_model.forward(img.copy(), hq)
670
+ elif i == 6:
671
+ img = add_JPEG_noise(img)
672
+ elif i == 7:
673
+ img = add_blur(img, sf=sf)
674
+ elif i == 8:
675
+ img = add_resize(img, sf=sf)
676
+ elif i == 9:
677
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
678
+ elif i == 10:
679
+ if random.random() < poisson_prob:
680
+ img = add_Poisson_noise(img)
681
+ elif i == 11:
682
+ if random.random() < speckle_prob:
683
+ img = add_speckle_noise(img)
684
+ elif i == 12:
685
+ if random.random() < isp_prob and isp_model is not None:
686
+ with torch.no_grad():
687
+ img, hq = isp_model.forward(img.copy(), hq)
688
+ else:
689
+ print('check the shuffle!')
690
+
691
+ # resize to desired size
692
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
693
+ interpolation=random.choice([1, 2, 3]))
694
+
695
+ # add final JPEG compression noise
696
+ img = add_JPEG_noise(img)
697
+
698
+ # random crop
699
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
700
+
701
+ return img, hq
702
+
703
+
704
+ if __name__ == '__main__':
705
+ print("hey")
706
+ img = util.imread_uint('utils/test.png', 3)
707
+ print(img)
708
+ img = util.uint2single(img)
709
+ print(img)
710
+ img = img[:448, :448]
711
+ h = img.shape[0] // 4
712
+ print("resizing to", h)
713
+ sf = 4
714
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
715
+ for i in range(20):
716
+ print(i)
717
+ img_lq = deg_fn(img)
718
+ print(img_lq)
719
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
720
+ print(img_lq.shape)
721
+ print("bicubic", img_lq_bicubic.shape)
722
+ print(img_hq.shape)
723
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
724
+ interpolation=0)
725
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
726
+ interpolation=0)
727
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
728
+ util.imsave(img_concat, str(i) + '.png')
729
+
730
+
ldm/modules/image_degradation/bsrgan_light.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+
6
+ from functools import partial
7
+ import random
8
+ from scipy import ndimage
9
+ import scipy
10
+ import scipy.stats as ss
11
+ from scipy.interpolate import interp2d
12
+ from scipy.linalg import orth
13
+ import albumentations
14
+
15
+ import ldm.modules.image_degradation.utils_image as util
16
+
17
+ """
18
+ # --------------------------------------------
19
+ # Super-Resolution
20
+ # --------------------------------------------
21
+ #
22
+ # Kai Zhang (cskaizhang@gmail.com)
23
+ # https://github.com/cszn
24
+ # From 2019/03--2021/08
25
+ # --------------------------------------------
26
+ """
27
+
28
+ def modcrop_np(img, sf):
29
+ '''
30
+ Args:
31
+ img: numpy image, WxH or WxHxC
32
+ sf: scale factor
33
+ Return:
34
+ cropped image
35
+ '''
36
+ w, h = img.shape[:2]
37
+ im = np.copy(img)
38
+ return im[:w - w % sf, :h - h % sf, ...]
39
+
40
+
41
+ """
42
+ # --------------------------------------------
43
+ # anisotropic Gaussian kernels
44
+ # --------------------------------------------
45
+ """
46
+
47
+
48
+ def analytic_kernel(k):
49
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
50
+ k_size = k.shape[0]
51
+ # Calculate the big kernels size
52
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
53
+ # Loop over the small kernel to fill the big one
54
+ for r in range(k_size):
55
+ for c in range(k_size):
56
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
57
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
58
+ crop = k_size // 2
59
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
60
+ # Normalize to 1
61
+ return cropped_big_k / cropped_big_k.sum()
62
+
63
+
64
+ def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
65
+ """ generate an anisotropic Gaussian kernel
66
+ Args:
67
+ ksize : e.g., 15, kernel size
68
+ theta : [0, pi], rotation angle range
69
+ l1 : [0.1,50], scaling of eigenvalues
70
+ l2 : [0.1,l1], scaling of eigenvalues
71
+ If l1 = l2, will get an isotropic Gaussian kernel.
72
+ Returns:
73
+ k : kernel
74
+ """
75
+
76
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
77
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
78
+ D = np.array([[l1, 0], [0, l2]])
79
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
80
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
81
+
82
+ return k
83
+
84
+
85
+ def gm_blur_kernel(mean, cov, size=15):
86
+ center = size / 2.0 + 0.5
87
+ k = np.zeros([size, size])
88
+ for y in range(size):
89
+ for x in range(size):
90
+ cy = y - center + 1
91
+ cx = x - center + 1
92
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
93
+
94
+ k = k / np.sum(k)
95
+ return k
96
+
97
+
98
+ def shift_pixel(x, sf, upper_left=True):
99
+ """shift pixel for super-resolution with different scale factors
100
+ Args:
101
+ x: WxHxC or WxH
102
+ sf: scale factor
103
+ upper_left: shift direction
104
+ """
105
+ h, w = x.shape[:2]
106
+ shift = (sf - 1) * 0.5
107
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
108
+ if upper_left:
109
+ x1 = xv + shift
110
+ y1 = yv + shift
111
+ else:
112
+ x1 = xv - shift
113
+ y1 = yv - shift
114
+
115
+ x1 = np.clip(x1, 0, w - 1)
116
+ y1 = np.clip(y1, 0, h - 1)
117
+
118
+ if x.ndim == 2:
119
+ x = interp2d(xv, yv, x)(x1, y1)
120
+ if x.ndim == 3:
121
+ for i in range(x.shape[-1]):
122
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
123
+
124
+ return x
125
+
126
+
127
+ def blur(x, k):
128
+ '''
129
+ x: image, NxcxHxW
130
+ k: kernel, Nx1xhxw
131
+ '''
132
+ n, c = x.shape[:2]
133
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
134
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
135
+ k = k.repeat(1, c, 1, 1)
136
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
137
+ x = x.view(1, -1, x.shape[2], x.shape[3])
138
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
139
+ x = x.view(n, c, x.shape[2], x.shape[3])
140
+
141
+ return x
142
+
143
+
144
+ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
145
+ """"
146
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
147
+ # Kai Zhang
148
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
149
+ # max_var = 2.5 * sf
150
+ """
151
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
152
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
153
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
154
+ theta = np.random.rand() * np.pi # random theta
155
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
156
+
157
+ # Set COV matrix using Lambdas and Theta
158
+ LAMBDA = np.diag([lambda_1, lambda_2])
159
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
160
+ [np.sin(theta), np.cos(theta)]])
161
+ SIGMA = Q @ LAMBDA @ Q.T
162
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
163
+
164
+ # Set expectation position (shifting kernel for aligned image)
165
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
166
+ MU = MU[None, None, :, None]
167
+
168
+ # Create meshgrid for Gaussian
169
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
170
+ Z = np.stack([X, Y], 2)[:, :, :, None]
171
+
172
+ # Calcualte Gaussian for every pixel of the kernel
173
+ ZZ = Z - MU
174
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
175
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
176
+
177
+ # shift the kernel so it will be centered
178
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
179
+
180
+ # Normalize the kernel and return
181
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
182
+ kernel = raw_kernel / np.sum(raw_kernel)
183
+ return kernel
184
+
185
+
186
+ def fspecial_gaussian(hsize, sigma):
187
+ hsize = [hsize, hsize]
188
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
189
+ std = sigma
190
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
191
+ arg = -(x * x + y * y) / (2 * std * std)
192
+ h = np.exp(arg)
193
+ h[h < scipy.finfo(float).eps * h.max()] = 0
194
+ sumh = h.sum()
195
+ if sumh != 0:
196
+ h = h / sumh
197
+ return h
198
+
199
+
200
+ def fspecial_laplacian(alpha):
201
+ alpha = max([0, min([alpha, 1])])
202
+ h1 = alpha / (alpha + 1)
203
+ h2 = (1 - alpha) / (alpha + 1)
204
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
205
+ h = np.array(h)
206
+ return h
207
+
208
+
209
+ def fspecial(filter_type, *args, **kwargs):
210
+ '''
211
+ python code from:
212
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
213
+ '''
214
+ if filter_type == 'gaussian':
215
+ return fspecial_gaussian(*args, **kwargs)
216
+ if filter_type == 'laplacian':
217
+ return fspecial_laplacian(*args, **kwargs)
218
+
219
+
220
+ """
221
+ # --------------------------------------------
222
+ # degradation models
223
+ # --------------------------------------------
224
+ """
225
+
226
+
227
+ def bicubic_degradation(x, sf=3):
228
+ '''
229
+ Args:
230
+ x: HxWxC image, [0, 1]
231
+ sf: down-scale factor
232
+ Return:
233
+ bicubicly downsampled LR image
234
+ '''
235
+ x = util.imresize_np(x, scale=1 / sf)
236
+ return x
237
+
238
+
239
+ def srmd_degradation(x, k, sf=3):
240
+ ''' blur + bicubic downsampling
241
+ Args:
242
+ x: HxWxC image, [0, 1]
243
+ k: hxw, double
244
+ sf: down-scale factor
245
+ Return:
246
+ downsampled LR image
247
+ Reference:
248
+ @inproceedings{zhang2018learning,
249
+ title={Learning a single convolutional super-resolution network for multiple degradations},
250
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
251
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
252
+ pages={3262--3271},
253
+ year={2018}
254
+ }
255
+ '''
256
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
257
+ x = bicubic_degradation(x, sf=sf)
258
+ return x
259
+
260
+
261
+ def dpsr_degradation(x, k, sf=3):
262
+ ''' bicubic downsampling + blur
263
+ Args:
264
+ x: HxWxC image, [0, 1]
265
+ k: hxw, double
266
+ sf: down-scale factor
267
+ Return:
268
+ downsampled LR image
269
+ Reference:
270
+ @inproceedings{zhang2019deep,
271
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
272
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
273
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
274
+ pages={1671--1681},
275
+ year={2019}
276
+ }
277
+ '''
278
+ x = bicubic_degradation(x, sf=sf)
279
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
280
+ return x
281
+
282
+
283
+ def classical_degradation(x, k, sf=3):
284
+ ''' blur + downsampling
285
+ Args:
286
+ x: HxWxC image, [0, 1]/[0, 255]
287
+ k: hxw, double
288
+ sf: down-scale factor
289
+ Return:
290
+ downsampled LR image
291
+ '''
292
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
293
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
294
+ st = 0
295
+ return x[st::sf, st::sf, ...]
296
+
297
+
298
+ def add_sharpening(img, weight=0.5, radius=50, threshold=10):
299
+ """USM sharpening. borrowed from real-ESRGAN
300
+ Input image: I; Blurry image: B.
301
+ 1. K = I + weight * (I - B)
302
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
303
+ 3. Blur mask:
304
+ 4. Out = Mask * K + (1 - Mask) * I
305
+ Args:
306
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
307
+ weight (float): Sharp weight. Default: 1.
308
+ radius (float): Kernel size of Gaussian blur. Default: 50.
309
+ threshold (int):
310
+ """
311
+ if radius % 2 == 0:
312
+ radius += 1
313
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
314
+ residual = img - blur
315
+ mask = np.abs(residual) * 255 > threshold
316
+ mask = mask.astype('float32')
317
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
318
+
319
+ K = img + weight * residual
320
+ K = np.clip(K, 0, 1)
321
+ return soft_mask * K + (1 - soft_mask) * img
322
+
323
+
324
+ def add_blur(img, sf=4):
325
+ wd2 = 4.0 + sf
326
+ wd = 2.0 + 0.2 * sf
327
+
328
+ wd2 = wd2/4
329
+ wd = wd/4
330
+
331
+ if random.random() < 0.5:
332
+ l1 = wd2 * random.random()
333
+ l2 = wd2 * random.random()
334
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
335
+ else:
336
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
337
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
338
+
339
+ return img
340
+
341
+
342
+ def add_resize(img, sf=4):
343
+ rnum = np.random.rand()
344
+ if rnum > 0.8: # up
345
+ sf1 = random.uniform(1, 2)
346
+ elif rnum < 0.7: # down
347
+ sf1 = random.uniform(0.5 / sf, 1)
348
+ else:
349
+ sf1 = 1.0
350
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
351
+ img = np.clip(img, 0.0, 1.0)
352
+
353
+ return img
354
+
355
+
356
+ # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
357
+ # noise_level = random.randint(noise_level1, noise_level2)
358
+ # rnum = np.random.rand()
359
+ # if rnum > 0.6: # add color Gaussian noise
360
+ # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
361
+ # elif rnum < 0.4: # add grayscale Gaussian noise
362
+ # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
363
+ # else: # add noise
364
+ # L = noise_level2 / 255.
365
+ # D = np.diag(np.random.rand(3))
366
+ # U = orth(np.random.rand(3, 3))
367
+ # conv = np.dot(np.dot(np.transpose(U), D), U)
368
+ # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
369
+ # img = np.clip(img, 0.0, 1.0)
370
+ # return img
371
+
372
+ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
373
+ noise_level = random.randint(noise_level1, noise_level2)
374
+ rnum = np.random.rand()
375
+ if rnum > 0.6: # add color Gaussian noise
376
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
377
+ elif rnum < 0.4: # add grayscale Gaussian noise
378
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
379
+ else: # add noise
380
+ L = noise_level2 / 255.
381
+ D = np.diag(np.random.rand(3))
382
+ U = orth(np.random.rand(3, 3))
383
+ conv = np.dot(np.dot(np.transpose(U), D), U)
384
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
385
+ img = np.clip(img, 0.0, 1.0)
386
+ return img
387
+
388
+
389
+ def add_speckle_noise(img, noise_level1=2, noise_level2=25):
390
+ noise_level = random.randint(noise_level1, noise_level2)
391
+ img = np.clip(img, 0.0, 1.0)
392
+ rnum = random.random()
393
+ if rnum > 0.6:
394
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
395
+ elif rnum < 0.4:
396
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
397
+ else:
398
+ L = noise_level2 / 255.
399
+ D = np.diag(np.random.rand(3))
400
+ U = orth(np.random.rand(3, 3))
401
+ conv = np.dot(np.dot(np.transpose(U), D), U)
402
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
403
+ img = np.clip(img, 0.0, 1.0)
404
+ return img
405
+
406
+
407
+ def add_Poisson_noise(img):
408
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
409
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
410
+ if random.random() < 0.5:
411
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
412
+ else:
413
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
414
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
415
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
416
+ img += noise_gray[:, :, np.newaxis]
417
+ img = np.clip(img, 0.0, 1.0)
418
+ return img
419
+
420
+
421
+ def add_JPEG_noise(img):
422
+ quality_factor = random.randint(80, 95)
423
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
424
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
425
+ img = cv2.imdecode(encimg, 1)
426
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
427
+ return img
428
+
429
+
430
+ def random_crop(lq, hq, sf=4, lq_patchsize=64):
431
+ h, w = lq.shape[:2]
432
+ rnd_h = random.randint(0, h - lq_patchsize)
433
+ rnd_w = random.randint(0, w - lq_patchsize)
434
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
435
+
436
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
437
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
438
+ return lq, hq
439
+
440
+
441
+ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
442
+ """
443
+ This is the degradation model of BSRGAN from the paper
444
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
445
+ ----------
446
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
447
+ sf: scale factor
448
+ isp_model: camera ISP model
449
+ Returns
450
+ -------
451
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
452
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
453
+ """
454
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
455
+ sf_ori = sf
456
+
457
+ h1, w1 = img.shape[:2]
458
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
459
+ h, w = img.shape[:2]
460
+
461
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
462
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
463
+
464
+ hq = img.copy()
465
+
466
+ if sf == 4 and random.random() < scale2_prob: # downsample1
467
+ if np.random.rand() < 0.5:
468
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
469
+ interpolation=random.choice([1, 2, 3]))
470
+ else:
471
+ img = util.imresize_np(img, 1 / 2, True)
472
+ img = np.clip(img, 0.0, 1.0)
473
+ sf = 2
474
+
475
+ shuffle_order = random.sample(range(7), 7)
476
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
477
+ if idx1 > idx2: # keep downsample3 last
478
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
479
+
480
+ for i in shuffle_order:
481
+
482
+ if i == 0:
483
+ img = add_blur(img, sf=sf)
484
+
485
+ elif i == 1:
486
+ img = add_blur(img, sf=sf)
487
+
488
+ elif i == 2:
489
+ a, b = img.shape[1], img.shape[0]
490
+ # downsample2
491
+ if random.random() < 0.75:
492
+ sf1 = random.uniform(1, 2 * sf)
493
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
494
+ interpolation=random.choice([1, 2, 3]))
495
+ else:
496
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
497
+ k_shifted = shift_pixel(k, sf)
498
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
499
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
500
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
501
+ img = np.clip(img, 0.0, 1.0)
502
+
503
+ elif i == 3:
504
+ # downsample3
505
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
506
+ img = np.clip(img, 0.0, 1.0)
507
+
508
+ elif i == 4:
509
+ # add Gaussian noise
510
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
511
+
512
+ elif i == 5:
513
+ # add JPEG noise
514
+ if random.random() < jpeg_prob:
515
+ img = add_JPEG_noise(img)
516
+
517
+ elif i == 6:
518
+ # add processed camera sensor noise
519
+ if random.random() < isp_prob and isp_model is not None:
520
+ with torch.no_grad():
521
+ img, hq = isp_model.forward(img.copy(), hq)
522
+
523
+ # add final JPEG compression noise
524
+ img = add_JPEG_noise(img)
525
+
526
+ # random crop
527
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
528
+
529
+ return img, hq
530
+
531
+
532
+ # todo no isp_model?
533
+ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
534
+ """
535
+ This is the degradation model of BSRGAN from the paper
536
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
537
+ ----------
538
+ sf: scale factor
539
+ isp_model: camera ISP model
540
+ Returns
541
+ -------
542
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
543
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
544
+ """
545
+ image = util.uint2single(image)
546
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
547
+ sf_ori = sf
548
+
549
+ h1, w1 = image.shape[:2]
550
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
551
+ h, w = image.shape[:2]
552
+
553
+ hq = image.copy()
554
+
555
+ if sf == 4 and random.random() < scale2_prob: # downsample1
556
+ if np.random.rand() < 0.5:
557
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
558
+ interpolation=random.choice([1, 2, 3]))
559
+ else:
560
+ image = util.imresize_np(image, 1 / 2, True)
561
+ image = np.clip(image, 0.0, 1.0)
562
+ sf = 2
563
+
564
+ shuffle_order = random.sample(range(7), 7)
565
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
566
+ if idx1 > idx2: # keep downsample3 last
567
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
568
+
569
+ for i in shuffle_order:
570
+
571
+ if i == 0:
572
+ image = add_blur(image, sf=sf)
573
+
574
+ # elif i == 1:
575
+ # image = add_blur(image, sf=sf)
576
+
577
+ if i == 0:
578
+ pass
579
+
580
+ elif i == 2:
581
+ a, b = image.shape[1], image.shape[0]
582
+ # downsample2
583
+ if random.random() < 0.8:
584
+ sf1 = random.uniform(1, 2 * sf)
585
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
586
+ interpolation=random.choice([1, 2, 3]))
587
+ else:
588
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
589
+ k_shifted = shift_pixel(k, sf)
590
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
591
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
592
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
593
+
594
+ image = np.clip(image, 0.0, 1.0)
595
+
596
+ elif i == 3:
597
+ # downsample3
598
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
599
+ image = np.clip(image, 0.0, 1.0)
600
+
601
+ elif i == 4:
602
+ # add Gaussian noise
603
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
604
+
605
+ elif i == 5:
606
+ # add JPEG noise
607
+ if random.random() < jpeg_prob:
608
+ image = add_JPEG_noise(image)
609
+ #
610
+ # elif i == 6:
611
+ # # add processed camera sensor noise
612
+ # if random.random() < isp_prob and isp_model is not None:
613
+ # with torch.no_grad():
614
+ # img, hq = isp_model.forward(img.copy(), hq)
615
+
616
+ # add final JPEG compression noise
617
+ image = add_JPEG_noise(image)
618
+ image = util.single2uint(image)
619
+ if up:
620
+ image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
621
+ example = {"image": image}
622
+ return example
623
+
624
+
625
+
626
+
627
+ if __name__ == '__main__':
628
+ print("hey")
629
+ img = util.imread_uint('utils/test.png', 3)
630
+ img = img[:448, :448]
631
+ h = img.shape[0] // 4
632
+ print("resizing to", h)
633
+ sf = 4
634
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
635
+ for i in range(20):
636
+ print(i)
637
+ img_hq = img
638
+ img_lq = deg_fn(img)["image"]
639
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
640
+ print(img_lq)
641
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
642
+ print(img_lq.shape)
643
+ print("bicubic", img_lq_bicubic.shape)
644
+ print(img_hq.shape)
645
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
646
+ interpolation=0)
647
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
648
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
649
+ interpolation=0)
650
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
651
+ util.imsave(img_concat, str(i) + '.png')
ldm/modules/image_degradation/utils/test.png ADDED
ldm/modules/image_degradation/utils_image.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ from torchvision.utils import make_grid
8
+ from datetime import datetime
9
+ #import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
10
+
11
+
12
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
13
+
14
+
15
+ '''
16
+ # --------------------------------------------
17
+ # Kai Zhang (github: https://github.com/cszn)
18
+ # 03/Mar/2019
19
+ # --------------------------------------------
20
+ # https://github.com/twhui/SRGAN-pyTorch
21
+ # https://github.com/xinntao/BasicSR
22
+ # --------------------------------------------
23
+ '''
24
+
25
+
26
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
27
+
28
+
29
+ def is_image_file(filename):
30
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
31
+
32
+
33
+ def get_timestamp():
34
+ return datetime.now().strftime('%y%m%d-%H%M%S')
35
+
36
+
37
+ def imshow(x, title=None, cbar=False, figsize=None):
38
+ plt.figure(figsize=figsize)
39
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
40
+ if title:
41
+ plt.title(title)
42
+ if cbar:
43
+ plt.colorbar()
44
+ plt.show()
45
+
46
+
47
+ def surf(Z, cmap='rainbow', figsize=None):
48
+ plt.figure(figsize=figsize)
49
+ ax3 = plt.axes(projection='3d')
50
+
51
+ w, h = Z.shape[:2]
52
+ xx = np.arange(0,w,1)
53
+ yy = np.arange(0,h,1)
54
+ X, Y = np.meshgrid(xx, yy)
55
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
56
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
57
+ plt.show()
58
+
59
+
60
+ '''
61
+ # --------------------------------------------
62
+ # get image pathes
63
+ # --------------------------------------------
64
+ '''
65
+
66
+
67
+ def get_image_paths(dataroot):
68
+ paths = None # return None if dataroot is None
69
+ if dataroot is not None:
70
+ paths = sorted(_get_paths_from_images(dataroot))
71
+ return paths
72
+
73
+
74
+ def _get_paths_from_images(path):
75
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
76
+ images = []
77
+ for dirpath, _, fnames in sorted(os.walk(path)):
78
+ for fname in sorted(fnames):
79
+ if is_image_file(fname):
80
+ img_path = os.path.join(dirpath, fname)
81
+ images.append(img_path)
82
+ assert images, '{:s} has no valid image file'.format(path)
83
+ return images
84
+
85
+
86
+ '''
87
+ # --------------------------------------------
88
+ # split large images into small images
89
+ # --------------------------------------------
90
+ '''
91
+
92
+
93
+ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
94
+ w, h = img.shape[:2]
95
+ patches = []
96
+ if w > p_max and h > p_max:
97
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
98
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
99
+ w1.append(w-p_size)
100
+ h1.append(h-p_size)
101
+ # print(w1)
102
+ # print(h1)
103
+ for i in w1:
104
+ for j in h1:
105
+ patches.append(img[i:i+p_size, j:j+p_size,:])
106
+ else:
107
+ patches.append(img)
108
+
109
+ return patches
110
+
111
+
112
+ def imssave(imgs, img_path):
113
+ """
114
+ imgs: list, N images of size WxHxC
115
+ """
116
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
117
+
118
+ for i, img in enumerate(imgs):
119
+ if img.ndim == 3:
120
+ img = img[:, :, [2, 1, 0]]
121
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
122
+ cv2.imwrite(new_path, img)
123
+
124
+
125
+ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
126
+ """
127
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
128
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
129
+ will be splitted.
130
+ Args:
131
+ original_dataroot:
132
+ taget_dataroot:
133
+ p_size: size of small images
134
+ p_overlap: patch size in training is a good choice
135
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
136
+ """
137
+ paths = get_image_paths(original_dataroot)
138
+ for img_path in paths:
139
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
140
+ img = imread_uint(img_path, n_channels=n_channels)
141
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
142
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
143
+ #if original_dataroot == taget_dataroot:
144
+ #del img_path
145
+
146
+ '''
147
+ # --------------------------------------------
148
+ # makedir
149
+ # --------------------------------------------
150
+ '''
151
+
152
+
153
+ def mkdir(path):
154
+ if not os.path.exists(path):
155
+ os.makedirs(path)
156
+
157
+
158
+ def mkdirs(paths):
159
+ if isinstance(paths, str):
160
+ mkdir(paths)
161
+ else:
162
+ for path in paths:
163
+ mkdir(path)
164
+
165
+
166
+ def mkdir_and_rename(path):
167
+ if os.path.exists(path):
168
+ new_name = path + '_archived_' + get_timestamp()
169
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
170
+ os.rename(path, new_name)
171
+ os.makedirs(path)
172
+
173
+
174
+ '''
175
+ # --------------------------------------------
176
+ # read image from path
177
+ # opencv is fast, but read BGR numpy image
178
+ # --------------------------------------------
179
+ '''
180
+
181
+
182
+ # --------------------------------------------
183
+ # get uint8 image of size HxWxn_channles (RGB)
184
+ # --------------------------------------------
185
+ def imread_uint(path, n_channels=3):
186
+ # input: path
187
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
188
+ if n_channels == 1:
189
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
190
+ img = np.expand_dims(img, axis=2) # HxWx1
191
+ elif n_channels == 3:
192
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
193
+ if img.ndim == 2:
194
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
195
+ else:
196
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
197
+ return img
198
+
199
+
200
+ # --------------------------------------------
201
+ # matlab's imwrite
202
+ # --------------------------------------------
203
+ def imsave(img, img_path):
204
+ img = np.squeeze(img)
205
+ if img.ndim == 3:
206
+ img = img[:, :, [2, 1, 0]]
207
+ cv2.imwrite(img_path, img)
208
+
209
+ def imwrite(img, img_path):
210
+ img = np.squeeze(img)
211
+ if img.ndim == 3:
212
+ img = img[:, :, [2, 1, 0]]
213
+ cv2.imwrite(img_path, img)
214
+
215
+
216
+
217
+ # --------------------------------------------
218
+ # get single image of size HxWxn_channles (BGR)
219
+ # --------------------------------------------
220
+ def read_img(path):
221
+ # read image by cv2
222
+ # return: Numpy float32, HWC, BGR, [0,1]
223
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
224
+ img = img.astype(np.float32) / 255.
225
+ if img.ndim == 2:
226
+ img = np.expand_dims(img, axis=2)
227
+ # some images have 4 channels
228
+ if img.shape[2] > 3:
229
+ img = img[:, :, :3]
230
+ return img
231
+
232
+
233
+ '''
234
+ # --------------------------------------------
235
+ # image format conversion
236
+ # --------------------------------------------
237
+ # numpy(single) <---> numpy(unit)
238
+ # numpy(single) <---> tensor
239
+ # numpy(unit) <---> tensor
240
+ # --------------------------------------------
241
+ '''
242
+
243
+
244
+ # --------------------------------------------
245
+ # numpy(single) [0, 1] <---> numpy(unit)
246
+ # --------------------------------------------
247
+
248
+
249
+ def uint2single(img):
250
+
251
+ return np.float32(img/255.)
252
+
253
+
254
+ def single2uint(img):
255
+
256
+ return np.uint8((img.clip(0, 1)*255.).round())
257
+
258
+
259
+ def uint162single(img):
260
+
261
+ return np.float32(img/65535.)
262
+
263
+
264
+ def single2uint16(img):
265
+
266
+ return np.uint16((img.clip(0, 1)*65535.).round())
267
+
268
+
269
+ # --------------------------------------------
270
+ # numpy(unit) (HxWxC or HxW) <---> tensor
271
+ # --------------------------------------------
272
+
273
+
274
+ # convert uint to 4-dimensional torch tensor
275
+ def uint2tensor4(img):
276
+ if img.ndim == 2:
277
+ img = np.expand_dims(img, axis=2)
278
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
279
+
280
+
281
+ # convert uint to 3-dimensional torch tensor
282
+ def uint2tensor3(img):
283
+ if img.ndim == 2:
284
+ img = np.expand_dims(img, axis=2)
285
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
286
+
287
+
288
+ # convert 2/3/4-dimensional torch tensor to uint
289
+ def tensor2uint(img):
290
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
291
+ if img.ndim == 3:
292
+ img = np.transpose(img, (1, 2, 0))
293
+ return np.uint8((img*255.0).round())
294
+
295
+
296
+ # --------------------------------------------
297
+ # numpy(single) (HxWxC) <---> tensor
298
+ # --------------------------------------------
299
+
300
+
301
+ # convert single (HxWxC) to 3-dimensional torch tensor
302
+ def single2tensor3(img):
303
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
304
+
305
+
306
+ # convert single (HxWxC) to 4-dimensional torch tensor
307
+ def single2tensor4(img):
308
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
309
+
310
+
311
+ # convert torch tensor to single
312
+ def tensor2single(img):
313
+ img = img.data.squeeze().float().cpu().numpy()
314
+ if img.ndim == 3:
315
+ img = np.transpose(img, (1, 2, 0))
316
+
317
+ return img
318
+
319
+ # convert torch tensor to single
320
+ def tensor2single3(img):
321
+ img = img.data.squeeze().float().cpu().numpy()
322
+ if img.ndim == 3:
323
+ img = np.transpose(img, (1, 2, 0))
324
+ elif img.ndim == 2:
325
+ img = np.expand_dims(img, axis=2)
326
+ return img
327
+
328
+
329
+ def single2tensor5(img):
330
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
331
+
332
+
333
+ def single32tensor5(img):
334
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
335
+
336
+
337
+ def single42tensor4(img):
338
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
339
+
340
+
341
+ # from skimage.io import imread, imsave
342
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
343
+ '''
344
+ Converts a torch Tensor into an image Numpy array of BGR channel order
345
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
346
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
347
+ '''
348
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
349
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
350
+ n_dim = tensor.dim()
351
+ if n_dim == 4:
352
+ n_img = len(tensor)
353
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
354
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
355
+ elif n_dim == 3:
356
+ img_np = tensor.numpy()
357
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
358
+ elif n_dim == 2:
359
+ img_np = tensor.numpy()
360
+ else:
361
+ raise TypeError(
362
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
363
+ if out_type == np.uint8:
364
+ img_np = (img_np * 255.0).round()
365
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
366
+ return img_np.astype(out_type)
367
+
368
+
369
+ '''
370
+ # --------------------------------------------
371
+ # Augmentation, flipe and/or rotate
372
+ # --------------------------------------------
373
+ # The following two are enough.
374
+ # (1) augmet_img: numpy image of WxHxC or WxH
375
+ # (2) augment_img_tensor4: tensor image 1xCxWxH
376
+ # --------------------------------------------
377
+ '''
378
+
379
+
380
+ def augment_img(img, mode=0):
381
+ '''Kai Zhang (github: https://github.com/cszn)
382
+ '''
383
+ if mode == 0:
384
+ return img
385
+ elif mode == 1:
386
+ return np.flipud(np.rot90(img))
387
+ elif mode == 2:
388
+ return np.flipud(img)
389
+ elif mode == 3:
390
+ return np.rot90(img, k=3)
391
+ elif mode == 4:
392
+ return np.flipud(np.rot90(img, k=2))
393
+ elif mode == 5:
394
+ return np.rot90(img)
395
+ elif mode == 6:
396
+ return np.rot90(img, k=2)
397
+ elif mode == 7:
398
+ return np.flipud(np.rot90(img, k=3))
399
+
400
+
401
+ def augment_img_tensor4(img, mode=0):
402
+ '''Kai Zhang (github: https://github.com/cszn)
403
+ '''
404
+ if mode == 0:
405
+ return img
406
+ elif mode == 1:
407
+ return img.rot90(1, [2, 3]).flip([2])
408
+ elif mode == 2:
409
+ return img.flip([2])
410
+ elif mode == 3:
411
+ return img.rot90(3, [2, 3])
412
+ elif mode == 4:
413
+ return img.rot90(2, [2, 3]).flip([2])
414
+ elif mode == 5:
415
+ return img.rot90(1, [2, 3])
416
+ elif mode == 6:
417
+ return img.rot90(2, [2, 3])
418
+ elif mode == 7:
419
+ return img.rot90(3, [2, 3]).flip([2])
420
+
421
+
422
+ def augment_img_tensor(img, mode=0):
423
+ '''Kai Zhang (github: https://github.com/cszn)
424
+ '''
425
+ img_size = img.size()
426
+ img_np = img.data.cpu().numpy()
427
+ if len(img_size) == 3:
428
+ img_np = np.transpose(img_np, (1, 2, 0))
429
+ elif len(img_size) == 4:
430
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
431
+ img_np = augment_img(img_np, mode=mode)
432
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
433
+ if len(img_size) == 3:
434
+ img_tensor = img_tensor.permute(2, 0, 1)
435
+ elif len(img_size) == 4:
436
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
437
+
438
+ return img_tensor.type_as(img)
439
+
440
+
441
+ def augment_img_np3(img, mode=0):
442
+ if mode == 0:
443
+ return img
444
+ elif mode == 1:
445
+ return img.transpose(1, 0, 2)
446
+ elif mode == 2:
447
+ return img[::-1, :, :]
448
+ elif mode == 3:
449
+ img = img[::-1, :, :]
450
+ img = img.transpose(1, 0, 2)
451
+ return img
452
+ elif mode == 4:
453
+ return img[:, ::-1, :]
454
+ elif mode == 5:
455
+ img = img[:, ::-1, :]
456
+ img = img.transpose(1, 0, 2)
457
+ return img
458
+ elif mode == 6:
459
+ img = img[:, ::-1, :]
460
+ img = img[::-1, :, :]
461
+ return img
462
+ elif mode == 7:
463
+ img = img[:, ::-1, :]
464
+ img = img[::-1, :, :]
465
+ img = img.transpose(1, 0, 2)
466
+ return img
467
+
468
+
469
+ def augment_imgs(img_list, hflip=True, rot=True):
470
+ # horizontal flip OR rotate
471
+ hflip = hflip and random.random() < 0.5
472
+ vflip = rot and random.random() < 0.5
473
+ rot90 = rot and random.random() < 0.5
474
+
475
+ def _augment(img):
476
+ if hflip:
477
+ img = img[:, ::-1, :]
478
+ if vflip:
479
+ img = img[::-1, :, :]
480
+ if rot90:
481
+ img = img.transpose(1, 0, 2)
482
+ return img
483
+
484
+ return [_augment(img) for img in img_list]
485
+
486
+
487
+ '''
488
+ # --------------------------------------------
489
+ # modcrop and shave
490
+ # --------------------------------------------
491
+ '''
492
+
493
+
494
+ def modcrop(img_in, scale):
495
+ # img_in: Numpy, HWC or HW
496
+ img = np.copy(img_in)
497
+ if img.ndim == 2:
498
+ H, W = img.shape
499
+ H_r, W_r = H % scale, W % scale
500
+ img = img[:H - H_r, :W - W_r]
501
+ elif img.ndim == 3:
502
+ H, W, C = img.shape
503
+ H_r, W_r = H % scale, W % scale
504
+ img = img[:H - H_r, :W - W_r, :]
505
+ else:
506
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
507
+ return img
508
+
509
+
510
+ def shave(img_in, border=0):
511
+ # img_in: Numpy, HWC or HW
512
+ img = np.copy(img_in)
513
+ h, w = img.shape[:2]
514
+ img = img[border:h-border, border:w-border]
515
+ return img
516
+
517
+
518
+ '''
519
+ # --------------------------------------------
520
+ # image processing process on numpy image
521
+ # channel_convert(in_c, tar_type, img_list):
522
+ # rgb2ycbcr(img, only_y=True):
523
+ # bgr2ycbcr(img, only_y=True):
524
+ # ycbcr2rgb(img):
525
+ # --------------------------------------------
526
+ '''
527
+
528
+
529
+ def rgb2ycbcr(img, only_y=True):
530
+ '''same as matlab rgb2ycbcr
531
+ only_y: only return Y channel
532
+ Input:
533
+ uint8, [0, 255]
534
+ float, [0, 1]
535
+ '''
536
+ in_img_type = img.dtype
537
+ img.astype(np.float32)
538
+ if in_img_type != np.uint8:
539
+ img *= 255.
540
+ # convert
541
+ if only_y:
542
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
543
+ else:
544
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
545
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
546
+ if in_img_type == np.uint8:
547
+ rlt = rlt.round()
548
+ else:
549
+ rlt /= 255.
550
+ return rlt.astype(in_img_type)
551
+
552
+
553
+ def ycbcr2rgb(img):
554
+ '''same as matlab ycbcr2rgb
555
+ Input:
556
+ uint8, [0, 255]
557
+ float, [0, 1]
558
+ '''
559
+ in_img_type = img.dtype
560
+ img.astype(np.float32)
561
+ if in_img_type != np.uint8:
562
+ img *= 255.
563
+ # convert
564
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
565
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
566
+ if in_img_type == np.uint8:
567
+ rlt = rlt.round()
568
+ else:
569
+ rlt /= 255.
570
+ return rlt.astype(in_img_type)
571
+
572
+
573
+ def bgr2ycbcr(img, only_y=True):
574
+ '''bgr version of rgb2ycbcr
575
+ only_y: only return Y channel
576
+ Input:
577
+ uint8, [0, 255]
578
+ float, [0, 1]
579
+ '''
580
+ in_img_type = img.dtype
581
+ img.astype(np.float32)
582
+ if in_img_type != np.uint8:
583
+ img *= 255.
584
+ # convert
585
+ if only_y:
586
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
587
+ else:
588
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
589
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
590
+ if in_img_type == np.uint8:
591
+ rlt = rlt.round()
592
+ else:
593
+ rlt /= 255.
594
+ return rlt.astype(in_img_type)
595
+
596
+
597
+ def channel_convert(in_c, tar_type, img_list):
598
+ # conversion among BGR, gray and y
599
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
600
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
601
+ return [np.expand_dims(img, axis=2) for img in gray_list]
602
+ elif in_c == 3 and tar_type == 'y': # BGR to y
603
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
604
+ return [np.expand_dims(img, axis=2) for img in y_list]
605
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
606
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
607
+ else:
608
+ return img_list
609
+
610
+
611
+ '''
612
+ # --------------------------------------------
613
+ # metric, PSNR and SSIM
614
+ # --------------------------------------------
615
+ '''
616
+
617
+
618
+ # --------------------------------------------
619
+ # PSNR
620
+ # --------------------------------------------
621
+ def calculate_psnr(img1, img2, border=0):
622
+ # img1 and img2 have range [0, 255]
623
+ #img1 = img1.squeeze()
624
+ #img2 = img2.squeeze()
625
+ if not img1.shape == img2.shape:
626
+ raise ValueError('Input images must have the same dimensions.')
627
+ h, w = img1.shape[:2]
628
+ img1 = img1[border:h-border, border:w-border]
629
+ img2 = img2[border:h-border, border:w-border]
630
+
631
+ img1 = img1.astype(np.float64)
632
+ img2 = img2.astype(np.float64)
633
+ mse = np.mean((img1 - img2)**2)
634
+ if mse == 0:
635
+ return float('inf')
636
+ return 20 * math.log10(255.0 / math.sqrt(mse))
637
+
638
+
639
+ # --------------------------------------------
640
+ # SSIM
641
+ # --------------------------------------------
642
+ def calculate_ssim(img1, img2, border=0):
643
+ '''calculate SSIM
644
+ the same outputs as MATLAB's
645
+ img1, img2: [0, 255]
646
+ '''
647
+ #img1 = img1.squeeze()
648
+ #img2 = img2.squeeze()
649
+ if not img1.shape == img2.shape:
650
+ raise ValueError('Input images must have the same dimensions.')
651
+ h, w = img1.shape[:2]
652
+ img1 = img1[border:h-border, border:w-border]
653
+ img2 = img2[border:h-border, border:w-border]
654
+
655
+ if img1.ndim == 2:
656
+ return ssim(img1, img2)
657
+ elif img1.ndim == 3:
658
+ if img1.shape[2] == 3:
659
+ ssims = []
660
+ for i in range(3):
661
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
662
+ return np.array(ssims).mean()
663
+ elif img1.shape[2] == 1:
664
+ return ssim(np.squeeze(img1), np.squeeze(img2))
665
+ else:
666
+ raise ValueError('Wrong input image dimensions.')
667
+
668
+
669
+ def ssim(img1, img2):
670
+ C1 = (0.01 * 255)**2
671
+ C2 = (0.03 * 255)**2
672
+
673
+ img1 = img1.astype(np.float64)
674
+ img2 = img2.astype(np.float64)
675
+ kernel = cv2.getGaussianKernel(11, 1.5)
676
+ window = np.outer(kernel, kernel.transpose())
677
+
678
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
679
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
680
+ mu1_sq = mu1**2
681
+ mu2_sq = mu2**2
682
+ mu1_mu2 = mu1 * mu2
683
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
684
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
685
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
686
+
687
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
688
+ (sigma1_sq + sigma2_sq + C2))
689
+ return ssim_map.mean()
690
+
691
+
692
+ '''
693
+ # --------------------------------------------
694
+ # matlab's bicubic imresize (numpy and torch) [0, 1]
695
+ # --------------------------------------------
696
+ '''
697
+
698
+
699
+ # matlab 'imresize' function, now only support 'bicubic'
700
+ def cubic(x):
701
+ absx = torch.abs(x)
702
+ absx2 = absx**2
703
+ absx3 = absx**3
704
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
705
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
706
+
707
+
708
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
709
+ if (scale < 1) and (antialiasing):
710
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
711
+ kernel_width = kernel_width / scale
712
+
713
+ # Output-space coordinates
714
+ x = torch.linspace(1, out_length, out_length)
715
+
716
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
717
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
718
+ # space maps to 1.5 in input space.
719
+ u = x / scale + 0.5 * (1 - 1 / scale)
720
+
721
+ # What is the left-most pixel that can be involved in the computation?
722
+ left = torch.floor(u - kernel_width / 2)
723
+
724
+ # What is the maximum number of pixels that can be involved in the
725
+ # computation? Note: it's OK to use an extra pixel here; if the
726
+ # corresponding weights are all zero, it will be eliminated at the end
727
+ # of this function.
728
+ P = math.ceil(kernel_width) + 2
729
+
730
+ # The indices of the input pixels involved in computing the k-th output
731
+ # pixel are in row k of the indices matrix.
732
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
733
+ 1, P).expand(out_length, P)
734
+
735
+ # The weights used to compute the k-th output pixel are in row k of the
736
+ # weights matrix.
737
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
738
+ # apply cubic kernel
739
+ if (scale < 1) and (antialiasing):
740
+ weights = scale * cubic(distance_to_center * scale)
741
+ else:
742
+ weights = cubic(distance_to_center)
743
+ # Normalize the weights matrix so that each row sums to 1.
744
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
745
+ weights = weights / weights_sum.expand(out_length, P)
746
+
747
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
748
+ weights_zero_tmp = torch.sum((weights == 0), 0)
749
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
750
+ indices = indices.narrow(1, 1, P - 2)
751
+ weights = weights.narrow(1, 1, P - 2)
752
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
753
+ indices = indices.narrow(1, 0, P - 2)
754
+ weights = weights.narrow(1, 0, P - 2)
755
+ weights = weights.contiguous()
756
+ indices = indices.contiguous()
757
+ sym_len_s = -indices.min() + 1
758
+ sym_len_e = indices.max() - in_length
759
+ indices = indices + sym_len_s - 1
760
+ return weights, indices, int(sym_len_s), int(sym_len_e)
761
+
762
+
763
+ # --------------------------------------------
764
+ # imresize for tensor image [0, 1]
765
+ # --------------------------------------------
766
+ def imresize(img, scale, antialiasing=True):
767
+ # Now the scale should be the same for H and W
768
+ # input: img: pytorch tensor, CHW or HW [0,1]
769
+ # output: CHW or HW [0,1] w/o round
770
+ need_squeeze = True if img.dim() == 2 else False
771
+ if need_squeeze:
772
+ img.unsqueeze_(0)
773
+ in_C, in_H, in_W = img.size()
774
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
775
+ kernel_width = 4
776
+ kernel = 'cubic'
777
+
778
+ # Return the desired dimension order for performing the resize. The
779
+ # strategy is to perform the resize first along the dimension with the
780
+ # smallest scale factor.
781
+ # Now we do not support this.
782
+
783
+ # get weights and indices
784
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
785
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
786
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
787
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
788
+ # process H dimension
789
+ # symmetric copying
790
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
791
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
792
+
793
+ sym_patch = img[:, :sym_len_Hs, :]
794
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
795
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
796
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
797
+
798
+ sym_patch = img[:, -sym_len_He:, :]
799
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
800
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
801
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
802
+
803
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
804
+ kernel_width = weights_H.size(1)
805
+ for i in range(out_H):
806
+ idx = int(indices_H[i][0])
807
+ for j in range(out_C):
808
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
809
+
810
+ # process W dimension
811
+ # symmetric copying
812
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
813
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
814
+
815
+ sym_patch = out_1[:, :, :sym_len_Ws]
816
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
817
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
818
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
819
+
820
+ sym_patch = out_1[:, :, -sym_len_We:]
821
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
822
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
823
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
824
+
825
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
826
+ kernel_width = weights_W.size(1)
827
+ for i in range(out_W):
828
+ idx = int(indices_W[i][0])
829
+ for j in range(out_C):
830
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
831
+ if need_squeeze:
832
+ out_2.squeeze_()
833
+ return out_2
834
+
835
+
836
+ # --------------------------------------------
837
+ # imresize for numpy image [0, 1]
838
+ # --------------------------------------------
839
+ def imresize_np(img, scale, antialiasing=True):
840
+ # Now the scale should be the same for H and W
841
+ # input: img: Numpy, HWC or HW [0,1]
842
+ # output: HWC or HW [0,1] w/o round
843
+ img = torch.from_numpy(img)
844
+ need_squeeze = True if img.dim() == 2 else False
845
+ if need_squeeze:
846
+ img.unsqueeze_(2)
847
+
848
+ in_H, in_W, in_C = img.size()
849
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
850
+ kernel_width = 4
851
+ kernel = 'cubic'
852
+
853
+ # Return the desired dimension order for performing the resize. The
854
+ # strategy is to perform the resize first along the dimension with the
855
+ # smallest scale factor.
856
+ # Now we do not support this.
857
+
858
+ # get weights and indices
859
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
860
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
861
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
862
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
863
+ # process H dimension
864
+ # symmetric copying
865
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
866
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
867
+
868
+ sym_patch = img[:sym_len_Hs, :, :]
869
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
870
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
871
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
872
+
873
+ sym_patch = img[-sym_len_He:, :, :]
874
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
875
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
876
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
877
+
878
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
879
+ kernel_width = weights_H.size(1)
880
+ for i in range(out_H):
881
+ idx = int(indices_H[i][0])
882
+ for j in range(out_C):
883
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
884
+
885
+ # process W dimension
886
+ # symmetric copying
887
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
888
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
889
+
890
+ sym_patch = out_1[:, :sym_len_Ws, :]
891
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
892
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
893
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
894
+
895
+ sym_patch = out_1[:, -sym_len_We:, :]
896
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
897
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
898
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
899
+
900
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
901
+ kernel_width = weights_W.size(1)
902
+ for i in range(out_W):
903
+ idx = int(indices_W[i][0])
904
+ for j in range(out_C):
905
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
906
+ if need_squeeze:
907
+ out_2.squeeze_()
908
+
909
+ return out_2.numpy()
910
+
911
+
912
+ if __name__ == '__main__':
913
+ print('---')
914
+ # img = imread_uint('test.bmp', 3)
915
+ # img = uint2single(img)
916
+ # img_bicubic = imresize_np(img, 1/4)
ldm/modules/midas/__init__.py ADDED
File without changes
ldm/modules/midas/api.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # based on https://github.com/isl-org/MiDaS
2
+
3
+ import cv2
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision.transforms import Compose
7
+
8
+ from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
9
+ from ldm.modules.midas.midas.midas_net import MidasNet
10
+ from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
11
+ from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
12
+
13
+
14
+ ISL_PATHS = {
15
+ "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
16
+ "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
17
+ "midas_v21": "",
18
+ "midas_v21_small": "",
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ def load_midas_transform(model_type):
29
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
30
+ # load transform only
31
+ if model_type == "dpt_large": # DPT-Large
32
+ net_w, net_h = 384, 384
33
+ resize_mode = "minimal"
34
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
35
+
36
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
37
+ net_w, net_h = 384, 384
38
+ resize_mode = "minimal"
39
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
40
+
41
+ elif model_type == "midas_v21":
42
+ net_w, net_h = 384, 384
43
+ resize_mode = "upper_bound"
44
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
45
+
46
+ elif model_type == "midas_v21_small":
47
+ net_w, net_h = 256, 256
48
+ resize_mode = "upper_bound"
49
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
50
+
51
+ else:
52
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
53
+
54
+ transform = Compose(
55
+ [
56
+ Resize(
57
+ net_w,
58
+ net_h,
59
+ resize_target=None,
60
+ keep_aspect_ratio=True,
61
+ ensure_multiple_of=32,
62
+ resize_method=resize_mode,
63
+ image_interpolation_method=cv2.INTER_CUBIC,
64
+ ),
65
+ normalization,
66
+ PrepareForNet(),
67
+ ]
68
+ )
69
+
70
+ return transform
71
+
72
+
73
+ def load_model(model_type):
74
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
75
+ # load network
76
+ model_path = ISL_PATHS[model_type]
77
+ if model_type == "dpt_large": # DPT-Large
78
+ model = DPTDepthModel(
79
+ path=model_path,
80
+ backbone="vitl16_384",
81
+ non_negative=True,
82
+ )
83
+ net_w, net_h = 384, 384
84
+ resize_mode = "minimal"
85
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
86
+
87
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
88
+ model = DPTDepthModel(
89
+ path=model_path,
90
+ backbone="vitb_rn50_384",
91
+ non_negative=True,
92
+ )
93
+ net_w, net_h = 384, 384
94
+ resize_mode = "minimal"
95
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
96
+
97
+ elif model_type == "midas_v21":
98
+ model = MidasNet(model_path, non_negative=True)
99
+ net_w, net_h = 384, 384
100
+ resize_mode = "upper_bound"
101
+ normalization = NormalizeImage(
102
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
103
+ )
104
+
105
+ elif model_type == "midas_v21_small":
106
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
107
+ non_negative=True, blocks={'expand': True})
108
+ net_w, net_h = 256, 256
109
+ resize_mode = "upper_bound"
110
+ normalization = NormalizeImage(
111
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
112
+ )
113
+
114
+ else:
115
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
116
+ assert False
117
+
118
+ transform = Compose(
119
+ [
120
+ Resize(
121
+ net_w,
122
+ net_h,
123
+ resize_target=None,
124
+ keep_aspect_ratio=True,
125
+ ensure_multiple_of=32,
126
+ resize_method=resize_mode,
127
+ image_interpolation_method=cv2.INTER_CUBIC,
128
+ ),
129
+ normalization,
130
+ PrepareForNet(),
131
+ ]
132
+ )
133
+
134
+ return model.eval(), transform
135
+
136
+
137
+ class MiDaSInference(nn.Module):
138
+ MODEL_TYPES_TORCH_HUB = [
139
+ "DPT_Large",
140
+ "DPT_Hybrid",
141
+ "MiDaS_small"
142
+ ]
143
+ MODEL_TYPES_ISL = [
144
+ "dpt_large",
145
+ "dpt_hybrid",
146
+ "midas_v21",
147
+ "midas_v21_small",
148
+ ]
149
+
150
+ def __init__(self, model_type):
151
+ super().__init__()
152
+ assert (model_type in self.MODEL_TYPES_ISL)
153
+ model, _ = load_model(model_type)
154
+ self.model = model
155
+ self.model.train = disabled_train
156
+
157
+ def forward(self, x):
158
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
159
+ # NOTE: we expect that the correct transform has been called during dataloading.
160
+ with torch.no_grad():
161
+ prediction = self.model(x)
162
+ prediction = torch.nn.functional.interpolate(
163
+ prediction.unsqueeze(1),
164
+ size=x.shape[2:],
165
+ mode="bicubic",
166
+ align_corners=False,
167
+ )
168
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
169
+ return prediction
170
+
ldm/modules/midas/midas/__init__.py ADDED
File without changes
ldm/modules/midas/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)