eitanrich commited on
Commit
028b6a1
·
1 Parent(s): 645fba0

VAE: Support retuning intermediate features for 3d perceptual loss

Browse files
xora/models/autoencoders/video_autoencoder.py CHANGED
@@ -310,7 +310,9 @@ class Encoder(nn.Module):
310
  * self.patch_size
311
  )
312
 
313
- def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
 
 
314
  r"""The forward method of the `Encoder` class."""
315
 
316
  downsample_in_time = sample.shape[2] != 1
@@ -332,10 +334,14 @@ class Encoder(nn.Module):
332
  else lambda x: x
333
  )
334
 
 
 
335
  for down_block in self.down_blocks:
336
  sample = checkpoint_fn(down_block)(
337
  sample, downsample_in_time=downsample_in_time
338
  )
 
 
339
 
340
  sample = checkpoint_fn(self.mid_block)(sample)
341
 
@@ -363,6 +369,11 @@ class Encoder(nn.Module):
363
  else:
364
  raise ValueError(f"Invalid input shape: {sample.shape}")
365
 
 
 
 
 
 
366
  return sample
367
 
368
 
 
310
  * self.patch_size
311
  )
312
 
313
+ def forward(
314
+ self, sample: torch.FloatTensor, return_features=False
315
+ ) -> torch.FloatTensor:
316
  r"""The forward method of the `Encoder` class."""
317
 
318
  downsample_in_time = sample.shape[2] != 1
 
334
  else lambda x: x
335
  )
336
 
337
+ if return_features:
338
+ features = []
339
  for down_block in self.down_blocks:
340
  sample = checkpoint_fn(down_block)(
341
  sample, downsample_in_time=downsample_in_time
342
  )
343
+ if return_features:
344
+ features.append(sample)
345
 
346
  sample = checkpoint_fn(self.mid_block)(sample)
347
 
 
369
  else:
370
  raise ValueError(f"Invalid input shape: {sample.shape}")
371
 
372
+ if return_features:
373
+ features.append(
374
+ sample[:, sample.shape[1] // 2, ...]
375
+ ) # Add the latent means as final feature
376
+ return sample, features
377
  return sample
378
 
379