misterbrainley commited on
Commit
b28391f
1 Parent(s): 45df722

added utisl

Browse files
Files changed (1) hide show
  1. utils.py +251 -0
utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import numpy as np
3
+ import pickle
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils.data import DataLoader, Dataset
7
+ import torchvision.transforms as T
8
+ from torchvision.utils import make_grid
9
+
10
+
11
+ img_path = '/home/alan/Projects/gen_dnd_art/filtered_images/im128/*pkl'
12
+ img_files = glob.glob(img_path)
13
+
14
+ # determine class names
15
+ labels = np.array([i.split('/')[-1].split('_')[:3] for i in img_files])
16
+ species = np.unique(labels[:, 0]).tolist()
17
+ classes = np.unique(labels[:, 1]).tolist()
18
+ genders = np.unique(labels[:, 2]).tolist()
19
+
20
+ class ImSet(Dataset):
21
+ def __init__(self, img_path=img_path):
22
+ super().__init__()
23
+ self.img_files = glob.glob(img_path)
24
+ self.transform = T.Compose([
25
+ T.ToTensor(),
26
+ T.ColorJitter(0.1, 0.1, 0.1, 0.1),
27
+ T.RandomHorizontalFlip(),
28
+ # add random noise and clip
29
+ lambda x: torch.clip(torch.randn(x.shape) / 20 + x, 0, 1),
30
+ T.Normalize(0.5, 0.5)
31
+ ])
32
+
33
+ def __len__(self):
34
+ return len(self.img_files)
35
+
36
+ def __getitem__(self, i):
37
+ img_file = self.img_files[i]
38
+
39
+ # load image
40
+ with open(img_file, 'rb') as fid:
41
+ img = pickle.load(fid)
42
+
43
+ # apply transforms
44
+ img = self.transform(img).float()
45
+
46
+ # extract class label
47
+ img_fname = img_file.split('/')[-1]
48
+ species_, class_, gender_, _, _ = img_fname.split('_')
49
+ species_ = species.index(species_)
50
+ class_ = classes.index(class_)
51
+ gender_ = genders.index(gender_)
52
+
53
+ return (img_fname, img, species_, class_, gender_)
54
+
55
+ class VariationalEncoder(nn.Module):
56
+ def __init__(self, input_channels=3, latent_size=2048):
57
+ super().__init__()
58
+
59
+ self.latent_size = latent_size
60
+
61
+ self.net = nn.Sequential(
62
+ # 128 -> 63
63
+ nn.Conv2d(input_channels, 8, 4, 2),
64
+ nn.LeakyReLU(0.2),
65
+
66
+ # 63 -> 31
67
+ nn.Conv2d(8, 16, 3, 2),
68
+ nn.LeakyReLU(0.2),
69
+
70
+ # 31 -> 15
71
+ nn.Conv2d(16, 32, 3, 2),
72
+ nn.LeakyReLU(0.2),
73
+
74
+ # 15 -> 7
75
+ nn.Conv2d(32, 64, 3, 2),
76
+ nn.LeakyReLU(0.2),
77
+
78
+ # 7 -> 5
79
+ nn.Conv2d(64, 128, 3, 1),
80
+ nn.LeakyReLU(0.2),
81
+
82
+ # 5 -> 4
83
+ nn.Conv2d(128, 256, 2, 1),
84
+ nn.LeakyReLU(0.2),
85
+
86
+ # 4 -> 3
87
+ nn.Conv2d(256, 512, 2, 1),
88
+ nn.LeakyReLU(0.2),
89
+
90
+ # 3 -> 2
91
+ nn.Conv2d(512, 1024, 2, 1),
92
+ nn.LeakyReLU(0.2),
93
+
94
+ # 2 -> 1
95
+ nn.Conv2d(1024, latent_size, 2, 1),
96
+ nn.LeakyReLU(0.2),
97
+
98
+ nn.Flatten(),
99
+ nn.Linear(latent_size, latent_size),
100
+ nn.Dropout(0.4)
101
+ )
102
+
103
+ # parameters for variational autoencoder
104
+ self.mu = nn.Linear(latent_size, latent_size)
105
+ self.sigma = nn.Linear(latent_size, latent_size)
106
+ self.N = torch.distributions.Normal(0, 1)
107
+ self.N.loc = self.N.loc.cuda()
108
+ self.N.scale = self.N.scale.cuda()
109
+ self.kl = 0
110
+
111
+ def forward(self, x):
112
+ x = self.net(x)
113
+ mu = self.mu(x)
114
+ sigma = torch.exp(self.sigma(x))
115
+ x = mu + sigma * self.N.sample(mu.shape)
116
+ self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
117
+
118
+ return x
119
+
120
+ class ConditionalEncoder(VariationalEncoder):
121
+ def __init__(self, latent_size=2048):
122
+ super().__init__(input_channels=4, latent_size=latent_size)
123
+
124
+ self.emb_species = nn.Embedding(len(species), 128**2 // 3 + 128**2 % 3)
125
+ self.emb_class = nn.Embedding(len(classes), 128**2 // 3)
126
+ self.emb_gender = nn.Embedding(len(genders), 128**2 // 3)
127
+ self.emb_reshape = nn.Unflatten(1, (1, 128, 128))
128
+
129
+
130
+ def forward(self, img, species_, class_, gender_):
131
+ x = self.emb_species(species_)
132
+ y = self.emb_class(class_)
133
+ z = self.emb_gender(gender_)
134
+
135
+ x = torch.concat([x, y, z], dim=1)
136
+ x = self.emb_reshape(x)
137
+
138
+ x = torch.concat([img, x], dim=1)
139
+ x = self.net(x)
140
+
141
+ mu = self.mu(x)
142
+ sigma = torch.exp(self.sigma(x))
143
+ x = mu + sigma * self.N.sample(mu.shape)
144
+ self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
145
+ return x
146
+
147
+
148
+ class Decoder(nn.Module):
149
+ def __init__(self, latent_size=2048):
150
+ super().__init__()
151
+ self.latent_size = latent_size
152
+ self.net = nn.Sequential(
153
+
154
+ nn.Linear(latent_size, latent_size),
155
+ nn.Dropout(0.4),
156
+
157
+ nn.Unflatten(1, (latent_size, 1, 1)),
158
+
159
+ # 1 -> 2
160
+ nn.ConvTranspose2d(latent_size, 1024, 2, 1),
161
+ nn.LeakyReLU(0.2),
162
+
163
+ # 2 -> 3
164
+ nn.ConvTranspose2d(1024, 512, 2, 1),
165
+ nn.LeakyReLU(0.2),
166
+
167
+ # 3 -> 4
168
+ nn.ConvTranspose2d(512, 256, 2, 1),
169
+ nn.LeakyReLU(0.2),
170
+
171
+ # 4 -> 5
172
+ nn.ConvTranspose2d(256, 128, 2, 1),
173
+ nn.LeakyReLU(0.2),
174
+
175
+ # 5 -> 7
176
+ nn.ConvTranspose2d(128, 64, 3, 1),
177
+ nn.LeakyReLU(0.2),
178
+
179
+ # 7 -> 15
180
+ nn.ConvTranspose2d(64, 32, 3, 2),
181
+ nn.LeakyReLU(0.2),
182
+
183
+ # 15 -> 31
184
+ nn.ConvTranspose2d(32, 16, 3, 2),
185
+ nn.LeakyReLU(0.2),
186
+
187
+ # 31 -> 63
188
+ nn.ConvTranspose2d(16, 8, 3, 2),
189
+ nn.LeakyReLU(0.2),
190
+
191
+ # 63 -> 128
192
+ nn.ConvTranspose2d(8, 3, 4, 2),
193
+ nn.Tanh()
194
+ )
195
+
196
+ def forward(self, x):
197
+ return self.net(x)
198
+
199
+
200
+ class ConditionalDecoder(Decoder):
201
+ def __init__(self, latent_size=1024):
202
+ super().__init__(latent_size)
203
+
204
+ self.emb_species = nn.Embedding(len(species), latent_size // 3 + latent_size % 3)
205
+ self.emb_class = nn.Embedding(len(classes), latent_size // 3)
206
+ self.emb_gender = nn.Embedding(len(genders), latent_size // 3)
207
+ self.label_net = nn.Linear(2*latent_size, latent_size)
208
+
209
+ def forward(self, Z, species_, class_, gender_):
210
+ x = self.emb_species(species_)
211
+ y = self.emb_class(class_)
212
+ z = self.emb_gender(gender_)
213
+
214
+ x = torch.concat([Z, x, y, z], dim=1)
215
+ x = self.label_net(x)
216
+ x = self.net(x)
217
+ return x
218
+
219
+
220
+ class VariationalAutoEncoder(nn.Module):
221
+ def __init__(self, latent_size=1024):
222
+ super().__init__()
223
+ self.latent_size = latent_size
224
+ self.enc = VariationalEncoder(latent_size)
225
+ self.dec = Decoder(latent_size)
226
+
227
+ def forward(self, x):
228
+ return self.dec(self.enc(x))
229
+
230
+ class ConditionalVariationalAutoEncoder(nn.Module):
231
+ def __init__(self, latent_size=1024):
232
+ super().__init__()
233
+ self.latent_size = latent_size
234
+ self.enc = ConditionalEncoder(latent_size)
235
+ self.dec = ConditionalDecoder(latent_size)
236
+
237
+ def forward(self, img, species_, class_, gender_):
238
+ Z = self.enc(img, species_, class_, gender_)
239
+ x = self.dec(Z, species_, class_, gender_)
240
+ return x
241
+
242
+ def show_tensor(Z, ax, **kwargs):
243
+ if len(Z.shape) > 3:
244
+ Z = Z[0]
245
+
246
+ if Z.min() < 1:
247
+ Z = (Z + 1) / 2
248
+
249
+ Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0))
250
+ ax.imshow(Z, **kwargs)
251
+ return ax