Ngô Đức Bảo
commited on
Commit
•
2a9ad6d
1
Parent(s):
53e6466
Upload 320 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Utils.ipynb +0 -0
- note.txt +10 -0
- resnet/DDPM_ResNet.ipynb +0 -0
- resnet/DDPM_ResNet.py +952 -0
- resnet/DDPM_ResNet_sample.py +856 -0
- resnet/log/info.log +585 -0
- resnet/log/iter_1000.png +0 -0
- resnet/log/iter_10000.png +0 -0
- resnet/log/iter_11000.png +0 -0
- resnet/log/iter_12000.png +0 -0
- resnet/log/iter_13000.png +0 -0
- resnet/log/iter_14000.png +0 -0
- resnet/log/iter_15000.png +0 -0
- resnet/log/iter_16000.png +0 -0
- resnet/log/iter_17000.png +0 -0
- resnet/log/iter_18000.png +0 -0
- resnet/log/iter_19000.png +0 -0
- resnet/log/iter_2000.png +0 -0
- resnet/log/iter_20000.png +0 -0
- resnet/log/iter_21000.png +0 -0
- resnet/log/iter_22000.png +0 -0
- resnet/log/iter_23000.png +0 -0
- resnet/log/iter_24000.png +0 -0
- resnet/log/iter_25000.png +0 -0
- resnet/log/iter_26000.png +0 -0
- resnet/log/iter_27000.png +0 -0
- resnet/log/iter_28000.png +0 -0
- resnet/log/iter_29000.png +0 -0
- resnet/log/iter_3000.png +0 -0
- resnet/log/iter_30000.png +0 -0
- resnet/log/iter_31000.png +0 -0
- resnet/log/iter_32000.png +0 -0
- resnet/log/iter_33000.png +0 -0
- resnet/log/iter_34000.png +0 -0
- resnet/log/iter_35000.png +0 -0
- resnet/log/iter_36000.png +0 -0
- resnet/log/iter_37000.png +0 -0
- resnet/log/iter_38000.png +0 -0
- resnet/log/iter_39000.png +0 -0
- resnet/log/iter_4000.png +0 -0
- resnet/log/iter_40000.png +0 -0
- resnet/log/iter_41000.png +0 -0
- resnet/log/iter_42000.png +0 -0
- resnet/log/iter_43000.png +0 -0
- resnet/log/iter_44000.png +0 -0
- resnet/log/iter_45000.png +0 -0
- resnet/log/iter_46000.png +0 -0
- resnet/log/iter_47000.png +0 -0
- resnet/log/iter_48000.png +0 -0
- resnet/log/iter_49000.png +0 -0
Utils.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
note.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
File Utils.ipynb bao gồm:
|
2 |
+
- Di chuyển ảnh: Copy ngẫu nhiên 5k ảnh trong 10k ảnh được sample (cho việc tính FID 10k) từ các mô hình để tính IS 5k.
|
3 |
+
- Tính FID 10k và IS 5k bằng thư viện torch-fidelity (https://github.com/toshas/torch-fidelity)
|
4 |
+
|
5 |
+
Trong mỗi thư mục (ví dụ resnet) gồm:
|
6 |
+
- Một thư mục "model" chứa checkpoint của mô hình cùng tên với thư mục gốc (ở đây là resnet) tại epoch thứ 30.
|
7 |
+
- Một thư mục "log" chứa log và ảnh sample sau mỗi 1000 iter. Số lượng ảnh sample có thể không bằng nhau do ban đầu để max_epoch là 50.
|
8 |
+
- Một tệp "DDPM_ResNet.ipynb", ở đây, ResNet chỉ là 1 ví dụ, với các mô hình khác sẽ có tên là "DDPM_ResNet_wo_t.ipynb" (mô hình Res-Net không sử dụng thời gian t), "DDPM_UNet.ipynb" (mô hình U-Net), "DDPM_UNet_wo_t.ipynb" (mô hình U-Net không có thời gian t). Trong đây sẽ tách rõ các phần của mô hình, code dùng để train, ... Mục đích chính của tệp này là dùng để huấn luyện mô hình.
|
9 |
+
- Một tệp "DDPM_ResNet.py", tên thay đổi theo mô hình như trên. Đây chỉ là bản convert từ một tệp ".ipynb" sang ".py" do treo máy nhà qua đêm, chạy trên tệp ".py" bằng terminal sẽ nhẹ nhàng hơn.
|
10 |
+
- Một tệp "DDPM_ResNet_sample.py", tên thay đổi theo mô hình như trên. Đây là bản chỉnh sửa từ tệp ".py", xoá hết tất cả các code về gọi data, huấn luyện, save log, ... và thay thế bằng code dùng để sample và lưu ảnh.
|
resnet/DDPM_ResNet.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
resnet/DDPM_ResNet.py
ADDED
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# # Library
|
5 |
+
|
6 |
+
# In[1]:
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
from torch.cuda.amp import autocast
|
13 |
+
|
14 |
+
import torchvision
|
15 |
+
from torchvision.transforms import transforms
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
|
18 |
+
from torch.optim import Adam
|
19 |
+
|
20 |
+
from einops import rearrange, reduce, repeat
|
21 |
+
import math
|
22 |
+
from random import random
|
23 |
+
|
24 |
+
from collections import namedtuple
|
25 |
+
from functools import partial
|
26 |
+
from tqdm.auto import tqdm
|
27 |
+
import logging
|
28 |
+
import os
|
29 |
+
|
30 |
+
from PIL import Image
|
31 |
+
from torchvision import utils
|
32 |
+
|
33 |
+
|
34 |
+
# # Helper
|
35 |
+
|
36 |
+
# ### Constant
|
37 |
+
|
38 |
+
# In[2]:
|
39 |
+
|
40 |
+
|
41 |
+
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
|
42 |
+
|
43 |
+
|
44 |
+
# ### Functions
|
45 |
+
|
46 |
+
# In[3]:
|
47 |
+
|
48 |
+
|
49 |
+
def exists(x):
|
50 |
+
return x is not None
|
51 |
+
|
52 |
+
def default(val, d):
|
53 |
+
if exists(val):
|
54 |
+
return val
|
55 |
+
return d() if callable(d) else d
|
56 |
+
|
57 |
+
|
58 |
+
# In[4]:
|
59 |
+
|
60 |
+
|
61 |
+
def cast_tuple(t, length = 1):
|
62 |
+
if isinstance(t, tuple):
|
63 |
+
return t
|
64 |
+
return ((t,) * length)
|
65 |
+
|
66 |
+
|
67 |
+
# In[5]:
|
68 |
+
|
69 |
+
|
70 |
+
def divisible_by(numer, denom):
|
71 |
+
return (numer % denom) == 0
|
72 |
+
|
73 |
+
|
74 |
+
# In[6]:
|
75 |
+
|
76 |
+
|
77 |
+
def identity(t, *args, **kwargs):
|
78 |
+
return t
|
79 |
+
|
80 |
+
|
81 |
+
# In[7]:
|
82 |
+
|
83 |
+
|
84 |
+
def cycle(dl):
|
85 |
+
while True:
|
86 |
+
for data in dl:
|
87 |
+
yield data
|
88 |
+
|
89 |
+
|
90 |
+
# In[8]:
|
91 |
+
|
92 |
+
|
93 |
+
def has_int_squareroot(num):
|
94 |
+
return (math.sqrt(num) ** 2) == num
|
95 |
+
|
96 |
+
|
97 |
+
# In[9]:
|
98 |
+
|
99 |
+
|
100 |
+
def num_to_groups(num, divisor):
|
101 |
+
groups = num // divisor
|
102 |
+
remainder = num % divisor
|
103 |
+
arr = [divisor] * groups
|
104 |
+
if remainder > 0:
|
105 |
+
arr.append(remainder)
|
106 |
+
return arr
|
107 |
+
|
108 |
+
|
109 |
+
# In[10]:
|
110 |
+
|
111 |
+
|
112 |
+
def convert_image_to_fn(img_type, image):
|
113 |
+
if image.mode != img_type:
|
114 |
+
return image.convert(img_type)
|
115 |
+
return image
|
116 |
+
|
117 |
+
|
118 |
+
# In[11]:
|
119 |
+
|
120 |
+
|
121 |
+
def extract(a, t, x_shape):
|
122 |
+
b, *_ = t.shape
|
123 |
+
out = a.gather(-1, t)
|
124 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
125 |
+
|
126 |
+
|
127 |
+
# ### Normalization Functions
|
128 |
+
|
129 |
+
# In[12]:
|
130 |
+
|
131 |
+
|
132 |
+
def normalize_to_neg_one_to_one(img):
|
133 |
+
return img * 2 - 1
|
134 |
+
|
135 |
+
def unnormalize_to_zero_to_one(t):
|
136 |
+
return (t + 1) * 0.5
|
137 |
+
|
138 |
+
|
139 |
+
# ### Sinusoidal positional embeds
|
140 |
+
|
141 |
+
# In[13]:
|
142 |
+
|
143 |
+
|
144 |
+
class SinusoidalPosEmb(nn.Module):
|
145 |
+
def __init__(self, dim, theta = 10000):
|
146 |
+
super().__init__()
|
147 |
+
self.dim = dim
|
148 |
+
self.theta = theta
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
device = x.device
|
152 |
+
half_dim = self.dim // 2
|
153 |
+
emb = math.log(self.theta) / (half_dim - 1)
|
154 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
155 |
+
emb = x[:, None] * emb[None, :]
|
156 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
157 |
+
return emb
|
158 |
+
|
159 |
+
|
160 |
+
# In[14]:
|
161 |
+
|
162 |
+
|
163 |
+
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
|
164 |
+
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
|
165 |
+
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
|
166 |
+
|
167 |
+
def __init__(self, dim, is_random = False):
|
168 |
+
super().__init__()
|
169 |
+
assert divisible_by(dim, 2)
|
170 |
+
half_dim = dim // 2
|
171 |
+
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
x = rearrange(x, 'b -> b 1')
|
175 |
+
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
|
176 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
|
177 |
+
fouriered = torch.cat((x, fouriered), dim = -1)
|
178 |
+
return fouriered
|
179 |
+
|
180 |
+
|
181 |
+
# ### Schedule
|
182 |
+
|
183 |
+
# In[15]:
|
184 |
+
|
185 |
+
|
186 |
+
def linear_beta_schedule(timesteps):
|
187 |
+
"""
|
188 |
+
linear schedule, proposed in original ddpm paper
|
189 |
+
"""
|
190 |
+
scale = 1000 / timesteps
|
191 |
+
beta_start = scale * 0.0001
|
192 |
+
beta_end = scale * 0.02
|
193 |
+
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
194 |
+
|
195 |
+
|
196 |
+
# In[16]:
|
197 |
+
|
198 |
+
|
199 |
+
def cosine_beta_schedule(timesteps, s = 0.008):
|
200 |
+
"""
|
201 |
+
cosine schedule
|
202 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
203 |
+
"""
|
204 |
+
steps = timesteps + 1
|
205 |
+
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
|
206 |
+
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
|
207 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
208 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
209 |
+
return torch.clip(betas, 0, 0.999)
|
210 |
+
|
211 |
+
|
212 |
+
# In[17]:
|
213 |
+
|
214 |
+
|
215 |
+
def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
|
216 |
+
"""
|
217 |
+
sigmoid schedule
|
218 |
+
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
|
219 |
+
better for images > 64x64, when used during training
|
220 |
+
"""
|
221 |
+
steps = timesteps + 1
|
222 |
+
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
|
223 |
+
v_start = torch.tensor(start / tau).sigmoid()
|
224 |
+
v_end = torch.tensor(end / tau).sigmoid()
|
225 |
+
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
|
226 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
227 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
228 |
+
return torch.clip(betas, 0, 0.999)
|
229 |
+
|
230 |
+
|
231 |
+
# # Diffusion model
|
232 |
+
|
233 |
+
# In[18]:
|
234 |
+
|
235 |
+
|
236 |
+
class GaussianDiffusion(nn.Module):
|
237 |
+
# Copy from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L163
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
model,
|
242 |
+
*,
|
243 |
+
image_size,
|
244 |
+
timesteps = 1000,
|
245 |
+
sampling_timesteps = None,
|
246 |
+
objective = 'pred_noise',
|
247 |
+
beta_schedule = 'linear',
|
248 |
+
schedule_fn_kwargs = dict(),
|
249 |
+
ddim_sampling_eta = 0.,
|
250 |
+
auto_normalize = True,
|
251 |
+
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
252 |
+
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
|
253 |
+
min_snr_gamma = 5
|
254 |
+
):
|
255 |
+
super().__init__()
|
256 |
+
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
|
257 |
+
assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond
|
258 |
+
|
259 |
+
self.model = model
|
260 |
+
|
261 |
+
self.channels = self.model.channels
|
262 |
+
self.self_condition = self.model.self_condition
|
263 |
+
|
264 |
+
self.image_size = image_size
|
265 |
+
|
266 |
+
self.objective = objective
|
267 |
+
|
268 |
+
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
|
269 |
+
|
270 |
+
if beta_schedule == 'linear':
|
271 |
+
beta_schedule_fn = linear_beta_schedule
|
272 |
+
elif beta_schedule == 'cosine':
|
273 |
+
beta_schedule_fn = cosine_beta_schedule
|
274 |
+
elif beta_schedule == 'sigmoid':
|
275 |
+
beta_schedule_fn = sigmoid_beta_schedule
|
276 |
+
else:
|
277 |
+
raise ValueError(f'unknown beta schedule {beta_schedule}')
|
278 |
+
|
279 |
+
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
|
280 |
+
|
281 |
+
alphas = 1. - betas
|
282 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
283 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
284 |
+
|
285 |
+
timesteps, = betas.shape
|
286 |
+
self.num_timesteps = int(timesteps)
|
287 |
+
|
288 |
+
# sampling related parameters
|
289 |
+
|
290 |
+
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
|
291 |
+
|
292 |
+
assert self.sampling_timesteps <= timesteps
|
293 |
+
self.is_ddim_sampling = self.sampling_timesteps < timesteps
|
294 |
+
self.ddim_sampling_eta = ddim_sampling_eta
|
295 |
+
|
296 |
+
# helper function to register buffer from float64 to float32
|
297 |
+
|
298 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
299 |
+
|
300 |
+
register_buffer('betas', betas)
|
301 |
+
register_buffer('alphas_cumprod', alphas_cumprod)
|
302 |
+
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
303 |
+
|
304 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
305 |
+
|
306 |
+
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
307 |
+
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
308 |
+
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
309 |
+
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
310 |
+
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
311 |
+
|
312 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
313 |
+
|
314 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
315 |
+
|
316 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
317 |
+
|
318 |
+
register_buffer('posterior_variance', posterior_variance)
|
319 |
+
|
320 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
321 |
+
|
322 |
+
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
323 |
+
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
324 |
+
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
325 |
+
|
326 |
+
# offset noise strength - in blogpost, they claimed 0.1 was ideal
|
327 |
+
|
328 |
+
self.offset_noise_strength = offset_noise_strength
|
329 |
+
|
330 |
+
# derive loss weight
|
331 |
+
# snr - signal noise ratio
|
332 |
+
|
333 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
334 |
+
|
335 |
+
# https://arxiv.org/abs/2303.09556
|
336 |
+
|
337 |
+
maybe_clipped_snr = snr.clone()
|
338 |
+
if min_snr_loss_weight:
|
339 |
+
maybe_clipped_snr.clamp_(max = min_snr_gamma)
|
340 |
+
|
341 |
+
if objective == 'pred_noise':
|
342 |
+
register_buffer('loss_weight', maybe_clipped_snr / snr)
|
343 |
+
elif objective == 'pred_x0':
|
344 |
+
register_buffer('loss_weight', maybe_clipped_snr)
|
345 |
+
elif objective == 'pred_v':
|
346 |
+
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
|
347 |
+
|
348 |
+
# auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
|
349 |
+
|
350 |
+
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
|
351 |
+
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
|
352 |
+
|
353 |
+
@property
|
354 |
+
def device(self):
|
355 |
+
return self.betas.device
|
356 |
+
|
357 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
358 |
+
return (
|
359 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
360 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
361 |
+
)
|
362 |
+
|
363 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
364 |
+
return (
|
365 |
+
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
366 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
367 |
+
)
|
368 |
+
|
369 |
+
def predict_v(self, x_start, t, noise):
|
370 |
+
return (
|
371 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
372 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
373 |
+
)
|
374 |
+
|
375 |
+
def predict_start_from_v(self, x_t, t, v):
|
376 |
+
return (
|
377 |
+
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
378 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
379 |
+
)
|
380 |
+
|
381 |
+
def q_posterior(self, x_start, x_t, t):
|
382 |
+
posterior_mean = (
|
383 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
384 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
385 |
+
)
|
386 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
387 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
388 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
389 |
+
|
390 |
+
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
|
391 |
+
model_output = self.model(x, t, x_self_cond)
|
392 |
+
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
|
393 |
+
|
394 |
+
if self.objective == 'pred_noise':
|
395 |
+
pred_noise = model_output
|
396 |
+
x_start = self.predict_start_from_noise(x, t, pred_noise)
|
397 |
+
x_start = maybe_clip(x_start)
|
398 |
+
|
399 |
+
if clip_x_start and rederive_pred_noise:
|
400 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
401 |
+
|
402 |
+
elif self.objective == 'pred_x0':
|
403 |
+
x_start = model_output
|
404 |
+
x_start = maybe_clip(x_start)
|
405 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
406 |
+
|
407 |
+
elif self.objective == 'pred_v':
|
408 |
+
v = model_output
|
409 |
+
x_start = self.predict_start_from_v(x, t, v)
|
410 |
+
x_start = maybe_clip(x_start)
|
411 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
412 |
+
|
413 |
+
return ModelPrediction(pred_noise, x_start)
|
414 |
+
|
415 |
+
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
|
416 |
+
preds = self.model_predictions(x, t, x_self_cond)
|
417 |
+
x_start = preds.pred_x_start
|
418 |
+
|
419 |
+
if clip_denoised:
|
420 |
+
x_start.clamp_(-1., 1.)
|
421 |
+
|
422 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
|
423 |
+
return model_mean, posterior_variance, posterior_log_variance, x_start
|
424 |
+
|
425 |
+
@torch.inference_mode()
|
426 |
+
def p_sample(self, x, t: int, x_self_cond = None):
|
427 |
+
b, *_, device = *x.shape, self.device
|
428 |
+
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
|
429 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
|
430 |
+
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
|
431 |
+
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
|
432 |
+
return pred_img, x_start
|
433 |
+
|
434 |
+
@torch.inference_mode()
|
435 |
+
def p_sample_loop(self, shape, return_all_timesteps = False):
|
436 |
+
batch, device = shape[0], self.device
|
437 |
+
|
438 |
+
img = torch.randn(shape, device = device)
|
439 |
+
imgs = [img]
|
440 |
+
|
441 |
+
x_start = None
|
442 |
+
|
443 |
+
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
444 |
+
self_cond = x_start if self.self_condition else None
|
445 |
+
img, x_start = self.p_sample(img, t, self_cond)
|
446 |
+
imgs.append(img)
|
447 |
+
|
448 |
+
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
449 |
+
|
450 |
+
ret = self.unnormalize(ret)
|
451 |
+
return ret
|
452 |
+
|
453 |
+
@torch.inference_mode()
|
454 |
+
def ddim_sample(self, shape, return_all_timesteps = False):
|
455 |
+
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
|
456 |
+
|
457 |
+
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
458 |
+
times = list(reversed(times.int().tolist()))
|
459 |
+
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
460 |
+
|
461 |
+
img = torch.randn(shape, device = device)
|
462 |
+
imgs = [img]
|
463 |
+
|
464 |
+
x_start = None
|
465 |
+
|
466 |
+
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
467 |
+
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
468 |
+
self_cond = x_start if self.self_condition else None
|
469 |
+
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
|
470 |
+
|
471 |
+
if time_next < 0:
|
472 |
+
img = x_start
|
473 |
+
imgs.append(img)
|
474 |
+
continue
|
475 |
+
|
476 |
+
alpha = self.alphas_cumprod[time]
|
477 |
+
alpha_next = self.alphas_cumprod[time_next]
|
478 |
+
|
479 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
480 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
481 |
+
|
482 |
+
noise = torch.randn_like(img)
|
483 |
+
|
484 |
+
img = x_start * alpha_next.sqrt() + \
|
485 |
+
c * pred_noise + \
|
486 |
+
sigma * noise
|
487 |
+
|
488 |
+
imgs.append(img)
|
489 |
+
|
490 |
+
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
491 |
+
|
492 |
+
ret = self.unnormalize(ret)
|
493 |
+
return ret
|
494 |
+
|
495 |
+
@torch.inference_mode()
|
496 |
+
def sample(self, batch_size = 16, return_all_timesteps = False):
|
497 |
+
image_size, channels = self.image_size, self.channels
|
498 |
+
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
|
499 |
+
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
|
500 |
+
|
501 |
+
@torch.inference_mode()
|
502 |
+
def interpolate(self, x1, x2, t = None, lam = 0.5):
|
503 |
+
b, *_, device = *x1.shape, x1.device
|
504 |
+
t = default(t, self.num_timesteps - 1)
|
505 |
+
|
506 |
+
assert x1.shape == x2.shape
|
507 |
+
|
508 |
+
t_batched = torch.full((b,), t, device = device)
|
509 |
+
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
|
510 |
+
|
511 |
+
img = (1 - lam) * xt1 + lam * xt2
|
512 |
+
|
513 |
+
x_start = None
|
514 |
+
|
515 |
+
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
|
516 |
+
self_cond = x_start if self.self_condition else None
|
517 |
+
img, x_start = self.p_sample(img, i, self_cond)
|
518 |
+
|
519 |
+
return img
|
520 |
+
|
521 |
+
@autocast(enabled = False)
|
522 |
+
def q_sample(self, x_start, t, noise = None):
|
523 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
524 |
+
|
525 |
+
return (
|
526 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
527 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
528 |
+
)
|
529 |
+
|
530 |
+
def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
|
531 |
+
b, c, h, w = x_start.shape
|
532 |
+
|
533 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
534 |
+
|
535 |
+
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
536 |
+
|
537 |
+
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
|
538 |
+
|
539 |
+
if offset_noise_strength > 0.:
|
540 |
+
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
|
541 |
+
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
|
542 |
+
|
543 |
+
# noise sample
|
544 |
+
|
545 |
+
x = self.q_sample(x_start = x_start, t = t, noise = noise)
|
546 |
+
|
547 |
+
# if doing self-conditioning, 50% of the time, predict x_start from current set of times
|
548 |
+
# and condition with unet with that
|
549 |
+
# this technique will slow down training by 25%, but seems to lower FID significantly
|
550 |
+
|
551 |
+
x_self_cond = None
|
552 |
+
if self.self_condition and random() < 0.5:
|
553 |
+
with torch.no_grad():
|
554 |
+
x_self_cond = self.model_predictions(x, t).pred_x_start
|
555 |
+
x_self_cond.detach_()
|
556 |
+
|
557 |
+
# predict and take gradient step
|
558 |
+
|
559 |
+
model_out = self.model(x, t, x_self_cond)
|
560 |
+
|
561 |
+
if self.objective == 'pred_noise':
|
562 |
+
target = noise
|
563 |
+
elif self.objective == 'pred_x0':
|
564 |
+
target = x_start
|
565 |
+
elif self.objective == 'pred_v':
|
566 |
+
v = self.predict_v(x_start, t, noise)
|
567 |
+
target = v
|
568 |
+
else:
|
569 |
+
raise ValueError(f'unknown objective {self.objective}')
|
570 |
+
|
571 |
+
loss = F.mse_loss(model_out, target, reduction = 'none')
|
572 |
+
loss = reduce(loss, 'b ... -> b', 'mean')
|
573 |
+
|
574 |
+
loss = loss * extract(self.loss_weight, t, loss.shape)
|
575 |
+
return loss.mean()
|
576 |
+
|
577 |
+
def forward(self, img, *args, **kwargs):
|
578 |
+
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
|
579 |
+
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
580 |
+
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
|
581 |
+
|
582 |
+
img = self.normalize(img)
|
583 |
+
return self.p_losses(img, t, *args, **kwargs)
|
584 |
+
|
585 |
+
|
586 |
+
# # Resnet Model
|
587 |
+
|
588 |
+
# In[19]:
|
589 |
+
|
590 |
+
|
591 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
592 |
+
return nn.Conv2d(
|
593 |
+
in_channels, out_channels, kernel_size,
|
594 |
+
padding=(kernel_size//2), bias=bias)
|
595 |
+
|
596 |
+
|
597 |
+
# In[20]:
|
598 |
+
|
599 |
+
|
600 |
+
class Swish(nn.Module):
|
601 |
+
def forward(self, x):
|
602 |
+
return x * torch.sigmoid(x)
|
603 |
+
|
604 |
+
|
605 |
+
# In[21]:
|
606 |
+
|
607 |
+
|
608 |
+
class AttnBlock(nn.Module):
|
609 |
+
def __init__(self, in_ch):
|
610 |
+
super().__init__()
|
611 |
+
self.group_norm = nn.GroupNorm(32, in_ch)
|
612 |
+
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
|
613 |
+
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
|
614 |
+
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
|
615 |
+
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
|
616 |
+
|
617 |
+
def forward(self, x):
|
618 |
+
B, C, H, W = x.shape
|
619 |
+
h = self.group_norm(x)
|
620 |
+
q = self.proj_q(h)
|
621 |
+
k = self.proj_k(h)
|
622 |
+
v = self.proj_v(h)
|
623 |
+
|
624 |
+
q = q.permute(0, 2, 3, 1).view(B, H * W, C)
|
625 |
+
k = k.view(B, C, H * W)
|
626 |
+
w = torch.bmm(q, k) * (int(C) ** (-0.5))
|
627 |
+
assert list(w.shape) == [B, H * W, H * W]
|
628 |
+
w = F.softmax(w, dim=-1)
|
629 |
+
|
630 |
+
v = v.permute(0, 2, 3, 1).view(B, H * W, C)
|
631 |
+
h = torch.bmm(w, v)
|
632 |
+
assert list(h.shape) == [B, H * W, C]
|
633 |
+
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
|
634 |
+
h = self.proj(h)
|
635 |
+
|
636 |
+
return x + h
|
637 |
+
|
638 |
+
|
639 |
+
# In[22]:
|
640 |
+
|
641 |
+
|
642 |
+
class ResBlock(nn.Module):
|
643 |
+
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
|
644 |
+
super().__init__()
|
645 |
+
self.block1 = nn.Sequential(
|
646 |
+
nn.GroupNorm(32, in_ch),
|
647 |
+
Swish(),
|
648 |
+
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
|
649 |
+
)
|
650 |
+
self.temb_proj = nn.Sequential(
|
651 |
+
Swish(),
|
652 |
+
nn.Linear(tdim, out_ch),
|
653 |
+
)
|
654 |
+
self.block2 = nn.Sequential(
|
655 |
+
nn.GroupNorm(32, out_ch),
|
656 |
+
Swish(),
|
657 |
+
nn.Dropout(dropout),
|
658 |
+
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
|
659 |
+
)
|
660 |
+
if in_ch != out_ch:
|
661 |
+
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
|
662 |
+
else:
|
663 |
+
self.shortcut = nn.Identity()
|
664 |
+
if attn:
|
665 |
+
self.attn = AttnBlock(out_ch)
|
666 |
+
else:
|
667 |
+
self.attn = nn.Identity()
|
668 |
+
|
669 |
+
def forward(self, x, temb):
|
670 |
+
h = self.block1(x)
|
671 |
+
h += self.temb_proj(temb)[:, :, None, None]
|
672 |
+
h = self.block2(h)
|
673 |
+
|
674 |
+
h = h + self.shortcut(x)
|
675 |
+
h = self.attn(h)
|
676 |
+
return h
|
677 |
+
|
678 |
+
|
679 |
+
# In[23]:
|
680 |
+
|
681 |
+
|
682 |
+
class EDSR(nn.Module):
|
683 |
+
# Modified from https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/edsr.py#L31
|
684 |
+
|
685 |
+
def __init__(self,
|
686 |
+
resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock'],
|
687 |
+
n_feats=128,
|
688 |
+
t_dim=256,
|
689 |
+
dropout=0.1,
|
690 |
+
channels=1,
|
691 |
+
out_dim=1,
|
692 |
+
self_condition = False,
|
693 |
+
learned_sinusoidal_cond=False,
|
694 |
+
random_fourier_features=False,
|
695 |
+
learned_sinusoidal_dim=16,
|
696 |
+
sinusoidal_pos_emb_theta=10000,
|
697 |
+
conv=default_conv):
|
698 |
+
super(EDSR, self).__init__()
|
699 |
+
|
700 |
+
self.resblocks = resblocks
|
701 |
+
self.n_feats = n_feats
|
702 |
+
self.t_dim = t_dim
|
703 |
+
self.dropout = dropout
|
704 |
+
self.channels = channels
|
705 |
+
self.out_dim = out_dim
|
706 |
+
self.self_condition = self_condition
|
707 |
+
self.kernel_size = 3
|
708 |
+
|
709 |
+
# define time embedding
|
710 |
+
if learned_sinusoidal_cond:
|
711 |
+
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
|
712 |
+
fourier_dim = learned_sinusoidal_dim + 1
|
713 |
+
else:
|
714 |
+
sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta)
|
715 |
+
fourier_dim = self.n_feats
|
716 |
+
|
717 |
+
self.time_mlp = nn.Sequential(
|
718 |
+
sinu_pos_emb,
|
719 |
+
nn.Linear(fourier_dim, self.t_dim),
|
720 |
+
nn.GELU(),
|
721 |
+
nn.Linear(self.t_dim, self.t_dim)
|
722 |
+
)
|
723 |
+
|
724 |
+
# define head module
|
725 |
+
self.head = conv(self.channels, self.n_feats, self.kernel_size)
|
726 |
+
|
727 |
+
# define body module
|
728 |
+
self.body = nn.ModuleList()
|
729 |
+
for block in resblocks:
|
730 |
+
if block == "ResBlock":
|
731 |
+
self.body.append(
|
732 |
+
ResBlock(in_ch=self.n_feats,
|
733 |
+
out_ch=self.n_feats,
|
734 |
+
tdim=self.t_dim,
|
735 |
+
dropout=self.dropout,
|
736 |
+
attn=False))
|
737 |
+
elif block == "AttnBlock":
|
738 |
+
self.body.append(
|
739 |
+
ResBlock(in_ch=self.n_feats,
|
740 |
+
out_ch=self.n_feats,
|
741 |
+
tdim=self.t_dim,
|
742 |
+
dropout=self.dropout,
|
743 |
+
attn=True))
|
744 |
+
else:
|
745 |
+
raise NotImplementedError("Model currently doesn't support this kind of block!")
|
746 |
+
self.body.append(conv(self.n_feats, self.n_feats, self.kernel_size))
|
747 |
+
|
748 |
+
# define tail module
|
749 |
+
self.tail = conv(self.n_feats, self.out_dim, self.kernel_size)
|
750 |
+
|
751 |
+
|
752 |
+
def forward(self, x, t, cond=None):
|
753 |
+
t = self.time_mlp(t)
|
754 |
+
|
755 |
+
x = self.head(x)
|
756 |
+
|
757 |
+
res = x
|
758 |
+
for block in self.body:
|
759 |
+
if isinstance(block, ResBlock):
|
760 |
+
res = block(res, t)
|
761 |
+
else:
|
762 |
+
res = block(res)
|
763 |
+
res += x
|
764 |
+
|
765 |
+
x = self.tail(res)
|
766 |
+
|
767 |
+
return x
|
768 |
+
|
769 |
+
|
770 |
+
# # Train
|
771 |
+
|
772 |
+
# In[24]:
|
773 |
+
|
774 |
+
|
775 |
+
# output dir
|
776 |
+
save_path = 'resnet/model'
|
777 |
+
log_path = 'resnet/log'
|
778 |
+
|
779 |
+
if not os.path.exists(log_path):
|
780 |
+
os.mkdir(log_path)
|
781 |
+
if not os.path.exists(save_path):
|
782 |
+
os.mkdir(save_path)
|
783 |
+
|
784 |
+
|
785 |
+
# In[25]:
|
786 |
+
|
787 |
+
|
788 |
+
# setup logging
|
789 |
+
|
790 |
+
# Setup logging to file
|
791 |
+
logging.basicConfig(
|
792 |
+
filename=os.path.join(log_path, 'info.log'),
|
793 |
+
filemode="w",
|
794 |
+
level=logging.DEBUG,
|
795 |
+
format= '[%(asctime)s] %(levelname)s - %(message)s',
|
796 |
+
datefmt='%H:%M:%S',
|
797 |
+
force=True
|
798 |
+
)
|
799 |
+
|
800 |
+
|
801 |
+
# Stop PIL from printing to file
|
802 |
+
pil_logger = logging.getLogger('PIL')
|
803 |
+
pil_logger.setLevel(logging.INFO)
|
804 |
+
|
805 |
+
# write and print at the same time
|
806 |
+
console = logging.StreamHandler()
|
807 |
+
console.setLevel(logging.INFO)
|
808 |
+
logging.getLogger().addHandler(console)
|
809 |
+
|
810 |
+
logger = logging.getLogger('Diffusion_Resnet')
|
811 |
+
|
812 |
+
|
813 |
+
# In[26]:
|
814 |
+
|
815 |
+
|
816 |
+
# define model
|
817 |
+
model = EDSR(
|
818 |
+
resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock',
|
819 |
+
'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock',],
|
820 |
+
n_feats=256,
|
821 |
+
t_dim=512,
|
822 |
+
dropout=0.1,
|
823 |
+
channels=1, # MNIST
|
824 |
+
out_dim=1, # MNIST
|
825 |
+
learned_sinusoidal_cond=False,
|
826 |
+
random_fourier_features=False,
|
827 |
+
learned_sinusoidal_dim=16,
|
828 |
+
sinusoidal_pos_emb_theta=10000,)
|
829 |
+
|
830 |
+
diffusion_model = GaussianDiffusion(
|
831 |
+
model,
|
832 |
+
image_size=28, # MNIST
|
833 |
+
timesteps=1000,
|
834 |
+
sampling_timesteps=None,
|
835 |
+
objective ='pred_noise',
|
836 |
+
beta_schedule ='linear',
|
837 |
+
schedule_fn_kwargs=dict(),
|
838 |
+
ddim_sampling_eta= 0.,
|
839 |
+
auto_normalize = True,
|
840 |
+
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
841 |
+
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
|
842 |
+
min_snr_gamma = 5)
|
843 |
+
|
844 |
+
|
845 |
+
# In[27]:
|
846 |
+
|
847 |
+
|
848 |
+
# define dataset
|
849 |
+
transform = transforms.Compose([
|
850 |
+
transforms.ToTensor(),
|
851 |
+
# v2.Normalize((0.1307,), (0.3081,)), # https://stackoverflow.com/questions/70892017/normalize-mnist-in-pytorch
|
852 |
+
])
|
853 |
+
|
854 |
+
train_dataset = torchvision.datasets.MNIST(root='.', train=True,
|
855 |
+
download=True, transform=transform)
|
856 |
+
# test_dataset = torchvision.datasets.MNIST(root='.', train=True,
|
857 |
+
# download=True, transform=transform)
|
858 |
+
|
859 |
+
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
860 |
+
# test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)
|
861 |
+
|
862 |
+
|
863 |
+
# In[28]:
|
864 |
+
|
865 |
+
|
866 |
+
# define optimizer
|
867 |
+
train_lr = 1e-4
|
868 |
+
adam_betas = (0.9, 0.99)
|
869 |
+
|
870 |
+
optimizer = Adam(diffusion_model.parameters(),
|
871 |
+
lr=train_lr,
|
872 |
+
betas=adam_betas)
|
873 |
+
|
874 |
+
|
875 |
+
# In[29]:
|
876 |
+
|
877 |
+
|
878 |
+
# device
|
879 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
880 |
+
|
881 |
+
|
882 |
+
# In[30]:
|
883 |
+
|
884 |
+
|
885 |
+
# trainer
|
886 |
+
max_epoches = 50
|
887 |
+
iter_print = 100
|
888 |
+
iter_sample = 1000
|
889 |
+
save_each = 1
|
890 |
+
|
891 |
+
diffusion_model = diffusion_model.to(device)
|
892 |
+
|
893 |
+
last_trained_path = None
|
894 |
+
if last_trained_path:
|
895 |
+
data = torch.load(os.path.join(last_trained_path))
|
896 |
+
diffusion_model.load_state_dict(data['model'])
|
897 |
+
optimizer.load_state_dict(data['opt'])
|
898 |
+
count = data['step']
|
899 |
+
start_epoch = data['epoch']
|
900 |
+
log_loss = data['loss']
|
901 |
+
else:
|
902 |
+
count = 0
|
903 |
+
start_epoch = 1
|
904 |
+
log_loss = []
|
905 |
+
|
906 |
+
for epoch in range(start_epoch, max_epoches+1):
|
907 |
+
diffusion_model.train()
|
908 |
+
for img, _ in train_dataloader:
|
909 |
+
img = img.to(device)
|
910 |
+
|
911 |
+
loss = diffusion_model(img)
|
912 |
+
|
913 |
+
optimizer.zero_grad()
|
914 |
+
loss.backward()
|
915 |
+
optimizer.step()
|
916 |
+
|
917 |
+
if count % iter_print == 0 or count == 0:
|
918 |
+
logger.info('Epoch {}/{}, Iter {}: Loss = {}, lr = {}'.format(
|
919 |
+
epoch,
|
920 |
+
max_epoches,
|
921 |
+
count,
|
922 |
+
loss.mean().item(),
|
923 |
+
train_lr,
|
924 |
+
))
|
925 |
+
|
926 |
+
log_loss.append(loss.mean().item())
|
927 |
+
|
928 |
+
loss = None
|
929 |
+
|
930 |
+
count += 1
|
931 |
+
|
932 |
+
if count % iter_sample == 0:
|
933 |
+
diffusion_model.eval()
|
934 |
+
|
935 |
+
sample_imgs = diffusion_model.sample(batch_size=16)
|
936 |
+
|
937 |
+
utils.save_image(sample_imgs,
|
938 |
+
os.path.join(log_path, f"iter_{count}.png"),
|
939 |
+
nrow = int(math.sqrt(16)))
|
940 |
+
|
941 |
+
|
942 |
+
if epoch % save_each == 0:
|
943 |
+
data = {
|
944 |
+
'model': diffusion_model.state_dict(),
|
945 |
+
'opt': optimizer.state_dict(),
|
946 |
+
'step': count,
|
947 |
+
'epoch': epoch,
|
948 |
+
'loss': log_loss,
|
949 |
+
}
|
950 |
+
|
951 |
+
torch.save(data, os.path.join(save_path, f"epoch_{epoch}.pth"))
|
952 |
+
|
resnet/DDPM_ResNet_sample.py
ADDED
@@ -0,0 +1,856 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
# # Library
|
5 |
+
|
6 |
+
# In[1]:
|
7 |
+
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import nn
|
12 |
+
from torch.cuda.amp import autocast
|
13 |
+
|
14 |
+
import torchvision
|
15 |
+
from torchvision.transforms import transforms
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
|
18 |
+
from torch.optim import Adam
|
19 |
+
|
20 |
+
from einops import rearrange, reduce, repeat
|
21 |
+
import math
|
22 |
+
from random import random
|
23 |
+
|
24 |
+
from collections import namedtuple
|
25 |
+
from functools import partial
|
26 |
+
from tqdm.auto import tqdm
|
27 |
+
import logging
|
28 |
+
import os
|
29 |
+
|
30 |
+
from PIL import Image
|
31 |
+
from torchvision import utils
|
32 |
+
|
33 |
+
|
34 |
+
# # Helper
|
35 |
+
|
36 |
+
# ### Constant
|
37 |
+
|
38 |
+
# In[2]:
|
39 |
+
|
40 |
+
|
41 |
+
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
|
42 |
+
|
43 |
+
|
44 |
+
# ### Functions
|
45 |
+
|
46 |
+
# In[3]:
|
47 |
+
|
48 |
+
|
49 |
+
def exists(x):
|
50 |
+
return x is not None
|
51 |
+
|
52 |
+
def default(val, d):
|
53 |
+
if exists(val):
|
54 |
+
return val
|
55 |
+
return d() if callable(d) else d
|
56 |
+
|
57 |
+
|
58 |
+
# In[4]:
|
59 |
+
|
60 |
+
|
61 |
+
def cast_tuple(t, length = 1):
|
62 |
+
if isinstance(t, tuple):
|
63 |
+
return t
|
64 |
+
return ((t,) * length)
|
65 |
+
|
66 |
+
|
67 |
+
# In[5]:
|
68 |
+
|
69 |
+
|
70 |
+
def divisible_by(numer, denom):
|
71 |
+
return (numer % denom) == 0
|
72 |
+
|
73 |
+
|
74 |
+
# In[6]:
|
75 |
+
|
76 |
+
|
77 |
+
def identity(t, *args, **kwargs):
|
78 |
+
return t
|
79 |
+
|
80 |
+
|
81 |
+
# In[7]:
|
82 |
+
|
83 |
+
|
84 |
+
def cycle(dl):
|
85 |
+
while True:
|
86 |
+
for data in dl:
|
87 |
+
yield data
|
88 |
+
|
89 |
+
|
90 |
+
# In[8]:
|
91 |
+
|
92 |
+
|
93 |
+
def has_int_squareroot(num):
|
94 |
+
return (math.sqrt(num) ** 2) == num
|
95 |
+
|
96 |
+
|
97 |
+
# In[9]:
|
98 |
+
|
99 |
+
|
100 |
+
def num_to_groups(num, divisor):
|
101 |
+
groups = num // divisor
|
102 |
+
remainder = num % divisor
|
103 |
+
arr = [divisor] * groups
|
104 |
+
if remainder > 0:
|
105 |
+
arr.append(remainder)
|
106 |
+
return arr
|
107 |
+
|
108 |
+
|
109 |
+
# In[10]:
|
110 |
+
|
111 |
+
|
112 |
+
def convert_image_to_fn(img_type, image):
|
113 |
+
if image.mode != img_type:
|
114 |
+
return image.convert(img_type)
|
115 |
+
return image
|
116 |
+
|
117 |
+
|
118 |
+
# In[11]:
|
119 |
+
|
120 |
+
|
121 |
+
def extract(a, t, x_shape):
|
122 |
+
b, *_ = t.shape
|
123 |
+
out = a.gather(-1, t)
|
124 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
125 |
+
|
126 |
+
|
127 |
+
# ### Normalization Functions
|
128 |
+
|
129 |
+
# In[12]:
|
130 |
+
|
131 |
+
|
132 |
+
def normalize_to_neg_one_to_one(img):
|
133 |
+
return img * 2 - 1
|
134 |
+
|
135 |
+
def unnormalize_to_zero_to_one(t):
|
136 |
+
return (t + 1) * 0.5
|
137 |
+
|
138 |
+
|
139 |
+
# ### Sinusoidal positional embeds
|
140 |
+
|
141 |
+
# In[13]:
|
142 |
+
|
143 |
+
|
144 |
+
class SinusoidalPosEmb(nn.Module):
|
145 |
+
def __init__(self, dim, theta = 10000):
|
146 |
+
super().__init__()
|
147 |
+
self.dim = dim
|
148 |
+
self.theta = theta
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
device = x.device
|
152 |
+
half_dim = self.dim // 2
|
153 |
+
emb = math.log(self.theta) / (half_dim - 1)
|
154 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
155 |
+
emb = x[:, None] * emb[None, :]
|
156 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
157 |
+
return emb
|
158 |
+
|
159 |
+
|
160 |
+
# In[14]:
|
161 |
+
|
162 |
+
|
163 |
+
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
|
164 |
+
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
|
165 |
+
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
|
166 |
+
|
167 |
+
def __init__(self, dim, is_random = False):
|
168 |
+
super().__init__()
|
169 |
+
assert divisible_by(dim, 2)
|
170 |
+
half_dim = dim // 2
|
171 |
+
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
x = rearrange(x, 'b -> b 1')
|
175 |
+
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
|
176 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
|
177 |
+
fouriered = torch.cat((x, fouriered), dim = -1)
|
178 |
+
return fouriered
|
179 |
+
|
180 |
+
|
181 |
+
# ### Schedule
|
182 |
+
|
183 |
+
# In[15]:
|
184 |
+
|
185 |
+
|
186 |
+
def linear_beta_schedule(timesteps):
|
187 |
+
"""
|
188 |
+
linear schedule, proposed in original ddpm paper
|
189 |
+
"""
|
190 |
+
scale = 1000 / timesteps
|
191 |
+
beta_start = scale * 0.0001
|
192 |
+
beta_end = scale * 0.02
|
193 |
+
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
|
194 |
+
|
195 |
+
|
196 |
+
# In[16]:
|
197 |
+
|
198 |
+
|
199 |
+
def cosine_beta_schedule(timesteps, s = 0.008):
|
200 |
+
"""
|
201 |
+
cosine schedule
|
202 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
203 |
+
"""
|
204 |
+
steps = timesteps + 1
|
205 |
+
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
|
206 |
+
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
|
207 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
208 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
209 |
+
return torch.clip(betas, 0, 0.999)
|
210 |
+
|
211 |
+
|
212 |
+
# In[17]:
|
213 |
+
|
214 |
+
|
215 |
+
def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5):
|
216 |
+
"""
|
217 |
+
sigmoid schedule
|
218 |
+
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
|
219 |
+
better for images > 64x64, when used during training
|
220 |
+
"""
|
221 |
+
steps = timesteps + 1
|
222 |
+
t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
|
223 |
+
v_start = torch.tensor(start / tau).sigmoid()
|
224 |
+
v_end = torch.tensor(end / tau).sigmoid()
|
225 |
+
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
|
226 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
227 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
228 |
+
return torch.clip(betas, 0, 0.999)
|
229 |
+
|
230 |
+
|
231 |
+
# # Diffusion model
|
232 |
+
|
233 |
+
# In[18]:
|
234 |
+
|
235 |
+
|
236 |
+
class GaussianDiffusion(nn.Module):
|
237 |
+
# Copy from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L163
|
238 |
+
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
model,
|
242 |
+
*,
|
243 |
+
image_size,
|
244 |
+
timesteps = 1000,
|
245 |
+
sampling_timesteps = None,
|
246 |
+
objective = 'pred_noise',
|
247 |
+
beta_schedule = 'linear',
|
248 |
+
schedule_fn_kwargs = dict(),
|
249 |
+
ddim_sampling_eta = 0.,
|
250 |
+
auto_normalize = True,
|
251 |
+
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
252 |
+
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
|
253 |
+
min_snr_gamma = 5
|
254 |
+
):
|
255 |
+
super().__init__()
|
256 |
+
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
|
257 |
+
assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond
|
258 |
+
|
259 |
+
self.model = model
|
260 |
+
|
261 |
+
self.channels = self.model.channels
|
262 |
+
self.self_condition = self.model.self_condition
|
263 |
+
|
264 |
+
self.image_size = image_size
|
265 |
+
|
266 |
+
self.objective = objective
|
267 |
+
|
268 |
+
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
|
269 |
+
|
270 |
+
if beta_schedule == 'linear':
|
271 |
+
beta_schedule_fn = linear_beta_schedule
|
272 |
+
elif beta_schedule == 'cosine':
|
273 |
+
beta_schedule_fn = cosine_beta_schedule
|
274 |
+
elif beta_schedule == 'sigmoid':
|
275 |
+
beta_schedule_fn = sigmoid_beta_schedule
|
276 |
+
else:
|
277 |
+
raise ValueError(f'unknown beta schedule {beta_schedule}')
|
278 |
+
|
279 |
+
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
|
280 |
+
|
281 |
+
alphas = 1. - betas
|
282 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
283 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
284 |
+
|
285 |
+
timesteps, = betas.shape
|
286 |
+
self.num_timesteps = int(timesteps)
|
287 |
+
|
288 |
+
# sampling related parameters
|
289 |
+
|
290 |
+
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
|
291 |
+
|
292 |
+
assert self.sampling_timesteps <= timesteps
|
293 |
+
self.is_ddim_sampling = self.sampling_timesteps < timesteps
|
294 |
+
self.ddim_sampling_eta = ddim_sampling_eta
|
295 |
+
|
296 |
+
# helper function to register buffer from float64 to float32
|
297 |
+
|
298 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
299 |
+
|
300 |
+
register_buffer('betas', betas)
|
301 |
+
register_buffer('alphas_cumprod', alphas_cumprod)
|
302 |
+
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
303 |
+
|
304 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
305 |
+
|
306 |
+
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
307 |
+
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
308 |
+
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
309 |
+
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
310 |
+
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
311 |
+
|
312 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
313 |
+
|
314 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
315 |
+
|
316 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
317 |
+
|
318 |
+
register_buffer('posterior_variance', posterior_variance)
|
319 |
+
|
320 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
321 |
+
|
322 |
+
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
323 |
+
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
324 |
+
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
325 |
+
|
326 |
+
# offset noise strength - in blogpost, they claimed 0.1 was ideal
|
327 |
+
|
328 |
+
self.offset_noise_strength = offset_noise_strength
|
329 |
+
|
330 |
+
# derive loss weight
|
331 |
+
# snr - signal noise ratio
|
332 |
+
|
333 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
334 |
+
|
335 |
+
# https://arxiv.org/abs/2303.09556
|
336 |
+
|
337 |
+
maybe_clipped_snr = snr.clone()
|
338 |
+
if min_snr_loss_weight:
|
339 |
+
maybe_clipped_snr.clamp_(max = min_snr_gamma)
|
340 |
+
|
341 |
+
if objective == 'pred_noise':
|
342 |
+
register_buffer('loss_weight', maybe_clipped_snr / snr)
|
343 |
+
elif objective == 'pred_x0':
|
344 |
+
register_buffer('loss_weight', maybe_clipped_snr)
|
345 |
+
elif objective == 'pred_v':
|
346 |
+
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
|
347 |
+
|
348 |
+
# auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
|
349 |
+
|
350 |
+
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
|
351 |
+
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
|
352 |
+
|
353 |
+
@property
|
354 |
+
def device(self):
|
355 |
+
return self.betas.device
|
356 |
+
|
357 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
358 |
+
return (
|
359 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
360 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
361 |
+
)
|
362 |
+
|
363 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
364 |
+
return (
|
365 |
+
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
|
366 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
367 |
+
)
|
368 |
+
|
369 |
+
def predict_v(self, x_start, t, noise):
|
370 |
+
return (
|
371 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
372 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
373 |
+
)
|
374 |
+
|
375 |
+
def predict_start_from_v(self, x_t, t, v):
|
376 |
+
return (
|
377 |
+
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
378 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
379 |
+
)
|
380 |
+
|
381 |
+
def q_posterior(self, x_start, x_t, t):
|
382 |
+
posterior_mean = (
|
383 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
384 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
385 |
+
)
|
386 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
387 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
388 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
389 |
+
|
390 |
+
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
|
391 |
+
model_output = self.model(x, t, x_self_cond)
|
392 |
+
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
|
393 |
+
|
394 |
+
if self.objective == 'pred_noise':
|
395 |
+
pred_noise = model_output
|
396 |
+
x_start = self.predict_start_from_noise(x, t, pred_noise)
|
397 |
+
x_start = maybe_clip(x_start)
|
398 |
+
|
399 |
+
if clip_x_start and rederive_pred_noise:
|
400 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
401 |
+
|
402 |
+
elif self.objective == 'pred_x0':
|
403 |
+
x_start = model_output
|
404 |
+
x_start = maybe_clip(x_start)
|
405 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
406 |
+
|
407 |
+
elif self.objective == 'pred_v':
|
408 |
+
v = model_output
|
409 |
+
x_start = self.predict_start_from_v(x, t, v)
|
410 |
+
x_start = maybe_clip(x_start)
|
411 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
412 |
+
|
413 |
+
return ModelPrediction(pred_noise, x_start)
|
414 |
+
|
415 |
+
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
|
416 |
+
preds = self.model_predictions(x, t, x_self_cond)
|
417 |
+
x_start = preds.pred_x_start
|
418 |
+
|
419 |
+
if clip_denoised:
|
420 |
+
x_start.clamp_(-1., 1.)
|
421 |
+
|
422 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
|
423 |
+
return model_mean, posterior_variance, posterior_log_variance, x_start
|
424 |
+
|
425 |
+
@torch.inference_mode()
|
426 |
+
def p_sample(self, x, t: int, x_self_cond = None):
|
427 |
+
b, *_, device = *x.shape, self.device
|
428 |
+
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
|
429 |
+
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
|
430 |
+
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
|
431 |
+
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
|
432 |
+
return pred_img, x_start
|
433 |
+
|
434 |
+
@torch.inference_mode()
|
435 |
+
def p_sample_loop(self, shape, return_all_timesteps = False):
|
436 |
+
batch, device = shape[0], self.device
|
437 |
+
|
438 |
+
img = torch.randn(shape, device = device)
|
439 |
+
imgs = [img]
|
440 |
+
|
441 |
+
x_start = None
|
442 |
+
|
443 |
+
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
444 |
+
self_cond = x_start if self.self_condition else None
|
445 |
+
img, x_start = self.p_sample(img, t, self_cond)
|
446 |
+
imgs.append(img)
|
447 |
+
|
448 |
+
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
449 |
+
|
450 |
+
ret = self.unnormalize(ret)
|
451 |
+
return ret
|
452 |
+
|
453 |
+
@torch.inference_mode()
|
454 |
+
def ddim_sample(self, shape, return_all_timesteps = False):
|
455 |
+
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
|
456 |
+
|
457 |
+
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
|
458 |
+
times = list(reversed(times.int().tolist()))
|
459 |
+
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
|
460 |
+
|
461 |
+
img = torch.randn(shape, device = device)
|
462 |
+
imgs = [img]
|
463 |
+
|
464 |
+
x_start = None
|
465 |
+
|
466 |
+
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
|
467 |
+
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
468 |
+
self_cond = x_start if self.self_condition else None
|
469 |
+
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
|
470 |
+
|
471 |
+
if time_next < 0:
|
472 |
+
img = x_start
|
473 |
+
imgs.append(img)
|
474 |
+
continue
|
475 |
+
|
476 |
+
alpha = self.alphas_cumprod[time]
|
477 |
+
alpha_next = self.alphas_cumprod[time_next]
|
478 |
+
|
479 |
+
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
|
480 |
+
c = (1 - alpha_next - sigma ** 2).sqrt()
|
481 |
+
|
482 |
+
noise = torch.randn_like(img)
|
483 |
+
|
484 |
+
img = x_start * alpha_next.sqrt() + \
|
485 |
+
c * pred_noise + \
|
486 |
+
sigma * noise
|
487 |
+
|
488 |
+
imgs.append(img)
|
489 |
+
|
490 |
+
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
|
491 |
+
|
492 |
+
ret = self.unnormalize(ret)
|
493 |
+
return ret
|
494 |
+
|
495 |
+
@torch.inference_mode()
|
496 |
+
def sample(self, batch_size = 16, return_all_timesteps = False):
|
497 |
+
image_size, channels = self.image_size, self.channels
|
498 |
+
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
|
499 |
+
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
|
500 |
+
|
501 |
+
@torch.inference_mode()
|
502 |
+
def interpolate(self, x1, x2, t = None, lam = 0.5):
|
503 |
+
b, *_, device = *x1.shape, x1.device
|
504 |
+
t = default(t, self.num_timesteps - 1)
|
505 |
+
|
506 |
+
assert x1.shape == x2.shape
|
507 |
+
|
508 |
+
t_batched = torch.full((b,), t, device = device)
|
509 |
+
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
|
510 |
+
|
511 |
+
img = (1 - lam) * xt1 + lam * xt2
|
512 |
+
|
513 |
+
x_start = None
|
514 |
+
|
515 |
+
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
|
516 |
+
self_cond = x_start if self.self_condition else None
|
517 |
+
img, x_start = self.p_sample(img, i, self_cond)
|
518 |
+
|
519 |
+
return img
|
520 |
+
|
521 |
+
@autocast(enabled = False)
|
522 |
+
def q_sample(self, x_start, t, noise = None):
|
523 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
524 |
+
|
525 |
+
return (
|
526 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
527 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
528 |
+
)
|
529 |
+
|
530 |
+
def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
|
531 |
+
b, c, h, w = x_start.shape
|
532 |
+
|
533 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
534 |
+
|
535 |
+
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
536 |
+
|
537 |
+
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
|
538 |
+
|
539 |
+
if offset_noise_strength > 0.:
|
540 |
+
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
|
541 |
+
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
|
542 |
+
|
543 |
+
# noise sample
|
544 |
+
|
545 |
+
x = self.q_sample(x_start = x_start, t = t, noise = noise)
|
546 |
+
|
547 |
+
# if doing self-conditioning, 50% of the time, predict x_start from current set of times
|
548 |
+
# and condition with unet with that
|
549 |
+
# this technique will slow down training by 25%, but seems to lower FID significantly
|
550 |
+
|
551 |
+
x_self_cond = None
|
552 |
+
if self.self_condition and random() < 0.5:
|
553 |
+
with torch.no_grad():
|
554 |
+
x_self_cond = self.model_predictions(x, t).pred_x_start
|
555 |
+
x_self_cond.detach_()
|
556 |
+
|
557 |
+
# predict and take gradient step
|
558 |
+
|
559 |
+
model_out = self.model(x, t, x_self_cond)
|
560 |
+
|
561 |
+
if self.objective == 'pred_noise':
|
562 |
+
target = noise
|
563 |
+
elif self.objective == 'pred_x0':
|
564 |
+
target = x_start
|
565 |
+
elif self.objective == 'pred_v':
|
566 |
+
v = self.predict_v(x_start, t, noise)
|
567 |
+
target = v
|
568 |
+
else:
|
569 |
+
raise ValueError(f'unknown objective {self.objective}')
|
570 |
+
|
571 |
+
loss = F.mse_loss(model_out, target, reduction = 'none')
|
572 |
+
loss = reduce(loss, 'b ... -> b', 'mean')
|
573 |
+
|
574 |
+
loss = loss * extract(self.loss_weight, t, loss.shape)
|
575 |
+
return loss.mean()
|
576 |
+
|
577 |
+
def forward(self, img, *args, **kwargs):
|
578 |
+
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
|
579 |
+
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
|
580 |
+
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
|
581 |
+
|
582 |
+
img = self.normalize(img)
|
583 |
+
return self.p_losses(img, t, *args, **kwargs)
|
584 |
+
|
585 |
+
|
586 |
+
# # Resnet Model
|
587 |
+
|
588 |
+
# In[19]:
|
589 |
+
|
590 |
+
|
591 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
592 |
+
return nn.Conv2d(
|
593 |
+
in_channels, out_channels, kernel_size,
|
594 |
+
padding=(kernel_size//2), bias=bias)
|
595 |
+
|
596 |
+
|
597 |
+
# In[20]:
|
598 |
+
|
599 |
+
|
600 |
+
class Swish(nn.Module):
|
601 |
+
def forward(self, x):
|
602 |
+
return x * torch.sigmoid(x)
|
603 |
+
|
604 |
+
|
605 |
+
# In[21]:
|
606 |
+
|
607 |
+
|
608 |
+
class AttnBlock(nn.Module):
|
609 |
+
def __init__(self, in_ch):
|
610 |
+
super().__init__()
|
611 |
+
self.group_norm = nn.GroupNorm(32, in_ch)
|
612 |
+
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
|
613 |
+
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
|
614 |
+
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
|
615 |
+
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
|
616 |
+
|
617 |
+
def forward(self, x):
|
618 |
+
B, C, H, W = x.shape
|
619 |
+
h = self.group_norm(x)
|
620 |
+
q = self.proj_q(h)
|
621 |
+
k = self.proj_k(h)
|
622 |
+
v = self.proj_v(h)
|
623 |
+
|
624 |
+
q = q.permute(0, 2, 3, 1).view(B, H * W, C)
|
625 |
+
k = k.view(B, C, H * W)
|
626 |
+
w = torch.bmm(q, k) * (int(C) ** (-0.5))
|
627 |
+
assert list(w.shape) == [B, H * W, H * W]
|
628 |
+
w = F.softmax(w, dim=-1)
|
629 |
+
|
630 |
+
v = v.permute(0, 2, 3, 1).view(B, H * W, C)
|
631 |
+
h = torch.bmm(w, v)
|
632 |
+
assert list(h.shape) == [B, H * W, C]
|
633 |
+
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
|
634 |
+
h = self.proj(h)
|
635 |
+
|
636 |
+
return x + h
|
637 |
+
|
638 |
+
|
639 |
+
# In[22]:
|
640 |
+
|
641 |
+
|
642 |
+
class ResBlock(nn.Module):
|
643 |
+
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
|
644 |
+
super().__init__()
|
645 |
+
self.block1 = nn.Sequential(
|
646 |
+
nn.GroupNorm(32, in_ch),
|
647 |
+
Swish(),
|
648 |
+
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
|
649 |
+
)
|
650 |
+
self.temb_proj = nn.Sequential(
|
651 |
+
Swish(),
|
652 |
+
nn.Linear(tdim, out_ch),
|
653 |
+
)
|
654 |
+
self.block2 = nn.Sequential(
|
655 |
+
nn.GroupNorm(32, out_ch),
|
656 |
+
Swish(),
|
657 |
+
nn.Dropout(dropout),
|
658 |
+
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
|
659 |
+
)
|
660 |
+
if in_ch != out_ch:
|
661 |
+
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
|
662 |
+
else:
|
663 |
+
self.shortcut = nn.Identity()
|
664 |
+
if attn:
|
665 |
+
self.attn = AttnBlock(out_ch)
|
666 |
+
else:
|
667 |
+
self.attn = nn.Identity()
|
668 |
+
|
669 |
+
def forward(self, x, temb):
|
670 |
+
h = self.block1(x)
|
671 |
+
h += self.temb_proj(temb)[:, :, None, None]
|
672 |
+
h = self.block2(h)
|
673 |
+
|
674 |
+
h = h + self.shortcut(x)
|
675 |
+
h = self.attn(h)
|
676 |
+
return h
|
677 |
+
|
678 |
+
|
679 |
+
# In[23]:
|
680 |
+
|
681 |
+
|
682 |
+
class EDSR(nn.Module):
|
683 |
+
# Modified from https://github.com/sanghyun-son/EDSR-PyTorch/blob/master/src/model/edsr.py#L31
|
684 |
+
|
685 |
+
def __init__(self,
|
686 |
+
resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock'],
|
687 |
+
n_feats=128,
|
688 |
+
t_dim=256,
|
689 |
+
dropout=0.1,
|
690 |
+
channels=1,
|
691 |
+
out_dim=1,
|
692 |
+
self_condition = False,
|
693 |
+
learned_sinusoidal_cond=False,
|
694 |
+
random_fourier_features=False,
|
695 |
+
learned_sinusoidal_dim=16,
|
696 |
+
sinusoidal_pos_emb_theta=10000,
|
697 |
+
conv=default_conv):
|
698 |
+
super(EDSR, self).__init__()
|
699 |
+
|
700 |
+
self.resblocks = resblocks
|
701 |
+
self.n_feats = n_feats
|
702 |
+
self.t_dim = t_dim
|
703 |
+
self.dropout = dropout
|
704 |
+
self.channels = channels
|
705 |
+
self.out_dim = out_dim
|
706 |
+
self.self_condition = self_condition
|
707 |
+
self.kernel_size = 3
|
708 |
+
|
709 |
+
# define time embedding
|
710 |
+
if learned_sinusoidal_cond:
|
711 |
+
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
|
712 |
+
fourier_dim = learned_sinusoidal_dim + 1
|
713 |
+
else:
|
714 |
+
sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta)
|
715 |
+
fourier_dim = self.n_feats
|
716 |
+
|
717 |
+
self.time_mlp = nn.Sequential(
|
718 |
+
sinu_pos_emb,
|
719 |
+
nn.Linear(fourier_dim, self.t_dim),
|
720 |
+
nn.GELU(),
|
721 |
+
nn.Linear(self.t_dim, self.t_dim)
|
722 |
+
)
|
723 |
+
|
724 |
+
# define head module
|
725 |
+
self.head = conv(self.channels, self.n_feats, self.kernel_size)
|
726 |
+
|
727 |
+
# define body module
|
728 |
+
self.body = nn.ModuleList()
|
729 |
+
for block in resblocks:
|
730 |
+
if block == "ResBlock":
|
731 |
+
self.body.append(
|
732 |
+
ResBlock(in_ch=self.n_feats,
|
733 |
+
out_ch=self.n_feats,
|
734 |
+
tdim=self.t_dim,
|
735 |
+
dropout=self.dropout,
|
736 |
+
attn=False))
|
737 |
+
elif block == "AttnBlock":
|
738 |
+
self.body.append(
|
739 |
+
ResBlock(in_ch=self.n_feats,
|
740 |
+
out_ch=self.n_feats,
|
741 |
+
tdim=self.t_dim,
|
742 |
+
dropout=self.dropout,
|
743 |
+
attn=True))
|
744 |
+
else:
|
745 |
+
raise NotImplementedError("Model currently doesn't support this kind of block!")
|
746 |
+
self.body.append(conv(self.n_feats, self.n_feats, self.kernel_size))
|
747 |
+
|
748 |
+
# define tail module
|
749 |
+
self.tail = conv(self.n_feats, self.out_dim, self.kernel_size)
|
750 |
+
|
751 |
+
|
752 |
+
def forward(self, x, t, cond=None):
|
753 |
+
t = self.time_mlp(t)
|
754 |
+
|
755 |
+
x = self.head(x)
|
756 |
+
|
757 |
+
res = x
|
758 |
+
for block in self.body:
|
759 |
+
if isinstance(block, ResBlock):
|
760 |
+
res = block(res, t)
|
761 |
+
else:
|
762 |
+
res = block(res)
|
763 |
+
res += x
|
764 |
+
|
765 |
+
x = self.tail(res)
|
766 |
+
|
767 |
+
return x
|
768 |
+
|
769 |
+
|
770 |
+
# # Train
|
771 |
+
|
772 |
+
# In[24]:
|
773 |
+
|
774 |
+
|
775 |
+
# In[25]:
|
776 |
+
|
777 |
+
|
778 |
+
|
779 |
+
# In[26]:
|
780 |
+
|
781 |
+
|
782 |
+
# define model
|
783 |
+
model = EDSR(
|
784 |
+
resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock',
|
785 |
+
'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock',],
|
786 |
+
n_feats=256,
|
787 |
+
t_dim=512,
|
788 |
+
dropout=0.1,
|
789 |
+
channels=1, # MNIST
|
790 |
+
out_dim=1, # MNIST
|
791 |
+
learned_sinusoidal_cond=False,
|
792 |
+
random_fourier_features=False,
|
793 |
+
learned_sinusoidal_dim=16,
|
794 |
+
sinusoidal_pos_emb_theta=10000,)
|
795 |
+
|
796 |
+
diffusion_model = GaussianDiffusion(
|
797 |
+
model,
|
798 |
+
image_size=28, # MNIST
|
799 |
+
timesteps=1000,
|
800 |
+
sampling_timesteps=None,
|
801 |
+
objective ='pred_noise',
|
802 |
+
beta_schedule ='linear',
|
803 |
+
schedule_fn_kwargs=dict(),
|
804 |
+
ddim_sampling_eta= 0.,
|
805 |
+
auto_normalize = True,
|
806 |
+
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
|
807 |
+
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
|
808 |
+
min_snr_gamma = 5)
|
809 |
+
|
810 |
+
|
811 |
+
# In[27]:
|
812 |
+
|
813 |
+
|
814 |
+
# In[28]:
|
815 |
+
|
816 |
+
|
817 |
+
|
818 |
+
# In[29]:
|
819 |
+
|
820 |
+
|
821 |
+
# device
|
822 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
823 |
+
|
824 |
+
|
825 |
+
# In[30]:
|
826 |
+
|
827 |
+
|
828 |
+
# trainer
|
829 |
+
max_epoches = 50
|
830 |
+
iter_print = 100
|
831 |
+
iter_sample = 1000
|
832 |
+
save_each = 1
|
833 |
+
|
834 |
+
diffusion_model = diffusion_model.to(device)
|
835 |
+
|
836 |
+
last_trained_path = 'resnet\model\epoch_30.pth'
|
837 |
+
diffusion_model.load_state_dict(torch.load(os.path.join(last_trained_path))['model'])
|
838 |
+
|
839 |
+
sample_path = 'resnet/sample2'
|
840 |
+
|
841 |
+
if not os.path.exists(sample_path):
|
842 |
+
os.mkdir(sample_path)
|
843 |
+
|
844 |
+
num_sample = 500
|
845 |
+
sample_batch = 16
|
846 |
+
count = 0
|
847 |
+
|
848 |
+
if num_sample % sample_batch != 0:
|
849 |
+
num_sample = num_sample + (sample_batch - (num_sample % sample_batch))
|
850 |
+
|
851 |
+
for batch in range(num_sample//sample_batch):
|
852 |
+
imgs = diffusion_model.sample(batch_size=sample_batch, return_all_timesteps=False)
|
853 |
+
for i in range(imgs.size(0)):
|
854 |
+
torchvision.utils.save_image(imgs[i, :, :, :], os.path.join(sample_path ,f'{count}.png'))
|
855 |
+
count += 1
|
856 |
+
|
resnet/log/info.log
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[02:18:34] INFO - Epoch 1/50, Iter 0: Loss = 1.23480224609375, lr = 0.0001
|
2 |
+
[02:18:59] INFO - Epoch 1/50, Iter 100: Loss = 0.08826638758182526, lr = 0.0001
|
3 |
+
[02:19:38] INFO - Epoch 1/50, Iter 200: Loss = 0.0545695461332798, lr = 0.0001
|
4 |
+
[02:20:35] INFO - Epoch 1/50, Iter 300: Loss = 0.06626827269792557, lr = 0.0001
|
5 |
+
[02:21:29] INFO - Epoch 1/50, Iter 400: Loss = 0.07271286845207214, lr = 0.0001
|
6 |
+
[02:22:24] INFO - Epoch 1/50, Iter 500: Loss = 0.027932994067668915, lr = 0.0001
|
7 |
+
[02:23:19] INFO - Epoch 1/50, Iter 600: Loss = 0.037907857447862625, lr = 0.0001
|
8 |
+
[02:24:14] INFO - Epoch 1/50, Iter 700: Loss = 0.03283434733748436, lr = 0.0001
|
9 |
+
[02:25:10] INFO - Epoch 1/50, Iter 800: Loss = 0.0401763841509819, lr = 0.0001
|
10 |
+
[02:26:05] INFO - Epoch 1/50, Iter 900: Loss = 0.02380681410431862, lr = 0.0001
|
11 |
+
[02:28:36] INFO - Epoch 1/50, Iter 1000: Loss = 0.03142669051885605, lr = 0.0001
|
12 |
+
[02:29:30] INFO - Epoch 1/50, Iter 1100: Loss = 0.021915458142757416, lr = 0.0001
|
13 |
+
[02:30:25] INFO - Epoch 1/50, Iter 1200: Loss = 0.03710126131772995, lr = 0.0001
|
14 |
+
[02:31:19] INFO - Epoch 1/50, Iter 1300: Loss = 0.017894160002470016, lr = 0.0001
|
15 |
+
[02:32:14] INFO - Epoch 1/50, Iter 1400: Loss = 0.032229095697402954, lr = 0.0001
|
16 |
+
[02:33:08] INFO - Epoch 1/50, Iter 1500: Loss = 0.022246181964874268, lr = 0.0001
|
17 |
+
[02:34:04] INFO - Epoch 1/50, Iter 1600: Loss = 0.02387898601591587, lr = 0.0001
|
18 |
+
[02:34:58] INFO - Epoch 1/50, Iter 1700: Loss = 0.033216990530490875, lr = 0.0001
|
19 |
+
[02:35:53] INFO - Epoch 1/50, Iter 1800: Loss = 0.03182423859834671, lr = 0.0001
|
20 |
+
[02:36:48] INFO - Epoch 2/50, Iter 1900: Loss = 0.027017910033464432, lr = 0.0001
|
21 |
+
[02:39:20] INFO - Epoch 2/50, Iter 2000: Loss = 0.03848206251859665, lr = 0.0001
|
22 |
+
[02:40:14] INFO - Epoch 2/50, Iter 2100: Loss = 0.02826070785522461, lr = 0.0001
|
23 |
+
[02:41:09] INFO - Epoch 2/50, Iter 2200: Loss = 0.03657548129558563, lr = 0.0001
|
24 |
+
[02:42:04] INFO - Epoch 2/50, Iter 2300: Loss = 0.03236750513315201, lr = 0.0001
|
25 |
+
[02:42:59] INFO - Epoch 2/50, Iter 2400: Loss = 0.02394908107817173, lr = 0.0001
|
26 |
+
[02:43:54] INFO - Epoch 2/50, Iter 2500: Loss = 0.028264183551073074, lr = 0.0001
|
27 |
+
[02:44:48] INFO - Epoch 2/50, Iter 2600: Loss = 0.034485459327697754, lr = 0.0001
|
28 |
+
[02:45:44] INFO - Epoch 2/50, Iter 2700: Loss = 0.02295440435409546, lr = 0.0001
|
29 |
+
[02:46:38] INFO - Epoch 2/50, Iter 2800: Loss = 0.03146759420633316, lr = 0.0001
|
30 |
+
[02:47:33] INFO - Epoch 2/50, Iter 2900: Loss = 0.022224590182304382, lr = 0.0001
|
31 |
+
[02:50:06] INFO - Epoch 2/50, Iter 3000: Loss = 0.03717297315597534, lr = 0.0001
|
32 |
+
[02:51:00] INFO - Epoch 2/50, Iter 3100: Loss = 0.023568114265799522, lr = 0.0001
|
33 |
+
[02:51:55] INFO - Epoch 2/50, Iter 3200: Loss = 0.01752738654613495, lr = 0.0001
|
34 |
+
[02:52:49] INFO - Epoch 2/50, Iter 3300: Loss = 0.024697361513972282, lr = 0.0001
|
35 |
+
[02:53:44] INFO - Epoch 2/50, Iter 3400: Loss = 0.027621649205684662, lr = 0.0001
|
36 |
+
[02:54:39] INFO - Epoch 2/50, Iter 3500: Loss = 0.03197108209133148, lr = 0.0001
|
37 |
+
[02:55:34] INFO - Epoch 2/50, Iter 3600: Loss = 0.034603990614414215, lr = 0.0001
|
38 |
+
[02:56:29] INFO - Epoch 2/50, Iter 3700: Loss = 0.024781333282589912, lr = 0.0001
|
39 |
+
[02:57:25] INFO - Epoch 3/50, Iter 3800: Loss = 0.029720211401581764, lr = 0.0001
|
40 |
+
[02:58:20] INFO - Epoch 3/50, Iter 3900: Loss = 0.050903625786304474, lr = 0.0001
|
41 |
+
[03:00:52] INFO - Epoch 3/50, Iter 4000: Loss = 0.022276397794485092, lr = 0.0001
|
42 |
+
[03:01:48] INFO - Epoch 3/50, Iter 4100: Loss = 0.02051287144422531, lr = 0.0001
|
43 |
+
[03:02:42] INFO - Epoch 3/50, Iter 4200: Loss = 0.02138718217611313, lr = 0.0001
|
44 |
+
[03:03:37] INFO - Epoch 3/50, Iter 4300: Loss = 0.013692906126379967, lr = 0.0001
|
45 |
+
[03:04:31] INFO - Epoch 3/50, Iter 4400: Loss = 0.026416348293423653, lr = 0.0001
|
46 |
+
[03:05:26] INFO - Epoch 3/50, Iter 4500: Loss = 0.02263474650681019, lr = 0.0001
|
47 |
+
[03:06:21] INFO - Epoch 3/50, Iter 4600: Loss = 0.02561156451702118, lr = 0.0001
|
48 |
+
[03:07:15] INFO - Epoch 3/50, Iter 4700: Loss = 0.022007182240486145, lr = 0.0001
|
49 |
+
[03:08:11] INFO - Epoch 3/50, Iter 4800: Loss = 0.024828705936670303, lr = 0.0001
|
50 |
+
[03:09:05] INFO - Epoch 3/50, Iter 4900: Loss = 0.0277644544839859, lr = 0.0001
|
51 |
+
[03:11:37] INFO - Epoch 3/50, Iter 5000: Loss = 0.022669199854135513, lr = 0.0001
|
52 |
+
[03:12:31] INFO - Epoch 3/50, Iter 5100: Loss = 0.03488582372665405, lr = 0.0001
|
53 |
+
[03:13:27] INFO - Epoch 3/50, Iter 5200: Loss = 0.033707648515701294, lr = 0.0001
|
54 |
+
[03:14:21] INFO - Epoch 3/50, Iter 5300: Loss = 0.034617647528648376, lr = 0.0001
|
55 |
+
[03:15:16] INFO - Epoch 3/50, Iter 5400: Loss = 0.015979502350091934, lr = 0.0001
|
56 |
+
[03:16:11] INFO - Epoch 3/50, Iter 5500: Loss = 0.017885394394397736, lr = 0.0001
|
57 |
+
[03:17:05] INFO - Epoch 3/50, Iter 5600: Loss = 0.013684597797691822, lr = 0.0001
|
58 |
+
[03:18:01] INFO - Epoch 4/50, Iter 5700: Loss = 0.018592171370983124, lr = 0.0001
|
59 |
+
[03:18:56] INFO - Epoch 4/50, Iter 5800: Loss = 0.019852526485919952, lr = 0.0001
|
60 |
+
[03:19:52] INFO - Epoch 4/50, Iter 5900: Loss = 0.014810988679528236, lr = 0.0001
|
61 |
+
[03:22:25] INFO - Epoch 4/50, Iter 6000: Loss = 0.022946510463953018, lr = 0.0001
|
62 |
+
[03:23:19] INFO - Epoch 4/50, Iter 6100: Loss = 0.022477544844150543, lr = 0.0001
|
63 |
+
[03:24:14] INFO - Epoch 4/50, Iter 6200: Loss = 0.021514300256967545, lr = 0.0001
|
64 |
+
[03:25:09] INFO - Epoch 4/50, Iter 6300: Loss = 0.017631331458687782, lr = 0.0001
|
65 |
+
[03:26:03] INFO - Epoch 4/50, Iter 6400: Loss = 0.02970929630100727, lr = 0.0001
|
66 |
+
[03:26:58] INFO - Epoch 4/50, Iter 6500: Loss = 0.02417093515396118, lr = 0.0001
|
67 |
+
[03:27:52] INFO - Epoch 4/50, Iter 6600: Loss = 0.028470398858189583, lr = 0.0001
|
68 |
+
[03:28:48] INFO - Epoch 4/50, Iter 6700: Loss = 0.02186693623661995, lr = 0.0001
|
69 |
+
[03:29:42] INFO - Epoch 4/50, Iter 6800: Loss = 0.021022997796535492, lr = 0.0001
|
70 |
+
[03:30:37] INFO - Epoch 4/50, Iter 6900: Loss = 0.02663368172943592, lr = 0.0001
|
71 |
+
[03:33:08] INFO - Epoch 4/50, Iter 7000: Loss = 0.0202815942466259, lr = 0.0001
|
72 |
+
[03:34:04] INFO - Epoch 4/50, Iter 7100: Loss = 0.017694229260087013, lr = 0.0001
|
73 |
+
[03:34:58] INFO - Epoch 4/50, Iter 7200: Loss = 0.03217596560716629, lr = 0.0001
|
74 |
+
[03:35:53] INFO - Epoch 4/50, Iter 7300: Loss = 0.027110356837511063, lr = 0.0001
|
75 |
+
[03:36:47] INFO - Epoch 4/50, Iter 7400: Loss = 0.02598414570093155, lr = 0.0001
|
76 |
+
[03:37:42] INFO - Epoch 5/50, Iter 7500: Loss = 0.031232168897986412, lr = 0.0001
|
77 |
+
[03:38:37] INFO - Epoch 5/50, Iter 7600: Loss = 0.0394064262509346, lr = 0.0001
|
78 |
+
[03:39:33] INFO - Epoch 5/50, Iter 7700: Loss = 0.017326747998595238, lr = 0.0001
|
79 |
+
[03:40:28] INFO - Epoch 5/50, Iter 7800: Loss = 0.029284335672855377, lr = 0.0001
|
80 |
+
[03:41:24] INFO - Epoch 5/50, Iter 7900: Loss = 0.01525358110666275, lr = 0.0001
|
81 |
+
[03:43:56] INFO - Epoch 5/50, Iter 8000: Loss = 0.019312670454382896, lr = 0.0001
|
82 |
+
[03:44:51] INFO - Epoch 5/50, Iter 8100: Loss = 0.022943828254938126, lr = 0.0001
|
83 |
+
[03:45:46] INFO - Epoch 5/50, Iter 8200: Loss = 0.014834869652986526, lr = 0.0001
|
84 |
+
[03:46:40] INFO - Epoch 5/50, Iter 8300: Loss = 0.013647425919771194, lr = 0.0001
|
85 |
+
[03:47:35] INFO - Epoch 5/50, Iter 8400: Loss = 0.012797506526112556, lr = 0.0001
|
86 |
+
[03:48:29] INFO - Epoch 5/50, Iter 8500: Loss = 0.028487099334597588, lr = 0.0001
|
87 |
+
[03:49:25] INFO - Epoch 5/50, Iter 8600: Loss = 0.0326717309653759, lr = 0.0001
|
88 |
+
[03:50:20] INFO - Epoch 5/50, Iter 8700: Loss = 0.018652349710464478, lr = 0.0001
|
89 |
+
[03:51:14] INFO - Epoch 5/50, Iter 8800: Loss = 0.026515061035752296, lr = 0.0001
|
90 |
+
[03:52:09] INFO - Epoch 5/50, Iter 8900: Loss = 0.02715548872947693, lr = 0.0001
|
91 |
+
[03:54:40] INFO - Epoch 5/50, Iter 9000: Loss = 0.025071512907743454, lr = 0.0001
|
92 |
+
[03:55:34] INFO - Epoch 5/50, Iter 9100: Loss = 0.02286442741751671, lr = 0.0001
|
93 |
+
[03:56:29] INFO - Epoch 5/50, Iter 9200: Loss = 0.024927817285060883, lr = 0.0001
|
94 |
+
[03:57:25] INFO - Epoch 5/50, Iter 9300: Loss = 0.02016012743115425, lr = 0.0001
|
95 |
+
[03:58:20] INFO - Epoch 6/50, Iter 9400: Loss = 0.016080211848020554, lr = 0.0001
|
96 |
+
[03:59:16] INFO - Epoch 6/50, Iter 9500: Loss = 0.03025580570101738, lr = 0.0001
|
97 |
+
[04:00:10] INFO - Epoch 6/50, Iter 9600: Loss = 0.034918542951345444, lr = 0.0001
|
98 |
+
[04:01:06] INFO - Epoch 6/50, Iter 9700: Loss = 0.024010658264160156, lr = 0.0001
|
99 |
+
[04:02:01] INFO - Epoch 6/50, Iter 9800: Loss = 0.024768657982349396, lr = 0.0001
|
100 |
+
[04:02:57] INFO - Epoch 6/50, Iter 9900: Loss = 0.02912471443414688, lr = 0.0001
|
101 |
+
[04:05:28] INFO - Epoch 6/50, Iter 10000: Loss = 0.013935514725744724, lr = 0.0001
|
102 |
+
[04:06:24] INFO - Epoch 6/50, Iter 10100: Loss = 0.024383660405874252, lr = 0.0001
|
103 |
+
[04:07:19] INFO - Epoch 6/50, Iter 10200: Loss = 0.02626352570950985, lr = 0.0001
|
104 |
+
[04:08:13] INFO - Epoch 6/50, Iter 10300: Loss = 0.02143704704940319, lr = 0.0001
|
105 |
+
[04:09:08] INFO - Epoch 6/50, Iter 10400: Loss = 0.022659476846456528, lr = 0.0001
|
106 |
+
[04:10:02] INFO - Epoch 6/50, Iter 10500: Loss = 0.020370323210954666, lr = 0.0001
|
107 |
+
[04:10:57] INFO - Epoch 6/50, Iter 10600: Loss = 0.02100287191569805, lr = 0.0001
|
108 |
+
[04:11:52] INFO - Epoch 6/50, Iter 10700: Loss = 0.01825377717614174, lr = 0.0001
|
109 |
+
[04:12:46] INFO - Epoch 6/50, Iter 10800: Loss = 0.026205215603113174, lr = 0.0001
|
110 |
+
[04:13:42] INFO - Epoch 6/50, Iter 10900: Loss = 0.03552094101905823, lr = 0.0001
|
111 |
+
[04:16:13] INFO - Epoch 6/50, Iter 11000: Loss = 0.016668759286403656, lr = 0.0001
|
112 |
+
[04:17:07] INFO - Epoch 6/50, Iter 11100: Loss = 0.018555857241153717, lr = 0.0001
|
113 |
+
[04:18:02] INFO - Epoch 6/50, Iter 11200: Loss = 0.01698373258113861, lr = 0.0001
|
114 |
+
[04:18:58] INFO - Epoch 7/50, Iter 11300: Loss = 0.021595774218440056, lr = 0.0001
|
115 |
+
[04:19:53] INFO - Epoch 7/50, Iter 11400: Loss = 0.029402505606412888, lr = 0.0001
|
116 |
+
[04:20:49] INFO - Epoch 7/50, Iter 11500: Loss = 0.017380326986312866, lr = 0.0001
|
117 |
+
[04:21:44] INFO - Epoch 7/50, Iter 11600: Loss = 0.022462423890829086, lr = 0.0001
|
118 |
+
[04:22:40] INFO - Epoch 7/50, Iter 11700: Loss = 0.024359144270420074, lr = 0.0001
|
119 |
+
[04:23:35] INFO - Epoch 7/50, Iter 11800: Loss = 0.025637302547693253, lr = 0.0001
|
120 |
+
[04:24:31] INFO - Epoch 7/50, Iter 11900: Loss = 0.027863897383213043, lr = 0.0001
|
121 |
+
[04:27:02] INFO - Epoch 7/50, Iter 12000: Loss = 0.025426337495446205, lr = 0.0001
|
122 |
+
[04:27:58] INFO - Epoch 7/50, Iter 12100: Loss = 0.03268758952617645, lr = 0.0001
|
123 |
+
[04:28:52] INFO - Epoch 7/50, Iter 12200: Loss = 0.016548998653888702, lr = 0.0001
|
124 |
+
[04:29:47] INFO - Epoch 7/50, Iter 12300: Loss = 0.02512863650918007, lr = 0.0001
|
125 |
+
[04:30:41] INFO - Epoch 7/50, Iter 12400: Loss = 0.0246925987303257, lr = 0.0001
|
126 |
+
[04:31:36] INFO - Epoch 7/50, Iter 12500: Loss = 0.018600817769765854, lr = 0.0001
|
127 |
+
[04:32:31] INFO - Epoch 7/50, Iter 12600: Loss = 0.01979782059788704, lr = 0.0001
|
128 |
+
[04:33:25] INFO - Epoch 7/50, Iter 12700: Loss = 0.021152257919311523, lr = 0.0001
|
129 |
+
[04:34:21] INFO - Epoch 7/50, Iter 12800: Loss = 0.02903410792350769, lr = 0.0001
|
130 |
+
[04:35:16] INFO - Epoch 7/50, Iter 12900: Loss = 0.03196360170841217, lr = 0.0001
|
131 |
+
[04:37:46] INFO - Epoch 7/50, Iter 13000: Loss = 0.019338594749569893, lr = 0.0001
|
132 |
+
[04:38:41] INFO - Epoch 7/50, Iter 13100: Loss = 0.027051424607634544, lr = 0.0001
|
133 |
+
[04:39:37] INFO - Epoch 8/50, Iter 13200: Loss = 0.0238485224545002, lr = 0.0001
|
134 |
+
[04:40:32] INFO - Epoch 8/50, Iter 13300: Loss = 0.02585774101316929, lr = 0.0001
|
135 |
+
[04:41:28] INFO - Epoch 8/50, Iter 13400: Loss = 0.01865781843662262, lr = 0.0001
|
136 |
+
[04:42:23] INFO - Epoch 8/50, Iter 13500: Loss = 0.03003603406250477, lr = 0.0001
|
137 |
+
[04:43:17] INFO - Epoch 8/50, Iter 13600: Loss = 0.02756107971072197, lr = 0.0001
|
138 |
+
[04:44:13] INFO - Epoch 8/50, Iter 13700: Loss = 0.018252156674861908, lr = 0.0001
|
139 |
+
[04:45:09] INFO - Epoch 8/50, Iter 13800: Loss = 0.0232943594455719, lr = 0.0001
|
140 |
+
[04:46:04] INFO - Epoch 8/50, Iter 13900: Loss = 0.03505060076713562, lr = 0.0001
|
141 |
+
[04:48:36] INFO - Epoch 8/50, Iter 14000: Loss = 0.015609338879585266, lr = 0.0001
|
142 |
+
[04:49:31] INFO - Epoch 8/50, Iter 14100: Loss = 0.024727653712034225, lr = 0.0001
|
143 |
+
[04:50:25] INFO - Epoch 8/50, Iter 14200: Loss = 0.01343458704650402, lr = 0.0001
|
144 |
+
[04:51:20] INFO - Epoch 8/50, Iter 14300: Loss = 0.02276020497083664, lr = 0.0001
|
145 |
+
[04:52:15] INFO - Epoch 8/50, Iter 14400: Loss = 0.030666548758745193, lr = 0.0001
|
146 |
+
[04:53:09] INFO - Epoch 8/50, Iter 14500: Loss = 0.027710841968655586, lr = 0.0001
|
147 |
+
[04:54:04] INFO - Epoch 8/50, Iter 14600: Loss = 0.02813234180212021, lr = 0.0001
|
148 |
+
[04:54:58] INFO - Epoch 8/50, Iter 14700: Loss = 0.0154835544526577, lr = 0.0001
|
149 |
+
[04:55:54] INFO - Epoch 8/50, Iter 14800: Loss = 0.0330531969666481, lr = 0.0001
|
150 |
+
[04:56:49] INFO - Epoch 8/50, Iter 14900: Loss = 0.02566523663699627, lr = 0.0001
|
151 |
+
[04:59:20] INFO - Epoch 9/50, Iter 15000: Loss = 0.03587709367275238, lr = 0.0001
|
152 |
+
[05:00:16] INFO - Epoch 9/50, Iter 15100: Loss = 0.011817749589681625, lr = 0.0001
|
153 |
+
[05:01:11] INFO - Epoch 9/50, Iter 15200: Loss = 0.019955918192863464, lr = 0.0001
|
154 |
+
[05:02:06] INFO - Epoch 9/50, Iter 15300: Loss = 0.01926155760884285, lr = 0.0001
|
155 |
+
[05:03:01] INFO - Epoch 9/50, Iter 15400: Loss = 0.025760915130376816, lr = 0.0001
|
156 |
+
[05:03:57] INFO - Epoch 9/50, Iter 15500: Loss = 0.023390091955661774, lr = 0.0001
|
157 |
+
[05:04:52] INFO - Epoch 9/50, Iter 15600: Loss = 0.03382980450987816, lr = 0.0001
|
158 |
+
[05:05:48] INFO - Epoch 9/50, Iter 15700: Loss = 0.019686255604028702, lr = 0.0001
|
159 |
+
[05:06:43] INFO - Epoch 9/50, Iter 15800: Loss = 0.017689798027276993, lr = 0.0001
|
160 |
+
[05:07:39] INFO - Epoch 9/50, Iter 15900: Loss = 0.02643013373017311, lr = 0.0001
|
161 |
+
[05:10:10] INFO - Epoch 9/50, Iter 16000: Loss = 0.01975519210100174, lr = 0.0001
|
162 |
+
[05:11:05] INFO - Epoch 9/50, Iter 16100: Loss = 0.02566615864634514, lr = 0.0001
|
163 |
+
[05:12:01] INFO - Epoch 9/50, Iter 16200: Loss = 0.023744797334074974, lr = 0.0001
|
164 |
+
[05:12:54] INFO - Epoch 9/50, Iter 16300: Loss = 0.029149867594242096, lr = 0.0001
|
165 |
+
[05:13:50] INFO - Epoch 9/50, Iter 16400: Loss = 0.024619584903120995, lr = 0.0001
|
166 |
+
[05:14:44] INFO - Epoch 9/50, Iter 16500: Loss = 0.017802121117711067, lr = 0.0001
|
167 |
+
[05:15:39] INFO - Epoch 9/50, Iter 16600: Loss = 0.030343685299158096, lr = 0.0001
|
168 |
+
[05:16:34] INFO - Epoch 9/50, Iter 16700: Loss = 0.028128691017627716, lr = 0.0001
|
169 |
+
[05:17:28] INFO - Epoch 9/50, Iter 16800: Loss = 0.013130296021699905, lr = 0.0001
|
170 |
+
[05:18:23] INFO - Epoch 10/50, Iter 16900: Loss = 0.015325885266065598, lr = 0.0001
|
171 |
+
[05:20:55] INFO - Epoch 10/50, Iter 17000: Loss = 0.02369626611471176, lr = 0.0001
|
172 |
+
[05:21:50] INFO - Epoch 10/50, Iter 17100: Loss = 0.03911880403757095, lr = 0.0001
|
173 |
+
[05:22:44] INFO - Epoch 10/50, Iter 17200: Loss = 0.019555510953068733, lr = 0.0001
|
174 |
+
[05:23:40] INFO - Epoch 10/50, Iter 17300: Loss = 0.026994436979293823, lr = 0.0001
|
175 |
+
[05:24:35] INFO - Epoch 10/50, Iter 17400: Loss = 0.014918794855475426, lr = 0.0001
|
176 |
+
[05:25:29] INFO - Epoch 10/50, Iter 17500: Loss = 0.015928588807582855, lr = 0.0001
|
177 |
+
[05:26:24] INFO - Epoch 10/50, Iter 17600: Loss = 0.026111863553524017, lr = 0.0001
|
178 |
+
[05:27:19] INFO - Epoch 10/50, Iter 17700: Loss = 0.023383410647511482, lr = 0.0001
|
179 |
+
[05:28:13] INFO - Epoch 10/50, Iter 17800: Loss = 0.022820118814706802, lr = 0.0001
|
180 |
+
[05:29:08] INFO - Epoch 10/50, Iter 17900: Loss = 0.016951140016317368, lr = 0.0001
|
181 |
+
[05:31:40] INFO - Epoch 10/50, Iter 18000: Loss = 0.021106135100126266, lr = 0.0001
|
182 |
+
[05:32:34] INFO - Epoch 10/50, Iter 18100: Loss = 0.015148286707699299, lr = 0.0001
|
183 |
+
[05:33:29] INFO - Epoch 10/50, Iter 18200: Loss = 0.019842375069856644, lr = 0.0001
|
184 |
+
[05:34:24] INFO - Epoch 10/50, Iter 18300: Loss = 0.022392811253666878, lr = 0.0001
|
185 |
+
[05:35:18] INFO - Epoch 10/50, Iter 18400: Loss = 0.02733965963125229, lr = 0.0001
|
186 |
+
[05:36:13] INFO - Epoch 10/50, Iter 18500: Loss = 0.02087550237774849, lr = 0.0001
|
187 |
+
[05:37:08] INFO - Epoch 10/50, Iter 18600: Loss = 0.02672572433948517, lr = 0.0001
|
188 |
+
[05:38:03] INFO - Epoch 10/50, Iter 18700: Loss = 0.02076902985572815, lr = 0.0001
|
189 |
+
[05:38:59] INFO - Epoch 11/50, Iter 18800: Loss = 0.0208309106528759, lr = 0.0001
|
190 |
+
[05:39:54] INFO - Epoch 11/50, Iter 18900: Loss = 0.01603943109512329, lr = 0.0001
|
191 |
+
[05:42:26] INFO - Epoch 11/50, Iter 19000: Loss = 0.018146460875868797, lr = 0.0001
|
192 |
+
[05:43:20] INFO - Epoch 11/50, Iter 19100: Loss = 0.03146671503782272, lr = 0.0001
|
193 |
+
[05:44:15] INFO - Epoch 11/50, Iter 19200: Loss = 0.017263440415263176, lr = 0.0001
|
194 |
+
[05:45:10] INFO - Epoch 11/50, Iter 19300: Loss = 0.021944427862763405, lr = 0.0001
|
195 |
+
[05:46:04] INFO - Epoch 11/50, Iter 19400: Loss = 0.017847534269094467, lr = 0.0001
|
196 |
+
[05:46:59] INFO - Epoch 11/50, Iter 19500: Loss = 0.021428382024168968, lr = 0.0001
|
197 |
+
[05:47:55] INFO - Epoch 11/50, Iter 19600: Loss = 0.020893530920147896, lr = 0.0001
|
198 |
+
[05:48:49] INFO - Epoch 11/50, Iter 19700: Loss = 0.02261212095618248, lr = 0.0001
|
199 |
+
[05:49:44] INFO - Epoch 11/50, Iter 19800: Loss = 0.017424296587705612, lr = 0.0001
|
200 |
+
[05:50:39] INFO - Epoch 11/50, Iter 19900: Loss = 0.025077205151319504, lr = 0.0001
|
201 |
+
[05:53:10] INFO - Epoch 11/50, Iter 20000: Loss = 0.029975447803735733, lr = 0.0001
|
202 |
+
[05:54:04] INFO - Epoch 11/50, Iter 20100: Loss = 0.019458118826150894, lr = 0.0001
|
203 |
+
[05:54:59] INFO - Epoch 11/50, Iter 20200: Loss = 0.0232146717607975, lr = 0.0001
|
204 |
+
[05:55:53] INFO - Epoch 11/50, Iter 20300: Loss = 0.02360851876437664, lr = 0.0001
|
205 |
+
[05:56:48] INFO - Epoch 11/50, Iter 20400: Loss = 0.024858074262738228, lr = 0.0001
|
206 |
+
[05:57:44] INFO - Epoch 11/50, Iter 20500: Loss = 0.044195011258125305, lr = 0.0001
|
207 |
+
[05:58:38] INFO - Epoch 11/50, Iter 20600: Loss = 0.018540263175964355, lr = 0.0001
|
208 |
+
[05:59:33] INFO - Epoch 12/50, Iter 20700: Loss = 0.021583855152130127, lr = 0.0001
|
209 |
+
[06:00:29] INFO - Epoch 12/50, Iter 20800: Loss = 0.02421833947300911, lr = 0.0001
|
210 |
+
[06:01:24] INFO - Epoch 12/50, Iter 20900: Loss = 0.026535984128713608, lr = 0.0001
|
211 |
+
[06:03:57] INFO - Epoch 12/50, Iter 21000: Loss = 0.01781940832734108, lr = 0.0001
|
212 |
+
[06:04:51] INFO - Epoch 12/50, Iter 21100: Loss = 0.023128725588321686, lr = 0.0001
|
213 |
+
[06:05:46] INFO - Epoch 12/50, Iter 21200: Loss = 0.02317957766354084, lr = 0.0001
|
214 |
+
[06:06:40] INFO - Epoch 12/50, Iter 21300: Loss = 0.016345253214240074, lr = 0.0001
|
215 |
+
[06:07:36] INFO - Epoch 12/50, Iter 21400: Loss = 0.02558373659849167, lr = 0.0001
|
216 |
+
[06:08:31] INFO - Epoch 12/50, Iter 21500: Loss = 0.026121504604816437, lr = 0.0001
|
217 |
+
[06:09:25] INFO - Epoch 12/50, Iter 21600: Loss = 0.022759977728128433, lr = 0.0001
|
218 |
+
[06:10:20] INFO - Epoch 12/50, Iter 21700: Loss = 0.026271792128682137, lr = 0.0001
|
219 |
+
[06:11:14] INFO - Epoch 12/50, Iter 21800: Loss = 0.027187272906303406, lr = 0.0001
|
220 |
+
[06:12:09] INFO - Epoch 12/50, Iter 21900: Loss = 0.023094702512025833, lr = 0.0001
|
221 |
+
[06:14:40] INFO - Epoch 12/50, Iter 22000: Loss = 0.016669970005750656, lr = 0.0001
|
222 |
+
[06:15:36] INFO - Epoch 12/50, Iter 22100: Loss = 0.026704635471105576, lr = 0.0001
|
223 |
+
[06:16:30] INFO - Epoch 12/50, Iter 22200: Loss = 0.02754068374633789, lr = 0.0001
|
224 |
+
[06:17:25] INFO - Epoch 12/50, Iter 22300: Loss = 0.025661129504442215, lr = 0.0001
|
225 |
+
[06:18:19] INFO - Epoch 12/50, Iter 22400: Loss = 0.025509830564260483, lr = 0.0001
|
226 |
+
[06:19:14] INFO - Epoch 13/50, Iter 22500: Loss = 0.025348283350467682, lr = 0.0001
|
227 |
+
[06:20:10] INFO - Epoch 13/50, Iter 22600: Loss = 0.026772376149892807, lr = 0.0001
|
228 |
+
[06:21:05] INFO - Epoch 13/50, Iter 22700: Loss = 0.01741105318069458, lr = 0.0001
|
229 |
+
[06:22:01] INFO - Epoch 13/50, Iter 22800: Loss = 0.02285039983689785, lr = 0.0001
|
230 |
+
[06:22:56] INFO - Epoch 13/50, Iter 22900: Loss = 0.027282923460006714, lr = 0.0001
|
231 |
+
[06:25:28] INFO - Epoch 13/50, Iter 23000: Loss = 0.012414131313562393, lr = 0.0001
|
232 |
+
[06:26:23] INFO - Epoch 13/50, Iter 23100: Loss = 0.019650613889098167, lr = 0.0001
|
233 |
+
[06:27:18] INFO - Epoch 13/50, Iter 23200: Loss = 0.02651660516858101, lr = 0.0001
|
234 |
+
[06:28:12] INFO - Epoch 13/50, Iter 23300: Loss = 0.026138421148061752, lr = 0.0001
|
235 |
+
[06:29:07] INFO - Epoch 13/50, Iter 23400: Loss = 0.018627706915140152, lr = 0.0001
|
236 |
+
[06:30:01] INFO - Epoch 13/50, Iter 23500: Loss = 0.028943434357643127, lr = 0.0001
|
237 |
+
[06:30:57] INFO - Epoch 13/50, Iter 23600: Loss = 0.01649133488535881, lr = 0.0001
|
238 |
+
[06:31:51] INFO - Epoch 13/50, Iter 23700: Loss = 0.01378883421421051, lr = 0.0001
|
239 |
+
[06:32:46] INFO - Epoch 13/50, Iter 23800: Loss = 0.02124626189470291, lr = 0.0001
|
240 |
+
[06:33:41] INFO - Epoch 13/50, Iter 23900: Loss = 0.017396021634340286, lr = 0.0001
|
241 |
+
[06:36:12] INFO - Epoch 13/50, Iter 24000: Loss = 0.01732352189719677, lr = 0.0001
|
242 |
+
[06:37:06] INFO - Epoch 13/50, Iter 24100: Loss = 0.014166954904794693, lr = 0.0001
|
243 |
+
[06:38:02] INFO - Epoch 13/50, Iter 24200: Loss = 0.02176068350672722, lr = 0.0001
|
244 |
+
[06:38:57] INFO - Epoch 13/50, Iter 24300: Loss = 0.019656777381896973, lr = 0.0001
|
245 |
+
[06:39:51] INFO - Epoch 14/50, Iter 24400: Loss = 0.02193061262369156, lr = 0.0001
|
246 |
+
[06:40:46] INFO - Epoch 14/50, Iter 24500: Loss = 0.018643012270331383, lr = 0.0001
|
247 |
+
[06:41:42] INFO - Epoch 14/50, Iter 24600: Loss = 0.012337702326476574, lr = 0.0001
|
248 |
+
[06:42:37] INFO - Epoch 14/50, Iter 24700: Loss = 0.016973398625850677, lr = 0.0001
|
249 |
+
[06:43:33] INFO - Epoch 14/50, Iter 24800: Loss = 0.025368668138980865, lr = 0.0001
|
250 |
+
[06:44:28] INFO - Epoch 14/50, Iter 24900: Loss = 0.02520618960261345, lr = 0.0001
|
251 |
+
[06:47:00] INFO - Epoch 14/50, Iter 25000: Loss = 0.01767529547214508, lr = 0.0001
|
252 |
+
[06:47:55] INFO - Epoch 14/50, Iter 25100: Loss = 0.021381141617894173, lr = 0.0001
|
253 |
+
[06:48:49] INFO - Epoch 14/50, Iter 25200: Loss = 0.021116536110639572, lr = 0.0001
|
254 |
+
[06:49:44] INFO - Epoch 14/50, Iter 25300: Loss = 0.017928242683410645, lr = 0.0001
|
255 |
+
[06:50:39] INFO - Epoch 14/50, Iter 25400: Loss = 0.021284624934196472, lr = 0.0001
|
256 |
+
[06:51:33] INFO - Epoch 14/50, Iter 25500: Loss = 0.013009730726480484, lr = 0.0001
|
257 |
+
[06:52:28] INFO - Epoch 14/50, Iter 25600: Loss = 0.018284976482391357, lr = 0.0001
|
258 |
+
[06:53:22] INFO - Epoch 14/50, Iter 25700: Loss = 0.019000139087438583, lr = 0.0001
|
259 |
+
[06:54:18] INFO - Epoch 14/50, Iter 25800: Loss = 0.01757623441517353, lr = 0.0001
|
260 |
+
[06:55:12] INFO - Epoch 14/50, Iter 25900: Loss = 0.019956346601247787, lr = 0.0001
|
261 |
+
[06:57:43] INFO - Epoch 14/50, Iter 26000: Loss = 0.025380369275808334, lr = 0.0001
|
262 |
+
[06:58:38] INFO - Epoch 14/50, Iter 26100: Loss = 0.02575628086924553, lr = 0.0001
|
263 |
+
[06:59:32] INFO - Epoch 14/50, Iter 26200: Loss = 0.02441999688744545, lr = 0.0001
|
264 |
+
[07:00:28] INFO - Epoch 15/50, Iter 26300: Loss = 0.015507195144891739, lr = 0.0001
|
265 |
+
[07:01:23] INFO - Epoch 15/50, Iter 26400: Loss = 0.018518857657909393, lr = 0.0001
|
266 |
+
[07:02:18] INFO - Epoch 15/50, Iter 26500: Loss = 0.0218639075756073, lr = 0.0001
|
267 |
+
[07:03:14] INFO - Epoch 15/50, Iter 26600: Loss = 0.01484048180282116, lr = 0.0001
|
268 |
+
[07:04:09] INFO - Epoch 15/50, Iter 26700: Loss = 0.020309407263994217, lr = 0.0001
|
269 |
+
[07:05:05] INFO - Epoch 15/50, Iter 26800: Loss = 0.02281174622476101, lr = 0.0001
|
270 |
+
[07:06:00] INFO - Epoch 15/50, Iter 26900: Loss = 0.022504042834043503, lr = 0.0001
|
271 |
+
[07:08:32] INFO - Epoch 15/50, Iter 27000: Loss = 0.016440019011497498, lr = 0.0001
|
272 |
+
[07:09:27] INFO - Epoch 15/50, Iter 27100: Loss = 0.015486285090446472, lr = 0.0001
|
273 |
+
[07:10:21] INFO - Epoch 15/50, Iter 27200: Loss = 0.01972173899412155, lr = 0.0001
|
274 |
+
[07:11:16] INFO - Epoch 15/50, Iter 27300: Loss = 0.018617577850818634, lr = 0.0001
|
275 |
+
[07:12:11] INFO - Epoch 15/50, Iter 27400: Loss = 0.02082516998052597, lr = 0.0001
|
276 |
+
[07:13:05] INFO - Epoch 15/50, Iter 27500: Loss = 0.01791219785809517, lr = 0.0001
|
277 |
+
[07:14:00] INFO - Epoch 15/50, Iter 27600: Loss = 0.02241137996315956, lr = 0.0001
|
278 |
+
[07:14:54] INFO - Epoch 15/50, Iter 27700: Loss = 0.020293384790420532, lr = 0.0001
|
279 |
+
[07:15:49] INFO - Epoch 15/50, Iter 27800: Loss = 0.029861796647310257, lr = 0.0001
|
280 |
+
[07:16:44] INFO - Epoch 15/50, Iter 27900: Loss = 0.02275857701897621, lr = 0.0001
|
281 |
+
[07:19:16] INFO - Epoch 15/50, Iter 28000: Loss = 0.015355780720710754, lr = 0.0001
|
282 |
+
[07:20:10] INFO - Epoch 15/50, Iter 28100: Loss = 0.019503731280565262, lr = 0.0001
|
283 |
+
[07:21:05] INFO - Epoch 16/50, Iter 28200: Loss = 0.024656936526298523, lr = 0.0001
|
284 |
+
[07:22:01] INFO - Epoch 16/50, Iter 28300: Loss = 0.016661042347550392, lr = 0.0001
|
285 |
+
[07:22:56] INFO - Epoch 16/50, Iter 28400: Loss = 0.017921866849064827, lr = 0.0001
|
286 |
+
[07:23:52] INFO - Epoch 16/50, Iter 28500: Loss = 0.020502446219325066, lr = 0.0001
|
287 |
+
[07:24:47] INFO - Epoch 16/50, Iter 28600: Loss = 0.012834666296839714, lr = 0.0001
|
288 |
+
[07:25:42] INFO - Epoch 16/50, Iter 28700: Loss = 0.017596762627363205, lr = 0.0001
|
289 |
+
[07:26:37] INFO - Epoch 16/50, Iter 28800: Loss = 0.02352038025856018, lr = 0.0001
|
290 |
+
[07:27:32] INFO - Epoch 16/50, Iter 28900: Loss = 0.022114895284175873, lr = 0.0001
|
291 |
+
[07:30:05] INFO - Epoch 16/50, Iter 29000: Loss = 0.018584776669740677, lr = 0.0001
|
292 |
+
[07:30:59] INFO - Epoch 16/50, Iter 29100: Loss = 0.021322712302207947, lr = 0.0001
|
293 |
+
[07:31:54] INFO - Epoch 16/50, Iter 29200: Loss = 0.01889413595199585, lr = 0.0001
|
294 |
+
[07:32:48] INFO - Epoch 16/50, Iter 29300: Loss = 0.027229465544223785, lr = 0.0001
|
295 |
+
[07:33:43] INFO - Epoch 16/50, Iter 29400: Loss = 0.026700954884290695, lr = 0.0001
|
296 |
+
[07:34:37] INFO - Epoch 16/50, Iter 29500: Loss = 0.026901915669441223, lr = 0.0001
|
297 |
+
[07:35:32] INFO - Epoch 16/50, Iter 29600: Loss = 0.0257167499512434, lr = 0.0001
|
298 |
+
[07:36:27] INFO - Epoch 16/50, Iter 29700: Loss = 0.023790445178747177, lr = 0.0001
|
299 |
+
[07:37:21] INFO - Epoch 16/50, Iter 29800: Loss = 0.010275682434439659, lr = 0.0001
|
300 |
+
[07:38:17] INFO - Epoch 16/50, Iter 29900: Loss = 0.024285804480314255, lr = 0.0001
|
301 |
+
[07:40:48] INFO - Epoch 17/50, Iter 30000: Loss = 0.01686658337712288, lr = 0.0001
|
302 |
+
[07:41:44] INFO - Epoch 17/50, Iter 30100: Loss = 0.019942965358495712, lr = 0.0001
|
303 |
+
[07:42:39] INFO - Epoch 17/50, Iter 30200: Loss = 0.032290853559970856, lr = 0.0001
|
304 |
+
[07:43:35] INFO - Epoch 17/50, Iter 30300: Loss = 0.02391435205936432, lr = 0.0001
|
305 |
+
[07:44:29] INFO - Epoch 17/50, Iter 30400: Loss = 0.022961270064115524, lr = 0.0001
|
306 |
+
[07:45:24] INFO - Epoch 17/50, Iter 30500: Loss = 0.02686147764325142, lr = 0.0001
|
307 |
+
[07:46:20] INFO - Epoch 17/50, Iter 30600: Loss = 0.021469425410032272, lr = 0.0001
|
308 |
+
[07:47:15] INFO - Epoch 17/50, Iter 30700: Loss = 0.019237644970417023, lr = 0.0001
|
309 |
+
[07:48:11] INFO - Epoch 17/50, Iter 30800: Loss = 0.01243587676435709, lr = 0.0001
|
310 |
+
[07:49:06] INFO - Epoch 17/50, Iter 30900: Loss = 0.019927412271499634, lr = 0.0001
|
311 |
+
[07:51:38] INFO - Epoch 17/50, Iter 31000: Loss = 0.021345121785998344, lr = 0.0001
|
312 |
+
[07:52:33] INFO - Epoch 17/50, Iter 31100: Loss = 0.0189402773976326, lr = 0.0001
|
313 |
+
[07:53:28] INFO - Epoch 17/50, Iter 31200: Loss = 0.022389506921172142, lr = 0.0001
|
314 |
+
[07:54:22] INFO - Epoch 17/50, Iter 31300: Loss = 0.019248703494668007, lr = 0.0001
|
315 |
+
[07:55:18] INFO - Epoch 17/50, Iter 31400: Loss = 0.020908750593662262, lr = 0.0001
|
316 |
+
[07:56:12] INFO - Epoch 17/50, Iter 31500: Loss = 0.029640033841133118, lr = 0.0001
|
317 |
+
[07:57:07] INFO - Epoch 17/50, Iter 31600: Loss = 0.026583340018987656, lr = 0.0001
|
318 |
+
[07:58:02] INFO - Epoch 17/50, Iter 31700: Loss = 0.01729031279683113, lr = 0.0001
|
319 |
+
[07:58:56] INFO - Epoch 17/50, Iter 31800: Loss = 0.026669491082429886, lr = 0.0001
|
320 |
+
[07:59:51] INFO - Epoch 18/50, Iter 31900: Loss = 0.015399916097521782, lr = 0.0001
|
321 |
+
[08:02:23] INFO - Epoch 18/50, Iter 32000: Loss = 0.027698248624801636, lr = 0.0001
|
322 |
+
[08:03:18] INFO - Epoch 18/50, Iter 32100: Loss = 0.020098572596907616, lr = 0.0001
|
323 |
+
[08:04:12] INFO - Epoch 18/50, Iter 32200: Loss = 0.023418741300702095, lr = 0.0001
|
324 |
+
[08:05:07] INFO - Epoch 18/50, Iter 32300: Loss = 0.015688564628362656, lr = 0.0001
|
325 |
+
[08:06:02] INFO - Epoch 18/50, Iter 32400: Loss = 0.013760192319750786, lr = 0.0001
|
326 |
+
[08:06:56] INFO - Epoch 18/50, Iter 32500: Loss = 0.018602928146719933, lr = 0.0001
|
327 |
+
[08:07:52] INFO - Epoch 18/50, Iter 32600: Loss = 0.0171047393232584, lr = 0.0001
|
328 |
+
[08:08:46] INFO - Epoch 18/50, Iter 32700: Loss = 0.02287128195166588, lr = 0.0001
|
329 |
+
[08:09:41] INFO - Epoch 18/50, Iter 32800: Loss = 0.01747080124914646, lr = 0.0001
|
330 |
+
[08:10:35] INFO - Epoch 18/50, Iter 32900: Loss = 0.032003749161958694, lr = 0.0001
|
331 |
+
[08:13:06] INFO - Epoch 18/50, Iter 33000: Loss = 0.021088197827339172, lr = 0.0001
|
332 |
+
[08:14:01] INFO - Epoch 18/50, Iter 33100: Loss = 0.0243061576038599, lr = 0.0001
|
333 |
+
[08:14:55] INFO - Epoch 18/50, Iter 33200: Loss = 0.017390495166182518, lr = 0.0001
|
334 |
+
[08:15:50] INFO - Epoch 18/50, Iter 33300: Loss = 0.027531778439879417, lr = 0.0001
|
335 |
+
[08:16:45] INFO - Epoch 18/50, Iter 33400: Loss = 0.01495380699634552, lr = 0.0001
|
336 |
+
[08:17:39] INFO - Epoch 18/50, Iter 33500: Loss = 0.02041369117796421, lr = 0.0001
|
337 |
+
[08:18:35] INFO - Epoch 18/50, Iter 33600: Loss = 0.016778916120529175, lr = 0.0001
|
338 |
+
[08:19:29] INFO - Epoch 18/50, Iter 33700: Loss = 0.0185483880341053, lr = 0.0001
|
339 |
+
[08:20:24] INFO - Epoch 19/50, Iter 33800: Loss = 0.017258750274777412, lr = 0.0001
|
340 |
+
[08:21:20] INFO - Epoch 19/50, Iter 33900: Loss = 0.013514120131731033, lr = 0.0001
|
341 |
+
[08:23:52] INFO - Epoch 19/50, Iter 34000: Loss = 0.017329292371869087, lr = 0.0001
|
342 |
+
[08:24:46] INFO - Epoch 19/50, Iter 34100: Loss = 0.03175392746925354, lr = 0.0001
|
343 |
+
[08:25:42] INFO - Epoch 19/50, Iter 34200: Loss = 0.024144772440195084, lr = 0.0001
|
344 |
+
[08:26:36] INFO - Epoch 19/50, Iter 34300: Loss = 0.025116432458162308, lr = 0.0001
|
345 |
+
[08:27:31] INFO - Epoch 19/50, Iter 34400: Loss = 0.023968493565917015, lr = 0.0001
|
346 |
+
[08:28:26] INFO - Epoch 19/50, Iter 34500: Loss = 0.023263823240995407, lr = 0.0001
|
347 |
+
[08:29:20] INFO - Epoch 19/50, Iter 34600: Loss = 0.015572518110275269, lr = 0.0001
|
348 |
+
[08:30:15] INFO - Epoch 19/50, Iter 34700: Loss = 0.011077907867729664, lr = 0.0001
|
349 |
+
[08:31:10] INFO - Epoch 19/50, Iter 34800: Loss = 0.019685542210936546, lr = 0.0001
|
350 |
+
[08:32:04] INFO - Epoch 19/50, Iter 34900: Loss = 0.026246516034007072, lr = 0.0001
|
351 |
+
[08:34:35] INFO - Epoch 19/50, Iter 35000: Loss = 0.0264703631401062, lr = 0.0001
|
352 |
+
[08:35:31] INFO - Epoch 19/50, Iter 35100: Loss = 0.018090050667524338, lr = 0.0001
|
353 |
+
[08:36:25] INFO - Epoch 19/50, Iter 35200: Loss = 0.014332180842757225, lr = 0.0001
|
354 |
+
[08:37:20] INFO - Epoch 19/50, Iter 35300: Loss = 0.03227975219488144, lr = 0.0001
|
355 |
+
[08:38:15] INFO - Epoch 19/50, Iter 35400: Loss = 0.017180195078253746, lr = 0.0001
|
356 |
+
[08:39:09] INFO - Epoch 19/50, Iter 35500: Loss = 0.01773938722908497, lr = 0.0001
|
357 |
+
[08:40:04] INFO - Epoch 19/50, Iter 35600: Loss = 0.02321586385369301, lr = 0.0001
|
358 |
+
[08:41:00] INFO - Epoch 20/50, Iter 35700: Loss = 0.018052995204925537, lr = 0.0001
|
359 |
+
[08:41:55] INFO - Epoch 20/50, Iter 35800: Loss = 0.02333519607782364, lr = 0.0001
|
360 |
+
[08:42:51] INFO - Epoch 20/50, Iter 35900: Loss = 0.023782718926668167, lr = 0.0001
|
361 |
+
[08:45:22] INFO - Epoch 20/50, Iter 36000: Loss = 0.021948453038930893, lr = 0.0001
|
362 |
+
[08:46:17] INFO - Epoch 20/50, Iter 36100: Loss = 0.01616925373673439, lr = 0.0001
|
363 |
+
[08:47:11] INFO - Epoch 20/50, Iter 36200: Loss = 0.0195147804915905, lr = 0.0001
|
364 |
+
[08:48:07] INFO - Epoch 20/50, Iter 36300: Loss = 0.02167724072933197, lr = 0.0001
|
365 |
+
[08:49:02] INFO - Epoch 20/50, Iter 36400: Loss = 0.017993919551372528, lr = 0.0001
|
366 |
+
[08:49:56] INFO - Epoch 20/50, Iter 36500: Loss = 0.024179894477128983, lr = 0.0001
|
367 |
+
[08:50:51] INFO - Epoch 20/50, Iter 36600: Loss = 0.029972080141305923, lr = 0.0001
|
368 |
+
[08:51:45] INFO - Epoch 20/50, Iter 36700: Loss = 0.02250525914132595, lr = 0.0001
|
369 |
+
[08:52:40] INFO - Epoch 20/50, Iter 36800: Loss = 0.016068585216999054, lr = 0.0001
|
370 |
+
[08:53:35] INFO - Epoch 20/50, Iter 36900: Loss = 0.02062491700053215, lr = 0.0001
|
371 |
+
[08:56:07] INFO - Epoch 20/50, Iter 37000: Loss = 0.026054339483380318, lr = 0.0001
|
372 |
+
[08:57:01] INFO - Epoch 20/50, Iter 37100: Loss = 0.01617574132978916, lr = 0.0001
|
373 |
+
[08:57:56] INFO - Epoch 20/50, Iter 37200: Loss = 0.01841990277171135, lr = 0.0001
|
374 |
+
[08:58:51] INFO - Epoch 20/50, Iter 37300: Loss = 0.016723550856113434, lr = 0.0001
|
375 |
+
[08:59:45] INFO - Epoch 20/50, Iter 37400: Loss = 0.015482468530535698, lr = 0.0001
|
376 |
+
[09:00:41] INFO - Epoch 21/50, Iter 37500: Loss = 0.028426745906472206, lr = 0.0001
|
377 |
+
[09:01:36] INFO - Epoch 21/50, Iter 37600: Loss = 0.026276376098394394, lr = 0.0001
|
378 |
+
[09:02:32] INFO - Epoch 21/50, Iter 37700: Loss = 0.026483114808797836, lr = 0.0001
|
379 |
+
[09:03:27] INFO - Epoch 21/50, Iter 37800: Loss = 0.021477442234754562, lr = 0.0001
|
380 |
+
[09:04:21] INFO - Epoch 21/50, Iter 37900: Loss = 0.015382439829409122, lr = 0.0001
|
381 |
+
[09:06:54] INFO - Epoch 21/50, Iter 38000: Loss = 0.013858610764145851, lr = 0.0001
|
382 |
+
[09:07:48] INFO - Epoch 21/50, Iter 38100: Loss = 0.022090336307883263, lr = 0.0001
|
383 |
+
[09:08:44] INFO - Epoch 21/50, Iter 38200: Loss = 0.025041067972779274, lr = 0.0001
|
384 |
+
[09:09:39] INFO - Epoch 21/50, Iter 38300: Loss = 0.01404337864369154, lr = 0.0001
|
385 |
+
[09:10:33] INFO - Epoch 21/50, Iter 38400: Loss = 0.022372154518961906, lr = 0.0001
|
386 |
+
[09:11:28] INFO - Epoch 21/50, Iter 38500: Loss = 0.022488964721560478, lr = 0.0001
|
387 |
+
[09:12:22] INFO - Epoch 21/50, Iter 38600: Loss = 0.018394947052001953, lr = 0.0001
|
388 |
+
[09:13:17] INFO - Epoch 21/50, Iter 38700: Loss = 0.019345279783010483, lr = 0.0001
|
389 |
+
[09:14:12] INFO - Epoch 21/50, Iter 38800: Loss = 0.013524915091693401, lr = 0.0001
|
390 |
+
[09:15:06] INFO - Epoch 21/50, Iter 38900: Loss = 0.023479681462049484, lr = 0.0001
|
391 |
+
[09:17:38] INFO - Epoch 21/50, Iter 39000: Loss = 0.018239330500364304, lr = 0.0001
|
392 |
+
[09:18:33] INFO - Epoch 21/50, Iter 39100: Loss = 0.014270618557929993, lr = 0.0001
|
393 |
+
[09:19:27] INFO - Epoch 21/50, Iter 39200: Loss = 0.012470152229070663, lr = 0.0001
|
394 |
+
[09:20:22] INFO - Epoch 21/50, Iter 39300: Loss = 0.024510135874152184, lr = 0.0001
|
395 |
+
[09:21:18] INFO - Epoch 22/50, Iter 39400: Loss = 0.01967580057680607, lr = 0.0001
|
396 |
+
[09:22:13] INFO - Epoch 22/50, Iter 39500: Loss = 0.02651473507285118, lr = 0.0001
|
397 |
+
[09:23:09] INFO - Epoch 22/50, Iter 39600: Loss = 0.014456840232014656, lr = 0.0001
|
398 |
+
[09:24:03] INFO - Epoch 22/50, Iter 39700: Loss = 0.013815360143780708, lr = 0.0001
|
399 |
+
[09:24:58] INFO - Epoch 22/50, Iter 39800: Loss = 0.026865314692258835, lr = 0.0001
|
400 |
+
[09:25:54] INFO - Epoch 22/50, Iter 39900: Loss = 0.022365324199199677, lr = 0.0001
|
401 |
+
[09:28:27] INFO - Epoch 22/50, Iter 40000: Loss = 0.02029530331492424, lr = 0.0001
|
402 |
+
[09:29:21] INFO - Epoch 22/50, Iter 40100: Loss = 0.021116379648447037, lr = 0.0001
|
403 |
+
[09:30:16] INFO - Epoch 22/50, Iter 40200: Loss = 0.02509278617799282, lr = 0.0001
|
404 |
+
[09:31:11] INFO - Epoch 22/50, Iter 40300: Loss = 0.02551993355154991, lr = 0.0001
|
405 |
+
[09:32:05] INFO - Epoch 22/50, Iter 40400: Loss = 0.020986683666706085, lr = 0.0001
|
406 |
+
[09:33:00] INFO - Epoch 22/50, Iter 40500: Loss = 0.020868226885795593, lr = 0.0001
|
407 |
+
[09:33:54] INFO - Epoch 22/50, Iter 40600: Loss = 0.017478734254837036, lr = 0.0001
|
408 |
+
[09:34:49] INFO - Epoch 22/50, Iter 40700: Loss = 0.027790624648332596, lr = 0.0001
|
409 |
+
[09:35:45] INFO - Epoch 22/50, Iter 40800: Loss = 0.022644832730293274, lr = 0.0001
|
410 |
+
[09:36:39] INFO - Epoch 22/50, Iter 40900: Loss = 0.024670612066984177, lr = 0.0001
|
411 |
+
[09:39:10] INFO - Epoch 22/50, Iter 41000: Loss = 0.026195334270596504, lr = 0.0001
|
412 |
+
[09:40:05] INFO - Epoch 22/50, Iter 41100: Loss = 0.021374046802520752, lr = 0.0001
|
413 |
+
[09:41:00] INFO - Epoch 22/50, Iter 41200: Loss = 0.02115592733025551, lr = 0.0001
|
414 |
+
[09:41:56] INFO - Epoch 23/50, Iter 41300: Loss = 0.01633710041642189, lr = 0.0001
|
415 |
+
[09:42:52] INFO - Epoch 23/50, Iter 41400: Loss = 0.02131003886461258, lr = 0.0001
|
416 |
+
[09:43:46] INFO - Epoch 23/50, Iter 41500: Loss = 0.022764872759580612, lr = 0.0001
|
417 |
+
[09:44:41] INFO - Epoch 23/50, Iter 41600: Loss = 0.01728042960166931, lr = 0.0001
|
418 |
+
[09:45:37] INFO - Epoch 23/50, Iter 41700: Loss = 0.0162839163094759, lr = 0.0001
|
419 |
+
[09:46:32] INFO - Epoch 23/50, Iter 41800: Loss = 0.014318926259875298, lr = 0.0001
|
420 |
+
[09:47:28] INFO - Epoch 23/50, Iter 41900: Loss = 0.018346164375543594, lr = 0.0001
|
421 |
+
[09:49:59] INFO - Epoch 23/50, Iter 42000: Loss = 0.027812600135803223, lr = 0.0001
|
422 |
+
[09:50:55] INFO - Epoch 23/50, Iter 42100: Loss = 0.026753295212984085, lr = 0.0001
|
423 |
+
[09:51:50] INFO - Epoch 23/50, Iter 42200: Loss = 0.018069680780172348, lr = 0.0001
|
424 |
+
[09:52:44] INFO - Epoch 23/50, Iter 42300: Loss = 0.03101518750190735, lr = 0.0001
|
425 |
+
[09:53:39] INFO - Epoch 23/50, Iter 42400: Loss = 0.025507837533950806, lr = 0.0001
|
426 |
+
[09:54:34] INFO - Epoch 23/50, Iter 42500: Loss = 0.017935875803232193, lr = 0.0001
|
427 |
+
[09:55:28] INFO - Epoch 23/50, Iter 42600: Loss = 0.022867443040013313, lr = 0.0001
|
428 |
+
[09:56:23] INFO - Epoch 23/50, Iter 42700: Loss = 0.02030709572136402, lr = 0.0001
|
429 |
+
[09:57:18] INFO - Epoch 23/50, Iter 42800: Loss = 0.013310606591403484, lr = 0.0001
|
430 |
+
[09:58:13] INFO - Epoch 23/50, Iter 42900: Loss = 0.014713610522449017, lr = 0.0001
|
431 |
+
[10:00:44] INFO - Epoch 23/50, Iter 43000: Loss = 0.02300114557147026, lr = 0.0001
|
432 |
+
[10:01:39] INFO - Epoch 23/50, Iter 43100: Loss = 0.02343389019370079, lr = 0.0001
|
433 |
+
[10:02:35] INFO - Epoch 24/50, Iter 43200: Loss = 0.019669387489557266, lr = 0.0001
|
434 |
+
[10:03:30] INFO - Epoch 24/50, Iter 43300: Loss = 0.025514639914035797, lr = 0.0001
|
435 |
+
[10:04:25] INFO - Epoch 24/50, Iter 43400: Loss = 0.027034897357225418, lr = 0.0001
|
436 |
+
[10:05:20] INFO - Epoch 24/50, Iter 43500: Loss = 0.026066435500979424, lr = 0.0001
|
437 |
+
[10:06:16] INFO - Epoch 24/50, Iter 43600: Loss = 0.022791586816310883, lr = 0.0001
|
438 |
+
[10:07:11] INFO - Epoch 24/50, Iter 43700: Loss = 0.01600833050906658, lr = 0.0001
|
439 |
+
[10:08:07] INFO - Epoch 24/50, Iter 43800: Loss = 0.01834738627076149, lr = 0.0001
|
440 |
+
[10:09:02] INFO - Epoch 24/50, Iter 43900: Loss = 0.026411669328808784, lr = 0.0001
|
441 |
+
[10:11:34] INFO - Epoch 24/50, Iter 44000: Loss = 0.01697351410984993, lr = 0.0001
|
442 |
+
[10:12:29] INFO - Epoch 24/50, Iter 44100: Loss = 0.025164766237139702, lr = 0.0001
|
443 |
+
[10:13:24] INFO - Epoch 24/50, Iter 44200: Loss = 0.023120088502764702, lr = 0.0001
|
444 |
+
[10:14:18] INFO - Epoch 24/50, Iter 44300: Loss = 0.016470227390527725, lr = 0.0001
|
445 |
+
[10:15:13] INFO - Epoch 24/50, Iter 44400: Loss = 0.02092874050140381, lr = 0.0001
|
446 |
+
[10:16:09] INFO - Epoch 24/50, Iter 44500: Loss = 0.017084982246160507, lr = 0.0001
|
447 |
+
[10:17:02] INFO - Epoch 24/50, Iter 44600: Loss = 0.01771422289311886, lr = 0.0001
|
448 |
+
[10:17:58] INFO - Epoch 24/50, Iter 44700: Loss = 0.01557396911084652, lr = 0.0001
|
449 |
+
[10:18:52] INFO - Epoch 24/50, Iter 44800: Loss = 0.01830480992794037, lr = 0.0001
|
450 |
+
[10:19:47] INFO - Epoch 24/50, Iter 44900: Loss = 0.03161770850419998, lr = 0.0001
|
451 |
+
[10:22:19] INFO - Epoch 25/50, Iter 45000: Loss = 0.013423663564026356, lr = 0.0001
|
452 |
+
[10:23:14] INFO - Epoch 25/50, Iter 45100: Loss = 0.0297955684363842, lr = 0.0001
|
453 |
+
[10:24:10] INFO - Epoch 25/50, Iter 45200: Loss = 0.02846469357609749, lr = 0.0001
|
454 |
+
[10:25:06] INFO - Epoch 25/50, Iter 45300: Loss = 0.015436829067766666, lr = 0.0001
|
455 |
+
[10:26:01] INFO - Epoch 25/50, Iter 45400: Loss = 0.024918153882026672, lr = 0.0001
|
456 |
+
[10:26:57] INFO - Epoch 25/50, Iter 45500: Loss = 0.02270306646823883, lr = 0.0001
|
457 |
+
[10:27:52] INFO - Epoch 25/50, Iter 45600: Loss = 0.015784474089741707, lr = 0.0001
|
458 |
+
[10:28:46] INFO - Epoch 25/50, Iter 45700: Loss = 0.011514103971421719, lr = 0.0001
|
459 |
+
[10:29:42] INFO - Epoch 25/50, Iter 45800: Loss = 0.024075977504253387, lr = 0.0001
|
460 |
+
[10:30:37] INFO - Epoch 25/50, Iter 45900: Loss = 0.018384993076324463, lr = 0.0001
|
461 |
+
[10:33:10] INFO - Epoch 25/50, Iter 46000: Loss = 0.024563699960708618, lr = 0.0001
|
462 |
+
[10:34:04] INFO - Epoch 25/50, Iter 46100: Loss = 0.015144889242947102, lr = 0.0001
|
463 |
+
[10:34:59] INFO - Epoch 25/50, Iter 46200: Loss = 0.022055502980947495, lr = 0.0001
|
464 |
+
[10:35:55] INFO - Epoch 25/50, Iter 46300: Loss = 0.013236483559012413, lr = 0.0001
|
465 |
+
[10:36:49] INFO - Epoch 25/50, Iter 46400: Loss = 0.016789842396974564, lr = 0.0001
|
466 |
+
[10:37:44] INFO - Epoch 25/50, Iter 46500: Loss = 0.018810316920280457, lr = 0.0001
|
467 |
+
[10:38:38] INFO - Epoch 25/50, Iter 46600: Loss = 0.01891239359974861, lr = 0.0001
|
468 |
+
[10:39:33] INFO - Epoch 25/50, Iter 46700: Loss = 0.03200780227780342, lr = 0.0001
|
469 |
+
[10:40:28] INFO - Epoch 25/50, Iter 46800: Loss = 0.025489578023552895, lr = 0.0001
|
470 |
+
[10:41:24] INFO - Epoch 26/50, Iter 46900: Loss = 0.02214771881699562, lr = 0.0001
|
471 |
+
[10:43:56] INFO - Epoch 26/50, Iter 47000: Loss = 0.01889549382030964, lr = 0.0001
|
472 |
+
[10:44:51] INFO - Epoch 26/50, Iter 47100: Loss = 0.015227919444441795, lr = 0.0001
|
473 |
+
[10:45:45] INFO - Epoch 26/50, Iter 47200: Loss = 0.01975785568356514, lr = 0.0001
|
474 |
+
[10:46:40] INFO - Epoch 26/50, Iter 47300: Loss = 0.021548938006162643, lr = 0.0001
|
475 |
+
[10:47:35] INFO - Epoch 26/50, Iter 47400: Loss = 0.018300775438547134, lr = 0.0001
|
476 |
+
[10:48:29] INFO - Epoch 26/50, Iter 47500: Loss = 0.02168145403265953, lr = 0.0001
|
477 |
+
[10:49:24] INFO - Epoch 26/50, Iter 47600: Loss = 0.02841881290078163, lr = 0.0001
|
478 |
+
[10:50:18] INFO - Epoch 26/50, Iter 47700: Loss = 0.01804378256201744, lr = 0.0001
|
479 |
+
[10:51:13] INFO - Epoch 26/50, Iter 47800: Loss = 0.026898138225078583, lr = 0.0001
|
480 |
+
[10:52:09] INFO - Epoch 26/50, Iter 47900: Loss = 0.018523452803492546, lr = 0.0001
|
481 |
+
[10:54:40] INFO - Epoch 26/50, Iter 48000: Loss = 0.016216814517974854, lr = 0.0001
|
482 |
+
[10:55:34] INFO - Epoch 26/50, Iter 48100: Loss = 0.02262328565120697, lr = 0.0001
|
483 |
+
[10:56:29] INFO - Epoch 26/50, Iter 48200: Loss = 0.015000266954302788, lr = 0.0001
|
484 |
+
[10:57:25] INFO - Epoch 26/50, Iter 48300: Loss = 0.02180442586541176, lr = 0.0001
|
485 |
+
[10:58:20] INFO - Epoch 26/50, Iter 48400: Loss = 0.025278791785240173, lr = 0.0001
|
486 |
+
[10:59:14] INFO - Epoch 26/50, Iter 48500: Loss = 0.03473420441150665, lr = 0.0001
|
487 |
+
[11:00:09] INFO - Epoch 26/50, Iter 48600: Loss = 0.017245961353182793, lr = 0.0001
|
488 |
+
[11:01:03] INFO - Epoch 26/50, Iter 48700: Loss = 0.03179230913519859, lr = 0.0001
|
489 |
+
[11:01:59] INFO - Epoch 27/50, Iter 48800: Loss = 0.015805833041667938, lr = 0.0001
|
490 |
+
[11:02:54] INFO - Epoch 27/50, Iter 48900: Loss = 0.02080763876438141, lr = 0.0001
|
491 |
+
[11:05:26] INFO - Epoch 27/50, Iter 49000: Loss = 0.020735610276460648, lr = 0.0001
|
492 |
+
[11:06:21] INFO - Epoch 27/50, Iter 49100: Loss = 0.024737179279327393, lr = 0.0001
|
493 |
+
[11:07:16] INFO - Epoch 27/50, Iter 49200: Loss = 0.026094382628798485, lr = 0.0001
|
494 |
+
[11:08:10] INFO - Epoch 27/50, Iter 49300: Loss = 0.021053478121757507, lr = 0.0001
|
495 |
+
[11:09:05] INFO - Epoch 27/50, Iter 49400: Loss = 0.014476573094725609, lr = 0.0001
|
496 |
+
[11:10:01] INFO - Epoch 27/50, Iter 49500: Loss = 0.030272990465164185, lr = 0.0001
|
497 |
+
[11:10:55] INFO - Epoch 27/50, Iter 49600: Loss = 0.022585971280932426, lr = 0.0001
|
498 |
+
[11:11:50] INFO - Epoch 27/50, Iter 49700: Loss = 0.01895831525325775, lr = 0.0001
|
499 |
+
[11:12:44] INFO - Epoch 27/50, Iter 49800: Loss = 0.018344363197684288, lr = 0.0001
|
500 |
+
[11:13:39] INFO - Epoch 27/50, Iter 49900: Loss = 0.022272832691669464, lr = 0.0001
|
501 |
+
[11:16:10] INFO - Epoch 27/50, Iter 50000: Loss = 0.022018130868673325, lr = 0.0001
|
502 |
+
[11:17:06] INFO - Epoch 27/50, Iter 50100: Loss = 0.027774281799793243, lr = 0.0001
|
503 |
+
[11:18:00] INFO - Epoch 27/50, Iter 50200: Loss = 0.014724764972925186, lr = 0.0001
|
504 |
+
[11:18:55] INFO - Epoch 27/50, Iter 50300: Loss = 0.018815312534570694, lr = 0.0001
|
505 |
+
[11:19:50] INFO - Epoch 27/50, Iter 50400: Loss = 0.019056078046560287, lr = 0.0001
|
506 |
+
[11:20:44] INFO - Epoch 27/50, Iter 50500: Loss = 0.01948639005422592, lr = 0.0001
|
507 |
+
[11:21:39] INFO - Epoch 27/50, Iter 50600: Loss = 0.02332192286849022, lr = 0.0001
|
508 |
+
[11:22:35] INFO - Epoch 28/50, Iter 50700: Loss = 0.02340688183903694, lr = 0.0001
|
509 |
+
[11:23:31] INFO - Epoch 28/50, Iter 50800: Loss = 0.02822597697377205, lr = 0.0001
|
510 |
+
[11:24:26] INFO - Epoch 28/50, Iter 50900: Loss = 0.02604568563401699, lr = 0.0001
|
511 |
+
[11:26:58] INFO - Epoch 28/50, Iter 51000: Loss = 0.015130102634429932, lr = 0.0001
|
512 |
+
[11:27:53] INFO - Epoch 28/50, Iter 51100: Loss = 0.020247958600521088, lr = 0.0001
|
513 |
+
[11:28:47] INFO - Epoch 28/50, Iter 51200: Loss = 0.021361518651247025, lr = 0.0001
|
514 |
+
[11:29:42] INFO - Epoch 28/50, Iter 51300: Loss = 0.0154896704480052, lr = 0.0001
|
515 |
+
[11:30:36] INFO - Epoch 28/50, Iter 51400: Loss = 0.020418627187609673, lr = 0.0001
|
516 |
+
[11:31:31] INFO - Epoch 28/50, Iter 51500: Loss = 0.016209501773118973, lr = 0.0001
|
517 |
+
[11:32:26] INFO - Epoch 28/50, Iter 51600: Loss = 0.021547267213463783, lr = 0.0001
|
518 |
+
[11:33:20] INFO - Epoch 28/50, Iter 51700: Loss = 0.03097592294216156, lr = 0.0001
|
519 |
+
[11:34:16] INFO - Epoch 28/50, Iter 51800: Loss = 0.01853656955063343, lr = 0.0001
|
520 |
+
[11:35:11] INFO - Epoch 28/50, Iter 51900: Loss = 0.025320153683423996, lr = 0.0001
|
521 |
+
[11:37:42] INFO - Epoch 28/50, Iter 52000: Loss = 0.01918005384504795, lr = 0.0001
|
522 |
+
[11:38:36] INFO - Epoch 28/50, Iter 52100: Loss = 0.02268061600625515, lr = 0.0001
|
523 |
+
[11:39:32] INFO - Epoch 28/50, Iter 52200: Loss = 0.024810226634144783, lr = 0.0001
|
524 |
+
[11:40:27] INFO - Epoch 28/50, Iter 52300: Loss = 0.02219560742378235, lr = 0.0001
|
525 |
+
[11:41:21] INFO - Epoch 28/50, Iter 52400: Loss = 0.027511518448591232, lr = 0.0001
|
526 |
+
[11:42:16] INFO - Epoch 29/50, Iter 52500: Loss = 0.016894716769456863, lr = 0.0001
|
527 |
+
[11:43:12] INFO - Epoch 29/50, Iter 52600: Loss = 0.01918671280145645, lr = 0.0001
|
528 |
+
[11:44:07] INFO - Epoch 29/50, Iter 52700: Loss = 0.021322811022400856, lr = 0.0001
|
529 |
+
[11:45:03] INFO - Epoch 29/50, Iter 52800: Loss = 0.01693873107433319, lr = 0.0001
|
530 |
+
[11:45:58] INFO - Epoch 29/50, Iter 52900: Loss = 0.028586234897375107, lr = 0.0001
|
531 |
+
[11:48:30] INFO - Epoch 29/50, Iter 53000: Loss = 0.02094537392258644, lr = 0.0001
|
532 |
+
[11:49:25] INFO - Epoch 29/50, Iter 53100: Loss = 0.025890830904245377, lr = 0.0001
|
533 |
+
[11:50:20] INFO - Epoch 29/50, Iter 53200: Loss = 0.019293418154120445, lr = 0.0001
|
534 |
+
[11:51:14] INFO - Epoch 29/50, Iter 53300: Loss = 0.013301231898367405, lr = 0.0001
|
535 |
+
[11:52:10] INFO - Epoch 29/50, Iter 53400: Loss = 0.024367133155465126, lr = 0.0001
|
536 |
+
[11:53:04] INFO - Epoch 29/50, Iter 53500: Loss = 0.013333385810256004, lr = 0.0001
|
537 |
+
[11:53:59] INFO - Epoch 29/50, Iter 53600: Loss = 0.021088868379592896, lr = 0.0001
|
538 |
+
[11:54:53] INFO - Epoch 29/50, Iter 53700: Loss = 0.014782575890421867, lr = 0.0001
|
539 |
+
[11:55:48] INFO - Epoch 29/50, Iter 53800: Loss = 0.019235175102949142, lr = 0.0001
|
540 |
+
[11:56:43] INFO - Epoch 29/50, Iter 53900: Loss = 0.02775110863149166, lr = 0.0001
|
541 |
+
[11:59:15] INFO - Epoch 29/50, Iter 54000: Loss = 0.014202380552887917, lr = 0.0001
|
542 |
+
[12:00:10] INFO - Epoch 29/50, Iter 54100: Loss = 0.021274959668517113, lr = 0.0001
|
543 |
+
[12:01:04] INFO - Epoch 29/50, Iter 54200: Loss = 0.028708720579743385, lr = 0.0001
|
544 |
+
[12:01:59] INFO - Epoch 29/50, Iter 54300: Loss = 0.024009495973587036, lr = 0.0001
|
545 |
+
[12:02:55] INFO - Epoch 30/50, Iter 54400: Loss = 0.018383020535111427, lr = 0.0001
|
546 |
+
[12:03:50] INFO - Epoch 30/50, Iter 54500: Loss = 0.012869146652519703, lr = 0.0001
|
547 |
+
[12:04:46] INFO - Epoch 30/50, Iter 54600: Loss = 0.015052242204546928, lr = 0.0001
|
548 |
+
[12:05:41] INFO - Epoch 30/50, Iter 54700: Loss = 0.021794060245156288, lr = 0.0001
|
549 |
+
[12:06:37] INFO - Epoch 30/50, Iter 54800: Loss = 0.021674180403351784, lr = 0.0001
|
550 |
+
[12:07:31] INFO - Epoch 30/50, Iter 54900: Loss = 0.0307894479483366, lr = 0.0001
|
551 |
+
[12:10:04] INFO - Epoch 30/50, Iter 55000: Loss = 0.023494703695178032, lr = 0.0001
|
552 |
+
[12:10:59] INFO - Epoch 30/50, Iter 55100: Loss = 0.025401834398508072, lr = 0.0001
|
553 |
+
[12:11:54] INFO - Epoch 30/50, Iter 55200: Loss = 0.021761178970336914, lr = 0.0001
|
554 |
+
[12:12:49] INFO - Epoch 30/50, Iter 55300: Loss = 0.02898026630282402, lr = 0.0001
|
555 |
+
[12:13:44] INFO - Epoch 30/50, Iter 55400: Loss = 0.02216275781393051, lr = 0.0001
|
556 |
+
[12:14:38] INFO - Epoch 30/50, Iter 55500: Loss = 0.00930317398160696, lr = 0.0001
|
557 |
+
[12:15:33] INFO - Epoch 30/50, Iter 55600: Loss = 0.024549826979637146, lr = 0.0001
|
558 |
+
[12:16:29] INFO - Epoch 30/50, Iter 55700: Loss = 0.016341213136911392, lr = 0.0001
|
559 |
+
[12:17:23] INFO - Epoch 30/50, Iter 55800: Loss = 0.015864314511418343, lr = 0.0001
|
560 |
+
[12:18:18] INFO - Epoch 30/50, Iter 55900: Loss = 0.034297745674848557, lr = 0.0001
|
561 |
+
[12:20:49] INFO - Epoch 30/50, Iter 56000: Loss = 0.02956249937415123, lr = 0.0001
|
562 |
+
[12:21:44] INFO - Epoch 30/50, Iter 56100: Loss = 0.02114814706146717, lr = 0.0001
|
563 |
+
[12:22:40] INFO - Epoch 30/50, Iter 56200: Loss = 0.0200330913066864, lr = 0.0001
|
564 |
+
[12:23:35] INFO - Epoch 31/50, Iter 56300: Loss = 0.026903297752141953, lr = 0.0001
|
565 |
+
[12:24:31] INFO - Epoch 31/50, Iter 56400: Loss = 0.02994358167052269, lr = 0.0001
|
566 |
+
[12:25:26] INFO - Epoch 31/50, Iter 56500: Loss = 0.016208231449127197, lr = 0.0001
|
567 |
+
[12:26:22] INFO - Epoch 31/50, Iter 56600: Loss = 0.029720913618803024, lr = 0.0001
|
568 |
+
[12:27:16] INFO - Epoch 31/50, Iter 56700: Loss = 0.021973680704832077, lr = 0.0001
|
569 |
+
[12:28:11] INFO - Epoch 31/50, Iter 56800: Loss = 0.017940720543265343, lr = 0.0001
|
570 |
+
[12:29:07] INFO - Epoch 31/50, Iter 56900: Loss = 0.022731531411409378, lr = 0.0001
|
571 |
+
[12:31:40] INFO - Epoch 31/50, Iter 57000: Loss = 0.016729535534977913, lr = 0.0001
|
572 |
+
[12:32:35] INFO - Epoch 31/50, Iter 57100: Loss = 0.026968562975525856, lr = 0.0001
|
573 |
+
[12:33:29] INFO - Epoch 31/50, Iter 57200: Loss = 0.015602253377437592, lr = 0.0001
|
574 |
+
[12:34:24] INFO - Epoch 31/50, Iter 57300: Loss = 0.028429606929421425, lr = 0.0001
|
575 |
+
[12:35:20] INFO - Epoch 31/50, Iter 57400: Loss = 0.021183405071496964, lr = 0.0001
|
576 |
+
[12:36:14] INFO - Epoch 31/50, Iter 57500: Loss = 0.024300210177898407, lr = 0.0001
|
577 |
+
[12:37:09] INFO - Epoch 31/50, Iter 57600: Loss = 0.017051223665475845, lr = 0.0001
|
578 |
+
[12:38:03] INFO - Epoch 31/50, Iter 57700: Loss = 0.016109324991703033, lr = 0.0001
|
579 |
+
[12:38:58] INFO - Epoch 31/50, Iter 57800: Loss = 0.019427603110671043, lr = 0.0001
|
580 |
+
[12:39:53] INFO - Epoch 31/50, Iter 57900: Loss = 0.030664775520563126, lr = 0.0001
|
581 |
+
[12:42:25] INFO - Epoch 31/50, Iter 58000: Loss = 0.021199747920036316, lr = 0.0001
|
582 |
+
[12:43:20] INFO - Epoch 31/50, Iter 58100: Loss = 0.01854831352829933, lr = 0.0001
|
583 |
+
[12:44:16] INFO - Epoch 32/50, Iter 58200: Loss = 0.01928992196917534, lr = 0.0001
|
584 |
+
[12:45:11] INFO - Epoch 32/50, Iter 58300: Loss = 0.018576214089989662, lr = 0.0001
|
585 |
+
[12:46:00] INFO - Epoch 32/50, Iter 58400: Loss = 0.019123028963804245, lr = 0.0001
|
resnet/log/iter_1000.png
ADDED
resnet/log/iter_10000.png
ADDED
resnet/log/iter_11000.png
ADDED
resnet/log/iter_12000.png
ADDED
resnet/log/iter_13000.png
ADDED
resnet/log/iter_14000.png
ADDED
resnet/log/iter_15000.png
ADDED
resnet/log/iter_16000.png
ADDED
resnet/log/iter_17000.png
ADDED
resnet/log/iter_18000.png
ADDED
resnet/log/iter_19000.png
ADDED
resnet/log/iter_2000.png
ADDED
resnet/log/iter_20000.png
ADDED
resnet/log/iter_21000.png
ADDED
resnet/log/iter_22000.png
ADDED
resnet/log/iter_23000.png
ADDED
resnet/log/iter_24000.png
ADDED
resnet/log/iter_25000.png
ADDED
resnet/log/iter_26000.png
ADDED
resnet/log/iter_27000.png
ADDED
resnet/log/iter_28000.png
ADDED
resnet/log/iter_29000.png
ADDED
resnet/log/iter_3000.png
ADDED
resnet/log/iter_30000.png
ADDED
resnet/log/iter_31000.png
ADDED
resnet/log/iter_32000.png
ADDED
resnet/log/iter_33000.png
ADDED
resnet/log/iter_34000.png
ADDED
resnet/log/iter_35000.png
ADDED
resnet/log/iter_36000.png
ADDED
resnet/log/iter_37000.png
ADDED
resnet/log/iter_38000.png
ADDED
resnet/log/iter_39000.png
ADDED
resnet/log/iter_4000.png
ADDED
resnet/log/iter_40000.png
ADDED
resnet/log/iter_41000.png
ADDED
resnet/log/iter_42000.png
ADDED
resnet/log/iter_43000.png
ADDED
resnet/log/iter_44000.png
ADDED
resnet/log/iter_45000.png
ADDED
resnet/log/iter_46000.png
ADDED
resnet/log/iter_47000.png
ADDED
resnet/log/iter_48000.png
ADDED
resnet/log/iter_49000.png
ADDED