Spaces:
Runtime error
Runtime error
misterbrainley
commited on
Commit
•
b28391f
1
Parent(s):
45df722
added utisl
Browse files
utils.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import numpy as np
|
3 |
+
import pickle
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.utils.data import DataLoader, Dataset
|
7 |
+
import torchvision.transforms as T
|
8 |
+
from torchvision.utils import make_grid
|
9 |
+
|
10 |
+
|
11 |
+
img_path = '/home/alan/Projects/gen_dnd_art/filtered_images/im128/*pkl'
|
12 |
+
img_files = glob.glob(img_path)
|
13 |
+
|
14 |
+
# determine class names
|
15 |
+
labels = np.array([i.split('/')[-1].split('_')[:3] for i in img_files])
|
16 |
+
species = np.unique(labels[:, 0]).tolist()
|
17 |
+
classes = np.unique(labels[:, 1]).tolist()
|
18 |
+
genders = np.unique(labels[:, 2]).tolist()
|
19 |
+
|
20 |
+
class ImSet(Dataset):
|
21 |
+
def __init__(self, img_path=img_path):
|
22 |
+
super().__init__()
|
23 |
+
self.img_files = glob.glob(img_path)
|
24 |
+
self.transform = T.Compose([
|
25 |
+
T.ToTensor(),
|
26 |
+
T.ColorJitter(0.1, 0.1, 0.1, 0.1),
|
27 |
+
T.RandomHorizontalFlip(),
|
28 |
+
# add random noise and clip
|
29 |
+
lambda x: torch.clip(torch.randn(x.shape) / 20 + x, 0, 1),
|
30 |
+
T.Normalize(0.5, 0.5)
|
31 |
+
])
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.img_files)
|
35 |
+
|
36 |
+
def __getitem__(self, i):
|
37 |
+
img_file = self.img_files[i]
|
38 |
+
|
39 |
+
# load image
|
40 |
+
with open(img_file, 'rb') as fid:
|
41 |
+
img = pickle.load(fid)
|
42 |
+
|
43 |
+
# apply transforms
|
44 |
+
img = self.transform(img).float()
|
45 |
+
|
46 |
+
# extract class label
|
47 |
+
img_fname = img_file.split('/')[-1]
|
48 |
+
species_, class_, gender_, _, _ = img_fname.split('_')
|
49 |
+
species_ = species.index(species_)
|
50 |
+
class_ = classes.index(class_)
|
51 |
+
gender_ = genders.index(gender_)
|
52 |
+
|
53 |
+
return (img_fname, img, species_, class_, gender_)
|
54 |
+
|
55 |
+
class VariationalEncoder(nn.Module):
|
56 |
+
def __init__(self, input_channels=3, latent_size=2048):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.latent_size = latent_size
|
60 |
+
|
61 |
+
self.net = nn.Sequential(
|
62 |
+
# 128 -> 63
|
63 |
+
nn.Conv2d(input_channels, 8, 4, 2),
|
64 |
+
nn.LeakyReLU(0.2),
|
65 |
+
|
66 |
+
# 63 -> 31
|
67 |
+
nn.Conv2d(8, 16, 3, 2),
|
68 |
+
nn.LeakyReLU(0.2),
|
69 |
+
|
70 |
+
# 31 -> 15
|
71 |
+
nn.Conv2d(16, 32, 3, 2),
|
72 |
+
nn.LeakyReLU(0.2),
|
73 |
+
|
74 |
+
# 15 -> 7
|
75 |
+
nn.Conv2d(32, 64, 3, 2),
|
76 |
+
nn.LeakyReLU(0.2),
|
77 |
+
|
78 |
+
# 7 -> 5
|
79 |
+
nn.Conv2d(64, 128, 3, 1),
|
80 |
+
nn.LeakyReLU(0.2),
|
81 |
+
|
82 |
+
# 5 -> 4
|
83 |
+
nn.Conv2d(128, 256, 2, 1),
|
84 |
+
nn.LeakyReLU(0.2),
|
85 |
+
|
86 |
+
# 4 -> 3
|
87 |
+
nn.Conv2d(256, 512, 2, 1),
|
88 |
+
nn.LeakyReLU(0.2),
|
89 |
+
|
90 |
+
# 3 -> 2
|
91 |
+
nn.Conv2d(512, 1024, 2, 1),
|
92 |
+
nn.LeakyReLU(0.2),
|
93 |
+
|
94 |
+
# 2 -> 1
|
95 |
+
nn.Conv2d(1024, latent_size, 2, 1),
|
96 |
+
nn.LeakyReLU(0.2),
|
97 |
+
|
98 |
+
nn.Flatten(),
|
99 |
+
nn.Linear(latent_size, latent_size),
|
100 |
+
nn.Dropout(0.4)
|
101 |
+
)
|
102 |
+
|
103 |
+
# parameters for variational autoencoder
|
104 |
+
self.mu = nn.Linear(latent_size, latent_size)
|
105 |
+
self.sigma = nn.Linear(latent_size, latent_size)
|
106 |
+
self.N = torch.distributions.Normal(0, 1)
|
107 |
+
self.N.loc = self.N.loc.cuda()
|
108 |
+
self.N.scale = self.N.scale.cuda()
|
109 |
+
self.kl = 0
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
x = self.net(x)
|
113 |
+
mu = self.mu(x)
|
114 |
+
sigma = torch.exp(self.sigma(x))
|
115 |
+
x = mu + sigma * self.N.sample(mu.shape)
|
116 |
+
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
|
117 |
+
|
118 |
+
return x
|
119 |
+
|
120 |
+
class ConditionalEncoder(VariationalEncoder):
|
121 |
+
def __init__(self, latent_size=2048):
|
122 |
+
super().__init__(input_channels=4, latent_size=latent_size)
|
123 |
+
|
124 |
+
self.emb_species = nn.Embedding(len(species), 128**2 // 3 + 128**2 % 3)
|
125 |
+
self.emb_class = nn.Embedding(len(classes), 128**2 // 3)
|
126 |
+
self.emb_gender = nn.Embedding(len(genders), 128**2 // 3)
|
127 |
+
self.emb_reshape = nn.Unflatten(1, (1, 128, 128))
|
128 |
+
|
129 |
+
|
130 |
+
def forward(self, img, species_, class_, gender_):
|
131 |
+
x = self.emb_species(species_)
|
132 |
+
y = self.emb_class(class_)
|
133 |
+
z = self.emb_gender(gender_)
|
134 |
+
|
135 |
+
x = torch.concat([x, y, z], dim=1)
|
136 |
+
x = self.emb_reshape(x)
|
137 |
+
|
138 |
+
x = torch.concat([img, x], dim=1)
|
139 |
+
x = self.net(x)
|
140 |
+
|
141 |
+
mu = self.mu(x)
|
142 |
+
sigma = torch.exp(self.sigma(x))
|
143 |
+
x = mu + sigma * self.N.sample(mu.shape)
|
144 |
+
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
class Decoder(nn.Module):
|
149 |
+
def __init__(self, latent_size=2048):
|
150 |
+
super().__init__()
|
151 |
+
self.latent_size = latent_size
|
152 |
+
self.net = nn.Sequential(
|
153 |
+
|
154 |
+
nn.Linear(latent_size, latent_size),
|
155 |
+
nn.Dropout(0.4),
|
156 |
+
|
157 |
+
nn.Unflatten(1, (latent_size, 1, 1)),
|
158 |
+
|
159 |
+
# 1 -> 2
|
160 |
+
nn.ConvTranspose2d(latent_size, 1024, 2, 1),
|
161 |
+
nn.LeakyReLU(0.2),
|
162 |
+
|
163 |
+
# 2 -> 3
|
164 |
+
nn.ConvTranspose2d(1024, 512, 2, 1),
|
165 |
+
nn.LeakyReLU(0.2),
|
166 |
+
|
167 |
+
# 3 -> 4
|
168 |
+
nn.ConvTranspose2d(512, 256, 2, 1),
|
169 |
+
nn.LeakyReLU(0.2),
|
170 |
+
|
171 |
+
# 4 -> 5
|
172 |
+
nn.ConvTranspose2d(256, 128, 2, 1),
|
173 |
+
nn.LeakyReLU(0.2),
|
174 |
+
|
175 |
+
# 5 -> 7
|
176 |
+
nn.ConvTranspose2d(128, 64, 3, 1),
|
177 |
+
nn.LeakyReLU(0.2),
|
178 |
+
|
179 |
+
# 7 -> 15
|
180 |
+
nn.ConvTranspose2d(64, 32, 3, 2),
|
181 |
+
nn.LeakyReLU(0.2),
|
182 |
+
|
183 |
+
# 15 -> 31
|
184 |
+
nn.ConvTranspose2d(32, 16, 3, 2),
|
185 |
+
nn.LeakyReLU(0.2),
|
186 |
+
|
187 |
+
# 31 -> 63
|
188 |
+
nn.ConvTranspose2d(16, 8, 3, 2),
|
189 |
+
nn.LeakyReLU(0.2),
|
190 |
+
|
191 |
+
# 63 -> 128
|
192 |
+
nn.ConvTranspose2d(8, 3, 4, 2),
|
193 |
+
nn.Tanh()
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
return self.net(x)
|
198 |
+
|
199 |
+
|
200 |
+
class ConditionalDecoder(Decoder):
|
201 |
+
def __init__(self, latent_size=1024):
|
202 |
+
super().__init__(latent_size)
|
203 |
+
|
204 |
+
self.emb_species = nn.Embedding(len(species), latent_size // 3 + latent_size % 3)
|
205 |
+
self.emb_class = nn.Embedding(len(classes), latent_size // 3)
|
206 |
+
self.emb_gender = nn.Embedding(len(genders), latent_size // 3)
|
207 |
+
self.label_net = nn.Linear(2*latent_size, latent_size)
|
208 |
+
|
209 |
+
def forward(self, Z, species_, class_, gender_):
|
210 |
+
x = self.emb_species(species_)
|
211 |
+
y = self.emb_class(class_)
|
212 |
+
z = self.emb_gender(gender_)
|
213 |
+
|
214 |
+
x = torch.concat([Z, x, y, z], dim=1)
|
215 |
+
x = self.label_net(x)
|
216 |
+
x = self.net(x)
|
217 |
+
return x
|
218 |
+
|
219 |
+
|
220 |
+
class VariationalAutoEncoder(nn.Module):
|
221 |
+
def __init__(self, latent_size=1024):
|
222 |
+
super().__init__()
|
223 |
+
self.latent_size = latent_size
|
224 |
+
self.enc = VariationalEncoder(latent_size)
|
225 |
+
self.dec = Decoder(latent_size)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
return self.dec(self.enc(x))
|
229 |
+
|
230 |
+
class ConditionalVariationalAutoEncoder(nn.Module):
|
231 |
+
def __init__(self, latent_size=1024):
|
232 |
+
super().__init__()
|
233 |
+
self.latent_size = latent_size
|
234 |
+
self.enc = ConditionalEncoder(latent_size)
|
235 |
+
self.dec = ConditionalDecoder(latent_size)
|
236 |
+
|
237 |
+
def forward(self, img, species_, class_, gender_):
|
238 |
+
Z = self.enc(img, species_, class_, gender_)
|
239 |
+
x = self.dec(Z, species_, class_, gender_)
|
240 |
+
return x
|
241 |
+
|
242 |
+
def show_tensor(Z, ax, **kwargs):
|
243 |
+
if len(Z.shape) > 3:
|
244 |
+
Z = Z[0]
|
245 |
+
|
246 |
+
if Z.min() < 1:
|
247 |
+
Z = (Z + 1) / 2
|
248 |
+
|
249 |
+
Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0))
|
250 |
+
ax.imshow(Z, **kwargs)
|
251 |
+
return ax
|