Niansuh commited on
Commit
51a92d7
1 Parent(s): ce2d932

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -39
app.py CHANGED
@@ -10,27 +10,32 @@ import spaces
10
  import torch
11
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
12
 
13
- if not torch.cuda.is_available():
14
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo may not work on CPU.</p>"
15
-
16
- MAX_SEED = np.iinfo(np.int32).max
17
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
18
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
19
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
20
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
 
21
 
 
22
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- if torch.cuda.is_available():
25
- pipe = StableDiffusionXLPipeline.from_pretrained(
26
- "sd-community/sdxl-flash",
27
- torch_dtype=torch.float16,
28
- use_safetensors=True,
29
- add_watermarker=False
30
- )
31
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
32
- pipe.to("cuda")
33
-
34
 
35
  def save_image(img):
36
  unique_name = str(uuid.uuid4()) + ".png"
@@ -53,51 +58,60 @@ def generate(
53
  guidance_scale: float = 3,
54
  num_inference_steps: int = 30,
55
  randomize_seed: bool = False,
56
- use_resolution_binning: bool = True,
 
57
  progress=gr.Progress(track_tqdm=True),
58
  ):
59
- pipe.to(device)
60
  seed = int(randomize_seed_fn(seed, randomize_seed))
61
- generator = torch.Generator().manual_seed(seed)
62
 
 
63
  options = {
64
- "prompt":prompt,
65
- "negative_prompt":negative_prompt,
66
- "width":width,
67
- "height":height,
68
- "guidance_scale":guidance_scale,
69
- "num_inference_steps":num_inference_steps,
70
- "generator":generator,
71
- "use_resolution_binning":use_resolution_binning,
72
- "output_type":"pil",
73
-
74
  }
75
-
76
- images = pipe(**options).images
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  image_paths = [save_image(img) for img in images]
79
  return image_paths, seed
80
 
81
-
82
  examples = [
83
  "a cat eating a piece of cheese",
84
  "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
85
  "Ironman VS Hulk, ultrarealistic",
86
  "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
87
- "An alien holding sign board contain word 'Flash', futuristic, neonpunk",
88
  "Kids going to school, Anime style"
89
  ]
90
 
91
  css = '''
92
- .gradio-container{max-width: 560px !important}
93
  h1{text-align:center}
94
  footer {
95
  visibility: hidden
96
  }
97
  '''
 
98
  with gr.Blocks(css=css) as demo:
99
- gr.Markdown("""# SDXL Flash
100
- ### First Image processing takes time then images generate faster.""")
101
  with gr.Group():
102
  with gr.Row():
103
  prompt = gr.Text(
@@ -108,8 +122,15 @@ with gr.Blocks(css=css) as demo:
108
  container=False,
109
  )
110
  run_button = gr.Button("Run", scale=0)
111
- result = gr.Gallery(label="Result", columns=1)
112
  with gr.Accordion("Advanced options", open=False):
 
 
 
 
 
 
 
113
  with gr.Row():
114
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
115
  negative_prompt = gr.Text(
@@ -162,9 +183,7 @@ with gr.Blocks(css=css) as demo:
162
  gr.Examples(
163
  examples=examples,
164
  inputs=prompt,
165
- outputs=[result, seed],
166
- fn=generate,
167
- cache_examples=CACHE_EXAMPLES,
168
  )
169
 
170
  use_negative_prompt.change(
@@ -191,6 +210,7 @@ with gr.Blocks(css=css) as demo:
191
  guidance_scale,
192
  num_inference_steps,
193
  randomize_seed,
 
194
  ],
195
  outputs=[result, seed],
196
  api_name="run",
 
10
  import torch
11
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
12
 
13
+ # Use environment variables for flexibility
14
+ MODEL_ID = os.getenv("MODEL_ID", "sd-community/sdxl-flash")
 
 
 
15
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
16
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
17
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
18
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # Allow generating multiple images at once
19
 
20
+ # Determine device and load model outside of function for efficiency
21
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+ pipe = StableDiffusionXLPipeline.from_pretrained(
23
+ MODEL_ID,
24
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
25
+ use_safetensors=True,
26
+ add_watermarker=False,
27
+ ).to(device)
28
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
29
+
30
+ # Torch compile for potential speedup (experimental)
31
+ if USE_TORCH_COMPILE:
32
+ pipe.compile()
33
+
34
+ # CPU offloading for larger RAM capacity (experimental)
35
+ if ENABLE_CPU_OFFLOAD:
36
+ pipe.enable_model_cpu_offload()
37
 
38
+ MAX_SEED = np.iinfo(np.int32).max
 
 
 
 
 
 
 
 
 
39
 
40
  def save_image(img):
41
  unique_name = str(uuid.uuid4()) + ".png"
 
58
  guidance_scale: float = 3,
59
  num_inference_steps: int = 30,
60
  randomize_seed: bool = False,
61
+ use_resolution_binning: bool = True,
62
+ num_images: int = 1, # Number of images to generate
63
  progress=gr.Progress(track_tqdm=True),
64
  ):
 
65
  seed = int(randomize_seed_fn(seed, randomize_seed))
66
+ generator = torch.Generator(device=device).manual_seed(seed)
67
 
68
+ # Improved options handling
69
  options = {
70
+ "prompt": [prompt] * num_images,
71
+ "negative_prompt": [negative_prompt] * num_images if use_negative_prompt else None,
72
+ "width": width,
73
+ "height": height,
74
+ "guidance_scale": guidance_scale,
75
+ "num_inference_steps": num_inference_steps,
76
+ "generator": generator,
77
+ "output_type": "pil",
 
 
78
  }
79
+
80
+ # Use resolution binning for faster generation with less VRAM usage
81
+ if use_resolution_binning:
82
+ options["use_resolution_binning"] = True
83
+
84
+ # Generate images potentially in batches
85
+ images = []
86
+ for i in range(0, num_images, BATCH_SIZE):
87
+ batch_options = options.copy()
88
+ batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
89
+ if "negative_prompt" in batch_options:
90
+ batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
91
+ images.extend(pipe(**batch_options).images)
92
 
93
  image_paths = [save_image(img) for img in images]
94
  return image_paths, seed
95
 
 
96
  examples = [
97
  "a cat eating a piece of cheese",
98
  "a ROBOT riding a BLUE horse on Mars, photorealistic, 4k",
99
  "Ironman VS Hulk, ultrarealistic",
100
  "Astronaut in a jungle, cold color palette, oil pastel, detailed, 8k",
101
+ "An alien holding a sign board containing the word 'Flash', futuristic, neonpunk",
102
  "Kids going to school, Anime style"
103
  ]
104
 
105
  css = '''
106
+ .gradio-container{max-width: 700px !important}
107
  h1{text-align:center}
108
  footer {
109
  visibility: hidden
110
  }
111
  '''
112
+
113
  with gr.Blocks(css=css) as demo:
114
+ gr.Markdown("""# SDXL Flash""")
 
115
  with gr.Group():
116
  with gr.Row():
117
  prompt = gr.Text(
 
122
  container=False,
123
  )
124
  run_button = gr.Button("Run", scale=0)
125
+ result = gr.Gallery(label="Result", columns=1, show_label=False)
126
  with gr.Accordion("Advanced options", open=False):
127
+ num_images = gr.Slider(
128
+ label="Number of Images",
129
+ minimum=1,
130
+ maximum=4,
131
+ step=1,
132
+ value=1,
133
+ )
134
  with gr.Row():
135
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=True)
136
  negative_prompt = gr.Text(
 
183
  gr.Examples(
184
  examples=examples,
185
  inputs=prompt,
186
+ cache_examples=False
 
 
187
  )
188
 
189
  use_negative_prompt.change(
 
210
  guidance_scale,
211
  num_inference_steps,
212
  randomize_seed,
213
+ num_images
214
  ],
215
  outputs=[result, seed],
216
  api_name="run",