LanguageBind commited on
Commit
65ca53b
1 Parent(s): e631027

Update opensora/models/ae/videobase/causal_vae/modeling_causalvae.py

Browse files
opensora/models/ae/videobase/causal_vae/modeling_causalvae.py CHANGED
@@ -316,7 +316,8 @@ class CausalVAEModel(VideoBaseAE_PL):
316
  self.tile_sample_min_size = 256
317
  self.tile_sample_min_size_t = 65
318
  self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
319
- # self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** self.time_compress)) + 1
 
320
  self.tile_overlap_factor = 0.25
321
  self.use_tiling = False
322
 
@@ -374,8 +375,9 @@ class CausalVAEModel(VideoBaseAE_PL):
374
  if self.use_tiling and (
375
  x.shape[-1] > self.tile_sample_min_size
376
  or x.shape[-2] > self.tile_sample_min_size
 
377
  ):
378
- return self.tiled_encode2d(x)
379
  h = self.encoder(x)
380
  moments = self.quant_conv(h)
381
  posterior = DiagonalGaussianDistribution(moments)
@@ -385,8 +387,9 @@ class CausalVAEModel(VideoBaseAE_PL):
385
  if self.use_tiling and (
386
  z.shape[-1] > self.tile_latent_min_size
387
  or z.shape[-2] > self.tile_latent_min_size
 
388
  ):
389
- return self.tiled_decode2d(z)
390
  z = self.post_quant_conv(z)
391
  dec = self.decoder(z)
392
  return dec
@@ -554,7 +557,54 @@ class CausalVAEModel(VideoBaseAE_PL):
554
  ) + b[:, :, :, :, x] * (x / blend_extent)
555
  return b
556
 
557
- def tiled_encode2d(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
559
  blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
560
  row_limit = self.tile_latent_min_size - blend_extent
@@ -590,7 +640,8 @@ class CausalVAEModel(VideoBaseAE_PL):
590
 
591
  moments = torch.cat(result_rows, dim=3)
592
  posterior = DiagonalGaussianDistribution(moments)
593
-
 
594
  return posterior
595
 
596
  def tiled_decode2d(self, z):
 
316
  self.tile_sample_min_size = 256
317
  self.tile_sample_min_size_t = 65
318
  self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(hidden_size_mult) - 1)))
319
+ t_down_ratio = [i for i in encoder_temporal_downsample if len(i) > 0]
320
+ self.tile_latent_min_size_t = int((self.tile_sample_min_size_t-1) / (2 ** len(t_down_ratio))) + 1
321
  self.tile_overlap_factor = 0.25
322
  self.use_tiling = False
323
 
 
375
  if self.use_tiling and (
376
  x.shape[-1] > self.tile_sample_min_size
377
  or x.shape[-2] > self.tile_sample_min_size
378
+ or x.shape[-3] > self.tile_sample_min_size_t
379
  ):
380
+ return self.tiled_encode(x)
381
  h = self.encoder(x)
382
  moments = self.quant_conv(h)
383
  posterior = DiagonalGaussianDistribution(moments)
 
387
  if self.use_tiling and (
388
  z.shape[-1] > self.tile_latent_min_size
389
  or z.shape[-2] > self.tile_latent_min_size
390
+ or z.shape[-3] > self.tile_latent_min_size_t
391
  ):
392
+ return self.tiled_decode(z)
393
  z = self.post_quant_conv(z)
394
  dec = self.decoder(z)
395
  return dec
 
557
  ) + b[:, :, :, :, x] * (x / blend_extent)
558
  return b
559
 
560
+ def tiled_encode(self, x):
561
+ t = x.shape[2]
562
+ t_chunk_idx = [i for i in range(0, t, self.tile_sample_min_size_t-1)]
563
+ if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
564
+ t_chunk_start_end = [[0, t]]
565
+ else:
566
+ t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1] for i in range(len(t_chunk_idx)-1)]
567
+ if t_chunk_start_end[-1][-1] > t:
568
+ t_chunk_start_end[-1][-1] = t
569
+ elif t_chunk_start_end[-1][-1] < t:
570
+ last_start_end = [t_chunk_idx[-1], t]
571
+ t_chunk_start_end.append(last_start_end)
572
+ moments = []
573
+ for idx, (start, end) in enumerate(t_chunk_start_end):
574
+ chunk_x = x[:, :, start: end]
575
+ if idx != 0:
576
+ moment = self.tiled_encode2d(chunk_x, return_moments=True)[:, :, 1:]
577
+ else:
578
+ moment = self.tiled_encode2d(chunk_x, return_moments=True)
579
+ moments.append(moment)
580
+ moments = torch.cat(moments, dim=2)
581
+ posterior = DiagonalGaussianDistribution(moments)
582
+ return posterior
583
+
584
+ def tiled_decode(self, x):
585
+ t = x.shape[2]
586
+ t_chunk_idx = [i for i in range(0, t, self.tile_latent_min_size_t-1)]
587
+ if len(t_chunk_idx) == 1 and t_chunk_idx[0] == 0:
588
+ t_chunk_start_end = [[0, t]]
589
+ else:
590
+ t_chunk_start_end = [[t_chunk_idx[i], t_chunk_idx[i+1]+1] for i in range(len(t_chunk_idx)-1)]
591
+ if t_chunk_start_end[-1][-1] > t:
592
+ t_chunk_start_end[-1][-1] = t
593
+ elif t_chunk_start_end[-1][-1] < t:
594
+ last_start_end = [t_chunk_idx[-1], t]
595
+ t_chunk_start_end.append(last_start_end)
596
+ dec_ = []
597
+ for idx, (start, end) in enumerate(t_chunk_start_end):
598
+ chunk_x = x[:, :, start: end]
599
+ if idx != 0:
600
+ dec = self.tiled_decode2d(chunk_x)[:, :, 1:]
601
+ else:
602
+ dec = self.tiled_decode2d(chunk_x)
603
+ dec_.append(dec)
604
+ dec_ = torch.cat(dec_, dim=2)
605
+ return dec_
606
+
607
+ def tiled_encode2d(self, x, return_moments=False):
608
  overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
609
  blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
610
  row_limit = self.tile_latent_min_size - blend_extent
 
640
 
641
  moments = torch.cat(result_rows, dim=3)
642
  posterior = DiagonalGaussianDistribution(moments)
643
+ if return_moments:
644
+ return moments
645
  return posterior
646
 
647
  def tiled_decode2d(self, z):