Spaces:
Running
Running
from spiga.inference.config import ModelConfig | |
from spiga.models.spiga import SPIGA | |
import spiga.inference.pretreatment as pretreat | |
import os | |
import pkg_resources | |
import copy | |
import torch | |
import numpy as np | |
# Paths | |
weights_path_dft = pkg_resources.resource_filename('spiga', 'models/weights') | |
class SPIGAFramework: | |
def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True): | |
# Parameters | |
self.model_cfg = model_cfg | |
self.gpus = gpus | |
# Pretreatment initialization | |
self.transforms = pretreat.get_transformers(self.model_cfg) | |
# SPIGA model | |
self.model_inputs = ['image', "model3d", "cam_matrix"] | |
self.model = SPIGA(num_landmarks=model_cfg.dataset.num_landmarks, | |
num_edges=model_cfg.dataset.num_edges) | |
# Load weights and set model | |
weights_path = self.model_cfg.model_weights_path | |
if weights_path is None: | |
weights_path = weights_path_dft | |
if self.model_cfg.load_model_url: | |
model_state_dict = torch.hub.load_state_dict_from_url(self.model_cfg.model_weights_url, | |
model_dir=weights_path, | |
file_name=self.model_cfg.model_weights) | |
else: | |
weights_file = os.path.join( | |
weights_path, self.model_cfg.model_weights) | |
model_state_dict = torch.load(weights_file) | |
self.model.load_state_dict(model_state_dict) | |
# self.model = self.model.cuda(gpus[0]) | |
self.model = self.model.cuda( | |
gpus[0]) if torch.cuda.is_available() else self.model | |
self.model.eval() | |
print('SPIGA model loaded!') | |
# Load 3D model and camera intrinsic matrix | |
if load3DM: | |
loader_3DM = pretreat.AddModel3D(model_cfg.dataset.ldm_ids, | |
ftmap_size=model_cfg.ftmap_size, | |
focal_ratio=model_cfg.focal_ratio, | |
totensor=True) | |
params_3DM = self._data2device(loader_3DM()) | |
self.model3d = params_3DM['model3d'] | |
self.cam_matrix = params_3DM['cam_matrix'] | |
def inference(self, image, bboxes): | |
""" | |
@param self: | |
@param image: Raw image | |
@param bboxes: List of bounding box founded on the image [[x,y,w,h],...] | |
@return: features dict {'landmarks': list with shape (num_bbox, num_landmarks, 2) and x,y referred to image size | |
'headpose': list with shape (num_bbox, 6) euler->[:3], trl->[3:] | |
""" | |
batch_crops, crop_bboxes = self.pretreat(image, bboxes) | |
outputs = self.net_forward(batch_crops) | |
features = self.postreatment(outputs, crop_bboxes, bboxes) | |
return features | |
def pretreat(self, image, bboxes): | |
crop_bboxes = [] | |
crop_images = [] | |
for bbox in bboxes: | |
sample = {'image': copy.deepcopy(image), | |
'bbox': copy.deepcopy(bbox)} | |
sample_crop = self.transforms(sample) | |
crop_bboxes.append(sample_crop['bbox']) | |
crop_images.append(sample_crop['image']) | |
# Images to tensor and device | |
batch_images = torch.tensor(np.array(crop_images), dtype=torch.float) | |
batch_images = self._data2device(batch_images) | |
# Batch 3D model and camera intrinsic matrix | |
batch_model3D = self.model3d.unsqueeze(0).repeat(len(bboxes), 1, 1) | |
batch_cam_matrix = self.cam_matrix.unsqueeze( | |
0).repeat(len(bboxes), 1, 1) | |
# SPIGA inputs | |
model_inputs = [batch_images, batch_model3D, batch_cam_matrix] | |
return model_inputs, crop_bboxes | |
def net_forward(self, inputs): | |
outputs = self.model(inputs) | |
return outputs | |
def postreatment(self, output, crop_bboxes, bboxes): | |
features = {} | |
crop_bboxes = np.array(crop_bboxes) | |
bboxes = np.array(bboxes) | |
if 'Landmarks' in output.keys(): | |
landmarks = output['Landmarks'][-1].cpu().detach().numpy() | |
landmarks = landmarks.transpose((1, 0, 2)) | |
landmarks = landmarks*self.model_cfg.image_size | |
landmarks_norm = ( | |
landmarks - crop_bboxes[:, 0:2]) / crop_bboxes[:, 2:4] | |
landmarks_out = (landmarks_norm * bboxes[:, 2:4]) + bboxes[:, 0:2] | |
landmarks_out = landmarks_out.transpose((1, 0, 2)) | |
features['landmarks'] = landmarks_out.tolist() | |
# Pose output | |
if 'Pose' in output.keys(): | |
pose = output['Pose'].cpu().detach().numpy() | |
features['headpose'] = pose.tolist() | |
return features | |
def select_inputs(self, batch): | |
inputs = [] | |
for ft_name in self.model_inputs: | |
data = batch[ft_name] | |
inputs.append(self._data2device(data.type(torch.float))) | |
return inputs | |
def _data2device(self, data): | |
if isinstance(data, list): | |
data_var = data | |
for data_id, v_data in enumerate(data): | |
data_var[data_id] = self._data2device(v_data) | |
if isinstance(data, dict): | |
data_var = data | |
for k, v in data.items(): | |
data[k] = self._data2device(v) | |
else: | |
with torch.no_grad(): | |
if torch.cuda.is_available(): | |
data_var = data.cuda( | |
device=self.gpus[0], non_blocking=True) | |
else: | |
data_var = data | |
return data_var | |