orpatashnik commited on
Commit
710e5f8
1 Parent(s): f65b8d3

fix inversion

Browse files
Files changed (3) hide show
  1. gradio_app.py +1 -2
  2. main.py +6 -4
  3. src/diffusion_model_wrapper.py +5 -5
gradio_app.py CHANGED
@@ -18,7 +18,6 @@ This demo supports both generated images and real images. To modify a real image
18
  '''
19
 
20
  stable, stable_config = setup(LPMConfig())
21
- stable_for_inversion, _ = setup(LPMConfig())
22
 
23
  def main_pipeline(
24
  prompt: str,
@@ -48,7 +47,7 @@ def main_pipeline(
48
  real_image_path="" if input_image is None else input_image
49
  )
50
 
51
- result_images, result_proxy_words = main(stable, stable_config, stable_for_inversion, args)
52
  result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
53
  result_images = [(im * 255).astype(np.uint8) for im in result_images]
54
  result_images = [Image.fromarray(im) for im in result_images]
18
  '''
19
 
20
  stable, stable_config = setup(LPMConfig())
 
21
 
22
  def main_pipeline(
23
  prompt: str,
47
  real_image_path="" if input_image is None else input_image
48
  )
49
 
50
+ result_images, result_proxy_words = main(stable, stable_config, args)
51
  result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
52
  result_images = [(im * 255).astype(np.uint8) for im in result_images]
53
  result_images = [Image.fromarray(im) for im in result_images]
main.py CHANGED
@@ -1,13 +1,14 @@
1
  import json
2
  import os
 
 
 
3
  import pyrallis
4
  import torch
5
- from dataclasses import dataclass, field
6
  from torch.utils.data import DataLoader
7
  from torchvision.transforms import ToTensor
8
  from torchvision.utils import save_image
9
  from tqdm import tqdm
10
- from typing import List
11
 
12
  from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
13
  generate_original_image
@@ -34,7 +35,7 @@ def setup(args):
34
  return ldm_stable, ldm_stable_config
35
 
36
 
37
- def main(ldm_stable, ldm_stable_config, ldm_stable_inversion, args):
38
 
39
  similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
40
  exp_path = save_args_dict(args, similar_words)
@@ -44,7 +45,8 @@ def main(ldm_stable, ldm_stable_config, ldm_stable_inversion, args):
44
  uncond_embeddings = None
45
 
46
  if args.real_image_path != "":
47
- x_t, uncond_embeddings = invert_image(args, ldm_stable_inversion, ldm_stable_config, prompts, exp_path)
 
48
 
49
  image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings)
50
  save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg")
1
  import json
2
  import os
3
+ from dataclasses import dataclass, field
4
+ from typing import List
5
+
6
  import pyrallis
7
  import torch
 
8
  from torch.utils.data import DataLoader
9
  from torchvision.transforms import ToTensor
10
  from torchvision.utils import save_image
11
  from tqdm import tqdm
 
12
 
13
  from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
14
  generate_original_image
35
  return ldm_stable, ldm_stable_config
36
 
37
 
38
+ def main(ldm_stable, ldm_stable_config, args):
39
 
40
  similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
41
  exp_path = save_args_dict(args, similar_words)
45
  uncond_embeddings = None
46
 
47
  if args.real_image_path != "":
48
+ ldm_stable, ldm_stable_config = setup(args)
49
+ x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path)
50
 
51
  image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings)
52
  save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg")
src/diffusion_model_wrapper.py CHANGED
@@ -1,13 +1,13 @@
1
- import torch
2
- import numpy as np
3
  from typing import Optional, List
4
 
 
 
 
5
  from diffusers import DDIMScheduler, StableDiffusionPipeline
6
  from tqdm import tqdm
7
- from cv2 import dilate
8
 
9
- from src.attention_utils import show_cross_attention
10
  from src.attention_based_segmentation import Segmentor
 
11
  from src.prompt_to_prompt_controllers import DummyController, AttentionStore
12
 
13
 
@@ -136,7 +136,7 @@ class DiffusionModelWrapper:
136
  if self.enbale_attn_controller_changes:
137
  attn = self.controller(attn, is_cross, place_in_unet)
138
 
139
- if is_cross and context[1] is not None and self.prompt_mixing is not None:
140
  attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size)
141
 
142
  if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None:
 
 
1
  from typing import Optional, List
2
 
3
+ import numpy as np
4
+ import torch
5
+ from cv2 import dilate
6
  from diffusers import DDIMScheduler, StableDiffusionPipeline
7
  from tqdm import tqdm
 
8
 
 
9
  from src.attention_based_segmentation import Segmentor
10
+ from src.attention_utils import show_cross_attention
11
  from src.prompt_to_prompt_controllers import DummyController, AttentionStore
12
 
13
 
136
  if self.enbale_attn_controller_changes:
137
  attn = self.controller(attn, is_cross, place_in_unet)
138
 
139
+ if is_cross and self.prompt_mixing is not None and context[1] is not None:
140
  attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size)
141
 
142
  if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None: