Ubuntu commited on
Commit
ae4c73e
·
1 Parent(s): 2b96898

Original version

Browse files
Files changed (5) hide show
  1. README.md +5 -7
  2. app.py +169 -0
  3. requirements.txt +7 -0
  4. safety_checker.py +137 -0
  5. style.css +12 -0
README.md CHANGED
@@ -1,13 +1,11 @@
1
  ---
2
  title: SDXL Lightning
3
- emoji: 👁
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
- license: openrail++
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: SDXL Lightning
3
+ emoji:
4
+ colorFrom: yellow
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.19.1
8
  app_file: app.py
9
  pinned: false
10
+ license: openrail
11
  ---
 
 
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+ import spaces
7
+ import os
8
+ from PIL import Image, ImageFilter
9
+ from typing import List, Tuple
10
+
11
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
12
+
13
+ # Constants
14
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
15
+ repo = "ByteDance/SDXL-Lightning"
16
+ checkpoints = {
17
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
18
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
19
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
20
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
21
+ }
22
+ aspect_ratios = {
23
+ "21:9": (21, 9),
24
+ "2:1": (2, 1),
25
+ "16:9": (16, 9),
26
+ "5:4": (5, 4),
27
+ "4:3": (4, 3),
28
+ "3:2": (3, 2),
29
+ "1:1": (1, 1),
30
+ }
31
+ # Function to calculate resolution
32
+ def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8):
33
+ if aspect_ratio not in aspect_ratios:
34
+ raise ValueError(f"Invalid aspect ratio: {aspect_ratio}")
35
+
36
+ width_multiplier, height_multiplier = aspect_ratios[aspect_ratio]
37
+ ratio = width_multiplier / height_multiplier
38
+ if mode == 'portrait':
39
+ # Swap the ratio for portrait mode
40
+ ratio = 1 / ratio
41
+
42
+ height = int((total_pixels / ratio) ** 0.5)
43
+ height -= height % divisibility
44
+
45
+ width = int(height * ratio)
46
+ width -= width % divisibility
47
+
48
+ while width * height > total_pixels:
49
+ height -= divisibility
50
+ width = int(height * ratio)
51
+ width -= width % divisibility
52
+
53
+ return width, height
54
+
55
+
56
+ # Example prompts with ckpt, aspect, and mode
57
+ examples = [
58
+ {"prompt": "A futuristic cityscape at sunset", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"},
59
+ {"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
60
+ {"prompt": "A portrait of a robot in the style of Renaissance art", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"},
61
+ {"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"},
62
+ {"prompt": "A serene landscape with mountains and a river", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"},
63
+ {"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"}
64
+ ]
65
+ # Define a function to set the example inputs
66
+ def set_example(selected_prompt):
67
+ # Find the example that matches the selected prompt
68
+ for example in examples:
69
+ if example["prompt"] == selected_prompt:
70
+ return example["prompt"], example["ckpt"], example["aspect"], example["mode"]
71
+ return None, None, None, None # Default values if not found
72
+
73
+ # Ensure model and scheduler are initialized in GPU-enabled function
74
+ if torch.cuda.is_available():
75
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
76
+
77
+ if SAFETY_CHECKER:
78
+ from safety_checker import StableDiffusionSafetyChecker
79
+ from transformers import CLIPFeatureExtractor
80
+
81
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(
82
+ "CompVis/stable-diffusion-safety-checker"
83
+ ).to("cuda")
84
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
85
+ "openai/clip-vit-base-patch32"
86
+ )
87
+
88
+ def check_nsfw_images(
89
+ images: List[Image.Image]
90
+ ) -> Tuple[List[Image.Image], List[bool]]:
91
+ # Assuming feature_extractor and safety_checker are defined and initialized elsewhere
92
+
93
+ # Convert PIL Images to the format expected by the feature extractor
94
+ # This often involves converting them to tensors, but the exact method
95
+ # depends on the feature_extractor's requirements
96
+ safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images]
97
+
98
+ # Get NSFW concepts for each image
99
+ has_nsfw_concepts = [safety_checker(
100
+ images=[image],
101
+ clip_input=safety_checker_input.pixel_values.to("cuda")
102
+ ) for image, safety_checker_input in zip(images, safety_checker_inputs)]
103
+
104
+ # Flatten the has_nsfw_concepts list if it's nested
105
+ has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist]
106
+
107
+ return images, has_nsfw_concepts
108
+
109
+ # Function
110
+ @spaces.GPU(enable_queue=True)
111
+ def generate_image(prompt, ckpt, aspect_ratio, mode):
112
+ width, height = calculate_resolution(aspect_ratio, mode) # Calculate resolution based on the aspect ratio
113
+ checkpoint = checkpoints[ckpt][0]
114
+ num_inference_steps = checkpoints[ckpt][1]
115
+
116
+ if num_inference_steps==1:
117
+ # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
118
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
119
+ else:
120
+ # Ensure sampler uses "trailing" timesteps.
121
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
122
+
123
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
124
+ results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height )
125
+
126
+ if SAFETY_CHECKER:
127
+ images, has_nsfw_concepts = check_nsfw_images(results.images)
128
+ if any(has_nsfw_concepts):
129
+ gr.Warning("NSFW content detected.")
130
+ # Apply a blur filter to the first image in the results
131
+ blurred_image = images[0].filter(ImageFilter.GaussianBlur(16)) # Adjust the radius as needed
132
+ return blurred_image
133
+ return images[0]
134
+ return results.images[0]
135
+
136
+
137
+
138
+ # Gradio Interface
139
+ description = """
140
+ SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
141
+ """
142
+
143
+ with gr.Blocks(css="style.css") as demo:
144
+ gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
145
+ gr.Markdown(description)
146
+ with gr.Group():
147
+ with gr.Row():
148
+ prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
149
+ with gr.Row():
150
+ ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
151
+ aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True)
152
+ mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape') # Mode as a dropdown
153
+ submit = gr.Button(scale=1, variant='primary')
154
+
155
+ img = gr.Image(label='SDXL-Lightning Generated Image')
156
+
157
+ prompt.submit(fn=generate_image,
158
+ inputs=[prompt, ckpt, aspect, mode],
159
+ outputs=img,
160
+ )
161
+ submit.click(fn=generate_image,
162
+ inputs=[prompt, ckpt, aspect, mode],
163
+ outputs=img,
164
+ )
165
+ # Dropdown for selecting examples
166
+ example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples])
167
+ example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, ckpt, aspect, mode])
168
+
169
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers
2
+ diffusers
3
+ torch
4
+ accelerate
5
+ gradio
6
+ pillow
7
+ spaces
safety_checker.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
19
+
20
+
21
+ def cosine_distance(image_embeds, text_embeds):
22
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
23
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
24
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
25
+
26
+
27
+ class StableDiffusionSafetyChecker(PreTrainedModel):
28
+ config_class = CLIPConfig
29
+
30
+ _no_split_modules = ["CLIPEncoderLayer"]
31
+
32
+ def __init__(self, config: CLIPConfig):
33
+ super().__init__(config)
34
+
35
+ self.vision_model = CLIPVisionModel(config.vision_config)
36
+ self.visual_projection = nn.Linear(
37
+ config.vision_config.hidden_size, config.projection_dim, bias=False
38
+ )
39
+
40
+ self.concept_embeds = nn.Parameter(
41
+ torch.ones(17, config.projection_dim), requires_grad=False
42
+ )
43
+ self.special_care_embeds = nn.Parameter(
44
+ torch.ones(3, config.projection_dim), requires_grad=False
45
+ )
46
+
47
+ self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
48
+ self.special_care_embeds_weights = nn.Parameter(
49
+ torch.ones(3), requires_grad=False
50
+ )
51
+
52
+ @torch.no_grad()
53
+ def forward(self, clip_input, images):
54
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
55
+ image_embeds = self.visual_projection(pooled_output)
56
+
57
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
58
+ special_cos_dist = (
59
+ cosine_distance(image_embeds, self.special_care_embeds)
60
+ .cpu()
61
+ .float()
62
+ .numpy()
63
+ )
64
+ cos_dist = (
65
+ cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
66
+ )
67
+
68
+ result = []
69
+ batch_size = image_embeds.shape[0]
70
+ for i in range(batch_size):
71
+ result_img = {
72
+ "special_scores": {},
73
+ "special_care": [],
74
+ "concept_scores": {},
75
+ "bad_concepts": [],
76
+ }
77
+
78
+ # increase this value to create a stronger `nfsw` filter
79
+ # at the cost of increasing the possibility of filtering benign images
80
+ adjustment = 0.0
81
+
82
+ for concept_idx in range(len(special_cos_dist[0])):
83
+ concept_cos = special_cos_dist[i][concept_idx]
84
+ concept_threshold = self.special_care_embeds_weights[concept_idx].item()
85
+ result_img["special_scores"][concept_idx] = round(
86
+ concept_cos - concept_threshold + adjustment, 3
87
+ )
88
+ if result_img["special_scores"][concept_idx] > 0:
89
+ result_img["special_care"].append(
90
+ {concept_idx, result_img["special_scores"][concept_idx]}
91
+ )
92
+ adjustment = 0.01
93
+
94
+ for concept_idx in range(len(cos_dist[0])):
95
+ concept_cos = cos_dist[i][concept_idx]
96
+ concept_threshold = self.concept_embeds_weights[concept_idx].item()
97
+ result_img["concept_scores"][concept_idx] = round(
98
+ concept_cos - concept_threshold + adjustment, 3
99
+ )
100
+ if result_img["concept_scores"][concept_idx] > 0:
101
+ result_img["bad_concepts"].append(concept_idx)
102
+
103
+ result.append(result_img)
104
+
105
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
106
+
107
+ return has_nsfw_concepts
108
+
109
+ @torch.no_grad()
110
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
111
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
112
+ image_embeds = self.visual_projection(pooled_output)
113
+
114
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
115
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
116
+
117
+ # increase this value to create a stronger `nsfw` filter
118
+ # at the cost of increasing the possibility of filtering benign images
119
+ adjustment = 0.0
120
+
121
+ special_scores = (
122
+ special_cos_dist - self.special_care_embeds_weights + adjustment
123
+ )
124
+ # special_scores = special_scores.round(decimals=3)
125
+ special_care = torch.any(special_scores > 0, dim=1)
126
+ special_adjustment = special_care * 0.01
127
+ special_adjustment = special_adjustment.unsqueeze(1).expand(
128
+ -1, cos_dist.shape[1]
129
+ )
130
+
131
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
132
+ # concept_scores = concept_scores.round(decimals=3)
133
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
134
+
135
+ images[has_nsfw_concepts] = 0.0 # black image
136
+
137
+ return images, has_nsfw_concepts
style.css ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .gradio-container {
2
+ max-width: 690px! important;
3
+ }
4
+
5
+ #share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;}
6
+ div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
7
+ #share-btn-container:hover {background-color: #060606}
8
+ #share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;}
9
+ #share-btn * {all: unset}
10
+ #share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
11
+ #share-btn-container .wrap {display: none !important}
12
+ #share-btn-container.hidden {display: none!important}