linoyts HF staff commited on
Commit
efbe74e
Β·
verified Β·
1 Parent(s): 266a724

add inversion [WIP] (#3)

Browse files

- add inversion [WIP] (7064ccacec247e1942f14fc82053553fc4c4f0bb)
- Update app.py (b867d00c8257250adeeb4955df03ba66ec66aec5)
- Update clip_slider_pipeline.py (d9d655868b3b10b688060ceb4b2de9e2815bf310)
- Update app.py (127f1911abbee95d7fa6022517f4c265d6240f45)

Files changed (2) hide show
  1. app.py +70 -1
  2. clip_slider_pipeline.py +9 -2
app.py CHANGED
@@ -71,7 +71,9 @@ def generate(slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale
71
  avg_diff_x_1, avg_diff_x_2,
72
  avg_diff_y_1, avg_diff_y_2,
73
  img2img_type = None, img = None,
74
- controlnet_scale= None, ip_adapter_scale=None):
 
 
75
 
76
  start_time = time.time()
77
  # check if avg diff for directions need to be re-calculated
@@ -101,6 +103,8 @@ def generate(slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale
101
  image = clip_slider.generate(prompt, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=(avg_diff_0,avg_diff_1), avg_diff_2nd=(avg_diff_2nd_0,avg_diff_2nd_1))
102
  elif img2img_type=="ip adapter" and img is not None:
103
  image = clip_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=(avg_diff_0,avg_diff_1), avg_diff_2nd=(avg_diff_2nd_0,avg_diff_2nd_1))
 
 
104
  else: # text to image
105
  image = clip_slider.generate(prompt, guidance_scale=guidance_scale, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=(avg_diff_0,avg_diff_1), avg_diff_2nd=(avg_diff_2nd_0,avg_diff_2nd_1))
106
 
@@ -153,6 +157,18 @@ def update_y(x,y,prompt, seed, steps,
153
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
154
  return image
155
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  css = '''
157
  #group {
158
  position: relative;
@@ -188,6 +204,10 @@ with gr.Blocks(css=css) as demo:
188
  avg_diff_x_2 = gr.State()
189
  avg_diff_y_1 = gr.State()
190
  avg_diff_y_2 = gr.State()
 
 
 
 
191
 
192
  with gr.Tab("text2image"):
193
  with gr.Row():
@@ -257,13 +277,62 @@ with gr.Blocks(css=css) as demo:
257
  value=0.8,
258
  )
259
  seed_a = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  submit.click(fn=generate,
262
  inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2],
263
  outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
 
 
 
 
 
264
 
265
  generate_butt.click(fn=update_scales, inputs=[x,y, prompt, seed, steps, guidance_scale, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
266
  generate_butt_a.click(fn=update_scales, inputs=[x_a,y_a, prompt_a, seed_a, steps_a, guidance_scale_a, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale], outputs=[output_image_a])
 
267
  #x.change(fn=update_scales, inputs=[x,y, prompt, seed, steps, guidance_scale, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
268
  #y.change(fn=update_scales, inputs=[x,y, prompt, seed, steps, guidance_scale, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
269
  submit_a.click(fn=generate,
 
71
  avg_diff_x_1, avg_diff_x_2,
72
  avg_diff_y_1, avg_diff_y_2,
73
  img2img_type = None, img = None,
74
+ controlnet_scale= None, ip_adapter_scale=None,
75
+ edit_threshold=None, edit_guidance_scale = None,
76
+ init_latents=None, zs=None):
77
 
78
  start_time = time.time()
79
  # check if avg diff for directions need to be re-calculated
 
103
  image = clip_slider.generate(prompt, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=(avg_diff_0,avg_diff_1), avg_diff_2nd=(avg_diff_2nd_0,avg_diff_2nd_1))
104
  elif img2img_type=="ip adapter" and img is not None:
105
  image = clip_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=(avg_diff_0,avg_diff_1), avg_diff_2nd=(avg_diff_2nd_0,avg_diff_2nd_1))
106
+ elif img2img_type=="inversion":
107
+ image = clip_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=(avg_diff_0,avg_diff_1), avg_diff_2nd=(avg_diff_2nd_0,avg_diff_2nd_1), init_latents = init_latents, zs=zs)
108
  else: # text to image
109
  image = clip_slider.generate(prompt, guidance_scale=guidance_scale, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=(avg_diff_0,avg_diff_1), avg_diff_2nd=(avg_diff_2nd_0,avg_diff_2nd_1))
110
 
 
157
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
158
  return image
159
 
160
+ @spaces.GPU
161
+ def invert(image, num_inversion_steps=50, skip=0.3):
162
+ _ = clip_slider_inv.pipe.invert(
163
+ source_prompt = "",
164
+ image = image,
165
+ num_inversion_steps = num_inversion_steps,
166
+ skip = skip
167
+ )
168
+ return clip_slider_inv.pipe.init_latents, lip_slider_inv.pipe.zs
169
+
170
+ def reset_do_inversion():
171
+ return True
172
  css = '''
173
  #group {
174
  position: relative;
 
204
  avg_diff_x_2 = gr.State()
205
  avg_diff_y_1 = gr.State()
206
  avg_diff_y_2 = gr.State()
207
+
208
+ do_inversion = gr.State()
209
+ init_latents = gr.State()
210
+ zs = gr.State()
211
 
212
  with gr.Tab("text2image"):
213
  with gr.Row():
 
277
  value=0.8,
278
  )
279
  seed_a = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
280
+
281
+ with gr.Tab(label="inversion"):
282
+ with gr.Row():
283
+ with gr.Column():
284
+ image_inv = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512))
285
+ slider_x_inv = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
286
+ slider_y_inv = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
287
+ prompt_inv = gr.Textbox(label="Prompt")
288
+ submit_inv = gr.Button("Submit")
289
+ with gr.Column():
290
+ with gr.Group(elem_id="group"):
291
+ x_inv = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="x", interactive=False)
292
+ y_inv = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
293
+ output_image_inv = gr.Image(elem_id="image_out")
294
+ generate_butt_inv = gr.Button("generate")
295
+
296
+ with gr.Accordion(label="advanced options", open=False):
297
+ iterations_inv = gr.Slider(label = "num iterations", minimum=0, value=200, maximum=300)
298
+ steps_inv = gr.Slider(label = "num inference steps", minimum=1, value=8, maximum=30)
299
+ guidance_scale_inv = gr.Slider(
300
+ label="Guidance scale",
301
+ minimum=0.1,
302
+ maximum=10.0,
303
+ step=0.1,
304
+ value=5,
305
+ )
306
+ # edit_threshold=None, edit_guidance_scale = None,
307
+ # init_latents=None, zs=None
308
+ edit_threshold = gr.Slider(
309
+ label="edit threshold",
310
+ minimum=0.01,
311
+ maximum=0.99,
312
+ step=0.1,
313
+ value=0.3,
314
+ )
315
+ edit_guidance_scale = gr.Slider(
316
+ label="edit guidance scale",
317
+ minimum=0,
318
+ maximum=20,
319
+ step=0.25,
320
+ value=5,
321
+ )
322
+ seed_inv = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
323
 
324
  submit.click(fn=generate,
325
  inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2],
326
  outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
327
+
328
+ image_inv.change(fn=rest_do_inversion, outputs=[do_inversion]).then(fn=invert, inputs=[image_inv], outputs=[init_latents,zs])
329
+ submit_inv.click(fn=generate,
330
+ inputs=[slider_x_inv, slider_y_inv, prompt_inv, seed_inv, iterations_inv, steps_inv, guidance_scale_inv, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2],
331
+ outputs=[x_inv, y_inv, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image_inv])
332
 
333
  generate_butt.click(fn=update_scales, inputs=[x,y, prompt, seed, steps, guidance_scale, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
334
  generate_butt_a.click(fn=update_scales, inputs=[x_a,y_a, prompt_a, seed_a, steps_a, guidance_scale_a, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale], outputs=[output_image_a])
335
+ generate_butt_inv.click(fn=update_scales, inputs=[x,y, prompt, seed, steps, guidance_scale, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, "inversion", None, None, None,edit_threshold, edit_guidance_scale, init_latents, zs], outputs=[output_image])
336
  #x.change(fn=update_scales, inputs=[x,y, prompt, seed, steps, guidance_scale, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
337
  #y.change(fn=update_scales, inputs=[x,y, prompt, seed, steps, guidance_scale, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
338
  submit_a.click(fn=generate,
clip_slider_pipeline.py CHANGED
@@ -209,7 +209,9 @@ class CLIPSliderXL(CLIPSlider):
209
  normalize_scales = False,
210
  correlation_weight_factor = 1.0,
211
  avg_diff = None,
212
- avg_diff_2nd = None,
 
 
213
  **pipeline_kwargs
214
  ):
215
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
@@ -287,8 +289,13 @@ class CLIPSliderXL(CLIPSlider):
287
  print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
288
  torch.manual_seed(seed)
289
  start_time = time.time()
290
- image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
 
 
291
  **pipeline_kwargs).images[0]
 
 
 
292
  end_time = time.time()
293
  print(f"generation time - pipe: {end_time - start_time:.2f} ms")
294
 
 
209
  normalize_scales = False,
210
  correlation_weight_factor = 1.0,
211
  avg_diff = None,
212
+ avg_diff_2nd = None,
213
+ init_latents = None, # inversion
214
+ zs = None, # inversion
215
  **pipeline_kwargs
216
  ):
217
  # if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
 
289
  print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
290
  torch.manual_seed(seed)
291
  start_time = time.time()
292
+ if init_latents is not None: # inversion
293
+ image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
294
+ avg_diff=avg_diff, avg_diff_2=avg_diff2, scale=scale,
295
  **pipeline_kwargs).images[0]
296
+ else:
297
+ image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
298
+ **pipeline_kwargs).images[0]
299
  end_time = time.time()
300
  print(f"generation time - pipe: {end_time - start_time:.2f} ms")
301