zdou0830's picture
desco
749745d
raw
history blame
No virus
9.63 kB
"""
COCO dataset which returns image_id for evaluation.
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
import torch
import json
from PIL import Image, ImageDraw
from .modulated_coco import ConvertCocoPolysToMask
from .tsv import ODTSVDataset
from pycocotools.coco import COCO
from maskrcnn_benchmark.structures.bounding_box import BoxList
import random
from .od_to_grounding import convert_object_detection_to_grounding_optimized_for_od, check_for_positive_overflow, sanity_check_target_after_processing, od_to_grounding_optimized_streamlined
from ._od_to_description import DescriptionConverter
import pdb
from collections import defaultdict
class CocoDetectionTSV(ODTSVDataset):
def __init__(
self,
name,
yaml_file,
transforms,
return_tokens,
tokenizer,
extra_fields,
random_sample_negative=-1,
add_detection_prompt=False,
add_detection_prompt_advanced=False,
use_od_data_aug=False,
control_probabilities={},
disable_shuffle=False,
prompt_engineer_version="v2",
prompt_limit_negative=-1,
positive_question_probability=0.6,
negative_question_probability=0.8,
full_question_probability=0.5,
disable_clip_to_image=False,
separation_tokens=" ",
no_mask_for_od=False,
max_num_labels=-1,
max_query_len=256,
od_to_grounding_version="legacy",
description_file = None,
similarity_file = None,
**kwargs
):
super(CocoDetectionTSV, self).__init__(yaml_file, extra_fields, **kwargs)
self._transforms = transforms
self.name = name
self.max_query_len = max_query_len
self.prepare = ConvertCocoPolysToMask(
return_masks=False, return_tokens=return_tokens, tokenizer=tokenizer, max_query_len=max_query_len
)
self.tokenizer = tokenizer
self.control_probabilities = control_probabilities
self.random_sample_negative = random_sample_negative
self.add_detection_prompt = add_detection_prompt
self.add_detection_prompt_advanced = add_detection_prompt_advanced
self.use_od_data_aug = use_od_data_aug
self.prompt_engineer_version = prompt_engineer_version
self.prompt_limit_negative = prompt_limit_negative
self.positive_question_probability = positive_question_probability
self.negative_question_probability = negative_question_probability
self.full_question_probability = full_question_probability
self.separation_tokens = separation_tokens
self.disable_clip_to_image = disable_clip_to_image
self.disable_shuffle = disable_shuffle
self.no_mask_for_od = no_mask_for_od
self.max_num_labels = max_num_labels
self.od_to_grounding_version = od_to_grounding_version
self.description_file = description_file
self.similarity_file = similarity_file
if "description" in self.od_to_grounding_version:
self.od_grounding_converter = DescriptionConverter(
self.description_file,
self.od_to_grounding_version,
[],
self.ind_to_class,
self.similarity_file,)
### stat
self.pos_rate = defaultdict(list)
def __len__(self):
return super(CocoDetectionTSV, self).__len__()
def categories(self, no_background=True):
categories = self.coco.dataset["categories"]
label_list = {}
for index, i in enumerate(categories):
# assert(index + 1 == i["id"])
if not no_background or (i["name"] != "__background__" and i["id"] != 0):
label_list[i["id"]] = i["name"]
return label_list
def __getitem__(self, idx):
# tgt is a BoxList
img, target, _, scale = super(CocoDetectionTSV, self).__getitem__(idx)
image_id = self.get_img_id(idx)
restricted_negative_list = None
if not self.disable_clip_to_image:
target = target.clip_to_image(remove_empty=True)
original_box_num = len(target)
target, positive_caption_length = check_for_positive_overflow(
target, self.ind_to_class, self.tokenizer, self.max_query_len - 2
) # leave some space for the special tokens
if len(target) < original_box_num:
print("WARNING: removed {} boxes due to positive caption overflow".format(original_box_num - len(target)))
if "mixed" in self.od_to_grounding_version: # 70% v.s. 30%
if random.random() < 0.7:
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions, target = self.od_grounding_converter.train_od_to_grounding(
target=target,
image_id=image_id,
ind_to_class=self.ind_to_class,
tokenizer=self.tokenizer,
random_sample_negative=self.random_sample_negative,
)
else:
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od(
target=target,
image_id=image_id,
ind_to_class=self.ind_to_class,
disable_shuffle=self.disable_shuffle,
add_detection_prompt=self.add_detection_prompt,
add_detection_prompt_advanced=self.add_detection_prompt_advanced,
random_sample_negative=self.random_sample_negative,
control_probabilities=self.control_probabilities,
restricted_negative_list=restricted_negative_list,
separation_tokens=self.separation_tokens,
max_num_labels=self.max_num_labels,
positive_caption_length=positive_caption_length,
tokenizer=self.tokenizer,
max_seq_length=self.max_query_len - 2,
)
elif "description" in self.od_to_grounding_version:
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions, target = self.od_grounding_converter.train_od_to_grounding(
target=target,
image_id=image_id,
ind_to_class=self.ind_to_class,
tokenizer=self.tokenizer,
random_sample_negative=self.random_sample_negative,
)
elif self.od_to_grounding_version != "legacy":
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions, target = od_to_grounding_optimized_streamlined(
target=target,
image_id=image_id,
ind_to_class=self.ind_to_class,
tokenizer=self.tokenizer,
od_to_grounding_version=self.od_to_grounding_version,
)
else:
annotations, caption, greenlight_span_for_masked_lm_objective, label_to_positions = convert_object_detection_to_grounding_optimized_for_od(
target=target,
image_id=image_id,
ind_to_class=self.ind_to_class,
disable_shuffle=self.disable_shuffle,
add_detection_prompt=self.add_detection_prompt,
add_detection_prompt_advanced=self.add_detection_prompt_advanced,
random_sample_negative=self.random_sample_negative,
control_probabilities=self.control_probabilities,
restricted_negative_list=restricted_negative_list,
separation_tokens=self.separation_tokens,
max_num_labels=self.max_num_labels,
positive_caption_length=positive_caption_length,
tokenizer=self.tokenizer,
max_seq_length=self.max_query_len - 2,
)
# assert(len(self.tokenizer.tokenize(caption)) <= self.max_query_len-2)
anno = {
"image_id": image_id,
"annotations": annotations,
"caption": caption,
"label_to_positions": label_to_positions,
}
if "spans" in target.extra_fields:
anno["spans"] = target.extra_fields["spans"]
if not isinstance(anno["spans"], list):
anno["spans"] = anno["spans"].tolist()
anno["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective
if self.no_mask_for_od:
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
img, anno = self.prepare(img, anno, box_format="xyxy")
if self._transforms is not None:
img, target = self._transforms(img, target)
# add additional property
for ann in anno:
target.add_field(ann, anno[ann])
# sanity_check_target_after_processing(target)
return img, target, idx
def get_raw_image(self, idx):
image, *_ = super(CocoDetectionTSV, self).__getitem__(idx)
return image
def get_img_id(self, idx):
line_no = self.get_line_no(idx)
if self.label_tsv is not None:
row = self.label_tsv.seek(line_no)
img_id = row[0]
try:
return int(img_id)
except:
return idx