模型代码
Browse files- models/__init__.py +2 -0
- models/ddm.py +260 -0
- models/restoration.py +59 -0
- models/unet.py +331 -0
models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from models.ddm import *
|
2 |
+
from models.restoration import *
|
models/ddm.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.backends.cudnn as cudnn
|
7 |
+
import utils
|
8 |
+
from models.unet import DiffusionUNet
|
9 |
+
import torch.distributed as dist
|
10 |
+
from torch.utils.tensorboard import SummaryWriter
|
11 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
12 |
+
|
13 |
+
|
14 |
+
def data_transform(X):
|
15 |
+
return 2 * X - 1.0
|
16 |
+
|
17 |
+
|
18 |
+
def inverse_data_transform(X):
|
19 |
+
return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
|
20 |
+
|
21 |
+
|
22 |
+
class EMAHelper(object):
|
23 |
+
def __init__(self, mu=0.9999):
|
24 |
+
self.mu = mu
|
25 |
+
self.shadow = {}
|
26 |
+
|
27 |
+
def register(self, module):
|
28 |
+
if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
|
29 |
+
module = module.module
|
30 |
+
for name, param in module.named_parameters():
|
31 |
+
if param.requires_grad:
|
32 |
+
self.shadow[name] = param.data.clone()
|
33 |
+
|
34 |
+
def update(self, module, device):
|
35 |
+
if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
|
36 |
+
module = module.module
|
37 |
+
for name, param in module.named_parameters():
|
38 |
+
if param.requires_grad:
|
39 |
+
self.shadow[name].data = (1. - self.mu) * param.data + self.mu * self.shadow[name].data.to(device)
|
40 |
+
|
41 |
+
def ema(self, module):
|
42 |
+
if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
|
43 |
+
module = module.module
|
44 |
+
for name, param in module.named_parameters():
|
45 |
+
if param.requires_grad:
|
46 |
+
param.data.copy_(self.shadow[name].data)
|
47 |
+
|
48 |
+
def ema_copy(self, module):
|
49 |
+
if isinstance(module, nn.DataParallel) or isinstance(module, nn.parallel.DistributedDataParallel):
|
50 |
+
inner_module = module.module
|
51 |
+
module_copy = type(inner_module)(inner_module.config).to(inner_module.config.device)
|
52 |
+
module_copy.load_state_dict(inner_module.state_dict())
|
53 |
+
module_copy = nn.DataParallel(module_copy)
|
54 |
+
else:
|
55 |
+
module_copy = type(module)(module.config).to(module.config.device)
|
56 |
+
module_copy.load_state_dict(module.state_dict())
|
57 |
+
self.ema(module_copy)
|
58 |
+
return module_copy
|
59 |
+
|
60 |
+
def state_dict(self):
|
61 |
+
return self.shadow
|
62 |
+
|
63 |
+
def load_state_dict(self, state_dict):
|
64 |
+
self.shadow = state_dict
|
65 |
+
|
66 |
+
|
67 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
68 |
+
def sigmoid(x):
|
69 |
+
return 1 / (np.exp(-x) + 1)
|
70 |
+
|
71 |
+
if beta_schedule == "quad":
|
72 |
+
betas = (np.linspace(beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64) ** 2)
|
73 |
+
elif beta_schedule == "linear":
|
74 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
75 |
+
elif beta_schedule == "const":
|
76 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
77 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
78 |
+
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
|
79 |
+
elif beta_schedule == "sigmoid":
|
80 |
+
betas = np.linspace(-6, 6, num_diffusion_timesteps)
|
81 |
+
betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
|
82 |
+
else:
|
83 |
+
raise NotImplementedError(beta_schedule)
|
84 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
85 |
+
return betas
|
86 |
+
|
87 |
+
|
88 |
+
def noise_estimation_loss(model, x0, t, e, b):
|
89 |
+
a = (1 - b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)
|
90 |
+
x = x0[:, 3:, :, :] * a.sqrt() + e * (1.0 - a).sqrt()
|
91 |
+
output = model(torch.cat([x0[:, :3, :, :], x], dim=1), t.float())
|
92 |
+
return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
|
93 |
+
|
94 |
+
|
95 |
+
class DenoisingDiffusion(object):
|
96 |
+
def __init__(self, config, test=False):
|
97 |
+
super().__init__()
|
98 |
+
self.config = config
|
99 |
+
self.device = config.device
|
100 |
+
self.writer = SummaryWriter(config.data.tensorboard)
|
101 |
+
self.model = DiffusionUNet(config)
|
102 |
+
self.model.to(self.device)
|
103 |
+
if test:
|
104 |
+
self.model = torch.nn.DataParallel(self.model)
|
105 |
+
else:
|
106 |
+
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[config.local_rank],
|
107 |
+
output_device=config.local_rank)
|
108 |
+
self.ema_helper = EMAHelper()
|
109 |
+
self.ema_helper.register(self.model)
|
110 |
+
|
111 |
+
self.optimizer = utils.optimize.get_optimizer(self.config, self.model.parameters())
|
112 |
+
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=config.training.n_epochs)
|
113 |
+
self.start_epoch, self.step = 0, 0
|
114 |
+
|
115 |
+
betas = get_beta_schedule(
|
116 |
+
beta_schedule=config.diffusion.beta_schedule,
|
117 |
+
beta_start=config.diffusion.beta_start,
|
118 |
+
beta_end=config.diffusion.beta_end,
|
119 |
+
num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
|
120 |
+
)
|
121 |
+
|
122 |
+
betas = self.betas = torch.from_numpy(betas).float().to(self.device)
|
123 |
+
self.num_timesteps = betas.shape[0]
|
124 |
+
|
125 |
+
def load_ddm_ckpt(self, load_path, ema=False):
|
126 |
+
checkpoint = utils.logging.load_checkpoint(load_path, None)
|
127 |
+
self.start_epoch = checkpoint['epoch']
|
128 |
+
self.step = checkpoint['step']
|
129 |
+
self.model.load_state_dict(checkpoint['state_dict'], strict=True)
|
130 |
+
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
131 |
+
self.ema_helper.load_state_dict(checkpoint['ema_helper'])
|
132 |
+
self.scheduler.load_state_dict(checkpoint['scheduler'])
|
133 |
+
if ema:
|
134 |
+
self.ema_helper.ema(self.model)
|
135 |
+
print("=> loaded checkpoint '{}' (epoch {}, step {})".format(load_path, checkpoint['epoch'], self.step))
|
136 |
+
|
137 |
+
def train(self, DATASET):
|
138 |
+
cudnn.benchmark = True
|
139 |
+
train_loader, val_loader = DATASET.get_loaders()
|
140 |
+
pretrained_model_path = self.config.training.resume + '.pth.tar'
|
141 |
+
if os.path.isfile(pretrained_model_path):
|
142 |
+
self.load_ddm_ckpt(pretrained_model_path)
|
143 |
+
dist.barrier()
|
144 |
+
# 训练
|
145 |
+
for epoch in range(self.start_epoch, self.config.training.n_epochs):
|
146 |
+
if (epoch == 0) and dist.get_rank() == 0:
|
147 |
+
utils.logging.save_checkpoint({
|
148 |
+
'epoch': epoch + 1,
|
149 |
+
'step': self.step,
|
150 |
+
'state_dict': self.model.state_dict(),
|
151 |
+
'optimizer': self.optimizer.state_dict(),
|
152 |
+
'ema_helper': self.ema_helper.state_dict(),
|
153 |
+
'config': self.config,
|
154 |
+
'scheduler': self.scheduler.state_dict()
|
155 |
+
}, filename=self.config.training.resume + '_' + str(epoch))
|
156 |
+
utils.logging.save_checkpoint({
|
157 |
+
'epoch': epoch + 1,
|
158 |
+
'step': self.step,
|
159 |
+
'state_dict': self.model.state_dict(),
|
160 |
+
'optimizer': self.optimizer.state_dict(),
|
161 |
+
'ema_helper': self.ema_helper.state_dict(),
|
162 |
+
'config': self.config,
|
163 |
+
'scheduler': self.scheduler.state_dict()
|
164 |
+
}, filename=self.config.training.resume)
|
165 |
+
if dist.get_rank() == 0:
|
166 |
+
print('=> current epoch: ', epoch)
|
167 |
+
data_start = time.time()
|
168 |
+
data_time = 0
|
169 |
+
train_loader.sampler.set_epoch(epoch)
|
170 |
+
for i, (x, y) in enumerate(train_loader):
|
171 |
+
x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
|
172 |
+
n = x.size(0)
|
173 |
+
data_time += time.time() - data_start
|
174 |
+
self.model.train()
|
175 |
+
self.step += 1
|
176 |
+
|
177 |
+
x = x.to(self.device)
|
178 |
+
x = data_transform(x)
|
179 |
+
e = torch.randn_like(x[:, 3:, :, :])
|
180 |
+
b = self.betas
|
181 |
+
|
182 |
+
# antithetic sampling
|
183 |
+
t = torch.randint(low=0, high=self.num_timesteps, size=(n // 2 + 1,)).to(self.device)
|
184 |
+
t = torch.cat([t, self.num_timesteps - t - 1], dim=0)[:n]
|
185 |
+
loss = noise_estimation_loss(self.model, x, t, e, b)
|
186 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
187 |
+
|
188 |
+
if self.step % 10 == 0:
|
189 |
+
print(
|
190 |
+
'rank: %d, step: %d, loss: %.6f, lr: %.6f, time consumption: %.6f' % (
|
191 |
+
dist.get_rank(), self.step, loss.item(), current_lr, data_time / (i + 1)))
|
192 |
+
|
193 |
+
# 更新参数
|
194 |
+
self.optimizer.zero_grad()
|
195 |
+
loss.backward()
|
196 |
+
self.optimizer.step()
|
197 |
+
self.ema_helper.update(self.model, self.device)
|
198 |
+
data_start = time.time()
|
199 |
+
|
200 |
+
if self.step % self.config.training.validation_freq == 0:
|
201 |
+
self.model.eval()
|
202 |
+
self.sample_validation_patches(val_loader, self.step)
|
203 |
+
|
204 |
+
if (self.step % 100 == 0) and dist.get_rank() == 0:
|
205 |
+
self.writer.add_scalar('train/loss', loss.item(), self.step)
|
206 |
+
self.writer.add_scalar('train/lr', current_lr, self.step)
|
207 |
+
|
208 |
+
self.scheduler.step()
|
209 |
+
# 保存模型
|
210 |
+
if (epoch % self.config.training.snapshot_freq == 0) and dist.get_rank() == 0:
|
211 |
+
utils.logging.save_checkpoint({
|
212 |
+
'epoch': epoch + 1,
|
213 |
+
'step': self.step,
|
214 |
+
'state_dict': self.model.state_dict(),
|
215 |
+
'optimizer': self.optimizer.state_dict(),
|
216 |
+
'ema_helper': self.ema_helper.state_dict(),
|
217 |
+
'config': self.config,
|
218 |
+
'scheduler': self.scheduler.state_dict()
|
219 |
+
}, filename=self.config.training.resume + '_' + str(epoch))
|
220 |
+
utils.logging.save_checkpoint({
|
221 |
+
'epoch': epoch + 1,
|
222 |
+
'step': self.step,
|
223 |
+
'state_dict': self.model.state_dict(),
|
224 |
+
'optimizer': self.optimizer.state_dict(),
|
225 |
+
'ema_helper': self.ema_helper.state_dict(),
|
226 |
+
'config': self.config,
|
227 |
+
'scheduler': self.scheduler.state_dict()
|
228 |
+
}, filename=self.config.training.resume)
|
229 |
+
|
230 |
+
def sample_image(self, x_cond, x, last=True, patch_locs=None, patch_size=None):
|
231 |
+
skip = self.config.diffusion.num_diffusion_timesteps // self.config.sampling.sampling_timesteps
|
232 |
+
seq = range(0, self.config.diffusion.num_diffusion_timesteps, skip)
|
233 |
+
if patch_locs is not None:
|
234 |
+
xs = utils.sampling.generalized_steps_overlapping(x, x_cond, seq, self.model, self.betas, eta=0.,
|
235 |
+
corners=patch_locs, p_size=patch_size, device=self.device)
|
236 |
+
else:
|
237 |
+
xs = utils.sampling.generalized_steps(x, x_cond, seq, self.model, self.betas, eta=0., device=self.device)
|
238 |
+
if last:
|
239 |
+
xs = xs[0][-1]
|
240 |
+
return xs
|
241 |
+
|
242 |
+
def sample_validation_patches(self, val_loader, step):
|
243 |
+
image_folder = os.path.join(self.config.data.val_save_dir, str(self.config.data.image_size))
|
244 |
+
with torch.no_grad():
|
245 |
+
if dist.get_rank() == 0:
|
246 |
+
print(f"Processing a single batch of validation images at step: {step}")
|
247 |
+
for i, (x, y) in enumerate(val_loader):
|
248 |
+
x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
|
249 |
+
break
|
250 |
+
n = x.size(0)
|
251 |
+
x_cond = x[:, :3, :, :].to(self.device) # 条件图像
|
252 |
+
x_cond = data_transform(x_cond)
|
253 |
+
x = torch.randn(n, 3, self.config.data.image_size, self.config.data.image_size, device=self.device)
|
254 |
+
x = self.sample_image(x_cond, x)
|
255 |
+
x = inverse_data_transform(x)
|
256 |
+
x_cond = inverse_data_transform(x_cond)
|
257 |
+
|
258 |
+
for i in range(n):
|
259 |
+
utils.logging.save_image(x_cond[i], os.path.join(image_folder, str(step), f"{i}_cond.png"))
|
260 |
+
utils.logging.save_image(x[i], os.path.join(image_folder, str(step), f"{i}.png"))
|
models/restoration.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import utils
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
def data_transform(X):
|
8 |
+
return 2 * X - 1.0
|
9 |
+
|
10 |
+
|
11 |
+
def inverse_data_transform(X):
|
12 |
+
return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
|
13 |
+
|
14 |
+
|
15 |
+
class DiffusiveRestoration:
|
16 |
+
def __init__(self, diffusion, config):
|
17 |
+
super(DiffusiveRestoration, self).__init__()
|
18 |
+
self.config = config
|
19 |
+
self.diffusion = diffusion
|
20 |
+
|
21 |
+
# 判断预训练模型是否存在
|
22 |
+
pretrained_model_path = self.config.training.resume + '.pth.tar'
|
23 |
+
assert os.path.isfile(pretrained_model_path), ('pretrained diffusion model path is wrong!')
|
24 |
+
self.diffusion.load_ddm_ckpt(pretrained_model_path, ema=True)
|
25 |
+
self.diffusion.model.eval()
|
26 |
+
self.diffusion.model.requires_grad_(False)
|
27 |
+
|
28 |
+
def restore(self, val_loader, r=None):
|
29 |
+
image_folder = self.config.data.test_save_dir
|
30 |
+
with torch.no_grad():
|
31 |
+
for i, (x, y) in tqdm(enumerate(val_loader)):
|
32 |
+
print(f"=> starting processing image named {y}")
|
33 |
+
x = x.flatten(start_dim=0, end_dim=1) if x.ndim == 5 else x
|
34 |
+
x_cond = x[:, :3, :, :].to(self.diffusion.device)
|
35 |
+
x_output = self.diffusive_restoration(x_cond, r=r)
|
36 |
+
x_output = inverse_data_transform(x_output)
|
37 |
+
utils.logging.save_image(x_output, os.path.join(image_folder, f"{y[0]}.png"))
|
38 |
+
|
39 |
+
def diffusive_restoration(self, x_cond, r=None):
|
40 |
+
p_size = self.config.data.image_size
|
41 |
+
h_list, w_list = self.overlapping_grid_indices(x_cond, output_size=p_size, r=r)
|
42 |
+
corners = [(i, j) for i in h_list for j in w_list]
|
43 |
+
x = torch.randn(x_cond.size(), device=self.diffusion.device)
|
44 |
+
x_output = self.diffusion.sample_image(x_cond, x, patch_locs=corners, patch_size=p_size)
|
45 |
+
return x_output
|
46 |
+
|
47 |
+
def overlapping_grid_indices(self, x_cond, output_size, r=None):
|
48 |
+
_, c, h, w = x_cond.shape
|
49 |
+
r = 16 if r is None else r
|
50 |
+
h_list = [i for i in range(0, h - output_size + 1, r)]
|
51 |
+
w_list = [i for i in range(0, w - output_size + 1, r)]
|
52 |
+
return h_list, w_list
|
53 |
+
|
54 |
+
def web_restore(self, image, r=None):
|
55 |
+
with torch.no_grad():
|
56 |
+
image_cond = image.to(self.diffusion.device)
|
57 |
+
image_output = self.diffusive_restoration(image_cond, r=r)
|
58 |
+
image_output = inverse_data_transform(image_output)
|
59 |
+
return image_output
|
models/unet.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
7 |
+
assert len(timesteps.shape) == 1
|
8 |
+
|
9 |
+
half_dim = embedding_dim // 2
|
10 |
+
emb = math.log(10000) / (half_dim - 1)
|
11 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
12 |
+
emb = emb.to(device=timesteps.device)
|
13 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
14 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
15 |
+
if embedding_dim % 2 == 1: # zero pad
|
16 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
17 |
+
return emb
|
18 |
+
|
19 |
+
|
20 |
+
def nonlinearity(x):
|
21 |
+
# swish
|
22 |
+
return x*torch.sigmoid(x)
|
23 |
+
|
24 |
+
|
25 |
+
def Normalize(in_channels):
|
26 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
27 |
+
|
28 |
+
|
29 |
+
class Upsample(nn.Module):
|
30 |
+
def __init__(self, in_channels, with_conv):
|
31 |
+
super().__init__()
|
32 |
+
self.with_conv = with_conv
|
33 |
+
if self.with_conv:
|
34 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
35 |
+
in_channels,
|
36 |
+
kernel_size=3,
|
37 |
+
stride=1,
|
38 |
+
padding=1)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
x = torch.nn.functional.interpolate(
|
42 |
+
x, scale_factor=2.0, mode="nearest")
|
43 |
+
if self.with_conv:
|
44 |
+
x = self.conv(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class Downsample(nn.Module):
|
49 |
+
def __init__(self, in_channels, with_conv):
|
50 |
+
super().__init__()
|
51 |
+
self.with_conv = with_conv
|
52 |
+
if self.with_conv:
|
53 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
54 |
+
in_channels,
|
55 |
+
kernel_size=3,
|
56 |
+
stride=2,
|
57 |
+
padding=0)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
if self.with_conv:
|
61 |
+
pad = (0, 1, 0, 1)
|
62 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
63 |
+
x = self.conv(x)
|
64 |
+
else:
|
65 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
class ResnetBlock(nn.Module):
|
70 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
71 |
+
dropout, temb_channels=512):
|
72 |
+
super().__init__()
|
73 |
+
self.in_channels = in_channels
|
74 |
+
out_channels = in_channels if out_channels is None else out_channels
|
75 |
+
self.out_channels = out_channels
|
76 |
+
self.use_conv_shortcut = conv_shortcut
|
77 |
+
|
78 |
+
self.norm1 = Normalize(in_channels)
|
79 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
80 |
+
out_channels,
|
81 |
+
kernel_size=3,
|
82 |
+
stride=1,
|
83 |
+
padding=1)
|
84 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
85 |
+
out_channels)
|
86 |
+
self.norm2 = Normalize(out_channels)
|
87 |
+
self.dropout = torch.nn.Dropout(dropout)
|
88 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1)
|
93 |
+
if self.in_channels != self.out_channels:
|
94 |
+
if self.use_conv_shortcut:
|
95 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
96 |
+
out_channels,
|
97 |
+
kernel_size=3,
|
98 |
+
stride=1,
|
99 |
+
padding=1)
|
100 |
+
else:
|
101 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
102 |
+
out_channels,
|
103 |
+
kernel_size=1,
|
104 |
+
stride=1,
|
105 |
+
padding=0)
|
106 |
+
|
107 |
+
def forward(self, x, temb):
|
108 |
+
h = x
|
109 |
+
h = self.norm1(h)
|
110 |
+
h = nonlinearity(h)
|
111 |
+
h = self.conv1(h)
|
112 |
+
|
113 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
114 |
+
|
115 |
+
h = self.norm2(h)
|
116 |
+
h = nonlinearity(h)
|
117 |
+
h = self.dropout(h)
|
118 |
+
h = self.conv2(h)
|
119 |
+
|
120 |
+
if self.in_channels != self.out_channels:
|
121 |
+
if self.use_conv_shortcut:
|
122 |
+
x = self.conv_shortcut(x)
|
123 |
+
else:
|
124 |
+
x = self.nin_shortcut(x)
|
125 |
+
|
126 |
+
return x+h
|
127 |
+
|
128 |
+
|
129 |
+
class AttnBlock(nn.Module):
|
130 |
+
def __init__(self, in_channels):
|
131 |
+
super().__init__()
|
132 |
+
self.in_channels = in_channels
|
133 |
+
|
134 |
+
self.norm = Normalize(in_channels)
|
135 |
+
self.q = torch.nn.Conv2d(in_channels,
|
136 |
+
in_channels,
|
137 |
+
kernel_size=1,
|
138 |
+
stride=1,
|
139 |
+
padding=0)
|
140 |
+
self.k = torch.nn.Conv2d(in_channels,
|
141 |
+
in_channels,
|
142 |
+
kernel_size=1,
|
143 |
+
stride=1,
|
144 |
+
padding=0)
|
145 |
+
self.v = torch.nn.Conv2d(in_channels,
|
146 |
+
in_channels,
|
147 |
+
kernel_size=1,
|
148 |
+
stride=1,
|
149 |
+
padding=0)
|
150 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
151 |
+
in_channels,
|
152 |
+
kernel_size=1,
|
153 |
+
stride=1,
|
154 |
+
padding=0)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
h_ = x
|
158 |
+
h_ = self.norm(h_)
|
159 |
+
q = self.q(h_)
|
160 |
+
k = self.k(h_)
|
161 |
+
v = self.v(h_)
|
162 |
+
|
163 |
+
# 自注意力
|
164 |
+
b, c, h, w = q.shape
|
165 |
+
q = q.reshape(b, c, h*w)
|
166 |
+
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
167 |
+
k = k.reshape(b, c, h*w) # b,c,hw
|
168 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
169 |
+
w_ = w_ * (int(c)**(-0.5))
|
170 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
171 |
+
|
172 |
+
# attend to values
|
173 |
+
v = v.reshape(b, c, h*w)
|
174 |
+
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
175 |
+
h_ = torch.bmm(v, w_)
|
176 |
+
h_ = h_.reshape(b, c, h, w)
|
177 |
+
|
178 |
+
h_ = self.proj_out(h_)
|
179 |
+
|
180 |
+
return x+h_
|
181 |
+
|
182 |
+
|
183 |
+
class DiffusionUNet(nn.Module):
|
184 |
+
def __init__(self, config):
|
185 |
+
super().__init__()
|
186 |
+
self.config = config
|
187 |
+
ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
|
188 |
+
num_res_blocks = config.model.num_res_blocks
|
189 |
+
attn_resolutions = config.model.attn_resolutions
|
190 |
+
dropout = config.model.dropout
|
191 |
+
in_channels = config.model.in_channels * 2 if config.data.conditional else config.model.in_channels
|
192 |
+
resolution = config.data.image_size
|
193 |
+
resamp_with_conv = config.model.resamp_with_conv
|
194 |
+
|
195 |
+
self.ch = ch
|
196 |
+
self.temb_ch = self.ch*4
|
197 |
+
self.num_resolutions = len(ch_mult)
|
198 |
+
self.num_res_blocks = num_res_blocks
|
199 |
+
self.resolution = resolution
|
200 |
+
self.in_channels = in_channels
|
201 |
+
|
202 |
+
# timestep embedding
|
203 |
+
self.temb = nn.Module()
|
204 |
+
self.temb.dense = nn.ModuleList([
|
205 |
+
torch.nn.Linear(self.ch,
|
206 |
+
self.temb_ch),
|
207 |
+
torch.nn.Linear(self.temb_ch,
|
208 |
+
self.temb_ch),
|
209 |
+
])
|
210 |
+
|
211 |
+
# 下采样
|
212 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
213 |
+
self.ch,
|
214 |
+
kernel_size=3,
|
215 |
+
stride=1,
|
216 |
+
padding=1)
|
217 |
+
|
218 |
+
curr_res = resolution
|
219 |
+
in_ch_mult = (1,)+ch_mult
|
220 |
+
self.down = nn.ModuleList()
|
221 |
+
block_in = None
|
222 |
+
for i_level in range(self.num_resolutions):
|
223 |
+
block = nn.ModuleList()
|
224 |
+
attn = nn.ModuleList()
|
225 |
+
block_in = ch*in_ch_mult[i_level]
|
226 |
+
block_out = ch*ch_mult[i_level]
|
227 |
+
for i_block in range(self.num_res_blocks):
|
228 |
+
block.append(ResnetBlock(in_channels=block_in,
|
229 |
+
out_channels=block_out,
|
230 |
+
temb_channels=self.temb_ch,
|
231 |
+
dropout=dropout))
|
232 |
+
block_in = block_out
|
233 |
+
if curr_res in attn_resolutions:
|
234 |
+
attn.append(AttnBlock(block_in))
|
235 |
+
down = nn.Module()
|
236 |
+
down.block = block
|
237 |
+
down.attn = attn
|
238 |
+
if i_level != self.num_resolutions-1:
|
239 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
240 |
+
curr_res = curr_res // 2
|
241 |
+
self.down.append(down)
|
242 |
+
|
243 |
+
# middle
|
244 |
+
self.mid = nn.Module()
|
245 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
246 |
+
out_channels=block_in,
|
247 |
+
temb_channels=self.temb_ch,
|
248 |
+
dropout=dropout)
|
249 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
250 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
251 |
+
out_channels=block_in,
|
252 |
+
temb_channels=self.temb_ch,
|
253 |
+
dropout=dropout)
|
254 |
+
|
255 |
+
# 上采样
|
256 |
+
self.up = nn.ModuleList()
|
257 |
+
for i_level in reversed(range(self.num_resolutions)):
|
258 |
+
block = nn.ModuleList()
|
259 |
+
attn = nn.ModuleList()
|
260 |
+
block_out = ch*ch_mult[i_level]
|
261 |
+
skip_in = ch*ch_mult[i_level]
|
262 |
+
for i_block in range(self.num_res_blocks+1):
|
263 |
+
if i_block == self.num_res_blocks:
|
264 |
+
skip_in = ch*in_ch_mult[i_level]
|
265 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
266 |
+
out_channels=block_out,
|
267 |
+
temb_channels=self.temb_ch,
|
268 |
+
dropout=dropout))
|
269 |
+
block_in = block_out
|
270 |
+
if curr_res in attn_resolutions:
|
271 |
+
attn.append(AttnBlock(block_in))
|
272 |
+
up = nn.Module()
|
273 |
+
up.block = block
|
274 |
+
up.attn = attn
|
275 |
+
if i_level != 0:
|
276 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
277 |
+
curr_res = curr_res * 2
|
278 |
+
self.up.insert(0, up) # prepend to get consistent order
|
279 |
+
|
280 |
+
# end
|
281 |
+
self.norm_out = Normalize(block_in)
|
282 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
283 |
+
out_ch,
|
284 |
+
kernel_size=3,
|
285 |
+
stride=1,
|
286 |
+
padding=1)
|
287 |
+
|
288 |
+
def forward(self, x, t):
|
289 |
+
assert x.shape[2] == x.shape[3] == self.resolution
|
290 |
+
|
291 |
+
# timestep embedding
|
292 |
+
temb = get_timestep_embedding(t, self.ch)
|
293 |
+
temb = self.temb.dense[0](temb)
|
294 |
+
temb = nonlinearity(temb)
|
295 |
+
temb = self.temb.dense[1](temb)
|
296 |
+
|
297 |
+
# 下采样
|
298 |
+
hs = [self.conv_in(x)]
|
299 |
+
for i_level in range(self.num_resolutions):
|
300 |
+
for i_block in range(self.num_res_blocks):
|
301 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
302 |
+
if len(self.down[i_level].attn) > 0:
|
303 |
+
h = self.down[i_level].attn[i_block](h)
|
304 |
+
hs.append(h)
|
305 |
+
if i_level != self.num_resolutions-1:
|
306 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
307 |
+
|
308 |
+
# middle
|
309 |
+
h = hs[-1]
|
310 |
+
h = self.mid.block_1(h, temb)
|
311 |
+
h = self.mid.attn_1(h)
|
312 |
+
h = self.mid.block_2(h, temb)
|
313 |
+
|
314 |
+
# 上采样
|
315 |
+
for i_level in reversed(range(self.num_resolutions)):
|
316 |
+
for i_block in range(self.num_res_blocks+1):
|
317 |
+
h = self.up[i_level].block[i_block](
|
318 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
319 |
+
if len(self.up[i_level].attn) > 0:
|
320 |
+
h = self.up[i_level].attn[i_block](h)
|
321 |
+
if i_level != 0:
|
322 |
+
h = self.up[i_level].upsample(h)
|
323 |
+
|
324 |
+
# end
|
325 |
+
h = self.norm_out(h)
|
326 |
+
h = nonlinearity(h)
|
327 |
+
h = self.conv_out(h)
|
328 |
+
return h
|
329 |
+
|
330 |
+
|
331 |
+
# net = DiffusionUNet()
|