multimodalart HF staff commited on
Commit
e6d1b54
1 Parent(s): f33c43f

Disable SC

Browse files
Files changed (1) hide show
  1. patch_sdxl.py +4 -30
patch_sdxl.py CHANGED
@@ -1,6 +1,3 @@
1
-
2
-
3
-
4
  import inspect
5
  from typing import Any, Callable, Dict, List, Optional, Union, Tuple
6
 
@@ -29,7 +26,6 @@ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOut
29
 
30
 
31
 
32
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
33
  from transformers import CLIPFeatureExtractor
34
  import numpy as np
35
  import torch
@@ -40,27 +36,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  torch_device = device
41
  torch_dtype = torch.float16
42
 
43
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(
44
- "CompVis/stable-diffusion-safety-checker"
45
- ).to(device)
46
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
47
- "openai/clip-vit-base-patch32"
48
- )
49
-
50
- def check_nsfw_images(
51
- images: list[Image.Image],
52
- ) -> list[bool]:
53
- safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
54
- images_np = [np.array(img) for img in images]
55
-
56
- _, has_nsfw_concepts = safety_checker(
57
- images=images_np,
58
- clip_input=safety_checker_input.pixel_values.to(torch_device),
59
- )
60
- return has_nsfw_concepts
61
-
62
-
63
-
64
 
65
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
 
@@ -569,12 +544,11 @@ class SDEmb(StableDiffusionXLPipeline):
569
  # apply watermark if available
570
  if self.watermark is not None:
571
  image = self.watermark.apply_watermark(image)
572
-
573
  image = self.image_processor.postprocess(image, output_type=output_type)
574
- maybe_nsfw = any(check_nsfw_images(image))
575
- if maybe_nsfw:
576
- print('This image could be NSFW so we return a blank image.')
577
- return StableDiffusionXLPipelineOutput(images=[Image.new('RGB', (1024, 1024))])
578
 
579
  # Offload all models
580
  self.maybe_free_model_hooks()
 
 
 
 
1
  import inspect
2
  from typing import Any, Callable, Dict, List, Optional, Union, Tuple
3
 
 
26
 
27
 
28
 
 
29
  from transformers import CLIPFeatureExtractor
30
  import numpy as np
31
  import torch
 
36
  torch_device = device
37
  torch_dtype = torch.float16
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
 
544
  # apply watermark if available
545
  if self.watermark is not None:
546
  image = self.watermark.apply_watermark(image)
 
547
  image = self.image_processor.postprocess(image, output_type=output_type)
548
+ #maybe_nsfw = any(check_nsfw_images(image))
549
+ #if maybe_nsfw:
550
+ # print('This image could be NSFW so we return a blank image.')
551
+ # return StableDiffusionXLPipelineOutput(images=[Image.new('RGB', (1024, 1024))])
552
 
553
  # Offload all models
554
  self.maybe_free_model_hooks()