Spaces:
Build error
Build error
File size: 11,011 Bytes
708dec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import collections
import json
import os.path as op
import numpy as np
import torch
from .tsv import TSVYamlDataset, find_file_path_in_yaml
from .box_label_loader import BoxLabelLoader
from maskrcnn_benchmark.data.datasets.coco_dt import CocoDetectionTSV
class VGDetectionTSV(CocoDetectionTSV):
pass
def sort_key_by_val(dic):
sorted_dic = sorted(dic.items(), key=lambda kv: kv[1])
return [kv[0] for kv in sorted_dic]
def bbox_overlaps(anchors, gt_boxes):
"""
anchors: (N, 4) ndarray of float
gt_boxes: (K, 4) ndarray of float
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
N = anchors.size(0)
K = gt_boxes.size(0)
gt_boxes_area = ((gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
(gt_boxes[:, 3] - gt_boxes[:, 1] + 1)).view(1, K)
anchors_area = ((anchors[:, 2] - anchors[:, 0] + 1) *
(anchors[:, 3] - anchors[:, 1] + 1)).view(N, 1)
boxes = anchors.view(N, 1, 4).expand(N, K, 4)
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
iw = (torch.min(boxes[:, :, 2], query_boxes[:, :, 2]) -
torch.max(boxes[:, :, 0], query_boxes[:, :, 0]) + 1)
iw[iw < 0] = 0
ih = (torch.min(boxes[:, :, 3], query_boxes[:, :, 3]) -
torch.max(boxes[:, :, 1], query_boxes[:, :, 1]) + 1)
ih[ih < 0] = 0
ua = anchors_area + gt_boxes_area - (iw * ih)
overlaps = iw * ih / ua
return overlaps
# VG data loader for Danfei Xu's Scene graph focused format.
# todo: if ordering of classes, attributes, relations changed
# todo make sure to re-write the obj_classes.txt/rel_classes.txt files
def _box_filter(boxes, must_overlap=False):
""" Only include boxes that overlap as possible relations.
If no overlapping boxes, use all of them."""
overlaps = bbox_overlaps(boxes, boxes).numpy() > 0
np.fill_diagonal(overlaps, 0)
all_possib = np.ones_like(overlaps, dtype=np.bool)
np.fill_diagonal(all_possib, 0)
if must_overlap:
possible_boxes = np.column_stack(np.where(overlaps))
if possible_boxes.size == 0:
possible_boxes = np.column_stack(np.where(all_possib))
else:
possible_boxes = np.column_stack(np.where(all_possib))
return possible_boxes
class VGTSVDataset(TSVYamlDataset):
"""
Generic TSV dataset format for Object Detection.
"""
def __init__(self, yaml_file, extra_fields=None, transforms=None,
is_load_label=True, filter_duplicate_rels=True,
relation_on=False, cv2_output=False, **kwargs):
if extra_fields is None:
extra_fields = []
self.transforms = transforms
self.is_load_label = is_load_label
self.relation_on = relation_on
super(VGTSVDataset, self).__init__(yaml_file, cv2_output=cv2_output)
ignore_attrs = self.cfg.get("ignore_attrs", None)
# construct those maps
jsondict_file = find_file_path_in_yaml(self.cfg.get("jsondict", None), self.root)
jsondict = json.load(open(jsondict_file, 'r'))
# self.linelist_file
if 'train' in op.basename(self.linelist_file):
self.split = "train"
elif 'test' in op.basename(self.linelist_file) \
or 'val' in op.basename(self.linelist_file) \
or 'valid' in op.basename(self.linelist_file):
self.split = "test"
else:
raise ValueError("Split must be one of [train, test], but get {}!".format(self.linelist_file))
self.filter_duplicate_rels = filter_duplicate_rels and self.split == 'train'
self.class_to_ind = jsondict['label_to_idx']
self.ind_to_class = jsondict['idx_to_label']
self.class_to_ind['__background__'] = 0
self.ind_to_class['0'] = '__background__'
self.classes = sort_key_by_val(self.class_to_ind)
assert (all([self.classes[i] == self.ind_to_class[str(i)] for i in range(len(self.classes))]))
# writing obj classes to disk for Neural Motif model building.
obj_classes_out_fn = op.splitext(self.label_file)[0] + ".obj_classes.txt"
if not op.isfile(obj_classes_out_fn):
with open(obj_classes_out_fn, 'w') as f:
for item in self.classes:
f.write("%s\n" % item)
self.attribute_to_ind = jsondict['attribute_to_idx']
self.ind_to_attribute = jsondict['idx_to_attribute']
self.attribute_to_ind['__no_attribute__'] = 0
self.ind_to_attribute['0'] = '__no_attribute__'
self.attributes = sort_key_by_val(self.attribute_to_ind)
assert (all([self.attributes[i] == self.ind_to_attribute[str(i)] for i in range(len(self.attributes))]))
self.relation_to_ind = jsondict['predicate_to_idx']
self.ind_to_relation = jsondict['idx_to_predicate']
self.relation_to_ind['__no_relation__'] = 0
self.ind_to_relation['0'] = '__no_relation__'
self.relations = sort_key_by_val(self.relation_to_ind)
assert (all([self.relations[i] == self.ind_to_relation[str(i)] for i in range(len(self.relations))]))
# writing rel classes to disk for Neural Motif Model building.
rel_classes_out_fn = op.splitext(self.label_file)[0] + '.rel_classes.txt'
if not op.isfile(rel_classes_out_fn):
with open(rel_classes_out_fn, 'w') as f:
for item in self.relations:
f.write("%s\n" % item)
# label map: minus one because we will add one in BoxLabelLoader
self.labelmap = {key: val - 1 for key, val in self.class_to_ind.items()}
labelmap_file = find_file_path_in_yaml(self.cfg.get("labelmap_dec"), self.root)
# self.labelmap_dec = load_labelmap_file(labelmap_file)
if self.is_load_label:
self.label_loader = BoxLabelLoader(
labelmap=self.labelmap,
extra_fields=extra_fields,
ignore_attrs=ignore_attrs
)
# get frequency prior for relations
if self.relation_on:
self.freq_prior_file = op.splitext(self.label_file)[0] + ".freq_prior.npy"
if self.split == 'train' and not op.exists(self.freq_prior_file):
print("Computing frequency prior matrix...")
fg_matrix, bg_matrix = self._get_freq_prior()
prob_matrix = fg_matrix.astype(np.float32)
prob_matrix[:, :, 0] = bg_matrix
prob_matrix[:, :, 0] += 1
prob_matrix /= np.sum(prob_matrix, 2)[:, :, None]
np.save(self.freq_prior_file, prob_matrix)
def _get_freq_prior(self, must_overlap=False):
fg_matrix = np.zeros((
len(self.classes),
len(self.classes),
len(self.relations)
), dtype=np.int64)
bg_matrix = np.zeros((
len(self.classes),
len(self.classes),
), dtype=np.int64)
for ex_ind in range(self.__len__()):
target = self.get_groundtruth(ex_ind)
gt_classes = target.get_field('labels').numpy()
gt_relations = target.get_field('relation_labels').numpy()
gt_boxes = target.bbox
# For the foreground, we'll just look at everything
try:
o1o2 = gt_classes[gt_relations[:, :2]]
for (o1, o2), gtr in zip(o1o2, gt_relations[:, 2]):
fg_matrix[o1, o2, gtr] += 1
# For the background, get all of the things that overlap.
o1o2_total = gt_classes[np.array(
_box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)]
for (o1, o2) in o1o2_total:
bg_matrix[o1, o2] += 1
except IndexError as e:
assert len(gt_relations) == 0
if ex_ind % 20 == 0:
print("processing {}/{}".format(ex_ind, self.__len__()))
return fg_matrix, bg_matrix
def relation_loader(self, relation_triplets, target):
# relation_triplets [list of tuples]: M*3
# target: BoxList from label_loader
if self.filter_duplicate_rels:
# Filter out dupes!
assert self.split == 'train'
all_rel_sets = collections.defaultdict(list)
for (o0, o1, r) in relation_triplets:
all_rel_sets[(o0, o1)].append(r)
relation_triplets = [(k[0], k[1], np.random.choice(v)) for k, v in all_rel_sets.items()]
# get M*M pred_labels
relations = torch.zeros([len(target), len(target)], dtype=torch.int64)
for i in range(len(relation_triplets)):
subj_id = relation_triplets[i][0]
obj_id = relation_triplets[i][1]
pred = relation_triplets[i][2]
relations[subj_id, obj_id] = int(pred)
relation_triplets = torch.tensor(relation_triplets)
target.add_field("relation_labels", relation_triplets)
target.add_field("pred_labels", relations)
return target
def get_target_from_annotations(self, annotations, img_size, idx):
if self.is_load_label and annotations:
target = self.label_loader(annotations['objects'], img_size)
# make sure no boxes are removed
assert (len(annotations['objects']) == len(target))
if self.split in ["val", "test"]:
# add the difficult field
target.add_field("difficult", torch.zeros(len(target), dtype=torch.int32))
# load relations
if self.relation_on:
target = self.relation_loader(annotations["relations"], target)
return target
def get_groundtruth(self, idx, call=False):
# similar to __getitem__ but without transform
img = self.get_image(idx)
if self.cv2_output:
img_size = img.shape[:2][::-1] # h, w -> w, h
else:
img_size = img.size # w, h
annotations = self.get_annotations(idx)
target = self.get_target_from_annotations(annotations, img_size, idx)
if call:
return img, target, annotations
else:
return target
def apply_transforms(self, img, target=None):
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def map_class_id_to_class_name(self, class_id):
return self.classes[class_id]
def map_attribute_id_to_attribute_name(self, attribute_id):
return self.attributes[attribute_id]
def map_relation_id_to_relation_name(self, relation_id):
return self.relations[relation_id]
|