Spaces:
Runtime error
Runtime error
import numpy as np | |
from torch.utils.data import DataLoader | |
from torch.utils.data.sampler import SubsetRandomSampler | |
import torchvision.transforms as transforms | |
from data_aug.gaussian_blur import GaussianBlur | |
from torchvision import datasets | |
import pandas as pd | |
from PIL import Image | |
from skimage import io, img_as_ubyte | |
np.random.seed(0) | |
class Dataset(): | |
def __init__(self, csv_file, transform=None): | |
lines = [] | |
with open(csv_file) as f: | |
for line in f: | |
line = line.rstrip().strip() | |
lines.append(line) | |
self.files_list = lines#pd.read_csv(csv_file) | |
self.transform = transform | |
def __len__(self): | |
return len(self.files_list) | |
def __getitem__(self, idx): | |
temp_path = self.files_list[idx]# self.files_list.iloc[idx, 0] | |
img = Image.open(temp_path) | |
img = transforms.functional.to_tensor(img) | |
if self.transform: | |
sample = self.transform(img) | |
return sample | |
class ToPIL(object): | |
def __call__(self, sample): | |
img = sample | |
img = transforms.functional.to_pil_image(img) | |
return img | |
class DataSetWrapper(object): | |
def __init__(self, batch_size, num_workers, valid_size, input_shape, s): | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.valid_size = valid_size | |
self.s = s | |
self.input_shape = eval(input_shape) | |
def get_data_loaders(self): | |
data_augment = self._get_simclr_pipeline_transform() | |
train_dataset = Dataset(csv_file='all_patches.csv', transform=SimCLRDataTransform(data_augment)) | |
train_loader, valid_loader = self.get_train_validation_data_loaders(train_dataset) | |
return train_loader, valid_loader | |
def _get_simclr_pipeline_transform(self): | |
# get a set of data augmentation transformations as described in the SimCLR paper. | |
color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s) | |
data_transforms = transforms.Compose([ToPIL(), | |
# transforms.RandomResizedCrop(size=self.input_shape[0]), | |
transforms.Resize((self.input_shape[0],self.input_shape[1])), | |
transforms.RandomHorizontalFlip(), | |
transforms.RandomApply([color_jitter], p=0.8), | |
transforms.RandomGrayscale(p=0.2), | |
GaussianBlur(kernel_size=int(0.06 * self.input_shape[0])), | |
transforms.ToTensor()]) | |
return data_transforms | |
def get_train_validation_data_loaders(self, train_dataset): | |
# obtain training indices that will be used for validation | |
num_train = len(train_dataset) | |
indices = list(range(num_train)) | |
np.random.shuffle(indices) | |
split = int(np.floor(self.valid_size * num_train)) | |
train_idx, valid_idx = indices[split:], indices[:split] | |
# define samplers for obtaining training and validation batches | |
train_sampler = SubsetRandomSampler(train_idx) | |
valid_sampler = SubsetRandomSampler(valid_idx) | |
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=train_sampler, | |
num_workers=self.num_workers, drop_last=True, shuffle=False) | |
valid_loader = DataLoader(train_dataset, batch_size=self.batch_size, sampler=valid_sampler, | |
num_workers=self.num_workers, drop_last=True) | |
return train_loader, valid_loader | |
class SimCLRDataTransform(object): | |
def __init__(self, transform): | |
self.transform = transform | |
def __call__(self, sample): | |
xi = self.transform(sample) | |
xj = self.transform(sample) | |
return xi, xj | |