Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,538 Bytes
262b155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Authors: Hui Ren (rhfeiyang.github.io)
import os.path
import sys
from typing import Any, Callable, List, Optional, Tuple
import tqdm
from PIL import Image
from torch.utils.data import Dataset
import pickle
from torchvision import transforms
# import torch
# import torchvision
# import re
class SamDataset(Dataset):
def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "data/sam/clip_filtered_ids.pickle",id_dict_file:str =None , transforms: Optional[Callable] = None,
resolution=None,
get_img=True,
get_cap=True,):
if id_dict_file is not None:
with open(id_dict_file, 'rb') as f:
print(f"Loading id_dict from {id_dict_file}", flush=True)
self.id_dict = pickle.load(f)
print(f"Loaded id_dict from {id_dict_file}", flush=True)
else:
self.id_dict = None
if isinstance(id_file, list):
self.ids = id_file
elif isinstance(id_file, str):
with open(id_file, 'rb') as f:
print(f"Loading ids from {id_file}", flush=True)
self.ids = pickle.load(f)
print(f"Loaded ids from {id_file}", flush=True)
self.resolution = resolution
self.ori_image_folder_path = image_folder_path
if self.resolution is not None:
if os.path.exists("/var/jomat/datasets/"):
# self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
self.image_folder_path = f"{image_folder_path}_{resolution}"
else:
self.image_folder_path = f"{image_folder_path}_{resolution}"
os.makedirs(self.image_folder_path, exist_ok=True)
else:
self.image_folder_path = image_folder_path
self.caption_folder_path = caption_folder_path
self.transforms = transforms
self.column_names = ["image", "text"]
self.get_img = get_img
self.get_cap = get_cap
def __len__(self):
# return 100
return len(self.ids)
def __getitem__(self, index: int):
id = self.ids[index]
ret={"id":id}
try:
# if index == 1:
# raise Exception("test")
if self.get_img:
image = self._load_image(id)
ret["image"]=image
if self.get_cap:
target = self._load_caption(id)
ret["text"] = [target]
if self.transforms is not None:
ret = self.transforms(ret)
return ret
except Exception as e:
raise e
print(f"Error loading image and caption for id {id}, error: {e}, redirecting to index 0", flush=True)
ret = self[0]
return ret
def define_resolution(self, resolution: int):
self.resolution = resolution
if os.path.exists("/var/jomat/datasets/"):
self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
# self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
else:
self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
print(f"SamDataset resolution defined to {resolution}, new image folder path: {self.image_folder_path}")
def _load_image(self, id: int) -> Image.Image:
if self.id_dict is not None:
subfolder = self.id_dict[id]
image_path = f"{self.image_folder_path}/{subfolder}/sa_{id}.jpg"
else:
image_path = f"{self.image_folder_path}/sa_{id}.jpg"
try:
with open(image_path, 'rb') as f:
img = Image.open(f).convert("RGB")
# return img
except:
# load original image
if self.id_dict is not None:
subfolder = self.id_dict[id]
ori_image_path = f"{self.ori_image_folder_path}/{subfolder}/sa_{id}.jpg"
else:
ori_image_path = f"{self.ori_image_folder_path}/sa_{id}.jpg"
assert os.path.exists(ori_image_path)
with open(ori_image_path, 'rb') as f:
img = Image.open(f).convert("RGB")
# resize image keep aspect ratio
if self.resolution is not None:
img = transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BICUBIC)(img)
# write image
os.makedirs(os.path.dirname(image_path), exist_ok=True)
img.save(image_path)
return img
def _load_caption(self, id: int):
caption_path = f"{self.caption_folder_path}/sa_{id}.txt"
if not os.path.exists(caption_path):
return None
try:
with open(caption_path, 'r', encoding="utf-8") as f:
content = f.read()
except Exception as e:
raise e
print(f"Error reading caption file {caption_path}, error: {e}")
return None
sentences = content.split('.')
# remove empty sentences and sentences with "black and white"(too many false prediction)
sentences = [sentence.strip() for sentence in sentences if sentence.strip() and "black and white" not in sentence]
# join sentence
sentences = ". ".join(sentences)
if len(sentences) > 0 and sentences[-1] != '.':
sentences += '.'
return sentences
def with_transform(self, transform):
self.transforms = transform
return self
def subsample(self, n: int = 10000):
if n is None or n == -1:
return self
ori_len = len(self)
assert n <= ori_len
# equal interval subsample
ids = self.ids[::ori_len // n][:n]
self.ids = ids
print(f"SAM dataset subsampled from {ori_len} to {len(self)}")
return self
if __name__ == "__main__":
# sam_filt(caption_filt=False, clip_filt=False, clip_logit=True)
from custom_datasets.sam_caption.mypath import MyPath
dataset = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"))
dataset.get_img = False
for i in tqdm.tqdm(dataset):
a=i['text']
|