VIVEK JAYARAM commited on
Commit
3e0a809
1 Parent(s): d8f7287

Basic working example

Browse files
cdim/diffusion/diffusion_pipeline.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+
4
+ from cdim.image_utils import randn_tensor
5
+
6
+
7
+ @torch.no_grad()
8
+ def run_diffusion(
9
+ model,
10
+ scheduler,
11
+ noisy_observation,
12
+ operator,
13
+ noise_function,
14
+ device,
15
+ num_inference_steps: int = 1000,
16
+ K=5,
17
+ image_dim=256,
18
+ image_channels=3
19
+ ):
20
+ batch_size = noisy_observation.shape[0]
21
+ image_shape = (batch_size, image_channels, image_dim, image_dim)
22
+ image = randn_tensor(image_shape, device=device)
23
+
24
+ scheduler.set_timesteps(num_inference_steps, device=device)
25
+ t_skip = scheduler.timesteps[0] - scheduler.timesteps[1]
26
+
27
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps), desc="Processing timesteps"):
28
+ # 1. predict noise model_output
29
+ model_output = model(image, t.unsqueeze(0).to(device))[:, :3]
30
+
31
+ # 2. compute previous image: x_t -> x_t-1
32
+ image = scheduler.step(model_output, t, image).prev_sample
33
+ image.requires_grad_()
34
+ alpha_prod_t_prev = scheduler.alphas_cumprod[t-t_skip] if t-t_skip >= 0 else 1
35
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
36
+ for j in range(K):
37
+ if t <= 0: break
38
+
39
+ with torch.enable_grad():
40
+ # Calculate x^hat_0
41
+ model_output = model(image, (t - t_skip).unsqueeze(0).to(device))[:, :3]
42
+ x_0 = (image - beta_prod_t_prev ** (0.5) * model_output) / alpha_prod_t_prev ** (0.5)
43
+
44
+ distance = operator(x_0) - noisy_observation
45
+ if (distance ** 2).mean() < noise_function.sigma ** 2:
46
+ break
47
+ loss = ((distance) ** 2).mean()
48
+ print(loss.mean())
49
+ loss.mean().backward()
50
+
51
+ image -= 10 / torch.linalg.norm(image.grad) * image.grad
52
+
53
+ return image
cdim/diffusion/scheduling_ddim.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from collections import OrderedDict
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from cdim.image_utils import randn_tensor
27
+
28
+
29
+ class FrozenDict(OrderedDict):
30
+ def __init__(self, *args, **kwargs):
31
+ super().__init__(*args, **kwargs)
32
+
33
+ for key, value in self.items():
34
+ setattr(self, key, value)
35
+
36
+ self.__frozen = True
37
+
38
+ def __delitem__(self, *args, **kwargs):
39
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
40
+
41
+ def setdefault(self, *args, **kwargs):
42
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
43
+
44
+ def pop(self, *args, **kwargs):
45
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
46
+
47
+ def update(self, *args, **kwargs):
48
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
49
+
50
+ def __setattr__(self, name, value):
51
+ if hasattr(self, "__frozen") and self.__frozen:
52
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
53
+ super().__setattr__(name, value)
54
+
55
+ def __setitem__(self, name, value):
56
+ if hasattr(self, "__frozen") and self.__frozen:
57
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
58
+ super().__setitem__(name, value)
59
+
60
+
61
+ @dataclass
62
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
63
+ class DDIMSchedulerOutput:
64
+ """
65
+ Output class for the scheduler's `step` function output.
66
+
67
+ Args:
68
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
69
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
70
+ denoising loop.
71
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
72
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
73
+ `pred_original_sample` can be used to preview progress or for guidance.
74
+ """
75
+
76
+ prev_sample: torch.FloatTensor
77
+ pred_original_sample: Optional[torch.FloatTensor] = None
78
+
79
+
80
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
81
+ def betas_for_alpha_bar(
82
+ num_diffusion_timesteps,
83
+ max_beta=0.999,
84
+ alpha_transform_type="cosine",
85
+ ):
86
+ """
87
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
88
+ (1-beta) over time from t = [0,1].
89
+
90
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
91
+ to that part of the diffusion process.
92
+
93
+
94
+ Args:
95
+ num_diffusion_timesteps (`int`): the number of betas to produce.
96
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
97
+ prevent singularities.
98
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
99
+ Choose from `cosine` or `exp`
100
+
101
+ Returns:
102
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
103
+ """
104
+ if alpha_transform_type == "cosine":
105
+
106
+ def alpha_bar_fn(t):
107
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
108
+
109
+ elif alpha_transform_type == "exp":
110
+
111
+ def alpha_bar_fn(t):
112
+ return math.exp(t * -12.0)
113
+
114
+ else:
115
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
116
+
117
+ betas = []
118
+ for i in range(num_diffusion_timesteps):
119
+ t1 = i / num_diffusion_timesteps
120
+ t2 = (i + 1) / num_diffusion_timesteps
121
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
122
+ return torch.tensor(betas, dtype=torch.float32)
123
+
124
+
125
+ class DDIMScheduler:
126
+ """
127
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
128
+ non-Markovian guidance.
129
+
130
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
131
+ methods the library implements for all schedulers such as loading and saving.
132
+
133
+ Args:
134
+ num_train_timesteps (`int`, defaults to 1000):
135
+ The number of diffusion steps to train the model.
136
+ beta_start (`float`, defaults to 0.0001):
137
+ The starting `beta` value of inference.
138
+ beta_end (`float`, defaults to 0.02):
139
+ The final `beta` value.
140
+ beta_schedule (`str`, defaults to `"linear"`):
141
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
142
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
143
+ trained_betas (`np.ndarray`, *optional*):
144
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
145
+ clip_sample (`bool`, defaults to `True`):
146
+ Clip the predicted sample for numerical stability.
147
+ clip_sample_range (`float`, defaults to 1.0):
148
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
149
+ set_alpha_to_one (`bool`, defaults to `True`):
150
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
151
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
152
+ otherwise it uses the alpha value at step 0.
153
+ steps_offset (`int`, defaults to 0):
154
+ An offset added to the inference steps, as required by some model families.
155
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
156
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
157
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
158
+ Video](https://imagen.research.google/video/paper.pdf) paper).
159
+ thresholding (`bool`, defaults to `False`):
160
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
161
+ as Stable Diffusion.
162
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
163
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
164
+ sample_max_value (`float`, defaults to 1.0):
165
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
166
+ timestep_spacing (`str`, defaults to `"leading"`):
167
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
168
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
169
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
170
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
171
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
172
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
173
+ """
174
+ def __init__(
175
+ self,
176
+ num_train_timesteps: int = 1000,
177
+ beta_start: float = 0.0001,
178
+ beta_end: float = 0.02,
179
+ beta_schedule: str = "linear",
180
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
181
+ clip_sample: bool = True,
182
+ set_alpha_to_one: bool = True,
183
+ steps_offset: int = 0,
184
+ prediction_type: str = "epsilon",
185
+ thresholding: bool = False,
186
+ dynamic_thresholding_ratio: float = 0.995,
187
+ clip_sample_range: float = 1.0,
188
+ sample_max_value: float = 1.0,
189
+ timestep_spacing: str = "leading",
190
+ ):
191
+
192
+ # Hacky way to replicate diffusers register to config
193
+ self.config = FrozenDict(
194
+ {key: value for key, value in locals().items() if key != "self"}
195
+ )
196
+ if trained_betas is not None:
197
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
198
+ elif beta_schedule == "linear":
199
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
200
+ elif beta_schedule == "scaled_linear":
201
+ # this schedule is very specific to the latent diffusion model.
202
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
203
+ elif beta_schedule == "squaredcos_cap_v2":
204
+ # Glide cosine schedule
205
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
206
+ else:
207
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
208
+
209
+ self.alphas = 1.0 - self.betas
210
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
211
+
212
+ # At every step in ddim, we are looking into the previous alphas_cumprod
213
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
214
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
215
+ # whether we use the final alpha of the "non-previous" one.
216
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
217
+
218
+ # standard deviation of the initial noise distribution
219
+ self.init_noise_sigma = 1.0
220
+
221
+ # setable values
222
+ self.num_inference_steps = None
223
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
224
+
225
+
226
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
227
+ """
228
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
229
+ current timestep.
230
+
231
+ Args:
232
+ sample (`torch.FloatTensor`):
233
+ The input sample.
234
+ timestep (`int`, *optional*):
235
+ The current timestep in the diffusion chain.
236
+
237
+ Returns:
238
+ `torch.FloatTensor`:
239
+ A scaled input sample.
240
+ """
241
+ return sample
242
+
243
+ def _get_variance(self, timestep, prev_timestep):
244
+ alpha_prod_t = self.alphas_cumprod[timestep]
245
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
246
+ beta_prod_t = 1 - alpha_prod_t
247
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
248
+
249
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
250
+
251
+ return variance
252
+
253
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
254
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
255
+ """
256
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
257
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
258
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
259
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
260
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
261
+
262
+ https://arxiv.org/abs/2205.11487
263
+ """
264
+ dtype = sample.dtype
265
+ batch_size, channels, *remaining_dims = sample.shape
266
+
267
+ if dtype not in (torch.float32, torch.float64):
268
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
269
+
270
+ # Flatten sample for doing quantile calculation along each image
271
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
272
+
273
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
274
+
275
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
276
+ s = torch.clamp(
277
+ s, min=1, max=self.config.sample_max_value
278
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
279
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
280
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
281
+
282
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
283
+ sample = sample.to(dtype)
284
+
285
+ return sample
286
+
287
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
288
+ """
289
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
290
+
291
+ Args:
292
+ num_inference_steps (`int`):
293
+ The number of diffusion steps used when generating samples with a pre-trained model.
294
+ """
295
+
296
+ if num_inference_steps > self.config.num_train_timesteps:
297
+ raise ValueError(
298
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
299
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
300
+ f" maximal {self.config.num_train_timesteps} timesteps."
301
+ )
302
+
303
+ self.num_inference_steps = num_inference_steps
304
+
305
+ # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
306
+ if self.config.timestep_spacing == "linspace":
307
+ timesteps = (
308
+ np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
309
+ .round()[::-1]
310
+ .copy()
311
+ .astype(np.int64)
312
+ )
313
+ elif self.config.timestep_spacing == "leading":
314
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
315
+ # creates integer timesteps by multiplying by ratio
316
+ # casting to int to avoid issues when num_inference_step is power of 3
317
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
318
+ timesteps += self.config.steps_offset
319
+ elif self.config.timestep_spacing == "trailing":
320
+ step_ratio = self.config.num_train_timesteps / self.num_inference_steps
321
+ # creates integer timesteps by multiplying by ratio
322
+ # casting to int to avoid issues when num_inference_step is power of 3
323
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
324
+ timesteps -= 1
325
+ else:
326
+ raise ValueError(
327
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
328
+ )
329
+
330
+ self.timesteps = torch.from_numpy(timesteps).to(device)
331
+
332
+ def step(
333
+ self,
334
+ model_output: torch.FloatTensor,
335
+ timestep: int,
336
+ sample: torch.FloatTensor,
337
+ eta: float = 0.0,
338
+ use_clipped_model_output: bool = False,
339
+ generator=None,
340
+ variance_noise: Optional[torch.FloatTensor] = None,
341
+ return_dict: bool = True,
342
+ original_image = None
343
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
344
+ """
345
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
346
+ process from the learned model outputs (most often the predicted noise).
347
+
348
+ Args:
349
+ model_output (`torch.FloatTensor`):
350
+ The direct output from learned diffusion model.
351
+ timestep (`float`):
352
+ The current discrete timestep in the diffusion chain.
353
+ sample (`torch.FloatTensor`):
354
+ A current instance of a sample created by the diffusion process.
355
+ eta (`float`):
356
+ The weight of noise for added noise in diffusion step.
357
+ use_clipped_model_output (`bool`, defaults to `False`):
358
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
359
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
360
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
361
+ `use_clipped_model_output` has no effect.
362
+ generator (`torch.Generator`, *optional*):
363
+ A random number generator.
364
+ variance_noise (`torch.FloatTensor`):
365
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
366
+ itself. Useful for methods such as [`CycleDiffusion`].
367
+ return_dict (`bool`, *optional*, defaults to `True`):
368
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
369
+
370
+ Returns:
371
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
372
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
373
+ tuple is returned where the first element is the sample tensor.
374
+
375
+ """
376
+ if self.num_inference_steps is None:
377
+ raise ValueError(
378
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
379
+ )
380
+
381
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
382
+ # Ideally, read DDIM paper in-detail understanding
383
+
384
+ # Notation (<variable name> -> <name in paper>
385
+ # - pred_noise_t -> e_theta(x_t, t)
386
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
387
+ # - std_dev_t -> sigma_t
388
+ # - eta -> η
389
+ # - pred_sample_direction -> "direction pointing to x_t"
390
+ # - pred_prev_sample -> "x_t-1"
391
+
392
+ # 1. get previous step value (=t-1)
393
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
394
+
395
+ # 2. compute alphas, betas
396
+ alpha_prod_t = self.alphas_cumprod[timestep]
397
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
398
+
399
+ beta_prod_t = 1 - alpha_prod_t
400
+
401
+ # 3. compute predicted original sample from predicted noise also called
402
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
403
+ if self.config.prediction_type == "epsilon":
404
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
405
+ pred_epsilon = model_output
406
+ elif self.config.prediction_type == "sample":
407
+ pred_original_sample = model_output
408
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
409
+ elif self.config.prediction_type == "v_prediction":
410
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
411
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
412
+ else:
413
+ raise ValueError(
414
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
415
+ " `v_prediction`"
416
+ )
417
+
418
+ # 4. Clip or threshold "predicted x_0"
419
+ if self.config.thresholding:
420
+ pred_original_sample = self._threshold_sample(pred_original_sample)
421
+ elif self.config.clip_sample:
422
+ pred_original_sample = pred_original_sample.clamp(
423
+ -self.config.clip_sample_range, self.config.clip_sample_range
424
+ )
425
+
426
+ # pred_original_sample[:, :, 128:, :] = original_image[:, :, 128:, :]
427
+
428
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
429
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
430
+ variance = self._get_variance(timestep, prev_timestep)
431
+ std_dev_t = eta * variance ** (0.5)
432
+
433
+ if use_clipped_model_output:
434
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
435
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
436
+
437
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
438
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
439
+
440
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
441
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
442
+
443
+ if eta > 0:
444
+ if variance_noise is not None and generator is not None:
445
+ raise ValueError(
446
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
447
+ " `variance_noise` stays `None`."
448
+ )
449
+
450
+ if variance_noise is None:
451
+ variance_noise = randn_tensor(
452
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
453
+ )
454
+ variance = std_dev_t * variance_noise
455
+
456
+ prev_sample = prev_sample + variance
457
+
458
+ if not return_dict:
459
+ return (prev_sample,)
460
+
461
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
462
+
463
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
464
+ def add_noise(
465
+ self,
466
+ original_samples: torch.FloatTensor,
467
+ noise: torch.FloatTensor,
468
+ timesteps: torch.IntTensor,
469
+ ) -> torch.FloatTensor:
470
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
471
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
472
+ # for the subsequent add_noise calls
473
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device)
474
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
475
+ timesteps = timesteps.to(original_samples.device)
476
+
477
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
478
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
479
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
480
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
481
+
482
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
483
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
484
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
485
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
486
+
487
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
488
+ return noisy_samples
489
+
490
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
491
+ def get_velocity(
492
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
493
+ ) -> torch.FloatTensor:
494
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
495
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device)
496
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
497
+ timesteps = timesteps.to(sample.device)
498
+
499
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
500
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
501
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
502
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
503
+
504
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
505
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
506
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
507
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
508
+
509
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
510
+ return velocity
511
+
512
+ def __len__(self):
513
+ return self.config.num_train_timesteps
cdim/dps_model/dps_unet.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code based on https://github.com/DPS2022/diffusion-posterior-sampling
2
+ from abc import abstractmethod
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch as th
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import functools
11
+
12
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
13
+ from .nn import (
14
+ checkpoint,
15
+ conv_nd,
16
+ linear,
17
+ avg_pool_nd,
18
+ zero_module,
19
+ normalization,
20
+ timestep_embedding,
21
+ )
22
+
23
+
24
+ NUM_CLASSES = 1000
25
+
26
+ def create_model(
27
+ image_size,
28
+ num_channels,
29
+ num_res_blocks,
30
+ channel_mult="",
31
+ learn_sigma=False,
32
+ class_cond=False,
33
+ use_checkpoint=False,
34
+ attention_resolutions="16",
35
+ num_heads=1,
36
+ num_head_channels=-1,
37
+ num_heads_upsample=-1,
38
+ use_scale_shift_norm=False,
39
+ dropout=0,
40
+ resblock_updown=False,
41
+ use_fp16=False,
42
+ use_new_attention_order=False,
43
+ model_path='',
44
+ ):
45
+ if channel_mult == "":
46
+ if image_size == 512:
47
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
48
+ elif image_size == 256:
49
+ channel_mult = (1, 1, 2, 2, 4, 4)
50
+ elif image_size == 128:
51
+ channel_mult = (1, 1, 2, 3, 4)
52
+ elif image_size == 64:
53
+ channel_mult = (1, 2, 3, 4)
54
+ else:
55
+ raise ValueError(f"unsupported image size: {image_size}")
56
+ else:
57
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
58
+
59
+ attention_ds = []
60
+ if isinstance(attention_resolutions, int):
61
+ attention_ds.append(image_size // attention_resolutions)
62
+ elif isinstance(attention_resolutions, str):
63
+ for res in attention_resolutions.split(","):
64
+ attention_ds.append(image_size // int(res))
65
+ else:
66
+ raise NotImplementedError
67
+
68
+ model= UNetModel(
69
+ image_size=image_size,
70
+ in_channels=3,
71
+ model_channels=num_channels,
72
+ out_channels=(3 if not learn_sigma else 6),
73
+ num_res_blocks=num_res_blocks,
74
+ attention_resolutions=tuple(attention_ds),
75
+ dropout=dropout,
76
+ channel_mult=channel_mult,
77
+ num_classes=(NUM_CLASSES if class_cond else None),
78
+ use_checkpoint=use_checkpoint,
79
+ use_fp16=use_fp16,
80
+ num_heads=num_heads,
81
+ num_head_channels=num_head_channels,
82
+ num_heads_upsample=num_heads_upsample,
83
+ use_scale_shift_norm=use_scale_shift_norm,
84
+ resblock_updown=resblock_updown,
85
+ use_new_attention_order=use_new_attention_order,
86
+ )
87
+
88
+ try:
89
+ model.load_state_dict(th.load(model_path, map_location='cpu'))
90
+ except Exception as e:
91
+ print(f"Got exception: {e} / Randomly initialize")
92
+ return model
93
+
94
+ class AttentionPool2d(nn.Module):
95
+ """
96
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ spacial_dim: int,
102
+ embed_dim: int,
103
+ num_heads_channels: int,
104
+ output_dim: int = None,
105
+ ):
106
+ super().__init__()
107
+ self.positional_embedding = nn.Parameter(
108
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
109
+ )
110
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
111
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
112
+ self.num_heads = embed_dim // num_heads_channels
113
+ self.attention = QKVAttention(self.num_heads)
114
+
115
+ def forward(self, x):
116
+ b, c, *_spatial = x.shape
117
+ x = x.reshape(b, c, -1) # NC(HW)
118
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
119
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
120
+ x = self.qkv_proj(x)
121
+ x = self.attention(x)
122
+ x = self.c_proj(x)
123
+ return x[:, :, 0]
124
+
125
+
126
+ class TimestepBlock(nn.Module):
127
+ """
128
+ Any module where forward() takes timestep embeddings as a second argument.
129
+ """
130
+
131
+ @abstractmethod
132
+ def forward(self, x, emb):
133
+ """
134
+ Apply the module to `x` given `emb` timestep embeddings.
135
+ """
136
+
137
+
138
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
139
+ """
140
+ A sequential module that passes timestep embeddings to the children that
141
+ support it as an extra input.
142
+ """
143
+
144
+ def forward(self, x, emb):
145
+ for layer in self:
146
+ if isinstance(layer, TimestepBlock):
147
+ x = layer(x, emb)
148
+ else:
149
+ x = layer(x)
150
+ return x
151
+
152
+
153
+ class Upsample(nn.Module):
154
+ """
155
+ An upsampling layer with an optional convolution.
156
+
157
+ :param channels: channels in the inputs and outputs.
158
+ :param use_conv: a bool determining if a convolution is applied.
159
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
160
+ upsampling occurs in the inner-two dimensions.
161
+ """
162
+
163
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
164
+ super().__init__()
165
+ self.channels = channels
166
+ self.out_channels = out_channels or channels
167
+ self.use_conv = use_conv
168
+ self.dims = dims
169
+ if use_conv:
170
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
171
+
172
+ def forward(self, x):
173
+ assert x.shape[1] == self.channels
174
+ if self.dims == 3:
175
+ x = F.interpolate(
176
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
177
+ )
178
+ else:
179
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
180
+ if self.use_conv:
181
+ x = self.conv(x)
182
+ return x
183
+
184
+
185
+ class Downsample(nn.Module):
186
+ """
187
+ A downsampling layer with an optional convolution.
188
+
189
+ :param channels: channels in the inputs and outputs.
190
+ :param use_conv: a bool determining if a convolution is applied.
191
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
192
+ downsampling occurs in the inner-two dimensions.
193
+ """
194
+
195
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
196
+ super().__init__()
197
+ self.channels = channels
198
+ self.out_channels = out_channels or channels
199
+ self.use_conv = use_conv
200
+ self.dims = dims
201
+ stride = 2 if dims != 3 else (1, 2, 2)
202
+ if use_conv:
203
+ self.op = conv_nd(
204
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
205
+ )
206
+ else:
207
+ assert self.channels == self.out_channels
208
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
209
+
210
+ def forward(self, x):
211
+ assert x.shape[1] == self.channels
212
+ return self.op(x)
213
+
214
+
215
+ class ResBlock(TimestepBlock):
216
+ """
217
+ A residual block that can optionally change the number of channels.
218
+
219
+ :param channels: the number of input channels.
220
+ :param emb_channels: the number of timestep embedding channels.
221
+ :param dropout: the rate of dropout.
222
+ :param out_channels: if specified, the number of out channels.
223
+ :param use_conv: if True and out_channels is specified, use a spatial
224
+ convolution instead of a smaller 1x1 convolution to change the
225
+ channels in the skip connection.
226
+ :param dims: determines if the signal is 1D, 2D, or 3D.
227
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
228
+ :param up: if True, use this block for upsampling.
229
+ :param down: if True, use this block for downsampling.
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ channels,
235
+ emb_channels,
236
+ dropout,
237
+ out_channels=None,
238
+ use_conv=False,
239
+ use_scale_shift_norm=False,
240
+ dims=2,
241
+ use_checkpoint=False,
242
+ up=False,
243
+ down=False,
244
+ ):
245
+ super().__init__()
246
+ self.channels = channels
247
+ self.emb_channels = emb_channels
248
+ self.dropout = dropout
249
+ self.out_channels = out_channels or channels
250
+ self.use_conv = use_conv
251
+ self.use_checkpoint = use_checkpoint
252
+ self.use_scale_shift_norm = use_scale_shift_norm
253
+
254
+ self.in_layers = nn.Sequential(
255
+ normalization(channels),
256
+ nn.SiLU(),
257
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
258
+ )
259
+
260
+ self.updown = up or down
261
+
262
+ if up:
263
+ self.h_upd = Upsample(channels, False, dims)
264
+ self.x_upd = Upsample(channels, False, dims)
265
+ elif down:
266
+ self.h_upd = Downsample(channels, False, dims)
267
+ self.x_upd = Downsample(channels, False, dims)
268
+ else:
269
+ self.h_upd = self.x_upd = nn.Identity()
270
+
271
+ self.emb_layers = nn.Sequential(
272
+ nn.SiLU(),
273
+ linear(
274
+ emb_channels,
275
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
276
+ ),
277
+ )
278
+ self.out_layers = nn.Sequential(
279
+ normalization(self.out_channels),
280
+ nn.SiLU(),
281
+ nn.Dropout(p=dropout),
282
+ zero_module(
283
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
284
+ ),
285
+ )
286
+
287
+ if self.out_channels == channels:
288
+ self.skip_connection = nn.Identity()
289
+ elif use_conv:
290
+ self.skip_connection = conv_nd(
291
+ dims, channels, self.out_channels, 3, padding=1
292
+ )
293
+ else:
294
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
295
+
296
+ def forward(self, x, emb):
297
+ """
298
+ Apply the block to a Tensor, conditioned on a timestep embedding.
299
+
300
+ :param x: an [N x C x ...] Tensor of features.
301
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
302
+ :return: an [N x C x ...] Tensor of outputs.
303
+ """
304
+ return checkpoint(
305
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
306
+ )
307
+
308
+ def _forward(self, x, emb):
309
+ if self.updown:
310
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
311
+ h = in_rest(x)
312
+ h = self.h_upd(h)
313
+ x = self.x_upd(x)
314
+ h = in_conv(h)
315
+ else:
316
+ h = self.in_layers(x)
317
+ emb_out = self.emb_layers(emb).type(h.dtype)
318
+ while len(emb_out.shape) < len(h.shape):
319
+ emb_out = emb_out[..., None]
320
+ if self.use_scale_shift_norm:
321
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
322
+ scale, shift = th.chunk(emb_out, 2, dim=1)
323
+ h = out_norm(h) * (1 + scale) + shift
324
+ h = out_rest(h)
325
+ else:
326
+ h = h + emb_out
327
+ h = self.out_layers(h)
328
+ return self.skip_connection(x) + h
329
+
330
+
331
+ class AttentionBlock(nn.Module):
332
+ """
333
+ An attention block that allows spatial positions to attend to each other.
334
+
335
+ Originally ported from here, but adapted to the N-d case.
336
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
337
+ """
338
+
339
+ def __init__(
340
+ self,
341
+ channels,
342
+ num_heads=1,
343
+ num_head_channels=-1,
344
+ use_checkpoint=False,
345
+ use_new_attention_order=False,
346
+ ):
347
+ super().__init__()
348
+ self.channels = channels
349
+ if num_head_channels == -1:
350
+ self.num_heads = num_heads
351
+ else:
352
+ assert (
353
+ channels % num_head_channels == 0
354
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
355
+ self.num_heads = channels // num_head_channels
356
+ self.use_checkpoint = use_checkpoint
357
+ self.norm = normalization(channels)
358
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
359
+ if use_new_attention_order:
360
+ # split qkv before split heads
361
+ self.attention = QKVAttention(self.num_heads)
362
+ else:
363
+ # split heads before split qkv
364
+ self.attention = QKVAttentionLegacy(self.num_heads)
365
+
366
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
367
+
368
+ def forward(self, x):
369
+ return checkpoint(self._forward, (x,), self.parameters(), True)
370
+
371
+ def _forward(self, x):
372
+ b, c, *spatial = x.shape
373
+ x = x.reshape(b, c, -1)
374
+ qkv = self.qkv(self.norm(x))
375
+ h = self.attention(qkv)
376
+ h = self.proj_out(h)
377
+ return (x + h).reshape(b, c, *spatial)
378
+
379
+
380
+ def count_flops_attn(model, _x, y):
381
+ """
382
+ A counter for the `thop` package to count the operations in an
383
+ attention operation.
384
+ Meant to be used like:
385
+ macs, params = thop.profile(
386
+ model,
387
+ inputs=(inputs, timestamps),
388
+ custom_ops={QKVAttention: QKVAttention.count_flops},
389
+ )
390
+ """
391
+ b, c, *spatial = y[0].shape
392
+ num_spatial = int(np.prod(spatial))
393
+ # We perform two matmuls with the same number of ops.
394
+ # The first computes the weight matrix, the second computes
395
+ # the combination of the value vectors.
396
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
397
+ model.total_ops += th.DoubleTensor([matmul_ops])
398
+
399
+
400
+ class QKVAttentionLegacy(nn.Module):
401
+ """
402
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
403
+ """
404
+
405
+ def __init__(self, n_heads):
406
+ super().__init__()
407
+ self.n_heads = n_heads
408
+
409
+ def forward(self, qkv):
410
+ """
411
+ Apply QKV attention.
412
+
413
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
414
+ :return: an [N x (H * C) x T] tensor after attention.
415
+ """
416
+ bs, width, length = qkv.shape
417
+ assert width % (3 * self.n_heads) == 0
418
+ ch = width // (3 * self.n_heads)
419
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
420
+ scale = 1 / math.sqrt(math.sqrt(ch))
421
+ weight = th.einsum(
422
+ "bct,bcs->bts", q * scale, k * scale
423
+ ) # More stable with f16 than dividing afterwards
424
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
425
+ a = th.einsum("bts,bcs->bct", weight, v)
426
+ return a.reshape(bs, -1, length)
427
+
428
+ @staticmethod
429
+ def count_flops(model, _x, y):
430
+ return count_flops_attn(model, _x, y)
431
+
432
+
433
+ class QKVAttention(nn.Module):
434
+ """
435
+ A module which performs QKV attention and splits in a different order.
436
+ """
437
+
438
+ def __init__(self, n_heads):
439
+ super().__init__()
440
+ self.n_heads = n_heads
441
+
442
+ def forward(self, qkv):
443
+ """
444
+ Apply QKV attention.
445
+
446
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
447
+ :return: an [N x (H * C) x T] tensor after attention.
448
+ """
449
+ bs, width, length = qkv.shape
450
+ assert width % (3 * self.n_heads) == 0
451
+ ch = width // (3 * self.n_heads)
452
+ q, k, v = qkv.chunk(3, dim=1)
453
+ scale = 1 / math.sqrt(math.sqrt(ch))
454
+ weight = th.einsum(
455
+ "bct,bcs->bts",
456
+ (q * scale).view(bs * self.n_heads, ch, length),
457
+ (k * scale).view(bs * self.n_heads, ch, length),
458
+ ) # More stable with f16 than dividing afterwards
459
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
460
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
461
+ return a.reshape(bs, -1, length)
462
+
463
+ @staticmethod
464
+ def count_flops(model, _x, y):
465
+ return count_flops_attn(model, _x, y)
466
+
467
+
468
+ class UNetModel(nn.Module):
469
+ """
470
+ The full UNet model with attention and timestep embedding.
471
+
472
+ :param in_channels: channels in the input Tensor.
473
+ :param model_channels: base channel count for the model.
474
+ :param out_channels: channels in the output Tensor.
475
+ :param num_res_blocks: number of residual blocks per downsample.
476
+ :param attention_resolutions: a collection of downsample rates at which
477
+ attention will take place. May be a set, list, or tuple.
478
+ For example, if this contains 4, then at 4x downsampling, attention
479
+ will be used.
480
+ :param dropout: the dropout probability.
481
+ :param channel_mult: channel multiplier for each level of the UNet.
482
+ :param conv_resample: if True, use learned convolutions for upsampling and
483
+ downsampling.
484
+ :param dims: determines if the signal is 1D, 2D, or 3D.
485
+ :param num_classes: if specified (as an int), then this model will be
486
+ class-conditional with `num_classes` classes.
487
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
488
+ :param num_heads: the number of attention heads in each attention layer.
489
+ :param num_heads_channels: if specified, ignore num_heads and instead use
490
+ a fixed channel width per attention head.
491
+ :param num_heads_upsample: works with num_heads to set a different number
492
+ of heads for upsampling. Deprecated.
493
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
494
+ :param resblock_updown: use residual blocks for up/downsampling.
495
+ :param use_new_attention_order: use a different attention pattern for potentially
496
+ increased efficiency.
497
+ """
498
+
499
+ def __init__(
500
+ self,
501
+ image_size,
502
+ in_channels,
503
+ model_channels,
504
+ out_channels,
505
+ num_res_blocks,
506
+ attention_resolutions,
507
+ dropout=0,
508
+ channel_mult=(1, 2, 4, 8),
509
+ conv_resample=True,
510
+ dims=2,
511
+ num_classes=None,
512
+ use_checkpoint=False,
513
+ use_fp16=False,
514
+ num_heads=1,
515
+ num_head_channels=-1,
516
+ num_heads_upsample=-1,
517
+ use_scale_shift_norm=False,
518
+ resblock_updown=False,
519
+ use_new_attention_order=False,
520
+ ):
521
+ super().__init__()
522
+
523
+ if num_heads_upsample == -1:
524
+ num_heads_upsample = num_heads
525
+
526
+ self.image_size = image_size
527
+ self.in_channels = in_channels
528
+ self.model_channels = model_channels
529
+ self.out_channels = out_channels
530
+ self.num_res_blocks = num_res_blocks
531
+ self.attention_resolutions = attention_resolutions
532
+ self.dropout = dropout
533
+ self.channel_mult = channel_mult
534
+ self.conv_resample = conv_resample
535
+ self.num_classes = num_classes
536
+ self.use_checkpoint = use_checkpoint
537
+ self.dtype = th.float16 if use_fp16 else th.float32
538
+ self.num_heads = num_heads
539
+ self.num_head_channels = num_head_channels
540
+ self.num_heads_upsample = num_heads_upsample
541
+
542
+ time_embed_dim = model_channels * 4
543
+ self.time_embed = nn.Sequential(
544
+ linear(model_channels, time_embed_dim),
545
+ nn.SiLU(),
546
+ linear(time_embed_dim, time_embed_dim),
547
+ )
548
+
549
+ if self.num_classes is not None:
550
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
551
+
552
+ ch = input_ch = int(channel_mult[0] * model_channels)
553
+ self.input_blocks = nn.ModuleList(
554
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
555
+ )
556
+ self._feature_size = ch
557
+ input_block_chans = [ch]
558
+ ds = 1
559
+ for level, mult in enumerate(channel_mult):
560
+ for _ in range(num_res_blocks):
561
+ layers = [
562
+ ResBlock(
563
+ ch,
564
+ time_embed_dim,
565
+ dropout,
566
+ out_channels=int(mult * model_channels),
567
+ dims=dims,
568
+ use_checkpoint=use_checkpoint,
569
+ use_scale_shift_norm=use_scale_shift_norm,
570
+ )
571
+ ]
572
+ ch = int(mult * model_channels)
573
+ if ds in attention_resolutions:
574
+ layers.append(
575
+ AttentionBlock(
576
+ ch,
577
+ use_checkpoint=use_checkpoint,
578
+ num_heads=num_heads,
579
+ num_head_channels=num_head_channels,
580
+ use_new_attention_order=use_new_attention_order,
581
+ )
582
+ )
583
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
584
+ self._feature_size += ch
585
+ input_block_chans.append(ch)
586
+ if level != len(channel_mult) - 1:
587
+ out_ch = ch
588
+ self.input_blocks.append(
589
+ TimestepEmbedSequential(
590
+ ResBlock(
591
+ ch,
592
+ time_embed_dim,
593
+ dropout,
594
+ out_channels=out_ch,
595
+ dims=dims,
596
+ use_checkpoint=use_checkpoint,
597
+ use_scale_shift_norm=use_scale_shift_norm,
598
+ down=True,
599
+ )
600
+ if resblock_updown
601
+ else Downsample(
602
+ ch, conv_resample, dims=dims, out_channels=out_ch
603
+ )
604
+ )
605
+ )
606
+ ch = out_ch
607
+ input_block_chans.append(ch)
608
+ ds *= 2
609
+ self._feature_size += ch
610
+
611
+ self.middle_block = TimestepEmbedSequential(
612
+ ResBlock(
613
+ ch,
614
+ time_embed_dim,
615
+ dropout,
616
+ dims=dims,
617
+ use_checkpoint=use_checkpoint,
618
+ use_scale_shift_norm=use_scale_shift_norm,
619
+ ),
620
+ AttentionBlock(
621
+ ch,
622
+ use_checkpoint=use_checkpoint,
623
+ num_heads=num_heads,
624
+ num_head_channels=num_head_channels,
625
+ use_new_attention_order=use_new_attention_order,
626
+ ),
627
+ ResBlock(
628
+ ch,
629
+ time_embed_dim,
630
+ dropout,
631
+ dims=dims,
632
+ use_checkpoint=use_checkpoint,
633
+ use_scale_shift_norm=use_scale_shift_norm,
634
+ ),
635
+ )
636
+ self._feature_size += ch
637
+
638
+ self.output_blocks = nn.ModuleList([])
639
+ for level, mult in list(enumerate(channel_mult))[::-1]:
640
+ for i in range(num_res_blocks + 1):
641
+ ich = input_block_chans.pop()
642
+ layers = [
643
+ ResBlock(
644
+ ch + ich,
645
+ time_embed_dim,
646
+ dropout,
647
+ out_channels=int(model_channels * mult),
648
+ dims=dims,
649
+ use_checkpoint=use_checkpoint,
650
+ use_scale_shift_norm=use_scale_shift_norm,
651
+ )
652
+ ]
653
+ ch = int(model_channels * mult)
654
+ if ds in attention_resolutions:
655
+ layers.append(
656
+ AttentionBlock(
657
+ ch,
658
+ use_checkpoint=use_checkpoint,
659
+ num_heads=num_heads_upsample,
660
+ num_head_channels=num_head_channels,
661
+ use_new_attention_order=use_new_attention_order,
662
+ )
663
+ )
664
+ if level and i == num_res_blocks:
665
+ out_ch = ch
666
+ layers.append(
667
+ ResBlock(
668
+ ch,
669
+ time_embed_dim,
670
+ dropout,
671
+ out_channels=out_ch,
672
+ dims=dims,
673
+ use_checkpoint=use_checkpoint,
674
+ use_scale_shift_norm=use_scale_shift_norm,
675
+ up=True,
676
+ )
677
+ if resblock_updown
678
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
679
+ )
680
+ ds //= 2
681
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
682
+ self._feature_size += ch
683
+
684
+ self.out = nn.Sequential(
685
+ normalization(ch),
686
+ nn.SiLU(),
687
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
688
+ )
689
+
690
+ def convert_to_fp16(self):
691
+ """
692
+ Convert the torso of the model to float16.
693
+ """
694
+ self.input_blocks.apply(convert_module_to_f16)
695
+ self.middle_block.apply(convert_module_to_f16)
696
+ self.output_blocks.apply(convert_module_to_f16)
697
+
698
+ def convert_to_fp32(self):
699
+ """
700
+ Convert the torso of the model to float32.
701
+ """
702
+ self.input_blocks.apply(convert_module_to_f32)
703
+ self.middle_block.apply(convert_module_to_f32)
704
+ self.output_blocks.apply(convert_module_to_f32)
705
+
706
+ def forward(self, x, timesteps, y=None):
707
+ """
708
+ Apply the model to an input batch.
709
+
710
+ :param x: an [N x C x ...] Tensor of inputs.
711
+ :param timesteps: a 1-D batch of timesteps.
712
+ :param y: an [N] Tensor of labels, if class-conditional.
713
+ :return: an [N x C x ...] Tensor of outputs.
714
+ """
715
+ assert (y is not None) == (
716
+ self.num_classes is not None
717
+ ), "must specify y if and only if the model is class-conditional"
718
+
719
+ hs = []
720
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
721
+
722
+ if self.num_classes is not None:
723
+ assert y.shape == (x.shape[0],)
724
+ emb = emb + self.label_emb(y)
725
+
726
+ h = x.type(self.dtype)
727
+ for module in self.input_blocks:
728
+ h = module(h, emb)
729
+ hs.append(h)
730
+ h = self.middle_block(h, emb)
731
+ for module in self.output_blocks:
732
+ h = th.cat([h, hs.pop()], dim=1)
733
+ h = module(h, emb)
734
+ h = h.type(x.dtype)
735
+ return self.out(h)
736
+
737
+
738
+ class SuperResModel(UNetModel):
739
+ """
740
+ A UNetModel that performs super-resolution.
741
+
742
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
743
+ """
744
+
745
+ def __init__(self, image_size, in_channels, *args, **kwargs):
746
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
747
+
748
+ def forward(self, x, timesteps, low_res=None, **kwargs):
749
+ _, _, new_height, new_width = x.shape
750
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
751
+ x = th.cat([x, upsampled], dim=1)
752
+ return super().forward(x, timesteps, **kwargs)
753
+
754
+
755
+ class EncoderUNetModel(nn.Module):
756
+ """
757
+ The half UNet model with attention and timestep embedding.
758
+
759
+ For usage, see UNet.
760
+ """
761
+
762
+ def __init__(
763
+ self,
764
+ image_size,
765
+ in_channels,
766
+ model_channels,
767
+ out_channels,
768
+ num_res_blocks,
769
+ attention_resolutions,
770
+ dropout=0,
771
+ channel_mult=(1, 2, 4, 8),
772
+ conv_resample=True,
773
+ dims=2,
774
+ use_checkpoint=False,
775
+ use_fp16=False,
776
+ num_heads=1,
777
+ num_head_channels=-1,
778
+ num_heads_upsample=-1,
779
+ use_scale_shift_norm=False,
780
+ resblock_updown=False,
781
+ use_new_attention_order=False,
782
+ pool="adaptive",
783
+ ):
784
+ super().__init__()
785
+
786
+ if num_heads_upsample == -1:
787
+ num_heads_upsample = num_heads
788
+
789
+ self.in_channels = in_channels
790
+ self.model_channels = model_channels
791
+ self.out_channels = out_channels
792
+ self.num_res_blocks = num_res_blocks
793
+ self.attention_resolutions = attention_resolutions
794
+ self.dropout = dropout
795
+ self.channel_mult = channel_mult
796
+ self.conv_resample = conv_resample
797
+ self.use_checkpoint = use_checkpoint
798
+ self.dtype = th.float16 if use_fp16 else th.float32
799
+ self.num_heads = num_heads
800
+ self.num_head_channels = num_head_channels
801
+ self.num_heads_upsample = num_heads_upsample
802
+
803
+ time_embed_dim = model_channels * 4
804
+ self.time_embed = nn.Sequential(
805
+ linear(model_channels, time_embed_dim),
806
+ nn.SiLU(),
807
+ linear(time_embed_dim, time_embed_dim),
808
+ )
809
+
810
+ ch = int(channel_mult[0] * model_channels)
811
+ self.input_blocks = nn.ModuleList(
812
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
813
+ )
814
+ self._feature_size = ch
815
+ input_block_chans = [ch]
816
+ ds = 1
817
+ for level, mult in enumerate(channel_mult):
818
+ for _ in range(num_res_blocks):
819
+ layers = [
820
+ ResBlock(
821
+ ch,
822
+ time_embed_dim,
823
+ dropout,
824
+ out_channels=int(mult * model_channels),
825
+ dims=dims,
826
+ use_checkpoint=use_checkpoint,
827
+ use_scale_shift_norm=use_scale_shift_norm,
828
+ )
829
+ ]
830
+ ch = int(mult * model_channels)
831
+ if ds in attention_resolutions:
832
+ layers.append(
833
+ AttentionBlock(
834
+ ch,
835
+ use_checkpoint=use_checkpoint,
836
+ num_heads=num_heads,
837
+ num_head_channels=num_head_channels,
838
+ use_new_attention_order=use_new_attention_order,
839
+ )
840
+ )
841
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
842
+ self._feature_size += ch
843
+ input_block_chans.append(ch)
844
+ if level != len(channel_mult) - 1:
845
+ out_ch = ch
846
+ self.input_blocks.append(
847
+ TimestepEmbedSequential(
848
+ ResBlock(
849
+ ch,
850
+ time_embed_dim,
851
+ dropout,
852
+ out_channels=out_ch,
853
+ dims=dims,
854
+ use_checkpoint=use_checkpoint,
855
+ use_scale_shift_norm=use_scale_shift_norm,
856
+ down=True,
857
+ )
858
+ if resblock_updown
859
+ else Downsample(
860
+ ch, conv_resample, dims=dims, out_channels=out_ch
861
+ )
862
+ )
863
+ )
864
+ ch = out_ch
865
+ input_block_chans.append(ch)
866
+ ds *= 2
867
+ self._feature_size += ch
868
+
869
+ self.middle_block = TimestepEmbedSequential(
870
+ ResBlock(
871
+ ch,
872
+ time_embed_dim,
873
+ dropout,
874
+ dims=dims,
875
+ use_checkpoint=use_checkpoint,
876
+ use_scale_shift_norm=use_scale_shift_norm,
877
+ ),
878
+ AttentionBlock(
879
+ ch,
880
+ use_checkpoint=use_checkpoint,
881
+ num_heads=num_heads,
882
+ num_head_channels=num_head_channels,
883
+ use_new_attention_order=use_new_attention_order,
884
+ ),
885
+ ResBlock(
886
+ ch,
887
+ time_embed_dim,
888
+ dropout,
889
+ dims=dims,
890
+ use_checkpoint=use_checkpoint,
891
+ use_scale_shift_norm=use_scale_shift_norm,
892
+ ),
893
+ )
894
+ self._feature_size += ch
895
+ self.pool = pool
896
+ if pool == "adaptive":
897
+ self.out = nn.Sequential(
898
+ normalization(ch),
899
+ nn.SiLU(),
900
+ nn.AdaptiveAvgPool2d((1, 1)),
901
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
902
+ nn.Flatten(),
903
+ )
904
+ elif pool == "attention":
905
+ assert num_head_channels != -1
906
+ self.out = nn.Sequential(
907
+ normalization(ch),
908
+ nn.SiLU(),
909
+ AttentionPool2d(
910
+ (image_size // ds), ch, num_head_channels, out_channels
911
+ ),
912
+ )
913
+ elif pool == "spatial":
914
+ self.out = nn.Sequential(
915
+ nn.Linear(self._feature_size, 2048),
916
+ nn.ReLU(),
917
+ nn.Linear(2048, self.out_channels),
918
+ )
919
+ elif pool == "spatial_v2":
920
+ self.out = nn.Sequential(
921
+ nn.Linear(self._feature_size, 2048),
922
+ normalization(2048),
923
+ nn.SiLU(),
924
+ nn.Linear(2048, self.out_channels),
925
+ )
926
+ else:
927
+ raise NotImplementedError(f"Unexpected {pool} pooling")
928
+
929
+ def convert_to_fp16(self):
930
+ """
931
+ Convert the torso of the model to float16.
932
+ """
933
+ self.input_blocks.apply(convert_module_to_f16)
934
+ self.middle_block.apply(convert_module_to_f16)
935
+
936
+ def convert_to_fp32(self):
937
+ """
938
+ Convert the torso of the model to float32.
939
+ """
940
+ self.input_blocks.apply(convert_module_to_f32)
941
+ self.middle_block.apply(convert_module_to_f32)
942
+
943
+ def forward(self, x, timesteps):
944
+ """
945
+ Apply the model to an input batch.
946
+
947
+ :param x: an [N x C x ...] Tensor of inputs.
948
+ :param timesteps: a 1-D batch of timesteps.
949
+ :return: an [N x K] Tensor of outputs.
950
+ """
951
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
952
+
953
+ results = []
954
+ h = x.type(self.dtype)
955
+ for module in self.input_blocks:
956
+ h = module(h, emb)
957
+ if self.pool.startswith("spatial"):
958
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
959
+ h = self.middle_block(h, emb)
960
+ if self.pool.startswith("spatial"):
961
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
962
+ h = th.cat(results, axis=-1)
963
+ return self.out(h)
964
+ else:
965
+ h = h.type(x.dtype)
966
+ return self.out(h)
967
+
968
+
969
+ class NLayerDiscriminator(nn.Module):
970
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
971
+ super(NLayerDiscriminator, self).__init__()
972
+ if type(norm_layer) == functools.partial:
973
+ use_bias = norm_layer.func == nn.InstanceNorm2d
974
+ else:
975
+ use_bias = norm_layer == nn.InstanceNorm2d
976
+
977
+ kw = 4
978
+ padw = 1
979
+ sequence = [
980
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
981
+ nn.LeakyReLU(0.2, True)
982
+ ]
983
+
984
+ nf_mult = 1
985
+ nf_mult_prev = 1
986
+ for n in range(1, n_layers):
987
+ nf_mult_prev = nf_mult
988
+ nf_mult = min(2**n, 8)
989
+ sequence += [
990
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
991
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias),
992
+ norm_layer(ndf * nf_mult),
993
+ nn.LeakyReLU(0.2, True)
994
+ ]
995
+
996
+ nf_mult_prev = nf_mult
997
+ nf_mult = min(2**n_layers, 8)
998
+ sequence += [
999
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
1000
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1001
+ norm_layer(ndf * nf_mult),
1002
+ nn.LeakyReLU(0.2, True)
1003
+ ]
1004
+
1005
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=2, padding=padw)] + [nn.Dropout(0.5)]
1006
+ if use_sigmoid:
1007
+ sequence += [nn.Sigmoid()]
1008
+
1009
+ self.model = nn.Sequential(*sequence)
1010
+
1011
+ def forward(self, input):
1012
+ return self.model(input)
1013
+
1014
+
1015
+ class GANLoss(nn.Module):
1016
+ """Define different GAN objectives.
1017
+
1018
+ The GANLoss class abstracts away the need to create the target label tensor
1019
+ that has the same size as the input.
1020
+ """
1021
+
1022
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
1023
+ """ Initialize the GANLoss class.
1024
+
1025
+ Parameters:
1026
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
1027
+ target_real_label (bool) - - label for a real image
1028
+ target_fake_label (bool) - - label of a fake image
1029
+
1030
+ Note: Do not use sigmoid as the last layer of Discriminator.
1031
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
1032
+ """
1033
+ super(GANLoss, self).__init__()
1034
+ self.register_buffer('real_label', th.tensor(target_real_label))
1035
+ self.register_buffer('fake_label', th.tensor(target_fake_label))
1036
+ self.gan_mode = gan_mode
1037
+ if gan_mode == 'lsgan':
1038
+ self.loss = nn.MSELoss()
1039
+ elif gan_mode == 'vanilla':
1040
+ self.loss = nn.BCEWithLogitsLoss()
1041
+ elif gan_mode in ['wgangp']:
1042
+ self.loss = None
1043
+ else:
1044
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
1045
+
1046
+ def get_target_tensor(self, prediction, target_is_real):
1047
+ """Create label tensors with the same size as the input.
1048
+
1049
+ Parameters:
1050
+ prediction (tensor) - - tpyically the prediction from a discriminator
1051
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
1052
+
1053
+ Returns:
1054
+ A label tensor filled with ground truth label, and with the size of the input
1055
+ """
1056
+
1057
+ if target_is_real:
1058
+ target_tensor = self.real_label
1059
+ else:
1060
+ target_tensor = self.fake_label
1061
+ return target_tensor.expand_as(prediction)
1062
+
1063
+ def __call__(self, prediction, target_is_real):
1064
+ """Calculate loss given Discriminator's output and grount truth labels.
1065
+
1066
+ Parameters:
1067
+ prediction (tensor) - - tpyically the prediction output from a discriminator
1068
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
1069
+
1070
+ Returns:
1071
+ the calculated loss.
1072
+ """
1073
+ if self.gan_mode in ['lsgan', 'vanilla']:
1074
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
1075
+ loss = self.loss(prediction, target_tensor)
1076
+ elif self.gan_mode == 'wgangp':
1077
+ if target_is_real:
1078
+ loss = -prediction.mean()
1079
+ else:
1080
+ loss = prediction.mean()
1081
+ return loss
1082
+
1083
+
1084
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
1085
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
1086
+
1087
+ Arguments:
1088
+ netD (network) -- discriminator network
1089
+ real_data (tensor array) -- real images
1090
+ fake_data (tensor array) -- generated images from the generator
1091
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
1092
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
1093
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
1094
+ lambda_gp (float) -- weight for this loss
1095
+
1096
+ Returns the gradient penalty loss
1097
+ """
1098
+ if lambda_gp > 0.0:
1099
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
1100
+ interpolatesv = real_data
1101
+ elif type == 'fake':
1102
+ interpolatesv = fake_data
1103
+ elif type == 'mixed':
1104
+ alpha = th.rand(real_data.shape[0], 1, device=device)
1105
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
1106
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
1107
+ else:
1108
+ raise NotImplementedError('{} not implemented'.format(type))
1109
+ interpolatesv.requires_grad_(True)
1110
+ disc_interpolates = netD(interpolatesv)
1111
+ gradients = th.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
1112
+ grad_outputs=th.ones(disc_interpolates.size()).to(device),
1113
+ create_graph=True, retain_graph=True, only_inputs=True)
1114
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
1115
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
1116
+ return gradient_penalty, gradients
1117
+ else:
1118
+ return 0.0, None
cdim/dps_model/fp16_util.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ INITIAL_LOG_LOSS_SCALE = 20.0
11
+
12
+
13
+ def convert_module_to_f16(l):
14
+ """
15
+ Convert primitive modules to float16.
16
+ """
17
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
18
+ l.weight.data = l.weight.data.half()
19
+ if l.bias is not None:
20
+ l.bias.data = l.bias.data.half()
21
+
22
+
23
+ def convert_module_to_f32(l):
24
+ """
25
+ Convert primitive modules to float32, undoing convert_module_to_f16().
26
+ """
27
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
28
+ l.weight.data = l.weight.data.float()
29
+ if l.bias is not None:
30
+ l.bias.data = l.bias.data.float()
31
+
32
+
33
+ def make_master_params(param_groups_and_shapes):
34
+ """
35
+ Copy model parameters into a (differently-shaped) list of full-precision
36
+ parameters.
37
+ """
38
+ master_params = []
39
+ for param_group, shape in param_groups_and_shapes:
40
+ master_param = nn.Parameter(
41
+ _flatten_dense_tensors(
42
+ [param.detach().float() for (_, param) in param_group]
43
+ ).view(shape)
44
+ )
45
+ master_param.requires_grad = True
46
+ master_params.append(master_param)
47
+ return master_params
48
+
49
+
50
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
51
+ """
52
+ Copy the gradients from the model parameters into the master parameters
53
+ from make_master_params().
54
+ """
55
+ for master_param, (param_group, shape) in zip(
56
+ master_params, param_groups_and_shapes
57
+ ):
58
+ master_param.grad = _flatten_dense_tensors(
59
+ [param_grad_or_zeros(param) for (_, param) in param_group]
60
+ ).view(shape)
61
+
62
+
63
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
64
+ """
65
+ Copy the master parameter data back into the model parameters.
66
+ """
67
+ # Without copying to a list, if a generator is passed, this will
68
+ # silently not copy any parameters.
69
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
70
+ for (_, param), unflat_master_param in zip(
71
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
72
+ ):
73
+ param.detach().copy_(unflat_master_param)
74
+
75
+
76
+ def unflatten_master_params(param_group, master_param):
77
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
78
+
79
+
80
+ def get_param_groups_and_shapes(named_model_params):
81
+ named_model_params = list(named_model_params)
82
+ scalar_vector_named_params = (
83
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
84
+ (-1),
85
+ )
86
+ matrix_named_params = (
87
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
88
+ (1, -1),
89
+ )
90
+ return [scalar_vector_named_params, matrix_named_params]
91
+
92
+
93
+ def master_params_to_state_dict(
94
+ model, param_groups_and_shapes, master_params, use_fp16
95
+ ):
96
+ if use_fp16:
97
+ state_dict = model.state_dict()
98
+ for master_param, (param_group, _) in zip(
99
+ master_params, param_groups_and_shapes
100
+ ):
101
+ for (name, _), unflat_master_param in zip(
102
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
103
+ ):
104
+ assert name in state_dict
105
+ state_dict[name] = unflat_master_param
106
+ else:
107
+ state_dict = model.state_dict()
108
+ for i, (name, _value) in enumerate(model.named_parameters()):
109
+ assert name in state_dict
110
+ state_dict[name] = master_params[i]
111
+ return state_dict
112
+
113
+
114
+ def state_dict_to_master_params(model, state_dict, use_fp16):
115
+ if use_fp16:
116
+ named_model_params = [
117
+ (name, state_dict[name]) for name, _ in model.named_parameters()
118
+ ]
119
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
120
+ master_params = make_master_params(param_groups_and_shapes)
121
+ else:
122
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
123
+ return master_params
124
+
125
+
126
+ def zero_master_grads(master_params):
127
+ for param in master_params:
128
+ param.grad = None
129
+
130
+
131
+ def zero_grad(model_params):
132
+ for param in model_params:
133
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
134
+ if param.grad is not None:
135
+ param.grad.detach_()
136
+ param.grad.zero_()
137
+
138
+
139
+ def param_grad_or_zeros(param):
140
+ if param.grad is not None:
141
+ return param.grad.data.detach()
142
+ else:
143
+ return th.zeros_like(param)
144
+
145
+
146
+ class MixedPrecisionTrainer:
147
+ def __init__(
148
+ self,
149
+ *,
150
+ model,
151
+ use_fp16=False,
152
+ fp16_scale_growth=1e-3,
153
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
154
+ ):
155
+ self.model = model
156
+ self.use_fp16 = use_fp16
157
+ self.fp16_scale_growth = fp16_scale_growth
158
+
159
+ self.model_params = list(self.model.parameters())
160
+ self.master_params = self.model_params
161
+ self.param_groups_and_shapes = None
162
+ self.lg_loss_scale = initial_lg_loss_scale
163
+
164
+ if self.use_fp16:
165
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
166
+ self.model.named_parameters()
167
+ )
168
+ self.master_params = make_master_params(self.param_groups_and_shapes)
169
+ self.model.convert_to_fp16()
170
+
171
+ def zero_grad(self):
172
+ zero_grad(self.model_params)
173
+
174
+ def backward(self, loss: th.Tensor):
175
+ if self.use_fp16:
176
+ loss_scale = 2 ** self.lg_loss_scale
177
+ (loss * loss_scale).backward()
178
+ else:
179
+ loss.backward()
180
+
181
+ def optimize(self, opt: th.optim.Optimizer):
182
+ if self.use_fp16:
183
+ return self._optimize_fp16(opt)
184
+ else:
185
+ return self._optimize_normal(opt)
186
+
187
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
188
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
189
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
190
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
191
+ if check_overflow(grad_norm):
192
+ self.lg_loss_scale -= 1
193
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
194
+ zero_master_grads(self.master_params)
195
+ return False
196
+
197
+ logger.logkv_mean("grad_norm", grad_norm)
198
+ logger.logkv_mean("param_norm", param_norm)
199
+
200
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
201
+ opt.step()
202
+ zero_master_grads(self.master_params)
203
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
204
+ self.lg_loss_scale += self.fp16_scale_growth
205
+ return True
206
+
207
+ def _optimize_normal(self, opt: th.optim.Optimizer):
208
+ grad_norm, param_norm = self._compute_norms()
209
+ logger.logkv_mean("grad_norm", grad_norm)
210
+ logger.logkv_mean("param_norm", param_norm)
211
+ opt.step()
212
+ return True
213
+
214
+ def _compute_norms(self, grad_scale=1.0):
215
+ grad_norm = 0.0
216
+ param_norm = 0.0
217
+ for p in self.master_params:
218
+ with th.no_grad():
219
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
220
+ if p.grad is not None:
221
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
222
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
223
+
224
+ def master_params_to_state_dict(self, master_params):
225
+ return master_params_to_state_dict(
226
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
227
+ )
228
+
229
+ def state_dict_to_master_params(self, state_dict):
230
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
231
+
232
+
233
+ def check_overflow(value):
234
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
cdim/dps_model/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
cdim/image_utils.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from torchvision.transforms import ToPILImage
2
 
3
  def save_to_image(tensor, filename):
@@ -15,3 +18,51 @@ def save_to_image(tensor, filename):
15
 
16
  # Save the image
17
  img.save(filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
  from torchvision.transforms import ToPILImage
5
 
6
  def save_to_image(tensor, filename):
 
18
 
19
  # Save the image
20
  img.save(filename)
21
+
22
+
23
+ def randn_tensor(
24
+ shape: Union[Tuple, List],
25
+ generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
26
+ device: Optional["torch.device"] = None,
27
+ dtype: Optional["torch.dtype"] = None,
28
+ layout: Optional["torch.layout"] = None,
29
+ ):
30
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
31
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
32
+ is always created on the CPU.
33
+ """
34
+ # device on which tensor is created defaults to device
35
+ rand_device = device
36
+ batch_size = shape[0]
37
+
38
+ layout = layout or torch.strided
39
+ device = device or torch.device("cpu")
40
+
41
+ if generator is not None:
42
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
43
+ if gen_device_type != device.type and gen_device_type == "cpu":
44
+ rand_device = "cpu"
45
+ if device != "mps":
46
+ logger.info(
47
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
48
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
49
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
50
+ )
51
+ elif gen_device_type != device.type and gen_device_type == "cuda":
52
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
53
+
54
+ # make sure generator list of length 1 is treated like a non-list
55
+ if isinstance(generator, list) and len(generator) == 1:
56
+ generator = generator[0]
57
+
58
+ if isinstance(generator, list):
59
+ shape = (1,) + shape[1:]
60
+ latents = [
61
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
62
+ for i in range(batch_size)
63
+ ]
64
+ latents = torch.cat(latents, dim=0).to(device)
65
+ else:
66
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
67
+
68
+ return latents
cdim/noise.py CHANGED
@@ -33,18 +33,20 @@ class Noise(ABC):
33
  @register_noise(name='gaussian')
34
  class GaussianNoise(Noise):
35
  def __init__(self, sigma):
36
- self.sigma = sigma
37
-
38
- def __call__(self, data):
39
  # Important! We scale sigma by 2 because the config assumes images are in [0, 1]
40
  # but actually this model uses images in [-1, 1]
41
- return data + torch.randn_like(data, device=data.device) * self.sigma * 2
 
 
 
 
42
 
43
 
44
  @register_noise(name='poisson')
45
  class PoissonNoise(Noise):
46
  def __init__(self, rate):
47
  self.rate = rate
 
48
 
49
  def __call__(self, data):
50
  import numpy as np
 
33
  @register_noise(name='gaussian')
34
  class GaussianNoise(Noise):
35
  def __init__(self, sigma):
 
 
 
36
  # Important! We scale sigma by 2 because the config assumes images are in [0, 1]
37
  # but actually this model uses images in [-1, 1]
38
+ self.sigma = 2 * sigma
39
+ self.name = 'gaussian'
40
+
41
+ def __call__(self, data):
42
+ return data + torch.randn_like(data, device=data.device) * self.sigma
43
 
44
 
45
  @register_noise(name='poisson')
46
  class PoissonNoise(Noise):
47
  def __init__(self, rate):
48
  self.rate = rate
49
+ self.name = 'poisson'
50
 
51
  def __call__(self, data):
52
  import numpy as np
cdim/operators/__init__.py CHANGED
@@ -21,4 +21,5 @@ def get_operator(name: str, **kwargs):
21
 
22
  # Import everything to make sure they register
23
  from .random_box_masker import RandomBoxMasker
 
24
  from .identity_operator import IdentityOperator
 
21
 
22
  # Import everything to make sure they register
23
  from .random_box_masker import RandomBoxMasker
24
+ from .random_pixel_masker import RandomPixelMasker
25
  from .identity_operator import IdentityOperator
cdim/operators/random_pixel_masker.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from cdim.operators import register_operator
4
+
5
+ @register_operator(name='random_inpainting')
6
+ class RandomPixelMasker:
7
+ def __init__(self, height=256, width=256, channels=3, fraction=0.08, device='cpu'):
8
+ """
9
+ Initialize the ConsistentRandomPixelSelector.
10
+
11
+ Args:
12
+ height (int): Height of the input tensors (default: 256)
13
+ width (int): Width of the input tensors (default: 256)
14
+ channels (int): Number of channels in the input tensors (default: 3)
15
+ fraction (float): Fraction of pixels to keep (default: 0.08 for 8%)
16
+ device (str): Device to create the mask on (default: 'cpu')
17
+ """
18
+ self.height = height
19
+ self.width = width
20
+ self.channels = channels
21
+ self.fraction = fraction
22
+ self.device = device
23
+
24
+ # Create a binary mask for pixel selection
25
+ num_pixels = height * width
26
+ num_selected = int(num_pixels * fraction)
27
+ self.mask = torch.zeros((1, channels, height, width), device=device)
28
+
29
+ # Randomly select pixel indices
30
+ selected_indices = torch.randperm(num_pixels)[:num_selected]
31
+
32
+ # Convert indices to 2D coordinates
33
+ selected_y = selected_indices // width
34
+ selected_x = selected_indices % width
35
+
36
+ # Set selected pixels in the mask to 1
37
+ self.mask[0, :, selected_y, selected_x] = 1
38
+
39
+ def __call__(self, tensor):
40
+ """
41
+ Apply the consistent random pixel selection to the input tensor.
42
+
43
+ Args:
44
+ tensor (torch.Tensor): Input tensor of shape (b, channels, height, width)
45
+
46
+ Returns:
47
+ torch.Tensor: Tensor with the same shape as input, but with only selected pixels
48
+ """
49
+ b, c, h, w = tensor.shape
50
+ assert c == self.channels and h == self.height and w == self.width, \
51
+ f"Input tensor must be of shape (b, {self.channels}, {self.height}, {self.width})"
52
+
53
+ # Move the mask to the same device as the input tensor if necessary
54
+ if tensor.device != self.mask.device:
55
+ self.mask = self.mask.to(tensor.device)
56
+
57
+ # Apply the mask to the input tensor
58
+ return tensor * self.mask
inference.py CHANGED
@@ -1,6 +1,7 @@
1
  import argparse
2
  import os
3
  import yaml
 
4
 
5
  from PIL import Image
6
  import numpy as np
@@ -9,7 +10,11 @@ import torch
9
  from cdim.noise import get_noise
10
  from cdim.operators import get_operator
11
  from cdim.image_utils import save_to_image
 
 
 
12
 
 
13
 
14
  def load_image(path):
15
  """
@@ -40,17 +45,43 @@ def main(args):
40
  # Load the noise function
41
  noise_config = load_yaml(args.noise_config)
42
  noise_function = get_noise(**noise_config)
43
- print(noise_function)
44
 
45
  # Load the measurement function A
46
  operator_config = load_yaml(args.operator_config)
47
  operator_config["device"] = device
48
  operator = get_operator(**operator_config)
49
- print(operator)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  noisy_measurement = noise_function(operator(original_image))
52
  save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
53
 
 
 
 
 
 
 
 
 
 
 
54
  if __name__ == '__main__':
55
  parser = argparse.ArgumentParser()
56
  parser.add_argument("input_image", type=str)
@@ -59,6 +90,7 @@ if __name__ == '__main__':
59
  parser.add_argument("model", type=str)
60
  parser.add_argument("operator_config", type=str)
61
  parser.add_argument("noise_config", type=str)
 
62
  parser.add_argument("--output-dir", default=".", type=str)
63
  parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
64
 
 
1
  import argparse
2
  import os
3
  import yaml
4
+ import time
5
 
6
  from PIL import Image
7
  import numpy as np
 
10
  from cdim.noise import get_noise
11
  from cdim.operators import get_operator
12
  from cdim.image_utils import save_to_image
13
+ from cdim.dps_model.dps_unet import create_model
14
+ from cdim.diffusion.scheduling_ddim import DDIMScheduler
15
+ from cdim.diffusion.diffusion_pipeline import run_diffusion
16
 
17
+ torch.manual_seed(8)
18
 
19
  def load_image(path):
20
  """
 
45
  # Load the noise function
46
  noise_config = load_yaml(args.noise_config)
47
  noise_function = get_noise(**noise_config)
 
48
 
49
  # Load the measurement function A
50
  operator_config = load_yaml(args.operator_config)
51
  operator_config["device"] = device
52
  operator = get_operator(**operator_config)
53
+
54
+ # Load the model
55
+ model_config = load_yaml(args.model_config)
56
+ model = create_model(**model_config)
57
+ model = model.to(device)
58
+ model.eval()
59
+
60
+ # All the models have the same scheduler.
61
+ # you can change this for different models
62
+ ddim_scheduler = DDIMScheduler(
63
+ num_train_timesteps=1000,
64
+ beta_start=0.0001,
65
+ beta_end=0.02,
66
+ beta_schedule="linear",
67
+ prediction_type="epsilon",
68
+ timestep_spacing="leading",
69
+ steps_offset=0,
70
+ )
71
 
72
  noisy_measurement = noise_function(operator(original_image))
73
  save_to_image(noisy_measurement, os.path.join(args.output_dir, "noisy_measurement.png"))
74
 
75
+ t0 = time.time()
76
+ output_image = run_diffusion(
77
+ model, ddim_scheduler,
78
+ noisy_measurement, operator, noise_function, device,
79
+ num_inference_steps=args.T,
80
+ K=args.K)
81
+ print(f"total time {time.time() - t0}")
82
+
83
+ save_to_image(output_image, os.path.join(args.output_dir, "output.png"))
84
+
85
  if __name__ == '__main__':
86
  parser = argparse.ArgumentParser()
87
  parser.add_argument("input_image", type=str)
 
90
  parser.add_argument("model", type=str)
91
  parser.add_argument("operator_config", type=str)
92
  parser.add_argument("noise_config", type=str)
93
+ parser.add_argument("model_config", type=str)
94
  parser.add_argument("--output-dir", default=".", type=str)
95
  parser.add_argument("--cuda", default=True, action=argparse.BooleanOptionalAction)
96
 
models/ffhq_model_config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Defaults for image training.
2
+
3
+ image_size: 256
4
+ num_channels: 128
5
+ num_res_blocks: 1
6
+ channel_mult: ""
7
+ learn_sigma: True
8
+ class_cond: False
9
+ use_checkpoint: False
10
+ attention_resolutions: 16
11
+ num_heads: 4
12
+ num_head_channels: 64
13
+ num_heads_upsample: -1
14
+ use_scale_shift_norm: True
15
+ dropout: 0.0
16
+ resblock_updown: True
17
+ use_fp16: False
18
+ use_new_attention_order: False
19
+
20
+ model_path: models/ffhq_10m.pt
operator_configs/random_inpainting_config.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ name: random_inpainting
2
+ fraction: 0.08 # Fraction of pixels to keep
3
+ height: 256
4
+ width: 256
5
+ channels: 3