OwenElliott's picture
Create app.py
ec7d005 verified
raw
history blame
1.16 kB
import gradio as gr
from urllib.request import urlopen
from PIL import Image
import timm
import torch
# Load the model
model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True)
model = model.eval()
# Prepare the data transformation
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
# Prediction function
def predict(image):
with torch.no_grad():
# Transform the image
input_tensor = transforms(image).unsqueeze(0)
# Run the model
output = model(input_tensor).softmax(dim=-1).cpu()
# Get class names
class_names = model.pretrained_cfg["label_names"]
# Create the result dictionary
result = {class_names[i]: float(output[0, i]) for i in range(len(class_names))}
return result
# Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="NSFW Image Detection",
description="Upload an image to detect if it is NSFW or Safe for Work."
)
if __name__ == "__main__":
interface.launch()