turn-the-cam-anonymous's picture
adding CLIP taming
1ed7deb
raw
history blame
No virus
5.38 kB
import os
import numpy as np
import cv2
import albumentations
from PIL import Image
from torch.utils.data import Dataset
from taming.data.sflckr import SegmentationBase # for examples included in repo
class Examples(SegmentationBase):
def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
super().__init__(data_csv="data/ade20k_examples.txt",
data_root="data/ade20k_images",
segmentation_root="data/ade20k_segmentations",
size=size, random_crop=random_crop,
interpolation=interpolation,
n_labels=151, shift_segmentation=False)
# With semantic map and scene label
class ADE20kBase(Dataset):
def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
self.split = self.get_split()
self.n_labels = 151 # unknown + 150
self.data_csv = {"train": "data/ade20k_train.txt",
"validation": "data/ade20k_test.txt"}[self.split]
self.data_root = "data/ade20k_root"
with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
self.scene_categories = f.read().splitlines()
self.scene_categories = dict(line.split() for line in self.scene_categories)
with open(self.data_csv, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, "images", l)
for l in self.image_paths],
"relative_segmentation_path_": [l.replace(".jpg", ".png")
for l in self.image_paths],
"segmentation_path_": [os.path.join(self.data_root, "annotations",
l.replace(".jpg", ".png"))
for l in self.image_paths],
"scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
for l in self.image_paths],
}
size = None if size is not None and size<=0 else size
self.size = size
if crop_size is None:
self.crop_size = size if size is not None else None
else:
self.crop_size = crop_size
if self.size is not None:
self.interpolation = interpolation
self.interpolation = {
"nearest": cv2.INTER_NEAREST,
"bilinear": cv2.INTER_LINEAR,
"bicubic": cv2.INTER_CUBIC,
"area": cv2.INTER_AREA,
"lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=self.interpolation)
self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
interpolation=cv2.INTER_NEAREST)
if crop_size is not None:
self.center_crop = not random_crop
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
else:
self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
self.preprocessor = self.cropper
def __len__(self):
return self._length
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
if self.size is not None:
image = self.image_rescaler(image=image)["image"]
segmentation = Image.open(example["segmentation_path_"])
segmentation = np.array(segmentation).astype(np.uint8)
if self.size is not None:
segmentation = self.segmentation_rescaler(image=segmentation)["image"]
if self.size is not None:
processed = self.preprocessor(image=image, mask=segmentation)
else:
processed = {"image": image, "mask": segmentation}
example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
segmentation = processed["mask"]
onehot = np.eye(self.n_labels)[segmentation]
example["segmentation"] = onehot
return example
class ADE20kTrain(ADE20kBase):
# default to random_crop=True
def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
super().__init__(config=config, size=size, random_crop=random_crop,
interpolation=interpolation, crop_size=crop_size)
def get_split(self):
return "train"
class ADE20kValidation(ADE20kBase):
def get_split(self):
return "validation"
if __name__ == "__main__":
dset = ADE20kValidation()
ex = dset[0]
for k in ["image", "scene_category", "segmentation"]:
print(type(ex[k]))
try:
print(ex[k].shape)
except:
print(ex[k])