File size: 2,117 Bytes
43b7e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPConfig, CLIPVisionModelWithProjection, PreTrainedModel
from ...utils import logging
logger = logging.get_logger(__name__)
class IFSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModelWithProjection(config.vision_config)
self.p_head = nn.Linear(config.vision_config.projection_dim, 1)
self.w_head = nn.Linear(config.vision_config.projection_dim, 1)
@torch.no_grad()
def forward(self, clip_input, images, p_threshold=0.5, w_threshold=0.5):
image_embeds = self.vision_model(clip_input)[0]
nsfw_detected = self.p_head(image_embeds)
nsfw_detected = nsfw_detected.flatten()
nsfw_detected = nsfw_detected > p_threshold
nsfw_detected = nsfw_detected.tolist()
if any(nsfw_detected):
logger.warning(
"Potential NSFW content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)
for idx, nsfw_detected_ in enumerate(nsfw_detected):
if nsfw_detected_:
images[idx] = np.zeros(images[idx].shape)
watermark_detected = self.w_head(image_embeds)
watermark_detected = watermark_detected.flatten()
watermark_detected = watermark_detected > w_threshold
watermark_detected = watermark_detected.tolist()
if any(watermark_detected):
logger.warning(
"Potential watermarked content was detected in one or more images. A black image will be returned instead."
" Try again with a different prompt and/or seed."
)
for idx, watermark_detected_ in enumerate(watermark_detected):
if watermark_detected_:
images[idx] = np.zeros(images[idx].shape)
return images, nsfw_detected, watermark_detected
|