from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou from maskrcnn_benchmark.structures.bounding_box import BoxList import json import numpy as np import os.path as osp import os from prettytable import PrettyTable import xml.etree.ElementTree as ET from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import maskrcnn_benchmark.utils.mdetr_dist as dist #### The following loading utilities are imported from #### https://github.com/BryanPlummer/flickr30k_entities/blob/68b3d6f12d1d710f96233f6bd2b6de799d6f4e5b/flickr30k_entities_utils.py # Changelog: # - Added typing information # - Completed docstrings def get_sentence_data(filename) -> List[Dict[str, Any]]: """ Parses a sentence file from the Flickr30K Entities dataset input: filename - full file path to the sentence file to parse output: a list of dictionaries for each sentence with the following fields: sentence - the original sentence phrases - a list of dictionaries for each phrase with the following fields: phrase - the text of the annotated phrase first_word_index - the position of the first word of the phrase in the sentence phrase_id - an identifier for this phrase phrase_type - a list of the coarse categories this phrase belongs to """ with open(filename, "r") as f: sentences = f.read().split("\n") annotations = [] for sentence in sentences: if not sentence: continue first_word = [] phrases = [] phrase_id = [] phrase_type = [] words = [] current_phrase = [] add_to_phrase = False for token in sentence.split(): if add_to_phrase: if token[-1] == "]": add_to_phrase = False token = token[:-1] current_phrase.append(token) phrases.append(" ".join(current_phrase)) current_phrase = [] else: current_phrase.append(token) words.append(token) else: if token[0] == "[": add_to_phrase = True first_word.append(len(words)) parts = token.split("/") phrase_id.append(parts[1][3:]) phrase_type.append(parts[2:]) else: words.append(token) sentence_data = {"sentence": " ".join(words), "phrases": []} for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type): sentence_data["phrases"].append( {"first_word_index": index, "phrase": phrase, "phrase_id": p_id, "phrase_type": p_type} ) annotations.append(sentence_data) return annotations def get_annotations(filename) -> Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]]: """ Parses the xml files in the Flickr30K Entities dataset input: filename - full file path to the annotations file to parse output: dictionary with the following fields: scene - list of identifiers which were annotated as pertaining to the whole scene nobox - list of identifiers which were annotated as not being visible in the image boxes - a dictionary where the fields are identifiers and the values are its list of boxes in the [xmin ymin xmax ymax] format height - int representing the height of the image width - int representing the width of the image depth - int representing the depth of the image """ tree = ET.parse(filename) root = tree.getroot() size_container = root.findall("size")[0] anno_info: Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]] = {} all_boxes: Dict[str, List[List[int]]] = {} all_noboxes: List[str] = [] all_scenes: List[str] = [] for size_element in size_container: assert size_element.text anno_info[size_element.tag] = int(size_element.text) for object_container in root.findall("object"): for names in object_container.findall("name"): box_id = names.text assert box_id box_container = object_container.findall("bndbox") if len(box_container) > 0: if box_id not in all_boxes: all_boxes[box_id] = [] xmin = int(box_container[0].findall("xmin")[0].text) ymin = int(box_container[0].findall("ymin")[0].text) xmax = int(box_container[0].findall("xmax")[0].text) ymax = int(box_container[0].findall("ymax")[0].text) all_boxes[box_id].append([xmin, ymin, xmax, ymax]) else: nobndbox = int(object_container.findall("nobndbox")[0].text) if nobndbox > 0: all_noboxes.append(box_id) scene = int(object_container.findall("scene")[0].text) if scene > 0: all_scenes.append(box_id) anno_info["boxes"] = all_boxes anno_info["nobox"] = all_noboxes anno_info["scene"] = all_scenes return anno_info #### END of import from flickr30k_entities #### Bounding box utilities imported from torchvision and converted to numpy def box_area(boxes: np.array) -> np.array: """ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. Args: boxes (Tensor[N, 4]): boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Returns: area (Tensor[N]): area for each box """ assert boxes.ndim == 2 and boxes.shape[-1] == 4 return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py # with slight modifications def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]: area1 = box_area(boxes1) area2 = box_area(boxes2) lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clip(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter return inter, union def box_iou(boxes1: np.array, boxes2: np.array) -> np.array: """ Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. Args: boxes1 (Tensor[N, 4]) boxes2 (Tensor[M, 4]) Returns: iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 """ inter, union = _box_inter_union(boxes1, boxes2) iou = inter / union return iou #### End of import of box utilities def _merge_boxes(boxes: List[List[int]]) -> List[List[int]]: """ Return the boxes corresponding to the smallest enclosing box containing all the provided boxes The boxes are expected in [x1, y1, x2, y2] format """ if len(boxes) == 1: return boxes np_boxes = np.asarray(boxes) return [[np_boxes[:, 0].min(), np_boxes[:, 1].min(), np_boxes[:, 2].max(), np_boxes[:, 3].max()]] class RecallTracker: """Utility class to track recall@k for various k, split by categories""" def __init__(self, topk: Sequence[int]): """ Parameters: - topk : tuple of ints corresponding to the recalls being tracked (eg, recall@1, recall@10, ...) """ self.total_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk} self.positives_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk} def add_positive(self, k: int, category: str): """Log a positive hit @k for given category""" if k not in self.total_byk_bycat: raise RuntimeError(f"{k} is not a valid recall threshold") self.total_byk_bycat[k][category] += 1 self.positives_byk_bycat[k][category] += 1 def add_negative(self, k: int, category: str): """Log a negative hit @k for given category""" if k not in self.total_byk_bycat: raise RuntimeError(f"{k} is not a valid recall threshold") self.total_byk_bycat[k][category] += 1 def report(self) -> Dict[int, Dict[str, float]]: """Return a condensed report of the results as a dict of dict. report[k][cat] is the recall@k for the given category """ report: Dict[int, Dict[str, float]] = {} for k in self.total_byk_bycat: assert k in self.positives_byk_bycat report[k] = { cat: self.positives_byk_bycat[k][cat] / self.total_byk_bycat[k][cat] for cat in self.total_byk_bycat[k] } return report class Flickr30kEntitiesRecallEvaluator: def __init__( self, flickr_path: str, subset: str = "test", topk: Sequence[int] = (1, 5, 10, -1), iou_thresh: float = 0.5, merge_boxes: bool = False, verbose: bool = True, ): assert subset in ["train", "test", "val"], f"Wrong flickr subset {subset}" self.topk = topk self.iou_thresh = iou_thresh flickr_path = Path(flickr_path) # We load the image ids corresponding to the current subset with open(flickr_path / f"{subset}.txt") as file_d: self.img_ids = [line.strip() for line in file_d] if verbose: print(f"Flickr subset contains {len(self.img_ids)} images") # Read the box annotations for all the images self.imgid2boxes: Dict[str, Dict[str, List[List[int]]]] = {} if verbose: print("Loading annotations...") for img_id in self.img_ids: anno_info = get_annotations(flickr_path / "Annotations" / f"{img_id}.xml")["boxes"] if merge_boxes: merged = {} for phrase_id, boxes in anno_info.items(): merged[phrase_id] = _merge_boxes(boxes) anno_info = merged self.imgid2boxes[img_id] = anno_info # Read the sentences annotations self.imgid2sentences: Dict[str, List[List[Optional[Dict]]]] = {} if verbose: print("Loading annotations...") self.all_ids: List[str] = [] tot_phrases = 0 for img_id in self.img_ids: sentence_info = get_sentence_data(flickr_path / "Sentences" / f"{img_id}.txt") self.imgid2sentences[img_id] = [None for _ in range(len(sentence_info))] # Some phrases don't have boxes, we filter them. for sent_id, sentence in enumerate(sentence_info): phrases = [phrase for phrase in sentence["phrases"] if phrase["phrase_id"] in self.imgid2boxes[img_id]] if len(phrases) > 0: self.imgid2sentences[img_id][sent_id] = phrases tot_phrases += len(phrases) self.all_ids += [ f"{img_id}_{k}" for k in range(len(sentence_info)) if self.imgid2sentences[img_id][k] is not None ] if verbose: print(f"There are {tot_phrases} phrases in {len(self.all_ids)} sentences to evaluate") def evaluate(self, predictions: List[Dict]): evaluated_ids = set() recall_tracker = RecallTracker(self.topk) for pred in predictions: cur_id = f"{pred['image_id']}_{pred['sentence_id']}" if cur_id in evaluated_ids: print( "Warning, multiple predictions found for sentence" f"{pred['sentence_id']} in image {pred['image_id']}" ) continue # Skip the sentences with no valid phrase if cur_id not in self.all_ids: if len(pred["boxes"]) != 0: print( f"Warning, in image {pred['image_id']} we were not expecting predictions " f"for sentence {pred['sentence_id']}. Ignoring them." ) continue evaluated_ids.add(cur_id) pred_boxes = pred["boxes"] if str(pred["image_id"]) not in self.imgid2sentences: raise RuntimeError(f"Unknown image id {pred['image_id']}") if not 0 <= int(pred["sentence_id"]) < len(self.imgid2sentences[str(pred["image_id"])]): raise RuntimeError(f"Unknown sentence id {pred['sentence_id']}" f" in image {pred['image_id']}") target_sentence = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])] phrases = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])] if len(pred_boxes) != len(phrases): raise RuntimeError( f"Error, got {len(pred_boxes)} predictions, expected {len(phrases)} " f"for sentence {pred['sentence_id']} in image {pred['image_id']}" ) for cur_boxes, phrase in zip(pred_boxes, phrases): target_boxes = self.imgid2boxes[str(pred["image_id"])][phrase["phrase_id"]] ious = box_iou(np.asarray(cur_boxes), np.asarray(target_boxes)) for k in self.topk: maxi = 0 if k == -1: maxi = ious.max() else: assert k > 0 maxi = ious[:k].max() if maxi >= self.iou_thresh: recall_tracker.add_positive(k, "all") for phrase_type in phrase["phrase_type"]: recall_tracker.add_positive(k, phrase_type) else: recall_tracker.add_negative(k, "all") for phrase_type in phrase["phrase_type"]: recall_tracker.add_negative(k, phrase_type) if len(evaluated_ids) != len(self.all_ids): print("ERROR, the number of evaluated sentence doesn't match. Missing predictions:") un_processed = set(self.all_ids) - evaluated_ids for missing in un_processed: img_id, sent_id = missing.split("_") print(f"\t sentence {sent_id} in image {img_id}") raise RuntimeError("Missing predictions") return recall_tracker.report() class FlickrEvaluator(object): def __init__( self, flickr_path, subset, top_k=(1, 5, 10, -1), iou_thresh=0.5, merge_boxes=False, ): assert isinstance(top_k, (list, tuple)) self.evaluator = Flickr30kEntitiesRecallEvaluator( flickr_path, subset=subset, topk=top_k, iou_thresh=iou_thresh, merge_boxes=merge_boxes, verbose=False ) self.predictions = [] self.results = None def accumulate(self): pass def update(self, predictions): self.predictions += predictions def synchronize_between_processes(self): all_predictions = dist.all_gather(self.predictions) self.predictions = sum(all_predictions, []) def summarize(self): if dist.is_main_process(): self.results = self.evaluator.evaluate(self.predictions) table = PrettyTable() all_cat = sorted(list(self.results.values())[0].keys()) table.field_names = ["Recall@k"] + all_cat score = {} for k, v in self.results.items(): cur_results = [v[cat] for cat in all_cat] header = "Upper_bound" if k == -1 else f"Recall@{k}" for cat in all_cat: score[f"{header}_{cat}"] = v[cat] table.add_row([header] + cur_results) print(table) return score return None, None