mask-detection / app.py
StevenLimcorn's picture
test
4a907c1
import gradio as gr
import os
import torch.nn.functional as F
import torch
from torchvision import transforms
model = torch.load("./model.pth", map_location=torch.device("cpu"))
IMG_SIZE = 224
MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]
transforms_test = transforms.Compose(
[
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]
def predict_image(image):
transformed_tensor = torch.unsqueeze(transforms_test(image), 0)
logits = model(transformed_tensor)
probability = torch.flatten(F.softmax(logits, dim=1)).detach().cpu().numpy()
print(probability)
labels = {A: B.item() for A, B in zip(MASK_LABEL, probability)}
sorted_labels = dict(sorted(labels.items(), key=lambda item: item[1], reverse=True))
print(sorted_labels)
return sorted_labels
title = "ViT Mask Detection"
description = "<p style='text-align: center'>Gradio demo for ViT-16 Mask Image Classification created by <a href='https://github.com/stevenlimcorn'>Steven Limcorn</a></p>"
article = "<p style='text-align: center'>An Application made by stevenlimcorn. Notebook access at: <a href='https://github.com/stevenlimcorn/Mask-Classification'>Mask Classification</a></p>"
demo = gr.Interface(predict_image,
inputs=gr.Image(label="Input Image", type="pil", source="webcam"),
outputs=gr.Label(), title=title, description=description, article=article
)
demo.launch()