bram-w commited on
Commit
93d448d
1 Parent(s): 47a05f1

safety check

Browse files
Files changed (1) hide show
  1. edict_functions.py +29 -1
edict_functions.py CHANGED
@@ -17,6 +17,8 @@ import os
17
  from torchvision import datasets
18
  import pickle
19
 
 
 
20
  # StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
21
  use_half_prec = True
22
  if use_half_prec:
@@ -66,7 +68,30 @@ else:
66
  clip.double().to(device)
67
  print("Loaded all models")
68
 
69
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  def EDICT_editing(im_path,
@@ -597,6 +622,9 @@ def baseline_stablediffusion(prompt="",
597
 
598
  image = (image / 2 + 0.5).clamp(0, 1)
599
  image = image.cpu().permute(0, 2, 3, 1).numpy()
 
 
 
600
  image = (image[0] * 255).round().astype("uint8")
601
  return Image.fromarray(image)
602
  ####################################
 
17
  from torchvision import datasets
18
  import pickle
19
 
20
+
21
+
22
  # StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
23
  use_half_prec = True
24
  if use_half_prec:
 
68
  clip.double().to(device)
69
  print("Loaded all models")
70
 
71
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
72
+ from transformers import AutoFeatureExtractor
73
+ # load safety model
74
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
75
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
76
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
77
+ def load_replacement(x):
78
+ try:
79
+ hwc = x.shape
80
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
81
+ y = (np.array(y)/255.0).astype(x.dtype)
82
+ assert y.shape == x.shape
83
+ return y
84
+ except Exception:
85
+ return x
86
+ def check_safety(x_image):
87
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
88
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
89
+ assert x_checked_image.shape[0] == len(has_nsfw_concept)
90
+ for i in range(len(has_nsfw_concept)):
91
+ if has_nsfw_concept[i]:
92
+ # x_checked_image[i] = load_replacement(x_checked_image[i])
93
+ x_checked_image[i] *= 0 # load_replacement(x_checked_image[i])
94
+ return x_checked_image, has_nsfw_concept
95
 
96
 
97
  def EDICT_editing(im_path,
 
622
 
623
  image = (image / 2 + 0.5).clamp(0, 1)
624
  image = image.cpu().permute(0, 2, 3, 1).numpy()
625
+
626
+ image, _ = check_safety(image)
627
+
628
  image = (image[0] * 255).round().astype("uint8")
629
  return Image.fromarray(image)
630
  ####################################