LituRout commited on
Commit
399d9fc
1 Parent(s): 425b6ce

add dps gd

Browse files
diffusion-posterior-sampling/guided_diffusion/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Codebase for "Improved Denoising Diffusion Probabilistic Models".
3
+ """
diffusion-posterior-sampling/guided_diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (256 Bytes). View file
 
diffusion-posterior-sampling/guided_diffusion/__pycache__/condition_methods.cpython-38.pyc ADDED
Binary file (4.69 kB). View file
 
diffusion-posterior-sampling/guided_diffusion/__pycache__/fp16_util.cpython-38.pyc ADDED
Binary file (7.78 kB). View file
 
diffusion-posterior-sampling/guided_diffusion/__pycache__/gaussian_diffusion.cpython-38.pyc ADDED
Binary file (15.4 kB). View file
 
diffusion-posterior-sampling/guided_diffusion/__pycache__/measurements.cpython-38.pyc ADDED
Binary file (11 kB). View file
 
diffusion-posterior-sampling/guided_diffusion/__pycache__/nn.cpython-38.pyc ADDED
Binary file (5.9 kB). View file
 
diffusion-posterior-sampling/guided_diffusion/__pycache__/posterior_mean_variance.cpython-38.pyc ADDED
Binary file (9.17 kB). View file
 
diffusion-posterior-sampling/guided_diffusion/__pycache__/unet.cpython-38.pyc ADDED
Binary file (28.2 kB). View file
 
diffusion-posterior-sampling/guided_diffusion/condition_methods.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import torch
3
+
4
+ __CONDITIONING_METHOD__ = {}
5
+
6
+ def register_conditioning_method(name: str):
7
+ def wrapper(cls):
8
+ if __CONDITIONING_METHOD__.get(name, None):
9
+ raise NameError(f"Name {name} is already registered!")
10
+ __CONDITIONING_METHOD__[name] = cls
11
+ return cls
12
+ return wrapper
13
+
14
+ def get_conditioning_method(name: str, operator, noiser, **kwargs):
15
+ if __CONDITIONING_METHOD__.get(name, None) is None:
16
+ raise NameError(f"Name {name} is not defined!")
17
+ return __CONDITIONING_METHOD__[name](operator=operator, noiser=noiser, **kwargs)
18
+
19
+
20
+ class ConditioningMethod(ABC):
21
+ def __init__(self, operator, noiser, **kwargs):
22
+ self.operator = operator
23
+ self.noiser = noiser
24
+
25
+ def project(self, data, noisy_measurement, **kwargs):
26
+ return self.operator.project(data=data, measurement=noisy_measurement, **kwargs)
27
+
28
+ def grad_and_value(self, x_prev, x_0_hat, measurement, **kwargs):
29
+ if self.noiser.__name__ == 'gaussian':
30
+ difference = measurement - self.operator.forward(x_0_hat, **kwargs)
31
+ norm = torch.linalg.norm(difference)
32
+ norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
33
+
34
+ elif self.noiser.__name__ == 'poisson':
35
+ Ax = self.operator.forward(x_0_hat, **kwargs)
36
+ difference = measurement-Ax
37
+ norm = torch.linalg.norm(difference) / measurement.abs()
38
+ norm = norm.mean()
39
+ norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
40
+
41
+ else:
42
+ raise NotImplementedError
43
+
44
+ return norm_grad, norm
45
+
46
+ @abstractmethod
47
+ def conditioning(self, x_t, measurement, noisy_measurement=None, **kwargs):
48
+ pass
49
+
50
+ @register_conditioning_method(name='vanilla')
51
+ class Identity(ConditioningMethod):
52
+ # just pass the input without conditioning
53
+ def conditioning(self, x_t):
54
+ return x_t
55
+
56
+ @register_conditioning_method(name='projection')
57
+ class Projection(ConditioningMethod):
58
+ def conditioning(self, x_t, noisy_measurement, **kwargs):
59
+ x_t = self.project(data=x_t, noisy_measurement=noisy_measurement)
60
+ return x_t
61
+
62
+
63
+ @register_conditioning_method(name='mcg')
64
+ class ManifoldConstraintGradient(ConditioningMethod):
65
+ def __init__(self, operator, noiser, **kwargs):
66
+ super().__init__(operator, noiser)
67
+ self.scale = kwargs.get('scale', 1.0)
68
+
69
+ def conditioning(self, x_prev, x_t, x_0_hat, measurement, noisy_measurement, **kwargs):
70
+ # posterior sampling
71
+ norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
72
+ x_t -= norm_grad * self.scale
73
+
74
+ # projection
75
+ x_t = self.project(data=x_t, noisy_measurement=noisy_measurement, **kwargs)
76
+ return x_t, norm
77
+
78
+ @register_conditioning_method(name='ps')
79
+ class PosteriorSampling(ConditioningMethod):
80
+ def __init__(self, operator, noiser, **kwargs):
81
+ super().__init__(operator, noiser)
82
+ self.scale = kwargs.get('scale', 1.0)
83
+
84
+ def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):
85
+ norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
86
+ x_t -= norm_grad * self.scale
87
+ return x_t, norm
88
+
89
+ @register_conditioning_method(name='ps+')
90
+ class PosteriorSamplingPlus(ConditioningMethod):
91
+ def __init__(self, operator, noiser, **kwargs):
92
+ super().__init__(operator, noiser)
93
+ self.num_sampling = kwargs.get('num_sampling', 5)
94
+ self.scale = kwargs.get('scale', 1.0)
95
+
96
+ def conditioning(self, x_prev, x_t, x_0_hat, measurement, **kwargs):
97
+ norm = 0
98
+ for _ in range(self.num_sampling):
99
+ # TODO: use noiser?
100
+ x_0_hat_noise = x_0_hat + 0.05 * torch.rand_like(x_0_hat)
101
+ difference = measurement - self.operator.forward(x_0_hat_noise)
102
+ norm += torch.linalg.norm(difference) / self.num_sampling
103
+
104
+ norm_grad = torch.autograd.grad(outputs=norm, inputs=x_prev)[0]
105
+ x_t -= norm_grad * self.scale
106
+ return x_t, norm
diffusion-posterior-sampling/guided_diffusion/fp16_util.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ INITIAL_LOG_LOSS_SCALE = 20.0
11
+
12
+
13
+ def convert_module_to_f16(l):
14
+ """
15
+ Convert primitive modules to float16.
16
+ """
17
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
18
+ l.weight.data = l.weight.data.half()
19
+ if l.bias is not None:
20
+ l.bias.data = l.bias.data.half()
21
+
22
+
23
+ def convert_module_to_f32(l):
24
+ """
25
+ Convert primitive modules to float32, undoing convert_module_to_f16().
26
+ """
27
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
28
+ l.weight.data = l.weight.data.float()
29
+ if l.bias is not None:
30
+ l.bias.data = l.bias.data.float()
31
+
32
+
33
+ def make_master_params(param_groups_and_shapes):
34
+ """
35
+ Copy model parameters into a (differently-shaped) list of full-precision
36
+ parameters.
37
+ """
38
+ master_params = []
39
+ for param_group, shape in param_groups_and_shapes:
40
+ master_param = nn.Parameter(
41
+ _flatten_dense_tensors(
42
+ [param.detach().float() for (_, param) in param_group]
43
+ ).view(shape)
44
+ )
45
+ master_param.requires_grad = True
46
+ master_params.append(master_param)
47
+ return master_params
48
+
49
+
50
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
51
+ """
52
+ Copy the gradients from the model parameters into the master parameters
53
+ from make_master_params().
54
+ """
55
+ for master_param, (param_group, shape) in zip(
56
+ master_params, param_groups_and_shapes
57
+ ):
58
+ master_param.grad = _flatten_dense_tensors(
59
+ [param_grad_or_zeros(param) for (_, param) in param_group]
60
+ ).view(shape)
61
+
62
+
63
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
64
+ """
65
+ Copy the master parameter data back into the model parameters.
66
+ """
67
+ # Without copying to a list, if a generator is passed, this will
68
+ # silently not copy any parameters.
69
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
70
+ for (_, param), unflat_master_param in zip(
71
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
72
+ ):
73
+ param.detach().copy_(unflat_master_param)
74
+
75
+
76
+ def unflatten_master_params(param_group, master_param):
77
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
78
+
79
+
80
+ def get_param_groups_and_shapes(named_model_params):
81
+ named_model_params = list(named_model_params)
82
+ scalar_vector_named_params = (
83
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
84
+ (-1),
85
+ )
86
+ matrix_named_params = (
87
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
88
+ (1, -1),
89
+ )
90
+ return [scalar_vector_named_params, matrix_named_params]
91
+
92
+
93
+ def master_params_to_state_dict(
94
+ model, param_groups_and_shapes, master_params, use_fp16
95
+ ):
96
+ if use_fp16:
97
+ state_dict = model.state_dict()
98
+ for master_param, (param_group, _) in zip(
99
+ master_params, param_groups_and_shapes
100
+ ):
101
+ for (name, _), unflat_master_param in zip(
102
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
103
+ ):
104
+ assert name in state_dict
105
+ state_dict[name] = unflat_master_param
106
+ else:
107
+ state_dict = model.state_dict()
108
+ for i, (name, _value) in enumerate(model.named_parameters()):
109
+ assert name in state_dict
110
+ state_dict[name] = master_params[i]
111
+ return state_dict
112
+
113
+
114
+ def state_dict_to_master_params(model, state_dict, use_fp16):
115
+ if use_fp16:
116
+ named_model_params = [
117
+ (name, state_dict[name]) for name, _ in model.named_parameters()
118
+ ]
119
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
120
+ master_params = make_master_params(param_groups_and_shapes)
121
+ else:
122
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
123
+ return master_params
124
+
125
+
126
+ def zero_master_grads(master_params):
127
+ for param in master_params:
128
+ param.grad = None
129
+
130
+
131
+ def zero_grad(model_params):
132
+ for param in model_params:
133
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
134
+ if param.grad is not None:
135
+ param.grad.detach_()
136
+ param.grad.zero_()
137
+
138
+
139
+ def param_grad_or_zeros(param):
140
+ if param.grad is not None:
141
+ return param.grad.data.detach()
142
+ else:
143
+ return th.zeros_like(param)
144
+
145
+
146
+ class MixedPrecisionTrainer:
147
+ def __init__(
148
+ self,
149
+ *,
150
+ model,
151
+ use_fp16=False,
152
+ fp16_scale_growth=1e-3,
153
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
154
+ ):
155
+ self.model = model
156
+ self.use_fp16 = use_fp16
157
+ self.fp16_scale_growth = fp16_scale_growth
158
+
159
+ self.model_params = list(self.model.parameters())
160
+ self.master_params = self.model_params
161
+ self.param_groups_and_shapes = None
162
+ self.lg_loss_scale = initial_lg_loss_scale
163
+
164
+ if self.use_fp16:
165
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
166
+ self.model.named_parameters()
167
+ )
168
+ self.master_params = make_master_params(self.param_groups_and_shapes)
169
+ self.model.convert_to_fp16()
170
+
171
+ def zero_grad(self):
172
+ zero_grad(self.model_params)
173
+
174
+ def backward(self, loss: th.Tensor):
175
+ if self.use_fp16:
176
+ loss_scale = 2 ** self.lg_loss_scale
177
+ (loss * loss_scale).backward()
178
+ else:
179
+ loss.backward()
180
+
181
+ def optimize(self, opt: th.optim.Optimizer):
182
+ if self.use_fp16:
183
+ return self._optimize_fp16(opt)
184
+ else:
185
+ return self._optimize_normal(opt)
186
+
187
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
188
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
189
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
190
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
191
+ if check_overflow(grad_norm):
192
+ self.lg_loss_scale -= 1
193
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
194
+ zero_master_grads(self.master_params)
195
+ return False
196
+
197
+ logger.logkv_mean("grad_norm", grad_norm)
198
+ logger.logkv_mean("param_norm", param_norm)
199
+
200
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
201
+ opt.step()
202
+ zero_master_grads(self.master_params)
203
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
204
+ self.lg_loss_scale += self.fp16_scale_growth
205
+ return True
206
+
207
+ def _optimize_normal(self, opt: th.optim.Optimizer):
208
+ grad_norm, param_norm = self._compute_norms()
209
+ logger.logkv_mean("grad_norm", grad_norm)
210
+ logger.logkv_mean("param_norm", param_norm)
211
+ opt.step()
212
+ return True
213
+
214
+ def _compute_norms(self, grad_scale=1.0):
215
+ grad_norm = 0.0
216
+ param_norm = 0.0
217
+ for p in self.master_params:
218
+ with th.no_grad():
219
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
220
+ if p.grad is not None:
221
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
222
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
223
+
224
+ def master_params_to_state_dict(self, master_params):
225
+ return master_params_to_state_dict(
226
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
227
+ )
228
+
229
+ def state_dict_to_master_params(self, state_dict):
230
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
231
+
232
+
233
+ def check_overflow(value):
234
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
diffusion-posterior-sampling/guided_diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from functools import partial
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ from tqdm.auto import tqdm
8
+
9
+ from util.img_utils import clear_color
10
+ from .posterior_mean_variance import get_mean_processor, get_var_processor
11
+
12
+
13
+
14
+ __SAMPLER__ = {}
15
+
16
+ def register_sampler(name: str):
17
+ def wrapper(cls):
18
+ if __SAMPLER__.get(name, None):
19
+ raise NameError(f"Name {name} is already registered!")
20
+ __SAMPLER__[name] = cls
21
+ return cls
22
+ return wrapper
23
+
24
+
25
+ def get_sampler(name: str):
26
+ if __SAMPLER__.get(name, None) is None:
27
+ raise NameError(f"Name {name} is not defined!")
28
+ return __SAMPLER__[name]
29
+
30
+
31
+ def create_sampler(sampler,
32
+ steps,
33
+ noise_schedule,
34
+ model_mean_type,
35
+ model_var_type,
36
+ dynamic_threshold,
37
+ clip_denoised,
38
+ rescale_timesteps,
39
+ timestep_respacing=""):
40
+
41
+ sampler = get_sampler(name=sampler)
42
+
43
+ betas = get_named_beta_schedule(noise_schedule, steps)
44
+ if not timestep_respacing:
45
+ timestep_respacing = [steps]
46
+
47
+ return sampler(use_timesteps=space_timesteps(steps, timestep_respacing),
48
+ betas=betas,
49
+ model_mean_type=model_mean_type,
50
+ model_var_type=model_var_type,
51
+ dynamic_threshold=dynamic_threshold,
52
+ clip_denoised=clip_denoised,
53
+ rescale_timesteps=rescale_timesteps)
54
+
55
+
56
+ class GaussianDiffusion:
57
+ def __init__(self,
58
+ betas,
59
+ model_mean_type,
60
+ model_var_type,
61
+ dynamic_threshold,
62
+ clip_denoised,
63
+ rescale_timesteps
64
+ ):
65
+
66
+ # use float64 for accuracy.
67
+ betas = np.array(betas, dtype=np.float64)
68
+ self.betas = betas
69
+ assert self.betas.ndim == 1, "betas must be 1-D"
70
+ assert (0 < self.betas).all() and (self.betas <=1).all(), "betas must be in (0..1]"
71
+
72
+ self.num_timesteps = int(self.betas.shape[0])
73
+ self.rescale_timesteps = rescale_timesteps
74
+
75
+ alphas = 1.0 - self.betas
76
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
77
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
78
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
79
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
80
+
81
+ # calculations for diffusion q(x_t | x_{t-1}) and others
82
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
83
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
84
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
85
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
86
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
87
+
88
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
89
+ self.posterior_variance = (
90
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
91
+ )
92
+ # log calculation clipped because the posterior variance is 0 at the
93
+ # beginning of the diffusion chain.
94
+ self.posterior_log_variance_clipped = np.log(
95
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
96
+ )
97
+ self.posterior_mean_coef1 = (
98
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
99
+ )
100
+ self.posterior_mean_coef2 = (
101
+ (1.0 - self.alphas_cumprod_prev)
102
+ * np.sqrt(alphas)
103
+ / (1.0 - self.alphas_cumprod)
104
+ )
105
+
106
+ self.mean_processor = get_mean_processor(model_mean_type,
107
+ betas=betas,
108
+ dynamic_threshold=dynamic_threshold,
109
+ clip_denoised=clip_denoised)
110
+
111
+ self.var_processor = get_var_processor(model_var_type,
112
+ betas=betas)
113
+
114
+ def q_mean_variance(self, x_start, t):
115
+ """
116
+ Get the distribution q(x_t | x_0).
117
+
118
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
119
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
120
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
121
+ """
122
+
123
+ mean = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start) * x_start
124
+ variance = extract_and_expand(1.0 - self.alphas_cumprod, t, x_start)
125
+ log_variance = extract_and_expand(self.log_one_minus_alphas_cumprod, t, x_start)
126
+
127
+ return mean, variance, log_variance
128
+
129
+ def q_sample(self, x_start, t):
130
+ """
131
+ Diffuse the data for a given number of diffusion steps.
132
+
133
+ In other words, sample from q(x_t | x_0).
134
+
135
+ :param x_start: the initial data batch.
136
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
137
+ :param noise: if specified, the split-out normal noise.
138
+ :return: A noisy version of x_start.
139
+ """
140
+ noise = torch.randn_like(x_start)
141
+ assert noise.shape == x_start.shape
142
+
143
+ coef1 = extract_and_expand(self.sqrt_alphas_cumprod, t, x_start)
144
+ coef2 = extract_and_expand(self.sqrt_one_minus_alphas_cumprod, t, x_start)
145
+
146
+ return coef1 * x_start + coef2 * noise
147
+
148
+ def q_posterior_mean_variance(self, x_start, x_t, t):
149
+ """
150
+ Compute the mean and variance of the diffusion posterior:
151
+
152
+ q(x_{t-1} | x_t, x_0)
153
+
154
+ """
155
+ assert x_start.shape == x_t.shape
156
+ coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
157
+ coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
158
+ posterior_mean = coef1 * x_start + coef2 * x_t
159
+ posterior_variance = extract_and_expand(self.posterior_variance, t, x_t)
160
+ posterior_log_variance_clipped = extract_and_expand(self.posterior_log_variance_clipped, t, x_t)
161
+
162
+ assert (
163
+ posterior_mean.shape[0]
164
+ == posterior_variance.shape[0]
165
+ == posterior_log_variance_clipped.shape[0]
166
+ == x_start.shape[0]
167
+ )
168
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
169
+
170
+ def p_sample_loop(self,
171
+ model,
172
+ x_start,
173
+ measurement,
174
+ measurement_cond_fn,
175
+ record,
176
+ save_root):
177
+ """
178
+ The function used for sampling from noise.
179
+ """
180
+ img = x_start
181
+ device = x_start.device
182
+
183
+ pbar = tqdm(list(range(self.num_timesteps))[::-1])
184
+ for idx in pbar:
185
+ time = torch.tensor([idx] * img.shape[0], device=device)
186
+
187
+ img = img.requires_grad_()
188
+ out = self.p_sample(x=img, t=time, model=model)
189
+
190
+ # Give condition.
191
+ noisy_measurement = self.q_sample(measurement, t=time)
192
+
193
+ # TODO: how can we handle argument for different condition method?
194
+ img, distance = measurement_cond_fn(x_t=out['sample'],
195
+ measurement=measurement,
196
+ noisy_measurement=noisy_measurement,
197
+ x_prev=img,
198
+ x_0_hat=out['pred_xstart'])
199
+ img = img.detach_()
200
+
201
+ pbar.set_postfix({'distance': distance.item()}, refresh=False)
202
+ if record:
203
+ if idx % 10 == 0:
204
+ file_path = os.path.join(save_root, f"progress/x_{str(idx).zfill(4)}.png")
205
+ plt.imsave(file_path, clear_color(img))
206
+
207
+ return img
208
+
209
+ def p_sample(self, model, x, t):
210
+ raise NotImplementedError
211
+
212
+ def p_mean_variance(self, model, x, t):
213
+ model_output = model(x, self._scale_timesteps(t))
214
+
215
+ # In the case of "learned" variance, model will give twice channels.
216
+ if model_output.shape[1] == 2 * x.shape[1]:
217
+ model_output, model_var_values = torch.split(model_output, x.shape[1], dim=1)
218
+ else:
219
+ # The name of variable is wrong.
220
+ # This will just provide shape information, and
221
+ # will not be used for calculating something important in variance.
222
+ model_var_values = model_output
223
+
224
+ model_mean, pred_xstart = self.mean_processor.get_mean_and_xstart(x, t, model_output)
225
+ model_variance, model_log_variance = self.var_processor.get_variance(model_var_values, t)
226
+
227
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
228
+
229
+ return {'mean': model_mean,
230
+ 'variance': model_variance,
231
+ 'log_variance': model_log_variance,
232
+ 'pred_xstart': pred_xstart}
233
+
234
+
235
+ def _scale_timesteps(self, t):
236
+ if self.rescale_timesteps:
237
+ return t.float() * (1000.0 / self.num_timesteps)
238
+ return t
239
+
240
+ def space_timesteps(num_timesteps, section_counts):
241
+ """
242
+ Create a list of timesteps to use from an original diffusion process,
243
+ given the number of timesteps we want to take from equally-sized portions
244
+ of the original process.
245
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
246
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
247
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
248
+ If the stride is a string starting with "ddim", then the fixed striding
249
+ from the DDIM paper is used, and only one section is allowed.
250
+ :param num_timesteps: the number of diffusion steps in the original
251
+ process to divide up.
252
+ :param section_counts: either a list of numbers, or a string containing
253
+ comma-separated numbers, indicating the step count
254
+ per section. As a special case, use "ddimN" where N
255
+ is a number of steps to use the striding from the
256
+ DDIM paper.
257
+ :return: a set of diffusion steps from the original process to use.
258
+ """
259
+ if isinstance(section_counts, str):
260
+ if section_counts.startswith("ddim"):
261
+ desired_count = int(section_counts[len("ddim") :])
262
+ for i in range(1, num_timesteps):
263
+ if len(range(0, num_timesteps, i)) == desired_count:
264
+ return set(range(0, num_timesteps, i))
265
+ raise ValueError(
266
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
267
+ )
268
+ section_counts = [int(x) for x in section_counts.split(",")]
269
+ elif isinstance(section_counts, int):
270
+ section_counts = [section_counts]
271
+
272
+ size_per = num_timesteps // len(section_counts)
273
+ extra = num_timesteps % len(section_counts)
274
+ start_idx = 0
275
+ all_steps = []
276
+ for i, section_count in enumerate(section_counts):
277
+ size = size_per + (1 if i < extra else 0)
278
+ if size < section_count:
279
+ raise ValueError(
280
+ f"cannot divide section of {size} steps into {section_count}"
281
+ )
282
+ if section_count <= 1:
283
+ frac_stride = 1
284
+ else:
285
+ frac_stride = (size - 1) / (section_count - 1)
286
+ cur_idx = 0.0
287
+ taken_steps = []
288
+ for _ in range(section_count):
289
+ taken_steps.append(start_idx + round(cur_idx))
290
+ cur_idx += frac_stride
291
+ all_steps += taken_steps
292
+ start_idx += size
293
+ return set(all_steps)
294
+
295
+
296
+ class SpacedDiffusion(GaussianDiffusion):
297
+ """
298
+ A diffusion process which can skip steps in a base diffusion process.
299
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
300
+ original diffusion process to retain.
301
+ :param kwargs: the kwargs to create the base diffusion process.
302
+ """
303
+
304
+ def __init__(self, use_timesteps, **kwargs):
305
+ self.use_timesteps = set(use_timesteps)
306
+ self.timestep_map = []
307
+ self.original_num_steps = len(kwargs["betas"])
308
+
309
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
310
+ last_alpha_cumprod = 1.0
311
+ new_betas = []
312
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
313
+ if i in self.use_timesteps:
314
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
315
+ last_alpha_cumprod = alpha_cumprod
316
+ self.timestep_map.append(i)
317
+ kwargs["betas"] = np.array(new_betas)
318
+ super().__init__(**kwargs)
319
+
320
+ def p_mean_variance(
321
+ self, model, *args, **kwargs
322
+ ): # pylint: disable=signature-differs
323
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
324
+
325
+ def training_losses(
326
+ self, model, *args, **kwargs
327
+ ): # pylint: disable=signature-differs
328
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
329
+
330
+ def condition_mean(self, cond_fn, *args, **kwargs):
331
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
332
+
333
+ def condition_score(self, cond_fn, *args, **kwargs):
334
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
335
+
336
+ def _wrap_model(self, model):
337
+ if isinstance(model, _WrappedModel):
338
+ return model
339
+ return _WrappedModel(
340
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
341
+ )
342
+
343
+ def _scale_timesteps(self, t):
344
+ # Scaling is done by the wrapped model.
345
+ return t
346
+
347
+
348
+ class _WrappedModel:
349
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
350
+ self.model = model
351
+ self.timestep_map = timestep_map
352
+ self.rescale_timesteps = rescale_timesteps
353
+ self.original_num_steps = original_num_steps
354
+
355
+ def __call__(self, x, ts, **kwargs):
356
+ map_tensor = torch.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
357
+ new_ts = map_tensor[ts]
358
+ if self.rescale_timesteps:
359
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
360
+ return self.model(x, new_ts, **kwargs)
361
+
362
+
363
+ @register_sampler(name='ddpm')
364
+ class DDPM(SpacedDiffusion):
365
+ def p_sample(self, model, x, t):
366
+ out = self.p_mean_variance(model, x, t)
367
+ sample = out['mean']
368
+
369
+ noise = torch.randn_like(x)
370
+ if t != 0: # no noise when t == 0
371
+ sample += torch.exp(0.5 * out['log_variance']) * noise
372
+
373
+ return {'sample': sample, 'pred_xstart': out['pred_xstart']}
374
+
375
+
376
+ @register_sampler(name='ddim')
377
+ class DDIM(SpacedDiffusion):
378
+ def p_sample(self, model, x, t, eta=0.0):
379
+ out = self.p_mean_variance(model, x, t)
380
+
381
+ eps = self.predict_eps_from_x_start(x, t, out['pred_xstart'])
382
+
383
+ alpha_bar = extract_and_expand(self.alphas_cumprod, t, x)
384
+ alpha_bar_prev = extract_and_expand(self.alphas_cumprod_prev, t, x)
385
+ sigma = (
386
+ eta
387
+ * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
388
+ * torch.sqrt(1 - alpha_bar / alpha_bar_prev)
389
+ )
390
+ # Equation 12.
391
+ noise = torch.randn_like(x)
392
+ mean_pred = (
393
+ out["pred_xstart"] * torch.sqrt(alpha_bar_prev)
394
+ + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
395
+ )
396
+
397
+ sample = mean_pred
398
+ if t != 0:
399
+ sample += sigma * noise
400
+
401
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
402
+
403
+ def predict_eps_from_x_start(self, x_t, t, pred_xstart):
404
+ coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
405
+ coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, x_t)
406
+ return (coef1 * x_t - pred_xstart) / coef2
407
+
408
+
409
+ # =================
410
+ # Helper functions
411
+ # =================
412
+
413
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
414
+ """
415
+ Get a pre-defined beta schedule for the given name.
416
+
417
+ The beta schedule library consists of beta schedules which remain similar
418
+ in the limit of num_diffusion_timesteps.
419
+ Beta schedules may be added, but should not be removed or changed once
420
+ they are committed to maintain backwards compatibility.
421
+ """
422
+ if schedule_name == "linear":
423
+ # Linear schedule from Ho et al, extended to work for any number of
424
+ # diffusion steps.
425
+ scale = 1000 / num_diffusion_timesteps
426
+ beta_start = scale * 0.0001
427
+ beta_end = scale * 0.02
428
+ return np.linspace(
429
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
430
+ )
431
+ elif schedule_name == "cosine":
432
+ return betas_for_alpha_bar(
433
+ num_diffusion_timesteps,
434
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
435
+ )
436
+ else:
437
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
438
+
439
+
440
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
441
+ """
442
+ Create a beta schedule that discretizes the given alpha_t_bar function,
443
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
444
+
445
+ :param num_diffusion_timesteps: the number of betas to produce.
446
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
447
+ produces the cumulative product of (1-beta) up to that
448
+ part of the diffusion process.
449
+ :param max_beta: the maximum beta to use; use values lower than 1 to
450
+ prevent singularities.
451
+ """
452
+ betas = []
453
+ for i in range(num_diffusion_timesteps):
454
+ t1 = i / num_diffusion_timesteps
455
+ t2 = (i + 1) / num_diffusion_timesteps
456
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
457
+ return np.array(betas)
458
+
459
+ # ================
460
+ # Helper function
461
+ # ================
462
+
463
+ def extract_and_expand(array, time, target):
464
+ array = torch.from_numpy(array).to(target.device)[time].float()
465
+ while array.ndim < target.ndim:
466
+ array = array.unsqueeze(-1)
467
+ return array.expand_as(target)
468
+
469
+
470
+ def expand_as(array, target):
471
+ if isinstance(array, np.ndarray):
472
+ array = torch.from_numpy(array)
473
+ elif isinstance(array, np.float):
474
+ array = torch.tensor([array])
475
+
476
+ while array.ndim < target.ndim:
477
+ array = array.unsqueeze(-1)
478
+
479
+ return array.expand_as(target).to(target.device)
480
+
481
+
482
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
483
+ """
484
+ Extract values from a 1-D numpy array for a batch of indices.
485
+
486
+ :param arr: the 1-D numpy array.
487
+ :param timesteps: a tensor of indices into the array to extract.
488
+ :param broadcast_shape: a larger shape of K dimensions with the batch
489
+ dimension equal to the length of timesteps.
490
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
491
+ """
492
+ res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
493
+ while len(res.shape) < len(broadcast_shape):
494
+ res = res[..., None]
495
+ return res.expand(broadcast_shape)
diffusion-posterior-sampling/guided_diffusion/measurements.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''This module handles task-dependent operations (A) and noises (n) to simulate a measurement y=Ax+n.'''
2
+
3
+ from abc import ABC, abstractmethod
4
+ from functools import partial
5
+ import yaml
6
+ from torch.nn import functional as F
7
+ from torchvision import torch
8
+ from motionblur.motionblur import Kernel
9
+
10
+ from util.resizer import Resizer
11
+ from util.img_utils import Blurkernel, fft2_m
12
+
13
+
14
+ # =================
15
+ # Operation classes
16
+ # =================
17
+
18
+ __OPERATOR__ = {}
19
+
20
+ def register_operator(name: str):
21
+ def wrapper(cls):
22
+ if __OPERATOR__.get(name, None):
23
+ raise NameError(f"Name {name} is already registered!")
24
+ __OPERATOR__[name] = cls
25
+ return cls
26
+ return wrapper
27
+
28
+
29
+ def get_operator(name: str, **kwargs):
30
+ if __OPERATOR__.get(name, None) is None:
31
+ raise NameError(f"Name {name} is not defined.")
32
+ return __OPERATOR__[name](**kwargs)
33
+
34
+
35
+ class LinearOperator(ABC):
36
+ @abstractmethod
37
+ def forward(self, data, **kwargs):
38
+ # calculate A * X
39
+ pass
40
+
41
+ @abstractmethod
42
+ def transpose(self, data, **kwargs):
43
+ # calculate A^T * X
44
+ pass
45
+
46
+ def ortho_project(self, data, **kwargs):
47
+ # calculate (I - A^T * A)X
48
+ return data - self.transpose(self.forward(data, **kwargs), **kwargs)
49
+
50
+ def project(self, data, measurement, **kwargs):
51
+ # calculate (I - A^T * A)Y - AX
52
+ return self.ortho_project(measurement, **kwargs) - self.forward(data, **kwargs)
53
+
54
+
55
+ @register_operator(name='noise')
56
+ class DenoiseOperator(LinearOperator):
57
+ def __init__(self, device):
58
+ self.device = device
59
+
60
+ def forward(self, data):
61
+ return data
62
+
63
+ def transpose(self, data):
64
+ return data
65
+
66
+ def ortho_project(self, data):
67
+ return data
68
+
69
+ def project(self, data):
70
+ return data
71
+
72
+
73
+ @register_operator(name='super_resolution')
74
+ class SuperResolutionOperator(LinearOperator):
75
+ def __init__(self, in_shape, scale_factor, device):
76
+ self.device = device
77
+ self.up_sample = partial(F.interpolate, scale_factor=scale_factor)
78
+ self.down_sample = Resizer(in_shape, 1/scale_factor).to(device)
79
+
80
+ def forward(self, data, **kwargs):
81
+ return self.down_sample(data)
82
+
83
+ def transpose(self, data, **kwargs):
84
+ return self.up_sample(data)
85
+
86
+ def project(self, data, measurement, **kwargs):
87
+ return data - self.transpose(self.forward(data)) + self.transpose(measurement)
88
+
89
+ @register_operator(name='motion_blur')
90
+ class MotionBlurOperator(LinearOperator):
91
+ def __init__(self, kernel_size, intensity, device):
92
+ self.device = device
93
+ self.kernel_size = kernel_size
94
+ self.conv = Blurkernel(blur_type='motion',
95
+ kernel_size=kernel_size,
96
+ std=intensity,
97
+ device=device).to(device) # should we keep this device term?
98
+
99
+ self.kernel = Kernel(size=(kernel_size, kernel_size), intensity=intensity)
100
+ kernel = torch.tensor(self.kernel.kernelMatrix, dtype=torch.float32)
101
+ self.conv.update_weights(kernel)
102
+
103
+ def forward(self, data, **kwargs):
104
+ # A^T * A
105
+ return self.conv(data)
106
+
107
+ def transpose(self, data, **kwargs):
108
+ return data
109
+
110
+ def get_kernel(self):
111
+ kernel = self.kernel.kernelMatrix.type(torch.float32).to(self.device)
112
+ return kernel.view(1, 1, self.kernel_size, self.kernel_size)
113
+
114
+
115
+ @register_operator(name='gaussian_blur')
116
+ class GaussialBlurOperator(LinearOperator):
117
+ def __init__(self, kernel_size, intensity, device):
118
+ self.device = device
119
+ self.kernel_size = kernel_size
120
+ self.conv = Blurkernel(blur_type='gaussian',
121
+ kernel_size=kernel_size,
122
+ std=intensity,
123
+ device=device).to(device)
124
+ self.kernel = self.conv.get_kernel()
125
+ self.conv.update_weights(self.kernel.type(torch.float32))
126
+
127
+ def forward(self, data, **kwargs):
128
+ return self.conv(data)
129
+
130
+ def transpose(self, data, **kwargs):
131
+ return data
132
+
133
+ def get_kernel(self):
134
+ return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
135
+
136
+ @register_operator(name='inpainting')
137
+ class InpaintingOperator(LinearOperator):
138
+ '''This operator get pre-defined mask and return masked image.'''
139
+ def __init__(self, device):
140
+ self.device = device
141
+
142
+ def forward(self, data, **kwargs):
143
+ try:
144
+ return data * kwargs.get('mask', None).to(self.device)
145
+ except:
146
+ raise ValueError("Require mask")
147
+
148
+ def transpose(self, data, **kwargs):
149
+ return data
150
+
151
+ def ortho_project(self, data, **kwargs):
152
+ return data - self.forward(data, **kwargs)
153
+
154
+
155
+ class NonLinearOperator(ABC):
156
+ @abstractmethod
157
+ def forward(self, data, **kwargs):
158
+ pass
159
+
160
+ def project(self, data, measurement, **kwargs):
161
+ return data + measurement - self.forward(data)
162
+
163
+ @register_operator(name='phase_retrieval')
164
+ class PhaseRetrievalOperator(NonLinearOperator):
165
+ def __init__(self, oversample, device):
166
+ self.pad = int((oversample / 8.0) * 256)
167
+ self.device = device
168
+
169
+ def forward(self, data, **kwargs):
170
+ padded = F.pad(data, (self.pad, self.pad, self.pad, self.pad))
171
+ amplitude = fft2_m(padded).abs()
172
+ return amplitude
173
+
174
+ @register_operator(name='nonlinear_blur')
175
+ class NonlinearBlurOperator(NonLinearOperator):
176
+ def __init__(self, opt_yml_path, device):
177
+ self.device = device
178
+ self.blur_model = self.prepare_nonlinear_blur_model(opt_yml_path)
179
+
180
+ def prepare_nonlinear_blur_model(self, opt_yml_path):
181
+ '''
182
+ Nonlinear deblur requires external codes (bkse).
183
+ '''
184
+ from bkse.models.kernel_encoding.kernel_wizard import KernelWizard
185
+
186
+ with open(opt_yml_path, "r") as f:
187
+ opt = yaml.safe_load(f)["KernelWizard"]
188
+ model_path = opt["pretrained"]
189
+ blur_model = KernelWizard(opt)
190
+ blur_model.eval()
191
+ blur_model.load_state_dict(torch.load(model_path))
192
+ blur_model = blur_model.to(self.device)
193
+ return blur_model
194
+
195
+ def forward(self, data, **kwargs):
196
+ random_kernel = torch.randn(1, 512, 2, 2).to(self.device) * 1.2
197
+ data = (data + 1.0) / 2.0 #[-1, 1] -> [0, 1]
198
+ blurred = self.blur_model.adaptKernel(data, kernel=random_kernel)
199
+ blurred = (blurred * 2.0 - 1.0).clamp(-1, 1) #[0, 1] -> [-1, 1]
200
+ return blurred
201
+
202
+ # =============
203
+ # Noise classes
204
+ # =============
205
+
206
+
207
+ __NOISE__ = {}
208
+
209
+ def register_noise(name: str):
210
+ def wrapper(cls):
211
+ if __NOISE__.get(name, None):
212
+ raise NameError(f"Name {name} is already defined!")
213
+ __NOISE__[name] = cls
214
+ return cls
215
+ return wrapper
216
+
217
+ def get_noise(name: str, **kwargs):
218
+ if __NOISE__.get(name, None) is None:
219
+ raise NameError(f"Name {name} is not defined.")
220
+ noiser = __NOISE__[name](**kwargs)
221
+ noiser.__name__ = name
222
+ return noiser
223
+
224
+ class Noise(ABC):
225
+ def __call__(self, data):
226
+ return self.forward(data)
227
+
228
+ @abstractmethod
229
+ def forward(self, data):
230
+ pass
231
+
232
+ @register_noise(name='clean')
233
+ class Clean(Noise):
234
+ def forward(self, data):
235
+ return data
236
+
237
+ @register_noise(name='gaussian')
238
+ class GaussianNoise(Noise):
239
+ def __init__(self, sigma):
240
+ self.sigma = sigma
241
+
242
+ def forward(self, data):
243
+ return data + torch.randn_like(data, device=data.device) * self.sigma
244
+
245
+
246
+ @register_noise(name='poisson')
247
+ class PoissonNoise(Noise):
248
+ def __init__(self, rate):
249
+ self.rate = rate
250
+
251
+ def forward(self, data):
252
+ '''
253
+ Follow skimage.util.random_noise.
254
+ '''
255
+
256
+ # TODO: set one version of poisson
257
+
258
+ # version 3 (stack-overflow)
259
+ import numpy as np
260
+ data = (data + 1.0) / 2.0
261
+ data = data.clamp(0, 1)
262
+ device = data.device
263
+ data = data.detach().cpu()
264
+ data = torch.from_numpy(np.random.poisson(data * 255.0 * self.rate) / 255.0 / self.rate)
265
+ data = data * 2.0 - 1.0
266
+ data = data.clamp(-1, 1)
267
+ return data.to(device)
268
+
269
+ # version 2 (skimage)
270
+ # if data.min() < 0:
271
+ # low_clip = -1
272
+ # else:
273
+ # low_clip = 0
274
+
275
+
276
+ # # Determine unique values in iamge & calculate the next power of two
277
+ # vals = torch.Tensor([len(torch.unique(data))])
278
+ # vals = 2 ** torch.ceil(torch.log2(vals))
279
+ # vals = vals.to(data.device)
280
+
281
+ # if low_clip == -1:
282
+ # old_max = data.max()
283
+ # data = (data + 1.0) / (old_max + 1.0)
284
+
285
+ # data = torch.poisson(data * vals) / float(vals)
286
+
287
+ # if low_clip == -1:
288
+ # data = data * (old_max + 1.0) - 1.0
289
+
290
+ # return data.clamp(low_clip, 1.0)
diffusion-posterior-sampling/guided_diffusion/nn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Various utilities for neural networks.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch as th
8
+ import torch.nn as nn
9
+
10
+
11
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12
+ class SiLU(nn.Module):
13
+ def forward(self, x):
14
+ return x * th.sigmoid(x)
15
+
16
+
17
+ class GroupNorm32(nn.GroupNorm):
18
+ def forward(self, x):
19
+ return super().forward(x.float()).type(x.dtype)
20
+
21
+
22
+ def conv_nd(dims, *args, **kwargs):
23
+ """
24
+ Create a 1D, 2D, or 3D convolution module.
25
+ """
26
+ if dims == 1:
27
+ return nn.Conv1d(*args, **kwargs)
28
+ elif dims == 2:
29
+ return nn.Conv2d(*args, **kwargs)
30
+ elif dims == 3:
31
+ return nn.Conv3d(*args, **kwargs)
32
+ raise ValueError(f"unsupported dimensions: {dims}")
33
+
34
+
35
+ def linear(*args, **kwargs):
36
+ """
37
+ Create a linear module.
38
+ """
39
+ return nn.Linear(*args, **kwargs)
40
+
41
+
42
+ def avg_pool_nd(dims, *args, **kwargs):
43
+ """
44
+ Create a 1D, 2D, or 3D average pooling module.
45
+ """
46
+ if dims == 1:
47
+ return nn.AvgPool1d(*args, **kwargs)
48
+ elif dims == 2:
49
+ return nn.AvgPool2d(*args, **kwargs)
50
+ elif dims == 3:
51
+ return nn.AvgPool3d(*args, **kwargs)
52
+ raise ValueError(f"unsupported dimensions: {dims}")
53
+
54
+
55
+ def update_ema(target_params, source_params, rate=0.99):
56
+ """
57
+ Update target parameters to be closer to those of source parameters using
58
+ an exponential moving average.
59
+
60
+ :param target_params: the target parameter sequence.
61
+ :param source_params: the source parameter sequence.
62
+ :param rate: the EMA rate (closer to 1 means slower).
63
+ """
64
+ for targ, src in zip(target_params, source_params):
65
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66
+
67
+
68
+ def zero_module(module):
69
+ """
70
+ Zero out the parameters of a module and return it.
71
+ """
72
+ for p in module.parameters():
73
+ p.detach().zero_()
74
+ return module
75
+
76
+
77
+ def scale_module(module, scale):
78
+ """
79
+ Scale the parameters of a module and return it.
80
+ """
81
+ for p in module.parameters():
82
+ p.detach().mul_(scale)
83
+ return module
84
+
85
+
86
+ def mean_flat(tensor):
87
+ """
88
+ Take the mean over all non-batch dimensions.
89
+ """
90
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
91
+
92
+
93
+ def normalization(channels):
94
+ """
95
+ Make a standard normalization layer.
96
+
97
+ :param channels: number of input channels.
98
+ :return: an nn.Module for normalization.
99
+ """
100
+ return GroupNorm32(32, channels)
101
+
102
+
103
+ def timestep_embedding(timesteps, dim, max_period=10000):
104
+ """
105
+ Create sinusoidal timestep embeddings.
106
+
107
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
108
+ These may be fractional.
109
+ :param dim: the dimension of the output.
110
+ :param max_period: controls the minimum frequency of the embeddings.
111
+ :return: an [N x dim] Tensor of positional embeddings.
112
+ """
113
+ half = dim // 2
114
+ freqs = th.exp(
115
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116
+ ).to(device=timesteps.device)
117
+ args = timesteps[:, None].float() * freqs[None]
118
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121
+ return embedding
122
+
123
+
124
+ def checkpoint(func, inputs, params, flag):
125
+ """
126
+ Evaluate a function without caching intermediate activations, allowing for
127
+ reduced memory at the expense of extra compute in the backward pass.
128
+
129
+ :param func: the function to evaluate.
130
+ :param inputs: the argument sequence to pass to `func`.
131
+ :param params: a sequence of parameters `func` depends on but does not
132
+ explicitly take as arguments.
133
+ :param flag: if False, disable gradient checkpointing.
134
+ """
135
+ if flag:
136
+ args = tuple(inputs) + tuple(params)
137
+ return CheckpointFunction.apply(func, len(inputs), *args)
138
+ else:
139
+ return func(*inputs)
140
+
141
+
142
+ class CheckpointFunction(th.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, run_function, length, *args):
145
+ ctx.run_function = run_function
146
+ ctx.input_tensors = list(args[:length])
147
+ ctx.input_params = list(args[length:])
148
+ with th.no_grad():
149
+ output_tensors = ctx.run_function(*ctx.input_tensors)
150
+ return output_tensors
151
+
152
+ @staticmethod
153
+ def backward(ctx, *output_grads):
154
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155
+ with th.enable_grad():
156
+ # Fixes a bug where the first op in run_function modifies the
157
+ # Tensor storage in place, which is not allowed for detach()'d
158
+ # Tensors.
159
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160
+ output_tensors = ctx.run_function(*shallow_copies)
161
+ input_grads = th.autograd.grad(
162
+ output_tensors,
163
+ ctx.input_tensors + ctx.input_params,
164
+ output_grads,
165
+ allow_unused=True,
166
+ )
167
+ del ctx.input_tensors
168
+ del ctx.input_params
169
+ del output_tensors
170
+ return (None, None) + input_grads
diffusion-posterior-sampling/guided_diffusion/posterior_mean_variance.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from util.img_utils import dynamic_thresholding
7
+
8
+
9
+
10
+ # ====================
11
+ # Model Mean Processor
12
+ # ====================
13
+
14
+ __MODEL_MEAN_PROCESSOR__ = {}
15
+
16
+ def register_mean_processor(name: str):
17
+ def wrapper(cls):
18
+ if __MODEL_MEAN_PROCESSOR__.get(name, None):
19
+ raise NameError(f"Name {name} is already registerd.")
20
+ __MODEL_MEAN_PROCESSOR__[name] = cls
21
+ return cls
22
+ return wrapper
23
+
24
+ def get_mean_processor(name: str, **kwargs):
25
+ if __MODEL_MEAN_PROCESSOR__.get(name, None) is None:
26
+ raise NameError(f"Name {name} is not defined.")
27
+ return __MODEL_MEAN_PROCESSOR__[name](**kwargs)
28
+
29
+ class MeanProcessor(ABC):
30
+ """Predict x_start and calculate mean value"""
31
+ @abstractmethod
32
+ def __init__(self, betas, dynamic_threshold, clip_denoised):
33
+ self.dynamic_threshold = dynamic_threshold
34
+ self.clip_denoised = clip_denoised
35
+
36
+ @abstractmethod
37
+ def get_mean_and_xstart(self, x, t, model_output):
38
+ pass
39
+
40
+ def process_xstart(self, x):
41
+ if self.dynamic_threshold:
42
+ x = dynamic_thresholding(x, s=0.95)
43
+ if self.clip_denoised:
44
+ x = x.clamp(-1, 1)
45
+ return x
46
+
47
+ @register_mean_processor(name='previous_x')
48
+ class PreviousXMeanProcessor(MeanProcessor):
49
+ def __init__(self, betas, dynamic_threshold, clip_denoised):
50
+ super().__init__(betas, dynamic_threshold, clip_denoised)
51
+ alphas = 1.0 - betas
52
+ alphas_cumprod = np.cumprod(alphas, axis=0)
53
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
54
+
55
+ self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
56
+ self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
57
+
58
+ def predict_xstart(self, x_t, t, x_prev):
59
+ coef1 = extract_and_expand(1.0/self.posterior_mean_coef1, t, x_t)
60
+ coef2 = extract_and_expand(self.posterior_mean_coef2/self.posterior_mean_coef1, t, x_t)
61
+ return coef1 * x_prev - coef2 * x_t
62
+
63
+ def get_mean_and_xstart(self, x, t, model_output):
64
+ mean = model_output
65
+ pred_xstart = self.process_xstart(self.predict_xstart(x, t, model_output))
66
+ return mean, pred_xstart
67
+
68
+ @register_mean_processor(name='start_x')
69
+ class StartXMeanProcessor(MeanProcessor):
70
+ def __init__(self, betas, dynamic_threshold, clip_denoised):
71
+ super().__init__(betas, dynamic_threshold, clip_denoised)
72
+ alphas = 1.0 - betas
73
+ alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
75
+
76
+ self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
77
+ self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
78
+
79
+ def q_posterior_mean(self, x_start, x_t, t):
80
+ """
81
+ Compute the mean of the diffusion posteriro:
82
+ q(x_{t-1} | x_t, x_0)
83
+ """
84
+ assert x_start.shape == x_t.shape
85
+ coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
86
+ coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
87
+
88
+ return coef1 * x_start + coef2 * x_t
89
+
90
+ def get_mean_and_xstart(self, x, t, model_output):
91
+ pred_xstart = self.process_xstart(model_output)
92
+ mean = self.q_posterior_mean(x_start=pred_xstart, x_t=x, t=t)
93
+
94
+ return mean, pred_xstart
95
+
96
+ @register_mean_processor(name='epsilon')
97
+ class EpsilonXMeanProcessor(MeanProcessor):
98
+ def __init__(self, betas, dynamic_threshold, clip_denoised):
99
+ super().__init__(betas, dynamic_threshold, clip_denoised)
100
+ alphas = 1.0 - betas
101
+ alphas_cumprod = np.cumprod(alphas, axis=0)
102
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
103
+
104
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
105
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
106
+ self.posterior_mean_coef1 = betas * np.sqrt(alphas_cumprod_prev) / (1.0-alphas_cumprod)
107
+ self.posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
108
+
109
+
110
+ def q_posterior_mean(self, x_start, x_t, t):
111
+ """
112
+ Compute the mean of the diffusion posteriro:
113
+ q(x_{t-1} | x_t, x_0)
114
+ """
115
+ assert x_start.shape == x_t.shape
116
+ coef1 = extract_and_expand(self.posterior_mean_coef1, t, x_start)
117
+ coef2 = extract_and_expand(self.posterior_mean_coef2, t, x_t)
118
+ return coef1 * x_start + coef2 * x_t
119
+
120
+ def predict_xstart(self, x_t, t, eps):
121
+ coef1 = extract_and_expand(self.sqrt_recip_alphas_cumprod, t, x_t)
122
+ coef2 = extract_and_expand(self.sqrt_recipm1_alphas_cumprod, t, eps)
123
+ return coef1 * x_t - coef2 * eps
124
+
125
+ def get_mean_and_xstart(self, x, t, model_output):
126
+ pred_xstart = self.process_xstart(self.predict_xstart(x, t, model_output))
127
+ mean = self.q_posterior_mean(pred_xstart, x, t)
128
+
129
+ return mean, pred_xstart
130
+
131
+ # =========================
132
+ # Model Variance Processor
133
+ # =========================
134
+
135
+ __MODEL_VAR_PROCESSOR__ = {}
136
+
137
+ def register_var_processor(name: str):
138
+ def wrapper(cls):
139
+ if __MODEL_VAR_PROCESSOR__.get(name, None):
140
+ raise NameError(f"Name {name} is already registerd.")
141
+ __MODEL_VAR_PROCESSOR__[name] = cls
142
+ return cls
143
+ return wrapper
144
+
145
+ def get_var_processor(name: str, **kwargs):
146
+ if __MODEL_VAR_PROCESSOR__.get(name, None) is None:
147
+ raise NameError(f"Name {name} is not defined.")
148
+ return __MODEL_VAR_PROCESSOR__[name](**kwargs)
149
+
150
+ class VarianceProcessor(ABC):
151
+ @abstractmethod
152
+ def __init__(self, betas):
153
+ pass
154
+
155
+ @abstractmethod
156
+ def get_variance(self, x, t):
157
+ pass
158
+
159
+ @register_var_processor(name='fixed_small')
160
+ class FixedSmallVarianceProcessor(VarianceProcessor):
161
+ def __init__(self, betas):
162
+ alphas = 1.0 - betas
163
+ alphas_cumprod = np.cumprod(alphas, axis=0)
164
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
165
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
166
+ self.posterior_variance = (
167
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
168
+ )
169
+
170
+ def get_variance(self, x, t):
171
+ model_variance = self.posterior_variance
172
+ model_log_variance = np.log(model_variance)
173
+
174
+ model_variance = extract_and_expand(model_variance, t, x)
175
+ model_log_variance = extract_and_expand(model_log_variance, t, x)
176
+
177
+ return model_variance, model_log_variance
178
+
179
+ @register_var_processor(name='fixed_large')
180
+ class FixedLargeVarianceProcessor(VarianceProcessor):
181
+ def __init__(self, betas):
182
+ self.betas = betas
183
+
184
+ alphas = 1.0 - betas
185
+ alphas_cumprod = np.cumprod(alphas, axis=0)
186
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
190
+ )
191
+
192
+ def get_variance(self, x, t):
193
+ model_variance = np.append(self.posterior_variance[1], self.betas[1:])
194
+ model_log_variance = np.log(model_variance)
195
+
196
+ model_variance = extract_and_expand(model_variance, t, x)
197
+ model_log_variance = extract_and_expand(model_log_variance, t, x)
198
+
199
+ return model_variance, model_log_variance
200
+
201
+ @register_var_processor(name='learned')
202
+ class LearnedVarianceProcessor(VarianceProcessor):
203
+ def __init__(self, betas):
204
+ pass
205
+
206
+ def get_variance(self, x, t):
207
+ model_log_variance = x
208
+ model_variance = torch.exp(model_log_variance)
209
+ return model_variance, model_log_variance
210
+
211
+ @register_var_processor(name='learned_range')
212
+ class LearnedRangeVarianceProcessor(VarianceProcessor):
213
+ def __init__(self, betas):
214
+ self.betas = betas
215
+
216
+ alphas = 1.0 - betas
217
+ alphas_cumprod = np.cumprod(alphas, axis=0)
218
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
219
+
220
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
221
+ posterior_variance = (
222
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
223
+ )
224
+ # log calculation clipped because the posterior variance is 0 at the
225
+ # beginning of the diffusion chain.
226
+ self.posterior_log_variance_clipped = np.log(
227
+ np.append(posterior_variance[1], posterior_variance[1:])
228
+ )
229
+
230
+ def get_variance(self, x, t):
231
+ model_var_values = x
232
+ min_log = self.posterior_log_variance_clipped
233
+ max_log = np.log(self.betas)
234
+
235
+ min_log = extract_and_expand(min_log, t, x)
236
+ max_log = extract_and_expand(max_log, t, x)
237
+
238
+ # The model_var_values is [-1, 1] for [min_var, max_var]
239
+ frac = (model_var_values + 1.0) / 2.0
240
+ model_log_variance = frac * max_log + (1-frac) * min_log
241
+ model_variance = torch.exp(model_log_variance)
242
+ return model_variance, model_log_variance
243
+
244
+ # ================
245
+ # Helper function
246
+ # ================
247
+
248
+ def extract_and_expand(array, time, target):
249
+ array = torch.from_numpy(array).to(target.device)[time].float()
250
+ while array.ndim < target.ndim:
251
+ array = array.unsqueeze(-1)
252
+ return array.expand_as(target)
253
+
254
+
255
+ def expand_as(array, target):
256
+ if isinstance(array, np.ndarray):
257
+ array = torch.from_numpy(array)
258
+ elif isinstance(array, np.float):
259
+ array = torch.tensor([array])
260
+
261
+ while array.ndim < target.ndim:
262
+ array = array.unsqueeze(-1)
263
+
264
+ return array.expand_as(target).to(target.device)
diffusion-posterior-sampling/guided_diffusion/unet.py ADDED
@@ -0,0 +1,1117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import math
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import functools
10
+
11
+ from .fp16_util import convert_module_to_f16, convert_module_to_f32
12
+ from .nn import (
13
+ checkpoint,
14
+ conv_nd,
15
+ linear,
16
+ avg_pool_nd,
17
+ zero_module,
18
+ normalization,
19
+ timestep_embedding,
20
+ )
21
+
22
+
23
+ NUM_CLASSES = 1000
24
+
25
+ def create_model(
26
+ image_size,
27
+ num_channels,
28
+ num_res_blocks,
29
+ channel_mult="",
30
+ learn_sigma=False,
31
+ class_cond=False,
32
+ use_checkpoint=False,
33
+ attention_resolutions="16",
34
+ num_heads=1,
35
+ num_head_channels=-1,
36
+ num_heads_upsample=-1,
37
+ use_scale_shift_norm=False,
38
+ dropout=0,
39
+ resblock_updown=False,
40
+ use_fp16=False,
41
+ use_new_attention_order=False,
42
+ model_path='',
43
+ ):
44
+ if channel_mult == "":
45
+ if image_size == 512:
46
+ channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
47
+ elif image_size == 256:
48
+ channel_mult = (1, 1, 2, 2, 4, 4)
49
+ elif image_size == 128:
50
+ channel_mult = (1, 1, 2, 3, 4)
51
+ elif image_size == 64:
52
+ channel_mult = (1, 2, 3, 4)
53
+ else:
54
+ raise ValueError(f"unsupported image size: {image_size}")
55
+ else:
56
+ channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
57
+
58
+ attention_ds = []
59
+ if isinstance(attention_resolutions, int):
60
+ attention_ds.append(image_size // attention_resolutions)
61
+ elif isinstance(attention_resolutions, str):
62
+ for res in attention_resolutions.split(","):
63
+ attention_ds.append(image_size // int(res))
64
+ else:
65
+ raise NotImplementedError
66
+
67
+ model= UNetModel(
68
+ image_size=image_size,
69
+ in_channels=3,
70
+ model_channels=num_channels,
71
+ out_channels=(3 if not learn_sigma else 6),
72
+ num_res_blocks=num_res_blocks,
73
+ attention_resolutions=tuple(attention_ds),
74
+ dropout=dropout,
75
+ channel_mult=channel_mult,
76
+ num_classes=(NUM_CLASSES if class_cond else None),
77
+ use_checkpoint=use_checkpoint,
78
+ use_fp16=use_fp16,
79
+ num_heads=num_heads,
80
+ num_head_channels=num_head_channels,
81
+ num_heads_upsample=num_heads_upsample,
82
+ use_scale_shift_norm=use_scale_shift_norm,
83
+ resblock_updown=resblock_updown,
84
+ use_new_attention_order=use_new_attention_order,
85
+ )
86
+
87
+ try:
88
+ model.load_state_dict(th.load(model_path, map_location='cpu'))
89
+ except Exception as e:
90
+ print(f"Got exception: {e} / Randomly initialize")
91
+ return model
92
+
93
+ class AttentionPool2d(nn.Module):
94
+ """
95
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ spacial_dim: int,
101
+ embed_dim: int,
102
+ num_heads_channels: int,
103
+ output_dim: int = None,
104
+ ):
105
+ super().__init__()
106
+ self.positional_embedding = nn.Parameter(
107
+ th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5
108
+ )
109
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
110
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
111
+ self.num_heads = embed_dim // num_heads_channels
112
+ self.attention = QKVAttention(self.num_heads)
113
+
114
+ def forward(self, x):
115
+ b, c, *_spatial = x.shape
116
+ x = x.reshape(b, c, -1) # NC(HW)
117
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
118
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
119
+ x = self.qkv_proj(x)
120
+ x = self.attention(x)
121
+ x = self.c_proj(x)
122
+ return x[:, :, 0]
123
+
124
+
125
+ class TimestepBlock(nn.Module):
126
+ """
127
+ Any module where forward() takes timestep embeddings as a second argument.
128
+ """
129
+
130
+ @abstractmethod
131
+ def forward(self, x, emb):
132
+ """
133
+ Apply the module to `x` given `emb` timestep embeddings.
134
+ """
135
+
136
+
137
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
138
+ """
139
+ A sequential module that passes timestep embeddings to the children that
140
+ support it as an extra input.
141
+ """
142
+
143
+ def forward(self, x, emb):
144
+ for layer in self:
145
+ if isinstance(layer, TimestepBlock):
146
+ x = layer(x, emb)
147
+ else:
148
+ x = layer(x)
149
+ return x
150
+
151
+
152
+ class Upsample(nn.Module):
153
+ """
154
+ An upsampling layer with an optional convolution.
155
+
156
+ :param channels: channels in the inputs and outputs.
157
+ :param use_conv: a bool determining if a convolution is applied.
158
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
159
+ upsampling occurs in the inner-two dimensions.
160
+ """
161
+
162
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
163
+ super().__init__()
164
+ self.channels = channels
165
+ self.out_channels = out_channels or channels
166
+ self.use_conv = use_conv
167
+ self.dims = dims
168
+ if use_conv:
169
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
170
+
171
+ def forward(self, x):
172
+ assert x.shape[1] == self.channels
173
+ if self.dims == 3:
174
+ x = F.interpolate(
175
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
176
+ )
177
+ else:
178
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
179
+ if self.use_conv:
180
+ x = self.conv(x)
181
+ return x
182
+
183
+
184
+ class Downsample(nn.Module):
185
+ """
186
+ A downsampling layer with an optional convolution.
187
+
188
+ :param channels: channels in the inputs and outputs.
189
+ :param use_conv: a bool determining if a convolution is applied.
190
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
191
+ downsampling occurs in the inner-two dimensions.
192
+ """
193
+
194
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
195
+ super().__init__()
196
+ self.channels = channels
197
+ self.out_channels = out_channels or channels
198
+ self.use_conv = use_conv
199
+ self.dims = dims
200
+ stride = 2 if dims != 3 else (1, 2, 2)
201
+ if use_conv:
202
+ self.op = conv_nd(
203
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
204
+ )
205
+ else:
206
+ assert self.channels == self.out_channels
207
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
208
+
209
+ def forward(self, x):
210
+ assert x.shape[1] == self.channels
211
+ return self.op(x)
212
+
213
+
214
+ class ResBlock(TimestepBlock):
215
+ """
216
+ A residual block that can optionally change the number of channels.
217
+
218
+ :param channels: the number of input channels.
219
+ :param emb_channels: the number of timestep embedding channels.
220
+ :param dropout: the rate of dropout.
221
+ :param out_channels: if specified, the number of out channels.
222
+ :param use_conv: if True and out_channels is specified, use a spatial
223
+ convolution instead of a smaller 1x1 convolution to change the
224
+ channels in the skip connection.
225
+ :param dims: determines if the signal is 1D, 2D, or 3D.
226
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
227
+ :param up: if True, use this block for upsampling.
228
+ :param down: if True, use this block for downsampling.
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ channels,
234
+ emb_channels,
235
+ dropout,
236
+ out_channels=None,
237
+ use_conv=False,
238
+ use_scale_shift_norm=False,
239
+ dims=2,
240
+ use_checkpoint=False,
241
+ up=False,
242
+ down=False,
243
+ ):
244
+ super().__init__()
245
+ self.channels = channels
246
+ self.emb_channels = emb_channels
247
+ self.dropout = dropout
248
+ self.out_channels = out_channels or channels
249
+ self.use_conv = use_conv
250
+ self.use_checkpoint = use_checkpoint
251
+ self.use_scale_shift_norm = use_scale_shift_norm
252
+
253
+ self.in_layers = nn.Sequential(
254
+ normalization(channels),
255
+ nn.SiLU(),
256
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
257
+ )
258
+
259
+ self.updown = up or down
260
+
261
+ if up:
262
+ self.h_upd = Upsample(channels, False, dims)
263
+ self.x_upd = Upsample(channels, False, dims)
264
+ elif down:
265
+ self.h_upd = Downsample(channels, False, dims)
266
+ self.x_upd = Downsample(channels, False, dims)
267
+ else:
268
+ self.h_upd = self.x_upd = nn.Identity()
269
+
270
+ self.emb_layers = nn.Sequential(
271
+ nn.SiLU(),
272
+ linear(
273
+ emb_channels,
274
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
275
+ ),
276
+ )
277
+ self.out_layers = nn.Sequential(
278
+ normalization(self.out_channels),
279
+ nn.SiLU(),
280
+ nn.Dropout(p=dropout),
281
+ zero_module(
282
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
283
+ ),
284
+ )
285
+
286
+ if self.out_channels == channels:
287
+ self.skip_connection = nn.Identity()
288
+ elif use_conv:
289
+ self.skip_connection = conv_nd(
290
+ dims, channels, self.out_channels, 3, padding=1
291
+ )
292
+ else:
293
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
294
+
295
+ def forward(self, x, emb):
296
+ """
297
+ Apply the block to a Tensor, conditioned on a timestep embedding.
298
+
299
+ :param x: an [N x C x ...] Tensor of features.
300
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
301
+ :return: an [N x C x ...] Tensor of outputs.
302
+ """
303
+ return checkpoint(
304
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
305
+ )
306
+
307
+ def _forward(self, x, emb):
308
+ if self.updown:
309
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
310
+ h = in_rest(x)
311
+ h = self.h_upd(h)
312
+ x = self.x_upd(x)
313
+ h = in_conv(h)
314
+ else:
315
+ h = self.in_layers(x)
316
+ emb_out = self.emb_layers(emb).type(h.dtype)
317
+ while len(emb_out.shape) < len(h.shape):
318
+ emb_out = emb_out[..., None]
319
+ if self.use_scale_shift_norm:
320
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
321
+ scale, shift = th.chunk(emb_out, 2, dim=1)
322
+ h = out_norm(h) * (1 + scale) + shift
323
+ h = out_rest(h)
324
+ else:
325
+ h = h + emb_out
326
+ h = self.out_layers(h)
327
+ return self.skip_connection(x) + h
328
+
329
+
330
+ class AttentionBlock(nn.Module):
331
+ """
332
+ An attention block that allows spatial positions to attend to each other.
333
+
334
+ Originally ported from here, but adapted to the N-d case.
335
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ channels,
341
+ num_heads=1,
342
+ num_head_channels=-1,
343
+ use_checkpoint=False,
344
+ use_new_attention_order=False,
345
+ ):
346
+ super().__init__()
347
+ self.channels = channels
348
+ if num_head_channels == -1:
349
+ self.num_heads = num_heads
350
+ else:
351
+ assert (
352
+ channels % num_head_channels == 0
353
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
354
+ self.num_heads = channels // num_head_channels
355
+ self.use_checkpoint = use_checkpoint
356
+ self.norm = normalization(channels)
357
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
358
+ if use_new_attention_order:
359
+ # split qkv before split heads
360
+ self.attention = QKVAttention(self.num_heads)
361
+ else:
362
+ # split heads before split qkv
363
+ self.attention = QKVAttentionLegacy(self.num_heads)
364
+
365
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
366
+
367
+ def forward(self, x):
368
+ return checkpoint(self._forward, (x,), self.parameters(), True)
369
+
370
+ def _forward(self, x):
371
+ b, c, *spatial = x.shape
372
+ x = x.reshape(b, c, -1)
373
+ qkv = self.qkv(self.norm(x))
374
+ h = self.attention(qkv)
375
+ h = self.proj_out(h)
376
+ return (x + h).reshape(b, c, *spatial)
377
+
378
+
379
+ def count_flops_attn(model, _x, y):
380
+ """
381
+ A counter for the `thop` package to count the operations in an
382
+ attention operation.
383
+ Meant to be used like:
384
+ macs, params = thop.profile(
385
+ model,
386
+ inputs=(inputs, timestamps),
387
+ custom_ops={QKVAttention: QKVAttention.count_flops},
388
+ )
389
+ """
390
+ b, c, *spatial = y[0].shape
391
+ num_spatial = int(np.prod(spatial))
392
+ # We perform two matmuls with the same number of ops.
393
+ # The first computes the weight matrix, the second computes
394
+ # the combination of the value vectors.
395
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
396
+ model.total_ops += th.DoubleTensor([matmul_ops])
397
+
398
+
399
+ class QKVAttentionLegacy(nn.Module):
400
+ """
401
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
402
+ """
403
+
404
+ def __init__(self, n_heads):
405
+ super().__init__()
406
+ self.n_heads = n_heads
407
+
408
+ def forward(self, qkv):
409
+ """
410
+ Apply QKV attention.
411
+
412
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
413
+ :return: an [N x (H * C) x T] tensor after attention.
414
+ """
415
+ bs, width, length = qkv.shape
416
+ assert width % (3 * self.n_heads) == 0
417
+ ch = width // (3 * self.n_heads)
418
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
419
+ scale = 1 / math.sqrt(math.sqrt(ch))
420
+ weight = th.einsum(
421
+ "bct,bcs->bts", q * scale, k * scale
422
+ ) # More stable with f16 than dividing afterwards
423
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
424
+ a = th.einsum("bts,bcs->bct", weight, v)
425
+ return a.reshape(bs, -1, length)
426
+
427
+ @staticmethod
428
+ def count_flops(model, _x, y):
429
+ return count_flops_attn(model, _x, y)
430
+
431
+
432
+ class QKVAttention(nn.Module):
433
+ """
434
+ A module which performs QKV attention and splits in a different order.
435
+ """
436
+
437
+ def __init__(self, n_heads):
438
+ super().__init__()
439
+ self.n_heads = n_heads
440
+
441
+ def forward(self, qkv):
442
+ """
443
+ Apply QKV attention.
444
+
445
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
446
+ :return: an [N x (H * C) x T] tensor after attention.
447
+ """
448
+ bs, width, length = qkv.shape
449
+ assert width % (3 * self.n_heads) == 0
450
+ ch = width // (3 * self.n_heads)
451
+ q, k, v = qkv.chunk(3, dim=1)
452
+ scale = 1 / math.sqrt(math.sqrt(ch))
453
+ weight = th.einsum(
454
+ "bct,bcs->bts",
455
+ (q * scale).view(bs * self.n_heads, ch, length),
456
+ (k * scale).view(bs * self.n_heads, ch, length),
457
+ ) # More stable with f16 than dividing afterwards
458
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
459
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
460
+ return a.reshape(bs, -1, length)
461
+
462
+ @staticmethod
463
+ def count_flops(model, _x, y):
464
+ return count_flops_attn(model, _x, y)
465
+
466
+
467
+ class UNetModel(nn.Module):
468
+ """
469
+ The full UNet model with attention and timestep embedding.
470
+
471
+ :param in_channels: channels in the input Tensor.
472
+ :param model_channels: base channel count for the model.
473
+ :param out_channels: channels in the output Tensor.
474
+ :param num_res_blocks: number of residual blocks per downsample.
475
+ :param attention_resolutions: a collection of downsample rates at which
476
+ attention will take place. May be a set, list, or tuple.
477
+ For example, if this contains 4, then at 4x downsampling, attention
478
+ will be used.
479
+ :param dropout: the dropout probability.
480
+ :param channel_mult: channel multiplier for each level of the UNet.
481
+ :param conv_resample: if True, use learned convolutions for upsampling and
482
+ downsampling.
483
+ :param dims: determines if the signal is 1D, 2D, or 3D.
484
+ :param num_classes: if specified (as an int), then this model will be
485
+ class-conditional with `num_classes` classes.
486
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
487
+ :param num_heads: the number of attention heads in each attention layer.
488
+ :param num_heads_channels: if specified, ignore num_heads and instead use
489
+ a fixed channel width per attention head.
490
+ :param num_heads_upsample: works with num_heads to set a different number
491
+ of heads for upsampling. Deprecated.
492
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
493
+ :param resblock_updown: use residual blocks for up/downsampling.
494
+ :param use_new_attention_order: use a different attention pattern for potentially
495
+ increased efficiency.
496
+ """
497
+
498
+ def __init__(
499
+ self,
500
+ image_size,
501
+ in_channels,
502
+ model_channels,
503
+ out_channels,
504
+ num_res_blocks,
505
+ attention_resolutions,
506
+ dropout=0,
507
+ channel_mult=(1, 2, 4, 8),
508
+ conv_resample=True,
509
+ dims=2,
510
+ num_classes=None,
511
+ use_checkpoint=False,
512
+ use_fp16=False,
513
+ num_heads=1,
514
+ num_head_channels=-1,
515
+ num_heads_upsample=-1,
516
+ use_scale_shift_norm=False,
517
+ resblock_updown=False,
518
+ use_new_attention_order=False,
519
+ ):
520
+ super().__init__()
521
+
522
+ if num_heads_upsample == -1:
523
+ num_heads_upsample = num_heads
524
+
525
+ self.image_size = image_size
526
+ self.in_channels = in_channels
527
+ self.model_channels = model_channels
528
+ self.out_channels = out_channels
529
+ self.num_res_blocks = num_res_blocks
530
+ self.attention_resolutions = attention_resolutions
531
+ self.dropout = dropout
532
+ self.channel_mult = channel_mult
533
+ self.conv_resample = conv_resample
534
+ self.num_classes = num_classes
535
+ self.use_checkpoint = use_checkpoint
536
+ self.dtype = th.float16 if use_fp16 else th.float32
537
+ self.num_heads = num_heads
538
+ self.num_head_channels = num_head_channels
539
+ self.num_heads_upsample = num_heads_upsample
540
+
541
+ time_embed_dim = model_channels * 4
542
+ self.time_embed = nn.Sequential(
543
+ linear(model_channels, time_embed_dim),
544
+ nn.SiLU(),
545
+ linear(time_embed_dim, time_embed_dim),
546
+ )
547
+
548
+ if self.num_classes is not None:
549
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
550
+
551
+ ch = input_ch = int(channel_mult[0] * model_channels)
552
+ self.input_blocks = nn.ModuleList(
553
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
554
+ )
555
+ self._feature_size = ch
556
+ input_block_chans = [ch]
557
+ ds = 1
558
+ for level, mult in enumerate(channel_mult):
559
+ for _ in range(num_res_blocks):
560
+ layers = [
561
+ ResBlock(
562
+ ch,
563
+ time_embed_dim,
564
+ dropout,
565
+ out_channels=int(mult * model_channels),
566
+ dims=dims,
567
+ use_checkpoint=use_checkpoint,
568
+ use_scale_shift_norm=use_scale_shift_norm,
569
+ )
570
+ ]
571
+ ch = int(mult * model_channels)
572
+ if ds in attention_resolutions:
573
+ layers.append(
574
+ AttentionBlock(
575
+ ch,
576
+ use_checkpoint=use_checkpoint,
577
+ num_heads=num_heads,
578
+ num_head_channels=num_head_channels,
579
+ use_new_attention_order=use_new_attention_order,
580
+ )
581
+ )
582
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
583
+ self._feature_size += ch
584
+ input_block_chans.append(ch)
585
+ if level != len(channel_mult) - 1:
586
+ out_ch = ch
587
+ self.input_blocks.append(
588
+ TimestepEmbedSequential(
589
+ ResBlock(
590
+ ch,
591
+ time_embed_dim,
592
+ dropout,
593
+ out_channels=out_ch,
594
+ dims=dims,
595
+ use_checkpoint=use_checkpoint,
596
+ use_scale_shift_norm=use_scale_shift_norm,
597
+ down=True,
598
+ )
599
+ if resblock_updown
600
+ else Downsample(
601
+ ch, conv_resample, dims=dims, out_channels=out_ch
602
+ )
603
+ )
604
+ )
605
+ ch = out_ch
606
+ input_block_chans.append(ch)
607
+ ds *= 2
608
+ self._feature_size += ch
609
+
610
+ self.middle_block = TimestepEmbedSequential(
611
+ ResBlock(
612
+ ch,
613
+ time_embed_dim,
614
+ dropout,
615
+ dims=dims,
616
+ use_checkpoint=use_checkpoint,
617
+ use_scale_shift_norm=use_scale_shift_norm,
618
+ ),
619
+ AttentionBlock(
620
+ ch,
621
+ use_checkpoint=use_checkpoint,
622
+ num_heads=num_heads,
623
+ num_head_channels=num_head_channels,
624
+ use_new_attention_order=use_new_attention_order,
625
+ ),
626
+ ResBlock(
627
+ ch,
628
+ time_embed_dim,
629
+ dropout,
630
+ dims=dims,
631
+ use_checkpoint=use_checkpoint,
632
+ use_scale_shift_norm=use_scale_shift_norm,
633
+ ),
634
+ )
635
+ self._feature_size += ch
636
+
637
+ self.output_blocks = nn.ModuleList([])
638
+ for level, mult in list(enumerate(channel_mult))[::-1]:
639
+ for i in range(num_res_blocks + 1):
640
+ ich = input_block_chans.pop()
641
+ layers = [
642
+ ResBlock(
643
+ ch + ich,
644
+ time_embed_dim,
645
+ dropout,
646
+ out_channels=int(model_channels * mult),
647
+ dims=dims,
648
+ use_checkpoint=use_checkpoint,
649
+ use_scale_shift_norm=use_scale_shift_norm,
650
+ )
651
+ ]
652
+ ch = int(model_channels * mult)
653
+ if ds in attention_resolutions:
654
+ layers.append(
655
+ AttentionBlock(
656
+ ch,
657
+ use_checkpoint=use_checkpoint,
658
+ num_heads=num_heads_upsample,
659
+ num_head_channels=num_head_channels,
660
+ use_new_attention_order=use_new_attention_order,
661
+ )
662
+ )
663
+ if level and i == num_res_blocks:
664
+ out_ch = ch
665
+ layers.append(
666
+ ResBlock(
667
+ ch,
668
+ time_embed_dim,
669
+ dropout,
670
+ out_channels=out_ch,
671
+ dims=dims,
672
+ use_checkpoint=use_checkpoint,
673
+ use_scale_shift_norm=use_scale_shift_norm,
674
+ up=True,
675
+ )
676
+ if resblock_updown
677
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
678
+ )
679
+ ds //= 2
680
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
681
+ self._feature_size += ch
682
+
683
+ self.out = nn.Sequential(
684
+ normalization(ch),
685
+ nn.SiLU(),
686
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
687
+ )
688
+
689
+ def convert_to_fp16(self):
690
+ """
691
+ Convert the torso of the model to float16.
692
+ """
693
+ self.input_blocks.apply(convert_module_to_f16)
694
+ self.middle_block.apply(convert_module_to_f16)
695
+ self.output_blocks.apply(convert_module_to_f16)
696
+
697
+ def convert_to_fp32(self):
698
+ """
699
+ Convert the torso of the model to float32.
700
+ """
701
+ self.input_blocks.apply(convert_module_to_f32)
702
+ self.middle_block.apply(convert_module_to_f32)
703
+ self.output_blocks.apply(convert_module_to_f32)
704
+
705
+ def forward(self, x, timesteps, y=None):
706
+ """
707
+ Apply the model to an input batch.
708
+
709
+ :param x: an [N x C x ...] Tensor of inputs.
710
+ :param timesteps: a 1-D batch of timesteps.
711
+ :param y: an [N] Tensor of labels, if class-conditional.
712
+ :return: an [N x C x ...] Tensor of outputs.
713
+ """
714
+ assert (y is not None) == (
715
+ self.num_classes is not None
716
+ ), "must specify y if and only if the model is class-conditional"
717
+
718
+ hs = []
719
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
720
+
721
+ if self.num_classes is not None:
722
+ assert y.shape == (x.shape[0],)
723
+ emb = emb + self.label_emb(y)
724
+
725
+ h = x.type(self.dtype)
726
+ for module in self.input_blocks:
727
+ h = module(h, emb)
728
+ hs.append(h)
729
+ h = self.middle_block(h, emb)
730
+ for module in self.output_blocks:
731
+ h = th.cat([h, hs.pop()], dim=1)
732
+ h = module(h, emb)
733
+ h = h.type(x.dtype)
734
+ return self.out(h)
735
+
736
+
737
+ class SuperResModel(UNetModel):
738
+ """
739
+ A UNetModel that performs super-resolution.
740
+
741
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
742
+ """
743
+
744
+ def __init__(self, image_size, in_channels, *args, **kwargs):
745
+ super().__init__(image_size, in_channels * 2, *args, **kwargs)
746
+
747
+ def forward(self, x, timesteps, low_res=None, **kwargs):
748
+ _, _, new_height, new_width = x.shape
749
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
750
+ x = th.cat([x, upsampled], dim=1)
751
+ return super().forward(x, timesteps, **kwargs)
752
+
753
+
754
+ class EncoderUNetModel(nn.Module):
755
+ """
756
+ The half UNet model with attention and timestep embedding.
757
+
758
+ For usage, see UNet.
759
+ """
760
+
761
+ def __init__(
762
+ self,
763
+ image_size,
764
+ in_channels,
765
+ model_channels,
766
+ out_channels,
767
+ num_res_blocks,
768
+ attention_resolutions,
769
+ dropout=0,
770
+ channel_mult=(1, 2, 4, 8),
771
+ conv_resample=True,
772
+ dims=2,
773
+ use_checkpoint=False,
774
+ use_fp16=False,
775
+ num_heads=1,
776
+ num_head_channels=-1,
777
+ num_heads_upsample=-1,
778
+ use_scale_shift_norm=False,
779
+ resblock_updown=False,
780
+ use_new_attention_order=False,
781
+ pool="adaptive",
782
+ ):
783
+ super().__init__()
784
+
785
+ if num_heads_upsample == -1:
786
+ num_heads_upsample = num_heads
787
+
788
+ self.in_channels = in_channels
789
+ self.model_channels = model_channels
790
+ self.out_channels = out_channels
791
+ self.num_res_blocks = num_res_blocks
792
+ self.attention_resolutions = attention_resolutions
793
+ self.dropout = dropout
794
+ self.channel_mult = channel_mult
795
+ self.conv_resample = conv_resample
796
+ self.use_checkpoint = use_checkpoint
797
+ self.dtype = th.float16 if use_fp16 else th.float32
798
+ self.num_heads = num_heads
799
+ self.num_head_channels = num_head_channels
800
+ self.num_heads_upsample = num_heads_upsample
801
+
802
+ time_embed_dim = model_channels * 4
803
+ self.time_embed = nn.Sequential(
804
+ linear(model_channels, time_embed_dim),
805
+ nn.SiLU(),
806
+ linear(time_embed_dim, time_embed_dim),
807
+ )
808
+
809
+ ch = int(channel_mult[0] * model_channels)
810
+ self.input_blocks = nn.ModuleList(
811
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
812
+ )
813
+ self._feature_size = ch
814
+ input_block_chans = [ch]
815
+ ds = 1
816
+ for level, mult in enumerate(channel_mult):
817
+ for _ in range(num_res_blocks):
818
+ layers = [
819
+ ResBlock(
820
+ ch,
821
+ time_embed_dim,
822
+ dropout,
823
+ out_channels=int(mult * model_channels),
824
+ dims=dims,
825
+ use_checkpoint=use_checkpoint,
826
+ use_scale_shift_norm=use_scale_shift_norm,
827
+ )
828
+ ]
829
+ ch = int(mult * model_channels)
830
+ if ds in attention_resolutions:
831
+ layers.append(
832
+ AttentionBlock(
833
+ ch,
834
+ use_checkpoint=use_checkpoint,
835
+ num_heads=num_heads,
836
+ num_head_channels=num_head_channels,
837
+ use_new_attention_order=use_new_attention_order,
838
+ )
839
+ )
840
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
841
+ self._feature_size += ch
842
+ input_block_chans.append(ch)
843
+ if level != len(channel_mult) - 1:
844
+ out_ch = ch
845
+ self.input_blocks.append(
846
+ TimestepEmbedSequential(
847
+ ResBlock(
848
+ ch,
849
+ time_embed_dim,
850
+ dropout,
851
+ out_channels=out_ch,
852
+ dims=dims,
853
+ use_checkpoint=use_checkpoint,
854
+ use_scale_shift_norm=use_scale_shift_norm,
855
+ down=True,
856
+ )
857
+ if resblock_updown
858
+ else Downsample(
859
+ ch, conv_resample, dims=dims, out_channels=out_ch
860
+ )
861
+ )
862
+ )
863
+ ch = out_ch
864
+ input_block_chans.append(ch)
865
+ ds *= 2
866
+ self._feature_size += ch
867
+
868
+ self.middle_block = TimestepEmbedSequential(
869
+ ResBlock(
870
+ ch,
871
+ time_embed_dim,
872
+ dropout,
873
+ dims=dims,
874
+ use_checkpoint=use_checkpoint,
875
+ use_scale_shift_norm=use_scale_shift_norm,
876
+ ),
877
+ AttentionBlock(
878
+ ch,
879
+ use_checkpoint=use_checkpoint,
880
+ num_heads=num_heads,
881
+ num_head_channels=num_head_channels,
882
+ use_new_attention_order=use_new_attention_order,
883
+ ),
884
+ ResBlock(
885
+ ch,
886
+ time_embed_dim,
887
+ dropout,
888
+ dims=dims,
889
+ use_checkpoint=use_checkpoint,
890
+ use_scale_shift_norm=use_scale_shift_norm,
891
+ ),
892
+ )
893
+ self._feature_size += ch
894
+ self.pool = pool
895
+ if pool == "adaptive":
896
+ self.out = nn.Sequential(
897
+ normalization(ch),
898
+ nn.SiLU(),
899
+ nn.AdaptiveAvgPool2d((1, 1)),
900
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
901
+ nn.Flatten(),
902
+ )
903
+ elif pool == "attention":
904
+ assert num_head_channels != -1
905
+ self.out = nn.Sequential(
906
+ normalization(ch),
907
+ nn.SiLU(),
908
+ AttentionPool2d(
909
+ (image_size // ds), ch, num_head_channels, out_channels
910
+ ),
911
+ )
912
+ elif pool == "spatial":
913
+ self.out = nn.Sequential(
914
+ nn.Linear(self._feature_size, 2048),
915
+ nn.ReLU(),
916
+ nn.Linear(2048, self.out_channels),
917
+ )
918
+ elif pool == "spatial_v2":
919
+ self.out = nn.Sequential(
920
+ nn.Linear(self._feature_size, 2048),
921
+ normalization(2048),
922
+ nn.SiLU(),
923
+ nn.Linear(2048, self.out_channels),
924
+ )
925
+ else:
926
+ raise NotImplementedError(f"Unexpected {pool} pooling")
927
+
928
+ def convert_to_fp16(self):
929
+ """
930
+ Convert the torso of the model to float16.
931
+ """
932
+ self.input_blocks.apply(convert_module_to_f16)
933
+ self.middle_block.apply(convert_module_to_f16)
934
+
935
+ def convert_to_fp32(self):
936
+ """
937
+ Convert the torso of the model to float32.
938
+ """
939
+ self.input_blocks.apply(convert_module_to_f32)
940
+ self.middle_block.apply(convert_module_to_f32)
941
+
942
+ def forward(self, x, timesteps):
943
+ """
944
+ Apply the model to an input batch.
945
+
946
+ :param x: an [N x C x ...] Tensor of inputs.
947
+ :param timesteps: a 1-D batch of timesteps.
948
+ :return: an [N x K] Tensor of outputs.
949
+ """
950
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
951
+
952
+ results = []
953
+ h = x.type(self.dtype)
954
+ for module in self.input_blocks:
955
+ h = module(h, emb)
956
+ if self.pool.startswith("spatial"):
957
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
958
+ h = self.middle_block(h, emb)
959
+ if self.pool.startswith("spatial"):
960
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
961
+ h = th.cat(results, axis=-1)
962
+ return self.out(h)
963
+ else:
964
+ h = h.type(x.dtype)
965
+ return self.out(h)
966
+
967
+
968
+ class NLayerDiscriminator(nn.Module):
969
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
970
+ super(NLayerDiscriminator, self).__init__()
971
+ if type(norm_layer) == functools.partial:
972
+ use_bias = norm_layer.func == nn.InstanceNorm2d
973
+ else:
974
+ use_bias = norm_layer == nn.InstanceNorm2d
975
+
976
+ kw = 4
977
+ padw = 1
978
+ sequence = [
979
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
980
+ nn.LeakyReLU(0.2, True)
981
+ ]
982
+
983
+ nf_mult = 1
984
+ nf_mult_prev = 1
985
+ for n in range(1, n_layers):
986
+ nf_mult_prev = nf_mult
987
+ nf_mult = min(2**n, 8)
988
+ sequence += [
989
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
990
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias),
991
+ norm_layer(ndf * nf_mult),
992
+ nn.LeakyReLU(0.2, True)
993
+ ]
994
+
995
+ nf_mult_prev = nf_mult
996
+ nf_mult = min(2**n_layers, 8)
997
+ sequence += [
998
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
999
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias),
1000
+ norm_layer(ndf * nf_mult),
1001
+ nn.LeakyReLU(0.2, True)
1002
+ ]
1003
+
1004
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=2, padding=padw)] + [nn.Dropout(0.5)]
1005
+ if use_sigmoid:
1006
+ sequence += [nn.Sigmoid()]
1007
+
1008
+ self.model = nn.Sequential(*sequence)
1009
+
1010
+ def forward(self, input):
1011
+ return self.model(input)
1012
+
1013
+
1014
+ class GANLoss(nn.Module):
1015
+ """Define different GAN objectives.
1016
+
1017
+ The GANLoss class abstracts away the need to create the target label tensor
1018
+ that has the same size as the input.
1019
+ """
1020
+
1021
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
1022
+ """ Initialize the GANLoss class.
1023
+
1024
+ Parameters:
1025
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
1026
+ target_real_label (bool) - - label for a real image
1027
+ target_fake_label (bool) - - label of a fake image
1028
+
1029
+ Note: Do not use sigmoid as the last layer of Discriminator.
1030
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
1031
+ """
1032
+ super(GANLoss, self).__init__()
1033
+ self.register_buffer('real_label', th.tensor(target_real_label))
1034
+ self.register_buffer('fake_label', th.tensor(target_fake_label))
1035
+ self.gan_mode = gan_mode
1036
+ if gan_mode == 'lsgan':
1037
+ self.loss = nn.MSELoss()
1038
+ elif gan_mode == 'vanilla':
1039
+ self.loss = nn.BCEWithLogitsLoss()
1040
+ elif gan_mode in ['wgangp']:
1041
+ self.loss = None
1042
+ else:
1043
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
1044
+
1045
+ def get_target_tensor(self, prediction, target_is_real):
1046
+ """Create label tensors with the same size as the input.
1047
+
1048
+ Parameters:
1049
+ prediction (tensor) - - tpyically the prediction from a discriminator
1050
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
1051
+
1052
+ Returns:
1053
+ A label tensor filled with ground truth label, and with the size of the input
1054
+ """
1055
+
1056
+ if target_is_real:
1057
+ target_tensor = self.real_label
1058
+ else:
1059
+ target_tensor = self.fake_label
1060
+ return target_tensor.expand_as(prediction)
1061
+
1062
+ def __call__(self, prediction, target_is_real):
1063
+ """Calculate loss given Discriminator's output and grount truth labels.
1064
+
1065
+ Parameters:
1066
+ prediction (tensor) - - tpyically the prediction output from a discriminator
1067
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
1068
+
1069
+ Returns:
1070
+ the calculated loss.
1071
+ """
1072
+ if self.gan_mode in ['lsgan', 'vanilla']:
1073
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
1074
+ loss = self.loss(prediction, target_tensor)
1075
+ elif self.gan_mode == 'wgangp':
1076
+ if target_is_real:
1077
+ loss = -prediction.mean()
1078
+ else:
1079
+ loss = prediction.mean()
1080
+ return loss
1081
+
1082
+
1083
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
1084
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
1085
+
1086
+ Arguments:
1087
+ netD (network) -- discriminator network
1088
+ real_data (tensor array) -- real images
1089
+ fake_data (tensor array) -- generated images from the generator
1090
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
1091
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
1092
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
1093
+ lambda_gp (float) -- weight for this loss
1094
+
1095
+ Returns the gradient penalty loss
1096
+ """
1097
+ if lambda_gp > 0.0:
1098
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
1099
+ interpolatesv = real_data
1100
+ elif type == 'fake':
1101
+ interpolatesv = fake_data
1102
+ elif type == 'mixed':
1103
+ alpha = th.rand(real_data.shape[0], 1, device=device)
1104
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
1105
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
1106
+ else:
1107
+ raise NotImplementedError('{} not implemented'.format(type))
1108
+ interpolatesv.requires_grad_(True)
1109
+ disc_interpolates = netD(interpolatesv)
1110
+ gradients = th.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
1111
+ grad_outputs=th.ones(disc_interpolates.size()).to(device),
1112
+ create_graph=True, retain_graph=True, only_inputs=True)
1113
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
1114
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
1115
+ return gradient_penalty, gradients
1116
+ else:
1117
+ return 0.0, None