njgroene's picture
Update app.py
bcfaff6
raw history blame
No virus
3.39 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
from torch_mtcnn import detect_faces
from torch_mtcnn import show_bboxes
# pipeline = pipeline(task="image-classification", model="njgroene/fairface")
def pipeline(img):
bounding_boxes, landmarks = detect_faces(img)
bb = [bounding_boxes[0,0], bounding_boxes[0,1], bounding_boxes[0,2], bounding_boxes[0,3]]
img_cropped = img.crop(bb)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_fair_7 = torchvision.models.resnet34(pretrained=True)
model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
model_fair_7.load_state_dict(torch.load('res34_fair_align_multi_7_20190809.pt', map_location=torch.device('cpu')))
model_fair_7 = model_fair_7.to(device)
model_fair_7.eval()
trans = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
face_names = []
gender_scores_fair = []
age_scores_fair = []
gender_preds_fair = []
age_preds_fair = []
image = trans(img_cropped)
image = image.view(1, 3, 224, 224) # reshape image to match model dimensions (1 batch size)
image = image.to(device)
# fair 7 class
outputs = model_fair_7(image)
outputs = outputs.cpu().detach().numpy()
outputs = np.squeeze(outputs)
gender_outputs = outputs[7:9]
age_outputs = outputs[9:18]
gender_score = np.exp(gender_outputs) / np.sum(np.exp(gender_outputs))
age_score = np.exp(age_outputs) / np.sum(np.exp(age_outputs))
gender_pred = np.argmax(gender_score)
age_pred = np.argmax(age_score)
gender_scores_fair.append(gender_score)
age_scores_fair.append(age_score)
gender_preds_fair.append(gender_pred)
age_preds_fair.append(age_pred)
result = pd.DataFrame([gender_preds_fair,
age_preds_fair]).T
result.columns = ['gender_preds_fair',
'age_preds_fair']
# gender
result.loc[result['gender_preds_fair'] == 0, 'gender'] = 'Male'
result.loc[result['gender_preds_fair'] == 1, 'gender'] = 'Female'
# age
result.loc[result['age_preds_fair'] == 0, 'age'] = '0-2'
result.loc[result['age_preds_fair'] == 1, 'age'] = '3-9'
result.loc[result['age_preds_fair'] == 2, 'age'] = '10-19'
result.loc[result['age_preds_fair'] == 3, 'age'] = '20-29'
result.loc[result['age_preds_fair'] == 4, 'age'] = '30-39'
result.loc[result['age_preds_fair'] == 5, 'age'] = '40-49'
result.loc[result['age_preds_fair'] == 6, 'age'] = '50-59'
result.loc[result['age_preds_fair'] == 7, 'age'] = '60-69'
result.loc[result['age_preds_fair'] == 8, 'age'] = '70+'
return [result['gender'][0],result['age'][0]]
def predict(image):
predictions = pipeline(image)
return "A " + predictions[0] + " in the age range of " + predictions[1]
gr.Interface(
predict,
inputs=gr.inputs.Image(label="Upload a profile picture of a single person", type="pil"),
outputs=("text"),
title="Estimate age and gender from profile picture",
examples=["ex0.jpg","ex1.jpg","ex2.jpg","ex3.jpg"]
).launch()