""" COCO dataset which returns image_id for evaluation. Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py """ import torch import json from PIL import Image, ImageDraw from .modulated_coco import ConvertCocoPolysToMask from .tsv import ODTSVDataset from pycocotools.coco import COCO from maskrcnn_benchmark.structures.bounding_box import BoxList import random from .od_to_grounding import convert_object_detection_to_grounding_optimized_for_od, check_for_positive_overflow, sanity_check_target_after_processing class CocoDetectionTSV(ODTSVDataset): def __init__(self, name, yaml_file, transforms, return_tokens, tokenizer, extra_fields, random_sample_negative=-1, add_detection_prompt=False, add_detection_prompt_advanced=False, use_od_data_aug=False, control_probabilities={}, disable_shuffle=False, prompt_engineer_version="v2", prompt_limit_negative=-1, positive_question_probability=0.6, negative_question_probability=0.8, full_question_probability=0.5, disable_clip_to_image=False, separation_tokens=" ", no_mask_for_od=False, max_num_labels=-1, max_query_len=256, **kwargs ): super(CocoDetectionTSV, self).__init__(yaml_file, extra_fields, **kwargs) self._transforms = transforms self.name = name self.max_query_len = max_query_len self.prepare = ConvertCocoPolysToMask( return_masks=False, return_tokens=return_tokens, tokenizer=tokenizer, max_query_len=max_query_len ) self.tokenizer = tokenizer self.control_probabilities = control_probabilities self.random_sample_negative = random_sample_negative self.add_detection_prompt = add_detection_prompt self.add_detection_prompt_advanced = add_detection_prompt_advanced self.use_od_data_aug = use_od_data_aug self.prompt_engineer_version = prompt_engineer_version self.prompt_limit_negative = prompt_limit_negative self.positive_question_probability = positive_question_probability self.negative_question_probability = negative_question_probability self.full_question_probability = full_question_probability self.separation_tokens = separation_tokens self.disable_clip_to_image = disable_clip_to_image self.disable_shuffle = disable_shuffle self.no_mask_for_od = no_mask_for_od self.max_num_labels = max_num_labels def __len__(self): return super(CocoDetectionTSV, self).__len__() def categories(self, no_background=True): categories = self.coco.dataset["categories"] label_list = {} for index, i in enumerate(categories): # assert(index + 1 == i["id"]) if not no_background or (i["name"] != "__background__" and i['id'] != 0): label_list[i["id"]] = i["name"] return label_list def __getitem__(self, idx): # tgt is a BoxList img, target, _, scale = super(CocoDetectionTSV, self).__getitem__(idx) image_id = self.get_img_id(idx) restricted_negative_list = None if not self.disable_clip_to_image: target = target.clip_to_image(remove_empty=True) original_box_num = len(target) target, positive_caption_length = check_for_positive_overflow(target, self.ind_to_class, self.tokenizer, self.max_query_len-2) # leave some space for the special tokens if len(target) < original_box_num: print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target))) annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od( target=target, image_id=image_id, ind_to_class=self.ind_to_class, disable_shuffle=self.disable_shuffle, add_detection_prompt=self.add_detection_prompt, add_detection_prompt_advanced=self.add_detection_prompt_advanced, random_sample_negative=self.random_sample_negative, control_probabilities=self.control_probabilities, restricted_negative_list=restricted_negative_list, separation_tokens=self.separation_tokens, max_num_labels=self.max_num_labels, positive_caption_length=positive_caption_length, tokenizer=self.tokenizer, max_seq_length=self.max_query_len-2 ) # assert(len(self.tokenizer.tokenize(caption)) <= self.max_query_len-2) # print(caption) anno = {"image_id": image_id, "annotations": annotations, "caption": caption, "label_to_positions": label_to_positions} anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective if self.no_mask_for_od: anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) img, anno = self.prepare(img, anno, box_format="xyxy") if self._transforms is not None: img, target = self._transforms(img, target) # add additional property for ann in anno: target.add_field(ann, anno[ann]) sanity_check_target_after_processing(target) return img, target, idx def get_raw_image(self, idx): image, *_ = super(CocoDetectionTSV, self).__getitem__(idx) return image def get_img_id(self, idx): line_no = self.get_line_no(idx) if self.label_tsv is not None: row = self.label_tsv.seek(line_no) img_id = row[0] try: return int(img_id) except: return idx