Shuang59 commited on
Commit
0773c5a
β€’
1 Parent(s): 04c4c30

Upload safety_checker.py

Browse files
Files changed (1) hide show
  1. safety_checker.py +80 -0
safety_checker.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
6
+
7
+ from diffusers.utils import logging
8
+
9
+
10
+ logger = logging.get_logger(__name__)
11
+
12
+
13
+ def cosine_distance(image_embeds, text_embeds):
14
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
15
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
16
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.T)
17
+
18
+
19
+ class StableDiffusionSafetyChecker(PreTrainedModel):
20
+ config_class = CLIPConfig
21
+
22
+ def __init__(self, config: CLIPConfig):
23
+ super().__init__(config)
24
+
25
+ self.vision_model = CLIPVisionModel(config.vision_config)
26
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
27
+
28
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
29
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
30
+
31
+ self.register_buffer("concept_embeds_weights", torch.ones(17))
32
+ self.register_buffer("special_care_embeds_weights", torch.ones(3))
33
+
34
+ @torch.no_grad()
35
+ def forward(self, clip_input, images):
36
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
37
+ image_embeds = self.visual_projection(pooled_output)
38
+
39
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
40
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
41
+
42
+ result = []
43
+ batch_size = image_embeds.shape[0]
44
+ for i in range(batch_size):
45
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
46
+
47
+ # increase this value to create a stronger `nfsw` filter
48
+ # at the cost of increasing the possibility of filtering benign images
49
+ adjustment = 0.0
50
+
51
+ for concet_idx in range(len(special_cos_dist[0])):
52
+ concept_cos = special_cos_dist[i][concet_idx]
53
+ concept_threshold = self.special_care_embeds_weights[concet_idx].item()
54
+ result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
55
+ if result_img["special_scores"][concet_idx] > 0:
56
+ result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
57
+ adjustment = 0.01
58
+
59
+ for concet_idx in range(len(cos_dist[0])):
60
+ concept_cos = cos_dist[i][concet_idx]
61
+ concept_threshold = self.concept_embeds_weights[concet_idx].item()
62
+ result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
63
+ if result_img["concept_scores"][concet_idx] > 0:
64
+ result_img["bad_concepts"].append(concet_idx)
65
+
66
+ result.append(result_img)
67
+
68
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
69
+
70
+ for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
71
+ if has_nsfw_concept:
72
+ images[idx] = np.zeros(images[idx].shape) # black image
73
+
74
+ if any(has_nsfw_concepts):
75
+ logger.warning(
76
+ "Potential NSFW content was detected in one or more images. A black image will be returned instead."
77
+ " Try again with a different prompt and/or seed."
78
+ )
79
+
80
+ return images, has_nsfw_concepts