from spiga.inference.config import ModelConfig from spiga.models.spiga import SPIGA import spiga.inference.pretreatment as pretreat import os import pkg_resources import copy import torch import numpy as np # Paths weights_path_dft = pkg_resources.resource_filename('spiga', 'models/weights') class SPIGAFramework: def __init__(self, model_cfg: ModelConfig(), gpus=[0], load3DM=True): # Parameters self.model_cfg = model_cfg self.gpus = gpus # Pretreatment initialization self.transforms = pretreat.get_transformers(self.model_cfg) # SPIGA model self.model_inputs = ['image', "model3d", "cam_matrix"] self.model = SPIGA(num_landmarks=model_cfg.dataset.num_landmarks, num_edges=model_cfg.dataset.num_edges) # Load weights and set model weights_path = self.model_cfg.model_weights_path if weights_path is None: weights_path = weights_path_dft if self.model_cfg.load_model_url: model_state_dict = torch.hub.load_state_dict_from_url(self.model_cfg.model_weights_url, model_dir=weights_path, file_name=self.model_cfg.model_weights) else: weights_file = os.path.join( weights_path, self.model_cfg.model_weights) model_state_dict = torch.load(weights_file) self.model.load_state_dict(model_state_dict) # self.model = self.model.cuda(gpus[0]) self.model = self.model.cuda( gpus[0]) if torch.cuda.is_available() else self.model self.model.eval() print('SPIGA model loaded!') # Load 3D model and camera intrinsic matrix if load3DM: loader_3DM = pretreat.AddModel3D(model_cfg.dataset.ldm_ids, ftmap_size=model_cfg.ftmap_size, focal_ratio=model_cfg.focal_ratio, totensor=True) params_3DM = self._data2device(loader_3DM()) self.model3d = params_3DM['model3d'] self.cam_matrix = params_3DM['cam_matrix'] def inference(self, image, bboxes): """ @param self: @param image: Raw image @param bboxes: List of bounding box founded on the image [[x,y,w,h],...] @return: features dict {'landmarks': list with shape (num_bbox, num_landmarks, 2) and x,y referred to image size 'headpose': list with shape (num_bbox, 6) euler->[:3], trl->[3:] """ batch_crops, crop_bboxes = self.pretreat(image, bboxes) outputs = self.net_forward(batch_crops) features = self.postreatment(outputs, crop_bboxes, bboxes) return features def pretreat(self, image, bboxes): crop_bboxes = [] crop_images = [] for bbox in bboxes: sample = {'image': copy.deepcopy(image), 'bbox': copy.deepcopy(bbox)} sample_crop = self.transforms(sample) crop_bboxes.append(sample_crop['bbox']) crop_images.append(sample_crop['image']) # Images to tensor and device batch_images = torch.tensor(np.array(crop_images), dtype=torch.float) batch_images = self._data2device(batch_images) # Batch 3D model and camera intrinsic matrix batch_model3D = self.model3d.unsqueeze(0).repeat(len(bboxes), 1, 1) batch_cam_matrix = self.cam_matrix.unsqueeze( 0).repeat(len(bboxes), 1, 1) # SPIGA inputs model_inputs = [batch_images, batch_model3D, batch_cam_matrix] return model_inputs, crop_bboxes def net_forward(self, inputs): outputs = self.model(inputs) return outputs def postreatment(self, output, crop_bboxes, bboxes): features = {} crop_bboxes = np.array(crop_bboxes) bboxes = np.array(bboxes) if 'Landmarks' in output.keys(): landmarks = output['Landmarks'][-1].cpu().detach().numpy() landmarks = landmarks.transpose((1, 0, 2)) landmarks = landmarks*self.model_cfg.image_size landmarks_norm = ( landmarks - crop_bboxes[:, 0:2]) / crop_bboxes[:, 2:4] landmarks_out = (landmarks_norm * bboxes[:, 2:4]) + bboxes[:, 0:2] landmarks_out = landmarks_out.transpose((1, 0, 2)) features['landmarks'] = landmarks_out.tolist() # Pose output if 'Pose' in output.keys(): pose = output['Pose'].cpu().detach().numpy() features['headpose'] = pose.tolist() return features def select_inputs(self, batch): inputs = [] for ft_name in self.model_inputs: data = batch[ft_name] inputs.append(self._data2device(data.type(torch.float))) return inputs def _data2device(self, data): if isinstance(data, list): data_var = data for data_id, v_data in enumerate(data): data_var[data_id] = self._data2device(v_data) if isinstance(data, dict): data_var = data for k, v in data.items(): data[k] = self._data2device(v) else: with torch.no_grad(): if torch.cuda.is_available(): data_var = data.cuda( device=self.gpus[0], non_blocking=True) else: data_var = data return data_var