|
|
|
from model import TinyVGG |
|
import cv2 |
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import gradio as gr |
|
import os |
|
import numpy |
|
from pathlib import Path |
|
|
|
def predict(img): |
|
"""Transforms and performs a prediction on img and returns prediction and time taken. |
|
""" |
|
|
|
model = TinyVGG(input_shape=3, |
|
hidden_units=10, |
|
output_shape=2) |
|
|
|
|
|
model.load_state_dict(torch.load(f="sex_tiny_vgg_defualt_weights.pth", map_location=torch.device("cpu"))) |
|
transform = transforms.Compose([ |
|
transforms.Resize((64, 64)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
]) |
|
class_names = ['female', 'male'] |
|
input_image = numpy.array(img) |
|
|
|
input_image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') |
|
faces = face_cascade.detectMultiScale(input_image_bgr, scaleFactor=1.1, minNeighbors=5, minSize=(64, 64)) |
|
|
|
if len(faces) == 0: |
|
return input_image |
|
else: |
|
model.eval() |
|
|
|
for i, (x, y, w, h) in enumerate(faces): |
|
face_image = input_image[y:y+h, x:x+w] |
|
face_image_pil = Image.fromarray(cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)) |
|
face_image_tensor = transform(face_image_pil).unsqueeze(0) |
|
|
|
with torch.inference_mode(): |
|
|
|
pred_probs = torch.sigmoid(model(face_image_tensor)) |
|
|
|
pred_prob_female = float(pred_probs[0][0]) |
|
pred_prob_male = float(pred_probs[0][1]) |
|
sex_label = 'Male' if pred_prob_male > pred_prob_female else 'Female' |
|
prob = max(pred_prob_female, pred_prob_male) * 100 |
|
|
|
|
|
cv2.rectangle(input_image, (x, y), (x+w, y+h), (0, 255, 0), 2) |
|
|
|
|
|
label = f'{sex_label}, P={prob:.2f}%' |
|
cv2.putText(input_image, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) |
|
|
|
return input_image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title = "Sex Prediction " |
|
description = "An tiny VGG feature extractor computer vision model to classify Human Face images into male or female." |
|
|
|
|
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[ |
|
gr.Image() |
|
], |
|
examples=example_list, |
|
title=title, |
|
description=description |
|
) |
|
|
|
|
|
demo.launch() |