userIdc2024 commited on
Commit
1e78ace
·
verified ·
1 Parent(s): 1963c84

Upload 3 files

Browse files
Files changed (3) hide show
  1. gitattributes +35 -0
  2. requirements.txt +15 -0
  3. sdapp.py +189 -0
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.23.1
2
+ transformers
3
+ gradio==4.7.1
4
+ --extra-index-url https://download.pytorch.org/whl/cu121
5
+ torch==2.1.0
6
+ fastapi==0.104.0
7
+ uvicorn==0.23.2
8
+ Pillow==10.1.0
9
+ accelerate==0.24.0
10
+ compel==2.0.2
11
+ controlnet-aux==0.0.7
12
+ peft==0.6.0
13
+ xformers
14
+ huggingface_hub==0.20.2
15
+ hf_transfer
sdapp.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
2
+ import torch
3
+ import os
4
+
5
+ try:
6
+ import intel_extension_for_pytorch as ipex
7
+ except:
8
+ pass
9
+
10
+ from PIL import Image
11
+ import numpy as np
12
+ import gradio as gr
13
+ import psutil
14
+ import time
15
+ import math
16
+
17
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
18
+
19
+ # check if MPS is available OSX only M1/M2/M3 chips
20
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
21
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
22
+ device = torch.device(
23
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
24
+ )
25
+ torch_device = device
26
+ torch_dtype = torch.float16
27
+
28
+ if mps_available:
29
+ device = torch.device("mps")
30
+ torch_device = "cpu"
31
+ torch_dtype = torch.float32
32
+
33
+ if SAFETY_CHECKER == "True":
34
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
35
+ "stabilityai/sdxl-turbo",
36
+ torch_dtype=torch_dtype,
37
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
38
+ )
39
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
40
+ "stabilityai/sdxl-turbo",
41
+ torch_dtype=torch_dtype,
42
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
43
+ )
44
+ else:
45
+ i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
46
+ "stabilityai/sdxl-turbo",
47
+ safety_checker=None,
48
+ torch_dtype=torch_dtype,
49
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
50
+ )
51
+ t2i_pipe = AutoPipelineForText2Image.from_pretrained(
52
+ "stabilityai/sdxl-turbo",
53
+ safety_checker=None,
54
+ torch_dtype=torch_dtype,
55
+ variant="fp16" if torch_dtype == torch.float16 else "fp32",
56
+ )
57
+
58
+
59
+ t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
60
+ t2i_pipe.set_progress_bar_config(disable=True)
61
+ i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
62
+ i2i_pipe.set_progress_bar_config(disable=True)
63
+
64
+
65
+ def resize_crop(image, size=512):
66
+ image = image.convert("RGB")
67
+ w, h = image.size
68
+ image = image.resize((size, int(size * (h / w))), Image.BICUBIC)
69
+ return image
70
+
71
+
72
+ async def predict(init_image, prompt, strength, steps, seed=1231231):
73
+ if init_image is not None:
74
+ init_image = resize_crop(init_image)
75
+ generator = torch.manual_seed(seed)
76
+ last_time = time.time()
77
+
78
+ if int(steps * strength) < 1:
79
+ steps = math.ceil(1 / max(0.10, strength))
80
+
81
+ results = i2i_pipe(
82
+ prompt=prompt,
83
+ image=init_image,
84
+ generator=generator,
85
+ num_inference_steps=steps,
86
+ guidance_scale=0.0,
87
+ strength=strength,
88
+ width=512,
89
+ height=512,
90
+ output_type="pil",
91
+ )
92
+ else:
93
+ generator = torch.manual_seed(seed)
94
+ last_time = time.time()
95
+ results = t2i_pipe(
96
+ prompt=prompt,
97
+ generator=generator,
98
+ num_inference_steps=steps,
99
+ guidance_scale=0.0,
100
+ width=512,
101
+ height=512,
102
+ output_type="pil",
103
+ )
104
+ print(f"Pipe took {time.time() - last_time} seconds")
105
+ nsfw_content_detected = (
106
+ results.nsfw_content_detected[0]
107
+ if "nsfw_content_detected" in results
108
+ else False
109
+ )
110
+ if nsfw_content_detected:
111
+ gr.Warning("NSFW content detected.")
112
+ return Image.new("RGB", (512, 512))
113
+ return results.images[0]
114
+
115
+
116
+ css = """
117
+ #container{
118
+ margin: 0 auto;
119
+ max-width: 80rem;
120
+ }
121
+ #intro{
122
+ max-width: 100%;
123
+ text-align: center;
124
+ margin: 0 auto;
125
+ }
126
+ """
127
+ with gr.Blocks(css=css) as demo:
128
+ init_image_state = gr.State()
129
+ with gr.Column(elem_id="container"):
130
+ gr.Markdown(
131
+ """# SDXL Turbo Image to Image/Text to Image
132
+ ## Unofficial Demo
133
+ SDXL Turbo model can generate high quality images in a single pass read more on [stability.ai post](https://stability.ai/news/stability-ai-sdxl-turbo).
134
+ **Model**: https://huggingface.co/stabilityai/sdxl-turbo
135
+ """,
136
+ elem_id="intro",
137
+ )
138
+ with gr.Row():
139
+ prompt = gr.Textbox(
140
+ placeholder="Insert your prompt here:",
141
+ scale=5,
142
+ container=False,
143
+ )
144
+ generate_bt = gr.Button("Generate", scale=1)
145
+ with gr.Column():
146
+ image = gr.Image(type="filepath")
147
+ with gr.Accordion("Advanced options", open=False):
148
+ strength = gr.Slider(
149
+ label="Strength",
150
+ value=0.7,
151
+ minimum=0.0,
152
+ maximum=1.0,
153
+ step=0.001,
154
+ )
155
+ steps = gr.Slider(
156
+ label="Steps", value=2, minimum=1, maximum=10, step=1
157
+ )
158
+ seed = gr.Slider(
159
+ randomize=True,
160
+ minimum=0,
161
+ maximum=12013012031030,
162
+ label="Seed",
163
+ step=1,
164
+ )
165
+ with gr.Row():
166
+ with gr.Column():
167
+ image_input = gr.Image(
168
+ sources=["upload", "webcam", "clipboard"],
169
+ label="Webcam",
170
+ type="pil",
171
+ )
172
+
173
+
174
+ inputs = [image_input, prompt, strength, steps, seed]
175
+ generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
176
+ prompt.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
177
+ steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
178
+ seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
179
+ strength.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
180
+ image_input.change(
181
+ fn=lambda x: x,
182
+ inputs=image_input,
183
+ outputs=init_image_state,
184
+ show_progress=False,
185
+ queue=False,
186
+ )
187
+
188
+ demo.queue()
189
+ demo.launch()