nkanungo commited on
Commit
fc189d1
1 Parent(s): 2c362e5

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +2 -2
utils.py CHANGED
@@ -230,7 +230,7 @@ def generate_image(prompt, method, seed, loss_apply=False):
230
  latents = torch.randn(
231
  (batch_size, pipe.unet.config.in_channels, height// 8, width //8),
232
  generator = generator,
233
- ).to(torch.float16)
234
 
235
 
236
  latents = latents.to(torch_device)
@@ -243,7 +243,7 @@ def generate_image(prompt, method, seed, loss_apply=False):
243
  latent_model_input = scheduler.scale_model_input(latent_model_input, t)
244
 
245
  with torch.no_grad():
246
- noise_pred = pipe.unet(latent_model_input.to(torch.float16), t, encoder_hidden_states=text_embeddings)["sample"]
247
 
248
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
249
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
230
  latents = torch.randn(
231
  (batch_size, pipe.unet.config.in_channels, height// 8, width //8),
232
  generator = generator,
233
+ ).to(torch.float32)
234
 
235
 
236
  latents = latents.to(torch_device)
 
243
  latent_model_input = scheduler.scale_model_input(latent_model_input, t)
244
 
245
  with torch.no_grad():
246
+ noise_pred = pipe.unet(latent_model_input.to(torch.float32), t, encoder_hidden_states=text_embeddings)["sample"]
247
 
248
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
249
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)