Mehdi Cherti commited on
Commit
b991bac
1 Parent(s): be61cf2

update discr

Browse files
Files changed (1) hide show
  1. score_sde/models/discriminator.py +36 -3
score_sde/models/discriminator.py CHANGED
@@ -252,12 +252,31 @@ class SmallCondAttnDiscriminator(nn.Module):
252
  class Discriminator_large(nn.Module):
253
  """A time-dependent discriminator for large images (CelebA, LSUN)."""
254
 
255
- def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
256
  super().__init__()
257
  # Gaussian random feature embedding layer for time
258
  self.cond_proj = nn.Linear(cond_size, ngf*8)
259
  self.act = act
260
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  self.t_embed = TimestepEmbedding(
262
  embedding_dim=t_emb_dim,
263
  hidden_dim=t_emb_dim,
@@ -317,7 +336,21 @@ class Discriminator_large(nn.Module):
317
  out = self.act(out)
318
 
319
  out = out.view(out.shape[0], out.shape[1], -1).sum(2)
320
- out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  return out
322
 
323
 
 
252
  class Discriminator_large(nn.Module):
253
  """A time-dependent discriminator for large images (CelebA, LSUN)."""
254
 
255
+ def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768, attn_pool=False, attn_pool_kw=None):
256
  super().__init__()
257
  # Gaussian random feature embedding layer for time
258
  self.cond_proj = nn.Linear(cond_size, ngf*8)
259
  self.act = act
260
+ if attn_pool:
261
+ if attn_pool_kw is None:
262
+ attn_pool_kw = dict(
263
+ depth=1,
264
+ dim_head = 64,
265
+ heads = 8,
266
+ num_latents = 64,
267
+ num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
268
+ max_seq_len = 512,
269
+ ff_mult = 4,
270
+ cosine_sim_attn = False,
271
+ )
272
+ self.attn_pool = layers.PerceiverResampler(
273
+ dim=cond_size,
274
+ **attn_pool_kw,
275
+ )
276
+ max_text_len = 512
277
+ self.null_text_embed = torch.nn.Parameter(torch.randn(1, max_text_len, cond_size))
278
+ else:
279
+ self.attn_pool = None
280
  self.t_embed = TimestepEmbedding(
281
  embedding_dim=t_emb_dim,
282
  hidden_dim=t_emb_dim,
 
336
  out = self.act(out)
337
 
338
  out = out.view(out.shape[0], out.shape[1], -1).sum(2)
339
+
340
+ if self.attn_pool is not None:
341
+ (cond_pooled, cond, cond_mask) = cond
342
+ if len(cond_mask.shape) == 2:
343
+ cond_mask = cond_mask.view(cond_mask.shape[0], cond_mask.shape[1], 1)
344
+ cond = torch.where(
345
+ cond_mask,
346
+ cond,
347
+ self.null_text_embed[:, :cond.shape[1]]
348
+ )
349
+ cond = self.attn_pool(cond)
350
+ cond = cond.mean(dim=1)
351
+ cond = self.cond_proj(cond)
352
+
353
+ out = self.end_linear(out) + (cond * out).sum(dim=1, keepdim=True)
354
  return out
355
 
356