Hei-Ha commited on
Commit
2496eea
1 Parent(s): d4fa071
Files changed (4) hide show
  1. app.py +5 -102
  2. requirements.txt +1 -4
  3. safety_checker.py +0 -137
  4. style.css +0 -12
app.py CHANGED
@@ -1,107 +1,10 @@
1
  import gradio as gr
2
- import torch
3
- import spaces
4
- from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler, DiffusionPipeline
5
- from huggingface_hub import hf_hub_download
6
- from safetensors.torch import load_file
7
- import os
8
- from PIL import Image
9
 
10
- SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
11
 
12
- # Constants
13
- base = "stabilityai/stable-diffusion-xl-base-1.0"
14
- repo = "ByteDance/SDXL-Lightning"
15
- checkpoints = {
16
- "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
- "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
18
- "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
19
- "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
20
- }
21
 
 
 
22
 
23
-
24
-
25
- # Ensure model and scheduler are initialized in GPU-enabled function
26
- if torch.cuda.is_available():
27
- pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
28
- else:
29
- pipe = DiffusionPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16")
30
- pipe.to('cuda')
31
- print('------------------------------')
32
-
33
- if SAFETY_CHECKER:
34
- from safety_checker import StableDiffusionSafetyChecker
35
- from transformers import CLIPFeatureExtractor
36
-
37
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(
38
- "CompVis/stable-diffusion-safety-checker"
39
- ).to("cuda")
40
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
41
- "openai/clip-vit-base-patch32"
42
- )
43
-
44
- def check_nsfw_images(
45
- images: list[Image.Image],
46
- ) -> tuple[list[Image.Image], list[bool]]:
47
- safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
48
- has_nsfw_concepts = safety_checker(
49
- images=[images],
50
- clip_input=safety_checker_input.pixel_values.to("cuda")
51
- )
52
-
53
- return images, has_nsfw_concepts
54
-
55
- # Function
56
- @spaces.GPU(enable_queue=True)
57
- def generate_image(prompt, ckpt):
58
-
59
- checkpoint = checkpoints[ckpt][0]
60
- num_inference_steps = checkpoints[ckpt][1]
61
-
62
- if num_inference_steps==1:
63
- # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference.
64
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample")
65
- else:
66
- # Ensure sampler uses "trailing" timesteps.
67
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
68
-
69
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
70
- results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
71
-
72
- if SAFETY_CHECKER:
73
- images, has_nsfw_concepts = check_nsfw_images(results.images)
74
- if any(has_nsfw_concepts):
75
- gr.Warning("NSFW content detected.")
76
- return Image.new("RGB", (512, 512))
77
- return images[0]
78
- return results.images[0]
79
-
80
-
81
-
82
- # Gradio Interface
83
- description = """
84
- This demo utilizes the SDXL-Lightning model by ByteDance, which is a lightning-fast text-to-image generative model capable of producing high-quality images in 4 steps.
85
- As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning
86
- """
87
-
88
- with gr.Blocks(css="style.css") as demo:
89
- gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>")
90
- gr.Markdown(description)
91
- with gr.Group():
92
- with gr.Row():
93
- prompt = gr.Textbox(label='Enter you image prompt:', scale=8)
94
- ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
95
- submit = gr.Button(scale=1, variant='primary')
96
- img = gr.Image(label='SDXL-Lightning Generated Image')
97
-
98
- prompt.submit(fn=generate_image,
99
- inputs=[prompt, ckpt],
100
- outputs=img,
101
- )
102
- submit.click(fn=generate_image,
103
- inputs=[prompt, ckpt],
104
- outputs=img,
105
- )
106
-
107
- demo.queue().launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
 
 
 
 
 
 
3
 
 
4
 
5
+ def helloName(name):
6
+ return "hello" + name
 
 
 
 
 
 
 
7
 
8
+ demo = gr.Interface(fn=helloName, input="textbox", outputs="textbox")
9
+ demo.launch()
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,3 @@
1
  transformers
2
- diffusers
3
  torch
4
- accelerate
5
- gradio
6
- spaces
 
1
  transformers
 
2
  torch
3
+ gradio
 
 
safety_checker.py DELETED
@@ -1,137 +0,0 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import numpy as np
16
- import torch
17
- import torch.nn as nn
18
- from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
19
-
20
-
21
- def cosine_distance(image_embeds, text_embeds):
22
- normalized_image_embeds = nn.functional.normalize(image_embeds)
23
- normalized_text_embeds = nn.functional.normalize(text_embeds)
24
- return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
25
-
26
-
27
- class StableDiffusionSafetyChecker(PreTrainedModel):
28
- config_class = CLIPConfig
29
-
30
- _no_split_modules = ["CLIPEncoderLayer"]
31
-
32
- def __init__(self, config: CLIPConfig):
33
- super().__init__(config)
34
-
35
- self.vision_model = CLIPVisionModel(config.vision_config)
36
- self.visual_projection = nn.Linear(
37
- config.vision_config.hidden_size, config.projection_dim, bias=False
38
- )
39
-
40
- self.concept_embeds = nn.Parameter(
41
- torch.ones(17, config.projection_dim), requires_grad=False
42
- )
43
- self.special_care_embeds = nn.Parameter(
44
- torch.ones(3, config.projection_dim), requires_grad=False
45
- )
46
-
47
- self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
48
- self.special_care_embeds_weights = nn.Parameter(
49
- torch.ones(3), requires_grad=False
50
- )
51
-
52
- @torch.no_grad()
53
- def forward(self, clip_input, images):
54
- pooled_output = self.vision_model(clip_input)[1] # pooled_output
55
- image_embeds = self.visual_projection(pooled_output)
56
-
57
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
58
- special_cos_dist = (
59
- cosine_distance(image_embeds, self.special_care_embeds)
60
- .cpu()
61
- .float()
62
- .numpy()
63
- )
64
- cos_dist = (
65
- cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
66
- )
67
-
68
- result = []
69
- batch_size = image_embeds.shape[0]
70
- for i in range(batch_size):
71
- result_img = {
72
- "special_scores": {},
73
- "special_care": [],
74
- "concept_scores": {},
75
- "bad_concepts": [],
76
- }
77
-
78
- # increase this value to create a stronger `nfsw` filter
79
- # at the cost of increasing the possibility of filtering benign images
80
- adjustment = 0.0
81
-
82
- for concept_idx in range(len(special_cos_dist[0])):
83
- concept_cos = special_cos_dist[i][concept_idx]
84
- concept_threshold = self.special_care_embeds_weights[concept_idx].item()
85
- result_img["special_scores"][concept_idx] = round(
86
- concept_cos - concept_threshold + adjustment, 3
87
- )
88
- if result_img["special_scores"][concept_idx] > 0:
89
- result_img["special_care"].append(
90
- {concept_idx, result_img["special_scores"][concept_idx]}
91
- )
92
- adjustment = 0.01
93
-
94
- for concept_idx in range(len(cos_dist[0])):
95
- concept_cos = cos_dist[i][concept_idx]
96
- concept_threshold = self.concept_embeds_weights[concept_idx].item()
97
- result_img["concept_scores"][concept_idx] = round(
98
- concept_cos - concept_threshold + adjustment, 3
99
- )
100
- if result_img["concept_scores"][concept_idx] > 0:
101
- result_img["bad_concepts"].append(concept_idx)
102
-
103
- result.append(result_img)
104
-
105
- has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
106
-
107
- return has_nsfw_concepts
108
-
109
- @torch.no_grad()
110
- def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
111
- pooled_output = self.vision_model(clip_input)[1] # pooled_output
112
- image_embeds = self.visual_projection(pooled_output)
113
-
114
- special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
115
- cos_dist = cosine_distance(image_embeds, self.concept_embeds)
116
-
117
- # increase this value to create a stronger `nsfw` filter
118
- # at the cost of increasing the possibility of filtering benign images
119
- adjustment = 0.0
120
-
121
- special_scores = (
122
- special_cos_dist - self.special_care_embeds_weights + adjustment
123
- )
124
- # special_scores = special_scores.round(decimals=3)
125
- special_care = torch.any(special_scores > 0, dim=1)
126
- special_adjustment = special_care * 0.01
127
- special_adjustment = special_adjustment.unsqueeze(1).expand(
128
- -1, cos_dist.shape[1]
129
- )
130
-
131
- concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
132
- # concept_scores = concept_scores.round(decimals=3)
133
- has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
134
-
135
- images[has_nsfw_concepts] = 0.0 # black image
136
-
137
- return images, has_nsfw_concepts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
style.css DELETED
@@ -1,12 +0,0 @@
1
- .gradio-container {
2
- max-width: 690px! important;
3
- }
4
-
5
- #share-btn-container{padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 13rem; margin-left: auto;margin-top: 0.35em;}
6
- div#share-btn-container > div {flex-direction: row;background: black;align-items: center}
7
- #share-btn-container:hover {background-color: #060606}
8
- #share-btn {all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important;right:0;font-size: 15px;}
9
- #share-btn * {all: unset}
10
- #share-btn-container div:nth-child(-n+2){width: auto !important;min-height: 0px !important;}
11
- #share-btn-container .wrap {display: none !important}
12
- #share-btn-container.hidden {display: none!important}