Ngô Đức Bảo commited on
Commit
2a9ad6d
1 Parent(s): 53e6466

Upload 320 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Utils.ipynb +0 -0
  2. note.txt +10 -0
  3. resnet/DDPM_ResNet.ipynb +0 -0
  4. resnet/DDPM_ResNet.py +952 -0
  5. resnet/DDPM_ResNet_sample.py +856 -0
  6. resnet/log/info.log +585 -0
  7. resnet/log/iter_1000.png +0 -0
  8. resnet/log/iter_10000.png +0 -0
  9. resnet/log/iter_11000.png +0 -0
  10. resnet/log/iter_12000.png +0 -0
  11. resnet/log/iter_13000.png +0 -0
  12. resnet/log/iter_14000.png +0 -0
  13. resnet/log/iter_15000.png +0 -0
  14. resnet/log/iter_16000.png +0 -0
  15. resnet/log/iter_17000.png +0 -0
  16. resnet/log/iter_18000.png +0 -0
  17. resnet/log/iter_19000.png +0 -0
  18. resnet/log/iter_2000.png +0 -0
  19. resnet/log/iter_20000.png +0 -0
  20. resnet/log/iter_21000.png +0 -0
  21. resnet/log/iter_22000.png +0 -0
  22. resnet/log/iter_23000.png +0 -0
  23. resnet/log/iter_24000.png +0 -0
  24. resnet/log/iter_25000.png +0 -0
  25. resnet/log/iter_26000.png +0 -0
  26. resnet/log/iter_27000.png +0 -0
  27. resnet/log/iter_28000.png +0 -0
  28. resnet/log/iter_29000.png +0 -0
  29. resnet/log/iter_3000.png +0 -0
  30. resnet/log/iter_30000.png +0 -0
  31. resnet/log/iter_31000.png +0 -0
  32. resnet/log/iter_32000.png +0 -0
  33. resnet/log/iter_33000.png +0 -0
  34. resnet/log/iter_34000.png +0 -0
  35. resnet/log/iter_35000.png +0 -0
  36. resnet/log/iter_36000.png +0 -0
  37. resnet/log/iter_37000.png +0 -0
  38. resnet/log/iter_38000.png +0 -0
  39. resnet/log/iter_39000.png +0 -0
  40. resnet/log/iter_4000.png +0 -0
  41. resnet/log/iter_40000.png +0 -0
  42. resnet/log/iter_41000.png +0 -0
  43. resnet/log/iter_42000.png +0 -0
  44. resnet/log/iter_43000.png +0 -0
  45. resnet/log/iter_44000.png +0 -0
  46. resnet/log/iter_45000.png +0 -0
  47. resnet/log/iter_46000.png +0 -0
  48. resnet/log/iter_47000.png +0 -0
  49. resnet/log/iter_48000.png +0 -0
  50. resnet/log/iter_49000.png +0 -0
Utils.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
note.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ File Utils.ipynb bao gồm:
2
+ - Di chuyển ảnh: Copy ngẫu nhiên 5k ảnh trong 10k ảnh được sample (cho việc tính FID 10k) từ các mô hình để tính IS 5k.
3
+ - Tính FID 10k và IS 5k bằng thư viện torch-fidelity (https://github.com/toshas/torch-fidelity)
4
+
5
+ Trong mỗi thư mục (ví dụ resnet) gồm:
6
+ - Một thư mục "model" chứa checkpoint của mô hình cùng tên với thư mục gốc (ở đây là resnet) tại epoch thứ 30.
7
+ - Một thư mục "log" chứa log và ảnh sample sau mỗi 1000 iter. Số lượng ảnh sample có thể không bằng nhau do ban đầu để max_epoch là 50.
8
+ - Một tệp "DDPM_ResNet.ipynb", ở đây, ResNet chỉ là 1 ví dụ, với các mô hình khác sẽ có tên là "DDPM_ResNet_wo_t.ipynb" (mô hình Res-Net không sử dụng thời gian t), "DDPM_UNet.ipynb" (mô hình U-Net), "DDPM_UNet_wo_t.ipynb" (mô hình U-Net không có thời gian t). Trong đây sẽ tách rõ các phần của mô hình, code dùng để train, ... Mục đích chính của tệp này là dùng để huấn luyện mô hình.
9
+ - Một tệp "DDPM_ResNet.py", tên thay đổi theo mô hình như trên. Đây chỉ là bản convert từ một tệp ".ipynb" sang ".py" do treo máy nhà qua đêm, chạy trên tệp ".py" bằng terminal sẽ nhẹ nhàng hơn.
10
+ - Một tệp "DDPM_ResNet_sample.py", tên thay đổi theo mô hình như trên. Đây là bản chỉnh sửa từ tệp ".py", xoá hết tất cả các code về gọi data, huấn luyện, save log, ... và thay thế bằng code dùng để sample và lưu ảnh.
resnet/DDPM_ResNet.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
resnet/DDPM_ResNet.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # # Library
5
+
6
+ # In[1]:
7
+
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.cuda.amp import autocast
13
+
14
+ import torchvision
15
+ from torchvision.transforms import transforms
16
+ from torch.utils.data import DataLoader
17
+
18
+ from torch.optim import Adam
19
+
20
+ from einops import rearrange, reduce, repeat
21
+ import math
22
+ from random import random
23
+
24
+ from collections import namedtuple
25
+ from functools import partial
26
+ from tqdm.auto import tqdm
27
+ import logging
28
+ import os
29
+
30
+ from PIL import Image
31
+ from torchvision import utils
32
+
33
+
34
+ # # Helper
35
+
36
+ # ### Constant
37
+
38
+ # In[2]:
39
+
40
+
41
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
42
+
43
+
44
+ # ### Functions
45
+
46
+ # In[3]:
47
+
48
+
49
+ def exists(x):
50
+ return x is not None
51
+
52
+ def default(val, d):
53
+ if exists(val):
54
+ return val
55
+ return d() if callable(d) else d
56
+
57
+
58
+ # In[4]:
59
+
60
+
61
+ def cast_tuple(t, length = 1):
62
+ if isinstance(t, tuple):
63
+ return t
64
+ return ((t,) * length)
65
+
66
+
67
+ # In[5]:
68
+
69
+
70
+ def divisible_by(numer, denom):
71
+ return (numer % denom) == 0
72
+
73
+
74
+ # In[6]:
75
+
76
+
77
+ def identity(t, *args, **kwargs):
78
+ return t
79
+
80
+
81
+ # In[7]:
82
+
83
+
84
+ def cycle(dl):
85
+ while True:
86
+ for data in dl:
87
+ yield data
88
+
89
+
90
+ # In[8]:
91
+
92
+
93
+ def has_int_squareroot(num):
94
+ return (math.sqrt(num) ** 2) == num
95
+
96
+
97
+ # In[9]:
98
+
99
+
100
+ def num_to_groups(num, divisor):
101
+ groups = num // divisor
102
+ remainder = num % divisor
103
+ arr = [divisor] * groups
104
+ if remainder > 0:
105
+ arr.append(remainder)
106
+ return arr
107
+
108
+
109
+ # In[10]:
110
+
111
+
112
+ def convert_image_to_fn(img_type, image):
113
+ if image.mode != img_type:
114
+ return image.convert(img_type)
115
+ return image
116
+
117
+
118
+ # In[11]:
119
+
120
+
121
+ def extract(a, t, x_shape):
122
+ b, *_ = t.shape
123
+ out = a.gather(-1, t)
124
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
125
+
126
+
127
+ # ### Normalization Functions
128
+
129
+ # In[12]:
130
+
131
+
132
+ def normalize_to_neg_one_to_one(img):
133
+ return img * 2 - 1
134
+
135
+ def unnormalize_to_zero_to_one(t):
136
+ return (t + 1) * 0.5
137
+
138
+
139
+ # ### Sinusoidal positional embeds
140
+
141
+ # In[13]:
142
+
143
+
144
+ class SinusoidalPosEmb(nn.Module):
145
+ def __init__(self, dim, theta = 10000):
146
+ super().__init__()
147
+ self.dim = dim
148
+ self.theta = theta
149
+
150
+ def forward(self, x):
151
+ device = x.device
152
+ half_dim = self.dim // 2
153
+ emb = math.log(self.theta) / (half_dim - 1)
154
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
155
+ emb = x[:, None] * emb[None, :]
156
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
157
+ return emb
158
+
159
+
160
+ # In[14]:
161
+
162
+
163
+ class RandomOrLearnedSinusoidalPosEmb(nn.Module):
164
+ """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
165
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
166
+
167
+ def __init__(self, dim, is_random = False):
168
+ super().__init__()
169
+ assert divisible_by(dim, 2)
170
+ half_dim = dim // 2
171
+ self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
172
+
173
+ def forward(self, x):
174
+ x = rearrange(x, 'b -> b 1')
175
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
176
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
177
+ fouriered = torch.cat((x, fouriered), dim = -1)
178
+ return fouriered
179
+
180
+
181
+ # ### Schedule
182
+
183
+ # In[15]:
184
+
185
+
186
+ def linear_beta_schedule(timesteps):
187
+ """
188
+ linear schedule, proposed in original ddpm paper
189
+ """
190
+ scale = 1000 / timesteps
191
+ beta_start = scale * 0.0001
192
+ beta_end = scale * 0.02
193
+ return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
194
+
195
+
196
+ # In[16]:
197
+
198
+
199
+ def cosine_beta_schedule(timesteps, s = 0.008):
200
+ """
201
+ cosine schedule
202
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
203
+ """
204
+ steps = timesteps + 1
205
+ t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
206
+ alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
207
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
208
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
209
+ return torch.clip(betas, 0, 0.999)
210
+
211
+
212
+ # In[17]:
213
+
214
+
215
+ def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
216
+ """
217
+ sigmoid schedule
218
+ proposed in https://arxiv.org/abs/2212.11972 - Figure 8
219
+ better for images > 64x64, when used during training
220
+ """
221
+ steps = timesteps + 1
222
+ t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
223
+ v_start = torch.tensor(start / tau).sigmoid()
224
+ v_end = torch.tensor(end / tau).sigmoid()
225
+ alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
226
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
227
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
228
+ return torch.clip(betas, 0, 0.999)
229
+
230
+
231
+ # # Diffusion model
232
+
233
+ # In[18]:
234
+
235
+
236
+ class GaussianDiffusion(nn.Module):
237
+ # Copy from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L163
238
+
239
+ def __init__(
240
+ self,
241
+ model,
242
+ *,
243
+ image_size,
244
+ timesteps = 1000,
245
+ sampling_timesteps = None,
246
+ objective = 'pred_noise',
247
+ beta_schedule = 'linear',
248
+ schedule_fn_kwargs = dict(),
249
+ ddim_sampling_eta = 0.,
250
+ auto_normalize = True,
251
+ offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
252
+ min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
253
+ min_snr_gamma = 5
254
+ ):
255
+ super().__init__()
256
+ assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
257
+ assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond
258
+
259
+ self.model = model
260
+
261
+ self.channels = self.model.channels
262
+ self.self_condition = self.model.self_condition
263
+
264
+ self.image_size = image_size
265
+
266
+ self.objective = objective
267
+
268
+ assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
269
+
270
+ if beta_schedule == 'linear':
271
+ beta_schedule_fn = linear_beta_schedule
272
+ elif beta_schedule == 'cosine':
273
+ beta_schedule_fn = cosine_beta_schedule
274
+ elif beta_schedule == 'sigmoid':
275
+ beta_schedule_fn = sigmoid_beta_schedule
276
+ else:
277
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
278
+
279
+ betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
280
+
281
+ alphas = 1. - betas
282
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
283
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
284
+
285
+ timesteps, = betas.shape
286
+ self.num_timesteps = int(timesteps)
287
+
288
+ # sampling related parameters
289
+
290
+ self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
291
+
292
+ assert self.sampling_timesteps <= timesteps
293
+ self.is_ddim_sampling = self.sampling_timesteps < timesteps
294
+ self.ddim_sampling_eta = ddim_sampling_eta
295
+
296
+ # helper function to register buffer from float64 to float32
297
+
298
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
299
+
300
+ register_buffer('betas', betas)
301
+ register_buffer('alphas_cumprod', alphas_cumprod)
302
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
303
+
304
+ # calculations for diffusion q(x_t | x_{t-1}) and others
305
+
306
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
307
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
308
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
309
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
310
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
311
+
312
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
313
+
314
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
315
+
316
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
317
+
318
+ register_buffer('posterior_variance', posterior_variance)
319
+
320
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
321
+
322
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
323
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
324
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
325
+
326
+ # offset noise strength - in blogpost, they claimed 0.1 was ideal
327
+
328
+ self.offset_noise_strength = offset_noise_strength
329
+
330
+ # derive loss weight
331
+ # snr - signal noise ratio
332
+
333
+ snr = alphas_cumprod / (1 - alphas_cumprod)
334
+
335
+ # https://arxiv.org/abs/2303.09556
336
+
337
+ maybe_clipped_snr = snr.clone()
338
+ if min_snr_loss_weight:
339
+ maybe_clipped_snr.clamp_(max = min_snr_gamma)
340
+
341
+ if objective == 'pred_noise':
342
+ register_buffer('loss_weight', maybe_clipped_snr / snr)
343
+ elif objective == 'pred_x0':
344
+ register_buffer('loss_weight', maybe_clipped_snr)
345
+ elif objective == 'pred_v':
346
+ register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
347
+
348
+ # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
349
+
350
+ self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
351
+ self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
352
+
353
+ @property
354
+ def device(self):
355
+ return self.betas.device
356
+
357
+ def predict_start_from_noise(self, x_t, t, noise):
358
+ return (
359
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
360
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
361
+ )
362
+
363
+ def predict_noise_from_start(self, x_t, t, x0):
364
+ return (
365
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
366
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
367
+ )
368
+
369
+ def predict_v(self, x_start, t, noise):
370
+ return (
371
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
372
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
373
+ )
374
+
375
+ def predict_start_from_v(self, x_t, t, v):
376
+ return (
377
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
378
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
379
+ )
380
+
381
+ def q_posterior(self, x_start, x_t, t):
382
+ posterior_mean = (
383
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
384
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
385
+ )
386
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
387
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
388
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
389
+
390
+ def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
391
+ model_output = self.model(x, t, x_self_cond)
392
+ maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
393
+
394
+ if self.objective == 'pred_noise':
395
+ pred_noise = model_output
396
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
397
+ x_start = maybe_clip(x_start)
398
+
399
+ if clip_x_start and rederive_pred_noise:
400
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
401
+
402
+ elif self.objective == 'pred_x0':
403
+ x_start = model_output
404
+ x_start = maybe_clip(x_start)
405
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
406
+
407
+ elif self.objective == 'pred_v':
408
+ v = model_output
409
+ x_start = self.predict_start_from_v(x, t, v)
410
+ x_start = maybe_clip(x_start)
411
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
412
+
413
+ return ModelPrediction(pred_noise, x_start)
414
+
415
+ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
416
+ preds = self.model_predictions(x, t, x_self_cond)
417
+ x_start = preds.pred_x_start
418
+
419
+ if clip_denoised:
420
+ x_start.clamp_(-1., 1.)
421
+
422
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
423
+ return model_mean, posterior_variance, posterior_log_variance, x_start
424
+
425
+ @torch.inference_mode()
426
+ def p_sample(self, x, t: int, x_self_cond = None):
427
+ b, *_, device = *x.shape, self.device
428
+ batched_times = torch.full((b,), t, device = device, dtype = torch.long)
429
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
430
+ noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
431
+ pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
432
+ return pred_img, x_start
433
+
434
+ @torch.inference_mode()
435
+ def p_sample_loop(self, shape, return_all_timesteps = False):
436
+ batch, device = shape[0], self.device
437
+
438
+ img = torch.randn(shape, device = device)
439
+ imgs = [img]
440
+
441
+ x_start = None
442
+
443
+ for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
444
+ self_cond = x_start if self.self_condition else None
445
+ img, x_start = self.p_sample(img, t, self_cond)
446
+ imgs.append(img)
447
+
448
+ ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
449
+
450
+ ret = self.unnormalize(ret)
451
+ return ret
452
+
453
+ @torch.inference_mode()
454
+ def ddim_sample(self, shape, return_all_timesteps = False):
455
+ batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
456
+
457
+ times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
458
+ times = list(reversed(times.int().tolist()))
459
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
460
+
461
+ img = torch.randn(shape, device = device)
462
+ imgs = [img]
463
+
464
+ x_start = None
465
+
466
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
467
+ time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
468
+ self_cond = x_start if self.self_condition else None
469
+ pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
470
+
471
+ if time_next < 0:
472
+ img = x_start
473
+ imgs.append(img)
474
+ continue
475
+
476
+ alpha = self.alphas_cumprod[time]
477
+ alpha_next = self.alphas_cumprod[time_next]
478
+
479
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
480
+ c = (1 - alpha_next - sigma ** 2).sqrt()
481
+
482
+ noise = torch.randn_like(img)
483
+
484
+ img = x_start * alpha_next.sqrt() + \
485
+ c * pred_noise + \
486
+ sigma * noise
487
+
488
+ imgs.append(img)
489
+
490
+ ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
491
+
492
+ ret = self.unnormalize(ret)
493
+ return ret
494
+
495
+ @torch.inference_mode()
496
+ def sample(self, batch_size = 16, return_all_timesteps = False):
497
+ image_size, channels = self.image_size, self.channels
498
+ sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
499
+ return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
500
+
501
+ @torch.inference_mode()
502
+ def interpolate(self, x1, x2, t = None, lam = 0.5):
503
+ b, *_, device = *x1.shape, x1.device
504
+ t = default(t, self.num_timesteps - 1)
505
+
506
+ assert x1.shape == x2.shape
507
+
508
+ t_batched = torch.full((b,), t, device = device)
509
+ xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
510
+
511
+ img = (1 - lam) * xt1 + lam * xt2
512
+
513
+ x_start = None
514
+
515
+ for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
516
+ self_cond = x_start if self.self_condition else None
517
+ img, x_start = self.p_sample(img, i, self_cond)
518
+
519
+ return img
520
+
521
+ @autocast(enabled = False)
522
+ def q_sample(self, x_start, t, noise = None):
523
+ noise = default(noise, lambda: torch.randn_like(x_start))
524
+
525
+ return (
526
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
527
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
528
+ )
529
+
530
+ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
531
+ b, c, h, w = x_start.shape
532
+
533
+ noise = default(noise, lambda: torch.randn_like(x_start))
534
+
535
+ # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
536
+
537
+ offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
538
+
539
+ if offset_noise_strength > 0.:
540
+ offset_noise = torch.randn(x_start.shape[:2], device = self.device)
541
+ noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
542
+
543
+ # noise sample
544
+
545
+ x = self.q_sample(x_start = x_start, t = t, noise = noise)
546
+
547
+ # if doing self-conditioning, 50% of the time, predict x_start from current set of times
548
+ # and condition with unet with that
549
+ # this technique will slow down training by 25%, but seems to lower FID significantly
550
+
551
+ x_self_cond = None
552
+ if self.self_condition and random() < 0.5:
553
+ with torch.no_grad():
554
+ x_self_cond = self.model_predictions(x, t).pred_x_start
555
+ x_self_cond.detach_()
556
+
557
+ # predict and take gradient step
558
+
559
+ model_out = self.model(x, t, x_self_cond)
560
+
561
+ if self.objective == 'pred_noise':
562
+ target = noise
563
+ elif self.objective == 'pred_x0':
564
+ target = x_start
565
+ elif self.objective == 'pred_v':
566
+ v = self.predict_v(x_start, t, noise)
567
+ target = v
568
+ else:
569
+ raise ValueError(f'unknown objective {self.objective}')
570
+
571
+ loss = F.mse_loss(model_out, target, reduction = 'none')
572
+ loss = reduce(loss, 'b ... -> b', 'mean')
573
+
574
+ loss = loss * extract(self.loss_weight, t, loss.shape)
575
+ return loss.mean()
576
+
577
+ def forward(self, img, *args, **kwargs):
578
+ b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
579
+ assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
580
+ t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
581
+
582
+ img = self.normalize(img)
583
+ return self.p_losses(img, t, *args, **kwargs)
584
+
585
+
586
+ # # Resnet Model
587
+
588
+ # In[19]:
589
+
590
+
591
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
592
+ return nn.Conv2d(
593
+ in_channels, out_channels, kernel_size,
594
+ padding=(kernel_size//2), bias=bias)
595
+
596
+
597
+ # In[20]:
598
+
599
+
600
+ class Swish(nn.Module):
601
+ def forward(self, x):
602
+ return x * torch.sigmoid(x)
603
+
604
+
605
+ # In[21]:
606
+
607
+
608
+ class AttnBlock(nn.Module):
609
+ def __init__(self, in_ch):
610
+ super().__init__()
611
+ self.group_norm = nn.GroupNorm(32, in_ch)
612
+ self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
613
+ self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
614
+ self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
615
+ self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
616
+
617
+ def forward(self, x):
618
+ B, C, H, W = x.shape
619
+ h = self.group_norm(x)
620
+ q = self.proj_q(h)
621
+ k = self.proj_k(h)
622
+ v = self.proj_v(h)
623
+
624
+ q = q.permute(0, 2, 3, 1).view(B, H * W, C)
625
+ k = k.view(B, C, H * W)
626
+ w = torch.bmm(q, k) * (int(C) ** (-0.5))
627
+ assert list(w.shape) == [B, H * W, H * W]
628
+ w = F.softmax(w, dim=-1)
629
+
630
+ v = v.permute(0, 2, 3, 1).view(B, H * W, C)
631
+ h = torch.bmm(w, v)
632
+ assert list(h.shape) == [B, H * W, C]
633
+ h = h.view(B, H, W, C).permute(0, 3, 1, 2)
634
+ h = self.proj(h)
635
+
636
+ return x + h
637
+
638
+
639
+ # In[22]:
640
+
641
+
642
+ class ResBlock(nn.Module):
643
+ def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
644
+ super().__init__()
645
+ self.block1 = nn.Sequential(
646
+ nn.GroupNorm(32, in_ch),
647
+ Swish(),
648
+ nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
649
+ )
650
+ self.temb_proj = nn.Sequential(
651
+ Swish(),
652
+ nn.Linear(tdim, out_ch),
653
+ )
654
+ self.block2 = nn.Sequential(
655
+ nn.GroupNorm(32, out_ch),
656
+ Swish(),
657
+ nn.Dropout(dropout),
658
+ nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
659
+ )
660
+ if in_ch != out_ch:
661
+ self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
662
+ else:
663
+ self.shortcut = nn.Identity()
664
+ if attn:
665
+ self.attn = AttnBlock(out_ch)
666
+ else:
667
+ self.attn = nn.Identity()
668
+
669
+ def forward(self, x, temb):
670
+ h = self.block1(x)
671
+ h += self.temb_proj(temb)[:, :, None, None]
672
+ h = self.block2(h)
673
+
674
+ h = h + self.shortcut(x)
675
+ h = self.attn(h)
676
+ return h
677
+
678
+
679
+ # In[23]:
680
+
681
+
682
+ class EDSR(nn.Module):
683
+ # Modified from https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/edsr.py#L31
684
+
685
+ def __init__(self,
686
+ resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock'],
687
+ n_feats=128,
688
+ t_dim=256,
689
+ dropout=0.1,
690
+ channels=1,
691
+ out_dim=1,
692
+ self_condition = False,
693
+ learned_sinusoidal_cond=False,
694
+ random_fourier_features=False,
695
+ learned_sinusoidal_dim=16,
696
+ sinusoidal_pos_emb_theta=10000,
697
+ conv=default_conv):
698
+ super(EDSR, self).__init__()
699
+
700
+ self.resblocks = resblocks
701
+ self.n_feats = n_feats
702
+ self.t_dim = t_dim
703
+ self.dropout = dropout
704
+ self.channels = channels
705
+ self.out_dim = out_dim
706
+ self.self_condition = self_condition
707
+ self.kernel_size = 3
708
+
709
+ # define time embedding
710
+ if learned_sinusoidal_cond:
711
+ sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
712
+ fourier_dim = learned_sinusoidal_dim + 1
713
+ else:
714
+ sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta)
715
+ fourier_dim = self.n_feats
716
+
717
+ self.time_mlp = nn.Sequential(
718
+ sinu_pos_emb,
719
+ nn.Linear(fourier_dim, self.t_dim),
720
+ nn.GELU(),
721
+ nn.Linear(self.t_dim, self.t_dim)
722
+ )
723
+
724
+ # define head module
725
+ self.head = conv(self.channels, self.n_feats, self.kernel_size)
726
+
727
+ # define body module
728
+ self.body = nn.ModuleList()
729
+ for block in resblocks:
730
+ if block == "ResBlock":
731
+ self.body.append(
732
+ ResBlock(in_ch=self.n_feats,
733
+ out_ch=self.n_feats,
734
+ tdim=self.t_dim,
735
+ dropout=self.dropout,
736
+ attn=False))
737
+ elif block == "AttnBlock":
738
+ self.body.append(
739
+ ResBlock(in_ch=self.n_feats,
740
+ out_ch=self.n_feats,
741
+ tdim=self.t_dim,
742
+ dropout=self.dropout,
743
+ attn=True))
744
+ else:
745
+ raise NotImplementedError("Model currently doesn't support this kind of block!")
746
+ self.body.append(conv(self.n_feats, self.n_feats, self.kernel_size))
747
+
748
+ # define tail module
749
+ self.tail = conv(self.n_feats, self.out_dim, self.kernel_size)
750
+
751
+
752
+ def forward(self, x, t, cond=None):
753
+ t = self.time_mlp(t)
754
+
755
+ x = self.head(x)
756
+
757
+ res = x
758
+ for block in self.body:
759
+ if isinstance(block, ResBlock):
760
+ res = block(res, t)
761
+ else:
762
+ res = block(res)
763
+ res += x
764
+
765
+ x = self.tail(res)
766
+
767
+ return x
768
+
769
+
770
+ # # Train
771
+
772
+ # In[24]:
773
+
774
+
775
+ # output dir
776
+ save_path = 'resnet/model'
777
+ log_path = 'resnet/log'
778
+
779
+ if not os.path.exists(log_path):
780
+ os.mkdir(log_path)
781
+ if not os.path.exists(save_path):
782
+ os.mkdir(save_path)
783
+
784
+
785
+ # In[25]:
786
+
787
+
788
+ # setup logging
789
+
790
+ # Setup logging to file
791
+ logging.basicConfig(
792
+ filename=os.path.join(log_path, 'info.log'),
793
+ filemode="w",
794
+ level=logging.DEBUG,
795
+ format= '[%(asctime)s] %(levelname)s - %(message)s',
796
+ datefmt='%H:%M:%S',
797
+ force=True
798
+ )
799
+
800
+
801
+ # Stop PIL from printing to file
802
+ pil_logger = logging.getLogger('PIL')
803
+ pil_logger.setLevel(logging.INFO)
804
+
805
+ # write and print at the same time
806
+ console = logging.StreamHandler()
807
+ console.setLevel(logging.INFO)
808
+ logging.getLogger().addHandler(console)
809
+
810
+ logger = logging.getLogger('Diffusion_Resnet')
811
+
812
+
813
+ # In[26]:
814
+
815
+
816
+ # define model
817
+ model = EDSR(
818
+ resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock',
819
+ 'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock',],
820
+ n_feats=256,
821
+ t_dim=512,
822
+ dropout=0.1,
823
+ channels=1, # MNIST
824
+ out_dim=1, # MNIST
825
+ learned_sinusoidal_cond=False,
826
+ random_fourier_features=False,
827
+ learned_sinusoidal_dim=16,
828
+ sinusoidal_pos_emb_theta=10000,)
829
+
830
+ diffusion_model = GaussianDiffusion(
831
+ model,
832
+ image_size=28, # MNIST
833
+ timesteps=1000,
834
+ sampling_timesteps=None,
835
+ objective ='pred_noise',
836
+ beta_schedule ='linear',
837
+ schedule_fn_kwargs=dict(),
838
+ ddim_sampling_eta= 0.,
839
+ auto_normalize = True,
840
+ offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
841
+ min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
842
+ min_snr_gamma = 5)
843
+
844
+
845
+ # In[27]:
846
+
847
+
848
+ # define dataset
849
+ transform = transforms.Compose([
850
+ transforms.ToTensor(),
851
+ # v2.Normalize((0.1307,), (0.3081,)), # https://stackoverflow.com/questions/70892017/normalize-mnist-in-pytorch
852
+ ])
853
+
854
+ train_dataset = torchvision.datasets.MNIST(root='.', train=True,
855
+ download=True, transform=transform)
856
+ # test_dataset = torchvision.datasets.MNIST(root='.', train=True,
857
+ # download=True, transform=transform)
858
+
859
+ train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
860
+ # test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
861
+
862
+
863
+ # In[28]:
864
+
865
+
866
+ # define optimizer
867
+ train_lr = 1e-4
868
+ adam_betas = (0.9, 0.99)
869
+
870
+ optimizer = Adam(diffusion_model.parameters(),
871
+ lr=train_lr,
872
+ betas=adam_betas)
873
+
874
+
875
+ # In[29]:
876
+
877
+
878
+ # device
879
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
880
+
881
+
882
+ # In[30]:
883
+
884
+
885
+ # trainer
886
+ max_epoches = 50
887
+ iter_print = 100
888
+ iter_sample = 1000
889
+ save_each = 1
890
+
891
+ diffusion_model = diffusion_model.to(device)
892
+
893
+ last_trained_path = None
894
+ if last_trained_path:
895
+ data = torch.load(os.path.join(last_trained_path))
896
+ diffusion_model.load_state_dict(data['model'])
897
+ optimizer.load_state_dict(data['opt'])
898
+ count = data['step']
899
+ start_epoch = data['epoch']
900
+ log_loss = data['loss']
901
+ else:
902
+ count = 0
903
+ start_epoch = 1
904
+ log_loss = []
905
+
906
+ for epoch in range(start_epoch, max_epoches+1):
907
+ diffusion_model.train()
908
+ for img, _ in train_dataloader:
909
+ img = img.to(device)
910
+
911
+ loss = diffusion_model(img)
912
+
913
+ optimizer.zero_grad()
914
+ loss.backward()
915
+ optimizer.step()
916
+
917
+ if count % iter_print == 0 or count == 0:
918
+ logger.info('Epoch {}/{}, Iter {}: Loss = {}, lr = {}'.format(
919
+ epoch,
920
+ max_epoches,
921
+ count,
922
+ loss.mean().item(),
923
+ train_lr,
924
+ ))
925
+
926
+ log_loss.append(loss.mean().item())
927
+
928
+ loss = None
929
+
930
+ count += 1
931
+
932
+ if count % iter_sample == 0:
933
+ diffusion_model.eval()
934
+
935
+ sample_imgs = diffusion_model.sample(batch_size=16)
936
+
937
+ utils.save_image(sample_imgs,
938
+ os.path.join(log_path, f"iter_{count}.png"),
939
+ nrow = int(math.sqrt(16)))
940
+
941
+
942
+ if epoch % save_each == 0:
943
+ data = {
944
+ 'model': diffusion_model.state_dict(),
945
+ 'opt': optimizer.state_dict(),
946
+ 'step': count,
947
+ 'epoch': epoch,
948
+ 'loss': log_loss,
949
+ }
950
+
951
+ torch.save(data, os.path.join(save_path, f"epoch_{epoch}.pth"))
952
+
resnet/DDPM_ResNet_sample.py ADDED
@@ -0,0 +1,856 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # # Library
5
+
6
+ # In[1]:
7
+
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.cuda.amp import autocast
13
+
14
+ import torchvision
15
+ from torchvision.transforms import transforms
16
+ from torch.utils.data import DataLoader
17
+
18
+ from torch.optim import Adam
19
+
20
+ from einops import rearrange, reduce, repeat
21
+ import math
22
+ from random import random
23
+
24
+ from collections import namedtuple
25
+ from functools import partial
26
+ from tqdm.auto import tqdm
27
+ import logging
28
+ import os
29
+
30
+ from PIL import Image
31
+ from torchvision import utils
32
+
33
+
34
+ # # Helper
35
+
36
+ # ### Constant
37
+
38
+ # In[2]:
39
+
40
+
41
+ ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
42
+
43
+
44
+ # ### Functions
45
+
46
+ # In[3]:
47
+
48
+
49
+ def exists(x):
50
+ return x is not None
51
+
52
+ def default(val, d):
53
+ if exists(val):
54
+ return val
55
+ return d() if callable(d) else d
56
+
57
+
58
+ # In[4]:
59
+
60
+
61
+ def cast_tuple(t, length = 1):
62
+ if isinstance(t, tuple):
63
+ return t
64
+ return ((t,) * length)
65
+
66
+
67
+ # In[5]:
68
+
69
+
70
+ def divisible_by(numer, denom):
71
+ return (numer % denom) == 0
72
+
73
+
74
+ # In[6]:
75
+
76
+
77
+ def identity(t, *args, **kwargs):
78
+ return t
79
+
80
+
81
+ # In[7]:
82
+
83
+
84
+ def cycle(dl):
85
+ while True:
86
+ for data in dl:
87
+ yield data
88
+
89
+
90
+ # In[8]:
91
+
92
+
93
+ def has_int_squareroot(num):
94
+ return (math.sqrt(num) ** 2) == num
95
+
96
+
97
+ # In[9]:
98
+
99
+
100
+ def num_to_groups(num, divisor):
101
+ groups = num // divisor
102
+ remainder = num % divisor
103
+ arr = [divisor] * groups
104
+ if remainder > 0:
105
+ arr.append(remainder)
106
+ return arr
107
+
108
+
109
+ # In[10]:
110
+
111
+
112
+ def convert_image_to_fn(img_type, image):
113
+ if image.mode != img_type:
114
+ return image.convert(img_type)
115
+ return image
116
+
117
+
118
+ # In[11]:
119
+
120
+
121
+ def extract(a, t, x_shape):
122
+ b, *_ = t.shape
123
+ out = a.gather(-1, t)
124
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
125
+
126
+
127
+ # ### Normalization Functions
128
+
129
+ # In[12]:
130
+
131
+
132
+ def normalize_to_neg_one_to_one(img):
133
+ return img * 2 - 1
134
+
135
+ def unnormalize_to_zero_to_one(t):
136
+ return (t + 1) * 0.5
137
+
138
+
139
+ # ### Sinusoidal positional embeds
140
+
141
+ # In[13]:
142
+
143
+
144
+ class SinusoidalPosEmb(nn.Module):
145
+ def __init__(self, dim, theta = 10000):
146
+ super().__init__()
147
+ self.dim = dim
148
+ self.theta = theta
149
+
150
+ def forward(self, x):
151
+ device = x.device
152
+ half_dim = self.dim // 2
153
+ emb = math.log(self.theta) / (half_dim - 1)
154
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
155
+ emb = x[:, None] * emb[None, :]
156
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
157
+ return emb
158
+
159
+
160
+ # In[14]:
161
+
162
+
163
+ class RandomOrLearnedSinusoidalPosEmb(nn.Module):
164
+ """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
165
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
166
+
167
+ def __init__(self, dim, is_random = False):
168
+ super().__init__()
169
+ assert divisible_by(dim, 2)
170
+ half_dim = dim // 2
171
+ self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
172
+
173
+ def forward(self, x):
174
+ x = rearrange(x, 'b -> b 1')
175
+ freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
176
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
177
+ fouriered = torch.cat((x, fouriered), dim = -1)
178
+ return fouriered
179
+
180
+
181
+ # ### Schedule
182
+
183
+ # In[15]:
184
+
185
+
186
+ def linear_beta_schedule(timesteps):
187
+ """
188
+ linear schedule, proposed in original ddpm paper
189
+ """
190
+ scale = 1000 / timesteps
191
+ beta_start = scale * 0.0001
192
+ beta_end = scale * 0.02
193
+ return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
194
+
195
+
196
+ # In[16]:
197
+
198
+
199
+ def cosine_beta_schedule(timesteps, s = 0.008):
200
+ """
201
+ cosine schedule
202
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
203
+ """
204
+ steps = timesteps + 1
205
+ t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
206
+ alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
207
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
208
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
209
+ return torch.clip(betas, 0, 0.999)
210
+
211
+
212
+ # In[17]:
213
+
214
+
215
+ def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
216
+ """
217
+ sigmoid schedule
218
+ proposed in https://arxiv.org/abs/2212.11972 - Figure 8
219
+ better for images > 64x64, when used during training
220
+ """
221
+ steps = timesteps + 1
222
+ t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
223
+ v_start = torch.tensor(start / tau).sigmoid()
224
+ v_end = torch.tensor(end / tau).sigmoid()
225
+ alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
226
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
227
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
228
+ return torch.clip(betas, 0, 0.999)
229
+
230
+
231
+ # # Diffusion model
232
+
233
+ # In[18]:
234
+
235
+
236
+ class GaussianDiffusion(nn.Module):
237
+ # Copy from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L163
238
+
239
+ def __init__(
240
+ self,
241
+ model,
242
+ *,
243
+ image_size,
244
+ timesteps = 1000,
245
+ sampling_timesteps = None,
246
+ objective = 'pred_noise',
247
+ beta_schedule = 'linear',
248
+ schedule_fn_kwargs = dict(),
249
+ ddim_sampling_eta = 0.,
250
+ auto_normalize = True,
251
+ offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
252
+ min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
253
+ min_snr_gamma = 5
254
+ ):
255
+ super().__init__()
256
+ assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
257
+ assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond
258
+
259
+ self.model = model
260
+
261
+ self.channels = self.model.channels
262
+ self.self_condition = self.model.self_condition
263
+
264
+ self.image_size = image_size
265
+
266
+ self.objective = objective
267
+
268
+ assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
269
+
270
+ if beta_schedule == 'linear':
271
+ beta_schedule_fn = linear_beta_schedule
272
+ elif beta_schedule == 'cosine':
273
+ beta_schedule_fn = cosine_beta_schedule
274
+ elif beta_schedule == 'sigmoid':
275
+ beta_schedule_fn = sigmoid_beta_schedule
276
+ else:
277
+ raise ValueError(f'unknown beta schedule {beta_schedule}')
278
+
279
+ betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
280
+
281
+ alphas = 1. - betas
282
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
283
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
284
+
285
+ timesteps, = betas.shape
286
+ self.num_timesteps = int(timesteps)
287
+
288
+ # sampling related parameters
289
+
290
+ self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
291
+
292
+ assert self.sampling_timesteps <= timesteps
293
+ self.is_ddim_sampling = self.sampling_timesteps < timesteps
294
+ self.ddim_sampling_eta = ddim_sampling_eta
295
+
296
+ # helper function to register buffer from float64 to float32
297
+
298
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
299
+
300
+ register_buffer('betas', betas)
301
+ register_buffer('alphas_cumprod', alphas_cumprod)
302
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
303
+
304
+ # calculations for diffusion q(x_t | x_{t-1}) and others
305
+
306
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
307
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
308
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
309
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
310
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
311
+
312
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
313
+
314
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
315
+
316
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
317
+
318
+ register_buffer('posterior_variance', posterior_variance)
319
+
320
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
321
+
322
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
323
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
324
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
325
+
326
+ # offset noise strength - in blogpost, they claimed 0.1 was ideal
327
+
328
+ self.offset_noise_strength = offset_noise_strength
329
+
330
+ # derive loss weight
331
+ # snr - signal noise ratio
332
+
333
+ snr = alphas_cumprod / (1 - alphas_cumprod)
334
+
335
+ # https://arxiv.org/abs/2303.09556
336
+
337
+ maybe_clipped_snr = snr.clone()
338
+ if min_snr_loss_weight:
339
+ maybe_clipped_snr.clamp_(max = min_snr_gamma)
340
+
341
+ if objective == 'pred_noise':
342
+ register_buffer('loss_weight', maybe_clipped_snr / snr)
343
+ elif objective == 'pred_x0':
344
+ register_buffer('loss_weight', maybe_clipped_snr)
345
+ elif objective == 'pred_v':
346
+ register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
347
+
348
+ # auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
349
+
350
+ self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
351
+ self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
352
+
353
+ @property
354
+ def device(self):
355
+ return self.betas.device
356
+
357
+ def predict_start_from_noise(self, x_t, t, noise):
358
+ return (
359
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
360
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
361
+ )
362
+
363
+ def predict_noise_from_start(self, x_t, t, x0):
364
+ return (
365
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
366
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
367
+ )
368
+
369
+ def predict_v(self, x_start, t, noise):
370
+ return (
371
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
372
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
373
+ )
374
+
375
+ def predict_start_from_v(self, x_t, t, v):
376
+ return (
377
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
378
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
379
+ )
380
+
381
+ def q_posterior(self, x_start, x_t, t):
382
+ posterior_mean = (
383
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
384
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
385
+ )
386
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
387
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
388
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
389
+
390
+ def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
391
+ model_output = self.model(x, t, x_self_cond)
392
+ maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
393
+
394
+ if self.objective == 'pred_noise':
395
+ pred_noise = model_output
396
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
397
+ x_start = maybe_clip(x_start)
398
+
399
+ if clip_x_start and rederive_pred_noise:
400
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
401
+
402
+ elif self.objective == 'pred_x0':
403
+ x_start = model_output
404
+ x_start = maybe_clip(x_start)
405
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
406
+
407
+ elif self.objective == 'pred_v':
408
+ v = model_output
409
+ x_start = self.predict_start_from_v(x, t, v)
410
+ x_start = maybe_clip(x_start)
411
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
412
+
413
+ return ModelPrediction(pred_noise, x_start)
414
+
415
+ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
416
+ preds = self.model_predictions(x, t, x_self_cond)
417
+ x_start = preds.pred_x_start
418
+
419
+ if clip_denoised:
420
+ x_start.clamp_(-1., 1.)
421
+
422
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
423
+ return model_mean, posterior_variance, posterior_log_variance, x_start
424
+
425
+ @torch.inference_mode()
426
+ def p_sample(self, x, t: int, x_self_cond = None):
427
+ b, *_, device = *x.shape, self.device
428
+ batched_times = torch.full((b,), t, device = device, dtype = torch.long)
429
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
430
+ noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
431
+ pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
432
+ return pred_img, x_start
433
+
434
+ @torch.inference_mode()
435
+ def p_sample_loop(self, shape, return_all_timesteps = False):
436
+ batch, device = shape[0], self.device
437
+
438
+ img = torch.randn(shape, device = device)
439
+ imgs = [img]
440
+
441
+ x_start = None
442
+
443
+ for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
444
+ self_cond = x_start if self.self_condition else None
445
+ img, x_start = self.p_sample(img, t, self_cond)
446
+ imgs.append(img)
447
+
448
+ ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
449
+
450
+ ret = self.unnormalize(ret)
451
+ return ret
452
+
453
+ @torch.inference_mode()
454
+ def ddim_sample(self, shape, return_all_timesteps = False):
455
+ batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
456
+
457
+ times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
458
+ times = list(reversed(times.int().tolist()))
459
+ time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
460
+
461
+ img = torch.randn(shape, device = device)
462
+ imgs = [img]
463
+
464
+ x_start = None
465
+
466
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
467
+ time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
468
+ self_cond = x_start if self.self_condition else None
469
+ pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
470
+
471
+ if time_next < 0:
472
+ img = x_start
473
+ imgs.append(img)
474
+ continue
475
+
476
+ alpha = self.alphas_cumprod[time]
477
+ alpha_next = self.alphas_cumprod[time_next]
478
+
479
+ sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
480
+ c = (1 - alpha_next - sigma ** 2).sqrt()
481
+
482
+ noise = torch.randn_like(img)
483
+
484
+ img = x_start * alpha_next.sqrt() + \
485
+ c * pred_noise + \
486
+ sigma * noise
487
+
488
+ imgs.append(img)
489
+
490
+ ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
491
+
492
+ ret = self.unnormalize(ret)
493
+ return ret
494
+
495
+ @torch.inference_mode()
496
+ def sample(self, batch_size = 16, return_all_timesteps = False):
497
+ image_size, channels = self.image_size, self.channels
498
+ sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
499
+ return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
500
+
501
+ @torch.inference_mode()
502
+ def interpolate(self, x1, x2, t = None, lam = 0.5):
503
+ b, *_, device = *x1.shape, x1.device
504
+ t = default(t, self.num_timesteps - 1)
505
+
506
+ assert x1.shape == x2.shape
507
+
508
+ t_batched = torch.full((b,), t, device = device)
509
+ xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
510
+
511
+ img = (1 - lam) * xt1 + lam * xt2
512
+
513
+ x_start = None
514
+
515
+ for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
516
+ self_cond = x_start if self.self_condition else None
517
+ img, x_start = self.p_sample(img, i, self_cond)
518
+
519
+ return img
520
+
521
+ @autocast(enabled = False)
522
+ def q_sample(self, x_start, t, noise = None):
523
+ noise = default(noise, lambda: torch.randn_like(x_start))
524
+
525
+ return (
526
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
527
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
528
+ )
529
+
530
+ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
531
+ b, c, h, w = x_start.shape
532
+
533
+ noise = default(noise, lambda: torch.randn_like(x_start))
534
+
535
+ # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
536
+
537
+ offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
538
+
539
+ if offset_noise_strength > 0.:
540
+ offset_noise = torch.randn(x_start.shape[:2], device = self.device)
541
+ noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
542
+
543
+ # noise sample
544
+
545
+ x = self.q_sample(x_start = x_start, t = t, noise = noise)
546
+
547
+ # if doing self-conditioning, 50% of the time, predict x_start from current set of times
548
+ # and condition with unet with that
549
+ # this technique will slow down training by 25%, but seems to lower FID significantly
550
+
551
+ x_self_cond = None
552
+ if self.self_condition and random() < 0.5:
553
+ with torch.no_grad():
554
+ x_self_cond = self.model_predictions(x, t).pred_x_start
555
+ x_self_cond.detach_()
556
+
557
+ # predict and take gradient step
558
+
559
+ model_out = self.model(x, t, x_self_cond)
560
+
561
+ if self.objective == 'pred_noise':
562
+ target = noise
563
+ elif self.objective == 'pred_x0':
564
+ target = x_start
565
+ elif self.objective == 'pred_v':
566
+ v = self.predict_v(x_start, t, noise)
567
+ target = v
568
+ else:
569
+ raise ValueError(f'unknown objective {self.objective}')
570
+
571
+ loss = F.mse_loss(model_out, target, reduction = 'none')
572
+ loss = reduce(loss, 'b ... -> b', 'mean')
573
+
574
+ loss = loss * extract(self.loss_weight, t, loss.shape)
575
+ return loss.mean()
576
+
577
+ def forward(self, img, *args, **kwargs):
578
+ b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
579
+ assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
580
+ t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
581
+
582
+ img = self.normalize(img)
583
+ return self.p_losses(img, t, *args, **kwargs)
584
+
585
+
586
+ # # Resnet Model
587
+
588
+ # In[19]:
589
+
590
+
591
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
592
+ return nn.Conv2d(
593
+ in_channels, out_channels, kernel_size,
594
+ padding=(kernel_size//2), bias=bias)
595
+
596
+
597
+ # In[20]:
598
+
599
+
600
+ class Swish(nn.Module):
601
+ def forward(self, x):
602
+ return x * torch.sigmoid(x)
603
+
604
+
605
+ # In[21]:
606
+
607
+
608
+ class AttnBlock(nn.Module):
609
+ def __init__(self, in_ch):
610
+ super().__init__()
611
+ self.group_norm = nn.GroupNorm(32, in_ch)
612
+ self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
613
+ self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
614
+ self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
615
+ self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
616
+
617
+ def forward(self, x):
618
+ B, C, H, W = x.shape
619
+ h = self.group_norm(x)
620
+ q = self.proj_q(h)
621
+ k = self.proj_k(h)
622
+ v = self.proj_v(h)
623
+
624
+ q = q.permute(0, 2, 3, 1).view(B, H * W, C)
625
+ k = k.view(B, C, H * W)
626
+ w = torch.bmm(q, k) * (int(C) ** (-0.5))
627
+ assert list(w.shape) == [B, H * W, H * W]
628
+ w = F.softmax(w, dim=-1)
629
+
630
+ v = v.permute(0, 2, 3, 1).view(B, H * W, C)
631
+ h = torch.bmm(w, v)
632
+ assert list(h.shape) == [B, H * W, C]
633
+ h = h.view(B, H, W, C).permute(0, 3, 1, 2)
634
+ h = self.proj(h)
635
+
636
+ return x + h
637
+
638
+
639
+ # In[22]:
640
+
641
+
642
+ class ResBlock(nn.Module):
643
+ def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
644
+ super().__init__()
645
+ self.block1 = nn.Sequential(
646
+ nn.GroupNorm(32, in_ch),
647
+ Swish(),
648
+ nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
649
+ )
650
+ self.temb_proj = nn.Sequential(
651
+ Swish(),
652
+ nn.Linear(tdim, out_ch),
653
+ )
654
+ self.block2 = nn.Sequential(
655
+ nn.GroupNorm(32, out_ch),
656
+ Swish(),
657
+ nn.Dropout(dropout),
658
+ nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
659
+ )
660
+ if in_ch != out_ch:
661
+ self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
662
+ else:
663
+ self.shortcut = nn.Identity()
664
+ if attn:
665
+ self.attn = AttnBlock(out_ch)
666
+ else:
667
+ self.attn = nn.Identity()
668
+
669
+ def forward(self, x, temb):
670
+ h = self.block1(x)
671
+ h += self.temb_proj(temb)[:, :, None, None]
672
+ h = self.block2(h)
673
+
674
+ h = h + self.shortcut(x)
675
+ h = self.attn(h)
676
+ return h
677
+
678
+
679
+ # In[23]:
680
+
681
+
682
+ class EDSR(nn.Module):
683
+ # Modified from https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/edsr.py#L31
684
+
685
+ def __init__(self,
686
+ resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock'],
687
+ n_feats=128,
688
+ t_dim=256,
689
+ dropout=0.1,
690
+ channels=1,
691
+ out_dim=1,
692
+ self_condition = False,
693
+ learned_sinusoidal_cond=False,
694
+ random_fourier_features=False,
695
+ learned_sinusoidal_dim=16,
696
+ sinusoidal_pos_emb_theta=10000,
697
+ conv=default_conv):
698
+ super(EDSR, self).__init__()
699
+
700
+ self.resblocks = resblocks
701
+ self.n_feats = n_feats
702
+ self.t_dim = t_dim
703
+ self.dropout = dropout
704
+ self.channels = channels
705
+ self.out_dim = out_dim
706
+ self.self_condition = self_condition
707
+ self.kernel_size = 3
708
+
709
+ # define time embedding
710
+ if learned_sinusoidal_cond:
711
+ sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
712
+ fourier_dim = learned_sinusoidal_dim + 1
713
+ else:
714
+ sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta)
715
+ fourier_dim = self.n_feats
716
+
717
+ self.time_mlp = nn.Sequential(
718
+ sinu_pos_emb,
719
+ nn.Linear(fourier_dim, self.t_dim),
720
+ nn.GELU(),
721
+ nn.Linear(self.t_dim, self.t_dim)
722
+ )
723
+
724
+ # define head module
725
+ self.head = conv(self.channels, self.n_feats, self.kernel_size)
726
+
727
+ # define body module
728
+ self.body = nn.ModuleList()
729
+ for block in resblocks:
730
+ if block == "ResBlock":
731
+ self.body.append(
732
+ ResBlock(in_ch=self.n_feats,
733
+ out_ch=self.n_feats,
734
+ tdim=self.t_dim,
735
+ dropout=self.dropout,
736
+ attn=False))
737
+ elif block == "AttnBlock":
738
+ self.body.append(
739
+ ResBlock(in_ch=self.n_feats,
740
+ out_ch=self.n_feats,
741
+ tdim=self.t_dim,
742
+ dropout=self.dropout,
743
+ attn=True))
744
+ else:
745
+ raise NotImplementedError("Model currently doesn't support this kind of block!")
746
+ self.body.append(conv(self.n_feats, self.n_feats, self.kernel_size))
747
+
748
+ # define tail module
749
+ self.tail = conv(self.n_feats, self.out_dim, self.kernel_size)
750
+
751
+
752
+ def forward(self, x, t, cond=None):
753
+ t = self.time_mlp(t)
754
+
755
+ x = self.head(x)
756
+
757
+ res = x
758
+ for block in self.body:
759
+ if isinstance(block, ResBlock):
760
+ res = block(res, t)
761
+ else:
762
+ res = block(res)
763
+ res += x
764
+
765
+ x = self.tail(res)
766
+
767
+ return x
768
+
769
+
770
+ # # Train
771
+
772
+ # In[24]:
773
+
774
+
775
+ # In[25]:
776
+
777
+
778
+
779
+ # In[26]:
780
+
781
+
782
+ # define model
783
+ model = EDSR(
784
+ resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock',
785
+ 'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock',],
786
+ n_feats=256,
787
+ t_dim=512,
788
+ dropout=0.1,
789
+ channels=1, # MNIST
790
+ out_dim=1, # MNIST
791
+ learned_sinusoidal_cond=False,
792
+ random_fourier_features=False,
793
+ learned_sinusoidal_dim=16,
794
+ sinusoidal_pos_emb_theta=10000,)
795
+
796
+ diffusion_model = GaussianDiffusion(
797
+ model,
798
+ image_size=28, # MNIST
799
+ timesteps=1000,
800
+ sampling_timesteps=None,
801
+ objective ='pred_noise',
802
+ beta_schedule ='linear',
803
+ schedule_fn_kwargs=dict(),
804
+ ddim_sampling_eta= 0.,
805
+ auto_normalize = True,
806
+ offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
807
+ min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
808
+ min_snr_gamma = 5)
809
+
810
+
811
+ # In[27]:
812
+
813
+
814
+ # In[28]:
815
+
816
+
817
+
818
+ # In[29]:
819
+
820
+
821
+ # device
822
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
823
+
824
+
825
+ # In[30]:
826
+
827
+
828
+ # trainer
829
+ max_epoches = 50
830
+ iter_print = 100
831
+ iter_sample = 1000
832
+ save_each = 1
833
+
834
+ diffusion_model = diffusion_model.to(device)
835
+
836
+ last_trained_path = 'resnet\model\epoch_30.pth'
837
+ diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path))['model'])
838
+
839
+ sample_path = 'resnet/sample2'
840
+
841
+ if not os.path.exists(sample_path):
842
+ os.mkdir(sample_path)
843
+
844
+ num_sample = 500
845
+ sample_batch = 16
846
+ count = 0
847
+
848
+ if num_sample % sample_batch != 0:
849
+ num_sample = num_sample + (sample_batch - (num_sample % sample_batch))
850
+
851
+ for batch in range(num_sample//sample_batch):
852
+ imgs = diffusion_model.sample(batch_size=sample_batch, return_all_timesteps=False)
853
+ for i in range(imgs.size(0)):
854
+ torchvision.utils.save_image(imgs[i, :, :, :], os.path.join(sample_path ,f'{count}.png'))
855
+ count += 1
856
+
resnet/log/info.log ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [02:18:34] INFO - Epoch 1/50, Iter 0: Loss = 1.23480224609375, lr = 0.0001
2
+ [02:18:59] INFO - Epoch 1/50, Iter 100: Loss = 0.08826638758182526, lr = 0.0001
3
+ [02:19:38] INFO - Epoch 1/50, Iter 200: Loss = 0.0545695461332798, lr = 0.0001
4
+ [02:20:35] INFO - Epoch 1/50, Iter 300: Loss = 0.06626827269792557, lr = 0.0001
5
+ [02:21:29] INFO - Epoch 1/50, Iter 400: Loss = 0.07271286845207214, lr = 0.0001
6
+ [02:22:24] INFO - Epoch 1/50, Iter 500: Loss = 0.027932994067668915, lr = 0.0001
7
+ [02:23:19] INFO - Epoch 1/50, Iter 600: Loss = 0.037907857447862625, lr = 0.0001
8
+ [02:24:14] INFO - Epoch 1/50, Iter 700: Loss = 0.03283434733748436, lr = 0.0001
9
+ [02:25:10] INFO - Epoch 1/50, Iter 800: Loss = 0.0401763841509819, lr = 0.0001
10
+ [02:26:05] INFO - Epoch 1/50, Iter 900: Loss = 0.02380681410431862, lr = 0.0001
11
+ [02:28:36] INFO - Epoch 1/50, Iter 1000: Loss = 0.03142669051885605, lr = 0.0001
12
+ [02:29:30] INFO - Epoch 1/50, Iter 1100: Loss = 0.021915458142757416, lr = 0.0001
13
+ [02:30:25] INFO - Epoch 1/50, Iter 1200: Loss = 0.03710126131772995, lr = 0.0001
14
+ [02:31:19] INFO - Epoch 1/50, Iter 1300: Loss = 0.017894160002470016, lr = 0.0001
15
+ [02:32:14] INFO - Epoch 1/50, Iter 1400: Loss = 0.032229095697402954, lr = 0.0001
16
+ [02:33:08] INFO - Epoch 1/50, Iter 1500: Loss = 0.022246181964874268, lr = 0.0001
17
+ [02:34:04] INFO - Epoch 1/50, Iter 1600: Loss = 0.02387898601591587, lr = 0.0001
18
+ [02:34:58] INFO - Epoch 1/50, Iter 1700: Loss = 0.033216990530490875, lr = 0.0001
19
+ [02:35:53] INFO - Epoch 1/50, Iter 1800: Loss = 0.03182423859834671, lr = 0.0001
20
+ [02:36:48] INFO - Epoch 2/50, Iter 1900: Loss = 0.027017910033464432, lr = 0.0001
21
+ [02:39:20] INFO - Epoch 2/50, Iter 2000: Loss = 0.03848206251859665, lr = 0.0001
22
+ [02:40:14] INFO - Epoch 2/50, Iter 2100: Loss = 0.02826070785522461, lr = 0.0001
23
+ [02:41:09] INFO - Epoch 2/50, Iter 2200: Loss = 0.03657548129558563, lr = 0.0001
24
+ [02:42:04] INFO - Epoch 2/50, Iter 2300: Loss = 0.03236750513315201, lr = 0.0001
25
+ [02:42:59] INFO - Epoch 2/50, Iter 2400: Loss = 0.02394908107817173, lr = 0.0001
26
+ [02:43:54] INFO - Epoch 2/50, Iter 2500: Loss = 0.028264183551073074, lr = 0.0001
27
+ [02:44:48] INFO - Epoch 2/50, Iter 2600: Loss = 0.034485459327697754, lr = 0.0001
28
+ [02:45:44] INFO - Epoch 2/50, Iter 2700: Loss = 0.02295440435409546, lr = 0.0001
29
+ [02:46:38] INFO - Epoch 2/50, Iter 2800: Loss = 0.03146759420633316, lr = 0.0001
30
+ [02:47:33] INFO - Epoch 2/50, Iter 2900: Loss = 0.022224590182304382, lr = 0.0001
31
+ [02:50:06] INFO - Epoch 2/50, Iter 3000: Loss = 0.03717297315597534, lr = 0.0001
32
+ [02:51:00] INFO - Epoch 2/50, Iter 3100: Loss = 0.023568114265799522, lr = 0.0001
33
+ [02:51:55] INFO - Epoch 2/50, Iter 3200: Loss = 0.01752738654613495, lr = 0.0001
34
+ [02:52:49] INFO - Epoch 2/50, Iter 3300: Loss = 0.024697361513972282, lr = 0.0001
35
+ [02:53:44] INFO - Epoch 2/50, Iter 3400: Loss = 0.027621649205684662, lr = 0.0001
36
+ [02:54:39] INFO - Epoch 2/50, Iter 3500: Loss = 0.03197108209133148, lr = 0.0001
37
+ [02:55:34] INFO - Epoch 2/50, Iter 3600: Loss = 0.034603990614414215, lr = 0.0001
38
+ [02:56:29] INFO - Epoch 2/50, Iter 3700: Loss = 0.024781333282589912, lr = 0.0001
39
+ [02:57:25] INFO - Epoch 3/50, Iter 3800: Loss = 0.029720211401581764, lr = 0.0001
40
+ [02:58:20] INFO - Epoch 3/50, Iter 3900: Loss = 0.050903625786304474, lr = 0.0001
41
+ [03:00:52] INFO - Epoch 3/50, Iter 4000: Loss = 0.022276397794485092, lr = 0.0001
42
+ [03:01:48] INFO - Epoch 3/50, Iter 4100: Loss = 0.02051287144422531, lr = 0.0001
43
+ [03:02:42] INFO - Epoch 3/50, Iter 4200: Loss = 0.02138718217611313, lr = 0.0001
44
+ [03:03:37] INFO - Epoch 3/50, Iter 4300: Loss = 0.013692906126379967, lr = 0.0001
45
+ [03:04:31] INFO - Epoch 3/50, Iter 4400: Loss = 0.026416348293423653, lr = 0.0001
46
+ [03:05:26] INFO - Epoch 3/50, Iter 4500: Loss = 0.02263474650681019, lr = 0.0001
47
+ [03:06:21] INFO - Epoch 3/50, Iter 4600: Loss = 0.02561156451702118, lr = 0.0001
48
+ [03:07:15] INFO - Epoch 3/50, Iter 4700: Loss = 0.022007182240486145, lr = 0.0001
49
+ [03:08:11] INFO - Epoch 3/50, Iter 4800: Loss = 0.024828705936670303, lr = 0.0001
50
+ [03:09:05] INFO - Epoch 3/50, Iter 4900: Loss = 0.0277644544839859, lr = 0.0001
51
+ [03:11:37] INFO - Epoch 3/50, Iter 5000: Loss = 0.022669199854135513, lr = 0.0001
52
+ [03:12:31] INFO - Epoch 3/50, Iter 5100: Loss = 0.03488582372665405, lr = 0.0001
53
+ [03:13:27] INFO - Epoch 3/50, Iter 5200: Loss = 0.033707648515701294, lr = 0.0001
54
+ [03:14:21] INFO - Epoch 3/50, Iter 5300: Loss = 0.034617647528648376, lr = 0.0001
55
+ [03:15:16] INFO - Epoch 3/50, Iter 5400: Loss = 0.015979502350091934, lr = 0.0001
56
+ [03:16:11] INFO - Epoch 3/50, Iter 5500: Loss = 0.017885394394397736, lr = 0.0001
57
+ [03:17:05] INFO - Epoch 3/50, Iter 5600: Loss = 0.013684597797691822, lr = 0.0001
58
+ [03:18:01] INFO - Epoch 4/50, Iter 5700: Loss = 0.018592171370983124, lr = 0.0001
59
+ [03:18:56] INFO - Epoch 4/50, Iter 5800: Loss = 0.019852526485919952, lr = 0.0001
60
+ [03:19:52] INFO - Epoch 4/50, Iter 5900: Loss = 0.014810988679528236, lr = 0.0001
61
+ [03:22:25] INFO - Epoch 4/50, Iter 6000: Loss = 0.022946510463953018, lr = 0.0001
62
+ [03:23:19] INFO - Epoch 4/50, Iter 6100: Loss = 0.022477544844150543, lr = 0.0001
63
+ [03:24:14] INFO - Epoch 4/50, Iter 6200: Loss = 0.021514300256967545, lr = 0.0001
64
+ [03:25:09] INFO - Epoch 4/50, Iter 6300: Loss = 0.017631331458687782, lr = 0.0001
65
+ [03:26:03] INFO - Epoch 4/50, Iter 6400: Loss = 0.02970929630100727, lr = 0.0001
66
+ [03:26:58] INFO - Epoch 4/50, Iter 6500: Loss = 0.02417093515396118, lr = 0.0001
67
+ [03:27:52] INFO - Epoch 4/50, Iter 6600: Loss = 0.028470398858189583, lr = 0.0001
68
+ [03:28:48] INFO - Epoch 4/50, Iter 6700: Loss = 0.02186693623661995, lr = 0.0001
69
+ [03:29:42] INFO - Epoch 4/50, Iter 6800: Loss = 0.021022997796535492, lr = 0.0001
70
+ [03:30:37] INFO - Epoch 4/50, Iter 6900: Loss = 0.02663368172943592, lr = 0.0001
71
+ [03:33:08] INFO - Epoch 4/50, Iter 7000: Loss = 0.0202815942466259, lr = 0.0001
72
+ [03:34:04] INFO - Epoch 4/50, Iter 7100: Loss = 0.017694229260087013, lr = 0.0001
73
+ [03:34:58] INFO - Epoch 4/50, Iter 7200: Loss = 0.03217596560716629, lr = 0.0001
74
+ [03:35:53] INFO - Epoch 4/50, Iter 7300: Loss = 0.027110356837511063, lr = 0.0001
75
+ [03:36:47] INFO - Epoch 4/50, Iter 7400: Loss = 0.02598414570093155, lr = 0.0001
76
+ [03:37:42] INFO - Epoch 5/50, Iter 7500: Loss = 0.031232168897986412, lr = 0.0001
77
+ [03:38:37] INFO - Epoch 5/50, Iter 7600: Loss = 0.0394064262509346, lr = 0.0001
78
+ [03:39:33] INFO - Epoch 5/50, Iter 7700: Loss = 0.017326747998595238, lr = 0.0001
79
+ [03:40:28] INFO - Epoch 5/50, Iter 7800: Loss = 0.029284335672855377, lr = 0.0001
80
+ [03:41:24] INFO - Epoch 5/50, Iter 7900: Loss = 0.01525358110666275, lr = 0.0001
81
+ [03:43:56] INFO - Epoch 5/50, Iter 8000: Loss = 0.019312670454382896, lr = 0.0001
82
+ [03:44:51] INFO - Epoch 5/50, Iter 8100: Loss = 0.022943828254938126, lr = 0.0001
83
+ [03:45:46] INFO - Epoch 5/50, Iter 8200: Loss = 0.014834869652986526, lr = 0.0001
84
+ [03:46:40] INFO - Epoch 5/50, Iter 8300: Loss = 0.013647425919771194, lr = 0.0001
85
+ [03:47:35] INFO - Epoch 5/50, Iter 8400: Loss = 0.012797506526112556, lr = 0.0001
86
+ [03:48:29] INFO - Epoch 5/50, Iter 8500: Loss = 0.028487099334597588, lr = 0.0001
87
+ [03:49:25] INFO - Epoch 5/50, Iter 8600: Loss = 0.0326717309653759, lr = 0.0001
88
+ [03:50:20] INFO - Epoch 5/50, Iter 8700: Loss = 0.018652349710464478, lr = 0.0001
89
+ [03:51:14] INFO - Epoch 5/50, Iter 8800: Loss = 0.026515061035752296, lr = 0.0001
90
+ [03:52:09] INFO - Epoch 5/50, Iter 8900: Loss = 0.02715548872947693, lr = 0.0001
91
+ [03:54:40] INFO - Epoch 5/50, Iter 9000: Loss = 0.025071512907743454, lr = 0.0001
92
+ [03:55:34] INFO - Epoch 5/50, Iter 9100: Loss = 0.02286442741751671, lr = 0.0001
93
+ [03:56:29] INFO - Epoch 5/50, Iter 9200: Loss = 0.024927817285060883, lr = 0.0001
94
+ [03:57:25] INFO - Epoch 5/50, Iter 9300: Loss = 0.02016012743115425, lr = 0.0001
95
+ [03:58:20] INFO - Epoch 6/50, Iter 9400: Loss = 0.016080211848020554, lr = 0.0001
96
+ [03:59:16] INFO - Epoch 6/50, Iter 9500: Loss = 0.03025580570101738, lr = 0.0001
97
+ [04:00:10] INFO - Epoch 6/50, Iter 9600: Loss = 0.034918542951345444, lr = 0.0001
98
+ [04:01:06] INFO - Epoch 6/50, Iter 9700: Loss = 0.024010658264160156, lr = 0.0001
99
+ [04:02:01] INFO - Epoch 6/50, Iter 9800: Loss = 0.024768657982349396, lr = 0.0001
100
+ [04:02:57] INFO - Epoch 6/50, Iter 9900: Loss = 0.02912471443414688, lr = 0.0001
101
+ [04:05:28] INFO - Epoch 6/50, Iter 10000: Loss = 0.013935514725744724, lr = 0.0001
102
+ [04:06:24] INFO - Epoch 6/50, Iter 10100: Loss = 0.024383660405874252, lr = 0.0001
103
+ [04:07:19] INFO - Epoch 6/50, Iter 10200: Loss = 0.02626352570950985, lr = 0.0001
104
+ [04:08:13] INFO - Epoch 6/50, Iter 10300: Loss = 0.02143704704940319, lr = 0.0001
105
+ [04:09:08] INFO - Epoch 6/50, Iter 10400: Loss = 0.022659476846456528, lr = 0.0001
106
+ [04:10:02] INFO - Epoch 6/50, Iter 10500: Loss = 0.020370323210954666, lr = 0.0001
107
+ [04:10:57] INFO - Epoch 6/50, Iter 10600: Loss = 0.02100287191569805, lr = 0.0001
108
+ [04:11:52] INFO - Epoch 6/50, Iter 10700: Loss = 0.01825377717614174, lr = 0.0001
109
+ [04:12:46] INFO - Epoch 6/50, Iter 10800: Loss = 0.026205215603113174, lr = 0.0001
110
+ [04:13:42] INFO - Epoch 6/50, Iter 10900: Loss = 0.03552094101905823, lr = 0.0001
111
+ [04:16:13] INFO - Epoch 6/50, Iter 11000: Loss = 0.016668759286403656, lr = 0.0001
112
+ [04:17:07] INFO - Epoch 6/50, Iter 11100: Loss = 0.018555857241153717, lr = 0.0001
113
+ [04:18:02] INFO - Epoch 6/50, Iter 11200: Loss = 0.01698373258113861, lr = 0.0001
114
+ [04:18:58] INFO - Epoch 7/50, Iter 11300: Loss = 0.021595774218440056, lr = 0.0001
115
+ [04:19:53] INFO - Epoch 7/50, Iter 11400: Loss = 0.029402505606412888, lr = 0.0001
116
+ [04:20:49] INFO - Epoch 7/50, Iter 11500: Loss = 0.017380326986312866, lr = 0.0001
117
+ [04:21:44] INFO - Epoch 7/50, Iter 11600: Loss = 0.022462423890829086, lr = 0.0001
118
+ [04:22:40] INFO - Epoch 7/50, Iter 11700: Loss = 0.024359144270420074, lr = 0.0001
119
+ [04:23:35] INFO - Epoch 7/50, Iter 11800: Loss = 0.025637302547693253, lr = 0.0001
120
+ [04:24:31] INFO - Epoch 7/50, Iter 11900: Loss = 0.027863897383213043, lr = 0.0001
121
+ [04:27:02] INFO - Epoch 7/50, Iter 12000: Loss = 0.025426337495446205, lr = 0.0001
122
+ [04:27:58] INFO - Epoch 7/50, Iter 12100: Loss = 0.03268758952617645, lr = 0.0001
123
+ [04:28:52] INFO - Epoch 7/50, Iter 12200: Loss = 0.016548998653888702, lr = 0.0001
124
+ [04:29:47] INFO - Epoch 7/50, Iter 12300: Loss = 0.02512863650918007, lr = 0.0001
125
+ [04:30:41] INFO - Epoch 7/50, Iter 12400: Loss = 0.0246925987303257, lr = 0.0001
126
+ [04:31:36] INFO - Epoch 7/50, Iter 12500: Loss = 0.018600817769765854, lr = 0.0001
127
+ [04:32:31] INFO - Epoch 7/50, Iter 12600: Loss = 0.01979782059788704, lr = 0.0001
128
+ [04:33:25] INFO - Epoch 7/50, Iter 12700: Loss = 0.021152257919311523, lr = 0.0001
129
+ [04:34:21] INFO - Epoch 7/50, Iter 12800: Loss = 0.02903410792350769, lr = 0.0001
130
+ [04:35:16] INFO - Epoch 7/50, Iter 12900: Loss = 0.03196360170841217, lr = 0.0001
131
+ [04:37:46] INFO - Epoch 7/50, Iter 13000: Loss = 0.019338594749569893, lr = 0.0001
132
+ [04:38:41] INFO - Epoch 7/50, Iter 13100: Loss = 0.027051424607634544, lr = 0.0001
133
+ [04:39:37] INFO - Epoch 8/50, Iter 13200: Loss = 0.0238485224545002, lr = 0.0001
134
+ [04:40:32] INFO - Epoch 8/50, Iter 13300: Loss = 0.02585774101316929, lr = 0.0001
135
+ [04:41:28] INFO - Epoch 8/50, Iter 13400: Loss = 0.01865781843662262, lr = 0.0001
136
+ [04:42:23] INFO - Epoch 8/50, Iter 13500: Loss = 0.03003603406250477, lr = 0.0001
137
+ [04:43:17] INFO - Epoch 8/50, Iter 13600: Loss = 0.02756107971072197, lr = 0.0001
138
+ [04:44:13] INFO - Epoch 8/50, Iter 13700: Loss = 0.018252156674861908, lr = 0.0001
139
+ [04:45:09] INFO - Epoch 8/50, Iter 13800: Loss = 0.0232943594455719, lr = 0.0001
140
+ [04:46:04] INFO - Epoch 8/50, Iter 13900: Loss = 0.03505060076713562, lr = 0.0001
141
+ [04:48:36] INFO - Epoch 8/50, Iter 14000: Loss = 0.015609338879585266, lr = 0.0001
142
+ [04:49:31] INFO - Epoch 8/50, Iter 14100: Loss = 0.024727653712034225, lr = 0.0001
143
+ [04:50:25] INFO - Epoch 8/50, Iter 14200: Loss = 0.01343458704650402, lr = 0.0001
144
+ [04:51:20] INFO - Epoch 8/50, Iter 14300: Loss = 0.02276020497083664, lr = 0.0001
145
+ [04:52:15] INFO - Epoch 8/50, Iter 14400: Loss = 0.030666548758745193, lr = 0.0001
146
+ [04:53:09] INFO - Epoch 8/50, Iter 14500: Loss = 0.027710841968655586, lr = 0.0001
147
+ [04:54:04] INFO - Epoch 8/50, Iter 14600: Loss = 0.02813234180212021, lr = 0.0001
148
+ [04:54:58] INFO - Epoch 8/50, Iter 14700: Loss = 0.0154835544526577, lr = 0.0001
149
+ [04:55:54] INFO - Epoch 8/50, Iter 14800: Loss = 0.0330531969666481, lr = 0.0001
150
+ [04:56:49] INFO - Epoch 8/50, Iter 14900: Loss = 0.02566523663699627, lr = 0.0001
151
+ [04:59:20] INFO - Epoch 9/50, Iter 15000: Loss = 0.03587709367275238, lr = 0.0001
152
+ [05:00:16] INFO - Epoch 9/50, Iter 15100: Loss = 0.011817749589681625, lr = 0.0001
153
+ [05:01:11] INFO - Epoch 9/50, Iter 15200: Loss = 0.019955918192863464, lr = 0.0001
154
+ [05:02:06] INFO - Epoch 9/50, Iter 15300: Loss = 0.01926155760884285, lr = 0.0001
155
+ [05:03:01] INFO - Epoch 9/50, Iter 15400: Loss = 0.025760915130376816, lr = 0.0001
156
+ [05:03:57] INFO - Epoch 9/50, Iter 15500: Loss = 0.023390091955661774, lr = 0.0001
157
+ [05:04:52] INFO - Epoch 9/50, Iter 15600: Loss = 0.03382980450987816, lr = 0.0001
158
+ [05:05:48] INFO - Epoch 9/50, Iter 15700: Loss = 0.019686255604028702, lr = 0.0001
159
+ [05:06:43] INFO - Epoch 9/50, Iter 15800: Loss = 0.017689798027276993, lr = 0.0001
160
+ [05:07:39] INFO - Epoch 9/50, Iter 15900: Loss = 0.02643013373017311, lr = 0.0001
161
+ [05:10:10] INFO - Epoch 9/50, Iter 16000: Loss = 0.01975519210100174, lr = 0.0001
162
+ [05:11:05] INFO - Epoch 9/50, Iter 16100: Loss = 0.02566615864634514, lr = 0.0001
163
+ [05:12:01] INFO - Epoch 9/50, Iter 16200: Loss = 0.023744797334074974, lr = 0.0001
164
+ [05:12:54] INFO - Epoch 9/50, Iter 16300: Loss = 0.029149867594242096, lr = 0.0001
165
+ [05:13:50] INFO - Epoch 9/50, Iter 16400: Loss = 0.024619584903120995, lr = 0.0001
166
+ [05:14:44] INFO - Epoch 9/50, Iter 16500: Loss = 0.017802121117711067, lr = 0.0001
167
+ [05:15:39] INFO - Epoch 9/50, Iter 16600: Loss = 0.030343685299158096, lr = 0.0001
168
+ [05:16:34] INFO - Epoch 9/50, Iter 16700: Loss = 0.028128691017627716, lr = 0.0001
169
+ [05:17:28] INFO - Epoch 9/50, Iter 16800: Loss = 0.013130296021699905, lr = 0.0001
170
+ [05:18:23] INFO - Epoch 10/50, Iter 16900: Loss = 0.015325885266065598, lr = 0.0001
171
+ [05:20:55] INFO - Epoch 10/50, Iter 17000: Loss = 0.02369626611471176, lr = 0.0001
172
+ [05:21:50] INFO - Epoch 10/50, Iter 17100: Loss = 0.03911880403757095, lr = 0.0001
173
+ [05:22:44] INFO - Epoch 10/50, Iter 17200: Loss = 0.019555510953068733, lr = 0.0001
174
+ [05:23:40] INFO - Epoch 10/50, Iter 17300: Loss = 0.026994436979293823, lr = 0.0001
175
+ [05:24:35] INFO - Epoch 10/50, Iter 17400: Loss = 0.014918794855475426, lr = 0.0001
176
+ [05:25:29] INFO - Epoch 10/50, Iter 17500: Loss = 0.015928588807582855, lr = 0.0001
177
+ [05:26:24] INFO - Epoch 10/50, Iter 17600: Loss = 0.026111863553524017, lr = 0.0001
178
+ [05:27:19] INFO - Epoch 10/50, Iter 17700: Loss = 0.023383410647511482, lr = 0.0001
179
+ [05:28:13] INFO - Epoch 10/50, Iter 17800: Loss = 0.022820118814706802, lr = 0.0001
180
+ [05:29:08] INFO - Epoch 10/50, Iter 17900: Loss = 0.016951140016317368, lr = 0.0001
181
+ [05:31:40] INFO - Epoch 10/50, Iter 18000: Loss = 0.021106135100126266, lr = 0.0001
182
+ [05:32:34] INFO - Epoch 10/50, Iter 18100: Loss = 0.015148286707699299, lr = 0.0001
183
+ [05:33:29] INFO - Epoch 10/50, Iter 18200: Loss = 0.019842375069856644, lr = 0.0001
184
+ [05:34:24] INFO - Epoch 10/50, Iter 18300: Loss = 0.022392811253666878, lr = 0.0001
185
+ [05:35:18] INFO - Epoch 10/50, Iter 18400: Loss = 0.02733965963125229, lr = 0.0001
186
+ [05:36:13] INFO - Epoch 10/50, Iter 18500: Loss = 0.02087550237774849, lr = 0.0001
187
+ [05:37:08] INFO - Epoch 10/50, Iter 18600: Loss = 0.02672572433948517, lr = 0.0001
188
+ [05:38:03] INFO - Epoch 10/50, Iter 18700: Loss = 0.02076902985572815, lr = 0.0001
189
+ [05:38:59] INFO - Epoch 11/50, Iter 18800: Loss = 0.0208309106528759, lr = 0.0001
190
+ [05:39:54] INFO - Epoch 11/50, Iter 18900: Loss = 0.01603943109512329, lr = 0.0001
191
+ [05:42:26] INFO - Epoch 11/50, Iter 19000: Loss = 0.018146460875868797, lr = 0.0001
192
+ [05:43:20] INFO - Epoch 11/50, Iter 19100: Loss = 0.03146671503782272, lr = 0.0001
193
+ [05:44:15] INFO - Epoch 11/50, Iter 19200: Loss = 0.017263440415263176, lr = 0.0001
194
+ [05:45:10] INFO - Epoch 11/50, Iter 19300: Loss = 0.021944427862763405, lr = 0.0001
195
+ [05:46:04] INFO - Epoch 11/50, Iter 19400: Loss = 0.017847534269094467, lr = 0.0001
196
+ [05:46:59] INFO - Epoch 11/50, Iter 19500: Loss = 0.021428382024168968, lr = 0.0001
197
+ [05:47:55] INFO - Epoch 11/50, Iter 19600: Loss = 0.020893530920147896, lr = 0.0001
198
+ [05:48:49] INFO - Epoch 11/50, Iter 19700: Loss = 0.02261212095618248, lr = 0.0001
199
+ [05:49:44] INFO - Epoch 11/50, Iter 19800: Loss = 0.017424296587705612, lr = 0.0001
200
+ [05:50:39] INFO - Epoch 11/50, Iter 19900: Loss = 0.025077205151319504, lr = 0.0001
201
+ [05:53:10] INFO - Epoch 11/50, Iter 20000: Loss = 0.029975447803735733, lr = 0.0001
202
+ [05:54:04] INFO - Epoch 11/50, Iter 20100: Loss = 0.019458118826150894, lr = 0.0001
203
+ [05:54:59] INFO - Epoch 11/50, Iter 20200: Loss = 0.0232146717607975, lr = 0.0001
204
+ [05:55:53] INFO - Epoch 11/50, Iter 20300: Loss = 0.02360851876437664, lr = 0.0001
205
+ [05:56:48] INFO - Epoch 11/50, Iter 20400: Loss = 0.024858074262738228, lr = 0.0001
206
+ [05:57:44] INFO - Epoch 11/50, Iter 20500: Loss = 0.044195011258125305, lr = 0.0001
207
+ [05:58:38] INFO - Epoch 11/50, Iter 20600: Loss = 0.018540263175964355, lr = 0.0001
208
+ [05:59:33] INFO - Epoch 12/50, Iter 20700: Loss = 0.021583855152130127, lr = 0.0001
209
+ [06:00:29] INFO - Epoch 12/50, Iter 20800: Loss = 0.02421833947300911, lr = 0.0001
210
+ [06:01:24] INFO - Epoch 12/50, Iter 20900: Loss = 0.026535984128713608, lr = 0.0001
211
+ [06:03:57] INFO - Epoch 12/50, Iter 21000: Loss = 0.01781940832734108, lr = 0.0001
212
+ [06:04:51] INFO - Epoch 12/50, Iter 21100: Loss = 0.023128725588321686, lr = 0.0001
213
+ [06:05:46] INFO - Epoch 12/50, Iter 21200: Loss = 0.02317957766354084, lr = 0.0001
214
+ [06:06:40] INFO - Epoch 12/50, Iter 21300: Loss = 0.016345253214240074, lr = 0.0001
215
+ [06:07:36] INFO - Epoch 12/50, Iter 21400: Loss = 0.02558373659849167, lr = 0.0001
216
+ [06:08:31] INFO - Epoch 12/50, Iter 21500: Loss = 0.026121504604816437, lr = 0.0001
217
+ [06:09:25] INFO - Epoch 12/50, Iter 21600: Loss = 0.022759977728128433, lr = 0.0001
218
+ [06:10:20] INFO - Epoch 12/50, Iter 21700: Loss = 0.026271792128682137, lr = 0.0001
219
+ [06:11:14] INFO - Epoch 12/50, Iter 21800: Loss = 0.027187272906303406, lr = 0.0001
220
+ [06:12:09] INFO - Epoch 12/50, Iter 21900: Loss = 0.023094702512025833, lr = 0.0001
221
+ [06:14:40] INFO - Epoch 12/50, Iter 22000: Loss = 0.016669970005750656, lr = 0.0001
222
+ [06:15:36] INFO - Epoch 12/50, Iter 22100: Loss = 0.026704635471105576, lr = 0.0001
223
+ [06:16:30] INFO - Epoch 12/50, Iter 22200: Loss = 0.02754068374633789, lr = 0.0001
224
+ [06:17:25] INFO - Epoch 12/50, Iter 22300: Loss = 0.025661129504442215, lr = 0.0001
225
+ [06:18:19] INFO - Epoch 12/50, Iter 22400: Loss = 0.025509830564260483, lr = 0.0001
226
+ [06:19:14] INFO - Epoch 13/50, Iter 22500: Loss = 0.025348283350467682, lr = 0.0001
227
+ [06:20:10] INFO - Epoch 13/50, Iter 22600: Loss = 0.026772376149892807, lr = 0.0001
228
+ [06:21:05] INFO - Epoch 13/50, Iter 22700: Loss = 0.01741105318069458, lr = 0.0001
229
+ [06:22:01] INFO - Epoch 13/50, Iter 22800: Loss = 0.02285039983689785, lr = 0.0001
230
+ [06:22:56] INFO - Epoch 13/50, Iter 22900: Loss = 0.027282923460006714, lr = 0.0001
231
+ [06:25:28] INFO - Epoch 13/50, Iter 23000: Loss = 0.012414131313562393, lr = 0.0001
232
+ [06:26:23] INFO - Epoch 13/50, Iter 23100: Loss = 0.019650613889098167, lr = 0.0001
233
+ [06:27:18] INFO - Epoch 13/50, Iter 23200: Loss = 0.02651660516858101, lr = 0.0001
234
+ [06:28:12] INFO - Epoch 13/50, Iter 23300: Loss = 0.026138421148061752, lr = 0.0001
235
+ [06:29:07] INFO - Epoch 13/50, Iter 23400: Loss = 0.018627706915140152, lr = 0.0001
236
+ [06:30:01] INFO - Epoch 13/50, Iter 23500: Loss = 0.028943434357643127, lr = 0.0001
237
+ [06:30:57] INFO - Epoch 13/50, Iter 23600: Loss = 0.01649133488535881, lr = 0.0001
238
+ [06:31:51] INFO - Epoch 13/50, Iter 23700: Loss = 0.01378883421421051, lr = 0.0001
239
+ [06:32:46] INFO - Epoch 13/50, Iter 23800: Loss = 0.02124626189470291, lr = 0.0001
240
+ [06:33:41] INFO - Epoch 13/50, Iter 23900: Loss = 0.017396021634340286, lr = 0.0001
241
+ [06:36:12] INFO - Epoch 13/50, Iter 24000: Loss = 0.01732352189719677, lr = 0.0001
242
+ [06:37:06] INFO - Epoch 13/50, Iter 24100: Loss = 0.014166954904794693, lr = 0.0001
243
+ [06:38:02] INFO - Epoch 13/50, Iter 24200: Loss = 0.02176068350672722, lr = 0.0001
244
+ [06:38:57] INFO - Epoch 13/50, Iter 24300: Loss = 0.019656777381896973, lr = 0.0001
245
+ [06:39:51] INFO - Epoch 14/50, Iter 24400: Loss = 0.02193061262369156, lr = 0.0001
246
+ [06:40:46] INFO - Epoch 14/50, Iter 24500: Loss = 0.018643012270331383, lr = 0.0001
247
+ [06:41:42] INFO - Epoch 14/50, Iter 24600: Loss = 0.012337702326476574, lr = 0.0001
248
+ [06:42:37] INFO - Epoch 14/50, Iter 24700: Loss = 0.016973398625850677, lr = 0.0001
249
+ [06:43:33] INFO - Epoch 14/50, Iter 24800: Loss = 0.025368668138980865, lr = 0.0001
250
+ [06:44:28] INFO - Epoch 14/50, Iter 24900: Loss = 0.02520618960261345, lr = 0.0001
251
+ [06:47:00] INFO - Epoch 14/50, Iter 25000: Loss = 0.01767529547214508, lr = 0.0001
252
+ [06:47:55] INFO - Epoch 14/50, Iter 25100: Loss = 0.021381141617894173, lr = 0.0001
253
+ [06:48:49] INFO - Epoch 14/50, Iter 25200: Loss = 0.021116536110639572, lr = 0.0001
254
+ [06:49:44] INFO - Epoch 14/50, Iter 25300: Loss = 0.017928242683410645, lr = 0.0001
255
+ [06:50:39] INFO - Epoch 14/50, Iter 25400: Loss = 0.021284624934196472, lr = 0.0001
256
+ [06:51:33] INFO - Epoch 14/50, Iter 25500: Loss = 0.013009730726480484, lr = 0.0001
257
+ [06:52:28] INFO - Epoch 14/50, Iter 25600: Loss = 0.018284976482391357, lr = 0.0001
258
+ [06:53:22] INFO - Epoch 14/50, Iter 25700: Loss = 0.019000139087438583, lr = 0.0001
259
+ [06:54:18] INFO - Epoch 14/50, Iter 25800: Loss = 0.01757623441517353, lr = 0.0001
260
+ [06:55:12] INFO - Epoch 14/50, Iter 25900: Loss = 0.019956346601247787, lr = 0.0001
261
+ [06:57:43] INFO - Epoch 14/50, Iter 26000: Loss = 0.025380369275808334, lr = 0.0001
262
+ [06:58:38] INFO - Epoch 14/50, Iter 26100: Loss = 0.02575628086924553, lr = 0.0001
263
+ [06:59:32] INFO - Epoch 14/50, Iter 26200: Loss = 0.02441999688744545, lr = 0.0001
264
+ [07:00:28] INFO - Epoch 15/50, Iter 26300: Loss = 0.015507195144891739, lr = 0.0001
265
+ [07:01:23] INFO - Epoch 15/50, Iter 26400: Loss = 0.018518857657909393, lr = 0.0001
266
+ [07:02:18] INFO - Epoch 15/50, Iter 26500: Loss = 0.0218639075756073, lr = 0.0001
267
+ [07:03:14] INFO - Epoch 15/50, Iter 26600: Loss = 0.01484048180282116, lr = 0.0001
268
+ [07:04:09] INFO - Epoch 15/50, Iter 26700: Loss = 0.020309407263994217, lr = 0.0001
269
+ [07:05:05] INFO - Epoch 15/50, Iter 26800: Loss = 0.02281174622476101, lr = 0.0001
270
+ [07:06:00] INFO - Epoch 15/50, Iter 26900: Loss = 0.022504042834043503, lr = 0.0001
271
+ [07:08:32] INFO - Epoch 15/50, Iter 27000: Loss = 0.016440019011497498, lr = 0.0001
272
+ [07:09:27] INFO - Epoch 15/50, Iter 27100: Loss = 0.015486285090446472, lr = 0.0001
273
+ [07:10:21] INFO - Epoch 15/50, Iter 27200: Loss = 0.01972173899412155, lr = 0.0001
274
+ [07:11:16] INFO - Epoch 15/50, Iter 27300: Loss = 0.018617577850818634, lr = 0.0001
275
+ [07:12:11] INFO - Epoch 15/50, Iter 27400: Loss = 0.02082516998052597, lr = 0.0001
276
+ [07:13:05] INFO - Epoch 15/50, Iter 27500: Loss = 0.01791219785809517, lr = 0.0001
277
+ [07:14:00] INFO - Epoch 15/50, Iter 27600: Loss = 0.02241137996315956, lr = 0.0001
278
+ [07:14:54] INFO - Epoch 15/50, Iter 27700: Loss = 0.020293384790420532, lr = 0.0001
279
+ [07:15:49] INFO - Epoch 15/50, Iter 27800: Loss = 0.029861796647310257, lr = 0.0001
280
+ [07:16:44] INFO - Epoch 15/50, Iter 27900: Loss = 0.02275857701897621, lr = 0.0001
281
+ [07:19:16] INFO - Epoch 15/50, Iter 28000: Loss = 0.015355780720710754, lr = 0.0001
282
+ [07:20:10] INFO - Epoch 15/50, Iter 28100: Loss = 0.019503731280565262, lr = 0.0001
283
+ [07:21:05] INFO - Epoch 16/50, Iter 28200: Loss = 0.024656936526298523, lr = 0.0001
284
+ [07:22:01] INFO - Epoch 16/50, Iter 28300: Loss = 0.016661042347550392, lr = 0.0001
285
+ [07:22:56] INFO - Epoch 16/50, Iter 28400: Loss = 0.017921866849064827, lr = 0.0001
286
+ [07:23:52] INFO - Epoch 16/50, Iter 28500: Loss = 0.020502446219325066, lr = 0.0001
287
+ [07:24:47] INFO - Epoch 16/50, Iter 28600: Loss = 0.012834666296839714, lr = 0.0001
288
+ [07:25:42] INFO - Epoch 16/50, Iter 28700: Loss = 0.017596762627363205, lr = 0.0001
289
+ [07:26:37] INFO - Epoch 16/50, Iter 28800: Loss = 0.02352038025856018, lr = 0.0001
290
+ [07:27:32] INFO - Epoch 16/50, Iter 28900: Loss = 0.022114895284175873, lr = 0.0001
291
+ [07:30:05] INFO - Epoch 16/50, Iter 29000: Loss = 0.018584776669740677, lr = 0.0001
292
+ [07:30:59] INFO - Epoch 16/50, Iter 29100: Loss = 0.021322712302207947, lr = 0.0001
293
+ [07:31:54] INFO - Epoch 16/50, Iter 29200: Loss = 0.01889413595199585, lr = 0.0001
294
+ [07:32:48] INFO - Epoch 16/50, Iter 29300: Loss = 0.027229465544223785, lr = 0.0001
295
+ [07:33:43] INFO - Epoch 16/50, Iter 29400: Loss = 0.026700954884290695, lr = 0.0001
296
+ [07:34:37] INFO - Epoch 16/50, Iter 29500: Loss = 0.026901915669441223, lr = 0.0001
297
+ [07:35:32] INFO - Epoch 16/50, Iter 29600: Loss = 0.0257167499512434, lr = 0.0001
298
+ [07:36:27] INFO - Epoch 16/50, Iter 29700: Loss = 0.023790445178747177, lr = 0.0001
299
+ [07:37:21] INFO - Epoch 16/50, Iter 29800: Loss = 0.010275682434439659, lr = 0.0001
300
+ [07:38:17] INFO - Epoch 16/50, Iter 29900: Loss = 0.024285804480314255, lr = 0.0001
301
+ [07:40:48] INFO - Epoch 17/50, Iter 30000: Loss = 0.01686658337712288, lr = 0.0001
302
+ [07:41:44] INFO - Epoch 17/50, Iter 30100: Loss = 0.019942965358495712, lr = 0.0001
303
+ [07:42:39] INFO - Epoch 17/50, Iter 30200: Loss = 0.032290853559970856, lr = 0.0001
304
+ [07:43:35] INFO - Epoch 17/50, Iter 30300: Loss = 0.02391435205936432, lr = 0.0001
305
+ [07:44:29] INFO - Epoch 17/50, Iter 30400: Loss = 0.022961270064115524, lr = 0.0001
306
+ [07:45:24] INFO - Epoch 17/50, Iter 30500: Loss = 0.02686147764325142, lr = 0.0001
307
+ [07:46:20] INFO - Epoch 17/50, Iter 30600: Loss = 0.021469425410032272, lr = 0.0001
308
+ [07:47:15] INFO - Epoch 17/50, Iter 30700: Loss = 0.019237644970417023, lr = 0.0001
309
+ [07:48:11] INFO - Epoch 17/50, Iter 30800: Loss = 0.01243587676435709, lr = 0.0001
310
+ [07:49:06] INFO - Epoch 17/50, Iter 30900: Loss = 0.019927412271499634, lr = 0.0001
311
+ [07:51:38] INFO - Epoch 17/50, Iter 31000: Loss = 0.021345121785998344, lr = 0.0001
312
+ [07:52:33] INFO - Epoch 17/50, Iter 31100: Loss = 0.0189402773976326, lr = 0.0001
313
+ [07:53:28] INFO - Epoch 17/50, Iter 31200: Loss = 0.022389506921172142, lr = 0.0001
314
+ [07:54:22] INFO - Epoch 17/50, Iter 31300: Loss = 0.019248703494668007, lr = 0.0001
315
+ [07:55:18] INFO - Epoch 17/50, Iter 31400: Loss = 0.020908750593662262, lr = 0.0001
316
+ [07:56:12] INFO - Epoch 17/50, Iter 31500: Loss = 0.029640033841133118, lr = 0.0001
317
+ [07:57:07] INFO - Epoch 17/50, Iter 31600: Loss = 0.026583340018987656, lr = 0.0001
318
+ [07:58:02] INFO - Epoch 17/50, Iter 31700: Loss = 0.01729031279683113, lr = 0.0001
319
+ [07:58:56] INFO - Epoch 17/50, Iter 31800: Loss = 0.026669491082429886, lr = 0.0001
320
+ [07:59:51] INFO - Epoch 18/50, Iter 31900: Loss = 0.015399916097521782, lr = 0.0001
321
+ [08:02:23] INFO - Epoch 18/50, Iter 32000: Loss = 0.027698248624801636, lr = 0.0001
322
+ [08:03:18] INFO - Epoch 18/50, Iter 32100: Loss = 0.020098572596907616, lr = 0.0001
323
+ [08:04:12] INFO - Epoch 18/50, Iter 32200: Loss = 0.023418741300702095, lr = 0.0001
324
+ [08:05:07] INFO - Epoch 18/50, Iter 32300: Loss = 0.015688564628362656, lr = 0.0001
325
+ [08:06:02] INFO - Epoch 18/50, Iter 32400: Loss = 0.013760192319750786, lr = 0.0001
326
+ [08:06:56] INFO - Epoch 18/50, Iter 32500: Loss = 0.018602928146719933, lr = 0.0001
327
+ [08:07:52] INFO - Epoch 18/50, Iter 32600: Loss = 0.0171047393232584, lr = 0.0001
328
+ [08:08:46] INFO - Epoch 18/50, Iter 32700: Loss = 0.02287128195166588, lr = 0.0001
329
+ [08:09:41] INFO - Epoch 18/50, Iter 32800: Loss = 0.01747080124914646, lr = 0.0001
330
+ [08:10:35] INFO - Epoch 18/50, Iter 32900: Loss = 0.032003749161958694, lr = 0.0001
331
+ [08:13:06] INFO - Epoch 18/50, Iter 33000: Loss = 0.021088197827339172, lr = 0.0001
332
+ [08:14:01] INFO - Epoch 18/50, Iter 33100: Loss = 0.0243061576038599, lr = 0.0001
333
+ [08:14:55] INFO - Epoch 18/50, Iter 33200: Loss = 0.017390495166182518, lr = 0.0001
334
+ [08:15:50] INFO - Epoch 18/50, Iter 33300: Loss = 0.027531778439879417, lr = 0.0001
335
+ [08:16:45] INFO - Epoch 18/50, Iter 33400: Loss = 0.01495380699634552, lr = 0.0001
336
+ [08:17:39] INFO - Epoch 18/50, Iter 33500: Loss = 0.02041369117796421, lr = 0.0001
337
+ [08:18:35] INFO - Epoch 18/50, Iter 33600: Loss = 0.016778916120529175, lr = 0.0001
338
+ [08:19:29] INFO - Epoch 18/50, Iter 33700: Loss = 0.0185483880341053, lr = 0.0001
339
+ [08:20:24] INFO - Epoch 19/50, Iter 33800: Loss = 0.017258750274777412, lr = 0.0001
340
+ [08:21:20] INFO - Epoch 19/50, Iter 33900: Loss = 0.013514120131731033, lr = 0.0001
341
+ [08:23:52] INFO - Epoch 19/50, Iter 34000: Loss = 0.017329292371869087, lr = 0.0001
342
+ [08:24:46] INFO - Epoch 19/50, Iter 34100: Loss = 0.03175392746925354, lr = 0.0001
343
+ [08:25:42] INFO - Epoch 19/50, Iter 34200: Loss = 0.024144772440195084, lr = 0.0001
344
+ [08:26:36] INFO - Epoch 19/50, Iter 34300: Loss = 0.025116432458162308, lr = 0.0001
345
+ [08:27:31] INFO - Epoch 19/50, Iter 34400: Loss = 0.023968493565917015, lr = 0.0001
346
+ [08:28:26] INFO - Epoch 19/50, Iter 34500: Loss = 0.023263823240995407, lr = 0.0001
347
+ [08:29:20] INFO - Epoch 19/50, Iter 34600: Loss = 0.015572518110275269, lr = 0.0001
348
+ [08:30:15] INFO - Epoch 19/50, Iter 34700: Loss = 0.011077907867729664, lr = 0.0001
349
+ [08:31:10] INFO - Epoch 19/50, Iter 34800: Loss = 0.019685542210936546, lr = 0.0001
350
+ [08:32:04] INFO - Epoch 19/50, Iter 34900: Loss = 0.026246516034007072, lr = 0.0001
351
+ [08:34:35] INFO - Epoch 19/50, Iter 35000: Loss = 0.0264703631401062, lr = 0.0001
352
+ [08:35:31] INFO - Epoch 19/50, Iter 35100: Loss = 0.018090050667524338, lr = 0.0001
353
+ [08:36:25] INFO - Epoch 19/50, Iter 35200: Loss = 0.014332180842757225, lr = 0.0001
354
+ [08:37:20] INFO - Epoch 19/50, Iter 35300: Loss = 0.03227975219488144, lr = 0.0001
355
+ [08:38:15] INFO - Epoch 19/50, Iter 35400: Loss = 0.017180195078253746, lr = 0.0001
356
+ [08:39:09] INFO - Epoch 19/50, Iter 35500: Loss = 0.01773938722908497, lr = 0.0001
357
+ [08:40:04] INFO - Epoch 19/50, Iter 35600: Loss = 0.02321586385369301, lr = 0.0001
358
+ [08:41:00] INFO - Epoch 20/50, Iter 35700: Loss = 0.018052995204925537, lr = 0.0001
359
+ [08:41:55] INFO - Epoch 20/50, Iter 35800: Loss = 0.02333519607782364, lr = 0.0001
360
+ [08:42:51] INFO - Epoch 20/50, Iter 35900: Loss = 0.023782718926668167, lr = 0.0001
361
+ [08:45:22] INFO - Epoch 20/50, Iter 36000: Loss = 0.021948453038930893, lr = 0.0001
362
+ [08:46:17] INFO - Epoch 20/50, Iter 36100: Loss = 0.01616925373673439, lr = 0.0001
363
+ [08:47:11] INFO - Epoch 20/50, Iter 36200: Loss = 0.0195147804915905, lr = 0.0001
364
+ [08:48:07] INFO - Epoch 20/50, Iter 36300: Loss = 0.02167724072933197, lr = 0.0001
365
+ [08:49:02] INFO - Epoch 20/50, Iter 36400: Loss = 0.017993919551372528, lr = 0.0001
366
+ [08:49:56] INFO - Epoch 20/50, Iter 36500: Loss = 0.024179894477128983, lr = 0.0001
367
+ [08:50:51] INFO - Epoch 20/50, Iter 36600: Loss = 0.029972080141305923, lr = 0.0001
368
+ [08:51:45] INFO - Epoch 20/50, Iter 36700: Loss = 0.02250525914132595, lr = 0.0001
369
+ [08:52:40] INFO - Epoch 20/50, Iter 36800: Loss = 0.016068585216999054, lr = 0.0001
370
+ [08:53:35] INFO - Epoch 20/50, Iter 36900: Loss = 0.02062491700053215, lr = 0.0001
371
+ [08:56:07] INFO - Epoch 20/50, Iter 37000: Loss = 0.026054339483380318, lr = 0.0001
372
+ [08:57:01] INFO - Epoch 20/50, Iter 37100: Loss = 0.01617574132978916, lr = 0.0001
373
+ [08:57:56] INFO - Epoch 20/50, Iter 37200: Loss = 0.01841990277171135, lr = 0.0001
374
+ [08:58:51] INFO - Epoch 20/50, Iter 37300: Loss = 0.016723550856113434, lr = 0.0001
375
+ [08:59:45] INFO - Epoch 20/50, Iter 37400: Loss = 0.015482468530535698, lr = 0.0001
376
+ [09:00:41] INFO - Epoch 21/50, Iter 37500: Loss = 0.028426745906472206, lr = 0.0001
377
+ [09:01:36] INFO - Epoch 21/50, Iter 37600: Loss = 0.026276376098394394, lr = 0.0001
378
+ [09:02:32] INFO - Epoch 21/50, Iter 37700: Loss = 0.026483114808797836, lr = 0.0001
379
+ [09:03:27] INFO - Epoch 21/50, Iter 37800: Loss = 0.021477442234754562, lr = 0.0001
380
+ [09:04:21] INFO - Epoch 21/50, Iter 37900: Loss = 0.015382439829409122, lr = 0.0001
381
+ [09:06:54] INFO - Epoch 21/50, Iter 38000: Loss = 0.013858610764145851, lr = 0.0001
382
+ [09:07:48] INFO - Epoch 21/50, Iter 38100: Loss = 0.022090336307883263, lr = 0.0001
383
+ [09:08:44] INFO - Epoch 21/50, Iter 38200: Loss = 0.025041067972779274, lr = 0.0001
384
+ [09:09:39] INFO - Epoch 21/50, Iter 38300: Loss = 0.01404337864369154, lr = 0.0001
385
+ [09:10:33] INFO - Epoch 21/50, Iter 38400: Loss = 0.022372154518961906, lr = 0.0001
386
+ [09:11:28] INFO - Epoch 21/50, Iter 38500: Loss = 0.022488964721560478, lr = 0.0001
387
+ [09:12:22] INFO - Epoch 21/50, Iter 38600: Loss = 0.018394947052001953, lr = 0.0001
388
+ [09:13:17] INFO - Epoch 21/50, Iter 38700: Loss = 0.019345279783010483, lr = 0.0001
389
+ [09:14:12] INFO - Epoch 21/50, Iter 38800: Loss = 0.013524915091693401, lr = 0.0001
390
+ [09:15:06] INFO - Epoch 21/50, Iter 38900: Loss = 0.023479681462049484, lr = 0.0001
391
+ [09:17:38] INFO - Epoch 21/50, Iter 39000: Loss = 0.018239330500364304, lr = 0.0001
392
+ [09:18:33] INFO - Epoch 21/50, Iter 39100: Loss = 0.014270618557929993, lr = 0.0001
393
+ [09:19:27] INFO - Epoch 21/50, Iter 39200: Loss = 0.012470152229070663, lr = 0.0001
394
+ [09:20:22] INFO - Epoch 21/50, Iter 39300: Loss = 0.024510135874152184, lr = 0.0001
395
+ [09:21:18] INFO - Epoch 22/50, Iter 39400: Loss = 0.01967580057680607, lr = 0.0001
396
+ [09:22:13] INFO - Epoch 22/50, Iter 39500: Loss = 0.02651473507285118, lr = 0.0001
397
+ [09:23:09] INFO - Epoch 22/50, Iter 39600: Loss = 0.014456840232014656, lr = 0.0001
398
+ [09:24:03] INFO - Epoch 22/50, Iter 39700: Loss = 0.013815360143780708, lr = 0.0001
399
+ [09:24:58] INFO - Epoch 22/50, Iter 39800: Loss = 0.026865314692258835, lr = 0.0001
400
+ [09:25:54] INFO - Epoch 22/50, Iter 39900: Loss = 0.022365324199199677, lr = 0.0001
401
+ [09:28:27] INFO - Epoch 22/50, Iter 40000: Loss = 0.02029530331492424, lr = 0.0001
402
+ [09:29:21] INFO - Epoch 22/50, Iter 40100: Loss = 0.021116379648447037, lr = 0.0001
403
+ [09:30:16] INFO - Epoch 22/50, Iter 40200: Loss = 0.02509278617799282, lr = 0.0001
404
+ [09:31:11] INFO - Epoch 22/50, Iter 40300: Loss = 0.02551993355154991, lr = 0.0001
405
+ [09:32:05] INFO - Epoch 22/50, Iter 40400: Loss = 0.020986683666706085, lr = 0.0001
406
+ [09:33:00] INFO - Epoch 22/50, Iter 40500: Loss = 0.020868226885795593, lr = 0.0001
407
+ [09:33:54] INFO - Epoch 22/50, Iter 40600: Loss = 0.017478734254837036, lr = 0.0001
408
+ [09:34:49] INFO - Epoch 22/50, Iter 40700: Loss = 0.027790624648332596, lr = 0.0001
409
+ [09:35:45] INFO - Epoch 22/50, Iter 40800: Loss = 0.022644832730293274, lr = 0.0001
410
+ [09:36:39] INFO - Epoch 22/50, Iter 40900: Loss = 0.024670612066984177, lr = 0.0001
411
+ [09:39:10] INFO - Epoch 22/50, Iter 41000: Loss = 0.026195334270596504, lr = 0.0001
412
+ [09:40:05] INFO - Epoch 22/50, Iter 41100: Loss = 0.021374046802520752, lr = 0.0001
413
+ [09:41:00] INFO - Epoch 22/50, Iter 41200: Loss = 0.02115592733025551, lr = 0.0001
414
+ [09:41:56] INFO - Epoch 23/50, Iter 41300: Loss = 0.01633710041642189, lr = 0.0001
415
+ [09:42:52] INFO - Epoch 23/50, Iter 41400: Loss = 0.02131003886461258, lr = 0.0001
416
+ [09:43:46] INFO - Epoch 23/50, Iter 41500: Loss = 0.022764872759580612, lr = 0.0001
417
+ [09:44:41] INFO - Epoch 23/50, Iter 41600: Loss = 0.01728042960166931, lr = 0.0001
418
+ [09:45:37] INFO - Epoch 23/50, Iter 41700: Loss = 0.0162839163094759, lr = 0.0001
419
+ [09:46:32] INFO - Epoch 23/50, Iter 41800: Loss = 0.014318926259875298, lr = 0.0001
420
+ [09:47:28] INFO - Epoch 23/50, Iter 41900: Loss = 0.018346164375543594, lr = 0.0001
421
+ [09:49:59] INFO - Epoch 23/50, Iter 42000: Loss = 0.027812600135803223, lr = 0.0001
422
+ [09:50:55] INFO - Epoch 23/50, Iter 42100: Loss = 0.026753295212984085, lr = 0.0001
423
+ [09:51:50] INFO - Epoch 23/50, Iter 42200: Loss = 0.018069680780172348, lr = 0.0001
424
+ [09:52:44] INFO - Epoch 23/50, Iter 42300: Loss = 0.03101518750190735, lr = 0.0001
425
+ [09:53:39] INFO - Epoch 23/50, Iter 42400: Loss = 0.025507837533950806, lr = 0.0001
426
+ [09:54:34] INFO - Epoch 23/50, Iter 42500: Loss = 0.017935875803232193, lr = 0.0001
427
+ [09:55:28] INFO - Epoch 23/50, Iter 42600: Loss = 0.022867443040013313, lr = 0.0001
428
+ [09:56:23] INFO - Epoch 23/50, Iter 42700: Loss = 0.02030709572136402, lr = 0.0001
429
+ [09:57:18] INFO - Epoch 23/50, Iter 42800: Loss = 0.013310606591403484, lr = 0.0001
430
+ [09:58:13] INFO - Epoch 23/50, Iter 42900: Loss = 0.014713610522449017, lr = 0.0001
431
+ [10:00:44] INFO - Epoch 23/50, Iter 43000: Loss = 0.02300114557147026, lr = 0.0001
432
+ [10:01:39] INFO - Epoch 23/50, Iter 43100: Loss = 0.02343389019370079, lr = 0.0001
433
+ [10:02:35] INFO - Epoch 24/50, Iter 43200: Loss = 0.019669387489557266, lr = 0.0001
434
+ [10:03:30] INFO - Epoch 24/50, Iter 43300: Loss = 0.025514639914035797, lr = 0.0001
435
+ [10:04:25] INFO - Epoch 24/50, Iter 43400: Loss = 0.027034897357225418, lr = 0.0001
436
+ [10:05:20] INFO - Epoch 24/50, Iter 43500: Loss = 0.026066435500979424, lr = 0.0001
437
+ [10:06:16] INFO - Epoch 24/50, Iter 43600: Loss = 0.022791586816310883, lr = 0.0001
438
+ [10:07:11] INFO - Epoch 24/50, Iter 43700: Loss = 0.01600833050906658, lr = 0.0001
439
+ [10:08:07] INFO - Epoch 24/50, Iter 43800: Loss = 0.01834738627076149, lr = 0.0001
440
+ [10:09:02] INFO - Epoch 24/50, Iter 43900: Loss = 0.026411669328808784, lr = 0.0001
441
+ [10:11:34] INFO - Epoch 24/50, Iter 44000: Loss = 0.01697351410984993, lr = 0.0001
442
+ [10:12:29] INFO - Epoch 24/50, Iter 44100: Loss = 0.025164766237139702, lr = 0.0001
443
+ [10:13:24] INFO - Epoch 24/50, Iter 44200: Loss = 0.023120088502764702, lr = 0.0001
444
+ [10:14:18] INFO - Epoch 24/50, Iter 44300: Loss = 0.016470227390527725, lr = 0.0001
445
+ [10:15:13] INFO - Epoch 24/50, Iter 44400: Loss = 0.02092874050140381, lr = 0.0001
446
+ [10:16:09] INFO - Epoch 24/50, Iter 44500: Loss = 0.017084982246160507, lr = 0.0001
447
+ [10:17:02] INFO - Epoch 24/50, Iter 44600: Loss = 0.01771422289311886, lr = 0.0001
448
+ [10:17:58] INFO - Epoch 24/50, Iter 44700: Loss = 0.01557396911084652, lr = 0.0001
449
+ [10:18:52] INFO - Epoch 24/50, Iter 44800: Loss = 0.01830480992794037, lr = 0.0001
450
+ [10:19:47] INFO - Epoch 24/50, Iter 44900: Loss = 0.03161770850419998, lr = 0.0001
451
+ [10:22:19] INFO - Epoch 25/50, Iter 45000: Loss = 0.013423663564026356, lr = 0.0001
452
+ [10:23:14] INFO - Epoch 25/50, Iter 45100: Loss = 0.0297955684363842, lr = 0.0001
453
+ [10:24:10] INFO - Epoch 25/50, Iter 45200: Loss = 0.02846469357609749, lr = 0.0001
454
+ [10:25:06] INFO - Epoch 25/50, Iter 45300: Loss = 0.015436829067766666, lr = 0.0001
455
+ [10:26:01] INFO - Epoch 25/50, Iter 45400: Loss = 0.024918153882026672, lr = 0.0001
456
+ [10:26:57] INFO - Epoch 25/50, Iter 45500: Loss = 0.02270306646823883, lr = 0.0001
457
+ [10:27:52] INFO - Epoch 25/50, Iter 45600: Loss = 0.015784474089741707, lr = 0.0001
458
+ [10:28:46] INFO - Epoch 25/50, Iter 45700: Loss = 0.011514103971421719, lr = 0.0001
459
+ [10:29:42] INFO - Epoch 25/50, Iter 45800: Loss = 0.024075977504253387, lr = 0.0001
460
+ [10:30:37] INFO - Epoch 25/50, Iter 45900: Loss = 0.018384993076324463, lr = 0.0001
461
+ [10:33:10] INFO - Epoch 25/50, Iter 46000: Loss = 0.024563699960708618, lr = 0.0001
462
+ [10:34:04] INFO - Epoch 25/50, Iter 46100: Loss = 0.015144889242947102, lr = 0.0001
463
+ [10:34:59] INFO - Epoch 25/50, Iter 46200: Loss = 0.022055502980947495, lr = 0.0001
464
+ [10:35:55] INFO - Epoch 25/50, Iter 46300: Loss = 0.013236483559012413, lr = 0.0001
465
+ [10:36:49] INFO - Epoch 25/50, Iter 46400: Loss = 0.016789842396974564, lr = 0.0001
466
+ [10:37:44] INFO - Epoch 25/50, Iter 46500: Loss = 0.018810316920280457, lr = 0.0001
467
+ [10:38:38] INFO - Epoch 25/50, Iter 46600: Loss = 0.01891239359974861, lr = 0.0001
468
+ [10:39:33] INFO - Epoch 25/50, Iter 46700: Loss = 0.03200780227780342, lr = 0.0001
469
+ [10:40:28] INFO - Epoch 25/50, Iter 46800: Loss = 0.025489578023552895, lr = 0.0001
470
+ [10:41:24] INFO - Epoch 26/50, Iter 46900: Loss = 0.02214771881699562, lr = 0.0001
471
+ [10:43:56] INFO - Epoch 26/50, Iter 47000: Loss = 0.01889549382030964, lr = 0.0001
472
+ [10:44:51] INFO - Epoch 26/50, Iter 47100: Loss = 0.015227919444441795, lr = 0.0001
473
+ [10:45:45] INFO - Epoch 26/50, Iter 47200: Loss = 0.01975785568356514, lr = 0.0001
474
+ [10:46:40] INFO - Epoch 26/50, Iter 47300: Loss = 0.021548938006162643, lr = 0.0001
475
+ [10:47:35] INFO - Epoch 26/50, Iter 47400: Loss = 0.018300775438547134, lr = 0.0001
476
+ [10:48:29] INFO - Epoch 26/50, Iter 47500: Loss = 0.02168145403265953, lr = 0.0001
477
+ [10:49:24] INFO - Epoch 26/50, Iter 47600: Loss = 0.02841881290078163, lr = 0.0001
478
+ [10:50:18] INFO - Epoch 26/50, Iter 47700: Loss = 0.01804378256201744, lr = 0.0001
479
+ [10:51:13] INFO - Epoch 26/50, Iter 47800: Loss = 0.026898138225078583, lr = 0.0001
480
+ [10:52:09] INFO - Epoch 26/50, Iter 47900: Loss = 0.018523452803492546, lr = 0.0001
481
+ [10:54:40] INFO - Epoch 26/50, Iter 48000: Loss = 0.016216814517974854, lr = 0.0001
482
+ [10:55:34] INFO - Epoch 26/50, Iter 48100: Loss = 0.02262328565120697, lr = 0.0001
483
+ [10:56:29] INFO - Epoch 26/50, Iter 48200: Loss = 0.015000266954302788, lr = 0.0001
484
+ [10:57:25] INFO - Epoch 26/50, Iter 48300: Loss = 0.02180442586541176, lr = 0.0001
485
+ [10:58:20] INFO - Epoch 26/50, Iter 48400: Loss = 0.025278791785240173, lr = 0.0001
486
+ [10:59:14] INFO - Epoch 26/50, Iter 48500: Loss = 0.03473420441150665, lr = 0.0001
487
+ [11:00:09] INFO - Epoch 26/50, Iter 48600: Loss = 0.017245961353182793, lr = 0.0001
488
+ [11:01:03] INFO - Epoch 26/50, Iter 48700: Loss = 0.03179230913519859, lr = 0.0001
489
+ [11:01:59] INFO - Epoch 27/50, Iter 48800: Loss = 0.015805833041667938, lr = 0.0001
490
+ [11:02:54] INFO - Epoch 27/50, Iter 48900: Loss = 0.02080763876438141, lr = 0.0001
491
+ [11:05:26] INFO - Epoch 27/50, Iter 49000: Loss = 0.020735610276460648, lr = 0.0001
492
+ [11:06:21] INFO - Epoch 27/50, Iter 49100: Loss = 0.024737179279327393, lr = 0.0001
493
+ [11:07:16] INFO - Epoch 27/50, Iter 49200: Loss = 0.026094382628798485, lr = 0.0001
494
+ [11:08:10] INFO - Epoch 27/50, Iter 49300: Loss = 0.021053478121757507, lr = 0.0001
495
+ [11:09:05] INFO - Epoch 27/50, Iter 49400: Loss = 0.014476573094725609, lr = 0.0001
496
+ [11:10:01] INFO - Epoch 27/50, Iter 49500: Loss = 0.030272990465164185, lr = 0.0001
497
+ [11:10:55] INFO - Epoch 27/50, Iter 49600: Loss = 0.022585971280932426, lr = 0.0001
498
+ [11:11:50] INFO - Epoch 27/50, Iter 49700: Loss = 0.01895831525325775, lr = 0.0001
499
+ [11:12:44] INFO - Epoch 27/50, Iter 49800: Loss = 0.018344363197684288, lr = 0.0001
500
+ [11:13:39] INFO - Epoch 27/50, Iter 49900: Loss = 0.022272832691669464, lr = 0.0001
501
+ [11:16:10] INFO - Epoch 27/50, Iter 50000: Loss = 0.022018130868673325, lr = 0.0001
502
+ [11:17:06] INFO - Epoch 27/50, Iter 50100: Loss = 0.027774281799793243, lr = 0.0001
503
+ [11:18:00] INFO - Epoch 27/50, Iter 50200: Loss = 0.014724764972925186, lr = 0.0001
504
+ [11:18:55] INFO - Epoch 27/50, Iter 50300: Loss = 0.018815312534570694, lr = 0.0001
505
+ [11:19:50] INFO - Epoch 27/50, Iter 50400: Loss = 0.019056078046560287, lr = 0.0001
506
+ [11:20:44] INFO - Epoch 27/50, Iter 50500: Loss = 0.01948639005422592, lr = 0.0001
507
+ [11:21:39] INFO - Epoch 27/50, Iter 50600: Loss = 0.02332192286849022, lr = 0.0001
508
+ [11:22:35] INFO - Epoch 28/50, Iter 50700: Loss = 0.02340688183903694, lr = 0.0001
509
+ [11:23:31] INFO - Epoch 28/50, Iter 50800: Loss = 0.02822597697377205, lr = 0.0001
510
+ [11:24:26] INFO - Epoch 28/50, Iter 50900: Loss = 0.02604568563401699, lr = 0.0001
511
+ [11:26:58] INFO - Epoch 28/50, Iter 51000: Loss = 0.015130102634429932, lr = 0.0001
512
+ [11:27:53] INFO - Epoch 28/50, Iter 51100: Loss = 0.020247958600521088, lr = 0.0001
513
+ [11:28:47] INFO - Epoch 28/50, Iter 51200: Loss = 0.021361518651247025, lr = 0.0001
514
+ [11:29:42] INFO - Epoch 28/50, Iter 51300: Loss = 0.0154896704480052, lr = 0.0001
515
+ [11:30:36] INFO - Epoch 28/50, Iter 51400: Loss = 0.020418627187609673, lr = 0.0001
516
+ [11:31:31] INFO - Epoch 28/50, Iter 51500: Loss = 0.016209501773118973, lr = 0.0001
517
+ [11:32:26] INFO - Epoch 28/50, Iter 51600: Loss = 0.021547267213463783, lr = 0.0001
518
+ [11:33:20] INFO - Epoch 28/50, Iter 51700: Loss = 0.03097592294216156, lr = 0.0001
519
+ [11:34:16] INFO - Epoch 28/50, Iter 51800: Loss = 0.01853656955063343, lr = 0.0001
520
+ [11:35:11] INFO - Epoch 28/50, Iter 51900: Loss = 0.025320153683423996, lr = 0.0001
521
+ [11:37:42] INFO - Epoch 28/50, Iter 52000: Loss = 0.01918005384504795, lr = 0.0001
522
+ [11:38:36] INFO - Epoch 28/50, Iter 52100: Loss = 0.02268061600625515, lr = 0.0001
523
+ [11:39:32] INFO - Epoch 28/50, Iter 52200: Loss = 0.024810226634144783, lr = 0.0001
524
+ [11:40:27] INFO - Epoch 28/50, Iter 52300: Loss = 0.02219560742378235, lr = 0.0001
525
+ [11:41:21] INFO - Epoch 28/50, Iter 52400: Loss = 0.027511518448591232, lr = 0.0001
526
+ [11:42:16] INFO - Epoch 29/50, Iter 52500: Loss = 0.016894716769456863, lr = 0.0001
527
+ [11:43:12] INFO - Epoch 29/50, Iter 52600: Loss = 0.01918671280145645, lr = 0.0001
528
+ [11:44:07] INFO - Epoch 29/50, Iter 52700: Loss = 0.021322811022400856, lr = 0.0001
529
+ [11:45:03] INFO - Epoch 29/50, Iter 52800: Loss = 0.01693873107433319, lr = 0.0001
530
+ [11:45:58] INFO - Epoch 29/50, Iter 52900: Loss = 0.028586234897375107, lr = 0.0001
531
+ [11:48:30] INFO - Epoch 29/50, Iter 53000: Loss = 0.02094537392258644, lr = 0.0001
532
+ [11:49:25] INFO - Epoch 29/50, Iter 53100: Loss = 0.025890830904245377, lr = 0.0001
533
+ [11:50:20] INFO - Epoch 29/50, Iter 53200: Loss = 0.019293418154120445, lr = 0.0001
534
+ [11:51:14] INFO - Epoch 29/50, Iter 53300: Loss = 0.013301231898367405, lr = 0.0001
535
+ [11:52:10] INFO - Epoch 29/50, Iter 53400: Loss = 0.024367133155465126, lr = 0.0001
536
+ [11:53:04] INFO - Epoch 29/50, Iter 53500: Loss = 0.013333385810256004, lr = 0.0001
537
+ [11:53:59] INFO - Epoch 29/50, Iter 53600: Loss = 0.021088868379592896, lr = 0.0001
538
+ [11:54:53] INFO - Epoch 29/50, Iter 53700: Loss = 0.014782575890421867, lr = 0.0001
539
+ [11:55:48] INFO - Epoch 29/50, Iter 53800: Loss = 0.019235175102949142, lr = 0.0001
540
+ [11:56:43] INFO - Epoch 29/50, Iter 53900: Loss = 0.02775110863149166, lr = 0.0001
541
+ [11:59:15] INFO - Epoch 29/50, Iter 54000: Loss = 0.014202380552887917, lr = 0.0001
542
+ [12:00:10] INFO - Epoch 29/50, Iter 54100: Loss = 0.021274959668517113, lr = 0.0001
543
+ [12:01:04] INFO - Epoch 29/50, Iter 54200: Loss = 0.028708720579743385, lr = 0.0001
544
+ [12:01:59] INFO - Epoch 29/50, Iter 54300: Loss = 0.024009495973587036, lr = 0.0001
545
+ [12:02:55] INFO - Epoch 30/50, Iter 54400: Loss = 0.018383020535111427, lr = 0.0001
546
+ [12:03:50] INFO - Epoch 30/50, Iter 54500: Loss = 0.012869146652519703, lr = 0.0001
547
+ [12:04:46] INFO - Epoch 30/50, Iter 54600: Loss = 0.015052242204546928, lr = 0.0001
548
+ [12:05:41] INFO - Epoch 30/50, Iter 54700: Loss = 0.021794060245156288, lr = 0.0001
549
+ [12:06:37] INFO - Epoch 30/50, Iter 54800: Loss = 0.021674180403351784, lr = 0.0001
550
+ [12:07:31] INFO - Epoch 30/50, Iter 54900: Loss = 0.0307894479483366, lr = 0.0001
551
+ [12:10:04] INFO - Epoch 30/50, Iter 55000: Loss = 0.023494703695178032, lr = 0.0001
552
+ [12:10:59] INFO - Epoch 30/50, Iter 55100: Loss = 0.025401834398508072, lr = 0.0001
553
+ [12:11:54] INFO - Epoch 30/50, Iter 55200: Loss = 0.021761178970336914, lr = 0.0001
554
+ [12:12:49] INFO - Epoch 30/50, Iter 55300: Loss = 0.02898026630282402, lr = 0.0001
555
+ [12:13:44] INFO - Epoch 30/50, Iter 55400: Loss = 0.02216275781393051, lr = 0.0001
556
+ [12:14:38] INFO - Epoch 30/50, Iter 55500: Loss = 0.00930317398160696, lr = 0.0001
557
+ [12:15:33] INFO - Epoch 30/50, Iter 55600: Loss = 0.024549826979637146, lr = 0.0001
558
+ [12:16:29] INFO - Epoch 30/50, Iter 55700: Loss = 0.016341213136911392, lr = 0.0001
559
+ [12:17:23] INFO - Epoch 30/50, Iter 55800: Loss = 0.015864314511418343, lr = 0.0001
560
+ [12:18:18] INFO - Epoch 30/50, Iter 55900: Loss = 0.034297745674848557, lr = 0.0001
561
+ [12:20:49] INFO - Epoch 30/50, Iter 56000: Loss = 0.02956249937415123, lr = 0.0001
562
+ [12:21:44] INFO - Epoch 30/50, Iter 56100: Loss = 0.02114814706146717, lr = 0.0001
563
+ [12:22:40] INFO - Epoch 30/50, Iter 56200: Loss = 0.0200330913066864, lr = 0.0001
564
+ [12:23:35] INFO - Epoch 31/50, Iter 56300: Loss = 0.026903297752141953, lr = 0.0001
565
+ [12:24:31] INFO - Epoch 31/50, Iter 56400: Loss = 0.02994358167052269, lr = 0.0001
566
+ [12:25:26] INFO - Epoch 31/50, Iter 56500: Loss = 0.016208231449127197, lr = 0.0001
567
+ [12:26:22] INFO - Epoch 31/50, Iter 56600: Loss = 0.029720913618803024, lr = 0.0001
568
+ [12:27:16] INFO - Epoch 31/50, Iter 56700: Loss = 0.021973680704832077, lr = 0.0001
569
+ [12:28:11] INFO - Epoch 31/50, Iter 56800: Loss = 0.017940720543265343, lr = 0.0001
570
+ [12:29:07] INFO - Epoch 31/50, Iter 56900: Loss = 0.022731531411409378, lr = 0.0001
571
+ [12:31:40] INFO - Epoch 31/50, Iter 57000: Loss = 0.016729535534977913, lr = 0.0001
572
+ [12:32:35] INFO - Epoch 31/50, Iter 57100: Loss = 0.026968562975525856, lr = 0.0001
573
+ [12:33:29] INFO - Epoch 31/50, Iter 57200: Loss = 0.015602253377437592, lr = 0.0001
574
+ [12:34:24] INFO - Epoch 31/50, Iter 57300: Loss = 0.028429606929421425, lr = 0.0001
575
+ [12:35:20] INFO - Epoch 31/50, Iter 57400: Loss = 0.021183405071496964, lr = 0.0001
576
+ [12:36:14] INFO - Epoch 31/50, Iter 57500: Loss = 0.024300210177898407, lr = 0.0001
577
+ [12:37:09] INFO - Epoch 31/50, Iter 57600: Loss = 0.017051223665475845, lr = 0.0001
578
+ [12:38:03] INFO - Epoch 31/50, Iter 57700: Loss = 0.016109324991703033, lr = 0.0001
579
+ [12:38:58] INFO - Epoch 31/50, Iter 57800: Loss = 0.019427603110671043, lr = 0.0001
580
+ [12:39:53] INFO - Epoch 31/50, Iter 57900: Loss = 0.030664775520563126, lr = 0.0001
581
+ [12:42:25] INFO - Epoch 31/50, Iter 58000: Loss = 0.021199747920036316, lr = 0.0001
582
+ [12:43:20] INFO - Epoch 31/50, Iter 58100: Loss = 0.01854831352829933, lr = 0.0001
583
+ [12:44:16] INFO - Epoch 32/50, Iter 58200: Loss = 0.01928992196917534, lr = 0.0001
584
+ [12:45:11] INFO - Epoch 32/50, Iter 58300: Loss = 0.018576214089989662, lr = 0.0001
585
+ [12:46:00] INFO - Epoch 32/50, Iter 58400: Loss = 0.019123028963804245, lr = 0.0001
resnet/log/iter_1000.png ADDED
resnet/log/iter_10000.png ADDED
resnet/log/iter_11000.png ADDED
resnet/log/iter_12000.png ADDED
resnet/log/iter_13000.png ADDED
resnet/log/iter_14000.png ADDED
resnet/log/iter_15000.png ADDED
resnet/log/iter_16000.png ADDED
resnet/log/iter_17000.png ADDED
resnet/log/iter_18000.png ADDED
resnet/log/iter_19000.png ADDED
resnet/log/iter_2000.png ADDED
resnet/log/iter_20000.png ADDED
resnet/log/iter_21000.png ADDED
resnet/log/iter_22000.png ADDED
resnet/log/iter_23000.png ADDED
resnet/log/iter_24000.png ADDED
resnet/log/iter_25000.png ADDED
resnet/log/iter_26000.png ADDED
resnet/log/iter_27000.png ADDED
resnet/log/iter_28000.png ADDED
resnet/log/iter_29000.png ADDED
resnet/log/iter_3000.png ADDED
resnet/log/iter_30000.png ADDED
resnet/log/iter_31000.png ADDED
resnet/log/iter_32000.png ADDED
resnet/log/iter_33000.png ADDED
resnet/log/iter_34000.png ADDED
resnet/log/iter_35000.png ADDED
resnet/log/iter_36000.png ADDED
resnet/log/iter_37000.png ADDED
resnet/log/iter_38000.png ADDED
resnet/log/iter_39000.png ADDED
resnet/log/iter_4000.png ADDED
resnet/log/iter_40000.png ADDED
resnet/log/iter_41000.png ADDED
resnet/log/iter_42000.png ADDED
resnet/log/iter_43000.png ADDED
resnet/log/iter_44000.png ADDED
resnet/log/iter_45000.png ADDED
resnet/log/iter_46000.png ADDED
resnet/log/iter_47000.png ADDED
resnet/log/iter_48000.png ADDED
resnet/log/iter_49000.png ADDED