zdou0830's picture
desco
749745d
raw
history blame
No virus
10.9 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import collections
import json
import os.path as op
import numpy as np
import torch
from .tsv import TSVYamlDataset, find_file_path_in_yaml
from .box_label_loader import BoxLabelLoader
from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV
class VGDetectionTSV(CocoDetectionTSV):
pass
def sort_key_by_val(dic):
sorted_dic = sorted(dic.items(), key=lambda kv: kv[1])
return [kv[0] for kv in sorted_dic]
def bbox_overlaps(anchors, gt_boxes):
"""
anchors: (N, 4) ndarray of float
gt_boxes: (K, 4) ndarray of float
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
N = anchors.size(0)
K = gt_boxes.size(0)
gt_boxes_area = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K)
anchors_area = ((anchors[:, 2] - anchors[:, 0] + 1) * (anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1)
boxes = anchors.view(N, 1, 4).expand(N, K, 4)
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
iw = torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1
iw[iw < 0] = 0
ih = torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1
ih[ih < 0] = 0
ua = anchors_area + gt_boxes_area - (iw * ih)
overlaps = iw * ih / ua
return overlaps
# VG data loader for Danfei Xu's Scene graph focused format.
# todo: if ordering of classes, attributes, relations changed
# todo make sure to re-write the obj_classes.txt/rel_classes.txt files
def _box_filter(boxes, must_overlap=False):
"""Only include boxes that overlap as possible relations.
If no overlapping boxes, use all of them."""
overlaps = bbox_overlaps(boxes, boxes).numpy() > 0
np.fill_diagonal(overlaps, 0)
all_possib = np.ones_like(overlaps, dtype=np.bool)
np.fill_diagonal(all_possib, 0)
if must_overlap:
possible_boxes = np.column_stack(np.where(overlaps))
if possible_boxes.size == 0:
possible_boxes = np.column_stack(np.where(all_possib))
else:
possible_boxes = np.column_stack(np.where(all_possib))
return possible_boxes
class VGTSVDataset(TSVYamlDataset):
"""
Generic TSV dataset format for Object Detection.
"""
def __init__(
self,
yaml_file,
extra_fields=None,
transforms=None,
is_load_label=True,
filter_duplicate_rels=True,
relation_on=False,
cv2_output=False,
**kwargs
):
if extra_fields is None:
extra_fields = []
self.transforms = transforms
self.is_load_label = is_load_label
self.relation_on = relation_on
super(VGTSVDataset, self).__init__(yaml_file, cv2_output=cv2_output)
ignore_attrs = self.cfg.get("ignore_attrs", None)
# construct those maps
jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root)
jsondict = json.load(open(jsondict_file, "r"))
# self.linelist_file
if "train" in op.basename(self.linelist_file):
self.split = "train"
elif (
"test" in op.basename(self.linelist_file)
or "val" in op.basename(self.linelist_file)
or "valid" in op.basename(self.linelist_file)
):
self.split = "test"
else:
raise ValueError("Split must be one of [train, test], but get {}!".format(self.linelist_file))
self.filter_duplicate_rels = filter_duplicate_rels and self.split == "train"
self.class_to_ind = jsondict["label_to_idx"]
self.ind_to_class = jsondict["idx_to_label"]
self.class_to_ind["__background__"] = 0
self.ind_to_class["0"] = "__background__"
self.classes = sort_key_by_val(self.class_to_ind)
assert all([self.classes[i] == self.ind_to_class[str(i)] for i in range(len(self.classes))])
# writing obj classes to disk for Neural Motif model building.
obj_classes_out_fn = op.splitext(self.label_file)[0] + ".obj_classes.txt"
if not op.isfile(obj_classes_out_fn):
with open(obj_classes_out_fn, "w") as f:
for item in self.classes:
f.write("%s\n" % item)
self.attribute_to_ind = jsondict["attribute_to_idx"]
self.ind_to_attribute = jsondict["idx_to_attribute"]
self.attribute_to_ind["__no_attribute__"] = 0
self.ind_to_attribute["0"] = "__no_attribute__"
self.attributes = sort_key_by_val(self.attribute_to_ind)
assert all([self.attributes[i] == self.ind_to_attribute[str(i)] for i in range(len(self.attributes))])
self.relation_to_ind = jsondict["predicate_to_idx"]
self.ind_to_relation = jsondict["idx_to_predicate"]
self.relation_to_ind["__no_relation__"] = 0
self.ind_to_relation["0"] = "__no_relation__"
self.relations = sort_key_by_val(self.relation_to_ind)
assert all([self.relations[i] == self.ind_to_relation[str(i)] for i in range(len(self.relations))])
# writing rel classes to disk for Neural Motif Model building.
rel_classes_out_fn = op.splitext(self.label_file)[0] + ".rel_classes.txt"
if not op.isfile(rel_classes_out_fn):
with open(rel_classes_out_fn, "w") as f:
for item in self.relations:
f.write("%s\n" % item)
# label map: minus one because we will add one in BoxLabelLoader
self.labelmap = {key: val - 1 for key, val in self.class_to_ind.items()}
labelmap_file = find_file_path_in_yaml(self.cfg.get("labelmap_dec"), self.root)
# self.labelmap_dec = load_labelmap_file(labelmap_file)
if self.is_load_label:
self.label_loader = BoxLabelLoader(
labelmap=self.labelmap, extra_fields=extra_fields, ignore_attrs=ignore_attrs
)
# get frequency prior for relations
if self.relation_on:
self.freq_prior_file = op.splitext(self.label_file)[0] + ".freq_prior.npy"
if self.split == "train" and not op.exists(self.freq_prior_file):
print("Computing frequency prior matrix...")
fg_matrix, bg_matrix = self._get_freq_prior()
prob_matrix = fg_matrix.astype(np.float32)
prob_matrix[:, :, 0] = bg_matrix
prob_matrix[:, :, 0] += 1
prob_matrix /= np.sum(prob_matrix, 2)[:, :, None]
np.save(self.freq_prior_file, prob_matrix)
def _get_freq_prior(self, must_overlap=False):
fg_matrix = np.zeros((len(self.classes), len(self.classes), len(self.relations)), dtype=np.int64)
bg_matrix = np.zeros(
(
len(self.classes),
len(self.classes),
),
dtype=np.int64,
)
for ex_ind in range(self.__len__()):
target = self.get_groundtruth(ex_ind)
gt_classes = target.get_field("labels").numpy()
gt_relations = target.get_field("relation_labels").numpy()
gt_boxes = target.bbox
# For the foreground, we'll just look at everything
try:
o1o2 = gt_classes[gt_relations[:, :2]]
for (o1, o2), gtr in zip(o1o2, gt_relations[:, 2]):
fg_matrix[o1, o2, gtr] += 1
# For the background, get all of the things that overlap.
o1o2_total = gt_classes[np.array(_box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)]
for (o1, o2) in o1o2_total:
bg_matrix[o1, o2] += 1
except IndexError as e:
assert len(gt_relations) == 0
if ex_ind % 20 == 0:
print("processing {}/{}".format(ex_ind, self.__len__()))
return fg_matrix, bg_matrix
def relation_loader(self, relation_triplets, target):
# relation_triplets [list of tuples]: M*3
# target: BoxList from label_loader
if self.filter_duplicate_rels:
# Filter out dupes!
assert self.split == "train"
all_rel_sets = collections.defaultdict(list)
for (o0, o1, r) in relation_triplets:
all_rel_sets[(o0, o1)].append(r)
relation_triplets = [(k[0], k[1], np.random.choice(v)) for k, v in all_rel_sets.items()]
# get M*M pred_labels
relations = torch.zeros([len(target), len(target)], dtype=torch.int64)
for i in range(len(relation_triplets)):
subj_id = relation_triplets[i][0]
obj_id = relation_triplets[i][1]
pred = relation_triplets[i][2]
relations[subj_id, obj_id] = int(pred)
relation_triplets = torch.tensor(relation_triplets)
target.add_field("relation_labels", relation_triplets)
target.add_field("pred_labels", relations)
return target
def get_target_from_annotations(self, annotations, img_size, idx):
if self.is_load_label and annotations:
target = self.label_loader(annotations["objects"], img_size)
# make sure no boxes are removed
assert len(annotations["objects"]) == len(target)
if self.split in ["val", "test"]:
# add the difficult field
target.add_field("difficult", torch.zeros(len(target), dtype=torch.int32))
# load relations
if self.relation_on:
target = self.relation_loader(annotations["relations"], target)
return target
def get_groundtruth(self, idx, call=False):
# similar to __getitem__ but without transform
img = self.get_image(idx)
if self.cv2_output:
img_size = img.shape[:2][::-1] # h, w -> w, h
else:
img_size = img.size # w, h
annotations = self.get_annotations(idx)
target = self.get_target_from_annotations(annotations, img_size, idx)
if call:
return img, target, annotations
else:
return target
def apply_transforms(self, img, target=None):
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def map_class_id_to_class_name(self, class_id):
return self.classes[class_id]
def map_attribute_id_to_attribute_name(self, attribute_id):
return self.attributes[attribute_id]
def map_relation_id_to_relation_name(self, relation_id):
return self.relations[relation_id]