File size: 16,945 Bytes
749745d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.bounding_box import BoxList
import json
import numpy as np
import os.path as osp
import os
from prettytable import PrettyTable

import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import maskrcnn_benchmark.utils.mdetr_dist as dist

#### The following loading utilities are imported from
#### https://github.com/BryanPlummer/flickr30k_entities/blob/68b3d6f12d1d710f96233f6bd2b6de799d6f4e5b/flickr30k_entities_utils.py
# Changelog:
#    - Added typing information
#    - Completed docstrings


def get_sentence_data(filename) -> List[Dict[str, Any]]:
    """

    Parses a sentence file from the Flickr30K Entities dataset



    input:

      filename - full file path to the sentence file to parse



    output:

      a list of dictionaries for each sentence with the following fields:

          sentence - the original sentence

          phrases - a list of dictionaries for each phrase with the

                    following fields:

                      phrase - the text of the annotated phrase

                      first_word_index - the position of the first word of

                                         the phrase in the sentence

                      phrase_id - an identifier for this phrase

                      phrase_type - a list of the coarse categories this

                                    phrase belongs to



    """
    with open(filename, "r") as f:
        sentences = f.read().split("\n")

    annotations = []
    for sentence in sentences:
        if not sentence:
            continue

        first_word = []
        phrases = []
        phrase_id = []
        phrase_type = []
        words = []
        current_phrase = []
        add_to_phrase = False
        for token in sentence.split():
            if add_to_phrase:
                if token[-1] == "]":
                    add_to_phrase = False
                    token = token[:-1]
                    current_phrase.append(token)
                    phrases.append(" ".join(current_phrase))
                    current_phrase = []
                else:
                    current_phrase.append(token)

                words.append(token)
            else:
                if token[0] == "[":
                    add_to_phrase = True
                    first_word.append(len(words))
                    parts = token.split("/")
                    phrase_id.append(parts[1][3:])
                    phrase_type.append(parts[2:])
                else:
                    words.append(token)

        sentence_data = {"sentence": " ".join(words), "phrases": []}
        for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type):
            sentence_data["phrases"].append(
                {"first_word_index": index, "phrase": phrase, "phrase_id": p_id, "phrase_type": p_type}
            )

        annotations.append(sentence_data)

    return annotations


def get_annotations(filename) -> Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]]:
    """

    Parses the xml files in the Flickr30K Entities dataset



    input:

      filename - full file path to the annotations file to parse



    output:

      dictionary with the following fields:

          scene - list of identifiers which were annotated as

                  pertaining to the whole scene

          nobox - list of identifiers which were annotated as

                  not being visible in the image

          boxes - a dictionary where the fields are identifiers

                  and the values are its list of boxes in the

                  [xmin ymin xmax ymax] format

          height - int representing the height of the image

          width - int representing the width of the image

          depth - int representing the depth of the image

    """
    tree = ET.parse(filename)
    root = tree.getroot()
    size_container = root.findall("size")[0]
    anno_info: Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]] = {}
    all_boxes: Dict[str, List[List[int]]] = {}
    all_noboxes: List[str] = []
    all_scenes: List[str] = []
    for size_element in size_container:
        assert size_element.text
        anno_info[size_element.tag] = int(size_element.text)

    for object_container in root.findall("object"):
        for names in object_container.findall("name"):
            box_id = names.text
            assert box_id
            box_container = object_container.findall("bndbox")
            if len(box_container) > 0:
                if box_id not in all_boxes:
                    all_boxes[box_id] = []
                xmin = int(box_container[0].findall("xmin")[0].text)
                ymin = int(box_container[0].findall("ymin")[0].text)
                xmax = int(box_container[0].findall("xmax")[0].text)
                ymax = int(box_container[0].findall("ymax")[0].text)
                all_boxes[box_id].append([xmin, ymin, xmax, ymax])
            else:
                nobndbox = int(object_container.findall("nobndbox")[0].text)
                if nobndbox > 0:
                    all_noboxes.append(box_id)

                scene = int(object_container.findall("scene")[0].text)
                if scene > 0:
                    all_scenes.append(box_id)
    anno_info["boxes"] = all_boxes
    anno_info["nobox"] = all_noboxes
    anno_info["scene"] = all_scenes

    return anno_info


#### END of import from flickr30k_entities


#### Bounding box utilities imported from torchvision and converted to numpy
def box_area(boxes: np.array) -> np.array:
    """

    Computes the area of a set of bounding boxes, which are specified by its

    (x1, y1, x2, y2) coordinates.



    Args:

        boxes (Tensor[N, 4]): boxes for which the area will be computed. They

            are expected to be in (x1, y1, x2, y2) format with

            ``0 <= x1 < x2`` and ``0 <= y1 < y2``.



    Returns:

        area (Tensor[N]): area for each box

    """
    assert boxes.ndim == 2 and boxes.shape[-1] == 4
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]:
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clip(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    union = area1[:, None] + area2 - inter

    return inter, union


def box_iou(boxes1: np.array, boxes2: np.array) -> np.array:
    """

    Return intersection-over-union (Jaccard index) of boxes.



    Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with

    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.



    Args:

        boxes1 (Tensor[N, 4])

        boxes2 (Tensor[M, 4])



    Returns:

        iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2

    """
    inter, union = _box_inter_union(boxes1, boxes2)
    iou = inter / union
    return iou


#### End of import of box utilities


def _merge_boxes(boxes: List[List[int]]) -> List[List[int]]:
    """

    Return the boxes corresponding to the smallest enclosing box containing all the provided boxes

    The boxes are expected in [x1, y1, x2, y2] format

    """
    if len(boxes) == 1:
        return boxes

    np_boxes = np.asarray(boxes)

    return [[np_boxes[:, 0].min(), np_boxes[:, 1].min(), np_boxes[:, 2].max(), np_boxes[:, 3].max()]]


class RecallTracker:
    """Utility class to track recall@k for various k, split by categories"""

    def __init__(self, topk: Sequence[int]):
        """

        Parameters:

           - topk : tuple of ints corresponding to the recalls being tracked (eg, recall@1, recall@10, ...)

        """

        self.total_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk}
        self.positives_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk}

    def add_positive(self, k: int, category: str):
        """Log a positive hit @k for given category"""
        if k not in self.total_byk_bycat:
            raise RuntimeError(f"{k} is not a valid recall threshold")
        self.total_byk_bycat[k][category] += 1
        self.positives_byk_bycat[k][category] += 1

    def add_negative(self, k: int, category: str):
        """Log a negative hit @k for given category"""
        if k not in self.total_byk_bycat:
            raise RuntimeError(f"{k} is not a valid recall threshold")
        self.total_byk_bycat[k][category] += 1

    def report(self) -> Dict[int, Dict[str, float]]:
        """Return a condensed report of the results as a dict of dict.

        report[k][cat] is the recall@k for the given category

        """
        report: Dict[int, Dict[str, float]] = {}
        for k in self.total_byk_bycat:
            assert k in self.positives_byk_bycat
            report[k] = {
                cat: self.positives_byk_bycat[k][cat] / self.total_byk_bycat[k][cat] for cat in self.total_byk_bycat[k]
            }
        return report


class Flickr30kEntitiesRecallEvaluator:
    def __init__(

        self,

        flickr_path: str,

        subset: str = "test",

        topk: Sequence[int] = (1, 5, 10, -1),

        iou_thresh: float = 0.5,

        merge_boxes: bool = False,

        verbose: bool = True,

    ):
        assert subset in ["train", "test", "val"], f"Wrong flickr subset {subset}"

        self.topk = topk
        self.iou_thresh = iou_thresh

        flickr_path = Path(flickr_path)

        # We load the image ids corresponding to the current subset
        with open(flickr_path / f"{subset}.txt") as file_d:
            self.img_ids = [line.strip() for line in file_d]

        if verbose:
            print(f"Flickr subset contains {len(self.img_ids)} images")

        # Read the box annotations for all the images
        self.imgid2boxes: Dict[str, Dict[str, List[List[int]]]] = {}

        if verbose:
            print("Loading annotations...")

        for img_id in self.img_ids:
            anno_info = get_annotations(flickr_path / "Annotations" / f"{img_id}.xml")["boxes"]
            if merge_boxes:
                merged = {}
                for phrase_id, boxes in anno_info.items():
                    merged[phrase_id] = _merge_boxes(boxes)
                anno_info = merged
            self.imgid2boxes[img_id] = anno_info

        # Read the sentences annotations
        self.imgid2sentences: Dict[str, List[List[Optional[Dict]]]] = {}

        if verbose:
            print("Loading annotations...")

        self.all_ids: List[str] = []
        tot_phrases = 0
        for img_id in self.img_ids:
            sentence_info = get_sentence_data(flickr_path / "Sentences" / f"{img_id}.txt")
            self.imgid2sentences[img_id] = [None for _ in range(len(sentence_info))]

            # Some phrases don't have boxes, we filter them.
            for sent_id, sentence in enumerate(sentence_info):
                phrases = [phrase for phrase in sentence["phrases"] if phrase["phrase_id"] in self.imgid2boxes[img_id]]
                if len(phrases) > 0:
                    self.imgid2sentences[img_id][sent_id] = phrases
                tot_phrases += len(phrases)

            self.all_ids += [
                f"{img_id}_{k}" for k in range(len(sentence_info)) if self.imgid2sentences[img_id][k] is not None
            ]

        if verbose:
            print(f"There are {tot_phrases} phrases in {len(self.all_ids)} sentences to evaluate")

    def evaluate(self, predictions: List[Dict]):
        evaluated_ids = set()

        recall_tracker = RecallTracker(self.topk)

        for pred in predictions:
            cur_id = f"{pred['image_id']}_{pred['sentence_id']}"
            if cur_id in evaluated_ids:
                print(
                    "Warning, multiple predictions found for sentence"
                    f"{pred['sentence_id']} in image {pred['image_id']}"
                )
                continue

            # Skip the sentences with no valid phrase
            if cur_id not in self.all_ids:
                if len(pred["boxes"]) != 0:
                    print(
                        f"Warning, in image {pred['image_id']} we were not expecting predictions "
                        f"for sentence {pred['sentence_id']}. Ignoring them."
                    )
                continue

            evaluated_ids.add(cur_id)

            pred_boxes = pred["boxes"]
            if str(pred["image_id"]) not in self.imgid2sentences:
                raise RuntimeError(f"Unknown image id {pred['image_id']}")
            if not 0 <= int(pred["sentence_id"]) < len(self.imgid2sentences[str(pred["image_id"])]):
                raise RuntimeError(f"Unknown sentence id {pred['sentence_id']}" f" in image {pred['image_id']}")
            target_sentence = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])]

            phrases = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])]
            if len(pred_boxes) != len(phrases):
                raise RuntimeError(
                    f"Error, got {len(pred_boxes)} predictions, expected {len(phrases)} "
                    f"for sentence {pred['sentence_id']} in image {pred['image_id']}"
                )

            for cur_boxes, phrase in zip(pred_boxes, phrases):
                target_boxes = self.imgid2boxes[str(pred["image_id"])][phrase["phrase_id"]]

                ious = box_iou(np.asarray(cur_boxes), np.asarray(target_boxes))
                for k in self.topk:
                    maxi = 0
                    if k == -1:
                        maxi = ious.max()
                    else:
                        assert k > 0
                        maxi = ious[:k].max()
                    if maxi >= self.iou_thresh:
                        recall_tracker.add_positive(k, "all")
                        for phrase_type in phrase["phrase_type"]:
                            recall_tracker.add_positive(k, phrase_type)
                    else:
                        recall_tracker.add_negative(k, "all")
                        for phrase_type in phrase["phrase_type"]:
                            recall_tracker.add_negative(k, phrase_type)

        if len(evaluated_ids) != len(self.all_ids):
            print("ERROR, the number of evaluated sentence doesn't match. Missing predictions:")
            un_processed = set(self.all_ids) - evaluated_ids
            for missing in un_processed:
                img_id, sent_id = missing.split("_")
                print(f"\t sentence {sent_id} in image {img_id}")
            raise RuntimeError("Missing predictions")

        return recall_tracker.report()


class FlickrEvaluator(object):
    def __init__(

        self,

        flickr_path,

        subset,

        top_k=(1, 5, 10, -1),

        iou_thresh=0.5,

        merge_boxes=False,

    ):
        assert isinstance(top_k, (list, tuple))

        self.evaluator = Flickr30kEntitiesRecallEvaluator(
            flickr_path, subset=subset, topk=top_k, iou_thresh=iou_thresh, merge_boxes=merge_boxes, verbose=False
        )
        self.predictions = []
        self.results = None

    def accumulate(self):
        pass

    def update(self, predictions):
        self.predictions += predictions

    def synchronize_between_processes(self):
        all_predictions = dist.all_gather(self.predictions)
        self.predictions = sum(all_predictions, [])

    def summarize(self):
        if dist.is_main_process():
            self.results = self.evaluator.evaluate(self.predictions)
            table = PrettyTable()
            all_cat = sorted(list(self.results.values())[0].keys())
            table.field_names = ["Recall@k"] + all_cat

            score = {}
            for k, v in self.results.items():
                cur_results = [v[cat] for cat in all_cat]
                header = "Upper_bound" if k == -1 else f"Recall@{k}"

                for cat in all_cat:
                    score[f"{header}_{cat}"] = v[cat]
                table.add_row([header] + cur_results)

            print(table)

            return score

        return None, None