Upload 6 files
Browse files- app.py +80 -0
- examples/im1.png +0 -0
- examples/im2.png +0 -0
- model.py +44 -0
- requirements.txt +5 -0
- sex_tiny_vgg_defualt_weights.pth +3 -0
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from model import TinyVGG
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from PIL import Image
|
7 |
+
import gradio as gr
|
8 |
+
import os
|
9 |
+
import numpy
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
def predict(img):
|
13 |
+
"""Transforms and performs a prediction on img and returns prediction and time taken.
|
14 |
+
"""
|
15 |
+
# Create tiny_vgg model
|
16 |
+
model = TinyVGG(input_shape=3, # number of color channels (3 for RGB)
|
17 |
+
hidden_units=10,
|
18 |
+
output_shape=2)
|
19 |
+
|
20 |
+
# Load saved weights
|
21 |
+
model.load_state_dict(torch.load(f="sex_tiny_vgg_defualt_weights.pth", map_location=torch.device("cpu")))
|
22 |
+
transform = transforms.Compose([
|
23 |
+
transforms.Resize((64, 64)),
|
24 |
+
transforms.ToTensor(),
|
25 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
26 |
+
])
|
27 |
+
class_names = ['female', 'male']
|
28 |
+
input_image = cv2.imread(img)
|
29 |
+
input_image_rgb = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
|
30 |
+
|
31 |
+
# Detect faces in the input image using OpenCV
|
32 |
+
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
|
33 |
+
faces = face_cascade.detectMultiScale(input_image_rgb, scaleFactor=1.1, minNeighbors=5, minSize=(64, 64))
|
34 |
+
|
35 |
+
if len(faces) == 0:
|
36 |
+
return "No faces detected in the image."
|
37 |
+
else:
|
38 |
+
model.eval()
|
39 |
+
# Process each detected face
|
40 |
+
for i, (x, y, w, h) in enumerate(faces):
|
41 |
+
face_image = input_image[y:y+h, x:x+w] # Extract face
|
42 |
+
face_image_pil = Image.fromarray(cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB)) # Convert to PIL format
|
43 |
+
face_image_tensor = transform(face_image_pil).unsqueeze(0) # Preprocess face for classification
|
44 |
+
# Put model into evaluation mode and turn on inference mode
|
45 |
+
with torch.inference_mode():
|
46 |
+
# Pass the transformed image through the model and turn the prediction logits into prediction probabilities
|
47 |
+
pred_probs = torch.sigmoid(model(face_image_tensor))
|
48 |
+
# Create a prediction label and prediction probability dictionary for each prediction class (this is the required format for Gradio's output parameter)
|
49 |
+
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
|
50 |
+
if pred_labels_and_probs['female'] >= pred_labels_and_probs['male']:
|
51 |
+
return f"Face {i+1}: (Female: {pred_labels_and_probs['female']})"
|
52 |
+
else:
|
53 |
+
return f"Face {i+1}: (Male: {pred_labels_and_probs['male']})"
|
54 |
+
return
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
# Create title, description and article strings
|
61 |
+
title = "Sex Prediction "
|
62 |
+
description = "An tiny VGG feature extractor computer vision model to classify Human Face images into male or female."
|
63 |
+
|
64 |
+
# Create examples list from "examples/" directory
|
65 |
+
example_list = [["examples/" + example] for example in os.listdir("examples")]
|
66 |
+
|
67 |
+
# Create Gradio interface
|
68 |
+
demo = gr.Interface(
|
69 |
+
fn=predict,
|
70 |
+
inputs=gr.Image(type="pil"),
|
71 |
+
outputs=[
|
72 |
+
gr.Label(num_top_classes=5, label="Predictions")
|
73 |
+
],
|
74 |
+
examples=example_list,
|
75 |
+
title=title,
|
76 |
+
description=description
|
77 |
+
)
|
78 |
+
|
79 |
+
# Launch the app!
|
80 |
+
demo.launch()
|
examples/im1.png
ADDED
examples/im2.png
ADDED
model.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class TinyVGG(nn.Module):
|
5 |
+
"""
|
6 |
+
Model architecture copying TinyVGG from:
|
7 |
+
https://poloclub.github.io/cnn-explainer/
|
8 |
+
https://www.learnpytorch.io/04_pytorch_custom_datasets/#:~:text=class%20TinyVGG(,device)%0Amodel_0
|
9 |
+
"""
|
10 |
+
def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
|
11 |
+
super().__init__()
|
12 |
+
self.conv_block_1 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_channels=input_shape,
|
14 |
+
out_channels=hidden_units,
|
15 |
+
kernel_size=3, # how big is the square that's going over the image?
|
16 |
+
stride=1, # default
|
17 |
+
padding=1), # options = "valid" (no padding) or "same" (output has same shape as input) or int for specific number
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.Conv2d(in_channels=hidden_units,
|
20 |
+
out_channels=hidden_units,
|
21 |
+
kernel_size=3,
|
22 |
+
stride=1,
|
23 |
+
padding=1),
|
24 |
+
nn.ReLU(),
|
25 |
+
nn.MaxPool2d(kernel_size=2,
|
26 |
+
stride=2) # default stride value is same as kernel_size
|
27 |
+
)
|
28 |
+
self.conv_block_2 = nn.Sequential(
|
29 |
+
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
|
30 |
+
nn.ReLU(),
|
31 |
+
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
|
32 |
+
nn.ReLU(),
|
33 |
+
nn.MaxPool2d(2)
|
34 |
+
)
|
35 |
+
self.classifier = nn.Sequential(
|
36 |
+
nn.Flatten(),
|
37 |
+
# Where did this in_features shape come from?
|
38 |
+
# It's because each layer of our network compresses and changes the shape of our inputs data.
|
39 |
+
nn.Linear(in_features=hidden_units*16*16,
|
40 |
+
out_features=output_shape)
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor):
|
44 |
+
return self.classifier(self.conv_block_2(self.conv_block_1(x)))
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision==0.15.2
|
3 |
+
gradio==4.27.0
|
4 |
+
opencv-python
|
5 |
+
numpy
|
sex_tiny_vgg_defualt_weights.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1284bdb8b1796e719c9b74ab2e794918acc7c72a3f905eba858349c184ec8dc6
|
3 |
+
size 36419
|