Spaces:
Sleeping
Sleeping
import torch | |
import torch.distributed as dist | |
import time | |
from torchvision.ops import nms | |
import random | |
import numpy as np | |
from PIL import Image, ImageDraw | |
import pdb | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from .modulated_coco import ConvertCocoPolysToMask | |
from .tsv import ODTSVDataset, TSVYamlDataset | |
from .od_to_grounding import sanity_check_target_after_processing | |
from copy import deepcopy | |
class PseudoData(TSVYamlDataset): | |
def __init__( | |
self, | |
yaml_file, | |
transforms, | |
return_tokens, | |
return_masks, | |
tokenizer, | |
caption_min_box=1, | |
replace_clean_label=False, | |
further_screen=False, | |
caption_conf=0.5, | |
caption_nms=-1, | |
pack_random_caption_number=0, | |
inference_caption=False, | |
sample_negative_for_grounding_data=-1, | |
random_pack_prob=-1.0, | |
no_random_pack_probability=0.0, | |
safeguard_positive_caption=True, | |
mlm_obj_for_only_positive=False, | |
caption_format_version="v1", | |
local_debug=False, | |
max_query_len=256, | |
diver_box_for_vqa=False, | |
**kwargs | |
): | |
super(PseudoData, self).__init__(yaml_file, None, replace_clean_label) | |
self.yaml_file = yaml_file | |
self._transforms = transforms | |
self.max_query_len = max_query_len | |
self.prepare = ConvertCocoPolysToMask( | |
return_masks=return_masks, return_tokens=return_tokens, tokenizer=tokenizer, max_query_len=max_query_len | |
) | |
self.diver_box_for_vqa = diver_box_for_vqa | |
if "qa" in self.yaml_file: | |
assert self.diver_box_for_vqa # must diver box | |
self.tokenizer = tokenizer | |
self.caption_min_box = caption_min_box | |
self.replace_clean_label = replace_clean_label | |
self.further_screen = further_screen | |
self.pack_random_caption_number = pack_random_caption_number | |
self.caption_format_version = caption_format_version | |
self.caption_conf = caption_conf | |
self.caption_nms = caption_nms | |
self.inference_caption = inference_caption | |
self.sample_negative_for_grounding_data = sample_negative_for_grounding_data | |
self.random_pack_prob = random_pack_prob | |
self.no_random_pack_probability = no_random_pack_probability | |
self.safeguard_positive_caption = safeguard_positive_caption | |
self.mlm_obj_for_only_positive = mlm_obj_for_only_positive | |
self.local_debug = local_debug | |
try: | |
self.rank = dist.get_rank() | |
except: | |
self.rank = 0 | |
def __len__(self): | |
return super(PseudoData, self).__len__() | |
def check_for_overlap(range1, range2): | |
if range1[0] > range2[1] or range2[0] > range1[1]: | |
return False | |
return True | |
def divert_boxes(self, anno): | |
# first get answer start and end | |
answer_start = len(anno["text"]) + 1 # +1 for the space | |
answer_end = len(anno["caption"]) | |
question = anno["caption"][:answer_start] # get the question | |
mask_start = len(question) | |
# add the mask token | |
mask_token = self.tokenizer.mask_token | |
if mask_token is None: | |
mask_token = "answer" | |
question += mask_token | |
mask_end = len(question) | |
# divert the box | |
for i in range(len(anno["bboxes"])): | |
# check over lap | |
for j in range(len(anno["tokens_positive"][i])): | |
if self.check_for_overlap(anno["tokens_positive"][i][j], [answer_start, answer_end]): | |
# if overlap, then divert the box to the mask token | |
anno["tokens_positive"][i][j] = [mask_start, mask_end] | |
anno["caption"] = question | |
return question, anno | |
def __getitem__(self, idx): | |
img, anno, _, scale = super(PseudoData, self).__getitem__(idx) | |
if self.inference_caption: | |
caption = None | |
if isinstance(anno, list): | |
caption = anno[0]["caption"] # inference mode for bing | |
anno = [] | |
elif len(anno) == 1: | |
caption = anno["caption"] # inference mode for googlecc | |
anno = [] | |
else: | |
caption = " ".join(anno["captions"]) | |
anno = [] | |
else: | |
if self.caption_format_version == "v2": | |
anno = self.convert_anno_from_yiling_to_ours(anno) | |
if self.further_screen: | |
conf = self.caption_conf | |
nms_thre = self.caption_nms | |
bboxes = torch.as_tensor(anno["bboxes"]).float() | |
scores = torch.as_tensor(anno["scores"]) | |
tokens_positive = anno["tokens_positive"] | |
keep = scores > conf | |
scores = scores[keep] | |
bboxes = bboxes[keep] | |
tokens_positive = [i for index, i in enumerate(tokens_positive) if keep[index]] | |
assert len(tokens_positive) == len(bboxes) == len(scores) | |
if len(bboxes) < self.caption_min_box: # Retry triggered! | |
return self[np.random.choice(len(self))] | |
if nms_thre > 0: | |
keep = nms(boxes=bboxes, scores=scores, iou_threshold=nms_thre) | |
scores = scores[keep] | |
bboxes = bboxes[keep] | |
tokens_positive = [tokens_positive[i] for i in keep] | |
assert len(tokens_positive) == len(bboxes) == len(scores) | |
# Write back | |
anno["bboxes"] = bboxes.tolist() | |
anno["scores"] = scores.tolist() | |
anno["tokens_positive"] = tokens_positive | |
boxes = torch.as_tensor(anno["bboxes"]) | |
if len(boxes) < self.caption_min_box: # Retry triggered! | |
return self[np.random.choice(len(self))] | |
target = BoxList(boxes, (anno["img_w"], anno["img_h"]), mode="xyxy") | |
target = target.clip_to_image(remove_empty=True) | |
if self.diver_box_for_vqa: | |
caption, anno = self.divert_boxes(anno=anno) # will change caption and "tokens_positive" | |
caption = anno["caption"] | |
greenlight_span_for_masked_lm_objective = [(0, len(caption))] | |
new_anno = [] | |
areas = target.area() | |
for i in range(len(target)): | |
new_anno_i = {} | |
new_anno_i["area"] = areas[i] | |
new_anno_i["iscrowd"] = 0 | |
new_anno_i["image_id"] = idx | |
new_anno_i["category_id"] = 1 # following vg and others | |
new_anno_i["id"] = None | |
new_anno_i["bbox"] = target.bbox[i].numpy().tolist() | |
new_anno_i["tokens_positive"] = anno["tokens_positive"][i] | |
new_anno.append(new_anno_i) | |
anno = new_anno | |
annotations = {"image_id": idx, "annotations": anno, "caption": caption} | |
annotations["greenlight_span_for_masked_lm_objective"] = greenlight_span_for_masked_lm_objective | |
img, annotations = self.prepare(img, annotations, box_format="xyxy") | |
if self._transforms is not None: | |
img, target = self._transforms(img, target) | |
# add additional property | |
for ann in annotations: | |
target.add_field(ann, annotations[ann]) | |
# This is the real image_id | |
image_id = self.get_img_id(idx) | |
# Can insert additional field into target if needed | |
sanity_check_target_after_processing(target) | |
return img, target, idx | |
def convert_anno_from_yiling_to_ours(self, anno): | |
flatterned_bboxes = [] | |
flatterned_tokens_positive = [] | |
flatterned_bboxes_scores = [] | |
for i in range(len(anno["bboxes"])): | |
# i is the index for entity | |
for j in range(len(anno["bboxes"][i])): | |
# j is the index for each box | |
flatterned_bboxes.append(anno["bboxes"][i][j]) | |
flatterned_tokens_positive.append( | |
anno["tokens_positive"][i] | |
) # Assume this box corresponds to all the token_spans for this entity | |
flatterned_bboxes_scores.append(anno["scores"][i][j]) | |
anno["bboxes"] = flatterned_bboxes | |
anno["tokens_positive"] = flatterned_tokens_positive | |
anno["scores"] = flatterned_bboxes_scores | |
return anno | |
def get_raw_image(self, idx): | |
image, *_ = super(PseudoData, self).__getitem__(idx) | |
return image | |
def get_img_id(self, idx): | |
line_no = self.get_line_no(idx) | |
if self.label_tsv is not None: | |
row = self.label_tsv.seek(line_no) | |
img_id = row[0] | |
return img_id | |