apolinario commited on
Commit
6624621
1 Parent(s): 8dfb33f

Better NSFW filter

Browse files
Files changed (1) hide show
  1. app.py +58 -8
app.py CHANGED
@@ -42,15 +42,62 @@ def load_model_from_config(config, ckpt, verbose=False):
42
  model.eval()
43
  return model
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
46
  model = load_model_from_config(config, f"txt2img-f8-large.ckpt")
47
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
48
  model = model.to(device)
 
49
  #NSFW CLIP Filter
50
- clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='laion400m_e32')
51
- text = tokenizer.tokenize(["NSFW", "adult content", "porn", "naked people","genitalia","penis","vagina"])
52
- with torch.no_grad():
53
- text_features = clip_model.encode_text(text)
54
 
55
  def run(prompt, steps, width, height, images, scale):
56
  opt = argparse.Namespace(
@@ -108,10 +155,13 @@ def run(prompt, steps, width, height, images, scale):
108
  for x_sample in x_samples_ddim:
109
  x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
110
  image_vector = Image.fromarray(x_sample.astype(np.uint8))
111
- image = preprocess(image_vector).unsqueeze(0)
112
- image_features = clip_model.encode_image(image)
113
- sims = image_features @ text_features.T
114
- if(sims.max()<18):
 
 
 
115
  all_samples_images.append(image_vector)
116
  else:
117
  return(None,None,"Sorry, potential NSFW content was detected on your outputs by our NSFW detection model. Try again with different prompts. If you feel your prompt was not supposed to give NSFW outputs, this may be due to a bias in the model. Read more about biases in the Biases Acknowledgment section below.")
 
42
  model.eval()
43
  return model
44
 
45
+ def load_safety_model(clip_model):
46
+ """load the safety model"""
47
+ import autokeras as ak # pylint: disable=import-outside-toplevel
48
+ from tensorflow.keras.models import load_model # pylint: disable=import-outside-toplevel
49
+ from os.path import expanduser # pylint: disable=import-outside-toplevel
50
+
51
+ home = expanduser("~")
52
+
53
+ cache_folder = home + "/.cache/clip_retrieval/" + clip_model.replace("/", "_")
54
+ if clip_model == "ViT-L/14":
55
+ model_dir = cache_folder + "/clip_autokeras_binary_nsfw"
56
+ dim = 768
57
+ elif clip_model == "ViT-B/32":
58
+ model_dir = cache_folder + "/clip_autokeras_nsfw_b32"
59
+ dim = 512
60
+ else:
61
+ raise ValueError("Unknown clip model")
62
+ if not os.path.exists(model_dir):
63
+ os.makedirs(cache_folder, exist_ok=True)
64
+
65
+ from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel
66
+
67
+ path_to_zip_file = cache_folder + "/clip_autokeras_binary_nsfw.zip"
68
+ if clip_model == "ViT-L/14":
69
+ url_model = "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip"
70
+ elif clip_model == "ViT-B/32":
71
+ url_model = (
72
+ "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_nsfw_b32.zip"
73
+ )
74
+ else:
75
+ raise ValueError("Unknown model {}".format(clip_model))
76
+ urlretrieve(url_model, path_to_zip_file)
77
+ import zipfile # pylint: disable=import-outside-toplevel
78
+
79
+ with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
80
+ zip_ref.extractall(cache_folder)
81
+
82
+ loaded_model = load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS)
83
+ loaded_model.predict(np.random.rand(10 ** 3, dim).astype("float32"), batch_size=10 ** 3)
84
+
85
+ return loaded_model
86
+
87
+ def is_unsafe(safety_model, embeddings, threshold=0.5):
88
+ """find unsafe embeddings"""
89
+ nsfw_values = safety_model.predict(embeddings, batch_size=embeddings.shape[0])
90
+ x = np.array([e[0] for e in nsfw_values])
91
+ return True if x > threshold else False
92
+
93
  config = OmegaConf.load("latent-diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml")
94
  model = load_model_from_config(config, f"txt2img-f8-large.ckpt")
95
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
96
  model = model.to(device)
97
+
98
  #NSFW CLIP Filter
99
+ safety_model = load_safety_model("ViT-B/32")
100
+ clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='openai')
 
 
101
 
102
  def run(prompt, steps, width, height, images, scale):
103
  opt = argparse.Namespace(
 
155
  for x_sample in x_samples_ddim:
156
  x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
157
  image_vector = Image.fromarray(x_sample.astype(np.uint8))
158
+ image_preprocess = preprocess(image_vector).unsqueeze(0)
159
+ with torch.no_grad():
160
+ image_features = clip_model.encode_image(image_preprocess)
161
+ image_features /= image_features.norm(dim=-1, keepdim=True)
162
+ query = image_features.cpu().detach().numpy().astype("float32")
163
+ unsafe = is_unsafe(safety_model,query,0.5)
164
+ if(not unsafe):
165
  all_samples_images.append(image_vector)
166
  else:
167
  return(None,None,"Sorry, potential NSFW content was detected on your outputs by our NSFW detection model. Try again with different prompts. If you feel your prompt was not supposed to give NSFW outputs, this may be due to a bias in the model. Read more about biases in the Biases Acknowledgment section below.")