Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import torch | |
import argparse | |
import numpy as np | |
import onnxruntime as ort | |
import torchvision.transforms as transforms | |
from typing import Union, List, Tuple | |
from mmengine.registry import Registry | |
HUMAN_ATTRS = Registry("human_attrs") | |
class ONNX_Base(): | |
def __init__(self, | |
input_shape, | |
model_path: str, | |
device: str = '0'): | |
self.input_shape = input_shape | |
self.model_path = model_path | |
self.device = self.select_device(device) | |
self.session = self.create_session(model_path) | |
self.time_dict = {'preprocess': 1e-6, 'inference': 1e-6, | |
'postprocess': 1e-6, 'total': 1e-6, 'n_imgs': 1e-6} | |
def create_session(self, model_path: str) -> ort.InferenceSession: | |
"""_summary_ | |
Args: | |
model_path (_type_): _description_ | |
Returns: | |
_type_: _description_ | |
""" | |
providers = ['CPUExecutionProvider'] | |
providers.insert(0, 'CUDAExecutionProvider') | |
ort_session = ort.InferenceSession(model_path, providers=providers) | |
return ort_session | |
def select_device(self, device: str) -> torch.device: | |
""" Select device to be used for inference. | |
Args: | |
param device: 'cpu' or '0' or '0,1,2,3' | |
Return: | |
torch.device | |
""" | |
cpu = device.lower() == "cpu" | |
# if cpu: | |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
return torch.device("cpu") | |
# else: | |
# assert torch.cuda.is_available( | |
# ), f'CUDA unavailable, invalid device {device} requested' | |
# os.environ['CUDA_VISBILE_DEVICES'] = device | |
# torch.cuda.set_device(int(device)) | |
# return torch.device(f"cuda:{device}") | |
def infer_batch(self, image_batch: np.ndarray) -> List[np.ndarray]: | |
""" Inference for onnx model. | |
Args: | |
param image_batch: (batch_size, height, width, channels) | |
Return: | |
results: List[np.ndarray] | |
""" | |
input_name = self.session.get_inputs()[0].name | |
results = self.session.run(None, {input_name: image_batch}) | |
return results | |
def benchmark_time(self, t_pre, t_infer, t_post, t_end, n_imgs): | |
if n_imgs > 0: | |
self.time_dict['n_imgs'] += n_imgs | |
k = n_imgs/self.time_dict['n_imgs'] | |
self.time_dict['preprocess'] += k * \ | |
((t_infer - t_pre)/n_imgs - self.time_dict['preprocess']) | |
self.time_dict['inference'] += k * \ | |
((t_post - t_infer)/n_imgs - self.time_dict['inference']) | |
self.time_dict['postprocess'] += k * \ | |
((t_end - t_post)/n_imgs - self.time_dict['postprocess']) | |
self.time_dict['total'] += k * \ | |
((t_end - t_pre)/n_imgs - self.time_dict['total']) | |
def print_time_benchmark(self): | |
print('-------------Benchmark time for a single image-------------') | |
for k, v in self.time_dict.items(): | |
print(f'{k}: {v:.4f}') | |
print("-------------Percentage of each step-----------------------") | |
for k, v in self.time_dict.items(): | |
if k == 'total' or k == 'n_imgs': | |
continue | |
print(f'{k}: {int(v/self.time_dict["total"]*100)}%') | |
def crop_objects(self, image: np.ndarray, bounding_boxes: np.ndarray): | |
""" Function to crop objects in input image. | |
Args: | |
image (np.ndarray): input image with shape (H, W, C). | |
bounding_boxes (np.ndarray): Array with shape Nx4 with N is the number of objects. | |
""" | |
max_h, max_w = image.shape[:2] | |
cropped_images = [] | |
for box in bounding_boxes: | |
x_top, y_top, x_bottom, y_bottom, _ = box.astype(int).tolist() | |
x_top = max(0, x_top) | |
y_top = max(0, y_top) | |
x_bottom = min(x_bottom, max_w) | |
y_bottom = min(y_bottom, max_h) | |
cropped_image = image[y_top:y_bottom, x_top:x_bottom] | |
cropped_images.append(cropped_image) | |
return cropped_images | |
class Attr_model(): | |
def __init__(self, | |
preprocess_cfg=dict( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
threshold=0.5), | |
AGE_CLASSES=['child', 'young', 'middle', 'senior'], | |
use_torch=False, | |
): | |
""" | |
Age Gender class for inference | |
Args: | |
preprocess_cfg (Dict): | |
- mean/std List[float, float, float]) : normarlize images | |
- to_rgb (bool): convert image to rgb channel | |
AGE_CLASSES (List[str, str, str]) 3 categories for classify age | |
use_torch (bool): use torch tensor or numpy array in preprocess and postprocess function. | |
The raw outputs classes are: | |
Hat | |
Glasses | |
HandBag | |
ShoulderBag | |
Backpack | |
HoldObjectsInFront | |
ShortSleeve | |
LongSleeve | |
UpperStride | |
UpperLogo | |
UpperPlaid | |
UpperSplice | |
LowerStripe | |
LowerPattern | |
LongCoat | |
Trousers | |
Shorts | |
Skirt&Dress | |
boots | |
AgeOver60 | |
Age18-60 | |
AgeLess18 | |
Female | |
Front | |
Side | |
Back | |
""" | |
self.preprocess_cfg = preprocess_cfg | |
self.mean = np.array(preprocess_cfg['mean'], dtype=np.float32) | |
self.std = np.array(preprocess_cfg['std'], dtype=np.float32) | |
self.AGE_CLASSES = AGE_CLASSES | |
self.threshold = preprocess_cfg['threshold'] | |
self.use_torch = use_torch | |
self.class_selects = [-4, -5, -6, -7] # (Female, <18,18-60,60<) | |
self.transforms = transforms.Compose([ | |
transforms.ToTensor(), # already convert to range [0,1] | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
# transforms.Normalize(mean=[0.406, 0.456, 0.485], | |
# std=[0.225, 0.224, 0.229]), | |
]) | |
def preprocess_one_img(self, | |
input_img: Union[np.ndarray, str], | |
height: int = 256, | |
width: int = 128): | |
""" | |
Preprocess for input data as one image | |
Args: | |
input_data (str/np.ndarray) : input data with shape [1,3,H,W] | |
""" | |
if isinstance(input_img, str): | |
img = cv2.imread(input_img) | |
out_img = cv2.resize(input_img, (width, height)) | |
tensor_data = self.transforms(out_img) | |
return tensor_data | |
def preprocess_list_img(self, batch_imgs: List): | |
""" | |
Preprocess function using for apply batch images as input data | |
Args: | |
input_data (np.ndarray) : batch input images with shape [N, 3, H, W] | |
""" | |
data = [self.preprocess_one_img(img) for img in batch_imgs] | |
data = torch.stack(data) | |
if len(data) < self.batch_size: | |
data = torch.cat( | |
[data, torch.zeros(self.batch_size-len(data), *data.shape[1:])]) | |
if self.use_torch: | |
data = data.to(self.device) | |
else: | |
data = data.numpy() | |
return data | |
def postprocess(self, attr_logits): | |
# sigmoid | |
attr_probs = attr_logits[:, self.class_selects] | |
# [ Female, AgeLess18, Age18-60, AgeOver60,] | |
gender_outputs = [] | |
age_outputs = [] | |
for attr_prob in attr_probs: | |
gender = 'female' if attr_prob[0] > 0.5 else 'male' | |
age_prob_list = attr_prob[1:] | |
age_prob_indx = np.argmax(age_prob_list, axis=0) | |
if age_prob_indx == 0: | |
age = self.AGE_CLASSES[0] | |
elif age_prob_indx == 1: | |
age = self.AGE_CLASSES[1] if age_prob_list[0] > age_prob_list[2] else self.AGE_CLASSES[2] | |
elif age_prob_indx == 2: | |
age = self.AGE_CLASSES[3] | |
gender_outputs.append(gender) | |
age_outputs.append(age) | |
return {'genders': gender_outputs, 'ages': age_outputs} | |
def postprocess_multihead(self, attr_logits): | |
# [ Female, AgeLess18, Age18-60, AgeOver60,] | |
gender_outputs = [] | |
age_outputs = [] | |
for attr_prob in attr_logits: | |
gender = 'female' if attr_prob[-1] > 0.5 else 'male' | |
age_prob_list = attr_prob[:-1] | |
age_prob_indx = np.argmax(age_prob_list, axis=0) | |
age = self.AGE_CLASSES[age_prob_indx] | |
gender_outputs.append(gender) | |
age_outputs.append(age) | |
return {'genders': gender_outputs, 'ages': age_outputs} | |
class HumanAG_ONNX(ONNX_Base, Attr_model): | |
def __init__(self, | |
model_path: str, | |
img_shape: Tuple[int, int, int] = (3, 256, 128), | |
batch_size: int = 32, | |
device: str = '0'): | |
"""Model Age gender onnx for inference, which base on ONNX Base and Age_Gender class""" | |
self.img_shape = img_shape | |
self.batch_size = batch_size | |
input_shape = (self.batch_size, *self.img_shape) | |
super().__init__(input_shape, model_path, device) | |
Attr_model.__init__(self, | |
use_torch=False) | |
def infer_batch(self, batch_images: np.ndarray) -> List[str]: | |
""" | |
Args: | |
batch_images (np.ndarray): batch of images with shape [N, 3, H, W] , N is number of images per batch | |
Returns: | |
age_output_list (str) : | |
gender_output_list (str) : | |
""" | |
batch_array_data = self.preprocess_list_img(batch_images) | |
results = super().infer_batch(batch_array_data) | |
attr_logits = results[0] | |
results = self.postprocess_multihead(attr_logits) | |
return results | |
def infer(self, batch_images: np.ndarray) -> List[str]: | |
return self.infer_batch(batch_images) |