""" COCO dataset which returns image_id for evaluation. Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py """ import torch import json from PIL import Image, ImageDraw from .modulated_coco import ConvertCocoPolysToMask from .tsv import ODTSVDataset from pycocotools.coco import COCO from maskrcnn_benchmark.structures.bounding_box import BoxList import random from .od_to_grounding import convert_object_detection_to_grounding_optimized_for_od, check_for_positive_overflow, sanity_check_target_after_processing, od_to_grounding_optimized_streamlined from ._od_to_description import DescriptionConverter import pdb from collections import defaultdict class CocoDetectionTSV(ODTSVDataset): def __init__( self, name, yaml_file, transforms, return_tokens, tokenizer, extra_fields, random_sample_negative=-1, add_detection_prompt=False, add_detection_prompt_advanced=False, use_od_data_aug=False, control_probabilities={}, disable_shuffle=False, prompt_engineer_version="v2", prompt_limit_negative=-1, positive_question_probability=0.6, negative_question_probability=0.8, full_question_probability=0.5, disable_clip_to_image=False, separation_tokens=" ", no_mask_for_od=False, max_num_labels=-1, max_query_len=256, od_to_grounding_version="legacy", description_file = None, similarity_file = None, **kwargs ): super(CocoDetectionTSV, self).__init__(yaml_file, extra_fields, **kwargs) self._transforms = transforms self.name = name self.max_query_len = max_query_len self.prepare = ConvertCocoPolysToMask( return_masks=False, return_tokens=return_tokens, tokenizer=tokenizer, max_query_len=max_query_len ) self.tokenizer = tokenizer self.control_probabilities = control_probabilities self.random_sample_negative = random_sample_negative self.add_detection_prompt = add_detection_prompt self.add_detection_prompt_advanced = add_detection_prompt_advanced self.use_od_data_aug = use_od_data_aug self.prompt_engineer_version = prompt_engineer_version self.prompt_limit_negative = prompt_limit_negative self.positive_question_probability = positive_question_probability self.negative_question_probability = negative_question_probability self.full_question_probability = full_question_probability self.separation_tokens = separation_tokens self.disable_clip_to_image = disable_clip_to_image self.disable_shuffle = disable_shuffle self.no_mask_for_od = no_mask_for_od self.max_num_labels = max_num_labels self.od_to_grounding_version = od_to_grounding_version self.description_file = description_file self.similarity_file = similarity_file if "description" in self.od_to_grounding_version: self.od_grounding_converter = DescriptionConverter( self.description_file, self.od_to_grounding_version, [], self.ind_to_class, self.similarity_file,) ### stat self.pos_rate = defaultdict(list) def __len__(self): return super(CocoDetectionTSV, self).__len__() def categories(self, no_background=True): categories = self.coco.dataset["categories"] label_list = {} for index, i in enumerate(categories): # assert(index + 1 == i["id"]) if not no_background or (i["name"] != "__background__" and i["id"] != 0): label_list[i["id"]] = i["name"] return label_list def __getitem__(self, idx): # tgt is a BoxList img, target, _, scale = super(CocoDetectionTSV, self).__getitem__(idx) image_id = self.get_img_id(idx) restricted_negative_list = None if not self.disable_clip_to_image: target = target.clip_to_image(remove_empty=True) original_box_num = len(target) target, positive_caption_length = check_for_positive_overflow( target, self.ind_to_class, self.tokenizer, self.max_query_len - 2 ) # leave some space for the special tokens if len(target) < original_box_num: print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target))) if "mixed" in self.od_to_grounding_version: # 70% v.s. 30% if random.random() < 0.7: annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions, target = self.od_grounding_converter.train_od_to_grounding( target=target, image_id=image_id, ind_to_class=self.ind_to_class, tokenizer=self.tokenizer, random_sample_negative=self.random_sample_negative, ) else: annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od( target=target, image_id=image_id, ind_to_class=self.ind_to_class, disable_shuffle=self.disable_shuffle, add_detection_prompt=self.add_detection_prompt, add_detection_prompt_advanced=self.add_detection_prompt_advanced, random_sample_negative=self.random_sample_negative, control_probabilities=self.control_probabilities, restricted_negative_list=restricted_negative_list, separation_tokens=self.separation_tokens, max_num_labels=self.max_num_labels, positive_caption_length=positive_caption_length, tokenizer=self.tokenizer, max_seq_length=self.max_query_len - 2, ) elif "description" in self.od_to_grounding_version: annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions, target = self.od_grounding_converter.train_od_to_grounding( target=target, image_id=image_id, ind_to_class=self.ind_to_class, tokenizer=self.tokenizer, random_sample_negative=self.random_sample_negative, ) elif self.od_to_grounding_version != "legacy": annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions, target = od_to_grounding_optimized_streamlined( target=target, image_id=image_id, ind_to_class=self.ind_to_class, tokenizer=self.tokenizer, od_to_grounding_version=self.od_to_grounding_version, ) else: annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od( target=target, image_id=image_id, ind_to_class=self.ind_to_class, disable_shuffle=self.disable_shuffle, add_detection_prompt=self.add_detection_prompt, add_detection_prompt_advanced=self.add_detection_prompt_advanced, random_sample_negative=self.random_sample_negative, control_probabilities=self.control_probabilities, restricted_negative_list=restricted_negative_list, separation_tokens=self.separation_tokens, max_num_labels=self.max_num_labels, positive_caption_length=positive_caption_length, tokenizer=self.tokenizer, max_seq_length=self.max_query_len - 2, ) # assert(len(self.tokenizer.tokenize(caption)) <= self.max_query_len-2) anno = { "image_id": image_id, "annotations": annotations, "caption": caption, "label_to_positions": label_to_positions, } if "spans" in target.extra_fields: anno["spans"] = target.extra_fields["spans"] if not isinstance(anno["spans"], list): anno["spans"] = anno["spans"].tolist() anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective if self.no_mask_for_od: anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) img, anno = self.prepare(img, anno, box_format="xyxy") if self._transforms is not None: img, target = self._transforms(img, target) # add additional property for ann in anno: target.add_field(ann, anno[ann]) # sanity_check_target_after_processing(target) return img, target, idx def get_raw_image(self, idx): image, *_ = super(CocoDetectionTSV, self).__getitem__(idx) return image def get_img_id(self, idx): line_no = self.get_line_no(idx) if self.label_tsv is not None: row = self.label_tsv.seek(line_no) img_id = row[0] try: return int(img_id) except: return idx