ELITE / datasets.py
csyxwei's picture
readme and new requirements
from packaging import version
from PIL import Image
from torchvision import transforms
import os
import PIL
from torch.utils.data import Dataset
import torchvision
import numpy as np
import torch
import random
import albumentations as A
import copy
import cv2
import pandas as pd
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
"linear": PIL.Image.Resampling.BILINEAR,
"bilinear": PIL.Image.Resampling.BILINEAR,
"bicubic": PIL.Image.Resampling.BICUBIC,
"lanczos": PIL.Image.Resampling.LANCZOS,
"nearest": PIL.Image.Resampling.NEAREST,
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
def is_image(file):
return 'jpg' in file.lower() or 'png' in file.lower() or 'jpeg' in file.lower()
class CustomDatasetWithBG(Dataset):
def __init__(
template="a photo of a {}",
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.placeholder_token = placeholder_token
self.image_paths = []
self.image_paths += [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) if is_image(file_path) and not 'bg' in file_path]
self.image_paths = sorted(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.interpolation = {
"linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL_INTERPOLATION["lanczos"],
self.template = template
def __len__(self):
return self._length
def get_tensor_clip(self, normalize=True, toTensor=True):
transform_list = []
if toTensor:
transform_list += [torchvision.transforms.ToTensor()]
if normalize:
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))]
return torchvision.transforms.Compose(transform_list)
def process(self, image):
img = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
img = np.array(img).astype(np.float32)
img = img / 127.5 - 1.0
return torch.from_numpy(img).permute(2, 0, 1)
def __getitem__(self, i):
example = {}
placeholder_string = self.placeholder_token
text = self.template.format(placeholder_string)
example["text"] = text
placeholder_index = 0
words = text.strip().split(' ')
for idx, word in enumerate(words):
if word == placeholder_string:
placeholder_index = idx + 1
example["index"] = torch.tensor(placeholder_index)
example["input_ids"] = self.tokenizer(
image = Image.open(self.image_paths[i % self.num_images])
mask_path = self.image_paths[i % self.num_images].replace('.jpeg', '.png').replace('.jpg', '.png').replace('.JPEG', '.png')[:-4] + '_bg.png'
mask = np.array(Image.open(mask_path))
mask = np.where(mask > 0, 1, 0)
if not image.mode == "RGB":
image = image.convert("RGB")
image_np = np.array(image)
object_tensor = image_np * mask
example["pixel_values"] = self.process(image_np)
ref_object_tensor = Image.fromarray(object_tensor.astype('uint8')).resize((224, 224), resample=self.interpolation)
ref_image_tenser = Image.fromarray(image_np.astype('uint8')).resize((224, 224), resample=self.interpolation)
example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor)
example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tenser)
ref_seg_tensor = Image.fromarray(mask.astype('uint8') * 255)
ref_seg_tensor = self.get_tensor_clip(normalize=False)(ref_seg_tensor)
example["pixel_values_seg"] = torch.nn.functional.interpolate(ref_seg_tensor.unsqueeze(0), size=(128, 128), mode='nearest').squeeze(0)
return example
class OpenImagesDataset(Dataset):
def __init__(
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.placeholder_token = placeholder_token
self.set_type = set
self.random_trans = A.Compose([
A.Resize(height=224, width=224),
self.bbox_path_list = []
if set == "train":
bboxs_path = os.path.join(data_root, 'annotations', f'oidv6-train-annotations-bbox.csv')
elif set == "validation":
bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-bbox.csv')
bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-bbox.csv')
df_val_bbox = pd.read_csv(bboxs_path)
bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName)
bbox_full = []
for label_name in df_val_bbox['LabelName'].unique():
bboxs = bbox_groups.get_group(label_name)[
['XMin', 'XMax', 'YMin', 'YMax', 'LabelName', 'ImageID',
'IsOccluded', 'IsTruncated', 'IsGroupOf', 'IsInside']].values.tolist()
bboxs_new = []
for bbox in bboxs:
if not ((bbox[1] - bbox[0]) * (bbox[3] - bbox[2]) > 0.8 or (bbox[1] - bbox[0]) * (
bbox[3] - bbox[2]) < 0.02):
bboxs_new.append([bbox[0], bbox[1], bbox[2], bbox[3], bbox[4], bbox[5]])
self.bboxs_full = bbox_full
self.num_images = len(bbox_full)
print('{}: total {} images ...'.format(set, self.num_images))
self._length = self.num_images
self.interpolation = {
"linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL_INTERPOLATION["lanczos"],
self.templates = imagenet_templates_small
def __len__(self):
return self._length
def get_tensor_clip(self, normalize=True, toTensor=True):
transform_list = []
if toTensor:
transform_list += [torchvision.transforms.ToTensor()]
if normalize:
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711))]
return torchvision.transforms.Compose(transform_list)
def process(self, image):
img = np.array(image)
img = cv2.resize(img, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
img = np.array(img).astype(np.float32)
img = img / 127.5 - 1.0
return torch.from_numpy(img).permute(2, 0, 1)
def obtain_text(self, add_caption, object_category=None):
if object_category is None:
placeholder_string = self.placeholder_token
placeholder_string = object_category
text = random.choice(self.templates).format(placeholder_string)
text = add_caption + text[1:]
placeholder_index = 0
words = text.strip().split(' ')
for idx, word in enumerate(words):
if word == placeholder_string:
placeholder_index = idx + 1
index = torch.tensor(placeholder_index)
input_ids = self.tokenizer(
return input_ids, index, text
def __getitem__(self, i):
example = {}
input_ids, index, text = self.obtain_text('a')
example["input_ids"] = input_ids
example["index"] = index
example["text"] = text
bbox_sample = self.bboxs_full[i % self.num_images]
bbox_sample = copy.copy(bbox_sample)
file_name = bbox_sample[-1] + '.jpg'
img_path = os.path.join(self.data_root, 'images', self.set_type, file_name)
img_p = Image.open(img_path).convert("RGB")
img_p_np = np.array(img_p)
bbox_sample[0] *= int(img_p_np.shape[1])
bbox_sample[1] *= int(img_p_np.shape[1])
bbox_sample[2] *= int(img_p_np.shape[0])
bbox_sample[3] *= int(img_p_np.shape[0])
bbox_pad = copy.copy(bbox_sample)
bbox_pad[0] = int(bbox_sample[0] - min(10, bbox_sample[0] - 0))
bbox_pad[1] = int(bbox_sample[1] + min(10, img_p.size[0] - bbox_sample[1]))
bbox_pad[2] = int(bbox_sample[2] - min(10, bbox_sample[2] - 0))
bbox_pad[3] = int(bbox_sample[3] + min(10, img_p.size[1] - bbox_sample[3]))
image_tensor = img_p_np[bbox_pad[2]:bbox_pad[3], bbox_pad[0]:bbox_pad[1], :]
example["pixel_values"] = self.process(image_tensor)
ref_image_tensor = self.random_trans(image=image_tensor)
ref_image_tensor = Image.fromarray(ref_image_tensor["image"])
example["pixel_values_clip"] = self.get_tensor_clip()(ref_image_tensor)
except Exception as e:
example["pixel_values"] = torch.zeros((3, 512, 512))
example["pixel_values_clip"] = torch.zeros((3, 224, 224))
with open('error.txt', 'a+') as f:
f.write(str(e) + '\n')
return example
class OpenImagesDatasetWithMask(OpenImagesDataset):
def __init__(self,
# super().__init__(data_root, tokenizer, size, interpolation, set, placeholder_token)
self.data_root = data_root
self.tokenizer = tokenizer
self.size = size
self.placeholder_token = placeholder_token
self.set = set
class_anno_path = os.path.join(data_root, 'annotations', f'oidv6-class-descriptions.csv')
anno_files = pd.read_csv(class_anno_path)
class_groups = anno_files.groupby(anno_files.LabelName)
if set == "train":
bboxs_path = os.path.join(data_root, 'annotations', f'train-annotations-object-segmentation.csv')
dict_path = os.path.join(data_root, 'segs', f'train_bbox_dict.npy')
elif set == "validation":
bboxs_path = os.path.join(data_root, 'annotations', f'validation-annotations-object-segmentation.csv')
dict_path = os.path.join(data_root, 'segs', f'validation_bbox_dict.npy')
bboxs_path = os.path.join(data_root, 'annotations', f'test-annotations-object-segmentation.csv')
dict_path = os.path.join(data_root, 'segs', f'test_bbox_dict.npy')
bbox_dict = np.load(dict_path, allow_pickle=True).item()
df_val_bbox = pd.read_csv(bboxs_path)
bbox_groups = df_val_bbox.groupby(df_val_bbox.LabelName)
bboxes_full = []
for label_name in df_val_bbox['LabelName'].unique():
bboxs = bbox_groups.get_group(label_name)[
['BoxXMin', 'BoxXMax', 'BoxYMin', 'BoxYMax', 'LabelName', 'MaskPath']].values.tolist()
bboxes_new = []
for box in bboxs:
if not box[-1] in bbox_dict:
bbox_data = bbox_dict[box[-1]]
if (bbox_data[2] - bbox_data[1]) < 100 or (bbox_data[4] - bbox_data[3]) < 100:
if not ((bbox_data[2] - bbox_data[1]) / (bbox_data[4] - bbox_data[3]) < 0.5 or (
bbox_data[4] - bbox_data[3]) / ( bbox_data[2] - bbox_data[1]) < 0.5):
class_name = class_groups.get_group(box[4])[['DisplayName']].values.tolist()[0][0]
bboxes_new.append([box[-1], bbox_data[1], bbox_data[2], bbox_data[3], bbox_data[4], class_name])
self.bboxes_full = bboxes_full
self.num_images = len(bboxes_full)
print('{}: total {} images ...'.format(set, self.num_images))
self._length = self.num_images
self.interpolation = {
"linear": PIL_INTERPOLATION["linear"],
"bilinear": PIL_INTERPOLATION["bilinear"],
"bicubic": PIL_INTERPOLATION["bicubic"],
"lanczos": PIL_INTERPOLATION["lanczos"],
self.templates = imagenet_templates_small
def __len__(self):
return self._length
## borrowed from custom diffusion
def custom_aug(self, instance_image):
instance_image = Image.fromarray(instance_image)
#### apply augmentation and create a valid image regions mask ####
if np.random.randint(0, 3) < 2:
random_scale = np.random.randint(self.size // 3, self.size + 1)
random_scale = np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
if random_scale % 2 == 1:
random_scale += 1
if random_scale < 0.6 * self.size:
add_to_caption = np.random.choice(["a far away", "very small"])
cx = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1)
cy = np.random.randint(random_scale // 2, self.size - random_scale // 2 + 1)
instance_image1 = instance_image.resize((random_scale, random_scale), resample=self.interpolation)
instance_image1 = np.array(instance_image1).astype(np.uint8)
instance_image1 = (instance_image1 / 127.5 - 1.0).astype(np.float32)
instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
instance_image[cx - random_scale // 2: cx + random_scale // 2,
cy - random_scale // 2: cy + random_scale // 2, :] = instance_image1
mask = np.zeros((self.size // 8, self.size // 8))
mask[(cx - random_scale // 2) // 8 + 1: (cx + random_scale // 2) // 8 - 1,
(cy - random_scale // 2) // 8 + 1: (cy + random_scale // 2) // 8 - 1] = 1.
elif random_scale > self.size:
add_to_caption = np.random.choice(["zoomed in", "close up"])
cx = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1)
cy = np.random.randint(self.size // 2, random_scale - self.size // 2 + 1)
instance_image = instance_image.resize((random_scale, random_scale), resample=self.interpolation)
instance_image = np.array(instance_image).astype(np.uint8)
instance_image = (instance_image / 127.5 - 1.0).astype(np.float32)
instance_image = instance_image[cx - self.size // 2: cx + self.size // 2,
cy - self.size // 2: cy + self.size // 2, :]
mask = np.ones((self.size // 8, self.size // 8))
add_to_caption = "a"
if self.size is not None:
instance_image = instance_image.resize((self.size, self.size), resample=self.interpolation)
instance_image = np.array(instance_image).astype(np.uint8)
instance_image = (instance_image / 127.5 - 1.0).astype(np.float32)
mask = np.ones((self.size // 8, self.size // 8))
return torch.from_numpy(instance_image).permute(2, 0, 1), torch.from_numpy(mask[:, :, None]).permute(2, 0, 1), add_to_caption
def aug_cv2(self, img, seg):
img_auged = np.array(img).copy()
seg_auged = np.array(seg).copy()
# resize and crop
if random.choice([0, 1]) == 0:
new_size = random.randint(224, 256)
img_auged = cv2.resize(img_auged, (new_size, new_size), interpolation=cv2.INTER_CUBIC)
seg_auged = cv2.resize(seg_auged, (new_size, new_size), interpolation=cv2.INTER_NEAREST)
start_x, start_y = random.randint(0, new_size - 224), random.randint(0, new_size - 224)
img_auged = img_auged[start_x:start_x + 224, start_y:start_y + 224, :]
seg_auged = seg_auged[start_x:start_x + 224, start_y:start_y + 224, :]
h, w = img_auged.shape[:2]
# rotate
if random.choice([0, 1]) == 0:
# print('rotate')
angle = random.randint(-30, 30)
M = cv2.getRotationMatrix2D((112, 112), angle, 1)
img_auged = cv2.warpAffine(img_auged, M, (w, h), flags=cv2.INTER_CUBIC)
seg_auged = cv2.warpAffine(seg_auged, M, (w, h), flags=cv2.INTER_NEAREST)
# translation
if random.choice([0, 1]) == 0:
trans_x = random.randint(-60, 60)
trans_y = random.randint(-60, 60)
H = np.float32([[1, 0, trans_x],
[0, 1, trans_y]])
img_auged = cv2.warpAffine(img_auged, H, (w, h), flags=cv2.INTER_CUBIC)
seg_auged = cv2.warpAffine(seg_auged, H, (w, h), flags=cv2.INTER_NEAREST)
img_auged = Image.fromarray(img_auged)
seg_auged = Image.fromarray(seg_auged)
return img_auged, seg_auged
def __getitem__(self, i):
example = {}
seg_name = self.bboxes_full[i % self.num_images][0]
file_name = seg_name.split('_')[0] + '.jpg'
img_path = os.path.join(self.data_root, 'images', self.set, file_name)
seg_path = os.path.join(self.data_root, 'segs', self.set, seg_name)
# crop image and mask
bbox_sample = self.bboxes_full[i % self.num_images][1:]
img_p_np = cv2.imread(img_path)
img_p_np = cv2.cvtColor(img_p_np, cv2.COLOR_BGR2RGB)
seg_p_np = cv2.imread(seg_path).astype('float')
seg_p_np = cv2.resize(seg_p_np, img_p_np.shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
bbox_pad = copy.copy(bbox_sample)
pad_size = random.choice(list(range(10, 20)))
bbox_pad[0] = int(bbox_pad[0] - min(pad_size, bbox_pad[0] - 0))
bbox_pad[1] = int(bbox_pad[1] + pad_size)
bbox_pad[2] = int(bbox_pad[2] - min(pad_size, bbox_pad[2] - 0))
bbox_pad[3] = int(bbox_pad[3] + pad_size)
image_tensor = img_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :]
seg_tensor = seg_p_np[bbox_pad[0]:bbox_pad[1], bbox_pad[2]:bbox_pad[3], :]
# augmentation for input image
augged_image, augged_mask, add_caption = self.custom_aug(image_tensor)
input_ids, index, text = self.obtain_text(add_caption)
example["pixel_values"] = augged_image
example["mask_values"] = augged_mask
example["input_ids"] = input_ids
example["index"] = index
example["text"] = text
object_tensor = image_tensor * (seg_tensor / 255)
ref_object_tensor = cv2.resize(object_tensor, (224, 224), interpolation=cv2.INTER_CUBIC)
ref_image_tenser = cv2.resize(image_tensor, (224, 224), interpolation=cv2.INTER_CUBIC)
ref_seg_tensor = cv2.resize(seg_tensor, (224, 224), interpolation=cv2.INTER_NEAREST)
ref_object_tensor, ref_seg_tensor = self.aug_cv2(ref_object_tensor.astype('uint8'), ref_seg_tensor.astype('uint8'))
example["pixel_values_clip"] = self.get_tensor_clip()(Image.fromarray(ref_image_tenser))
example["pixel_values_obj"] = self.get_tensor_clip()(ref_object_tensor)
example["pixel_values_seg"] = self.get_tensor_clip(normalize=False)(ref_seg_tensor)
except Exception as e:
example["pixel_values"] = torch.zeros((3, 512, 512))
example["pixel_values_obj"] = torch.zeros((3, 224, 224))
example["pixel_values_clip"] = torch.zeros((3, 224, 224))
example["pixel_values_seg"] = torch.zeros((3, 224, 224))
input_ids, index, text = self.obtain_text("a")
example["input_ids"] = input_ids
example["index"] = index
example["text"] = text
with open('error.txt', 'a+') as f:
f.write(str(e) + '\n')
return example