PeterL1n commited on
Commit
187715d
1 Parent(s): 95995fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -99
app.py CHANGED
@@ -1,120 +1,101 @@
1
  import gradio as gr
2
  import torch
3
- import spaces
4
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
5
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
- from diffusers.image_processor import VaeImageProcessor
7
- from transformers import CLIPImageProcessor
8
  from huggingface_hub import hf_hub_download
9
  from safetensors.torch import load_file
 
 
 
10
 
11
- device = "cuda"
12
- dtype = torch.float16
13
 
 
14
  base = "stabilityai/stable-diffusion-xl-base-1.0"
15
  repo = "ByteDance/SDXL-Lightning"
16
- opts = {
17
- "1 Step" : ("sdxl_lightning_1step_unet_x0.safetensors", 1),
18
- "2 Steps" : ("sdxl_lightning_2step_unet.safetensors", 2),
19
- "4 Steps" : ("sdxl_lightning_4step_unet.safetensors", 4),
20
- "8 Steps" : ("sdxl_lightning_8step_unet.safetensors", 8),
21
  }
22
 
23
- # Inference function.
24
- @spaces.GPU()
25
- def generate(prompt, option, progress=gr.Progress()):
26
- print(prompt, option)
27
- ckpt, step = opts[option]
28
-
29
- progress(0, desc="Initializing the model")
30
 
31
- # Main pipeline.
32
- unet = UNet2DConditionModel.from_config(base, subfolder="unet")
33
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
34
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
35
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
36
 
37
- # Safety checker.
38
- safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
39
- feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
40
- image_processor = VaeImageProcessor(vae_scale_factor=8)
41
 
42
- def inference_callback(p, i, t, kwargs):
43
- progress((i+1, step))
44
- return kwargs
45
-
46
- # Inference loop.
47
- progress((0, step))
48
- results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pt")
49
-
50
- # Safety check.
51
- feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
52
- safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
53
- pixel_values = safety_checker_input.pixel_values.to(device, dtype)
54
- images, has_nsfw_concept = safety_checker(
55
- images=results.images, clip_input=pixel_values
56
  )
57
- if has_nsfw_concept[0]:
58
- print(f"Safety checker triggered on prompt: {prompt}")
59
- return images[0]
60
 
61
- with gr.Blocks(css="style.css") as demo:
62
- gr.HTML(
63
- "<h1><center>SDXL-Lightning</center></h1>" +
64
- "<p><center>Lightning-fast text-to-image generation</center></p>" +
65
- "<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
66
- )
67
-
68
- with gr.Row():
69
- prompt = gr.Textbox(
70
- label="Text prompt",
71
- scale=8
72
- )
73
- option = gr.Dropdown(
74
- label="Inference steps",
75
- choices=["1 Step", "2 Steps", "4 Steps", "8 Steps"],
76
- value="4 Steps",
77
- interactive=True
78
- )
79
- submit = gr.Button(
80
- scale=1,
81
- variant="primary"
82
  )
83
-
84
- img = gr.Image(label="SDXL-Lightning Generated Image")
85
 
86
- prompt.submit(
87
- fn=generate,
88
- inputs=[prompt, option],
89
- outputs=img,
90
- )
91
- submit.click(
92
- fn=generate,
93
- inputs=[prompt, option],
94
- outputs=img,
95
- )
96
 
97
- gr.Examples(
98
- fn=generate,
99
- examples=[
100
- ["An owl perches quietly on a twisted branch deep within an ancient forest.", "1 Step"],
101
- ["A lion in the galaxy, octane render", "2 Steps"],
102
- ["A dolphin leaps through the waves, set against a backdrop of bright blues and teal hues.", "2 Steps"],
103
- ["A girl smiling", "4 Steps"],
104
- ["An astronaut riding a horse", "4 Steps"],
105
- ["A fish on a bicycle, colorful art", "4 Steps"],
106
- ["A close-up of an Asian lady with sunglasses.", "4 Steps"],
107
- ["Rabbit portrait in a forest, fantasy", "4 Steps"],
108
- ["A panda swimming", "4 Steps"],
109
- ["Man portrait, ethereal", "8 Steps"],
110
- ],
111
- inputs=[prompt, option],
112
- outputs=img,
113
- cache_examples=False,
114
- )
115
 
116
- gr.HTML(
117
- "<p><small><center>This demo is built together by the community</center></small></p>"
118
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  demo.queue().launch()
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
 
 
 
 
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
+ import spaces
7
+ import os
8
+ from PIL import Image
9
 
10
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
 
11
 
12
+ # Constants
13
  base = "stabilityai/stable-diffusion-xl-base-1.0"
14
  repo = "ByteDance/SDXL-Lightning"
15
+ checkpoints = {
16
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
18
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
19
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
20
  }
21
 
 
 
 
 
 
 
 
22
 
23
+ # Ensure model and scheduler are initialized in GPU-enabled function
24
+ if torch.cuda.is_available():
25
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
 
 
26
 
27
+ if SAFETY_CHECKER:
28
+ from safety_checker import StableDiffusionSafetyChecker
29
+ from transformers import CLIPFeatureExtractor
 
30
 
31
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
32
+ "CompVis/stable-diffusion-safety-checker"
33
+ ).to("cuda")
34
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
35
+ "openai/clip-vit-base-patch32"
 
 
 
 
 
 
 
 
 
36
  )
 
 
 
37
 
38
+ def check_nsfw_images(
39
+ images: list[Image.Image],
40
+ ) -> tuple[list[Image.Image], list[bool]]:
41
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
42
+ has_nsfw_concepts = safety_checker(
43
+ images=[images],
44
+ clip_input=safety_checker_input.pixel_values.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
 
 
46
 
47
+ return images, has_nsfw_concepts
 
 
 
 
 
 
 
 
 
48
 
49
+ # Function
50
+ @spaces.GPU(enable_queue=True)
51
+ def generate_image(prompt, ckpt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ checkpoint = checkpoints[ckpt][0]
54
+ num_inference_steps = checkpoints[ckpt][1]
55
+
56
+ if num_inference_steps==1:
57
+ # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
58
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
59
+ else:
60
+ # Ensure sampler uses "trailing" timesteps.
61
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
62
+
63
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
64
+ results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
65
+
66
+ if SAFETY_CHECKER:
67
+ images, has_nsfw_concepts = check_nsfw_images(results.images)
68
+ if any(has_nsfw_concepts):
69
+ gr.Warning("NSFW content detected.")
70
+ return Image.new("RGB", (512, 512))
71
+ return images[0]
72
+ return results.images[0]
73
+
74
+
75
+
76
+ # Gradio Interface
77
+ description = """
78
+ This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
79
+ As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
80
+ """
81
+
82
+ with gr.Blocks(css="style.css") as demo:
83
+ gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
84
+ gr.Markdown(description)
85
+ with gr.Group():
86
+ with gr.Row():
87
+ prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
88
+ ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
89
+ submit = gr.Button(scale=1, variant='primary')
90
+ img = gr.Image(label='SDXL-Lightning Generated Image')
91
+
92
+ prompt.submit(fn=generate_image,
93
+ inputs=[prompt, ckpt],
94
+ outputs=img,
95
+ )
96
+ submit.click(fn=generate_image,
97
+ inputs=[prompt, ckpt],
98
+ outputs=img,
99
+ )
100
 
101
  demo.queue().launch()