|
import gradio as gr |
|
from urllib.request import urlopen |
|
from PIL import Image |
|
import timm |
|
import torch |
|
|
|
|
|
model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True) |
|
model = model.eval() |
|
|
|
|
|
data_config = timm.data.resolve_model_data_config(model) |
|
transforms = timm.data.create_transform(**data_config, is_training=False) |
|
|
|
|
|
def predict(image): |
|
with torch.no_grad(): |
|
|
|
input_tensor = transforms(image).unsqueeze(0) |
|
|
|
output = model(input_tensor).softmax(dim=-1).cpu() |
|
|
|
class_names = model.pretrained_cfg["label_names"] |
|
|
|
result = {class_names[i]: float(output[0, i]) for i in range(len(class_names))} |
|
return result |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Label(num_top_classes=3), |
|
title="NSFW Image Detection", |
|
description="Upload an image to detect if it is NSFW or Safe for Work." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|