Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import collections | |
import json | |
import os.path as op | |
import numpy as np | |
import torch | |
from .tsv import TSVYamlDataset, find_file_path_in_yaml | |
from .box_label_loader import BoxLabelLoader | |
from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV | |
class VGDetectionTSV(CocoDetectionTSV): | |
pass | |
def sort_key_by_val(dic): | |
sorted_dic = sorted(dic.items(), key=lambda kv: kv[1]) | |
return [kv[0] for kv in sorted_dic] | |
def bbox_overlaps(anchors, gt_boxes): | |
""" | |
anchors: (N, 4) ndarray of float | |
gt_boxes: (K, 4) ndarray of float | |
overlaps: (N, K) ndarray of overlap between boxes and query_boxes | |
""" | |
N = anchors.size(0) | |
K = gt_boxes.size(0) | |
gt_boxes_area = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K) | |
anchors_area = ((anchors[:, 2] - anchors[:, 0] + 1) * (anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1) | |
boxes = anchors.view(N, 1, 4).expand(N, K, 4) | |
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4) | |
iw = torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) - torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1 | |
iw[iw < 0] = 0 | |
ih = torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) - torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1 | |
ih[ih < 0] = 0 | |
ua = anchors_area + gt_boxes_area - (iw * ih) | |
overlaps = iw * ih / ua | |
return overlaps | |
# VG data loader for Danfei Xu's Scene graph focused format. | |
# todo: if ordering of classes, attributes, relations changed | |
# todo make sure to re-write the obj_classes.txt/rel_classes.txt files | |
def _box_filter(boxes, must_overlap=False): | |
"""Only include boxes that overlap as possible relations. | |
If no overlapping boxes, use all of them.""" | |
overlaps = bbox_overlaps(boxes, boxes).numpy() > 0 | |
np.fill_diagonal(overlaps, 0) | |
all_possib = np.ones_like(overlaps, dtype=np.bool) | |
np.fill_diagonal(all_possib, 0) | |
if must_overlap: | |
possible_boxes = np.column_stack(np.where(overlaps)) | |
if possible_boxes.size == 0: | |
possible_boxes = np.column_stack(np.where(all_possib)) | |
else: | |
possible_boxes = np.column_stack(np.where(all_possib)) | |
return possible_boxes | |
class VGTSVDataset(TSVYamlDataset): | |
""" | |
Generic TSV dataset format for Object Detection. | |
""" | |
def __init__( | |
self, | |
yaml_file, | |
extra_fields=None, | |
transforms=None, | |
is_load_label=True, | |
filter_duplicate_rels=True, | |
relation_on=False, | |
cv2_output=False, | |
**kwargs | |
): | |
if extra_fields is None: | |
extra_fields = [] | |
self.transforms = transforms | |
self.is_load_label = is_load_label | |
self.relation_on = relation_on | |
super(VGTSVDataset, self).__init__(yaml_file, cv2_output=cv2_output) | |
ignore_attrs = self.cfg.get("ignore_attrs", None) | |
# construct those maps | |
jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root) | |
jsondict = json.load(open(jsondict_file, "r")) | |
# self.linelist_file | |
if "train" in op.basename(self.linelist_file): | |
self.split = "train" | |
elif ( | |
"test" in op.basename(self.linelist_file) | |
or "val" in op.basename(self.linelist_file) | |
or "valid" in op.basename(self.linelist_file) | |
): | |
self.split = "test" | |
else: | |
raise ValueError("Split must be one of [train, test], but get {}!".format(self.linelist_file)) | |
self.filter_duplicate_rels = filter_duplicate_rels and self.split == "train" | |
self.class_to_ind = jsondict["label_to_idx"] | |
self.ind_to_class = jsondict["idx_to_label"] | |
self.class_to_ind["__background__"] = 0 | |
self.ind_to_class["0"] = "__background__" | |
self.classes = sort_key_by_val(self.class_to_ind) | |
assert all([self.classes[i] == self.ind_to_class[str(i)] for i in range(len(self.classes))]) | |
# writing obj classes to disk for Neural Motif model building. | |
obj_classes_out_fn = op.splitext(self.label_file)[0] + ".obj_classes.txt" | |
if not op.isfile(obj_classes_out_fn): | |
with open(obj_classes_out_fn, "w") as f: | |
for item in self.classes: | |
f.write("%s\n" % item) | |
self.attribute_to_ind = jsondict["attribute_to_idx"] | |
self.ind_to_attribute = jsondict["idx_to_attribute"] | |
self.attribute_to_ind["__no_attribute__"] = 0 | |
self.ind_to_attribute["0"] = "__no_attribute__" | |
self.attributes = sort_key_by_val(self.attribute_to_ind) | |
assert all([self.attributes[i] == self.ind_to_attribute[str(i)] for i in range(len(self.attributes))]) | |
self.relation_to_ind = jsondict["predicate_to_idx"] | |
self.ind_to_relation = jsondict["idx_to_predicate"] | |
self.relation_to_ind["__no_relation__"] = 0 | |
self.ind_to_relation["0"] = "__no_relation__" | |
self.relations = sort_key_by_val(self.relation_to_ind) | |
assert all([self.relations[i] == self.ind_to_relation[str(i)] for i in range(len(self.relations))]) | |
# writing rel classes to disk for Neural Motif Model building. | |
rel_classes_out_fn = op.splitext(self.label_file)[0] + ".rel_classes.txt" | |
if not op.isfile(rel_classes_out_fn): | |
with open(rel_classes_out_fn, "w") as f: | |
for item in self.relations: | |
f.write("%s\n" % item) | |
# label map: minus one because we will add one in BoxLabelLoader | |
self.labelmap = {key: val - 1 for key, val in self.class_to_ind.items()} | |
labelmap_file = find_file_path_in_yaml(self.cfg.get("labelmap_dec"), self.root) | |
# self.labelmap_dec = load_labelmap_file(labelmap_file) | |
if self.is_load_label: | |
self.label_loader = BoxLabelLoader( | |
labelmap=self.labelmap, extra_fields=extra_fields, ignore_attrs=ignore_attrs | |
) | |
# get frequency prior for relations | |
if self.relation_on: | |
self.freq_prior_file = op.splitext(self.label_file)[0] + ".freq_prior.npy" | |
if self.split == "train" and not op.exists(self.freq_prior_file): | |
print("Computing frequency prior matrix...") | |
fg_matrix, bg_matrix = self._get_freq_prior() | |
prob_matrix = fg_matrix.astype(np.float32) | |
prob_matrix[:, :, 0] = bg_matrix | |
prob_matrix[:, :, 0] += 1 | |
prob_matrix /= np.sum(prob_matrix, 2)[:, :, None] | |
np.save(self.freq_prior_file, prob_matrix) | |
def _get_freq_prior(self, must_overlap=False): | |
fg_matrix = np.zeros((len(self.classes), len(self.classes), len(self.relations)), dtype=np.int64) | |
bg_matrix = np.zeros( | |
( | |
len(self.classes), | |
len(self.classes), | |
), | |
dtype=np.int64, | |
) | |
for ex_ind in range(self.__len__()): | |
target = self.get_groundtruth(ex_ind) | |
gt_classes = target.get_field("labels").numpy() | |
gt_relations = target.get_field("relation_labels").numpy() | |
gt_boxes = target.bbox | |
# For the foreground, we'll just look at everything | |
try: | |
o1o2 = gt_classes[gt_relations[:, :2]] | |
for (o1, o2), gtr in zip(o1o2, gt_relations[:, 2]): | |
fg_matrix[o1, o2, gtr] += 1 | |
# For the background, get all of the things that overlap. | |
o1o2_total = gt_classes[np.array(_box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)] | |
for (o1, o2) in o1o2_total: | |
bg_matrix[o1, o2] += 1 | |
except IndexError as e: | |
assert len(gt_relations) == 0 | |
if ex_ind % 20 == 0: | |
print("processing {}/{}".format(ex_ind, self.__len__())) | |
return fg_matrix, bg_matrix | |
def relation_loader(self, relation_triplets, target): | |
# relation_triplets [list of tuples]: M*3 | |
# target: BoxList from label_loader | |
if self.filter_duplicate_rels: | |
# Filter out dupes! | |
assert self.split == "train" | |
all_rel_sets = collections.defaultdict(list) | |
for (o0, o1, r) in relation_triplets: | |
all_rel_sets[(o0, o1)].append(r) | |
relation_triplets = [(k[0], k[1], np.random.choice(v)) for k, v in all_rel_sets.items()] | |
# get M*M pred_labels | |
relations = torch.zeros([len(target), len(target)], dtype=torch.int64) | |
for i in range(len(relation_triplets)): | |
subj_id = relation_triplets[i][0] | |
obj_id = relation_triplets[i][1] | |
pred = relation_triplets[i][2] | |
relations[subj_id, obj_id] = int(pred) | |
relation_triplets = torch.tensor(relation_triplets) | |
target.add_field("relation_labels", relation_triplets) | |
target.add_field("pred_labels", relations) | |
return target | |
def get_target_from_annotations(self, annotations, img_size, idx): | |
if self.is_load_label and annotations: | |
target = self.label_loader(annotations["objects"], img_size) | |
# make sure no boxes are removed | |
assert len(annotations["objects"]) == len(target) | |
if self.split in ["val", "test"]: | |
# add the difficult field | |
target.add_field("difficult", torch.zeros(len(target), dtype=torch.int32)) | |
# load relations | |
if self.relation_on: | |
target = self.relation_loader(annotations["relations"], target) | |
return target | |
def get_groundtruth(self, idx, call=False): | |
# similar to __getitem__ but without transform | |
img = self.get_image(idx) | |
if self.cv2_output: | |
img_size = img.shape[:2][::-1] # h, w -> w, h | |
else: | |
img_size = img.size # w, h | |
annotations = self.get_annotations(idx) | |
target = self.get_target_from_annotations(annotations, img_size, idx) | |
if call: | |
return img, target, annotations | |
else: | |
return target | |
def apply_transforms(self, img, target=None): | |
if self.transforms is not None: | |
img, target = self.transforms(img, target) | |
return img, target | |
def map_class_id_to_class_name(self, class_id): | |
return self.classes[class_id] | |
def map_attribute_id_to_attribute_name(self, attribute_id): | |
return self.attributes[attribute_id] | |
def map_relation_id_to_relation_name(self, relation_id): | |
return self.relations[relation_id] | |