orpatashnik commited on
Commit
b5f570d
1 Parent(s): e9baf2d

optimize setup

Browse files
Files changed (2) hide show
  1. gradio_app.py +5 -4
  2. main.py +14 -11
gradio_app.py CHANGED
@@ -1,15 +1,14 @@
1
  from __future__ import annotations
2
 
3
-
4
  import gradio as gr
 
5
  import numpy as np
6
  from PIL import Image
7
 
8
- import nltk
9
  nltk.download('punkt')
10
  nltk.download('averaged_perceptron_tagger')
11
 
12
- from main import LPMConfig, main
13
 
14
  DESCRIPTION = '''# Localizing Object-level Shape Variations with Text-to-Image Diffusion Models
15
  This is a demo for our ''Localizing Object-level Shape Variations with Text-to-Image Diffusion Models'' [paper](https://arxiv.org/abs/2303.11306).
@@ -17,6 +16,8 @@ We introduce a method that generates object-level shape variation for a given im
17
  This demo allows using a real image as well as a generated image. For a real image, a matching prompt is required.
18
  '''
19
 
 
 
20
  def main_pipeline(
21
  prompt: str,
22
  object_of_interest: str,
@@ -47,7 +48,7 @@ def main_pipeline(
47
  real_image_path="" if input_image is None else input_image
48
  )
49
 
50
- result_images, result_proxy_words = main(args)
51
  result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
52
  result_images = [(im * 255).astype(np.uint8) for im in result_images]
53
  result_images = [Image.fromarray(im) for im in result_images]
 
1
  from __future__ import annotations
2
 
 
3
  import gradio as gr
4
+ import nltk
5
  import numpy as np
6
  from PIL import Image
7
 
 
8
  nltk.download('punkt')
9
  nltk.download('averaged_perceptron_tagger')
10
 
11
+ from main import LPMConfig, main, setup
12
 
13
  DESCRIPTION = '''# Localizing Object-level Shape Variations with Text-to-Image Diffusion Models
14
  This is a demo for our ''Localizing Object-level Shape Variations with Text-to-Image Diffusion Models'' [paper](https://arxiv.org/abs/2303.11306).
 
16
  This demo allows using a real image as well as a generated image. For a real image, a matching prompt is required.
17
  '''
18
 
19
+ stable, stable_config = setup(LPMConfig())
20
+
21
  def main_pipeline(
22
  prompt: str,
23
  object_of_interest: str,
 
48
  real_image_path="" if input_image is None else input_image
49
  )
50
 
51
+ result_images, result_proxy_words = main(stable, stable_config, args)
52
  result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
53
  result_images = [(im * 255).astype(np.uint8) for im in result_images]
54
  result_images = [Image.fromarray(im) for im in result_images]
main.py CHANGED
@@ -1,21 +1,20 @@
1
  import json
2
  import os
3
- from dataclasses import dataclass, field
4
- from typing import List
5
-
6
  import pyrallis
7
  import torch
 
8
  from torch.utils.data import DataLoader
9
- from torchvision.utils import save_image
10
  from torchvision.transforms import ToTensor
 
11
  from tqdm import tqdm
 
12
 
13
- from src.prompt_to_prompt_controllers import AttentionStore, AttentionReplace
14
- from src.null_text_inversion import invert_image
15
- from src.prompt_utils import get_proxy_prompts
16
- from src.prompt_mixing import PromptMixing
17
  from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
18
  generate_original_image
 
 
 
 
19
 
20
 
21
  def save_args_dict(args, similar_words):
@@ -29,10 +28,13 @@ def save_args_dict(args, similar_words):
29
 
30
  return exp_path
31
 
32
-
33
- def main(args):
34
  ldm_stable = get_stable_diffusion_model(args)
35
  ldm_stable_config = get_stable_diffusion_config(args)
 
 
 
 
36
 
37
  similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
38
  exp_path = save_args_dict(args, similar_words)
@@ -147,4 +149,5 @@ if __name__ == '__main__':
147
  args = pyrallis.parse(config_class=LPMConfig)
148
 
149
  print(args)
150
- main(args)
 
 
1
  import json
2
  import os
 
 
 
3
  import pyrallis
4
  import torch
5
+ from dataclasses import dataclass, field
6
  from torch.utils.data import DataLoader
 
7
  from torchvision.transforms import ToTensor
8
+ from torchvision.utils import save_image
9
  from tqdm import tqdm
10
+ from typing import List
11
 
 
 
 
 
12
  from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
13
  generate_original_image
14
+ from src.null_text_inversion import invert_image
15
+ from src.prompt_mixing import PromptMixing
16
+ from src.prompt_to_prompt_controllers import AttentionStore, AttentionReplace
17
+ from src.prompt_utils import get_proxy_prompts
18
 
19
 
20
  def save_args_dict(args, similar_words):
 
28
 
29
  return exp_path
30
 
31
+ def setup(args):
 
32
  ldm_stable = get_stable_diffusion_model(args)
33
  ldm_stable_config = get_stable_diffusion_config(args)
34
+ return ldm_stable, ldm_stable_config
35
+
36
+
37
+ def main(ldm_stable, ldm_stable_config, args):
38
 
39
  similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
40
  exp_path = save_args_dict(args, similar_words)
 
149
  args = pyrallis.parse(config_class=LPMConfig)
150
 
151
  print(args)
152
+ stable, stable_config = setup(args)
153
+ main(stable, stable_config, args)