pseudotheos commited on
Commit
ce39840
·
1 Parent(s): 5de4ba0

hopefully this works lol

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -18,10 +18,11 @@ from diffusers import (
18
  StableDiffusionLatentUpscalePipeline,
19
  StableDiffusionImg2ImgPipeline,
20
  StableDiffusionControlNetImg2ImgPipeline,
 
21
  DPMSolverMultistepScheduler,
22
  EulerDiscreteScheduler
23
  )
24
- from transformers import AutoFeatureExtractor
25
  import random
26
  import time
27
  import tempfile
@@ -50,6 +51,8 @@ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
50
  BASE_MODEL,
51
  controlnet=controlnet,
52
  vae=vae,
 
 
53
  torch_dtype=torch.float16,
54
  ).to("cuda")
55
  image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
@@ -180,7 +183,10 @@ def inference(
180
  logger.debug("Output Types: generated_image=%s", type(None))
181
  logger.debug("Content of out_image: %s", out_image)
182
  logger.debug("Structure of out_image: %s", dir(out_image))
183
- return out_image["images"][0]
 
 
 
184
 
185
 
186
  except Exception as e:
 
18
  StableDiffusionLatentUpscalePipeline,
19
  StableDiffusionImg2ImgPipeline,
20
  StableDiffusionControlNetImg2ImgPipeline,
21
+ StableDiffusionSafetyChecker,
22
  DPMSolverMultistepScheduler,
23
  EulerDiscreteScheduler
24
  )
25
+ from transformers import AutoFeatureExtractor, CLIPFeatureExtractor
26
  import random
27
  import time
28
  import tempfile
 
51
  BASE_MODEL,
52
  controlnet=controlnet,
53
  vae=vae,
54
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
55
+ feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
56
  torch_dtype=torch.float16,
57
  ).to("cuda")
58
  image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
 
183
  logger.debug("Output Types: generated_image=%s", type(None))
184
  logger.debug("Content of out_image: %s", out_image)
185
  logger.debug("Structure of out_image: %s", dir(out_image))
186
+ if not out_image.nsfw_content_detected[0]:
187
+ return out_image["images"][0]
188
+ else:
189
+ print("NSFW detected. Nice try.")
190
 
191
 
192
  except Exception as e: