Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision.transforms import Compose, Resize, ToTensor | |
import imageio | |
from tqdm import tqdm | |
class pix2pixDataset(Dataset): | |
def __init__(self, dataset="maps", data_dir="/projects/ml4science/datasets_pix2pix/", split="train", normalize=True, transforms=None, preload=False, image_size=256, direction="BtoA"): | |
self.datadir = os.path.join(data_dir, dataset) | |
self.img_name_list_path = os.path.join(data_dir, dataset, split) | |
if not os.path.exists(self.datadir): | |
print(f'Dataset directory {self.datadir} does not exists') | |
self.normalize=normalize | |
self.image_name_list = os.listdir(self.img_name_list_path) | |
self.preload = preload | |
self.direction = direction | |
if transforms is None: | |
self.transforms = Compose([ | |
ToTensor(), # Convert to torch tensor | |
Resize((image_size, image_size), antialias=False), # Resize to 256x256 | |
]) | |
else: | |
self.transforms = transforms | |
if self.preload: | |
self.x_list, self.y_list= (), () | |
for name in tqdm(self.image_name_list): | |
x, y = self.load_every(name) | |
self.x_list = self.x_list + (x,) | |
self.y_list = self.y_list + (y,) | |
self.x_list = torch.stack(self.x_list, 0) | |
self.y_list = torch.stack(self.y_list, 0) | |
print(f"{split} dataset preloaded!") | |
def load_every(self, name): | |
img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name))) | |
img_H, img_W = img_array.shape[0], img_array.shape[1] | |
if self.normalize: | |
img_array = self.normalize_fn(img_array) | |
x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :] | |
x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform | |
return x_img.float(), y_img.float() | |
def normalize_fn(self, x): | |
return (x/255. -0.5)*2 | |
def unnormalize_fn(self, x): | |
return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images | |
def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation) | |
class_cond = None | |
if self.preload: | |
x_img, y_img = self.x_list[index], self.y_list[index] | |
else: | |
name = self.image_name_list[index] | |
x_img, y_img = self.load_every(name) | |
# if self.direction == "BtoA": | |
# return x_img, y_img, class_cond | |
# elif self.direction == "AtoB": | |
# return y_img, x_img, class_cond | |
batch ={ | |
"image1":x_img, | |
"image2":y_img, | |
} | |
return batch | |
def __len__(self): | |
return len(self.image_name_list) | |
class FishDataset(Dataset): | |
def __init__(self, data_dir="/projects/ml4science/FishDiffusion/", split="train", normalize=True, transforms=None, preload=False, image_size=128): | |
self.datadir = os.path.join(data_dir) | |
self.img_name_list_path = os.path.join(data_dir, split) | |
if not os.path.exists(self.datadir): | |
print(f'Dataset directory {self.datadir} does not exists') | |
self.normalize=normalize | |
self.image_name_list = os.listdir(self.img_name_list_path) | |
self.preload = preload | |
if transforms is None: | |
# self.transforms = Compose([ | |
# ToTensor(), # Convert to torch tensor | |
# Resize((image_size, image_size), antialias=False), # Resize to 256x256 | |
# ]) | |
self.transforms = Compose([ | |
ToTensor(), # Convert to torch tensor | |
]) | |
else: | |
self.transforms = transforms | |
if self.preload: | |
self.x_list, self.y_list, self.class_id = (), (), [] | |
for name in tqdm(self.image_name_list): | |
x, y = self.load_every(name) | |
cls_id = int(name.split("_")[-1][:-4]) | |
self.x_list = self.x_list + (x,) | |
self.y_list = self.y_list + (y,) | |
self.class_id.append(cls_id) | |
self.x_list = torch.stack(self.x_list, 0) | |
self.y_list = torch.stack(self.y_list, 0) | |
self.class_id = torch.tensor(self.class_id) | |
print(f"{split} dataset preloaded!") | |
def load_every(self, name): | |
img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name))) | |
img_H, img_W = img_array.shape[0], img_array.shape[1] | |
if self.normalize: | |
img_array = self.normalize_fn(img_array) | |
x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :] | |
x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform | |
return x_img.float(), y_img.float() | |
def normalize_fn(self, x): | |
return (x/255. -0.5)*2 | |
def unnormalize_fn(self, x): | |
return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images | |
def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation) | |
if self.preload: | |
x_img, y_img, class_id = self.x_list[index], self.y_list[index], self.class_id[index] | |
else: | |
name = self.image_name_list[index] | |
class_id = torch.tensor(int(name.split("_")[-1][:-4])) | |
x_img, y_img = self.load_every(name) | |
return x_img, y_img, class_id | |
def __len__(self): | |
return len(self.image_name_list) |