Spaces:
Running
on
Zero
Running
on
Zero
Commit
•
07822f8
1
Parent(s):
d1c0879
Native safety checker
Browse files
app.py
CHANGED
@@ -23,7 +23,7 @@ import user_history
|
|
23 |
from illusion_style import css
|
24 |
import os
|
25 |
from transformers import CLIPImageProcessor
|
26 |
-
from safety_checker import StableDiffusionSafetyChecker
|
27 |
|
28 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
29 |
|
@@ -49,16 +49,16 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
49 |
).to("cuda")
|
50 |
|
51 |
# Function to check NSFW images
|
52 |
-
def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
|
53 |
-
if SAFETY_CHECKER_ENABLED:
|
54 |
-
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
55 |
-
has_nsfw_concepts = safety_checker(
|
56 |
-
images=[images],
|
57 |
-
clip_input=safety_checker_input.pixel_values.to("cuda")
|
58 |
-
)
|
59 |
-
return images, has_nsfw_concepts
|
60 |
-
else:
|
61 |
-
return images, [False] * len(images)
|
62 |
|
63 |
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
64 |
#main_pipe.unet.to(memory_format=torch.channels_last)
|
@@ -284,4 +284,4 @@ with gr.Blocks(css=css) as app_with_history:
|
|
284 |
app_with_history.queue(max_size=20,api_open=False )
|
285 |
|
286 |
if __name__ == "__main__":
|
287 |
-
app_with_history.launch(max_threads=400)
|
|
|
23 |
from illusion_style import css
|
24 |
import os
|
25 |
from transformers import CLIPImageProcessor
|
26 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
27 |
|
28 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
29 |
|
|
|
49 |
).to("cuda")
|
50 |
|
51 |
# Function to check NSFW images
|
52 |
+
#def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
|
53 |
+
# if SAFETY_CHECKER_ENABLED:
|
54 |
+
# safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
55 |
+
# has_nsfw_concepts = safety_checker(
|
56 |
+
# images=[images],
|
57 |
+
# clip_input=safety_checker_input.pixel_values.to("cuda")
|
58 |
+
# )
|
59 |
+
# return images, has_nsfw_concepts
|
60 |
+
# else:
|
61 |
+
# return images, [False] * len(images)
|
62 |
|
63 |
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
64 |
#main_pipe.unet.to(memory_format=torch.channels_last)
|
|
|
284 |
app_with_history.queue(max_size=20,api_open=False )
|
285 |
|
286 |
if __name__ == "__main__":
|
287 |
+
app_with_history.launch(max_threads=400)
|