Linoy Tsaban commited on
Commit
8b5d4bf
·
1 Parent(s): d19d91b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -28,7 +28,7 @@ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image
28
  def caption_image(input_image):
29
  inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
30
  pixel_values = inputs.pixel_values
31
-
32
  generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
33
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
  return generated_caption
@@ -38,9 +38,9 @@ def caption_image(input_image):
38
  ## DDPM INVERSION AND SAMPLING ##
39
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
40
 
41
- # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
42
  # based on the code in https://github.com/inbarhub/DDPM_inversion
43
-
44
  # returns wt, zs, wts:
45
  # wt - inverted latent
46
  # wts - intermediate inverted latents
@@ -50,7 +50,7 @@ def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta
50
 
51
  # vae encode image
52
  with inference_mode():
53
- w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
54
 
55
  # find Zs and wts - forward process
56
  wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
@@ -61,10 +61,10 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
61
 
62
  # reverse process (via Zs and wT)
63
  w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
64
-
65
  # vae decode image
66
  with inference_mode():
67
- x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
68
  if x0_dec.dim()<4:
69
  x0_dec = x0_dec[None,:,:,:]
70
  img = image_grid(x0_dec)
@@ -142,7 +142,7 @@ def edit(input_image,
142
  src_cfg_scale):
143
 
144
  if do_inversion or randomize_seed:
145
- x0 = load_512(input_image, device=device)
146
  # invert and retrieve noise maps and latent
147
  zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
148
  wts = gr.State(value=wts_tensor)
 
28
  def caption_image(input_image):
29
  inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
30
  pixel_values = inputs.pixel_values
31
+
32
  generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
33
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
  return generated_caption
 
38
  ## DDPM INVERSION AND SAMPLING ##
39
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
40
 
41
+ # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
42
  # based on the code in https://github.com/inbarhub/DDPM_inversion
43
+
44
  # returns wt, zs, wts:
45
  # wt - inverted latent
46
  # wts - intermediate inverted latents
 
50
 
51
  # vae encode image
52
  with inference_mode():
53
+ w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215)
54
 
55
  # find Zs and wts - forward process
56
  wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
 
61
 
62
  # reverse process (via Zs and wT)
63
  w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
64
+
65
  # vae decode image
66
  with inference_mode():
67
+ x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
68
  if x0_dec.dim()<4:
69
  x0_dec = x0_dec[None,:,:,:]
70
  img = image_grid(x0_dec)
 
142
  src_cfg_scale):
143
 
144
  if do_inversion or randomize_seed:
145
+ x0 = load_512(input_image, device=device).to(torch.float16)
146
  # invert and retrieve noise maps and latent
147
  zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
148
  wts = gr.State(value=wts_tensor)