Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torchvision.models as models | |
from torchvision import transforms | |
from torch import nn | |
from PIL import Image | |
transform = transforms.Compose([ | |
transforms.Resize((128, 128)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
model = models.mobilenet_v3_large(pretrained=True) | |
model.classifier[3] = nn.Sequential( | |
nn.Dropout(0.2), | |
nn.Linear(model.classifier[3].in_features, 2) | |
) | |
model = model.to("cpu") | |
model.load_state_dict(torch.load("cnn_model.pth", weights_only=True, map_location="cpu")) | |
model.eval() | |
label = ["nsfw", "safe"] | |
def inference(image): | |
image = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = model(image) | |
output = torch.nn.functional.softmax(output, dim=1) | |
predicted_class = torch.argmax(output, dim=1).item() | |
score = output[0][predicted_class] | |
if label[predicted_class] == "nsfw": | |
output = f'NSFW [{label[predicted_class]}:{score}]' | |
else: | |
output = f'SAFE [{label[predicted_class]}:{score}]' | |
return output | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
inputs = gr.Image(type="pil") | |
with gr.Column(): | |
btn = gr.Button("Cek") | |
pred = gr.Text(label="Prediction") | |
btn.click(fn=inference, inputs=inputs, outputs=pred) | |
demo.queue().launch() |