File size: 1,775 Bytes
44b4267
 
 
 
 
 
c056859
44b4267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a907c1
 
44b4267
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import gradio as gr
import os
import torch.nn.functional as F
import torch
from torchvision import transforms

model = torch.load("./model.pth", map_location=torch.device("cpu"))
IMG_SIZE = 224
MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]

transforms_test = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ]
)

MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]

def predict_image(image):
    transformed_tensor = torch.unsqueeze(transforms_test(image), 0)
    logits = model(transformed_tensor)
    probability = torch.flatten(F.softmax(logits, dim=1)).detach().cpu().numpy()
    print(probability)
    labels = {A: B.item() for A, B in zip(MASK_LABEL, probability)}
    sorted_labels = dict(sorted(labels.items(), key=lambda item: item[1], reverse=True))
    print(sorted_labels)
    return sorted_labels

title = "ViT Mask Detection"
description = "<p style='text-align: center'>Gradio demo for ViT-16 Mask Image Classification created by <a href='https://github.com/stevenlimcorn'>Steven Limcorn</a></p>"
article = "<p style='text-align: center'>An Application made by stevenlimcorn. Notebook access at: <a href='https://github.com/stevenlimcorn/Mask-Classification'>Mask Classification</a></p>"

demo = gr.Interface(predict_image, 
                    inputs=gr.Image(label="Input Image", type="pil", source="webcam"),
                    outputs=gr.Label(), title=title, description=description, article=article
                    )

demo.launch()