Mehdi Cherti commited on
Commit
bc53ac3
1 Parent(s): c81908d

support cond attn based discriminator

Browse files
pytorch_fid/fid_score.py CHANGED
@@ -148,7 +148,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', resize
148
 
149
  for batch in tqdm(dataloader):
150
  batch = batch.to(device)
151
- print(batch.shape, batch.min(), batch.max)
152
  with torch.no_grad():
153
  pred = model(batch)[0]
154
 
 
148
 
149
  for batch in tqdm(dataloader):
150
  batch = batch.to(device)
151
+ #print(batch.shape, batch.min(), batch.max)
152
  with torch.no_grad():
153
  pred = model(batch)[0]
154
 
score_sde/models/discriminator.py CHANGED
@@ -167,6 +167,87 @@ class Discriminator_small(nn.Module):
167
 
168
  return out
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  class Discriminator_large(nn.Module):
172
  """A time-dependent discriminator for large images (CelebA, LSUN)."""
@@ -239,3 +320,81 @@ class Discriminator_large(nn.Module):
239
  out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
240
  return out
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  return out
169
 
170
+ class SmallCondAttnDiscriminator(nn.Module):
171
+ """A time-dependent discriminator for small images (CIFAR10, StackMNIST)."""
172
+
173
+ def __init__(self, nc = 3, ngf = 64, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
174
+ super().__init__()
175
+ # Gaussian random feature embedding layer for time
176
+ self.act = act
177
+ self.cond_attn = layers.CondAttnBlock(ngf*8, cond_size, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False)
178
+
179
+ self.t_embed = TimestepEmbedding(
180
+ embedding_dim=t_emb_dim,
181
+ hidden_dim=t_emb_dim,
182
+ output_dim=t_emb_dim,
183
+ act=act,
184
+ )
185
+
186
+
187
+
188
+ # Encoding layers where the resolution decreases
189
+ self.start_conv = conv2d(nc,ngf*2,1, padding=0)
190
+ self.conv1 = DownConvBlock(ngf*2, ngf*2, t_emb_dim = t_emb_dim,act=act)
191
+
192
+ self.conv2 = DownConvBlock(ngf*2, ngf*4, t_emb_dim = t_emb_dim, downsample=True,act=act)
193
+
194
+
195
+ self.conv3 = DownConvBlock(ngf*4, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
196
+
197
+
198
+ self.conv4 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
199
+
200
+
201
+ self.final_conv = conv2d(ngf*8 + 1, ngf*8, 3,padding=1, init_scale=0.)
202
+ self.end_linear = dense(ngf*8, 1)
203
+ self.end_linear_cond = dense(ngf*8, 1)
204
+ #self.gn_cond = nn.GroupNorm(num_groups=32, num_channels=ngf*8, eps=1e-6)
205
+
206
+ self.stddev_group = 4
207
+ self.stddev_feat = 1
208
+
209
+
210
+ def forward(self, x, t, x_t, cond=None):
211
+ t_embed = self.t_embed(t)
212
+ # if cond is not None:
213
+ # t_embed = t_embed + self.cond_proj(cond)
214
+ t_embed = self.act(t_embed)
215
+ input_x = torch.cat((x, x_t), dim = 1)
216
+
217
+ h0 = self.start_conv(input_x)
218
+ h1 = self.conv1(h0,t_embed)
219
+
220
+ h2 = self.conv2(h1,t_embed)
221
+
222
+ h3 = self.conv3(h2,t_embed)
223
+
224
+
225
+ out = self.conv4(h3,t_embed)
226
+
227
+ batch, channel, height, width = out.shape
228
+ group = min(batch, self.stddev_group)
229
+ stddev = out.view(
230
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
231
+ )
232
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
233
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
234
+ stddev = stddev.repeat(group, 1, height, width)
235
+ out = torch.cat([out, stddev], 1)
236
+
237
+ out = self.final_conv(out)
238
+ out = self.act(out)
239
+
240
+ cond_pooled, cond, cond_mask = cond
241
+
242
+ out_cond = (self.cond_attn(out, cond, cond_mask))
243
+
244
+ out = out.view(out.shape[0], out.shape[1], -1).mean(2)
245
+ out_cond = out_cond.view(out_cond.shape[0], out_cond.shape[1], -1).mean(2)
246
+ out = self.end_linear(out) + self.end_linear_cond(out_cond)
247
+ return out
248
+
249
+
250
+
251
 
252
  class Discriminator_large(nn.Module):
253
  """A time-dependent discriminator for large images (CelebA, LSUN)."""
 
320
  out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
321
  return out
322
 
323
+
324
+ class CondAttnDiscriminator(nn.Module):
325
+ """A time-dependent discriminator for large images (CelebA, LSUN)."""
326
+
327
+ def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
328
+ super().__init__()
329
+ # Gaussian random feature embedding layer for time
330
+ self.act = act
331
+ self.cond_attn = layers.CondAttnBlock(ngf*8, cond_size, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False)
332
+
333
+ self.t_embed = TimestepEmbedding(
334
+ embedding_dim=t_emb_dim,
335
+ hidden_dim=t_emb_dim,
336
+ output_dim=t_emb_dim,
337
+ act=act,
338
+ )
339
+
340
+ self.start_conv = conv2d(nc,ngf*2,1, padding=0)
341
+ self.conv1 = DownConvBlock(ngf*2, ngf*4, t_emb_dim = t_emb_dim, downsample = True, act=act)
342
+
343
+ self.conv2 = DownConvBlock(ngf*4, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
344
+
345
+ self.conv3 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
346
+
347
+
348
+ self.conv4 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
349
+ self.conv5 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
350
+ self.conv6 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
351
+
352
+
353
+ self.final_conv = conv2d(ngf*8 + 1, ngf*8, 3,padding=1)
354
+ self.end_linear = dense(ngf*8, 1)
355
+ self.end_linear_cond = dense(ngf*8, 1)
356
+
357
+ self.stddev_group = 4
358
+ self.stddev_feat = 1
359
+
360
+
361
+ def forward(self, x, t, x_t, cond=None):
362
+ cond_pooled, cond, cond_mask = cond
363
+
364
+ t_embed = self.t_embed(t)
365
+ t_embed = self.act(t_embed)
366
+
367
+ input_x = torch.cat((x, x_t), dim = 1)
368
+
369
+ h = self.start_conv(input_x)
370
+ h = self.conv1(h,t_embed)
371
+
372
+ h = self.conv2(h,t_embed)
373
+
374
+ h = self.conv3(h,t_embed)
375
+ h = self.conv4(h,t_embed)
376
+ h = self.conv5(h,t_embed)
377
+
378
+
379
+ out = self.conv6(h,t_embed)
380
+
381
+ batch, channel, height, width = out.shape
382
+ group = min(batch, self.stddev_group)
383
+ stddev = out.view(
384
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
385
+ )
386
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
387
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
388
+ stddev = stddev.repeat(group, 1, height, width)
389
+ out = torch.cat([out, stddev], 1)
390
+
391
+ out = self.final_conv(out)
392
+ out = self.act(out)
393
+
394
+ out_cond = self.cond_attn(out, cond, cond_mask)
395
+
396
+
397
+ out = out.view(out.shape[0], out.shape[1], -1).mean(2)
398
+ out_cond = out_cond.view(out_cond.shape[0], out_cond.shape[1], -1).mean(2)
399
+ out = self.end_linear(out) + self.end_linear_cond(out_cond)
400
+ return out