Linoy Tsaban commited on
Commit
0c1b1f7
1 Parent(s): 5edff5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -95,14 +95,14 @@ For faster inference without waiting in queue, you may duplicate the space and u
95
  <p/>"""
96
  with gr.Blocks(css='style.css') as demo:
97
 
98
- def reset_latents():
99
- wts = gr.State(value=False)
100
- zs = gr.State(value=False)
101
- return wts, zs
102
 
103
 
104
  def edit(input_image,
105
  wts, zs,
 
106
  src_prompt ="",
107
  tar_prompt="",
108
  steps=100,
@@ -119,23 +119,25 @@ with gr.Blocks(css='style.css') as demo:
119
  # offsets=(0,0,0,0)
120
  x0 = load_512(input_image, device=device)
121
 
122
- if not wts:
123
  # invert and retrieve noise maps and latent
124
  zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
125
  # xt = gr.State(value=wts[skip])
126
  # zs = gr.State(value=zs[skip:])
127
  wts = gr.State(value=wts_tensor)
128
  zs = gr.State(value=zs_tensor)
 
129
 
130
  # output = sample(zs.value, xt.value, prompt_tar=tar_prompt, cfg_scale_tar=cfg_scale_tar)
131
  output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=cfg_scale_tar)
132
 
133
- return output, wts, zs
134
 
135
  gr.HTML(intro)
136
  # xt = gr.State(value=False)
137
- wts = gr.State(value=False)
138
- zs = gr.State(value=False)
 
139
  with gr.Row():
140
  input_image = gr.Image(label="Input Image", interactive=True)
141
  input_image.style(height=512, width=512)
@@ -179,17 +181,17 @@ with gr.Blocks(css='style.css') as demo:
179
  seed,
180
  randomize_seed
181
  ],
182
- outputs=[output_image, wts, zs],
183
  )
184
 
185
  input_image.change(
186
- fn = reset_latents,
187
- outputs = [wts, zs]
188
  )
189
 
190
  src_prompt.change(
191
- fn = reset_latents,
192
- outputs = [wts, zs]
193
  )
194
 
195
  # skip.change(
 
95
  <p/>"""
96
  with gr.Blocks(css='style.css') as demo:
97
 
98
+ def reset_do_inversion():
99
+ do_inversion = True
100
+ return do_inversion
 
101
 
102
 
103
  def edit(input_image,
104
  wts, zs,
105
+ do_inversion,
106
  src_prompt ="",
107
  tar_prompt="",
108
  steps=100,
 
119
  # offsets=(0,0,0,0)
120
  x0 = load_512(input_image, device=device)
121
 
122
+ if do_inversion:
123
  # invert and retrieve noise maps and latent
124
  zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
125
  # xt = gr.State(value=wts[skip])
126
  # zs = gr.State(value=zs[skip:])
127
  wts = gr.State(value=wts_tensor)
128
  zs = gr.State(value=zs_tensor)
129
+ do_inversion = False
130
 
131
  # output = sample(zs.value, xt.value, prompt_tar=tar_prompt, cfg_scale_tar=cfg_scale_tar)
132
  output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=cfg_scale_tar)
133
 
134
+ return output, wts, zs, do_inversion
135
 
136
  gr.HTML(intro)
137
  # xt = gr.State(value=False)
138
+ wts = gr.State()
139
+ zs = gr.State()
140
+ do_inversion = gr.State(value=True)
141
  with gr.Row():
142
  input_image = gr.Image(label="Input Image", interactive=True)
143
  input_image.style(height=512, width=512)
 
181
  seed,
182
  randomize_seed
183
  ],
184
+ outputs=[output_image, wts, zs, do_inversion],
185
  )
186
 
187
  input_image.change(
188
+ fn = reset_do_inversion,
189
+ outputs = [do_inversion]
190
  )
191
 
192
  src_prompt.change(
193
+ fn = reset_do_inversion,
194
+ outputs = [do_inversion]
195
  )
196
 
197
  # skip.change(