Spaces:
Running
on
Zero
Running
on
Zero
# Authors: Hui Ren (rhfeiyang.github.io) | |
import os.path | |
import sys | |
from typing import Any, Callable, List, Optional, Tuple | |
import tqdm | |
from PIL import Image | |
from torch.utils.data import Dataset | |
import pickle | |
from torchvision import transforms | |
# import torch | |
# import torchvision | |
# import re | |
class SamDataset(Dataset): | |
def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "data/sam/clip_filtered_ids.pickle",id_dict_file:str =None , transforms: Optional[Callable] = None, | |
resolution=None, | |
get_img=True, | |
get_cap=True,): | |
if id_dict_file is not None: | |
with open(id_dict_file, 'rb') as f: | |
print(f"Loading id_dict from {id_dict_file}", flush=True) | |
self.id_dict = pickle.load(f) | |
print(f"Loaded id_dict from {id_dict_file}", flush=True) | |
else: | |
self.id_dict = None | |
if isinstance(id_file, list): | |
self.ids = id_file | |
elif isinstance(id_file, str): | |
with open(id_file, 'rb') as f: | |
print(f"Loading ids from {id_file}", flush=True) | |
self.ids = pickle.load(f) | |
print(f"Loaded ids from {id_file}", flush=True) | |
self.resolution = resolution | |
self.ori_image_folder_path = image_folder_path | |
if self.resolution is not None: | |
if os.path.exists("/var/jomat/datasets/"): | |
# self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}" | |
self.image_folder_path = f"{image_folder_path}_{resolution}" | |
else: | |
self.image_folder_path = f"{image_folder_path}_{resolution}" | |
os.makedirs(self.image_folder_path, exist_ok=True) | |
else: | |
self.image_folder_path = image_folder_path | |
self.caption_folder_path = caption_folder_path | |
self.transforms = transforms | |
self.column_names = ["image", "text"] | |
self.get_img = get_img | |
self.get_cap = get_cap | |
def __len__(self): | |
# return 100 | |
return len(self.ids) | |
def __getitem__(self, index: int): | |
id = self.ids[index] | |
ret={"id":id} | |
try: | |
# if index == 1: | |
# raise Exception("test") | |
if self.get_img: | |
image = self._load_image(id) | |
ret["image"]=image | |
if self.get_cap: | |
target = self._load_caption(id) | |
ret["text"] = [target] | |
if self.transforms is not None: | |
ret = self.transforms(ret) | |
return ret | |
except Exception as e: | |
raise e | |
print(f"Error loading image and caption for id {id}, error: {e}, redirecting to index 0", flush=True) | |
ret = self[0] | |
return ret | |
def define_resolution(self, resolution: int): | |
self.resolution = resolution | |
if os.path.exists("/var/jomat/datasets/"): | |
self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}" | |
# self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}" | |
else: | |
self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}" | |
print(f"SamDataset resolution defined to {resolution}, new image folder path: {self.image_folder_path}") | |
def _load_image(self, id: int) -> Image.Image: | |
if self.id_dict is not None: | |
subfolder = self.id_dict[id] | |
image_path = f"{self.image_folder_path}/{subfolder}/sa_{id}.jpg" | |
else: | |
image_path = f"{self.image_folder_path}/sa_{id}.jpg" | |
try: | |
with open(image_path, 'rb') as f: | |
img = Image.open(f).convert("RGB") | |
# return img | |
except: | |
# load original image | |
if self.id_dict is not None: | |
subfolder = self.id_dict[id] | |
ori_image_path = f"{self.ori_image_folder_path}/{subfolder}/sa_{id}.jpg" | |
else: | |
ori_image_path = f"{self.ori_image_folder_path}/sa_{id}.jpg" | |
assert os.path.exists(ori_image_path) | |
with open(ori_image_path, 'rb') as f: | |
img = Image.open(f).convert("RGB") | |
# resize image keep aspect ratio | |
if self.resolution is not None: | |
img = transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BICUBIC)(img) | |
# write image | |
os.makedirs(os.path.dirname(image_path), exist_ok=True) | |
img.save(image_path) | |
return img | |
def _load_caption(self, id: int): | |
caption_path = f"{self.caption_folder_path}/sa_{id}.txt" | |
if not os.path.exists(caption_path): | |
return None | |
try: | |
with open(caption_path, 'r', encoding="utf-8") as f: | |
content = f.read() | |
except Exception as e: | |
raise e | |
print(f"Error reading caption file {caption_path}, error: {e}") | |
return None | |
sentences = content.split('.') | |
# remove empty sentences and sentences with "black and white"(too many false prediction) | |
sentences = [sentence.strip() for sentence in sentences if sentence.strip() and "black and white" not in sentence] | |
# join sentence | |
sentences = ". ".join(sentences) | |
if len(sentences) > 0 and sentences[-1] != '.': | |
sentences += '.' | |
return sentences | |
def with_transform(self, transform): | |
self.transforms = transform | |
return self | |
def subsample(self, n: int = 10000): | |
if n is None or n == -1: | |
return self | |
ori_len = len(self) | |
assert n <= ori_len | |
# equal interval subsample | |
ids = self.ids[::ori_len // n][:n] | |
self.ids = ids | |
print(f"SAM dataset subsampled from {ori_len} to {len(self)}") | |
return self | |
if __name__ == "__main__": | |
# sam_filt(caption_filt=False, clip_filt=False, clip_logit=True) | |
from custom_datasets.sam_caption.mypath import MyPath | |
dataset = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict")) | |
dataset.get_img = False | |
for i in tqdm.tqdm(dataset): | |
a=i['text'] | |