Eyob-Sol's picture
Update app.py
cb15f44 verified
raw
history blame contribute delete
No virus
3.4 kB
from model import TinyVGG
import cv2
import torch
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import os
import numpy
from pathlib import Path
def predict(img):
"""Transforms and performs a prediction on img and returns prediction and time taken.
"""
# Create tiny_vgg model
model = TinyVGG(input_shape=3, # number of color channels (3 for RGB)
hidden_units=10,
output_shape=2)
# Load saved weights
model.load_state_dict(torch.load(f="sex_tiny_vgg_defualt_weights.pth", map_location=torch.device("cpu")))
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
class_names = ['female', 'male']
input_image = numpy.array(img)
# Convert the PIL image array to BGR format
input_image_bgr = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
# # Detect faces in the input image using OpenCV
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
faces = face_cascade.detectMultiScale(input_image_bgr, scaleFactor=1.1, minNeighbors=5, minSize=(64, 64))
if len(faces) == 0:
return input_image
else:
model.eval()
# Process each detected face
for i, (x, y, w, h) in enumerate(faces):
face_image = input_image[y:y+h, x:x+w] # Extract face
face_image_pil = Image.fromarray(cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)) # Convert to PIL format
face_image_tensor = transform(face_image_pil).unsqueeze(0) # Preprocess face for classification
# Put model into evaluation mode and turn on inference mode
with torch.inference_mode():
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
pred_probs = torch.sigmoid(model(face_image_tensor))
# Determine the predicted class label and probability
pred_prob_female = float(pred_probs[0][0])
pred_prob_male = float(pred_probs[0][1])
sex_label = 'Male' if pred_prob_male > pred_prob_female else 'Female'
prob = max(pred_prob_female, pred_prob_male) * 100
# Draw bounding box
cv2.rectangle(input_image, (x, y), (x+w, y+h), (0, 255, 0), 2)
# Write label
label = f'{sex_label}, P={prob:.2f}%'
cv2.putText(input_image, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return input_image
# Create title, description and article strings
title = "Sex Prediction "
description = "An tiny VGG feature extractor computer vision model to classify Human Face images into male or female."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image()
],
examples=example_list,
title=title,
description=description
)
# Launch the app!
demo.launch()