kadirnar commited on
Commit
c19bd70
1 Parent(s): 7cbd2bd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import random
4
+ import uuid
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ from PIL import Image
9
+ import spaces
10
+ import torch
11
+ from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler, AutoencoderKL
12
+
13
+ DESCRIPTION = """# Stable Diffusion 3"""
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
16
+
17
+ MAX_SEED = np.iinfo(np.int32).max
18
+ CACHE_EXAMPLES = False
19
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
20
+ USE_TORCH_COMPILE = False
21
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
22
+
23
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
25
+
26
+ if torch.cuda.is_available():
27
+ pipe = StableDiffusionXLPipeline.from_single_file(
28
+ "https://huggingface.co/kadirnar/Black-Hole/blob/main/tachyon.safetensors",
29
+ torch_dtype=torch.float16,
30
+ use_safetensors=True,
31
+ add_watermarker=False,
32
+ variant="fp16",
33
+ vae=vae,
34
+ )
35
+ if ENABLE_CPU_OFFLOAD:
36
+ pipe.enable_model_cpu_offload()
37
+ else:
38
+ pipe.to(device)
39
+ print("Loaded on Device!")
40
+
41
+ if USE_TORCH_COMPILE:
42
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
43
+ print("Model Compiled!")
44
+
45
+
46
+ def save_image(img):
47
+ unique_name = str(uuid.uuid4()) + ".png"
48
+ img.save(unique_name)
49
+ return unique_name
50
+
51
+
52
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
53
+ if randomize_seed:
54
+ seed = random.randint(0, MAX_SEED)
55
+ return seed
56
+
57
+
58
+ @spaces.GPU(enable_queue=True)
59
+ def generate(
60
+ prompt: str,
61
+ negative_prompt: str = "",
62
+ use_negative_prompt: bool = False,
63
+ seed: int = 0,
64
+ width: int = 1024,
65
+ height: int = 1024,
66
+ guidance_scale: float = 3,
67
+ randomize_seed: bool = False,
68
+ num_inference_steps=5,
69
+ NUM_IMAGES_PER_PROMPT=1,
70
+ use_resolution_binning: bool = True,
71
+ progress=gr.Progress(track_tqdm=True),
72
+ ):
73
+ pipe.to(device)
74
+ seed = int(randomize_seed_fn(seed, randomize_seed))
75
+ generator = torch.Generator().manual_seed(seed)
76
+ sampling_schedule = [999, 845, 730, 587, 443, 310, 193, 116, 53, 13, 0]
77
+ #pipe.scheduler = DPMSolverSinglestepScheduler(use_karras_sigmas=True).from_config(pipe.scheduler.config)
78
+ #pipe.scheduler = DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++").from_config(pipe.scheduler.config)
79
+
80
+ if not use_negative_prompt:
81
+ negative_prompt = None # type: ignore
82
+
83
+ output = pipe(
84
+ prompt=prompt,
85
+ negative_prompt=negative_prompt,
86
+ width=width,
87
+ height=height,
88
+ guidance_scale=guidance_scale,
89
+ num_inference_steps=num_inference_steps,
90
+ generator=generator,
91
+ num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
92
+ use_resolution_binning=use_resolution_binning,
93
+ output_type="pil",
94
+ ).images
95
+
96
+ return output
97
+
98
+
99
+ examples = [
100
+ "neon holography crystal cat",
101
+ "a cat eating a piece of cheese",
102
+ "an astronaut riding a horse in space",
103
+ "a cartoon of a boy playing with a tiger",
104
+ "a cute robot artist painting on an easel, concept art",
105
+ "a close up of a woman wearing a transparent, prismatic, elaborate nemeses headdress, over the should pose, brown skin-tone"
106
+ ]
107
+
108
+ css = '''
109
+ .gradio-container{max-width: 1000px !important}
110
+ h1{text-align:center}
111
+ '''
112
+ with gr.Blocks(css=css) as demo:
113
+ with gr.Row():
114
+ with gr.Column():
115
+ gr.HTML(
116
+ """
117
+ <h1 style='text-align: center'>
118
+ Fast Tachyon SDXL
119
+ </h1>
120
+ """
121
+ )
122
+ gr.HTML(
123
+ """
124
+ <h3 style='text-align: center'>
125
+ Follow me for more!
126
+ <a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a>
127
+ </h3>
128
+ """
129
+ )
130
+ with gr.Group():
131
+ with gr.Row():
132
+ prompt = gr.Text(
133
+ label="Prompt",
134
+ show_label=False,
135
+ max_lines=1,
136
+ placeholder="Enter your prompt",
137
+ container=False,
138
+ )
139
+ run_button = gr.Button("Run", scale=0)
140
+ result = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
141
+ with gr.Accordion("Advanced options", open=False):
142
+ with gr.Row():
143
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
144
+ negative_prompt = gr.Text(
145
+ label="Negative prompt",
146
+ max_lines=1,
147
+ value = "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
148
+ visible=True,
149
+ )
150
+ seed = gr.Slider(
151
+ label="Seed",
152
+ minimum=0,
153
+ maximum=MAX_SEED,
154
+ step=1,
155
+ value=0,
156
+ )
157
+
158
+ steps = gr.Slider(
159
+ label="Steps",
160
+ minimum=0,
161
+ maximum=15,
162
+ step=1,
163
+ value=4,
164
+ )
165
+ number_image = gr.Slider(
166
+ label="Number of Image",
167
+ minimum=1,
168
+ maximum=4,
169
+ step=1,
170
+ value=1,
171
+ )
172
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
173
+ with gr.Row(visible=True):
174
+ width = gr.Slider(
175
+ label="Width",
176
+ minimum=256,
177
+ maximum=MAX_IMAGE_SIZE,
178
+ step=32,
179
+ value=1024,
180
+ )
181
+ height = gr.Slider(
182
+ label="Height",
183
+ minimum=256,
184
+ maximum=MAX_IMAGE_SIZE,
185
+ step=32,
186
+ value=1024,
187
+ )
188
+ with gr.Row():
189
+ guidance_scale = gr.Slider(
190
+ label="Guidance Scale",
191
+ minimum=0.1,
192
+ maximum=10,
193
+ step=0.1,
194
+ value=2.0,
195
+ )
196
+
197
+ gr.Examples(
198
+ examples=examples,
199
+ inputs=prompt,
200
+ outputs=[result],
201
+ fn=generate,
202
+ cache_examples=CACHE_EXAMPLES,
203
+ )
204
+
205
+ use_negative_prompt.change(
206
+ fn=lambda x: gr.update(visible=x),
207
+ inputs=use_negative_prompt,
208
+ outputs=negative_prompt,
209
+ api_name=False,
210
+ )
211
+
212
+ gr.on(
213
+ triggers=[
214
+ prompt.submit,
215
+ negative_prompt.submit,
216
+ run_button.click,
217
+ ],
218
+ fn=generate,
219
+ inputs=[
220
+ prompt,
221
+ negative_prompt,
222
+ use_negative_prompt,
223
+ seed,
224
+ width,
225
+ height,
226
+ guidance_scale,
227
+ randomize_seed,
228
+ steps,
229
+ number_image,
230
+ ],
231
+ outputs=[result],
232
+ api_name="run",
233
+ )
234
+
235
+ if __name__ == "__main__":
236
+ demo.queue().launch()