eggarsway commited on
Commit
05c9228
1 Parent(s): 3a84324
Files changed (1) hide show
  1. app.py +75 -53
app.py CHANGED
@@ -5,47 +5,54 @@ import random
5
  import gradio as gr
6
  import itertools
7
  from PIL import Image, ImageFont, ImageDraw
8
- import sys
9
-
10
- sys.path.append("source")
11
-
12
  import DirectedDiffusion
13
 
14
- EX1 = [
15
- "A painting of a tiger, on the wall in the living room",
16
- "0.2,0.6,0.0,0.5",
17
- "1,5",
18
- 5,
19
- 15,
20
- 1.0,
21
- 2094889,
22
- ]
23
 
24
 
25
- def fake_gan(a, b, c):
26
- print(a, b, c)
27
- images = [
28
- (
29
- random.choice(
30
- [
31
- "https://images.unsplash.com/photo-1507003211169-0a1dd7228f2d?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
32
- "https://images.unsplash.com/photo-1554151228-14d9def656e4?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=386&q=80",
33
- "https://images.unsplash.com/photo-1542909168-82c3e7fdca5c?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxzZWFyY2h8MXx8aHVtYW4lMjBmYWNlfGVufDB8fDB8fA%3D%3D&w=1000&q=80",
34
- "https://images.unsplash.com/photo-1546456073-92b9f0a8d413?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=387&q=80",
35
- "https://images.unsplash.com/photo-1601412436009-d964bd02edbc?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=464&q=80",
36
- ]
37
- ),
38
- f"label {i}" if i != 0 else "label" * 50,
39
- )
40
- for i in range(3)
41
- ]
42
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  model_bundle = DirectedDiffusion.AttnEditorUtils.load_all_models(
46
- model_path_diffusion="CompVis/stable-diffusion-v1-4"
47
  )
48
 
 
 
49
 
50
  def directed_diffusion(
51
  in_prompt,
@@ -78,14 +85,12 @@ def directed_diffusion(
78
  is_save_attn=False,
79
  is_save_recons=False,
80
  )
81
- print(img.size)
82
  if is_draw_bbox and in_slider_ddsteps > 0:
83
  for r in roi:
84
  x0, y0, x1, y1 = [int(r_ * 512) for r_ in r]
85
- print(x0, y0, x1, y1)
86
  image_editable = ImageDraw.Draw(img)
87
  image_editable.rectangle(
88
- xy=[x0, y0, x1, y1], outline=(255, 0, 0, 255), width=5
89
  )
90
 
91
  return img
@@ -103,14 +108,14 @@ def run_it(
103
  is_grid_search,
104
  progress=gr.Progress(),
105
  ):
106
-
107
  num_affected_steps = [in_slider_ddsteps]
108
  noise_scale = [in_slider_gcoef]
109
  num_trailing_attn = [in_slider_trailings]
110
  if is_grid_search:
111
  num_affected_steps = [5, 10]
112
- #noise_scale = [1.0, 1.5, 2.5]
113
- #num_trailing_attn = [10, 20, 30, 40]
114
 
115
  param_list = [num_affected_steps, noise_scale, num_trailing_attn]
116
  param_list = list(itertools.product(*param_list))
@@ -145,10 +150,23 @@ def run_it(
145
  ),
146
  )
147
  )
148
- return results
 
 
 
 
 
 
 
149
 
150
 
151
  with gr.Blocks() as demo:
 
 
 
 
 
 
152
  with gr.Row(variant="panel"):
153
  with gr.Column(variant="compact"):
154
  in_prompt = gr.Textbox(
@@ -167,7 +185,7 @@ with gr.Blocks() as demo:
167
  placeholder="e.g., 0.1,0.5,0.3,0.6",
168
  )
169
  in_token_ids = gr.Textbox(
170
- label="Token idices",
171
  show_label=True,
172
  max_lines=1,
173
  placeholder="e.g., 1,2,3",
@@ -178,7 +196,7 @@ with gr.Blocks() as demo:
178
  with gr.Row(variant="compact"):
179
  is_grid_search = gr.Checkbox(
180
  value=False,
181
- label="Grid search? (Checked then sliders are ignored)",
182
  )
183
  is_draw_bbox = gr.Checkbox(
184
  value=True,
@@ -186,15 +204,24 @@ with gr.Blocks() as demo:
186
  )
187
  with gr.Row(variant="compact"):
188
  in_slider_trailings = gr.Slider(
189
- minimum=1, maximum=30, value=10, step=1, label="#trailings"
190
  )
191
  in_slider_ddsteps = gr.Slider(
192
- minimum=0, maximum=20, value=10, step=1, label="#DDSteps"
193
  )
194
  in_slider_gcoef = gr.Slider(
195
- minimum=1, maximum=5, value=1.0, step=0.1, label="GaussianCoef"
196
  )
197
- btn = gr.Button("Generate image").style(full_width=False)
 
 
 
 
 
 
 
 
 
198
 
199
  gallery = gr.Gallery(
200
  label="Generated images", show_label=False, elem_id="gallery"
@@ -211,15 +238,10 @@ with gr.Blocks() as demo:
211
  is_draw_bbox,
212
  is_grid_search,
213
  ]
214
-
215
- btn.click(
216
- run_it,
217
- inputs=args,
218
- outputs=gallery,
219
- )
220
-
221
  examples = gr.Examples(
222
- examples=[EX1],
223
  inputs=args,
224
  )
225
 
 
5
  import gradio as gr
6
  import itertools
7
  from PIL import Image, ImageFont, ImageDraw
 
 
 
 
8
  import DirectedDiffusion
9
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
+ # prompt
13
+ # boundingbox
14
+ # prompt indices for region
15
+ # number of trailing attention
16
+ # number of DD steps
17
+ # gaussian coefficient
18
+ # seed
19
+ EXAMPLES = [
20
+ [
21
+ "A painting of a tiger, on the wall in the living room",
22
+ "0.2,0.6,0.0,0.5",
23
+ "1,5",
24
+ 5,
25
+ 15,
26
+ 1.0,
27
+ 2094889,
28
+ ],
29
+ [
30
+ "a dog diving into a pool in sunny day",
31
+ "0.0,0.5,0.0,0.5",
32
+ "1,2",
33
+ 10,
34
+ 20,
35
+ 5.0,
36
+ 2483964026826,
37
+ ],
38
+ [
39
+ "A red cube above a blue sphere",
40
+ "0.4,0.7,0.0,0.5 0.4,0.7,0.5,1.0",
41
+ "2,3 6,7",
42
+ 10,
43
+ 20,
44
+ 1.0,
45
+ 1213698,
46
+ ],
47
+ ]
48
 
49
 
50
  model_bundle = DirectedDiffusion.AttnEditorUtils.load_all_models(
51
+ model_path_diffusion="../DirectedDiffusion/assets/models/stable-diffusion-v1-4"
52
  )
53
 
54
+ ALL_OUTPUT = []
55
+
56
 
57
  def directed_diffusion(
58
  in_prompt,
 
85
  is_save_attn=False,
86
  is_save_recons=False,
87
  )
 
88
  if is_draw_bbox and in_slider_ddsteps > 0:
89
  for r in roi:
90
  x0, y0, x1, y1 = [int(r_ * 512) for r_ in r]
 
91
  image_editable = ImageDraw.Draw(img)
92
  image_editable.rectangle(
93
+ xy=[x0, x1, y0, y1], outline=(255, 0, 0, 255), width=5
94
  )
95
 
96
  return img
 
108
  is_grid_search,
109
  progress=gr.Progress(),
110
  ):
111
+ global ALL_OUTPUT
112
  num_affected_steps = [in_slider_ddsteps]
113
  noise_scale = [in_slider_gcoef]
114
  num_trailing_attn = [in_slider_trailings]
115
  if is_grid_search:
116
  num_affected_steps = [5, 10]
117
+ noise_scale = [1.0, 1.5, 2.5]
118
+ num_trailing_attn = [10, 20, 30, 40]
119
 
120
  param_list = [num_affected_steps, noise_scale, num_trailing_attn]
121
  param_list = list(itertools.product(*param_list))
 
150
  ),
151
  )
152
  )
153
+ ALL_OUTPUT += results
154
+ return ALL_OUTPUT
155
+
156
+ def clean_gallery():
157
+ global ALL_OUTPUT
158
+ ALL_OUTPUT = []
159
+ return ALL_OUTPUT
160
+
161
 
162
 
163
  with gr.Blocks() as demo:
164
+ gr.Markdown(
165
+ """
166
+ # Directed Diffusion
167
+ Let's pin the object in the prompt as you wish!
168
+ """
169
+ )
170
  with gr.Row(variant="panel"):
171
  with gr.Column(variant="compact"):
172
  in_prompt = gr.Textbox(
 
185
  placeholder="e.g., 0.1,0.5,0.3,0.6",
186
  )
187
  in_token_ids = gr.Textbox(
188
+ label="Token indices",
189
  show_label=True,
190
  max_lines=1,
191
  placeholder="e.g., 1,2,3",
 
196
  with gr.Row(variant="compact"):
197
  is_grid_search = gr.Checkbox(
198
  value=False,
199
+ label="Grid search? (If checked then sliders are ignored)",
200
  )
201
  is_draw_bbox = gr.Checkbox(
202
  value=True,
 
204
  )
205
  with gr.Row(variant="compact"):
206
  in_slider_trailings = gr.Slider(
207
+ minimum=0, maximum=30, value=10, step=1, label="#trailings"
208
  )
209
  in_slider_ddsteps = gr.Slider(
210
+ minimum=0, maximum=30, value=10, step=1, label="#DDSteps"
211
  )
212
  in_slider_gcoef = gr.Slider(
213
+ minimum=0, maximum=10, value=1.0, step=0.1, label="GaussianCoef"
214
  )
215
+ with gr.Row(variant="compact"):
216
+ btn_run = gr.Button("Generate image").style(full_width=True)
217
+ btn_clean = gr.Button("Clean Gallery").style(full_width=True)
218
+
219
+ gr.Markdown(
220
+ """ Note:
221
+ 1) Please click one of the examples below for quick setup.
222
+ 2) if #DDsteps==0, it means the SD process runs without DD.
223
+ """
224
+ )
225
 
226
  gallery = gr.Gallery(
227
  label="Generated images", show_label=False, elem_id="gallery"
 
238
  is_draw_bbox,
239
  is_grid_search,
240
  ]
241
+ btn_run.click(run_it, inputs=args, outputs=gallery)
242
+ btn_clean.click(clean_gallery, outputs=gallery)
 
 
 
 
 
243
  examples = gr.Examples(
244
+ examples=EXAMPLES,
245
  inputs=args,
246
  )
247