ic_gan / data_utils /cocostuff_dataset.py
ArantxaCasanova
First model version
a00ee36
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#!/usr/bin/python
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections import defaultdict
import torchvision.transforms as T
from torch.utils.data import Dataset
import PIL
import numpy as np
PREDICATES_VALUES = ["left of", "right of", "above", "below", "inside", "surrounding"]
MEAN = [0.5, 0.5, 0.5]
STD = [0.5, 0.5, 0.5]
def imagenet_preprocess():
return T.Normalize(mean=MEAN, std=STD)
class Resize(object):
def __init__(self, size, interp=PIL.Image.BILINEAR):
if isinstance(size, tuple):
H, W = size
self.size = (W, H)
else:
self.size = (size, size)
self.interp = interp
def __call__(self, img):
return img.resize(self.size, self.interp)
class CocoStuff(Dataset):
"""
A PyTorch Dataset for loading Coco and Coco-Stuff annotations.
Parameters
----------
image_dir: str
Path to a directory where images are held.
instances_json: str
Path to a JSON file giving COCO annotations.
stuff_json: str,optional
Path to a JSON file giving COCO_Stuff annotations.
stuff_only: bool, optional
If True then only iterate over images which appear in stuff_json;
if False then iterate over all images in instances_json.
image_size: tuple, optional
Size (H, W) at which to load images. Default (64, 64).
normalize_image: bool, optional
If True then normalize images by subtracting ImageNet mean pixel and dividing by
ImageNet std pixel.
max_samples: int, optional
If None use all images. Other wise only use images in the range [0, max_samples).
Default None.
min_object_size: float, optional
Ignore objects whose bounding box takes up less than this fraction of the image.
min_objects_per_image: int, optional
Ignore images which have fewer than this many object annotations.
max_objects_per_image: int, optional
Ignore images which have more than this many object annotations.
instance_whitelist: list, optional
None means use all instance categories. Otherwise a list giving a whitelist of
instance category names to use.
stuff_whitelist: list, optional
None means use all stuff categories. Otherwise a list giving a whitelist of stuff
category names to use.
"""
def __init__(
self,
image_dir,
instances_json,
stuff_json=None,
stuff_only=True,
image_size=64,
normalize_images=True,
max_samples=None,
min_object_size=0.02,
min_objects_per_image=3,
max_objects_per_image=8,
instance_whitelist=None,
stuff_whitelist=None,
no__img__=False,
test_part=False,
split="train",
iscrowd=True,
mode="train",
**kwargs
):
super(Dataset, self).__init__()
if stuff_only and stuff_json is None:
print("WARNING: Got stuff_only=True but stuff_json=None.")
print("Falling back to stuff_only=False.")
self.image_dir = image_dir
self.max_samples = max_samples
self.normalize_images = normalize_images
self.iscrowd = iscrowd
# self.transform = transform
self.left_right_flip = False # True if split == 'train' else False
self.max_objects_per_image = max_objects_per_image
self.mode = mode
if image_size is not None:
self.set_image_size(image_size)
print(self.transform)
self.no__img__ = no__img__
with open(instances_json, "r") as f:
instances_data = json.load(f)
self.image_id_to_sentences = {}
stuff_data = None
if stuff_json is not None and stuff_json != "":
with open(stuff_json, "r") as f:
stuff_data = json.load(f)
self.image_ids = []
self.image_id_to_filename = {}
self.image_id_to_size = {}
for image_data in instances_data["images"]:
image_id = image_data["id"]
filename = image_data["file_name"]
width = image_data["width"]
height = image_data["height"]
self.image_ids.append(image_id)
self.image_id_to_filename[image_id] = filename
self.image_id_to_size[image_id] = (width, height)
object_idx_to_name = {}
# Get categories names and ids
all_instance_categories = self.populate_categories(
instances_data, object_idx_to_name
)
all_stuff_categories = self.populate_categories(stuff_data, object_idx_to_name)
if instance_whitelist is None:
instance_whitelist = all_instance_categories
if stuff_whitelist is None:
stuff_whitelist = all_stuff_categories
category_whitelist = set(instance_whitelist) | set(stuff_whitelist)
# Add object data from instances and stuff
self.image_id_to_objects = defaultdict(list)
self.add_object_instances(
instances_data, min_object_size, object_idx_to_name, category_whitelist
)
image_ids_with_stuff = self.add_object_instances(
stuff_data, min_object_size, object_idx_to_name, category_whitelist
)
if stuff_only:
new_image_ids = []
for image_id in self.image_ids:
if image_id in image_ids_with_stuff:
new_image_ids.append(image_id)
self.image_ids = new_image_ids
all_image_ids = set(self.image_id_to_filename.keys())
image_ids_to_remove = all_image_ids - image_ids_with_stuff
for image_id in image_ids_to_remove:
self.image_id_to_filename.pop(image_id, None)
self.image_id_to_size.pop(image_id, None)
self.image_id_to_objects.pop(image_id, None)
# Prune images that have too few or too many objects
new_image_ids = []
total_objs = 0
for image_id in self.image_ids:
num_objs = len(self.image_id_to_objects[image_id])
total_objs += num_objs
if min_objects_per_image <= num_objs <= max_objects_per_image:
new_image_ids.append(image_id)
self.image_ids = new_image_ids
if split == "val":
if test_part:
self.image_ids = self.image_ids[1024:]
else:
print("Entering in val part")
self.image_ids = self.image_ids[:1024]
def populate_categories(self, data, object_idx_to_name):
all_categories = []
for category_data in data["categories"]:
category_id = category_data["id"]
category_name = category_data["name"]
all_categories.append(category_name)
object_idx_to_name[category_id] = category_name
return all_categories
def add_object_instances(
self, data, min_object_size, object_idx_to_name, category_whitelist
):
image_ids_present = set()
for object_data in data["annotations"]:
image_id = object_data["image_id"]
_, _, w, h = object_data["bbox"]
image_ids_present.add(image_id)
W, H = self.image_id_to_size[image_id]
box_area = (w * h) / (W * H)
box_ok = box_area > min_object_size
object_name = object_idx_to_name[object_data["category_id"]]
category_ok = object_name in category_whitelist
other_ok = object_name != "other"
condition = box_ok and category_ok and other_ok
if self.iscrowd:
condition = condition and (object_data["iscrowd"] != 1)
if condition:
self.image_id_to_objects[image_id].append(object_data)
return image_ids_present
def set_image_size(self, image_size):
print("called set_image_size", image_size)
transform = [Resize(image_size), T.ToTensor()]
if self.normalize_images:
transform.append(imagenet_preprocess())
self.transform = T.Compose(transform)
self.image_size = image_size
def total_objects(self):
total_objs = 0
for i, image_id in enumerate(self.image_ids):
if self.max_samples and i >= self.max_samples:
break
num_objs = len(self.image_id_to_objects[image_id])
total_objs += num_objs
return total_objs
def __len__(self):
if self.max_samples is None:
if self.left_right_flip:
return len(self.image_ids) * 2
return len(self.image_ids)
return min(len(self.image_ids), self.max_samples)
def __getitem__(self, index):
""" Get an image, a void label and the image index.
Returns a tuple of: image (FloatTensor of shape (C, H, W)), void label 0 and image index.
"""
flip = False
if self.mode == "train":
if index >= len(self.image_ids):
index = index - len(self.image_ids)
flip = True
image_id = self.image_ids[index]
filename = self.image_id_to_filename[image_id]
image_path = os.path.join(self.image_dir, filename)
with open(image_path, "rb") as f:
with PIL.Image.open(f) as image:
if flip and self.mode == "train":
image = PIL.ImageOps.mirror(image)
image = self.transform(image.convert("RGB"))
return image, int(0), image_id