|
import torch |
|
from torchvision.models.detection import fasterrcnn_resnet50_fpn |
|
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor |
|
from torchvision.transforms import functional as F |
|
from PIL import Image, ImageDraw |
|
import gradio as gr |
|
|
|
|
|
COCO_CLASSES = { |
|
0: "Background", |
|
1: "Without Mask", |
|
2: "With Mask", |
|
3: "Incorrect Mask" |
|
} |
|
|
|
|
|
def get_model(num_classes=4): |
|
model = fasterrcnn_resnet50_fpn(weights=None) |
|
in_features = model.roi_heads.box_predictor.cls_score.in_features |
|
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) |
|
return model |
|
|
|
|
|
device = torch.device("cpu") |
|
model = get_model() |
|
model.load_state_dict(torch.load("fasterrcnn_resnet50_epoch_4.pth", map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
def predict(image): |
|
image_tensor = F.to_tensor(image).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
prediction = model(image_tensor) |
|
|
|
boxes = prediction[0]["boxes"] |
|
labels = prediction[0]["labels"] |
|
scores = prediction[0]["scores"] |
|
|
|
draw = ImageDraw.Draw(image) |
|
threshold = 0.5 |
|
|
|
for box, label, score in zip(boxes, labels, scores): |
|
if score > threshold: |
|
x1, y1, x2, y2 = box |
|
class_name = COCO_CLASSES.get(label.item(), "Unknown") |
|
draw.rectangle([x1, y1, x2, y2], outline="red", width=3) |
|
draw.text((x1, y1), f"{class_name} ({score:.2f})", fill="red") |
|
|
|
return image |
|
|
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload a Face Image"), |
|
outputs=gr.Image(type="pil", label="Detection Result"), |
|
title="Face Mask Detection - Faster R-CNN", |
|
description="Detects faces with mask, without mask, or incorrectly worn mask." |
|
).launch() |
|
|