mask-detection / app.py
bnsapa's picture
Update app.py
5ee4460
raw
history blame
1.61 kB
import gradio as gr
from torchvision import models
import torch.nn as nn
import torch
import os
from PIL import Image
from torchvision.transforms import transforms
from dotenv import load_dotenv
load_dotenv()
share = os.getenv("SHARE", False)
pretrained_model = models.vgg19(pretrained=True)
class NeuralNet(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
pretrained_model,
nn.Flatten(),
nn.Linear(1000, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuralNet()
model.load_state_dict(torch.load("mask_detection.pth", map_location=device))
model = model.to(device)
transform=transforms.Compose([
transforms.Resize((150,150)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
def maskDetection(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image.save("input.png")
image = Image.open("input.png")
input = transform(image).unsqueeze(0)
output = model(input.to(device))
probability = output.item()
if probability < 0.5:
return "Person in the pic has mask"
else:
return "Person in the pic does not have mask"
iface = gr.Interface(fn=maskDetection, inputs="image", outputs="text", title="Mask Detection")
if __name__ == "__main__":
if share:
server = "0.0.0.0"
else:
server = "127.0.0.1"
iface.launch(server_name = server)