radames HF staff commited on
Commit
de9d198
1 Parent(s): 3a38380
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +141 -0
  3. requirements.txt +14 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ venv
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline, LCMScheduler, AutoencoderTiny
2
+ from compel import Compel, ReturnedEmbeddingsType
3
+ import torch
4
+ import os
5
+
6
+ try:
7
+ import intel_extension_for_pytorch as ipex
8
+ except:
9
+ pass
10
+
11
+ from PIL import Image
12
+ import numpy as np
13
+ import gradio as gr
14
+ import psutil
15
+
16
+
17
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
18
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
19
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
+ # check if MPS is available OSX only M1/M2/M3 chips
21
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
22
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
23
+ device = torch.device(
24
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
25
+ )
26
+ torch_device = device
27
+ torch_dtype = torch.float16
28
+
29
+ print(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
30
+ print(f"TORCH_COMPILE: {TORCH_COMPILE}")
31
+ print(f"device: {device}")
32
+
33
+ if mps_available:
34
+ device = torch.device("mps")
35
+ torch_device = "cpu"
36
+ torch_dtype = torch.float32
37
+
38
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
39
+
40
+ if SAFETY_CHECKER == "True":
41
+ pipe = DiffusionPipeline.from_pretrained(model_id)
42
+ else:
43
+ pipe = DiffusionPipeline.from_pretrained(model_id, safety_checker=None)
44
+
45
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
46
+ pipe.to(device=torch_device, dtype=torch_dtype).to(device)
47
+ pipe.unet.to(memory_format=torch.channels_last)
48
+
49
+ # check if computer has less than 64GB of RAM using sys or os
50
+ if psutil.virtual_memory().total < 64 * 1024**3:
51
+ pipe.enable_attention_slicing()
52
+
53
+ if TORCH_COMPILE:
54
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
55
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
56
+
57
+ pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
58
+
59
+ # Load LCM LoRA
60
+ pipe.load_lora_weights(
61
+ "lcm-sd/lcm-sdxl-lora",
62
+ weight_name="lcm_sdxl_lora.safetensors",
63
+ adapter_name="lcm",
64
+ token=HF_TOKEN,
65
+ )
66
+
67
+ compel_proc = Compel(
68
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
69
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
70
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
71
+ requires_pooled=[False, True],
72
+ )
73
+
74
+
75
+ def predict(
76
+ prompt, guidance, steps, seed=1231231, progress=gr.Progress(track_tqdm=True)
77
+ ):
78
+ generator = torch.manual_seed(seed)
79
+ prompt_embeds, pooled_prompt_embeds = compel_proc(prompt)
80
+
81
+ results = pipe(
82
+ prompt_embeds=prompt_embeds,
83
+ pooled_prompt_embeds=pooled_prompt_embeds,
84
+ generator=generator,
85
+ num_inference_steps=steps,
86
+ guidance_scale=guidance,
87
+ width=1024,
88
+ height=1024,
89
+ # original_inference_steps=params.lcm_steps,
90
+ output_type="pil",
91
+ )
92
+ nsfw_content_detected = (
93
+ results.nsfw_content_detected[0]
94
+ if "nsfw_content_detected" in results
95
+ else False
96
+ )
97
+ if nsfw_content_detected:
98
+ raise gr.Error("NSFW content detected.")
99
+ return results.images[0]
100
+
101
+
102
+ css = """
103
+ #container{
104
+ margin: 0 auto;
105
+ max-width: 50rem;
106
+ }
107
+ #intro{
108
+ max-width: 32rem;
109
+ text-align: center;
110
+ margin: 0 auto;
111
+ }
112
+ """
113
+ with gr.Blocks(css=css) as demo:
114
+ with gr.Column(elem_id="container"):
115
+ gr.Markdown(
116
+ """# Ultra-Fast SDXL with LoRAs borrowed from Latent Consistency Models
117
+
118
+ """,
119
+ elem_id="intro",
120
+ )
121
+ with gr.Row():
122
+ with gr.Row():
123
+ prompt = gr.Textbox(
124
+ placeholder="Insert your prompt here", scale=5, container=False
125
+ )
126
+ generate_bt = gr.Button("Generate", scale=1)
127
+ with gr.Accordion("Advanced options", open=False):
128
+ guidance = gr.Slider(
129
+ label="Guidance", minimum=0.0, maximum=5, value=0.3, step=0.001
130
+ )
131
+ steps = gr.Slider(label="Steps", value=4, minimum=2, maximum=10, step=1)
132
+ seed = gr.Slider(
133
+ randomize=True, minimum=0, maximum=12013012031030, label="Seed"
134
+ )
135
+ image = gr.Image(type="filepath")
136
+
137
+ inputs = [prompt, guidance, steps, seed]
138
+ generate_bt.click(fn=predict, inputs=inputs, outputs=image)
139
+
140
+ demo.queue()
141
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # diffusers==0.22.2
2
+ git+https://github.com/huggingface/diffusers.git@6110d7c95f630479cf01340cc8a8141c1e359f09
3
+ transformers==4.34.1
4
+ gradio==4.1.2
5
+ --extra-index-url https://download.pytorch.org/whl/cu121
6
+ torch==2.1.0
7
+ fastapi==0.104.0
8
+ uvicorn==0.23.2
9
+ Pillow==10.1.0
10
+ accelerate==0.24.0
11
+ compel==2.0.2
12
+ controlnet-aux==0.0.7
13
+ peft==0.6.0
14
+ bitsandbytes