import gradio as gr from urllib.request import urlopen from PIL import Image import timm import torch import time model = timm.create_model("hf_hub:Marqo/nsfw-image-detection-384", pretrained=True) model = model.eval() data_config = timm.data.resolve_model_data_config(model) transforms = timm.data.create_transform(**data_config, is_training=False) def predict(image): start_time = time.time() with torch.no_grad(): input_tensor = transforms(image).unsqueeze(0) output = model(input_tensor).softmax(dim=-1).cpu() class_names = model.pretrained_cfg["label_names"] result = {class_names[i]: float(output[0, i]) for i in range(len(class_names))} end_time = time.time() inference_time = end_time - start_time return result, f"Inference time: {inference_time:.2f} seconds" demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", height=512), outputs=[ gr.Label(num_top_classes=2), gr.Textbox(label="Inference Time") ], title="NSFW Image Detection", description=( "Upload an image to detect if it is **NSFW (Not Safe For Work)** or **Safe For Work (SFW)**.\n\n" "This app uses the [Marqo/nsfw-image-detection-384](https://huggingface.co/Marqo/nsfw-image-detection-384) " "image classification model from Hugging Face's `timm` library." ) ) if __name__ == "__main__": demo.launch()