phylo-diffusion / ldm /data /i2sb_dataloader.py
mridulk's picture
added data
17191f4
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose, Resize, ToTensor
import imageio
from tqdm import tqdm
class pix2pixDataset(Dataset):
def __init__(self, dataset="maps", data_dir="/projects/ml4science/datasets_pix2pix/", split="train", normalize=True, transforms=None, preload=False, image_size=256, direction="BtoA"):
self.datadir = os.path.join(data_dir, dataset)
self.img_name_list_path = os.path.join(data_dir, dataset, split)
if not os.path.exists(self.datadir):
print(f'Dataset directory {self.datadir} does not exists')
self.normalize=normalize
self.image_name_list = os.listdir(self.img_name_list_path)
self.preload = preload
self.direction = direction
if transforms is None:
self.transforms = Compose([
ToTensor(), # Convert to torch tensor
Resize((image_size, image_size), antialias=False), # Resize to 256x256
])
else:
self.transforms = transforms
if self.preload:
self.x_list, self.y_list= (), ()
for name in tqdm(self.image_name_list):
x, y = self.load_every(name)
self.x_list = self.x_list + (x,)
self.y_list = self.y_list + (y,)
self.x_list = torch.stack(self.x_list, 0)
self.y_list = torch.stack(self.y_list, 0)
print(f"{split} dataset preloaded!")
def load_every(self, name):
img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name)))
img_H, img_W = img_array.shape[0], img_array.shape[1]
if self.normalize:
img_array = self.normalize_fn(img_array)
x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :]
x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform
return x_img.float(), y_img.float()
def normalize_fn(self, x):
return (x/255. -0.5)*2
def unnormalize_fn(self, x):
return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images
def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation)
class_cond = None
if self.preload:
x_img, y_img = self.x_list[index], self.y_list[index]
else:
name = self.image_name_list[index]
x_img, y_img = self.load_every(name)
# if self.direction == "BtoA":
# return x_img, y_img, class_cond
# elif self.direction == "AtoB":
# return y_img, x_img, class_cond
batch ={
"image1":x_img,
"image2":y_img,
}
return batch
def __len__(self):
return len(self.image_name_list)
class FishDataset(Dataset):
def __init__(self, data_dir="/projects/ml4science/FishDiffusion/", split="train", normalize=True, transforms=None, preload=False, image_size=128):
self.datadir = os.path.join(data_dir)
self.img_name_list_path = os.path.join(data_dir, split)
if not os.path.exists(self.datadir):
print(f'Dataset directory {self.datadir} does not exists')
self.normalize=normalize
self.image_name_list = os.listdir(self.img_name_list_path)
self.preload = preload
if transforms is None:
# self.transforms = Compose([
# ToTensor(), # Convert to torch tensor
# Resize((image_size, image_size), antialias=False), # Resize to 256x256
# ])
self.transforms = Compose([
ToTensor(), # Convert to torch tensor
])
else:
self.transforms = transforms
if self.preload:
self.x_list, self.y_list, self.class_id = (), (), []
for name in tqdm(self.image_name_list):
x, y = self.load_every(name)
cls_id = int(name.split("_")[-1][:-4])
self.x_list = self.x_list + (x,)
self.y_list = self.y_list + (y,)
self.class_id.append(cls_id)
self.x_list = torch.stack(self.x_list, 0)
self.y_list = torch.stack(self.y_list, 0)
self.class_id = torch.tensor(self.class_id)
print(f"{split} dataset preloaded!")
def load_every(self, name):
img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name)))
img_H, img_W = img_array.shape[0], img_array.shape[1]
if self.normalize:
img_array = self.normalize_fn(img_array)
x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :]
x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform
return x_img.float(), y_img.float()
def normalize_fn(self, x):
return (x/255. -0.5)*2
def unnormalize_fn(self, x):
return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images
def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation)
if self.preload:
x_img, y_img, class_id = self.x_list[index], self.y_list[index], self.class_id[index]
else:
name = self.image_name_list[index]
class_id = torch.tensor(int(name.split("_")[-1][:-4]))
x_img, y_img = self.load_every(name)
return x_img, y_img, class_id
def __len__(self):
return len(self.image_name_list)