Linoy Tsaban commited on
Commit
8e8b324
1 Parent(s): afb8388

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -10,6 +10,11 @@ from inversion_utils import *
10
  from torch import autocast, inference_mode
11
  import re
12
 
 
 
 
 
 
13
 
14
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
15
 
@@ -151,25 +156,20 @@ with gr.Blocks(css='style.css') as demo:
151
  seed = 0,
152
  randomized_seed = True):
153
 
154
- if randomized_seed:
155
- seed = random.randint(0, np.iinfo(np.int32).max)
156
 
157
- torch.manual_seed(seed)
158
- # offsets=(0,0,0,0)
159
  x0 = load_512(input_image, device=device)
160
 
161
- if do_inversion:
162
- # invert and retrieve noise maps and latent
163
  zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
164
- # xt = gr.State(value=wts[skip])
165
- # zs = gr.State(value=zs[skip:])
166
  wts = gr.State(value=wts_tensor)
167
  zs = gr.State(value=zs_tensor)
168
  do_inversion = False
169
 
170
- # output = sample(zs.value, xt.value, prompt_tar=tar_prompt, cfg_scale_tar=cfg_scale_tar)
171
- output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=cfg_scale_tar)
172
-
173
  return output, wts, zs, do_inversion
174
 
175
  gr.HTML(intro)
@@ -204,10 +204,13 @@ with gr.Blocks(css='style.css') as demo:
204
  skip = gr.Slider(minimum=0, maximum=40, value=36, step = 1, label="Skip Steps", interactive=True)
205
  cfg_scale_tar = gr.Slider(minimum=7, maximum=18,value=15, label=f"Target Guidance Scale", interactive=True)
206
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
207
- randomize_seed = gr.Checkbox(label='Randomize seed', value=True)
208
 
209
 
210
  edit_button.click(
 
 
 
211
  fn=edit,
212
  inputs=[input_image,
213
  do_inversion, wts, zs,
@@ -217,8 +220,7 @@ with gr.Blocks(css='style.css') as demo:
217
  cfg_scale_src,
218
  cfg_scale_tar,
219
  skip,
220
- seed,
221
- randomize_seed
222
  ],
223
  outputs=[output_image, wts, zs, do_inversion],
224
  )
 
10
  from torch import autocast, inference_mode
11
  import re
12
 
13
+ def randomize_seed_fn(seed, randomize_seed):
14
+ if randomize_seed:
15
+ seed = random.randint(0, np.iinfo(np.int32).max)
16
+ torch.manual_seed(seed)
17
+ return seed
18
 
19
  def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
20
 
 
156
  seed = 0,
157
  randomized_seed = True):
158
 
159
+ # if randomized_seed:
160
+ # seed = random.randint(0, np.iinfo(np.int32).max)
161
 
162
+ # torch.manual_seed(seed)
163
+ # # offsets=(0,0,0,0)
164
  x0 = load_512(input_image, device=device)
165
 
166
+ if do_inversion or randomized_seed:
 
167
  zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
 
 
168
  wts = gr.State(value=wts_tensor)
169
  zs = gr.State(value=zs_tensor)
170
  do_inversion = False
171
 
172
+ output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=cfg_scale_tar)
 
 
173
  return output, wts, zs, do_inversion
174
 
175
  gr.HTML(intro)
 
204
  skip = gr.Slider(minimum=0, maximum=40, value=36, step = 1, label="Skip Steps", interactive=True)
205
  cfg_scale_tar = gr.Slider(minimum=7, maximum=18,value=15, label=f"Target Guidance Scale", interactive=True)
206
  seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
207
+ randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
208
 
209
 
210
  edit_button.click(
211
+ fn = randomize_seed_fn,
212
+ inputs = [seed, randomize_seed],
213
+ outputs = [seed]).then(
214
  fn=edit,
215
  inputs=[input_image,
216
  do_inversion, wts, zs,
 
220
  cfg_scale_src,
221
  cfg_scale_tar,
222
  skip,
223
+ seed,randomized_seed
 
224
  ],
225
  outputs=[output_image, wts, zs, do_inversion],
226
  )