hysts HF staff commited on
Commit
e3c9822
1 Parent(s): a853b3a
Files changed (2) hide show
  1. app.py +42 -18
  2. model.py +2 -4
app.py CHANGED
@@ -27,8 +27,11 @@ def process_example(
27
  model_id = 'CompVis/stable-diffusion-v1-4'
28
  num_steps = 50
29
  guidance_scale = 7.5
30
- return model.run(model_id, prompt, indices_to_alter_str, seed,
31
- apply_attend_and_excite, num_steps, guidance_scale)
 
 
 
32
 
33
 
34
  with gr.Blocks(css='style.css') as demo:
@@ -166,12 +169,12 @@ with gr.Blocks(css='style.css') as demo:
166
  cache_examples=os.getenv('CACHE_EXAMPLES') == '1',
167
  examples_per_page=20)
168
 
169
- show_token_indices_button.click(fn=model.get_token_table,
170
- inputs=[
171
- model_id,
172
- prompt,
173
- ],
174
- outputs=token_indices_table)
175
 
176
  inputs = [
177
  model_id,
@@ -182,15 +185,36 @@ with gr.Blocks(css='style.css') as demo:
182
  num_steps,
183
  guidance_scale,
184
  ]
185
- outputs = [
186
- token_indices_table,
187
- result,
188
- ]
189
- prompt.submit(fn=model.run, inputs=inputs, outputs=outputs)
190
- token_indices_str.submit(fn=model.run, inputs=inputs, outputs=outputs)
191
- run_button.click(fn=model.run,
192
- inputs=inputs,
193
- outputs=outputs,
194
- api_name='run')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  demo.queue(max_size=10).launch()
 
27
  model_id = 'CompVis/stable-diffusion-v1-4'
28
  num_steps = 50
29
  guidance_scale = 7.5
30
+
31
+ token_table = model.get_token_table(model_id, prompt)
32
+ result = model.run(model_id, prompt, indices_to_alter_str, seed,
33
+ apply_attend_and_excite, num_steps, guidance_scale)
34
+ return token_table, result
35
 
36
 
37
  with gr.Blocks(css='style.css') as demo:
 
169
  cache_examples=os.getenv('CACHE_EXAMPLES') == '1',
170
  examples_per_page=20)
171
 
172
+ show_token_indices_button.click(
173
+ fn=model.get_token_table,
174
+ inputs=[model_id, prompt],
175
+ outputs=token_indices_table,
176
+ queue=False,
177
+ )
178
 
179
  inputs = [
180
  model_id,
 
185
  num_steps,
186
  guidance_scale,
187
  ]
188
+ prompt.submit(
189
+ fn=model.get_token_table,
190
+ inputs=[model_id, prompt],
191
+ outputs=token_indices_table,
192
+ queue=False,
193
+ ).then(
194
+ fn=model.run,
195
+ inputs=inputs,
196
+ outputs=result,
197
+ )
198
+ token_indices_str.submit(
199
+ fn=model.get_token_table,
200
+ inputs=[model_id, prompt],
201
+ outputs=token_indices_table,
202
+ queue=False,
203
+ ).then(
204
+ fn=model.run,
205
+ inputs=inputs,
206
+ outputs=result,
207
+ )
208
+ run_button.click(
209
+ fn=model.get_token_table,
210
+ inputs=[model_id, prompt],
211
+ outputs=token_indices_table,
212
+ queue=False,
213
+ ).then(
214
+ fn=model.run,
215
+ inputs=inputs,
216
+ outputs=result,
217
+ api_name='run',
218
+ )
219
 
220
  demo.queue(max_size=10).launch()
model.py CHANGED
@@ -56,7 +56,7 @@ class Model:
56
  20: 0.8
57
  },
58
  max_iter_to_alter: int = 25,
59
- ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
60
  generator = torch.Generator(device=self.device).manual_seed(seed)
61
  try:
62
  indices_to_alter = list(map(int, indices_to_alter_str.split(',')))
@@ -65,8 +65,6 @@ class Model:
65
 
66
  self.load_model(model_id)
67
 
68
- token_table = self.get_token_table(model_id, prompt)
69
-
70
  controller = AttentionStore()
71
  config = RunConfig(prompt=prompt,
72
  n_inference_steps=num_steps,
@@ -82,4 +80,4 @@ class Model:
82
  seed=generator,
83
  config=config)
84
 
85
- return token_table, image
 
56
  20: 0.8
57
  },
58
  max_iter_to_alter: int = 25,
59
+ ) -> PIL.Image.Image:
60
  generator = torch.Generator(device=self.device).manual_seed(seed)
61
  try:
62
  indices_to_alter = list(map(int, indices_to_alter_str.split(',')))
 
65
 
66
  self.load_model(model_id)
67
 
 
 
68
  controller = AttentionStore()
69
  config = RunConfig(prompt=prompt,
70
  n_inference_steps=num_steps,
 
80
  seed=generator,
81
  config=config)
82
 
83
+ return image