File size: 1,608 Bytes
5dbd83c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87f1a0c
5dbd83c
c1cda26
5dbd83c
 
 
 
4a7a631
 
 
 
5dbd83c
 
 
 
 
 
4a7a631
5dbd83c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from transformers import AutoImageProcessor, SiglipForImageClassification
from PIL import Image
import torch.nn.functional as F
import gradio as gr

# Load model and processor from Hugging Face Hub
model_path = "Ateeqq/nsfw-image-detection"
processor = AutoImageProcessor.from_pretrained(model_path)
model = SiglipForImageClassification.from_pretrained(model_path)
model.eval()

def predict(image):
    # Convert to RGB and preprocess
    image = Image.fromarray(image).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")

    # Inference
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = F.softmax(logits, dim=1)[0].tolist()

    # Return dictionary: class name -> confidence (no extra keys!)
    labels = [model.config.id2label[i] for i in range(len(probs))]
    return {labels[i]: float(f"{probs[i]:.8f}") for i in range(len(labels))}

# Gradio Interface
def main():
    description = "NSFW Image Detection using SigLIP2 Safety Classifier"

    model_card_link = "[🧠 View Model on Hugging Face](https://huggingface.co/Ateeqq/nsfw-image-detection)"
    article_link = "[πŸ“– Read Training Article](https://exnrt.com/blog/ai/fine-tuning-siglip2/)"

    iface = gr.Interface(
        fn=predict,
        inputs=gr.Image(type="numpy", label="Upload Image"),
        outputs=gr.Label(num_top_classes=3, label="Predictions"),
        title="NSFW Image Detector",
        description=description,
        article=f"{model_card_link}<br>{article_link}",
        allow_flagging="never"
    )
    iface.launch()

if __name__ == "__main__":
    main()