import logging from typing import Any, List, Optional, Set import cv2 import numpy as np import torch from mivolo.data.dataset.reader_age_gender import ReaderAgeGender from PIL import Image from torchvision import transforms _logger = logging.getLogger("AgeGenderDataset") class AgeGenderDataset(torch.utils.data.Dataset): def __init__( self, images_path, annotations_path, name=None, split="train", load_bytes=False, img_mode="RGB", transform=None, is_training=False, seed=1234, target_size=224, min_age=None, max_age=None, model_with_persons=False, use_persons=False, disable_faces=False, only_age=False, ): reader = ReaderAgeGender( images_path, annotations_path, split=split, seed=seed, target_size=target_size, with_persons=use_persons, disable_faces=disable_faces, only_age=only_age, ) self.name = name self.model_with_persons = model_with_persons self.reader = reader self.load_bytes = load_bytes self.img_mode = img_mode self.transform = transform self._consecutive_errors = 0 self.is_training = is_training self.random_flip = 0.0 # Setting up classes. # If min and max classes are passed - use them to have the same preprocessing for validation self.max_age: float = None self.min_age: float = None self.avg_age: float = None self.set_ages_min_max(min_age, max_age) self.genders = ["M", "F"] self.num_classes_gender = len(self.genders) self.age_classes: Optional[List[str]] = self.set_age_classes() self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes) self.num_classes: int = self.num_classes_age + self.num_classes_gender self.target_dtype = torch.float32 def set_age_classes(self) -> Optional[List[str]]: return None # for regression dataset def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]): assert all(age is None for age in [min_age, max_age]) or all( age is not None for age in [min_age, max_age] ), "Both min and max age must be passed or none of them" if max_age is not None and min_age is not None: _logger.info(f"Received predefined min_age {min_age} and max_age {max_age}") self.max_age = max_age self.min_age = min_age else: # collect statistics from loaded dataset all_ages_set: Set[int] = set() for img_path, image_samples in self.reader._ann.items(): for image_sample_info in image_samples: if image_sample_info.age == "-1": continue age = round(float(image_sample_info.age)) all_ages_set.add(age) self.max_age = max(all_ages_set) self.min_age = min(all_ages_set) self.avg_age = (self.max_age + self.min_age) / 2.0 def _norm_age(self, age): return (age - self.avg_age) / (self.max_age - self.min_age) def parse_gender(self, _gender: str) -> float: if _gender != "-1": gender = float(0 if _gender == "M" or _gender == "0" else 1) else: gender = -1 return gender def parse_target(self, _age: str, gender: str) -> List[Any]: if _age != "-1": age = round(float(_age)) age = self._norm_age(float(age)) else: age = -1 target: List[float] = [age, self.parse_gender(gender)] return target @property def transform(self): return self._transform @transform.setter def transform(self, transform): # Disable pretrained monkey-patched transforms if not transform: return _trans = [] for trans in transform.transforms: if "Resize" in str(trans): continue if "Crop" in str(trans): continue _trans.append(trans) self._transform = transforms.Compose(_trans) def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray: if image is None: return None if self.transform is None: return image image = convert_to_pil(image, self.img_mode) for trans in self.transform.transforms: image = trans(image) return image def __getitem__(self, index): # get preprocessed face and person crops (np.ndarray) # resize + pad, for person crops: cut off other bboxes images, target = self.reader[index] target = self.parse_target(*target) if self.model_with_persons: face_image, person_image = images person_image: np.ndarray = self.apply_tranforms(person_image) else: face_image = images[0] person_image = None face_image: np.ndarray = self.apply_tranforms(face_image) if person_image is not None: img = np.concatenate([face_image, person_image], axis=0) else: img = face_image return img, target def __len__(self): return len(self.reader) def filename(self, index, basename=False, absolute=False): return self.reader.filename(index, basename, absolute) def filenames(self, basename=False, absolute=False): return self.reader.filenames(basename, absolute) def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image": if cv_im is None: return None if img_mode == "RGB": cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB) else: raise Exception("Incorrect image mode has been passed!") cv_im = np.ascontiguousarray(cv_im) pil_image = Image.fromarray(cv_im) return pil_image