Spaces:
Running
Running
File size: 4,482 Bytes
bc97962 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import itertools
import pickle
import torch
from torchvision import models
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
import dependecies.segroot.paired_transforms_pt04 as p_tr
train_transform = p_tr.Compose([
p_tr.RandomCrop(256),
p_tr.RandomRotation((90, 90)),
p_tr.RandomRotation((180, 180)),
p_tr.RandomRotation((270, 270)),
p_tr.RandomHorizontalFlip(),
p_tr.RandomVerticalFlip(),
p_tr.ToTensor()
])
# normalize = p_tr.Normalize([0.35042979, 0.44016893, 0.2340332],
# [0.20999724, 0.25972678, 0.13885915])
normalize = p_tr.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])
def pad_pair_256(image, gt):
w, h = image.size
new_w = ((w - 1) // 256 + 1) * 256
new_h = ((h - 1) // 256 + 1) * 256
new_image = Image.new("RGB", (new_w, new_h))
new_image.paste(image, ((new_w - w) // 2, (new_h - h) // 2))
new_gt = Image.new("L", (new_w, new_h))
new_gt.paste(gt, ((new_w - w) // 2, (new_h - h) // 2))
return new_image, new_gt
def convert_png(image, gt):
new_image = Image.new('RGB', (256, 256))
new_image.paste(image)
new_gt = Image.new('L', (256, 256))
new_gt.paste(gt)
return new_image, new_gt
def get_paths(root_dir, im_ids):
imgs = []
for i in im_ids:
tmp = Path(root_dir).glob('*{}-*.png'.format(i))
tmp = [p for p in tmp if p.parts[-1].startswith(str(i)+'-')]
imgs = imgs + list(tmp)
return imgs
class LoopSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return itertools.cycle(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
class TrainDataset(Dataset):
def __init__(self, im_ids):
self.root_dir = '../data/data_raw'
self.mask_dir = '../data/mask'
self.im_ids = im_ids
with open('../data/info.pkl', 'rb') as handle:
self.info = pickle.load(handle)
self.fns = [self.info[im_id] for im_id in im_ids]
def __getitem__(self, index):
im_fn = self.fns[index]
im_name = os.path.join(self.root_dir, im_fn)
gt_name = os.path.join(
self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg')
image = Image.open(im_name)
gt = Image.open(gt_name)
image, gt = pad_pair_256(image, gt)
image, gt = train_transform(image, gt)
image = normalize(image)
return image, gt
def __len__(self):
return len(self.im_ids)
class StaticTrainDataset(Dataset):
def __init__(self, im_ids):
self.subimgs = sorted(get_paths('../data/subimg', im_ids))
self.submasks = sorted(get_paths('../data/submask', im_ids))
self.im_ids = im_ids
def __getitem__(self, index):
im_name = self.subimgs[index]
gt_name = self.submasks[index]
image = Image.open(im_name)
gt = Image.open(gt_name)
image, gt = convert_png(image, gt)
image, gt = train_transform(image, gt)
image = normalize(image)
return image, gt
def __len__(self):
return len(self.im_ids * 90)
class TrainDataLoader():
def __init__(self, dataset, batch_size, num_workers=0):
self.dataset = dataset
self.dataloader = DataLoader(self.dataset, batch_size=batch_size,
num_workers=num_workers, sampler=LoopSampler(self.dataset))
self.dl = iter(self.dataloader)
def next_batch(self):
image, gt = next(self.dl)
return image, gt
class TestDataset(Dataset):
def __init__(self, im_ids):
self.root_dir = '../data/data_raw'
self.mask_dir = '../data/masks'
with open('../data/info.pkl', 'rb') as handle:
self.info = pickle.load(handle)
self.im_ids = im_ids
self.fns = [self.info[im_id] for im_id in im_ids]
def __getitem__(self, index):
im_fn = self.fns[index]
im_name = os.path.join(self.root_dir, im_fn)
gt_name = os.path.join(
self.mask_dir, im_fn.split('.jpg')[0] + '-mask.jpg')
image = Image.open(im_name)
gt = Image.open(gt_name)
image, gt = pad_pair_256(image, gt)
image, gt = p_tr.ToTensor()(image, gt)
image = normalize(image)
return image, gt
def __len__(self):
return len(self.fns)
|