elia / data /dataset_refer_bert_mlm.py
yxchng
add files
a166479
raw
history blame contribute delete
No virus
8.76 kB
import os
import sys
import torch.utils.data as data
import torch
from torchvision import transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF
import random
from bert.tokenization_bert import BertTokenizer
import h5py
from refer.refer import REFER
from args import get_parser
# Dataset configuration initialization
parser = get_parser()
args = parser.parse_args()
#from hfai.datasets import CocoDetection
from PIL import Image
import numpy as np
#from ffrecord.torch import DataLoader,Dataset
#import ffrecord
from copy import deepcopy
_EXIF_ORIENT = 274
def _apply_exif_orientation(image):
"""
Applies the exif orientation correctly.
This code exists per the bug:
https://github.com/python-pillow/Pillow/issues/3973
with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
various methods, especially `tobytes`
Function based on:
https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
Args:
image (PIL.Image): a PIL image
Returns:
(PIL.Image): the PIL image with exif orientation applied, if applicable
"""
if not hasattr(image, "getexif"):
return image
try:
exif = image.getexif()
except Exception: # https://github.com/facebookresearch/detectron2/issues/1885
exif = None
if exif is None:
return image
orientation = exif.get(_EXIF_ORIENT)
method = {
2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}.get(orientation)
if method is not None:
return image.transpose(method)
return image
def convert_PIL_to_numpy(image, format):
"""
Convert PIL image to numpy array of target format.
Args:
image (PIL.Image): a PIL image
format (str): the format of output image
Returns:
(np.ndarray): also see `read_image`
"""
if format is not None:
# PIL only supports RGB, so convert to RGB and flip channels over below
conversion_format = format
if format in ["BGR", "YUV-BT.601"]:
conversion_format = "RGB"
image = image.convert(conversion_format)
image = np.asarray(image)
# PIL squeezes out the channel dimension for "L", so make it HWC
if format == "L":
image = np.expand_dims(image, -1)
# handle formats not supported by PIL
elif format == "BGR":
# flip channels if needed
image = image[:, :, ::-1]
elif format == "YUV-BT.601":
image = image / 255.0
image = np.dot(image, np.array(_M_RGB2YUV).T)
return image
class ReferDataset(data.Dataset):
#class ReferDataset(ffrecord.torch.Dataset):
def __init__(self,
args,
image_transforms=None,
target_transforms=None,
split='train',
eval_mode=False,
mlm_prob=0.15,
mlm_prob_mask=0.9,
mlm_prob_noise=0.0):
self.classes = []
self.image_transforms = image_transforms
self.target_transform = target_transforms
self.split = split
self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy)
self.max_tokens = 20
ref_ids = self.refer.getRefIds(split=self.split)
img_ids = self.refer.getImgIds(ref_ids)
all_imgs = self.refer.Imgs
self.imgs = list(all_imgs[i] for i in img_ids)
self.ref_ids = ref_ids
self.input_ids = []
self.attention_masks = []
self.tokenizer = BertTokenizer.from_pretrained(args.bert_tokenizer)
self.eval_mode = eval_mode
# if we are testing on a dataset, test all sentences of an object;
# o/w, we are validating during training, randomly sample one sentence for efficiency
self.mlm_prob = mlm_prob
self.mlm_prob_mask = mlm_prob_mask
self.mlm_prob_noise = mlm_prob_noise
for r in ref_ids:
ref = self.refer.Refs[r]
sentences_for_ref = []
attentions_for_ref = []
for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])):
sentence_raw = el['raw']
attention_mask = [0] * self.max_tokens
padded_input_ids = [0] * self.max_tokens
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True)
# truncation of tokens
input_ids = input_ids[:self.max_tokens]
padded_input_ids[:len(input_ids)] = input_ids
attention_mask[:len(input_ids)] = [1]*len(input_ids)
sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0))
attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0))
self.input_ids.append(sentences_for_ref)
self.attention_masks.append(attentions_for_ref)
def get_classes(self):
return self.classes
def __len__(self):
return len(self.ref_ids)
def __getitem__(self, index):
#print(index)
#index = index[0]
this_ref_id = self.ref_ids[index]
this_img_id = self.refer.getImgIds(this_ref_id)
this_img = self.refer.Imgs[this_img_id[0]]
#print("this_ref_id", this_ref_id)
#print("this_img_id", this_img_id)
#print("this_img", this_img)
img = Image.open(os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])).convert("RGB")
#img = self.hfai_dataset.reader.read_imgs([self.keys[this_img_id[0]]])[0]
img = _apply_exif_orientation(img)
img = convert_PIL_to_numpy(img, 'RGB')
#print(img.shape)
img = Image.fromarray(img)
ref = self.refer.loadRefs(this_ref_id)
ref_mask = np.array(self.refer.getMask(ref[0])['mask'])
annot = np.zeros(ref_mask.shape)
annot[ref_mask == 1] = 1
annot = Image.fromarray(annot.astype(np.uint8), mode="P")
if self.image_transforms is not None:
# resize, from PIL to tensor, and mean and std normalization
img, target = self.image_transforms(img, annot)
if self.eval_mode:
embedding = []
att = []
for s in range(len(self.input_ids[index])):
e = self.input_ids[index][s]
a = self.attention_masks[index][s]
embedding.append(e.unsqueeze(-1))
att.append(a.unsqueeze(-1))
tensor_embeddings = torch.cat(embedding, dim=-1)
attention_mask = torch.cat(att, dim=-1)
return img, target, tensor_embeddings, attention_mask
else:
#print(target.shape)
#print( np.argwhere(target.detach().cpu().numpy()).shape)
tmp = np.argwhere(target.detach().cpu().numpy())
centroid = tmp.mean(0)
#print(centroid)
centroid_x, centroid_y = int(centroid[1]), int(centroid[0])
#centroid_x, centroid_y = centroid[1], centroid[0]
position = torch.tensor([centroid_x, centroid_y]).float()
#print(centroid_x, centroid_y)
#print(input_ids.shape)
choice_sent = np.random.choice(len(self.input_ids[index]))
tensor_embeddings = self.input_ids[index][choice_sent]
attention_mask = self.attention_masks[index][choice_sent]
target_embeddings = deepcopy(tensor_embeddings)
mlm_mask = []
for j in range(tensor_embeddings.shape[1]):
prob = random.random()
if prob < self.mlm_prob:
mlm_mask.append(1)
prob /= self.mlm_prob
if prob < self.mlm_prob_mask:
tensor_embeddings[0][j] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
elif prob < self.mlm_prob_mask + self.mlm_prob_noise:
tensor_embeddings[0][j] = np.random.randint(len(self.tokenizer))
else:
mlm_mask.append(0)
mlm_mask = torch.tensor(mlm_mask).unsqueeze(0)
#pos_ids = self.tokenizer.encode(text="{:d} {:d}".format(centroid_x, centroid_y), add_special_tokens=True)
#print(attention_mask)
#print(attention_mask.shape)
return img, target, tensor_embeddings, attention_mask, target_embeddings, mlm_mask, position