Linoy Tsaban commited on
Commit
9cd2450
1 Parent(s): f634660

Update app.py

Browse files

splitting the CTA to 2 - invert and edit

Files changed (1) hide show
  1. app.py +72 -14
app.py CHANGED
@@ -96,24 +96,24 @@ def get_example():
96
  'examples/ddpm_sega_glass_walls_gian_elephant.png'
97
  ]]
98
  return case
99
-
100
- def edit(input_image,
 
 
 
101
  src_prompt ="",
102
  tar_prompt="",
103
  steps=100,
104
  # src_cfg_scale,
105
  skip=36,
106
  tar_cfg_scale=15,
107
- edit_concept="",
108
- sega_edit_guidance=0,
109
- warm_up=None,
110
  # neg_guidance=False,
111
  left = 0,
112
  right = 0,
113
  top = 0,
114
- bottom = 0):
115
-
116
- # offsets=(0,0,0,0)
117
  x0 = load_512(input_image, left,right, top, bottom, device)
118
 
119
 
@@ -127,9 +127,48 @@ def edit(input_image,
127
  #pure DDPM output
128
  pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
129
  cfg_scale_tar=tar_cfg_scale, skip=skip)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- if not edit_concept or not sega_edit_guidance:
132
- return pure_ddpm_out, pure_ddpm_out
 
 
 
 
 
 
 
 
133
 
134
  # SEGA
135
  # parse concepts and neg guidance
@@ -169,7 +208,7 @@ def edit(input_image,
169
  num_images_per_prompt=1,
170
  num_inference_steps=steps,
171
  use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
172
- return pure_ddpm_out,sega_out.images[0]
173
 
174
  ########
175
  # demo #
@@ -206,7 +245,8 @@ with gr.Blocks() as demo:
206
 
207
  with gr.Row():
208
  with gr.Column(scale=1, min_width=100):
209
- generate_button = gr.Button("Run")
 
210
 
211
  with gr.Accordion("Advanced Options", open=False):
212
  with gr.Row():
@@ -236,7 +276,25 @@ with gr.Blocks() as demo:
236
 
237
  # gr.Markdown(help_text)
238
 
239
- generate_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  fn=edit,
241
  inputs=[input_image,
242
  src_prompt,
@@ -254,7 +312,7 @@ with gr.Blocks() as demo:
254
  top,
255
  bottom
256
  ],
257
- outputs=[ddpm_edited_image, sega_edited_image],
258
  )
259
 
260
  gr.Examples(
 
96
  'examples/ddpm_sega_glass_walls_gian_elephant.png'
97
  ]]
98
  return case
99
+
100
+ inversion_map = dict()
101
+
102
+ def invert_and_reconstruct(
103
+ input_image,
104
  src_prompt ="",
105
  tar_prompt="",
106
  steps=100,
107
  # src_cfg_scale,
108
  skip=36,
109
  tar_cfg_scale=15,
 
 
 
110
  # neg_guidance=False,
111
  left = 0,
112
  right = 0,
113
  top = 0,
114
+ bottom = 0
115
+ ):
116
+ # offsets=(0,0,0,0)
117
  x0 = load_512(input_image, left,right, top, bottom, device)
118
 
119
 
 
127
  #pure DDPM output
128
  pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
129
  cfg_scale_tar=tar_cfg_scale, skip=skip)
130
+ inversion_map['wt'] = wt
131
+ inversion_map['zs'] = zs
132
+ inversion_map['wts'] = wts
133
+
134
+ return pure_ddpm_out
135
+
136
+ def edit(input_image,
137
+ src_prompt ="",
138
+ tar_prompt="",
139
+ steps=100,
140
+ # src_cfg_scale,
141
+ skip=36,
142
+ tar_cfg_scale=15,
143
+ edit_concept="",
144
+ sega_edit_guidance=0,
145
+ warm_up=None,
146
+ # neg_guidance=False,
147
+ left = 0,
148
+ right = 0,
149
+ top = 0,
150
+ bottom = 0):
151
+
152
+ # # offsets=(0,0,0,0)
153
+ # x0 = load_512(input_image, left,right, top, bottom, device)
154
+
155
+
156
+ # # invert
157
+ # # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
158
+ # wt, zs, wts = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps)
159
+
160
+ # latnets = wts[skip].expand(1, -1, -1, -1)
161
 
162
+
163
+ # #pure DDPM output
164
+ # pure_ddpm_out = sample(wt, zs, wts, prompt_tar=tar_prompt,
165
+ # cfg_scale_tar=tar_cfg_scale, skip=skip)
166
+
167
+ # if not edit_concept or not sega_edit_guidance:
168
+ # return pure_ddpm_out, pure_ddpm_out
169
+ if not bool(inversion_map):
170
+ raise gr.Error("Must invert before editing")
171
+ wt, zs, wts = inversion_map['wt'],inversion_map['zs'],inversion_map['wts']
172
 
173
  # SEGA
174
  # parse concepts and neg guidance
 
208
  num_images_per_prompt=1,
209
  num_inference_steps=steps,
210
  use_ddpm=True, wts=wts, zs=zs[skip:], **editing_args)
211
+ return sega_out.images[0]
212
 
213
  ########
214
  # demo #
 
245
 
246
  with gr.Row():
247
  with gr.Column(scale=1, min_width=100):
248
+ invert_button = gr.Button("Invert")
249
+ edit_button = gr.Button("Edit")
250
 
251
  with gr.Accordion("Advanced Options", open=False):
252
  with gr.Row():
 
276
 
277
  # gr.Markdown(help_text)
278
 
279
+ invert_button.click(
280
+ fn=invert,
281
+ inputs=[input_image,
282
+ src_prompt,
283
+ tar_prompt,
284
+ steps,
285
+ # src_cfg_scale,
286
+ skip,
287
+ tar_cfg_scale,
288
+ # neg_guidance,
289
+ left,
290
+ right,
291
+ top,
292
+ bottom
293
+ ],
294
+ outputs=[ddpm_edited_image],
295
+ )
296
+
297
+ edit_button.click(
298
  fn=edit,
299
  inputs=[input_image,
300
  src_prompt,
 
312
  top,
313
  bottom
314
  ],
315
+ outputs=[sega_edited_image],
316
  )
317
 
318
  gr.Examples(