yujincheng08 commited on
Commit
f375adb
1 Parent(s): 9484cfb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -0
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import random
7
+ import uuid
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+
14
+ from diffusers import AutoencoderKL, PixArtAlphaPipeline
15
+
16
+ DESCRIPTION = """# PixArt-Alpha 1024
17
+ #### [PixArt-Alpha 1024](https://github.com/PixArt-alpha/PixArt-alpha) is a transformer-based text-to-image diffusion system trained on text embeddings from T5. This demo uses the [PixArt-alpha/PixArt-XL-2-1024-MS](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) checkpoint.
18
+ """
19
+ if not torch.cuda.is_available():
20
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
21
+
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
24
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
25
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "1") == "1"
26
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
27
+
28
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
+
30
+ style_list = [
31
+ {
32
+ "name": "(No style)",
33
+ "prompt": "{prompt}",
34
+ "negative_prompt": "",
35
+ },
36
+ {
37
+ "name": "Cinematic",
38
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
39
+ "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
40
+ },
41
+ {
42
+ "name": "Photographic",
43
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
44
+ "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
45
+ },
46
+ {
47
+ "name": "Anime",
48
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
49
+ "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
50
+ },
51
+ {
52
+ "name": "Manga",
53
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
54
+ "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
55
+ },
56
+ {
57
+ "name": "Digital Art",
58
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
59
+ "negative_prompt": "photo, photorealistic, realism, ugly",
60
+ },
61
+ {
62
+ "name": "Pixel art",
63
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
64
+ "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
65
+ },
66
+ {
67
+ "name": "Fantasy art",
68
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
69
+ "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
70
+ },
71
+ {
72
+ "name": "Neonpunk",
73
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
74
+ "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
75
+ },
76
+ {
77
+ "name": "3D Model",
78
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
79
+ "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
80
+ },
81
+ ]
82
+
83
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
84
+ STYLE_NAMES = list(styles.keys())
85
+ DEFAULT_STYLE_NAME = "Cinematic"
86
+
87
+
88
+ def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
89
+ p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
90
+ if not negative:
91
+ negative = ""
92
+ return p.replace("{prompt}", positive), n + negative
93
+
94
+
95
+ if torch.cuda.is_available():
96
+ pipe = PixArtAlphaPipeline.from_pretrained(
97
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
98
+ torch_dtype=torch.float16,
99
+ use_safetensors=True,
100
+ )
101
+
102
+ if ENABLE_CPU_OFFLOAD:
103
+ pipe.enable_model_cpu_offload()
104
+ else:
105
+ pipe.to(device)
106
+ print("Loaded on Device!")
107
+
108
+ if USE_TORCH_COMPILE:
109
+ pipe.transformer = torch.compile(
110
+ pipe.transformer, mode="reduce-overhead", fullgraph=True
111
+ )
112
+ print("Model Compiled!")
113
+
114
+
115
+ def save_image(img):
116
+ unique_name = str(uuid.uuid4()) + ".png"
117
+ img.save(unique_name)
118
+ return unique_name
119
+
120
+
121
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
122
+ if randomize_seed:
123
+ seed = random.randint(0, MAX_SEED)
124
+ return seed
125
+
126
+
127
+ def generate(
128
+ prompt: str,
129
+ negative_prompt: str = "",
130
+ style: str = DEFAULT_STYLE_NAME,
131
+ use_negative_prompt: bool = False,
132
+ seed: int = 0,
133
+ width: int = 1024,
134
+ height: int = 1024,
135
+ guidance_scale: float = 4.5,
136
+ num_inference_steps: int = 20,
137
+ randomize_seed: bool = False,
138
+ progress=gr.Progress(track_tqdm=True),
139
+ ):
140
+ seed = randomize_seed_fn(seed, randomize_seed)
141
+ generator = torch.Generator().manual_seed(seed)
142
+
143
+ if not use_negative_prompt:
144
+ negative_prompt = None # type: ignore
145
+ prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
146
+ image = pipe(
147
+ prompt=prompt,
148
+ negative_prompt=negative_prompt,
149
+ width=width,
150
+ height=height,
151
+ guidance_scale=guidance_scale,
152
+ num_inference_steps=num_inference_steps,
153
+ generator=generator,
154
+ output_type="pil",
155
+ ).images[0]
156
+
157
+ image_path = save_image(image)
158
+ print(image_path)
159
+ return [image_path], seed
160
+
161
+
162
+ examples = [
163
+ "3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background",
164
+ "beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background",
165
+ "professional portrait photo of an anthropomorphic cat wearing fancy gentleman hat and jacket walking in autumn forest.",
166
+ "an astronaut sitting in a diner, eating fries, cinematic, analog film",
167
+ "Albert Einstein in a surrealist Cyberpunk 2077 world, hyperrealistic",
168
+ "cinematic film still of Futuristic hero with golden dark armour with machine gun, muscular body",
169
+ ]
170
+
171
+ with gr.Blocks(css="style.css") as demo:
172
+ gr.Markdown(DESCRIPTION)
173
+ gr.DuplicateButton(
174
+ value="Duplicate Space for private use",
175
+ elem_id="duplicate-button",
176
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
177
+ )
178
+ with gr.Group():
179
+ with gr.Row():
180
+ prompt = gr.Text(
181
+ label="Prompt",
182
+ show_label=False,
183
+ max_lines=1,
184
+ placeholder="Enter your prompt",
185
+ container=False,
186
+ )
187
+ run_button = gr.Button("Run", scale=0)
188
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
189
+ with gr.Accordion("Advanced options", open=False):
190
+ with gr.Row():
191
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
192
+ style_selection = gr.Radio(
193
+ show_label=True,
194
+ container=True,
195
+ interactive=True,
196
+ choices=STYLE_NAMES,
197
+ value=DEFAULT_STYLE_NAME,
198
+ label="Image Style",
199
+ )
200
+ negative_prompt = gr.Text(
201
+ label="Negative prompt",
202
+ max_lines=1,
203
+ placeholder="Enter a negative prompt",
204
+ visible=False,
205
+ )
206
+ seed = gr.Slider(
207
+ label="Seed",
208
+ minimum=0,
209
+ maximum=MAX_SEED,
210
+ step=1,
211
+ value=0,
212
+ )
213
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
214
+ with gr.Row(visible=False):
215
+ width = gr.Slider(
216
+ label="Width",
217
+ minimum=256,
218
+ maximum=MAX_IMAGE_SIZE,
219
+ step=32,
220
+ value=1024,
221
+ )
222
+ height = gr.Slider(
223
+ label="Height",
224
+ minimum=256,
225
+ maximum=MAX_IMAGE_SIZE,
226
+ step=32,
227
+ value=1024,
228
+ )
229
+ with gr.Row():
230
+ guidance_scale = gr.Slider(
231
+ label="Guidance scale",
232
+ minimum=1,
233
+ maximum=20,
234
+ step=0.1,
235
+ value=4.5,
236
+ )
237
+ num_inference_steps = gr.Slider(
238
+ label="Number of inference steps",
239
+ minimum=10,
240
+ maximum=100,
241
+ step=1,
242
+ value=20,
243
+ )
244
+
245
+ gr.Examples(
246
+ examples=examples,
247
+ inputs=prompt,
248
+ outputs=[result, seed],
249
+ fn=generate,
250
+ cache_examples=CACHE_EXAMPLES,
251
+ )
252
+
253
+ use_negative_prompt.change(
254
+ fn=lambda x: gr.update(visible=x),
255
+ inputs=use_negative_prompt,
256
+ outputs=negative_prompt,
257
+ queue=False,
258
+ api_name=False,
259
+ )
260
+
261
+ gr.on(
262
+ triggers=[
263
+ prompt.submit,
264
+ negative_prompt.submit,
265
+ run_button.click,
266
+ ],
267
+ fn=generate,
268
+ inputs=[
269
+ prompt,
270
+ negative_prompt,
271
+ style_selection,
272
+ use_negative_prompt,
273
+ seed,
274
+ width,
275
+ height,
276
+ guidance_scale,
277
+ num_inference_steps,
278
+ randomize_seed,
279
+ ],
280
+ outputs=[result, seed],
281
+ api_name="run",
282
+ )
283
+
284
+ if __name__ == "__main__":
285
+ demo.queue(max_size=20).launch()