Spaces:
Running
Running
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
from transformers import CLIPFeatureExtractor | |
import numpy as np | |
import torch | |
from PIL import Image | |
from typing import Optional, Tuple, Union | |
device = None | |
torch_device = None | |
torch_dtype = None | |
safety_checker = None | |
feature_extractor = None | |
def load_model(): | |
global device, torch_device, torch_dtype, safety_checker, feature_extractor | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch_device = device | |
torch_dtype = torch.float16 | |
safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
"CompVis/stable-diffusion-safety-checker" | |
).to(device) | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"openai/clip-vit-base-patch32" | |
) | |
def check(image): | |
if not image: | |
return None | |
images = [image] | |
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) | |
images_np = [np.array(img) for img in images] | |
_, has_nsfw_concepts = safety_checker( | |
images=images_np, | |
clip_input=safety_checker_input.pixel_values.to(torch_device), | |
) | |
return has_nsfw_concepts[0] | |