import os import pickle import random from pathlib import Path from typing import Dict from typing import List import torch from loguru import logger from PIL import Image from torch.utils.data import Dataset from torchvision import transforms from configs.mode import FaceSwapMode from configs.train_config import TrainConfig class ManyToManyTrainDataset(Dataset): def __init__(self, dataset_root: str, dataset_index: str, same_rate=0.5): """ Many-to-many 训练数据集构建 Parameters: ----------- dataset_root: str, 数据集根目录 dataset_index: str, 数据集index文件路径 same_rate: float, 每个batch里面相同人脸所占的比例 """ super(ManyToManyTrainDataset, self).__init__() self.transform = transforms.Compose( [ transforms.Resize((256, 256)), transforms.CenterCrop((256, 256)), transforms.ToTensor(), ] ) self.data_root = Path(dataset_root) with open(dataset_index, "rb") as f: self.file_index = pickle.load(f, encoding="bytes") self.same_rate = same_rate self.id_list: List[str] = list(self.file_index.keys()) # 所有id都遍历一遍,视为一个epoch self.length = len(self.id_list) self.image_num = sum([len(v) for v in self.file_index.values()]) self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth" logger.info(f"dataset contains {self.length} ids and {self.image_num} images") logger.info(f"will use mask mode: {self.mask_dir}") def __len__(self): return self.length def __getitem__(self, index): source_id_index = index source_file = random.choice(self.file_index[self.id_list[source_id_index]]) if random.random() < self.same_rate: # 在相同id的文件列表中选择 target_file = random.choice(self.file_index[self.id_list[source_id_index]]) same = torch.ones(1) else: # 在不同id的文件列表中选择 target_id_index = random.choice(list(set(range(self.length)) - set([source_id_index]))) target_file = random.choice(self.file_index[self.id_list[target_id_index]]) same = torch.zeros(1) source_file = self.data_root / Path(source_file) target_file = self.data_root / Path(target_file) target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name target_img = Image.open(target_file.as_posix()).convert("RGB") source_img = Image.open(source_file.as_posix()).convert("RGB") target_mask = Image.open(target_mask_file.as_posix()).convert("RGB") source_img = self.transform(source_img) target_img = self.transform(target_img) target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0) return { "source_image": source_img, "target_image": target_img, "target_mask": target_mask, "same": same, # "source_img_name": source_file.as_posix(), # "target_img_name": target_file.as_posix(), # "target_mask_name": target_mask_file.as_posix(), } class OneToManyTrainDataset(Dataset): def __init__(self, dataset_root: str, dataset_index: str, source_name: str, same_rate=0.5): """ One-to-many 训练数据集构建 Parameters: ----------- dataset_root: str, 数据集根目录 dataset_index: str, 数据集index文件路径 source_name: str, source face id的名称, one-to-many里面的one same_rate: float, 每个batch里面相同人脸所占的比例 """ super(OneToManyTrainDataset, self).__init__() self.transform = transforms.Compose( [ transforms.Resize((256, 256)), transforms.CenterCrop((256, 256)), transforms.ToTensor(), ] ) self.data_root = Path(dataset_root) with open(dataset_index, "rb") as f: self.file_index = pickle.load(f, encoding="bytes") self.same_rate = same_rate self.source_name = source_name self.id_list: List[str] = list(self.file_index.keys()) try: self.source_id_index: int = self.id_list.index(self.source_name) except Exception: raise Exception(f"{self.source_name} not in dataset dir") # 所有id都遍历一遍,视为一个epoch self.length = len(self.id_list) self.image_num = sum([len(v) for v in self.file_index.values()]) self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth" logger.info(f"dataset contains {self.length} ids and {self.image_num} images") logger.info(f"will use mask mode: {self.mask_dir}") def __len__(self): return self.length def __getitem__(self, index): target_id_index = index target_file = random.choice(self.file_index[self.id_list[target_id_index]]) if random.random() < self.same_rate: # 在相同id的文件列表中选择 source_file = random.choice(self.file_index[self.id_list[target_id_index]]) same = torch.ones(1) else: # 直接选择source name中的图片 source_file = random.choice(self.file_index[self.source_name]) # 如果和target同个id if self.source_id_index == target_id_index: same = torch.ones(1) else: same = torch.zeros(1) source_file = self.data_root / Path(source_file) target_file = self.data_root / Path(target_file) target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name target_img = Image.open(target_file.as_posix()).convert("RGB") source_img = Image.open(source_file.as_posix()).convert("RGB") target_mask = Image.open(target_mask_file.as_posix()).convert("RGB") source_img = self.transform(source_img) target_img = self.transform(target_img) target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0) return { "source_image": source_img, "target_image": target_img, "target_mask": target_mask, "same": same, # "source_img_name": source_file.as_posix(), # "target_img_name": target_file.as_posix(), # "target_mask_name": target_mask_file.as_posix(), } class TrainDatasetDataLoader: """Wrapper class of Dataset class that performs multi-threaded data loading""" def __init__(self): """Initialize this class""" opt = TrainConfig() if opt.mode is FaceSwapMode.MANY_TO_MANY: self.dataset = ManyToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.same_rate) elif opt.mode is FaceSwapMode.ONE_TO_MANY: logger.info(f"In one-to-many mode, source face is {opt.source_name}") self.dataset = OneToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.source_name, opt.same_rate) else: raise NotImplementedError logger.info(f"dataset {type(self.dataset).__name__} created") if opt.use_ddp: self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset, shuffle=True) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batch_size, num_workers=int(opt.num_threads), drop_last=True, sampler=self.train_sampler, pin_memory=True, ) else: self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.num_threads), drop_last=True, pin_memory=True, ) def load_data(self): return self def __len__(self): """Return the number of data in the dataset""" return len(self.dataset) def __iter__(self): """Return a batch of data""" for data in self.dataloader: yield data if __name__ == "__main__": dataloader = TrainDatasetDataLoader() for idx, data in enumerate(dataloader): # print(data["source_img_name"]) # print(data["target_img_name"]) # print(data["target_mask_name"]) print(data["same"])