Spaces:
Build error
Build error
import csv | |
import os.path | |
import time | |
import cv2 | |
import gdown | |
import numpy as np | |
import streamlit as st | |
import torch | |
from PIL import Image | |
def load_classes(csv_reader): | |
""" | |
Load classes from csv. | |
:param csv_reader: csv | |
:return: | |
""" | |
result = {} | |
for line, row in enumerate(csv_reader): | |
line += 1 | |
try: | |
class_name, class_id = row | |
except ValueError: | |
raise (ValueError('line {}: format should be \'class_name,class_id\''.format(line))) | |
class_id = int(class_id) | |
if class_name in result: | |
raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) | |
result[class_name] = class_id | |
return result | |
def draw_caption(image, box, caption): | |
""" | |
Draw caption and bbox on image. | |
:param image: image | |
:param box: bounding box | |
:param caption: caption | |
:return: | |
""" | |
b = np.array(box).astype(int) | |
cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2) | |
cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1) | |
def load_labels(): | |
""" | |
Loads labels. | |
:return: | |
""" | |
with open("data/labels.csv", 'r') as f: | |
classes = load_classes(csv.reader(f, delimiter=',')) | |
labels = {} | |
for key, value in classes.items(): | |
labels[value] = key | |
return labels | |
def download_models(ids): | |
""" | |
Download all models. | |
:param ids: name and links of models | |
:return: | |
""" | |
# Download model from drive if not stored locally | |
with st.spinner('Downloading models, this may take a minute...'): | |
for key in ids: | |
if not os.path.isfile(f"model/{key}.pt"): | |
url = f"https://drive.google.com/uc?id={ids[key]}" | |
gdown.download(url=url, output=f"model/{key}.pt") | |
def load_model(model_path, prefix: str = 'model/'): | |
""" | |
Load model. | |
:param model_path: path to inference model | |
:param prefix: model prefix if needed | |
:return: | |
""" | |
# Load model | |
if torch.cuda.is_available(): | |
model = torch.load(f"{prefix}{model_path}.pt").to('cuda') | |
else: | |
model = torch.load(f"{prefix}{model_path}.pt", map_location=torch.device('cpu')) | |
model = model.module.cpu() | |
model.training = False | |
model.eval() | |
return model | |
def process_img(model, image, labels, caption: bool = True, thickness=2): | |
""" | |
Process img given a model. | |
:param caption: whether to use captions or not | |
:param image: image to process | |
:param model: inference model | |
:param labels: given labels | |
:param thickness: thickness of bboxes | |
:return: | |
""" | |
image_orig = image.copy() | |
rows, cols, cns = image.shape | |
smallest_side = min(rows, cols) | |
# Rescale the image | |
min_side = 608 | |
max_side = 1024 | |
scale = min_side / smallest_side | |
# Check if the largest side is now greater than max_side | |
largest_side = max(rows, cols) | |
if largest_side * scale > max_side: | |
scale = max_side / largest_side | |
# Resize the image with the computed scale | |
image = cv2.resize(image, (int(round(cols * scale)), int(round((rows * scale))))) | |
rows, cols, cns = image.shape | |
pad_w = 32 - rows % 32 | |
pad_h = 32 - cols % 32 | |
new_image = np.zeros((rows + pad_w, cols + pad_h, cns)).astype(np.float32) | |
new_image[:rows, :cols, :] = image.astype(np.float32) | |
image = new_image.astype(np.float32) | |
image /= 255 | |
image -= [0.485, 0.456, 0.406] | |
image /= [0.229, 0.224, 0.225] | |
image = np.expand_dims(image, 0) | |
image = np.transpose(image, (0, 3, 1, 2)) | |
with torch.no_grad(): | |
image = torch.from_numpy(image) | |
if torch.cuda.is_available(): | |
image = image.cuda() | |
st = time.time() | |
scores, classification, transformed_anchors = model(image.float()) | |
elapsed_time = time.time() - st | |
idxs = np.where(scores.cpu() > 0.5) | |
for j in range(idxs[0].shape[0]): | |
bbox = transformed_anchors[idxs[0][j], :] | |
x1 = int(bbox[0] / scale) | |
y1 = int(bbox[1] / scale) | |
x2 = int(bbox[2] / scale) | |
y2 = int(bbox[3] / scale) | |
label_name = labels[int(classification[idxs[0][j]])] | |
colors = { | |
'with_mask': (0, 255, 0), | |
'without_mask': (255, 0, 0), | |
'mask_weared_incorrect': (190, 100, 20) | |
} | |
cap = '{}'.format(label_name) if caption else '' | |
draw_caption(image_orig, (x1, y1, x2, y2), cap) | |
cv2.rectangle(image_orig, (x1, y1), (x2, y2), color=colors[label_name], | |
thickness=int(1 * (smallest_side / 100))) | |
return image_orig | |
# Page config | |
st.set_page_config(layout="centered") | |
st.title("Face Mask Detection") | |
st.write('Face Mask Detection on images, videos and webcam feed with ResNet[18~152] models. ') | |
st.markdown(f"__Labels:__ with_mask, without_mask, mask_weared_incorrect") | |
# Models drive ids | |
ids = { | |
'resnet50_20': st.secrets['resnet50'], | |
'resnet152_20': st.secrets['resnet152'], | |
} | |
# Download all models from drive | |
download_models(ids) | |
# Split page into columns | |
left, right = st.columns([5, 3]) | |
# Model selection | |
labels = load_labels() | |
model_path = right.selectbox('Choose a model', options=[k for k in ids], index=0) | |
model = load_model(model_path=model_path) if model_path != '' else None | |
# Display example selection | |
index = left.number_input('', min_value=0, max_value=852, value=495, help='Choose an image. ') | |
# Uploader | |
uploaded = st.file_uploader("Try it out with your own image!", type=['.jpg', '.png', '.jfif']) | |
if uploaded is not None: | |
# Convert file to image | |
image = Image.open(uploaded) | |
image = np.array(image) | |
else: | |
# Get corresponding image and transform it | |
image = cv2.imread(f'data/validation/image/maksssksksss{str(index)}.jpg') | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Process img | |
with st.spinner('Please wait while the image is being processed... This may take a while. '): | |
image = process_img(model, image, labels, caption=False) | |
left.image(cv2.resize(image, (450, 300))) | |
# Write labels dict and device on right | |
right.write({ | |
'green': 'with_mask', | |
'orange': 'mask_weared_incorrect', | |
'red': 'without_mask' | |
}) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
right.write(device) | |
captions = [image for image in os.listdir('data/examples/')] | |
images = [Image.open(f'data/examples/{image}') for image in os.listdir('data/examples/')] | |
# Display examples | |
st.image(images, width=350) | |