Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
b991bac
1
Parent(s):
be61cf2
update discr
Browse files
score_sde/models/discriminator.py
CHANGED
@@ -252,12 +252,31 @@ class SmallCondAttnDiscriminator(nn.Module):
|
|
252 |
class Discriminator_large(nn.Module):
|
253 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
254 |
|
255 |
-
def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
|
256 |
super().__init__()
|
257 |
# Gaussian random feature embedding layer for time
|
258 |
self.cond_proj = nn.Linear(cond_size, ngf*8)
|
259 |
self.act = act
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
self.t_embed = TimestepEmbedding(
|
262 |
embedding_dim=t_emb_dim,
|
263 |
hidden_dim=t_emb_dim,
|
@@ -317,7 +336,21 @@ class Discriminator_large(nn.Module):
|
|
317 |
out = self.act(out)
|
318 |
|
319 |
out = out.view(out.shape[0], out.shape[1], -1).sum(2)
|
320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
return out
|
322 |
|
323 |
|
|
|
252 |
class Discriminator_large(nn.Module):
|
253 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
254 |
|
255 |
+
def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768, attn_pool=False, attn_pool_kw=None):
|
256 |
super().__init__()
|
257 |
# Gaussian random feature embedding layer for time
|
258 |
self.cond_proj = nn.Linear(cond_size, ngf*8)
|
259 |
self.act = act
|
260 |
+
if attn_pool:
|
261 |
+
if attn_pool_kw is None:
|
262 |
+
attn_pool_kw = dict(
|
263 |
+
depth=1,
|
264 |
+
dim_head = 64,
|
265 |
+
heads = 8,
|
266 |
+
num_latents = 64,
|
267 |
+
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
|
268 |
+
max_seq_len = 512,
|
269 |
+
ff_mult = 4,
|
270 |
+
cosine_sim_attn = False,
|
271 |
+
)
|
272 |
+
self.attn_pool = layers.PerceiverResampler(
|
273 |
+
dim=cond_size,
|
274 |
+
**attn_pool_kw,
|
275 |
+
)
|
276 |
+
max_text_len = 512
|
277 |
+
self.null_text_embed = torch.nn.Parameter(torch.randn(1, max_text_len, cond_size))
|
278 |
+
else:
|
279 |
+
self.attn_pool = None
|
280 |
self.t_embed = TimestepEmbedding(
|
281 |
embedding_dim=t_emb_dim,
|
282 |
hidden_dim=t_emb_dim,
|
|
|
336 |
out = self.act(out)
|
337 |
|
338 |
out = out.view(out.shape[0], out.shape[1], -1).sum(2)
|
339 |
+
|
340 |
+
if self.attn_pool is not None:
|
341 |
+
(cond_pooled, cond, cond_mask) = cond
|
342 |
+
if len(cond_mask.shape) == 2:
|
343 |
+
cond_mask = cond_mask.view(cond_mask.shape[0], cond_mask.shape[1], 1)
|
344 |
+
cond = torch.where(
|
345 |
+
cond_mask,
|
346 |
+
cond,
|
347 |
+
self.null_text_embed[:, :cond.shape[1]]
|
348 |
+
)
|
349 |
+
cond = self.attn_pool(cond)
|
350 |
+
cond = cond.mean(dim=1)
|
351 |
+
cond = self.cond_proj(cond)
|
352 |
+
|
353 |
+
out = self.end_linear(out) + (cond * out).sum(dim=1, keepdim=True)
|
354 |
return out
|
355 |
|
356 |
|