Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,7 @@ from share_btn import community_icon_html, loading_icon_html, share_js
|
|
22 |
import user_history
|
23 |
from illusion_style import css
|
24 |
import os
|
25 |
-
from transformers import
|
26 |
from safety_checker import StableDiffusionSafetyChecker
|
27 |
|
28 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
@@ -34,15 +34,17 @@ controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrco
|
|
34 |
# Initialize the safety checker conditionally
|
35 |
SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
36 |
safety_checker = None
|
|
|
37 |
if SAFETY_CHECKER_ENABLED:
|
38 |
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
|
39 |
-
feature_extractor =
|
40 |
|
41 |
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
42 |
BASE_MODEL,
|
43 |
controlnet=controlnet,
|
44 |
vae=vae,
|
45 |
safety_checker=safety_checker,
|
|
|
46 |
torch_dtype=torch.float16,
|
47 |
).to("cuda")
|
48 |
|
@@ -57,7 +59,7 @@ def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], lis
|
|
57 |
return images, has_nsfw_concepts
|
58 |
else:
|
59 |
return images, [False] * len(images)
|
60 |
-
|
61 |
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
62 |
#main_pipe.unet.to(memory_format=torch.channels_last)
|
63 |
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
|
|
22 |
import user_history
|
23 |
from illusion_style import css
|
24 |
import os
|
25 |
+
from transformers import CLIPImageProcessor
|
26 |
from safety_checker import StableDiffusionSafetyChecker
|
27 |
|
28 |
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
|
|
|
34 |
# Initialize the safety checker conditionally
|
35 |
SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
36 |
safety_checker = None
|
37 |
+
feature_extractor = None
|
38 |
if SAFETY_CHECKER_ENABLED:
|
39 |
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
|
40 |
+
feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
41 |
|
42 |
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
43 |
BASE_MODEL,
|
44 |
controlnet=controlnet,
|
45 |
vae=vae,
|
46 |
safety_checker=safety_checker,
|
47 |
+
feature_extractor=feature_extractor,
|
48 |
torch_dtype=torch.float16,
|
49 |
).to("cuda")
|
50 |
|
|
|
59 |
return images, has_nsfw_concepts
|
60 |
else:
|
61 |
return images, [False] * len(images)
|
62 |
+
|
63 |
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|
64 |
#main_pipe.unet.to(memory_format=torch.channels_last)
|
65 |
#main_pipe.unet = torch.compile(main_pipe.unet, mode="reduce-overhead", fullgraph=True)
|