StyleGene / models /stylegene /fair_face_model.py
wmpscc
add
7d1312d
import torch
import numpy as np
from PIL import Image
from torch import nn
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from configs import path_ckpt_fairface
# code adapted from https://github.com/dchen236/FairFace
def init_fair_model(device, path_ckpt=None):
if path_ckpt is None:
path_ckpt = path_ckpt_fairface
model_fair_7 = torchvision.models.resnet34(pretrained=False)
model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
model_fair_7.load_state_dict(
torch.load(path_ckpt))
model_fair_7 = model_fair_7.to(device)
model_fair_7.eval()
return model_fair_7
def predict_race(model_fair_7, path_img, device):
if type(path_img) == str:
trans = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(path_img)
image = trans(image)
image = image.view(1, 3, 224, 224) # reshape image to match model dimensions (1 batch size)
elif type(path_img) == torch.Tensor:
trans = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = F.interpolate(path_img, (224, 224))
image = image * 0.5 + 0.5
image = trans(image)
image = image.view(1, 3, 224, 224)
image = image.to(device)
outputs = model_fair_7(image)
outputs = outputs.cpu().detach().numpy()
outputs = np.squeeze(outputs)
race_outputs = outputs[:7]
gender_outputs = outputs[7:9]
age_outputs = outputs[9:18]
race_score = np.exp(race_outputs) / np.sum(np.exp(race_outputs))
gender_score = np.exp(gender_outputs) / np.sum(np.exp(gender_outputs))
age_score = np.exp(age_outputs) / np.sum(np.exp(age_outputs))
race_pred = np.argmax(race_score)
gender_pred = np.argmax(gender_score)
age_pred = np.argmax(age_score)
race_label = ['White', 'Black', 'Latino_Hispanic', 'East Asian', 'Southeast Asian', 'Indian', 'Middle Eastern']
return race_label[race_pred], race_pred, gender_pred, age_pred