Spaces:
Running
on
Zero
Running
on
Zero
# Authors: Hui Ren (rhfeiyang.github.io) | |
import random | |
import torch.utils.data as data | |
from PIL import Image | |
import os | |
import torch | |
# from tqdm import tqdm | |
class ImageSet(data.Dataset): | |
def __init__(self, folder , transform=None, keep_in_mem=True, caption=None): | |
self.path = folder | |
self.transform = transform | |
self.caption_path = None | |
self.images = [] | |
self.captions = [] | |
self.keep_in_mem = keep_in_mem | |
if not isinstance(folder, list): | |
self.image_files = [file for file in os.listdir(folder) if file.endswith((".png",".jpg"))] | |
self.image_files.sort() | |
else: | |
self.images = folder | |
if not isinstance(caption, list): | |
if caption not in [None, "", "None"]: | |
self.caption_path = caption | |
self.caption_files = [os.path.join(caption, file.replace(".png", ".txt").replace(".jpg", ".txt")) for file in self.image_files] | |
self.caption_files.sort() | |
else: | |
self.caption_path = True | |
self.captions = caption | |
# get all the image files png/jpg | |
if keep_in_mem: | |
if len(self.images) == 0: | |
for file in self.image_files: | |
img = self.load_image(os.path.join(self.path, file)) | |
self.images.append(img) | |
if len(self.captions) == 0: | |
if self.caption_path is not None: | |
self.captions = [] | |
for file in self.caption_files: | |
caption = self.load_caption(file) | |
self.captions.append(caption) | |
else: | |
self.images = None | |
def limit_num(self, n): | |
raise NotImplementedError | |
assert n <= len(self), f"n should be less than the length of the dataset {len(self)}" | |
self.image_files = self.image_files[:n] | |
self.caption_files = self.caption_files[:n] | |
if self.keep_in_mem: | |
self.images = self.images[:n] | |
self.captions = self.captions[:n] | |
print(f"Dataset limited to {n}") | |
def __len__(self): | |
if len(self.images) != 0: | |
return len(self.images) | |
else: | |
return len(self.image_files) | |
def load_image(self, path): | |
with open(path, 'rb') as f: | |
img = Image.open(f).convert('RGB') | |
return img | |
def load_caption(self, path): | |
with open(path, 'r') as f: | |
caption = f.readlines() | |
caption = [line.strip() for line in caption if len(line.strip()) > 0] | |
return caption | |
def __getitem__(self, index): | |
if len(self.images) != 0: | |
img = self.images[index] | |
else: | |
img = self.load_image(os.path.join(self.path, self.image_files[index])) | |
# if self.transform is not None: | |
# img = self.transform(img) | |
if self.caption_path is not None or len(self.captions) != 0: | |
if len(self.captions) != 0: | |
caption = self.captions[index] | |
else: | |
caption = self.load_caption(self.caption_files[index]) | |
ret= {"image": img, "caption": caption, "id": index} | |
else: | |
ret= {"image": img, "id": index} | |
if self.transform is not None: | |
ret = self.transform(ret) | |
return ret | |
def subsample(self, n: int = 10): | |
if n is None or n == -1: | |
return self | |
ori_len = len(self) | |
assert n <= ori_len | |
# equal interval subsample | |
ids = self.image_files[::ori_len // n][:n] | |
self.image_files = ids | |
if self.keep_in_mem: | |
self.images = self.images[::ori_len // n][:n] | |
print(f"Dataset subsampled from {ori_len} to {len(self)}") | |
return self | |
def with_transform(self, transform): | |
self.transform = transform | |
return self | |
def collate_fn(examples): | |
images = [example["image"] for example in examples] | |
ids = [example["id"] for example in examples] | |
if "caption" in examples[0]: | |
captions = [random.choice(example["caption"]) for example in examples] | |
return {"images": images, "captions": captions, "id": ids} | |
else: | |
return {"images": images, "id": ids} | |
class ImagePair(ImageSet): | |
def __init__(self, folder1, folder2, transform=None, keep_in_mem=True): | |
self.path1 = folder1 | |
self.path2 = folder2 | |
self.transform = transform | |
# get all the image files png/jpg | |
self.image_files = [file for file in os.listdir(folder1) if file.endswith(".png") or file.endswith(".jpg")] | |
self.image_files.sort() | |
self.keep_in_mem = keep_in_mem | |
if keep_in_mem: | |
self.images = [] | |
for file in self.image_files: | |
img1 = self.load_image(os.path.join(self.path1, file)) | |
img2 = self.load_image(os.path.join(self.path2, file)) | |
self.images.append((img1, img2)) | |
else: | |
self.images = None | |
def __getitem__(self, index): | |
if self.keep_in_mem: | |
img1, img2 = self.images[index] | |
else: | |
img1 = self.load_image(os.path.join(self.path1, self.image_files[index])) | |
img2 = self.load_image(os.path.join(self.path2, self.image_files[index])) | |
if self.transform is not None: | |
img1 = self.transform(img1) | |
img2 = self.transform(img2) | |
return {"image1": img1, "image2": img2, "id": index} | |
def collate_fn(examples): | |
images1 = [example["image1"] for example in examples] | |
images2 = [example["image2"] for example in examples] | |
# images1 = torch.stack(images1) | |
# images2 = torch.stack(images2) | |
ids = [example["id"] for example in examples] | |
return {"image1": images1, "image2": images2, "id": ids} | |
def push_to_huggingface(self, hug_folder): | |
from datasets import Dataset | |
from datasets import Image as HugImage | |
photo_path = [os.path.join(self.path1, file) for file in self.image_files] | |
sketch_path = [os.path.join(self.path2, file) for file in self.image_files] | |
dataset = Dataset.from_dict({"photo": photo_path, "sketch": sketch_path, "file_name": self.image_files}) | |
dataset = dataset.cast_column("photo", HugImage()) | |
dataset = dataset.cast_column("sketch", HugImage()) | |
dataset.push_to_hub(hug_folder, private=True) | |
class ImageClass(ImageSet): | |
def __init__(self, folders: list, transform=None, keep_in_mem=True): | |
self.paths = folders | |
self.transform = transform | |
# get all the image files png/jpg | |
self.image_files = [] | |
self.keep_in_mem = keep_in_mem | |
for i, folder in enumerate(folders): | |
self.image_files+=[(os.path.join(folder, file), i) for file in os.listdir(folder) if file.endswith(".png") or file.endswith(".jpg")] | |
if keep_in_mem: | |
self.images = [] | |
print("Loading images to memory") | |
for file in self.image_files: | |
img = self.load_image(file[0]) | |
self.images.append((img, file[1])) | |
print("Loading images to memory done") | |
else: | |
self.images = None | |
def __getitem__(self, index): | |
if self.keep_in_mem: | |
img, label = self.images[index] | |
else: | |
img_path, label = self.image_files[index] | |
img = self.load_image(img_path) | |
if self.transform is not None: | |
img = self.transform(img) | |
return {"image": img, "label": label, "id": index} | |
def collate_fn(examples): | |
images = [example["image"] for example in examples] | |
labels = [example["label"] for example in examples] | |
ids = [example["id"] for example in examples] | |
return {"images": images, "labels":labels, "id": ids} | |
if __name__ == "__main__": | |
# dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_50", | |
# "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_50",keep_in_mem=False) | |
# dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-50") | |
dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500", | |
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500", | |
keep_in_mem=True) | |
# dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-500") | |
# ret = dataset[0] | |
# print(len(dataset)) | |
import torch | |
from torchvision import transforms | |
train_transforms = transforms.Compose( | |
[ | |
transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), | |
transforms.CenterCrop(256), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
dataset = dataset.with_transform(train_transforms) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=ImagePair.collate_fn) | |
ret = dataloader.__iter__().__next__() | |
pass |