from typing import Tuple import torch from mivolo.model.mi_volo import MiVOLO from .age_gender_dataset import AgeGenderDataset from .age_gender_loader import create_loader from .classification_dataset import AdienceDataset, FairFaceDataset DATASET_CLASS_MAP = { "utk": AgeGenderDataset, "lagenda": AgeGenderDataset, "imdb": AgeGenderDataset, "adience": AdienceDataset, "fairface": FairFaceDataset, } def build( name: str, images_path: str, annotations_path: str, split: str, mivolo_model: MiVOLO, workers: int, batch_size: int, ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]: dataset_class = DATASET_CLASS_MAP[name] dataset: torch.utils.data.Dataset = dataset_class( images_path=images_path, annotations_path=annotations_path, name=name, split=split, target_size=mivolo_model.input_size, max_age=mivolo_model.meta.max_age, min_age=mivolo_model.meta.min_age, model_with_persons=mivolo_model.meta.with_persons_model, use_persons=mivolo_model.meta.use_persons, disable_faces=mivolo_model.meta.disable_faces, only_age=mivolo_model.meta.only_age, ) data_config = mivolo_model.data_config in_chans = 3 if not mivolo_model.meta.with_persons_model else 6 input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size) dataset_loader: torch.utils.data.DataLoader = create_loader( dataset, input_size=input_size, batch_size=batch_size, mean=data_config["mean"], std=data_config["std"], num_workers=workers, crop_pct=data_config["crop_pct"], crop_mode=data_config["crop_mode"], pin_memory=False, device=mivolo_model.device, target_type=dataset.target_dtype, ) return dataset, dataset_loader