multimodalart HF staff commited on
Commit
cc54eed
1 Parent(s): 9d80ec5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +276 -0
app.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import gradio as gr
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ from typing import List
8
+ from diffusers.utils import numpy_to_pil
9
+ from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
10
+ from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
11
+
12
+ import user_history
13
+
14
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
15
+
16
+ DESCRIPTION = "# Stable Cascade"
17
+ #DESCRIPTION += "\n<p style=\"text-align: center\"><a href='https://huggingface.co/warp-ai/wuerstchen' target='_blank'>Würstchen</a> is a new fast and efficient high resolution text-to-image architecture and model</p>"
18
+ if not torch.cuda.is_available():
19
+ DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
20
+
21
+ MAX_SEED = np.iinfo(np.int32).max
22
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
23
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
24
+ USE_TORCH_COMPILE = False
25
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
26
+
27
+ dtype = torch.float16
28
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29
+ if torch.cuda.is_available():
30
+ prior_pipeline = StableCascadePriorPipeline.from_pretrained("diffusers/StableCascade-prior", torch_dtype=torch.bfloat16).to("cuda")
31
+ decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16).to("cuda")
32
+
33
+ if ENABLE_CPU_OFFLOAD:
34
+ prior_pipeline.enable_model_cpu_offload()
35
+ decoder_pipeline.enable_model_cpu_offload()
36
+ else:
37
+ prior_pipeline.to(device)
38
+ decoder_pipeline.to(device)
39
+
40
+ if USE_TORCH_COMPILE:
41
+ prior_pipeline.prior = torch.compile(prior_pipeline.prior, mode="reduce-overhead", fullgraph=True)
42
+ decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="reduce-overhead", fullgraph=True)
43
+
44
+ #if PREVIEW_IMAGES:
45
+ # previewer = Previewer()
46
+ # previewer.load_state_dict(torch.load("previewer/text2img_wurstchen_b_v1_previewer_100k.pt")["state_dict"])
47
+ # previewer.eval().requires_grad_(False).to(device).to(dtype)
48
+
49
+ # def callback_prior(i, t, latents):
50
+ # output = previewer(latents)
51
+ # output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
52
+ # return output
53
+
54
+ else:
55
+ previewer = None
56
+ callback_prior = None
57
+ else:
58
+ prior_pipeline = None
59
+ decoder_pipeline = None
60
+
61
+
62
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
63
+ if randomize_seed:
64
+ seed = random.randint(0, MAX_SEED)
65
+ return seed
66
+
67
+
68
+ def generate(
69
+ prompt: str,
70
+ negative_prompt: str = "",
71
+ seed: int = 0,
72
+ width: int = 1024,
73
+ height: int = 1024,
74
+ prior_num_inference_steps: int = 60,
75
+ # prior_timesteps: List[float] = None,
76
+ prior_guidance_scale: float = 4.0,
77
+ decoder_num_inference_steps: int = 12,
78
+ # decoder_timesteps: List[float] = None,
79
+ decoder_guidance_scale: float = 0.0,
80
+ num_images_per_prompt: int = 2,
81
+ profile: gr.OAuthProfile | None = None,
82
+ ) -> PIL.Image.Image:
83
+ generator = torch.Generator().manual_seed(seed)
84
+
85
+ prior_output = prior_pipeline(
86
+ prompt=prompt,
87
+ height=height,
88
+ width=width,
89
+ timesteps=DEFAULT_STAGE_C_TIMESTEPS,
90
+ negative_prompt=negative_prompt,
91
+ guidance_scale=prior_guidance_scale,
92
+ num_images_per_prompt=num_images_per_prompt,
93
+ generator=generator,
94
+ callback=callback_prior,
95
+ )
96
+
97
+ #if PREVIEW_IMAGES:
98
+ # for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
99
+ # r = next(prior_output)
100
+ # if isinstance(r, list):
101
+ # yield r
102
+ # prior_output = r
103
+
104
+ decoder_output = decoder_pipeline(
105
+ image_embeddings=prior_output.image_embeddings,
106
+ prompt=prompt,
107
+ num_inference_steps=decoder_num_inference_steps,
108
+ # timesteps=decoder_timesteps,
109
+ guidance_scale=decoder_guidance_scale,
110
+ negative_prompt=negative_prompt,
111
+ generator=generator,
112
+ output_type="pil",
113
+ ).images
114
+
115
+ # Save images
116
+ for image in decoder_output:
117
+ user_history.save_image(
118
+ profile=profile,
119
+ image=image,
120
+ label=prompt,
121
+ metadata={
122
+ "negative_prompt": negative_prompt,
123
+ "seed": seed,
124
+ "width": width,
125
+ "height": height,
126
+ "prior_guidance_scale": prior_guidance_scale,
127
+ "decoder_num_inference_steps": decoder_num_inference_steps,
128
+ "decoder_guidance_scale": decoder_guidance_scale,
129
+ "num_images_per_prompt": num_images_per_prompt,
130
+ },
131
+ )
132
+
133
+ yield decoder_output
134
+
135
+
136
+ examples = [
137
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
138
+ "An astronaut riding a green horse",
139
+ ]
140
+
141
+ with gr.Blocks() as demo:
142
+ gr.Markdown(DESCRIPTION)
143
+ gr.DuplicateButton(
144
+ value="Duplicate Space for private use",
145
+ elem_id="duplicate-button",
146
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
147
+ )
148
+ with gr.Group():
149
+ with gr.Row():
150
+ prompt = gr.Text(
151
+ label="Prompt",
152
+ show_label=False,
153
+ max_lines=1,
154
+ placeholder="Enter your prompt",
155
+ container=False,
156
+ )
157
+ run_button = gr.Button("Run", scale=0)
158
+ result = gr.Gallery(label="Result", show_label=False)
159
+ with gr.Accordion("Advanced options", open=False):
160
+ negative_prompt = gr.Text(
161
+ label="Negative prompt",
162
+ max_lines=1,
163
+ placeholder="Enter a Negative Prompt",
164
+ )
165
+
166
+ seed = gr.Slider(
167
+ label="Seed",
168
+ minimum=0,
169
+ maximum=MAX_SEED,
170
+ step=1,
171
+ value=0,
172
+ )
173
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
174
+ with gr.Row():
175
+ width = gr.Slider(
176
+ label="Width",
177
+ minimum=1024,
178
+ maximum=MAX_IMAGE_SIZE,
179
+ step=512,
180
+ value=1024,
181
+ )
182
+ height = gr.Slider(
183
+ label="Height",
184
+ minimum=1024,
185
+ maximum=MAX_IMAGE_SIZE,
186
+ step=512,
187
+ value=1024,
188
+ )
189
+ num_images_per_prompt = gr.Slider(
190
+ label="Number of Images",
191
+ minimum=1,
192
+ maximum=2,
193
+ step=1,
194
+ value=2,
195
+ )
196
+ with gr.Row():
197
+ prior_guidance_scale = gr.Slider(
198
+ label="Prior Guidance Scale",
199
+ minimum=0,
200
+ maximum=20,
201
+ step=0.1,
202
+ value=4.0,
203
+ )
204
+ prior_num_inference_steps = gr.Slider(
205
+ label="Prior Inference Steps",
206
+ minimum=30,
207
+ maximum=30,
208
+ step=1,
209
+ value=30,
210
+ )
211
+
212
+ decoder_guidance_scale = gr.Slider(
213
+ label="Decoder Guidance Scale",
214
+ minimum=0,
215
+ maximum=0,
216
+ step=0.1,
217
+ value=0.0,
218
+ )
219
+ decoder_num_inference_steps = gr.Slider(
220
+ label="Decoder Inference Steps",
221
+ minimum=4,
222
+ maximum=12,
223
+ step=1,
224
+ value=12,
225
+ )
226
+
227
+ gr.Examples(
228
+ examples=examples,
229
+ inputs=prompt,
230
+ outputs=result,
231
+ fn=generate,
232
+ cache_examples=CACHE_EXAMPLES,
233
+ )
234
+
235
+ inputs = [
236
+ prompt,
237
+ negative_prompt,
238
+ seed,
239
+ width,
240
+ height,
241
+ prior_num_inference_steps,
242
+ # prior_timesteps,
243
+ prior_guidance_scale,
244
+ decoder_num_inference_steps,
245
+ # decoder_timesteps,
246
+ decoder_guidance_scale,
247
+ num_images_per_prompt,
248
+ ]
249
+ gr.on(
250
+ [prompt.submit, negative_prompt.submit, run_button.click],
251
+ fn=randomize_seed_fn,
252
+ inputs=[seed, randomize_seed],
253
+ outputs=seed,
254
+ queue=False,
255
+ api_name=False,
256
+ ).then(
257
+ fn=generate,
258
+ inputs=inputs,
259
+ outputs=result,
260
+ api_name="run",
261
+ )
262
+
263
+ with gr.Blocks(css="style.css") as demo_with_history:
264
+ with gr.Tab("App"):
265
+ demo.render()
266
+ with gr.Tab("Past generations"):
267
+ user_history.render()
268
+
269
+ if __name__ == "__main__":
270
+ demo_with_history.queue(max_size=20).launch()
271
+
272
+
273
+ prior_output = prior(prompt)
274
+ images = decoder(prompt=prompt,
275
+ image_embeddings=prior_output.image_embeddings)
276
+ images[0][0]