| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from torchvision import models, transforms |
| from PIL import Image |
|
|
| |
| model_path = "nsfw_classifier.pkl" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model = models.resnet18(weights=None) |
| model.fc = nn.Linear(model.fc.in_features, 2) |
|
|
| |
| model.load_state_dict(torch.load(model_path, map_location=device), strict=True) |
| model.eval() |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
| ]) |
|
|
| |
| def classify_image(image_path): |
| try: |
| |
| image = Image.open(image_path).convert("RGB") |
| image = transform(image).unsqueeze(0) |
|
|
| |
| with torch.no_grad(): |
| outputs = model(image.to(device)) |
| _, predicted = outputs.max(1) |
|
|
| |
| classes = ["Safe for Work", "Not Safe for Work (NSFW)"] |
| return classes[predicted.item()] |
| except Exception as e: |
| return f"Error processing image: {e}" |
|
|
| |
| interface = gr.Interface( |
| fn=classify_image, |
| inputs=gr.Image(type="filepath"), |
| outputs="text", |
| title="NSFW Image Classifier", |
| description="Upload an image to classify whether it is Safe for Work (SFW) or Not Safe for Work (NSFW).", |
| examples=[ |
| ["example1.jpg"], |
| ["example2.jpg"] |
| ] |
| ) |
|
|
| |
| interface.launch() |
|
|