Spaces:
Build error
Build error
File size: 8,706 Bytes
83d8d3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import os
import pickle
import random
from pathlib import Path
from typing import Dict
from typing import List
import torch
from loguru import logger
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from configs.mode import FaceSwapMode
from configs.train_config import TrainConfig
class ManyToManyTrainDataset(Dataset):
def __init__(self, dataset_root: str, dataset_index: str, same_rate=0.5):
"""
Many-to-many 训练数据集构建
Parameters:
-----------
dataset_root: str, 数据集根目录
dataset_index: str, 数据集index文件路径
same_rate: float, 每个batch里面相同人脸所占的比例
"""
super(ManyToManyTrainDataset, self).__init__()
self.transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop((256, 256)),
transforms.ToTensor(),
]
)
self.data_root = Path(dataset_root)
with open(dataset_index, "rb") as f:
self.file_index = pickle.load(f, encoding="bytes")
self.same_rate = same_rate
self.id_list: List[str] = list(self.file_index.keys())
# 所有id都遍历一遍,视为一个epoch
self.length = len(self.id_list)
self.image_num = sum([len(v) for v in self.file_index.values()])
self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth"
logger.info(f"dataset contains {self.length} ids and {self.image_num} images")
logger.info(f"will use mask mode: {self.mask_dir}")
def __len__(self):
return self.length
def __getitem__(self, index):
source_id_index = index
source_file = random.choice(self.file_index[self.id_list[source_id_index]])
if random.random() < self.same_rate:
# 在相同id的文件列表中选择
target_file = random.choice(self.file_index[self.id_list[source_id_index]])
same = torch.ones(1)
else:
# 在不同id的文件列表中选择
target_id_index = random.choice(list(set(range(self.length)) - set([source_id_index])))
target_file = random.choice(self.file_index[self.id_list[target_id_index]])
same = torch.zeros(1)
source_file = self.data_root / Path(source_file)
target_file = self.data_root / Path(target_file)
target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name
target_img = Image.open(target_file.as_posix()).convert("RGB")
source_img = Image.open(source_file.as_posix()).convert("RGB")
target_mask = Image.open(target_mask_file.as_posix()).convert("RGB")
source_img = self.transform(source_img)
target_img = self.transform(target_img)
target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0)
return {
"source_image": source_img,
"target_image": target_img,
"target_mask": target_mask,
"same": same,
# "source_img_name": source_file.as_posix(),
# "target_img_name": target_file.as_posix(),
# "target_mask_name": target_mask_file.as_posix(),
}
class OneToManyTrainDataset(Dataset):
def __init__(self, dataset_root: str, dataset_index: str, source_name: str, same_rate=0.5):
"""
One-to-many 训练数据集构建
Parameters:
-----------
dataset_root: str, 数据集根目录
dataset_index: str, 数据集index文件路径
source_name: str, source face id的名称, one-to-many里面的one
same_rate: float, 每个batch里面相同人脸所占的比例
"""
super(OneToManyTrainDataset, self).__init__()
self.transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.CenterCrop((256, 256)),
transforms.ToTensor(),
]
)
self.data_root = Path(dataset_root)
with open(dataset_index, "rb") as f:
self.file_index = pickle.load(f, encoding="bytes")
self.same_rate = same_rate
self.source_name = source_name
self.id_list: List[str] = list(self.file_index.keys())
try:
self.source_id_index: int = self.id_list.index(self.source_name)
except Exception:
raise Exception(f"{self.source_name} not in dataset dir")
# 所有id都遍历一遍,视为一个epoch
self.length = len(self.id_list)
self.image_num = sum([len(v) for v in self.file_index.values()])
self.mask_dir = "mask" if TrainConfig().mouth_mask else "mask_no_mouth"
logger.info(f"dataset contains {self.length} ids and {self.image_num} images")
logger.info(f"will use mask mode: {self.mask_dir}")
def __len__(self):
return self.length
def __getitem__(self, index):
target_id_index = index
target_file = random.choice(self.file_index[self.id_list[target_id_index]])
if random.random() < self.same_rate:
# 在相同id的文件列表中选择
source_file = random.choice(self.file_index[self.id_list[target_id_index]])
same = torch.ones(1)
else:
# 直接选择source name中的图片
source_file = random.choice(self.file_index[self.source_name])
# 如果和target同个id
if self.source_id_index == target_id_index:
same = torch.ones(1)
else:
same = torch.zeros(1)
source_file = self.data_root / Path(source_file)
target_file = self.data_root / Path(target_file)
target_mask_file = target_file.parent.parent.parent / self.mask_dir / target_file.parent.stem / target_file.name
target_img = Image.open(target_file.as_posix()).convert("RGB")
source_img = Image.open(source_file.as_posix()).convert("RGB")
target_mask = Image.open(target_mask_file.as_posix()).convert("RGB")
source_img = self.transform(source_img)
target_img = self.transform(target_img)
target_mask = self.transform(target_mask)[0, :, :].unsqueeze(0)
return {
"source_image": source_img,
"target_image": target_img,
"target_mask": target_mask,
"same": same,
# "source_img_name": source_file.as_posix(),
# "target_img_name": target_file.as_posix(),
# "target_mask_name": target_mask_file.as_posix(),
}
class TrainDatasetDataLoader:
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self):
"""Initialize this class"""
opt = TrainConfig()
if opt.mode is FaceSwapMode.MANY_TO_MANY:
self.dataset = ManyToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.same_rate)
elif opt.mode is FaceSwapMode.ONE_TO_MANY:
logger.info(f"In one-to-many mode, source face is {opt.source_name}")
self.dataset = OneToManyTrainDataset(opt.dataset_root, opt.dataset_index, opt.source_name, opt.same_rate)
else:
raise NotImplementedError
logger.info(f"dataset {type(self.dataset).__name__} created")
if opt.use_ddp:
self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.dataset, shuffle=True)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
num_workers=int(opt.num_threads),
drop_last=True,
sampler=self.train_sampler,
pin_memory=True,
)
else:
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=int(opt.num_threads),
drop_last=True,
pin_memory=True,
)
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return len(self.dataset)
def __iter__(self):
"""Return a batch of data"""
for data in self.dataloader:
yield data
if __name__ == "__main__":
dataloader = TrainDatasetDataLoader()
for idx, data in enumerate(dataloader):
# print(data["source_img_name"])
# print(data["target_img_name"])
# print(data["target_mask_name"])
print(data["same"])
|