zhiweili commited on
Commit
61f3bdc
·
1 Parent(s): 336094b

fix encode error

Browse files
Files changed (1) hide show
  1. app_ddim.py +3 -1
app_ddim.py CHANGED
@@ -51,7 +51,9 @@ def image_to_image(
51
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
52
 
53
  with torch.no_grad():
54
- latent = basepipeline.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(DEVICE) * 2 - 1)
 
 
55
  l = 0.18215 * latent.latent_dist.sample()
56
  inverted_latents = invert(l, input_image_prompt, num_inference_steps=num_steps)
57
  generated_image = sample(
 
51
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
52
 
53
  with torch.no_grad():
54
+ input_image_tensor = tfms.functional.to_tensor(input_image).unsqueeze(0).to(DEVICE)
55
+ input_image_tensor = input_image_tensor.to(dtype=torch.float16)
56
+ latent = basepipeline.vae.encode(input_image_tensor * 2 - 1)
57
  l = 0.18215 * latent.latent_dist.sample()
58
  inverted_latents = invert(l, input_image_prompt, num_inference_steps=num_steps)
59
  generated_image = sample(