|
from PIL import Image |
|
import gradio as gr |
|
from nsfw_image_detector import NSFWDetector |
|
import torch |
|
|
|
|
|
classifier_nsfw = NSFWDetector(dtype=torch.bfloat16, device="cpu") |
|
|
|
|
|
def classify_image(image, confidence_level): |
|
|
|
result_nsfw_proba = classifier_nsfw.predict_proba(image) |
|
is_nsfw_method = result_nsfw_proba[0][confidence_level] >= 0.5 |
|
|
|
|
|
proba_dict = result_nsfw_proba[0] |
|
nsfw_proba_str = "NSFW Probability Scores:\n" |
|
for level, score in proba_dict.items(): |
|
nsfw_proba_str += f"{level.value.title()}: {score:.4f}\n" |
|
|
|
|
|
is_nsfw_str = f"NSFW Classification ({confidence_level.title()}):\n" |
|
is_nsfw_str += "π΄ True" if is_nsfw_method else "π’ False" |
|
|
|
|
|
return nsfw_proba_str, is_nsfw_str |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_image, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload an image"), |
|
gr.Dropdown( |
|
choices=["low", "medium", "high"], |
|
value="medium", |
|
label="Low is the most restrictive, high is the least restrictive" |
|
) |
|
], |
|
outputs=[ |
|
gr.Textbox(label="NSFW Categories Scores", lines=3), |
|
gr.Textbox(label="NSFW Classification", lines=2), |
|
], |
|
title="NSFW Image Classifier", |
|
description="Upload an image and select a confidence level to get a prediction using the Freepik/nsfw_image_detector model." |
|
) |
|
|
|
|
|
demo.launch() |