radames commited on
Commit
a492ba5
ยท
1 Parent(s): ea1e8bf
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +115 -36
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Unofficial SDXL Turbo Real Time Text To Image
3
  emoji: ๐Ÿ†
4
  colorFrom: yellow
5
  colorTo: purple
 
1
  ---
2
+ title: Unofficial SDXL Turbo Img2Img Txt2Img
3
  emoji: ๐Ÿ†
4
  colorFrom: yellow
5
  colorTo: purple
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from diffusers import DiffusionPipeline
2
  import torch
3
  import os
4
 
@@ -35,31 +35,79 @@ if mps_available:
35
  torch_dtype = torch.float32
36
 
37
  if SAFETY_CHECKER == "True":
38
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", revision="pr/4")
 
 
 
 
 
 
 
 
 
39
  else:
40
- pipe = DiffusionPipeline.from_pretrained(
41
- "stabilityai/sdxl-turbo", revision="pr/4", safety_checker=None
 
 
 
 
 
 
 
 
 
42
  )
43
 
44
 
45
- pipe.to(device=torch_device, dtype=torch_dtype).to(device)
46
- pipe.unet.to(memory_format=torch.channels_last)
47
- pipe.set_progress_bar_config(disable=True)
48
-
49
-
50
- def predict(prompt, steps, seed=1231231):
51
- generator = torch.manual_seed(seed)
52
- last_time = time.time()
53
- results = pipe(
54
- prompt=prompt,
55
- generator=generator,
56
- num_inference_steps=steps,
57
- guidance_scale=0.0,
58
- width=512,
59
- height=512,
60
- # original_inference_steps=params.lcm_steps,
61
- output_type="pil",
62
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  print(f"Pipe took {time.time() - last_time} seconds")
64
  nsfw_content_detected = (
65
  results.nsfw_content_detected[0]
@@ -75,7 +123,7 @@ def predict(prompt, steps, seed=1231231):
75
  css = """
76
  #container{
77
  margin: 0 auto;
78
- max-width: 40rem;
79
  }
80
  #intro{
81
  max-width: 100%;
@@ -84,9 +132,10 @@ css = """
84
  }
85
  """
86
  with gr.Blocks(css=css) as demo:
 
87
  with gr.Column(elem_id="container"):
88
  gr.Markdown(
89
- """# SDXL Turbo - Text To Image
90
  ## Unofficial Demo
91
  SDXL Turbo model can generate high quality images in a single pass read more on [stability.ai post](https://stability.ai/news/stability-ai-sdxl-turbo).
92
  **Model**: https://huggingface.co/stabilityai/sdxl-turbo
@@ -94,18 +143,40 @@ with gr.Blocks(css=css) as demo:
94
  elem_id="intro",
95
  )
96
  with gr.Row():
97
- with gr.Row():
98
- prompt = gr.Textbox(
99
- placeholder="Insert your prompt here:", scale=5, container=False
 
 
 
 
 
 
 
 
 
100
  )
101
- generate_bt = gr.Button("Generate", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- image = gr.Image(type="filepath")
104
- with gr.Accordion("Advanced options", open=False):
105
- steps = gr.Slider(label="Steps", value=2, minimum=1, maximum=10, step=1)
106
- seed = gr.Slider(
107
- randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
108
- )
109
  with gr.Accordion("Run with diffusers"):
110
  gr.Markdown(
111
  """## Running SDXL Turbo with `diffusers`
@@ -116,7 +187,7 @@ with gr.Blocks(css=css) as demo:
116
  from diffusers import DiffusionPipeline
117
 
118
  pipe = DiffusionPipeline.from_pretrained(
119
- "stabilityai/sdxl-turbo", revision="refs/pr/4"
120
  ).to("cuda")
121
  results = pipe(
122
  prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe",
@@ -129,11 +200,19 @@ with gr.Blocks(css=css) as demo:
129
  """
130
  )
131
 
132
- inputs = [prompt, steps, seed]
133
  generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
134
  prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
135
  steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
136
  seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
 
 
 
 
 
 
 
 
137
 
138
  demo.queue()
139
  demo.launch()
 
1
+ from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
2
  import torch
3
  import os
4
 
 
35
  torch_dtype = torch.float32
36
 
37
  if SAFETY_CHECKER == "True":
38
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
39
+ "stabilityai/sdxl-turbo",
40
+ torch_dtype=torch_dtype,
41
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
42
+ )
43
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
44
+ "stabilityai/sdxl-turbo",
45
+ torch_dtype=torch_dtype,
46
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
47
+ )
48
  else:
49
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
50
+ "stabilityai/sdxl-turbo",
51
+ safety_checker=None,
52
+ torch_dtype=torch_dtype,
53
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
54
+ )
55
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
56
+ "stabilityai/sdxl-turbo",
57
+ safety_checker=None,
58
+ torch_dtype=torch_dtype,
59
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
60
  )
61
 
62
 
63
+ t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
64
+ t2i_pipe.set_progress_bar_config(disable=True)
65
+ i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
66
+ i2i_pipe.set_progress_bar_config(disable=True)
67
+
68
+
69
+ def pad_image(image):
70
+ w, h = image.size
71
+ if w == h:
72
+ return image
73
+ elif w > h:
74
+ new_image = Image.new(image.mode, (w, w), (0, 0, 0))
75
+ new_image.paste(image, (0, (w - h) // 2))
76
+ return new_image
77
+ else:
78
+ new_image = Image.new(image.mode, (h, h), (0, 0, 0))
79
+ new_image.paste(image, ((h - w) // 2, 0))
80
+ return new_image
81
+
82
+
83
+ async def predict(init_image, prompt, strength, steps, seed=1231231):
84
+ if init_image is not None:
85
+ init_image = pad_image(init_image).convert("RGB").resize((512, 512))
86
+ generator = torch.manual_seed(seed)
87
+ last_time = time.time()
88
+ results = i2i_pipe(
89
+ prompt=prompt,
90
+ image=init_image,
91
+ generator=generator,
92
+ num_inference_steps=steps,
93
+ guidance_scale=0.0,
94
+ strength=strength,
95
+ width=512,
96
+ height=512,
97
+ output_type="pil",
98
+ )
99
+ else:
100
+ generator = torch.manual_seed(seed)
101
+ last_time = time.time()
102
+ results = t2i_pipe(
103
+ prompt=prompt,
104
+ generator=generator,
105
+ num_inference_steps=steps,
106
+ guidance_scale=0.0,
107
+ width=512,
108
+ height=512,
109
+ output_type="pil",
110
+ )
111
  print(f"Pipe took {time.time() - last_time} seconds")
112
  nsfw_content_detected = (
113
  results.nsfw_content_detected[0]
 
123
  css = """
124
  #container{
125
  margin: 0 auto;
126
+ max-width: 80rem;
127
  }
128
  #intro{
129
  max-width: 100%;
 
132
  }
133
  """
134
  with gr.Blocks(css=css) as demo:
135
+ init_image_state = gr.State()
136
  with gr.Column(elem_id="container"):
137
  gr.Markdown(
138
+ """# SDXL Turbo Image to Image/Text to Image
139
  ## Unofficial Demo
140
  SDXL Turbo model can generate high quality images in a single pass read more on [stability.ai post](https://stability.ai/news/stability-ai-sdxl-turbo).
141
  **Model**: https://huggingface.co/stabilityai/sdxl-turbo
 
143
  elem_id="intro",
144
  )
145
  with gr.Row():
146
+ prompt = gr.Textbox(
147
+ placeholder="Insert your prompt here:",
148
+ scale=5,
149
+ container=False,
150
+ )
151
+ generate_bt = gr.Button("Generate", scale=1)
152
+ with gr.Row():
153
+ with gr.Column():
154
+ image_input = gr.Image(
155
+ sources=["upload", "webcam", "clipboard"],
156
+ label="Webcam",
157
+ type="pil",
158
  )
159
+ with gr.Column():
160
+ image = gr.Image(type="filepath")
161
+ with gr.Accordion("Advanced options", open=False):
162
+ strength = gr.Slider(
163
+ label="Strength",
164
+ value=0.7,
165
+ minimum=0.0,
166
+ maximum=1.0,
167
+ step=0.001,
168
+ )
169
+ steps = gr.Slider(
170
+ label="Steps", value=2, minimum=1, maximum=10, step=1
171
+ )
172
+ seed = gr.Slider(
173
+ randomize=True,
174
+ minimum=0,
175
+ maximum=12013012031030,
176
+ label="Seed",
177
+ step=1,
178
+ )
179
 
 
 
 
 
 
 
180
  with gr.Accordion("Run with diffusers"):
181
  gr.Markdown(
182
  """## Running SDXL Turbo with `diffusers`
 
187
  from diffusers import DiffusionPipeline
188
 
189
  pipe = DiffusionPipeline.from_pretrained(
190
+ "stabilityai/sdxl-turbo"
191
  ).to("cuda")
192
  results = pipe(
193
  prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe",
 
200
  """
201
  )
202
 
203
+ inputs = [image_input, prompt, strength, steps, seed]
204
  generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
205
  prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
206
  steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
207
  seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
208
+ strength.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
209
+ image_input.change(
210
+ fn=lambda x: x,
211
+ inputs=image_input,
212
+ outputs=init_image_state,
213
+ show_progress=False,
214
+ queue=False,
215
+ )
216
 
217
  demo.queue()
218
  demo.launch()