linoyts HF staff commited on
Commit
eb522ee
1 Parent(s): 78e99c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -21
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- os.system("pip uninstall -y gradio")
3
- os.system('pip install gradio==3.43.1')
4
 
5
  import torch
6
  import torchvision
@@ -21,7 +21,7 @@ from editing import get_direction, debias
21
  from sampling import sample_weights
22
  from lora_w2w import LoRAw2w
23
  from huggingface_hub import snapshot_download
24
-
25
  global device
26
  global generator
27
  global unet
@@ -32,7 +32,7 @@ global noise_scheduler
32
  global network
33
  device = "cuda:0"
34
  generator = torch.Generator(device=device)
35
-
36
 
37
 
38
 
@@ -61,7 +61,7 @@ def sample_model():
61
 
62
 
63
  @torch.no_grad()
64
- def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
65
  global device
66
  global generator
67
  global unet
@@ -113,7 +113,7 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
113
 
114
 
115
  @torch.no_grad()
116
- def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
117
 
118
  global device
119
  global generator
@@ -196,7 +196,7 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
196
  network.proj = torch.nn.Parameter(original_weights)
197
  network.reset()
198
 
199
- return image
200
 
201
 
202
  def sample_then_run():
@@ -435,13 +435,16 @@ with gr.Blocks(css="style.css") as demo:
435
  with gr.Column():
436
  with gr.Row():
437
  with gr.Column():
438
- input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
439
- height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
 
 
440
  with gr.Row():
441
  sample = gr.Button("Sample New Model")
442
  invert_button = gr.Button("Invert")
443
  with gr.Column():
444
- gallery1 = gr.Image(label="Identity from Original Model",height=512, width=512, interactive=False)
 
445
 
446
  prompt1 = gr.Textbox(label="Prompt",
447
  info="Make sure to include 'sks person'" ,
@@ -471,18 +474,22 @@ with gr.Blocks(css="style.css") as demo:
471
 
472
  with gr.Accordion("Advanced Options", open=False):
473
  with gr.Tab("Inversion"):
474
- with gr.Column():
475
  lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
476
  pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
 
477
  epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
478
  weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
479
  with gr.Tab("Sampling"):
480
- with gr.Column():
481
  cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
482
  steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
483
- negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
484
  seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
 
 
485
  injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
 
 
486
  # with gr.Tab("Editing"):
487
  # with gr.Column():
488
  # cfg2 = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
@@ -498,10 +505,10 @@ with gr.Blocks(css="style.css") as demo:
498
 
499
 
500
 
501
- # gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""")
502
 
503
  with gr.Row():
504
- file_output = gr.File(label="Download Sampled Model", container=True, interactive=False, info="After sampling a new model or inverting, you can download the model below")
505
 
506
 
507
 
@@ -509,18 +516,21 @@ with gr.Blocks(css="style.css") as demo:
509
 
510
  invert_button.click(fn=run_inversion,
511
  inputs=[input_image, pcs, epochs, weight_decay,lr],
512
- outputs = [gallery1, file_output])
513
 
514
 
515
- sample.click(fn=sample_then_run, outputs=[gallery1, file_output])
516
 
517
  # submit1.click(fn=inference,
518
  # inputs=[prompt1, negative_prompt1, cfg1, steps1, seed1],
519
  # outputs=gallery1)
520
- submit1.click(fn=edit_inference,
521
- inputs=[prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step, a1, a2, a3, a4],
522
- outputs=gallery1)
523
- file_input.change(fn=file_upload, inputs=file_input, outputs = gallery1)
 
 
 
524
 
525
 
526
 
 
1
  import os
2
+ # os.system("pip uninstall -y gradio")
3
+ # #os.system('pip install gradio==3.43.1')
4
 
5
  import torch
6
  import torchvision
 
21
  from sampling import sample_weights
22
  from lora_w2w import LoRAw2w
23
  from huggingface_hub import snapshot_download
24
+ import numpy as np
25
  global device
26
  global generator
27
  global unet
 
32
  global network
33
  device = "cuda:0"
34
  generator = torch.Generator(device=device)
35
+ from gradio_imageslider import ImageSlider
36
 
37
 
38
 
 
61
 
62
 
63
  @torch.no_grad()
64
+ def inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed):
65
  global device
66
  global generator
67
  global unet
 
113
 
114
 
115
  @torch.no_grad()
116
+ def edit_inference(input_image, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
117
 
118
  global device
119
  global generator
 
196
  network.proj = torch.nn.Parameter(original_weights)
197
  network.reset()
198
 
199
+ return (image, input_image['composite'])
200
 
201
 
202
  def sample_then_run():
 
435
  with gr.Column():
436
  with gr.Row():
437
  with gr.Column():
438
+ # input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
439
+ # height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
440
+ input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Upload image and draw to define mask",
441
+ height=512, width=512, brush=gr.Brush(), layers=False)
442
  with gr.Row():
443
  sample = gr.Button("Sample New Model")
444
  invert_button = gr.Button("Invert")
445
  with gr.Column():
446
+ image_slider = ImageSlider(position=1., type="pil", height=512, width=512)
447
+ # gallery1 = gr.Image(label="Identity from Original Model",height=512, width=512, interactive=False)
448
 
449
  prompt1 = gr.Textbox(label="Prompt",
450
  info="Make sure to include 'sks person'" ,
 
474
 
475
  with gr.Accordion("Advanced Options", open=False):
476
  with gr.Tab("Inversion"):
477
+ with gr.Row():
478
  lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
479
  pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
480
+ with gr.Row():
481
  epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
482
  weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
483
  with gr.Tab("Sampling"):
484
+ with gr.Row():
485
  cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
486
  steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
 
487
  seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
488
+ with gr.Row():
489
+ negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
490
  injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
491
+
492
+
493
  # with gr.Tab("Editing"):
494
  # with gr.Column():
495
  # cfg2 = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
 
505
 
506
 
507
 
508
+ gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""")
509
 
510
  with gr.Row():
511
+ file_output = gr.File(label="Download Sampled Model", container=True, interactive=False)
512
 
513
 
514
 
 
516
 
517
  invert_button.click(fn=run_inversion,
518
  inputs=[input_image, pcs, epochs, weight_decay,lr],
519
+ outputs = [input_image, file_output])
520
 
521
 
522
+ sample.click(fn=sample_then_run, outputs=[input_image, file_output])
523
 
524
  # submit1.click(fn=inference,
525
  # inputs=[prompt1, negative_prompt1, cfg1, steps1, seed1],
526
  # outputs=gallery1)
527
+ # submit1.click(fn=edit_inference,
528
+ # inputs=[input_image, prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step, a1, a2, a3, a4],
529
+ # outputs=image_slider)
530
+ submit1.click(
531
+ fn=edit_inference, inputs=[input_image, prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step, a1, a2, a3, a4], outputs=[image_slider]
532
+ )
533
+ file_input.change(fn=file_upload, inputs=file_input, outputs = input_image)
534
 
535
 
536