Text2Human / Text2Human /data /segm_attr_dataset.py
yumingj's picture
update
bde71cb
import os
import os.path
import random
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
class DeepFashionAttrSegmDataset(data.Dataset):
def __init__(self,
img_dir,
segm_dir,
pose_dir,
ann_dir,
downsample_factor=2,
xflip=False):
self._img_path = img_dir
self._densepose_path = pose_dir
self._segm_path = segm_dir
self._image_fnames = []
self.upper_fused_attrs = []
self.lower_fused_attrs = []
self.outer_fused_attrs = []
self.downsample_factor = downsample_factor
self.xflip = xflip
# load attributes
assert os.path.exists(f'{ann_dir}/upper_fused.txt')
for idx, row in enumerate(
open(os.path.join(f'{ann_dir}/upper_fused.txt'), 'r')):
annotations = row.split()
self._image_fnames.append(annotations[0])
# assert self._image_fnames[idx] == annotations[0]
self.upper_fused_attrs.append(int(annotations[1]))
assert len(self._image_fnames) == len(self.upper_fused_attrs)
assert os.path.exists(f'{ann_dir}/lower_fused.txt')
for idx, row in enumerate(
open(os.path.join(f'{ann_dir}/lower_fused.txt'), 'r')):
annotations = row.split()
assert self._image_fnames[idx] == annotations[0]
self.lower_fused_attrs.append(int(annotations[1]))
assert len(self._image_fnames) == len(self.lower_fused_attrs)
assert os.path.exists(f'{ann_dir}/outer_fused.txt')
for idx, row in enumerate(
open(os.path.join(f'{ann_dir}/outer_fused.txt'), 'r')):
annotations = row.split()
assert self._image_fnames[idx] == annotations[0]
self.outer_fused_attrs.append(int(annotations[1]))
assert len(self._image_fnames) == len(self.outer_fused_attrs)
# remove the overlapping item between upper cls and lower cls
# cls 21 can appear with upper clothes
# cls 4 can appear with lower clothes
self.upper_cls = [1., 4.]
self.lower_cls = [3., 5., 21.]
self.outer_cls = [2.]
self.other_cls = [
11., 18., 7., 8., 9., 10., 12., 16., 17., 19., 20., 22., 23., 15.,
14., 13., 0., 6.
]
def _open_file(self, path_prefix, fname):
return open(os.path.join(path_prefix, fname), 'rb')
def _load_raw_image(self, raw_idx):
fname = self._image_fnames[raw_idx]
with self._open_file(self._img_path, fname) as f:
image = Image.open(f)
if self.downsample_factor != 1:
width, height = image.size
width = width // self.downsample_factor
height = height // self.downsample_factor
image = image.resize(
size=(width, height), resample=Image.LANCZOS)
image = np.array(image)
if image.ndim == 2:
image = image[:, :, np.newaxis] # HW => HWC
image = image.transpose(2, 0, 1) # HWC => CHW
return image
def _load_densepose(self, raw_idx):
fname = self._image_fnames[raw_idx]
fname = f'{fname[:-4]}_densepose.png'
with self._open_file(self._densepose_path, fname) as f:
densepose = Image.open(f)
if self.downsample_factor != 1:
width, height = densepose.size
width = width // self.downsample_factor
height = height // self.downsample_factor
densepose = densepose.resize(
size=(width, height), resample=Image.NEAREST)
# channel-wise IUV order, [3, H, W]
densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1)
return densepose.astype(np.float32)
def _load_segm(self, raw_idx):
fname = self._image_fnames[raw_idx]
fname = f'{fname[:-4]}_segm.png'
with self._open_file(self._segm_path, fname) as f:
segm = Image.open(f)
if self.downsample_factor != 1:
width, height = segm.size
width = width // self.downsample_factor
height = height // self.downsample_factor
segm = segm.resize(
size=(width, height), resample=Image.NEAREST)
segm = np.array(segm)
segm = segm[:, :, np.newaxis].transpose(2, 0, 1)
return segm.astype(np.float32)
def __getitem__(self, index):
image = self._load_raw_image(index)
pose = self._load_densepose(index)
segm = self._load_segm(index)
if self.xflip and random.random() > 0.5:
assert image.ndim == 3 # CHW
image = image[:, :, ::-1].copy()
pose = pose[:, :, ::-1].copy()
segm = segm[:, :, ::-1].copy()
image = torch.from_numpy(image)
segm = torch.from_numpy(segm)
upper_fused_attr = self.upper_fused_attrs[index]
lower_fused_attr = self.lower_fused_attrs[index]
outer_fused_attr = self.outer_fused_attrs[index]
# mask 0: denotes the common codebook,
# mask (attr + 1): denotes the texture-specific codebook
mask = torch.zeros_like(segm)
if upper_fused_attr != 17:
for cls in self.upper_cls:
mask[segm == cls] = upper_fused_attr + 1
if lower_fused_attr != 17:
for cls in self.lower_cls:
mask[segm == cls] = lower_fused_attr + 1
if outer_fused_attr != 17:
for cls in self.outer_cls:
mask[segm == cls] = outer_fused_attr + 1
pose = pose / 12. - 1
image = image / 127.5 - 1
return_dict = {
'image': image,
'densepose': pose,
'segm': segm,
'texture_mask': mask,
'img_name': self._image_fnames[index]
}
return return_dict
def __len__(self):
return len(self._image_fnames)