dg845 commited on
Commit
70417b7
1 Parent(s): e51f900

Update unidiffuser/sample_v1_test.py

Browse files
Files changed (1) hide show
  1. unidiffuser/sample_v1_test.py +3 -3
unidiffuser/sample_v1_test.py CHANGED
@@ -59,7 +59,7 @@ class TestAutoencoderKL(nn.Module):
59
  def sample(self, moments, noise=None, generator=None, device="cuda"):
60
  mean, logvar = torch.chunk(moments, 2, dim=1)
61
  if noise is None:
62
- noise = randn_tensor(mean.shape, generator=generator, device=generator.device)
63
  noise = noise.to(device)
64
  logvar = torch.clamp(logvar, -30.0, 20.0)
65
  std = torch.exp(0.5 * logvar)
@@ -176,8 +176,8 @@ def prepare_latents(
176
  ):
177
  resolution = config.z_shape[-1] * vae_scale_factor
178
  # Fix device to CPU for reproducibility.
179
- # latent_device = "cpu"
180
- latent_device = device
181
  latent_torch_device = torch.device(latent_device)
182
  generator = torch.Generator(device=latent_torch_device).manual_seed(config.seed)
183
 
 
59
  def sample(self, moments, noise=None, generator=None, device="cuda"):
60
  mean, logvar = torch.chunk(moments, 2, dim=1)
61
  if noise is None:
62
+ noise = randn_tensor(mean.shape, generator=generator)
63
  noise = noise.to(device)
64
  logvar = torch.clamp(logvar, -30.0, 20.0)
65
  std = torch.exp(0.5 * logvar)
 
176
  ):
177
  resolution = config.z_shape[-1] * vae_scale_factor
178
  # Fix device to CPU for reproducibility.
179
+ latent_device = "cpu"
180
+ # latent_device = device
181
  latent_torch_device = torch.device(latent_device)
182
  generator = torch.Generator(device=latent_torch_device).manual_seed(config.seed)
183