fffiloni commited on
Commit
001fe3b
1 Parent(s): 17e300c

Convert tile to torch.cuda.FloatTensor if it's not already of that type

Browse files
opensora/models/ae/videobase/causal_vae/modeling_causalvae.py CHANGED
@@ -610,6 +610,11 @@ class CausalVAEModel(VideoBaseAE_PL):
610
  i : i + self.tile_latent_min_size,
611
  j : j + self.tile_latent_min_size,
612
  ]
 
 
 
 
 
613
  tile = self.post_quant_conv(tile)
614
  decoded = self.decoder(tile)
615
  row.append(decoded)
 
610
  i : i + self.tile_latent_min_size,
611
  j : j + self.tile_latent_min_size,
612
  ]
613
+
614
+ # Convert tile to torch.cuda.FloatTensor if it's not already of that type
615
+ if tile.dtype != torch.float32:
616
+ tile = tile.float()
617
+
618
  tile = self.post_quant_conv(tile)
619
  decoded = self.decoder(tile)
620
  row.append(decoded)