safetychecker / fn.py
aka7774's picture
Update fn.py
2018475 verified
raw
history blame
1.19 kB
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from transformers import CLIPFeatureExtractor
import numpy as np
import torch
from PIL import Image
from typing import Optional, Tuple, Union
device = None
torch_device = None
torch_dtype = None
safety_checker = None
feature_extractor = None
def load_model():
global device, torch_device, torch_dtype, safety_checker, feature_extractor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_device = device
torch_dtype = torch.float16
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker"
).to(device)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
"openai/clip-vit-base-patch32"
)
def check(image):
if not image:
return None
images = [image]
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
images_np = [np.array(img) for img in images]
_, has_nsfw_concepts = safety_checker(
images=images_np,
clip_input=safety_checker_input.pixel_values.to(torch_device),
)
return has_nsfw_concepts[0]