cutycat2000x commited on
Commit
7de5b3a
1 Parent(s): 0fd3868

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +316 -1
app.py CHANGED
@@ -1,3 +1,318 @@
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/cutycat2000x/InterDiffusion-3.5").launch()
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import random
7
+
8
  import gradio as gr
9
+ import numpy as np
10
+ import PIL.Image
11
+ import spaces
12
+ import torch
13
+ from diffusers import AutoencoderKL, DiffusionPipeline
14
+
15
+ DESCRIPTION = """
16
+ # InterDiffusion-3.5
17
+
18
+ **Demo by [cutycat2000x) - [Hugging Face](https://huggingface.co/cutycat2000x)**
19
+
20
+ This is a demo of <a href="https://huggingface.co/cutycat2000x/InterDiffusion-3.5">InterDiffusion-3.5</a> by @cutycat2000x.
21
+
22
+ **The code for this demo is based on [@hysts's SD-XL demo](https://huggingface.co/spaces/hysts/SD-XL) running on a A10G GPU.**
23
+ """
24
+ if not torch.cuda.is_available():
25
+ DESCRIPTION += "\n<h1>Running on CPU 🥶 This demo does not work on CPU.</a> instead</h1>"
26
+
27
+ MAX_SEED = np.iinfo(np.int32).max
28
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
29
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2024"))
30
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
31
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
32
+ ENABLE_REFINER = os.getenv("ENABLE_REFINER", "0")#
33
+
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ if torch.cuda.is_available():
36
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
37
+ pipe = DiffusionPipeline.from_pretrained(
38
+ "cutycat2000x/InterDiffusion-3.5",
39
+ vae=vae,
40
+ torch_dtype=torch.float16,
41
+ # variant="fp16",
42
+ )
43
+ if ENABLE_REFINER:
44
+ refiner = DiffusionPipeline.from_pretrained(
45
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
46
+ vae=vae,
47
+ torch_dtype=torch.float16,
48
+ # variant="fp16",
49
+ )
50
+
51
+ if ENABLE_CPU_OFFLOAD:
52
+ pipe.enable_model_cpu_offload()
53
+ if ENABLE_REFINER:
54
+ refiner.enable_model_cpu_offload()
55
+ else:
56
+ pipe.to(device)
57
+ if ENABLE_REFINER:
58
+ refiner.to(device)
59
+
60
+ if USE_TORCH_COMPILE:
61
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
62
+ if ENABLE_REFINER:
63
+ refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
64
+
65
+
66
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
67
+ if randomize_seed:
68
+ seed = random.randint(0, MAX_SEED)
69
+ return seed
70
+
71
+
72
+ @spaces.GPU(enable_queue=True)
73
+ def generate(
74
+ prompt: str,
75
+ negative_prompt: str = "",
76
+ prompt_2: str = "",
77
+ negative_prompt_2: str = "",
78
+ use_negative_prompt: bool = False,
79
+ use_prompt_2: bool = False,
80
+ use_negative_prompt_2: bool = False,
81
+ seed: int = 0,
82
+ width: int = 1024,
83
+ height: int = 1024,
84
+ guidance_scale_base: float = 5.0,
85
+ guidance_scale_refiner: float = 7.0,
86
+ num_inference_steps_base: int = 60,
87
+ num_inference_steps_refiner: int = 35,
88
+ apply_refiner: bool = False,
89
+ progress=gr.Progress(track_tqdm=True),
90
+ ) -> PIL.Image.Image:
91
+ print(f"** Generating image for: \"{prompt}\" **")
92
+ generator = torch.Generator().manual_seed(seed)
93
+
94
+ if not use_negative_prompt:
95
+ negative_prompt = None # type: ignore
96
+ if not use_prompt_2:
97
+ prompt_2 = None # type: ignore
98
+ if not use_negative_prompt_2:
99
+ negative_prompt_2 = None # type: ignore
100
+
101
+ if not apply_refiner:
102
+ return pipe(
103
+ prompt=prompt,
104
+ negative_prompt=negative_prompt,
105
+ prompt_2=prompt_2,
106
+ negative_prompt_2=negative_prompt_2,
107
+ width=width,
108
+ height=height,
109
+ guidance_scale=guidance_scale_base,
110
+ num_inference_steps=num_inference_steps_base,
111
+ generator=generator,
112
+ output_type="pil",
113
+ ).images[0]
114
+ else:
115
+ latents = pipe(
116
+ prompt=prompt,
117
+ negative_prompt=negative_prompt,
118
+ prompt_2=prompt_2,
119
+ negative_prompt_2=negative_prompt_2,
120
+ width=width,
121
+ height=height,
122
+ guidance_scale=guidance_scale_base,
123
+ num_inference_steps=num_inference_steps_base,
124
+ generator=generator,
125
+ output_type="latent",
126
+ ).images
127
+ image = refiner(
128
+ prompt=prompt,
129
+ negative_prompt=negative_prompt,
130
+ prompt_2=prompt_2,
131
+ negative_prompt_2=negative_prompt_2,
132
+ guidance_scale=guidance_scale_refiner,
133
+ num_inference_steps=num_inference_steps_refiner,
134
+ image=latents,
135
+ generator=generator,
136
+ ).images[0]
137
+ return image
138
+
139
+
140
+ examples = [
141
+ "A realistic photograph of an astronaut in a jungle, cold color palette, detailed, 8k",
142
+ "An astronaut riding a green horse",
143
+ ]
144
+
145
+ theme = gr.themes.Base(
146
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
147
+ )
148
+ with gr.Blocks(css="footer{display:none !important}", theme=theme) as demo:
149
+ gr.Markdown(DESCRIPTION)
150
+ gr.DuplicateButton(
151
+ value="Duplicate Space for private use",
152
+ elem_id="duplicate-button",
153
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
154
+ )
155
+ with gr.Group():
156
+ prompt = gr.Text(
157
+ label="Prompt",
158
+ show_label=False,
159
+ max_lines=1,
160
+ container=False,
161
+ placeholder="Enter your prompt",
162
+ )
163
+ run_button = gr.Button("Generate")
164
+ result = gr.Image(label="Result", show_label=False)
165
+ with gr.Accordion("Advanced options", open=False):
166
+ with gr.Row():
167
+ use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
168
+ use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
169
+ use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
170
+ negative_prompt = gr.Text(
171
+ label="Negative prompt",
172
+ max_lines=1,
173
+ placeholder="Enter a negative prompt",
174
+ visible=False,
175
+ )
176
+ prompt_2 = gr.Text(
177
+ label="Prompt 2",
178
+ max_lines=1,
179
+ placeholder="Enter your prompt",
180
+ visible=False,
181
+ )
182
+ negative_prompt_2 = gr.Text(
183
+ label="Negative prompt 2",
184
+ max_lines=1,
185
+ placeholder="Enter a negative prompt",
186
+ visible=False,
187
+ )
188
+
189
+ seed = gr.Slider(
190
+ label="Seed",
191
+ minimum=0,
192
+ maximum=MAX_SEED,
193
+ step=1,
194
+ value=0,
195
+ )
196
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
197
+ with gr.Row():
198
+ width = gr.Slider(
199
+ label="Width",
200
+ minimum=256,
201
+ maximum=MAX_IMAGE_SIZE,
202
+ step=32,
203
+ value=1024,
204
+ )
205
+ height = gr.Slider(
206
+ label="Height",
207
+ minimum=256,
208
+ maximum=MAX_IMAGE_SIZE,
209
+ step=32,
210
+ value=1024,
211
+ )
212
+ apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
213
+ with gr.Row():
214
+ guidance_scale_base = gr.Slider(
215
+ label="Guidance scale for base",
216
+ minimum=1,
217
+ maximum=20,
218
+ step=0.1,
219
+ value=7.5,
220
+ )
221
+ num_inference_steps_base = gr.Slider(
222
+ label="Number of inference steps for base",
223
+ minimum=10,
224
+ maximum=100,
225
+ step=1,
226
+ value=60,
227
+ )
228
+ with gr.Row(visible=False) as refiner_params:
229
+ guidance_scale_refiner = gr.Slider(
230
+ label="Guidance scale for refiner",
231
+ minimum=1,
232
+ maximum=20,
233
+ step=0.1,
234
+ value=7.5,
235
+ )
236
+ num_inference_steps_refiner = gr.Slider(
237
+ label="Number of inference steps for refiner",
238
+ minimum=10,
239
+ maximum=100,
240
+ step=1,
241
+ value=30,
242
+ )
243
+
244
+ gr.Examples(
245
+ examples=examples,
246
+ inputs=prompt,
247
+ outputs=result,
248
+ fn=generate,
249
+ cache_examples=CACHE_EXAMPLES,
250
+ )
251
+
252
+ use_negative_prompt.change(
253
+ fn=lambda x: gr.update(visible=x),
254
+ inputs=use_negative_prompt,
255
+ outputs=negative_prompt,
256
+ queue=False,
257
+ api_name=False,
258
+ )
259
+ use_prompt_2.change(
260
+ fn=lambda x: gr.update(visible=x),
261
+ inputs=use_prompt_2,
262
+ outputs=prompt_2,
263
+ queue=False,
264
+ api_name=False,
265
+ )
266
+ use_negative_prompt_2.change(
267
+ fn=lambda x: gr.update(visible=x),
268
+ inputs=use_negative_prompt_2,
269
+ outputs=negative_prompt_2,
270
+ queue=False,
271
+ api_name=False,
272
+ )
273
+ apply_refiner.change(
274
+ fn=lambda x: gr.update(visible=x),
275
+ inputs=apply_refiner,
276
+ outputs=refiner_params,
277
+ queue=False,
278
+ api_name=False,
279
+ )
280
+
281
+ gr.on(
282
+ triggers=[
283
+ prompt.submit,
284
+ negative_prompt.submit,
285
+ prompt_2.submit,
286
+ negative_prompt_2.submit,
287
+ run_button.click,
288
+ ],
289
+ fn=randomize_seed_fn,
290
+ inputs=[seed, randomize_seed],
291
+ outputs=seed,
292
+ queue=False,
293
+ api_name=False,
294
+ ).then(
295
+ fn=generate,
296
+ inputs=[
297
+ prompt,
298
+ negative_prompt,
299
+ prompt_2,
300
+ negative_prompt_2,
301
+ use_negative_prompt,
302
+ use_prompt_2,
303
+ use_negative_prompt_2,
304
+ seed,
305
+ width,
306
+ height,
307
+ guidance_scale_base,
308
+ guidance_scale_refiner,
309
+ num_inference_steps_base,
310
+ num_inference_steps_refiner,
311
+ apply_refiner,
312
+ ],
313
+ outputs=result,
314
+ api_name="run",
315
+ )
316
 
317
+ if __name__ == "__main__":
318
+ demo.queue(max_size=20, api_open=False).launch(show_api=False)