Andres Felipe Ruiz-Hurtado
initial
bc97962
import os
import itertools
import pickle
import torch
from torchvision import models
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
import dependecies.segroot.paired_transforms_pt04 as p_tr
train_transform = p_tr.Compose([
p_tr.RandomCrop(256),
p_tr.RandomRotation((90, 90)),
p_tr.RandomRotation((180, 180)),
p_tr.RandomRotation((270, 270)),
p_tr.RandomHorizontalFlip(),
p_tr.RandomVerticalFlip(),
p_tr.ToTensor()
])
# normalize = p_tr.Normalize([0.35042979, 0.44016893, 0.2340332],
# [0.20999724, 0.25972678, 0.13885915])
normalize = p_tr.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])
def pad_pair_256(image, gt):
w, h = image.size
new_w = ((w - 1) // 256 + 1) * 256
new_h = ((h - 1) // 256 + 1) * 256
new_image = Image.new("RGB", (new_w, new_h))
new_image.paste(image, ((new_w - w) // 2, (new_h - h) // 2))
new_gt = Image.new("L", (new_w, new_h))
new_gt.paste(gt, ((new_w - w) // 2, (new_h - h) // 2))
return new_image, new_gt
def convert_png(image, gt):
new_image = Image.new('RGB', (256, 256))
new_image.paste(image)
new_gt = Image.new('L', (256, 256))
new_gt.paste(gt)
return new_image, new_gt
def get_paths(root_dir, im_ids):
imgs = []
for i in im_ids:
tmp = Path(root_dir).glob('*{}-*.png'.format(i))
tmp = [p for p in tmp if p.parts[-1].startswith(str(i)+'-')]
imgs = imgs + list(tmp)
return imgs
class LoopSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return itertools.cycle(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
class TrainDataset(Dataset):
def __init__(self, im_ids):
self.root_dir = '../data/data_raw'
self.mask_dir = '../data/mask'
self.im_ids = im_ids
with open('../data/info.pkl', 'rb') as handle:
self.info = pickle.load(handle)
self.fns = [self.info[im_id] for im_id in im_ids]
def __getitem__(self, index):
im_fn = self.fns[index]
im_name = os.path.join(self.root_dir, im_fn)
gt_name = os.path.join(
self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg')
image = Image.open(im_name)
gt = Image.open(gt_name)
image, gt = pad_pair_256(image, gt)
image, gt = train_transform(image, gt)
image = normalize(image)
return image, gt
def __len__(self):
return len(self.im_ids)
class StaticTrainDataset(Dataset):
def __init__(self, im_ids):
self.subimgs = sorted(get_paths('../data/subimg', im_ids))
self.submasks = sorted(get_paths('../data/submask', im_ids))
self.im_ids = im_ids
def __getitem__(self, index):
im_name = self.subimgs[index]
gt_name = self.submasks[index]
image = Image.open(im_name)
gt = Image.open(gt_name)
image, gt = convert_png(image, gt)
image, gt = train_transform(image, gt)
image = normalize(image)
return image, gt
def __len__(self):
return len(self.im_ids * 90)
class TrainDataLoader():
def __init__(self, dataset, batch_size, num_workers=0):
self.dataset = dataset
self.dataloader = DataLoader(self.dataset, batch_size=batch_size,
num_workers=num_workers, sampler=LoopSampler(self.dataset))
self.dl = iter(self.dataloader)
def next_batch(self):
image, gt = next(self.dl)
return image, gt
class TestDataset(Dataset):
def __init__(self, im_ids):
self.root_dir = '../data/data_raw'
self.mask_dir = '../data/masks'
with open('../data/info.pkl', 'rb') as handle:
self.info = pickle.load(handle)
self.im_ids = im_ids
self.fns = [self.info[im_id] for im_id in im_ids]
def __getitem__(self, index):
im_fn = self.fns[index]
im_name = os.path.join(self.root_dir, im_fn)
gt_name = os.path.join(
self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg')
image = Image.open(im_name)
gt = Image.open(gt_name)
image, gt = pad_pair_256(image, gt)
image, gt = p_tr.ToTensor()(image, gt)
image = normalize(image)
return image, gt
def __len__(self):
return len(self.fns)