J commited on
Commit
d609346
1 Parent(s): dbab601

demo app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -2
app.py CHANGED
@@ -1,7 +1,102 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+ import spaces
7
+ import os
8
+ from PIL import Image
9
+
10
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
11
+
12
+ # Constants
13
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
14
+ repo = "ByteDance/SDXL-Lightning"
15
+ checkpoints = {
16
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
17
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
18
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
19
+ }
20
+
21
 
22
  def greet(name):
23
  return "Hello " + name + "!!"
24
 
25
+ # Ensure model and scheduler are initialized in GPU-enabled function
26
+ if torch.cuda.is_available():
27
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
28
+
29
+ if SAFETY_CHECKER:
30
+ from safety_checker import StableDiffusionSafetyChecker
31
+ from transformers import CLIPFeatureExtractor
32
+
33
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
34
+ "CompVis/stable-diffusion-safety-checker"
35
+ ).to("cuda")
36
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
37
+ "openai/clip-vit-base-patch32"
38
+ )
39
+
40
+ def check_nsfw_images(
41
+ images: list[Image.Image],
42
+ ) -> tuple[list[Image.Image], list[bool]]:
43
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
44
+ has_nsfw_concepts = safety_checker(
45
+ images=[images],
46
+ clip_input=safety_checker_input.pixel_values.to("cuda")
47
+ )
48
+
49
+ return images, has_nsfw_concepts
50
+
51
+ # Function
52
+ @spaces.GPU(enable_queue=True)
53
+ def generate_image(prompt, ckpt):
54
+
55
+ checkpoint = checkpoints[ckpt][0]
56
+ num_inference_steps = checkpoints[ckpt][1]
57
+
58
+ if num_inference_steps==1:
59
+ # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
60
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
61
+ else:
62
+ # Ensure sampler uses "trailing" timesteps.
63
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
64
+
65
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
66
+ results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
67
+
68
+ if SAFETY_CHECKER:
69
+ images, has_nsfw_concepts = check_nsfw_images(results.images)
70
+ if any(has_nsfw_concepts):
71
+ gr.Warning("NSFW content detected.")
72
+ return Image.new("RGB", (512, 512))
73
+ return images[0]
74
+ return results.images[0]
75
+
76
+
77
+
78
+ # Gradio Interface
79
+ description = """
80
+ This demo utilizes the SDXL-Lightning model
81
+ """
82
+
83
+ with gr.Blocks(css="style.css") as demo:
84
+ gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
85
+ gr.Markdown(description)
86
+ with gr.Group():
87
+ with gr.Row():
88
+ prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
89
+ ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
90
+ submit = gr.Button(scale=1, variant='primary')
91
+ img = gr.Image(label='SDXL-Lightning Generated Image')
92
+
93
+ prompt.submit(fn=generate_image,
94
+ inputs=[prompt, ckpt],
95
+ outputs=img,
96
+ )
97
+ submit.click(fn=generate_image,
98
+ inputs=[prompt, ckpt],
99
+ outputs=img,
100
+ )
101
+
102
+ demo.queue().launch()