Spaces:
Running
Running
import os | |
import itertools | |
import pickle | |
import torch | |
from torchvision import models | |
from pathlib import Path | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader, Sampler | |
import dependecies.segroot.paired_transforms_pt04 as p_tr | |
train_transform = p_tr.Compose([ | |
p_tr.RandomCrop(256), | |
p_tr.RandomRotation((90, 90)), | |
p_tr.RandomRotation((180, 180)), | |
p_tr.RandomRotation((270, 270)), | |
p_tr.RandomHorizontalFlip(), | |
p_tr.RandomVerticalFlip(), | |
p_tr.ToTensor() | |
]) | |
# normalize = p_tr.Normalize([0.35042979, 0.44016893, 0.2340332], | |
# [0.20999724, 0.25972678, 0.13885915]) | |
normalize = p_tr.Normalize([0.5, 0.5, 0.5], | |
[0.5, 0.5, 0.5]) | |
def pad_pair_256(image, gt): | |
w, h = image.size | |
new_w = ((w - 1) // 256 + 1) * 256 | |
new_h = ((h - 1) // 256 + 1) * 256 | |
new_image = Image.new("RGB", (new_w, new_h)) | |
new_image.paste(image, ((new_w - w) // 2, (new_h - h) // 2)) | |
new_gt = Image.new("L", (new_w, new_h)) | |
new_gt.paste(gt, ((new_w - w) // 2, (new_h - h) // 2)) | |
return new_image, new_gt | |
def convert_png(image, gt): | |
new_image = Image.new('RGB', (256, 256)) | |
new_image.paste(image) | |
new_gt = Image.new('L', (256, 256)) | |
new_gt.paste(gt) | |
return new_image, new_gt | |
def get_paths(root_dir, im_ids): | |
imgs = [] | |
for i in im_ids: | |
tmp = Path(root_dir).glob('*{}-*.png'.format(i)) | |
tmp = [p for p in tmp if p.parts[-1].startswith(str(i)+'-')] | |
imgs = imgs + list(tmp) | |
return imgs | |
class LoopSampler(Sampler): | |
def __init__(self, data_source): | |
self.data_source = data_source | |
def __iter__(self): | |
return itertools.cycle(range(len(self.data_source))) | |
def __len__(self): | |
return len(self.data_source) | |
class TrainDataset(Dataset): | |
def __init__(self, im_ids): | |
self.root_dir = '../data/data_raw' | |
self.mask_dir = '../data/mask' | |
self.im_ids = im_ids | |
with open('../data/info.pkl', 'rb') as handle: | |
self.info = pickle.load(handle) | |
self.fns = [self.info[im_id] for im_id in im_ids] | |
def __getitem__(self, index): | |
im_fn = self.fns[index] | |
im_name = os.path.join(self.root_dir, im_fn) | |
gt_name = os.path.join( | |
self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg') | |
image = Image.open(im_name) | |
gt = Image.open(gt_name) | |
image, gt = pad_pair_256(image, gt) | |
image, gt = train_transform(image, gt) | |
image = normalize(image) | |
return image, gt | |
def __len__(self): | |
return len(self.im_ids) | |
class StaticTrainDataset(Dataset): | |
def __init__(self, im_ids): | |
self.subimgs = sorted(get_paths('../data/subimg', im_ids)) | |
self.submasks = sorted(get_paths('../data/submask', im_ids)) | |
self.im_ids = im_ids | |
def __getitem__(self, index): | |
im_name = self.subimgs[index] | |
gt_name = self.submasks[index] | |
image = Image.open(im_name) | |
gt = Image.open(gt_name) | |
image, gt = convert_png(image, gt) | |
image, gt = train_transform(image, gt) | |
image = normalize(image) | |
return image, gt | |
def __len__(self): | |
return len(self.im_ids * 90) | |
class TrainDataLoader(): | |
def __init__(self, dataset, batch_size, num_workers=0): | |
self.dataset = dataset | |
self.dataloader = DataLoader(self.dataset, batch_size=batch_size, | |
num_workers=num_workers, sampler=LoopSampler(self.dataset)) | |
self.dl = iter(self.dataloader) | |
def next_batch(self): | |
image, gt = next(self.dl) | |
return image, gt | |
class TestDataset(Dataset): | |
def __init__(self, im_ids): | |
self.root_dir = '../data/data_raw' | |
self.mask_dir = '../data/masks' | |
with open('../data/info.pkl', 'rb') as handle: | |
self.info = pickle.load(handle) | |
self.im_ids = im_ids | |
self.fns = [self.info[im_id] for im_id in im_ids] | |
def __getitem__(self, index): | |
im_fn = self.fns[index] | |
im_name = os.path.join(self.root_dir, im_fn) | |
gt_name = os.path.join( | |
self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg') | |
image = Image.open(im_name) | |
gt = Image.open(gt_name) | |
image, gt = pad_pair_256(image, gt) | |
image, gt = p_tr.ToTensor()(image, gt) | |
image = normalize(image) | |
return image, gt | |
def __len__(self): | |
return len(self.fns) | |