linoyts HF staff commited on
Commit
e4e5057
·
verified ·
1 Parent(s): 164edec

add img2img support [wip]

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -11,9 +11,21 @@ pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash").to("
11
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
12
  clip_slider = CLIPSliderXL(pipe, device=torch.device("cuda"))
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @spaces.GPU
16
- def generate(slider_x, slider_y, prompt, seed, iterations, steps,
17
  x_concept_1, x_concept_2, y_concept_1, y_concept_2,
18
  avg_diff_x_1, avg_diff_x_2,
19
  avg_diff_y_1, avg_diff_y_2):
@@ -92,7 +104,7 @@ with gr.Blocks(css=css) as demo:
92
  avg_diff_y_1 = gr.State()
93
  avg_diff_y_2 = gr.State()
94
 
95
- with gr.Tab():
96
  with gr.Row():
97
  with gr.Column():
98
  slider_x = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
@@ -114,9 +126,10 @@ with gr.Blocks(css=css) as demo:
114
  outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
115
  x.change(fn=update_x, inputs=[x,y, prompt, seed, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
116
  y.change(fn=update_y, inputs=[x,y, prompt, seed, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
117
- with gr.Tab(label="IP Apater"):
118
  with gr.Row():
119
  with gr.Column():
 
120
  slider_x_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
121
  slider_y_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
122
  prompt_a = gr.Textbox(label="Prompt")
 
11
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
12
  clip_slider = CLIPSliderXL(pipe, device=torch.device("cuda"))
13
 
14
+ pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash").to("cuda", torch.float16)
15
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
16
+ clip_slider = CLIPSliderXL(pipe, device=torch.device("cuda"))
17
+
18
+ pipe_adapter = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash").to("cuda", torch.float16)
19
+ pipe_adapter.scheduler = EulerDiscreteScheduler.from_config(pipe_adapter.scheduler.config)
20
+ pipe_adapter.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
21
+ # scale = 0.8
22
+ # pipe_adapter.set_ip_adapter_scale(scale)
23
+
24
+ clip_slider_ip = CLIPSliderXL(sd_pipe=pipe_adapter,
25
+ device=torch.device("cuda"))
26
 
27
  @spaces.GPU
28
+ def generate(clip_slider, slider_x, slider_y, prompt, seed, iterations, steps,
29
  x_concept_1, x_concept_2, y_concept_1, y_concept_2,
30
  avg_diff_x_1, avg_diff_x_2,
31
  avg_diff_y_1, avg_diff_y_2):
 
104
  avg_diff_y_1 = gr.State()
105
  avg_diff_y_2 = gr.State()
106
 
107
+ with gr.Tab(""):
108
  with gr.Row():
109
  with gr.Column():
110
  slider_x = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
 
126
  outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2, output_image])
127
  x.change(fn=update_x, inputs=[x,y, prompt, seed, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
128
  y.change(fn=update_y, inputs=[x,y, prompt, seed, steps, avg_diff_x_1, avg_diff_x_2, avg_diff_y_1, avg_diff_y_2], outputs=[output_image])
129
+ with gr.Tab(label="image2image"):
130
  with gr.Row():
131
  with gr.Column():
132
+ image = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512))
133
  slider_x_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
134
  slider_y_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
135
  prompt_a = gr.Textbox(label="Prompt")