radames HF staff commited on
Commit
7c7890b
1 Parent(s): 8908c78
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +139 -0
  3. requirements.txt +14 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ *.py[cod]
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline
2
+ import torch
3
+ import os
4
+
5
+ try:
6
+ import intel_extension_for_pytorch as ipex
7
+ except:
8
+ pass
9
+
10
+ from PIL import Image
11
+ import numpy as np
12
+ import gradio as gr
13
+ import psutil
14
+ import time
15
+
16
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
17
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
18
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
+ # check if MPS is available OSX only M1/M2/M3 chips
20
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
21
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
22
+ device = torch.device(
23
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
24
+ )
25
+ torch_device = device
26
+ torch_dtype = torch.float16
27
+
28
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
29
+ print(f"TORCH_COMPILE: {TORCH_COMPILE}")
30
+ print(f"device: {device}")
31
+
32
+ if mps_available:
33
+ device = torch.device("mps")
34
+ torch_device = "cpu"
35
+ torch_dtype = torch.float32
36
+
37
+ if SAFETY_CHECKER == "True":
38
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", revision="pr/4")
39
+ else:
40
+ pipe = DiffusionPipeline.from_pretrained(
41
+ "stabilityai/sdxl-turbo", revision="pr/4", safety_checker=None
42
+ )
43
+
44
+
45
+ pipe.to(device=torch_device, dtype=torch_dtype).to(device)
46
+ pipe.unet.to(memory_format=torch.channels_last)
47
+ pipe.set_progress_bar_config(disable=True)
48
+
49
+
50
+ def predict(prompt, steps, seed=1231231):
51
+ generator = torch.manual_seed(seed)
52
+ last_time = time.time()
53
+ results = pipe(
54
+ prompt=prompt,
55
+ generator=generator,
56
+ num_inference_steps=steps,
57
+ guidance_scale=0.0,
58
+ width=512,
59
+ height=512,
60
+ # original_inference_steps=params.lcm_steps,
61
+ output_type="pil",
62
+ )
63
+ print(f"Pipe took {time.time() - last_time} seconds")
64
+ nsfw_content_detected = (
65
+ results.nsfw_content_detected[0]
66
+ if "nsfw_content_detected" in results
67
+ else False
68
+ )
69
+ if nsfw_content_detected:
70
+ gr.Warning("NSFW content detected.")
71
+ return Image.new("RGB", (512, 512))
72
+ return results.images[0]
73
+
74
+
75
+ css = """
76
+ #container{
77
+ margin: 0 auto;
78
+ max-width: 40rem;
79
+ }
80
+ #intro{
81
+ max-width: 100%;
82
+ text-align: center;
83
+ margin: 0 auto;
84
+ }
85
+ """
86
+ with gr.Blocks(css=css) as demo:
87
+ with gr.Column(elem_id="container"):
88
+ gr.Markdown(
89
+ """# SDXL Turbo - Text To Image
90
+ ## Unofficial Demo
91
+ SDXL Turbo model can generate high quality images in a single pass read more on [stability.ai post](https://stability.ai/news/stability-ai-sdxl-turbo).
92
+ **Model**: https://huggingface.co/stabilityai/sdxl-turbo
93
+ """,
94
+ elem_id="intro",
95
+ )
96
+ with gr.Row():
97
+ with gr.Row():
98
+ prompt = gr.Textbox(
99
+ placeholder="Insert your prompt here:", scale=5, container=False
100
+ )
101
+ generate_bt = gr.Button("Generate", scale=1)
102
+
103
+ image = gr.Image(type="filepath")
104
+ with gr.Accordion("Advanced options", open=False):
105
+ steps = gr.Slider(label="Steps", value=2, minimum=1, maximum=10, step=1)
106
+ seed = gr.Slider(
107
+ randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1
108
+ )
109
+ with gr.Accordion("Run with diffusers"):
110
+ gr.Markdown(
111
+ """## Running SDXL Turbo with `diffusers`
112
+ ```bash
113
+ pip install diffusers==0.23.1
114
+ ```
115
+ ```py
116
+ from diffusers import DiffusionPipeline
117
+
118
+ pipe = DiffusionPipeline.from_pretrained(
119
+ "stabilityai/sdxl-turbo", revision="refs/pr/4"
120
+ ).to("cuda")
121
+ results = pipe(
122
+ prompt="A cinematic shot of a baby racoon wearing an intricate italian priest robe",
123
+ num_inference_steps=1,
124
+ guidance_scale=0.0,
125
+ )
126
+ imga = results.images[0]
127
+ imga.save("image.png")
128
+ ```
129
+ """
130
+ )
131
+
132
+ inputs = [prompt, steps, seed]
133
+ generate_bt.click(fn=predict, inputs=inputs, outputs=image, show_progress=False)
134
+ prompt.input(fn=predict, inputs=inputs, outputs=image, show_progress=False)
135
+ steps.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
136
+ seed.change(fn=predict, inputs=inputs, outputs=image, show_progress=False)
137
+
138
+ demo.queue()
139
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.23.1
2
+ transformers
3
+ gradio==4.7.1
4
+ --extra-index-url https://download.pytorch.org/whl/cu121
5
+ torch==2.1.0
6
+ fastapi==0.104.0
7
+ uvicorn==0.23.2
8
+ Pillow==10.1.0
9
+ accelerate==0.24.0
10
+ compel==2.0.2
11
+ controlnet-aux==0.0.7
12
+ peft==0.6.0
13
+ xformers
14
+ hf_transfer