File size: 751 Bytes
79eeb88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import Any
from transformers import pipeline
from constants import SAFETY_CHECKER_MODEL
class SafetyChecker:
"""A class to check if an image is NSFW or not."""
def __init__(
self,
mode_id: str = SAFETY_CHECKER_MODEL,
):
self.classifier = pipeline(
"image-classification",
model=mode_id,
)
def is_safe(
self,
image: Any,
) -> bool:
pred = self.classifier(image)
scores = {label["label"]: label["score"] for label in pred}
nsfw_score = scores.get("nsfw", 0)
normal_score = scores.get("normal", 0)
print(f"NSFW score: {nsfw_score}, Normal score: {normal_score}")
return normal_score > nsfw_score
|