anasazasaa's picture
Update app.py
128514a verified
raw
history blame contribute delete
No virus
4.85 kB
#!/usr/bin/env python
import functools
import os
import pathlib
import cv2
import dlib
import gradio as gr
import huggingface_hub
import numpy as np
import pretrainedmodels
import torch
import torch.nn as nn
import torch.nn.functional as F
DESCRIPTION = '# [Age Estimation](https://github.com/yu4u/age-estimation-pytorch)'
print("Current directory:", os.getcwd())
print("Files in the current directory:", os.listdir('.'))
ssd_net = cv2.dnn.readNetFromCaffe('deploy.prototxt', 'res10_300x300_ssd_iter_140000.caffemodel')
def get_model(model_name='se_resnext50_32x4d',
num_classes=101,
pretrained='imagenet'):
model = pretrainedmodels.__dict__[model_name](pretrained=pretrained)
dim_feats = model.last_linear.in_features
model.last_linear = nn.Linear(dim_feats, num_classes)
model.avg_pool = nn.AdaptiveAvgPool2d(1)
return model
def load_model(device):
model = get_model(model_name='se_resnext50_32x4d', pretrained=None)
path = huggingface_hub.hf_hub_download(
'public-data/yu4u-age-estimation-pytorch', 'pretrained.pth')
model.load_state_dict(torch.load(path))
model = model.to(device)
model.eval()
return model
def load_image(path):
image = cv2.imread(path)
return image
def draw_label(image,
point,
label,
font=cv2.FONT_HERSHEY_SIMPLEX,
font_scale=0.8,
thickness=1):
size = cv2.getTextSize(label, font, font_scale, thickness)[0]
x, y = point
cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0),
cv2.FILLED)
cv2.putText(image,
label,
point,
font,
font_scale, (255, 255, 255),
thickness,
lineType=cv2.LINE_AA)
def detect_faces_ssd(image):
(h, w) = image.shape[:2]
blob = cv2.dnn.blobFromImage(cv2.resize(image, (300, 300)), 1.0, (300, 300), (104.0, 177.0, 123.0))
ssd_net.setInput(blob)
detections = ssd_net.forward()
faces = []
# Loop over the detections
for i in range(0, detections.shape[2]):
confidence = detections[0, 0, i, 2]
if confidence > 0.5:
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
faces.append(box.astype("int"))
return faces
@torch.inference_mode()
def predict(image_path, model, device, margin=0.4, input_size=224):
image = cv2.imread(image_path)
image_h, image_w = image.shape[:2]
# Detect faces using SSD detector
faces_boxes = detect_faces_ssd(image)
age_data = []
if len(faces_boxes) > 0:
for i, (startX, startY, endX, endY) in enumerate(faces_boxes):
# Compute face region dimensions
w = endX - startX
h = endY - startY
xw1 = max(int(startX - margin * w), 0)
yw1 = max(int(startY - margin * h), 0)
xw2 = min(int(endX + margin * w), image_w - 1)
yw2 = min(int(endY + margin * h), image_h - 1)
# Crop and resize face region
face = cv2.resize(image[yw1:yw2 + 1, xw1:xw2 + 1], (input_size, input_size))
# predict ages
input_blob = torch.from_numpy(np.transpose(face.astype(np.float32), (2, 0, 1))).unsqueeze(0).to(device)
output = F.softmax(model(input_blob), dim=-1).cpu().numpy()
ages = np.arange(0, 101)
predicted_age = (output * ages).sum(axis=-1).item() # Convert to native Python datatype
# draw results
age_text = f'{int(predicted_age)}'
age_data.append({
'age': int(predicted_age), # Ensure this is a native Python int, not numpy int64
'text': age_text,
'face_coordinates': (int(startX), int(startY)) # Convert to native Python datatype
})
# Optionally, draw bounding boxes and age labels on the image
# cv2.rectangle(image, (startX, startY), (endX, endY), (0, 255, 0), 2)
# cv2.putText(image, age_text, (startX, startY - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 2)
return age_data
def main():
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = load_model(device)
fn = functools.partial(predict, model=model, device=device)
image_dir = pathlib.Path('sample_images')
examples = [path.as_posix() for path in sorted(image_dir.glob('*.jpg'))]
demo = gr.Interface(
fn=fn,
inputs=gr.inputs.Image(type="filepath"),
outputs="json",
examples=examples,
title="Age Estimation",
description=DESCRIPTION,
cache_examples=os.getenv('CACHE_EXAMPLES') == '1'
)
demo.launch()
if __name__ == '__main__':
main()