|
import gradio as gr |
|
from urllib.request import urlopen |
|
from PIL import Image |
|
import timm |
|
import torch |
|
import time |
|
|
|
model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True) |
|
model = model.eval() |
|
|
|
data_config = timm.data.resolve_model_data_config(model) |
|
transforms = timm.data.create_transform(**data_config, is_training=False) |
|
|
|
def predict(image): |
|
start_time = time.time() |
|
with torch.no_grad(): |
|
input_tensor = transforms(image).unsqueeze(0) |
|
output = model(input_tensor).softmax(dim=-1).cpu() |
|
class_names = model.pretrained_cfg["label_names"] |
|
result = {class_names[i]: float(output[0, i]) for i in range(len(class_names))} |
|
end_time = time.time() |
|
inference_time = end_time - start_time |
|
return result, f"Inference time: {inference_time:.2f} seconds" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", height=512), |
|
outputs=[ |
|
gr.Label(num_top_classes=2), |
|
gr.Textbox(label="Inference Time") |
|
], |
|
title="NSFW Image Detection", |
|
description=( |
|
"Upload an image to detect if it is **NSFW (Not Safe For Work)** or **Safe For Work (SFW)**.\n\n" |
|
"This app uses the [Marqo/nsfw-image-detection-384](https://huggingface.co/Marqo/nsfw-image-detection-384) " |
|
"image classification model from Hugging Face's `timm` library." |
|
) |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|