Update unidiffuser/sample_v1_test.py
Browse files
unidiffuser/sample_v1_test.py
CHANGED
@@ -59,7 +59,7 @@ class TestAutoencoderKL(nn.Module):
|
|
59 |
def sample(self, moments, noise=None, generator=None, device="cuda"):
|
60 |
mean, logvar = torch.chunk(moments, 2, dim=1)
|
61 |
if noise is None:
|
62 |
-
noise = randn_tensor(mean.shape, generator=generator
|
63 |
noise = noise.to(device)
|
64 |
logvar = torch.clamp(logvar, -30.0, 20.0)
|
65 |
std = torch.exp(0.5 * logvar)
|
@@ -176,8 +176,8 @@ def prepare_latents(
|
|
176 |
):
|
177 |
resolution = config.z_shape[-1] * vae_scale_factor
|
178 |
# Fix device to CPU for reproducibility.
|
179 |
-
|
180 |
-
latent_device = device
|
181 |
latent_torch_device = torch.device(latent_device)
|
182 |
generator = torch.Generator(device=latent_torch_device).manual_seed(config.seed)
|
183 |
|
|
|
59 |
def sample(self, moments, noise=None, generator=None, device="cuda"):
|
60 |
mean, logvar = torch.chunk(moments, 2, dim=1)
|
61 |
if noise is None:
|
62 |
+
noise = randn_tensor(mean.shape, generator=generator)
|
63 |
noise = noise.to(device)
|
64 |
logvar = torch.clamp(logvar, -30.0, 20.0)
|
65 |
std = torch.exp(0.5 * logvar)
|
|
|
176 |
):
|
177 |
resolution = config.z_shape[-1] * vae_scale_factor
|
178 |
# Fix device to CPU for reproducibility.
|
179 |
+
latent_device = "cpu"
|
180 |
+
# latent_device = device
|
181 |
latent_torch_device = torch.device(latent_device)
|
182 |
generator = torch.Generator(device=latent_torch_device).manual_seed(config.seed)
|
183 |
|