Spaces:
Runtime error
Runtime error
Update utils.py
Browse files
utils.py
CHANGED
@@ -230,7 +230,7 @@ def generate_image(prompt, method, seed, loss_apply=False):
|
|
230 |
latents = torch.randn(
|
231 |
(batch_size, pipe.unet.config.in_channels, height// 8, width //8),
|
232 |
generator = generator,
|
233 |
-
).to(torch.
|
234 |
|
235 |
|
236 |
latents = latents.to(torch_device)
|
@@ -243,7 +243,7 @@ def generate_image(prompt, method, seed, loss_apply=False):
|
|
243 |
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
244 |
|
245 |
with torch.no_grad():
|
246 |
-
noise_pred = pipe.unet(latent_model_input.to(torch.
|
247 |
|
248 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
249 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
230 |
latents = torch.randn(
|
231 |
(batch_size, pipe.unet.config.in_channels, height// 8, width //8),
|
232 |
generator = generator,
|
233 |
+
).to(torch.float32)
|
234 |
|
235 |
|
236 |
latents = latents.to(torch_device)
|
|
|
243 |
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
244 |
|
245 |
with torch.no_grad():
|
246 |
+
noise_pred = pipe.unet(latent_model_input.to(torch.float32), t, encoder_hidden_states=text_embeddings)["sample"]
|
247 |
|
248 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
249 |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|