|
import os |
|
import sys |
|
import cv2 |
|
import math |
|
import glob |
|
import json |
|
import random |
|
import pickle |
|
import numpy as np |
|
import pandas as pd |
|
|
|
from PIL import Image, ImageDraw, ImageFilter |
|
from bert.tokenization_bert import BertTokenizer |
|
|
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
|
|
import torch, gc |
|
import torch.utils.data as data |
|
|
|
import lmdb |
|
import pyarrow as pa |
|
import warnings |
|
from .utils import get_warmup_value |
|
|
|
warnings.simplefilter(action='ignore', category=FutureWarning) |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def loads_pyarrow(buf): |
|
return pa.deserialize(buf) |
|
|
|
|
|
class ReferDataset(data.Dataset): |
|
|
|
def __init__(self, |
|
args, |
|
split='train', |
|
eval_mode=False): |
|
|
|
self.classes = [] |
|
self.args = args |
|
self.split = split |
|
self.aug = args.aug |
|
self.img_sz = args.img_size |
|
|
|
each_img_sz = int(args.img_size/math.sqrt(self.aug.num_bgs)) |
|
mean = (0.485, 0.456, 0.406) |
|
std = (0.229, 0.224, 0.225) |
|
|
|
self.resize_bg1 = A.Compose([ |
|
A.Resize(args.img_size, args.img_size, always_apply=True)]) |
|
|
|
self.resize_bg4 = A.Compose([ |
|
A.Resize(each_img_sz, each_img_sz, always_apply=True)], |
|
additional_targets={'image1': 'image', 'image2': 'image', 'image3': 'image', |
|
'mask1': 'mask', 'mask2': 'mask', 'mask3': 'mask',}) |
|
|
|
self.transforms = A.Compose([ |
|
A.Normalize(mean=mean, std=std), |
|
ToTensorV2 (), |
|
]) |
|
|
|
|
|
if args.dataset == 'refcocog' and args.split in ['testA', 'testB']: |
|
print(f"Easy & Hard Example Experiments - dataset : {args.dataset}, split : {args.split}") |
|
from refer.refer_test import REFER |
|
self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) |
|
else : |
|
from refer.refer import REFER |
|
self.refer = REFER(args.refer_data_root, args.dataset, args.splitBy) |
|
ref_ids = self.refer.getRefIds(split=self.split) |
|
img_ids = self.refer.getImgIds(ref_ids) |
|
all_imgs = self.refer.Imgs |
|
self.imgs = list(all_imgs[i] for i in img_ids) |
|
self.ref_ids = ref_ids |
|
self.ref_id2idx = dict(zip(ref_ids, range(len(ref_ids)))) |
|
self.ref_idx2id = dict(zip(range(len(ref_ids)), ref_ids)) |
|
|
|
|
|
|
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.max_tokens = 20 |
|
self.eval_mode = eval_mode |
|
self.input_ids = [] |
|
self.attention_masks = [] |
|
for i, r in enumerate(ref_ids): |
|
ref = self.refer.Refs[r] |
|
|
|
sentences_for_ref = [] |
|
attentions_for_ref = [] |
|
for j, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): |
|
sentence_raw = el['raw'] |
|
input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True, max_length=self.max_tokens, truncation=True) |
|
|
|
padded_input_ids = [0] * self.max_tokens |
|
padded_input_ids[:len(input_ids)] = input_ids |
|
attention_mask = [0] * self.max_tokens |
|
attention_mask[:len(input_ids)] = [1]*len(input_ids) |
|
sentences_for_ref.append(padded_input_ids) |
|
attentions_for_ref.append(attention_mask) |
|
|
|
self.input_ids.append(sentences_for_ref) |
|
self.attention_masks.append(attentions_for_ref) |
|
|
|
if self.aug.blur: |
|
self.blur = ImageFilter.GaussianBlur(100) |
|
|
|
|
|
if args.dataset == 'refcoco': |
|
lmdb_path = f'/data2/dataset/RefCOCO/logit_db/refcoco/refcoco_logit.lmdb' |
|
elif args.dataset == 'refcoco+': |
|
lmdb_path = f'/data2/dataset/RefCOCO/logit_db/refcoco+/refcocop_logit.lmdb' |
|
elif args.dataset == 'refcocog' and args.splitBy == 'umd': |
|
lmdb_path = f'/data2/dataset/RefCOCO/logit_db/refcocog_u/refcocog_u_logit.lmdb' |
|
|
|
|
|
self.lmdb_env = lmdb.open( |
|
lmdb_path, subdir=False, max_readers=32, |
|
readonly=True, lock=False, |
|
readahead=False, meminit=False) |
|
with self.lmdb_env.begin(write=False) as txn: |
|
self.length = loads_pyarrow(txn.get(b'__len__')) |
|
self.keys = loads_pyarrow(txn.get(b'__keys__')) |
|
|
|
self.epoch = 0 |
|
np.random.seed() |
|
|
|
def get_classes(self): |
|
return self.classes |
|
|
|
def __len__(self): |
|
return len(self.ref_ids) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
refid = self.ref_idx2id[index] |
|
with self.lmdb_env.begin(write=False) as txn: |
|
byteflow = txn.get(self.keys[refid]) |
|
lmdb_dict = loads_pyarrow(byteflow) |
|
|
|
|
|
|
|
if self.split=='train': |
|
if self.aug.num_bgs==4: |
|
aug_prob = self.aug.aug_prob |
|
retr_prob = self.aug.retr_prob |
|
rand_prob = aug_prob - retr_prob |
|
|
|
if self.epoch < self.aug.warmup_epoch : |
|
prob_rand = get_warmup_value(aug_prob, rand_prob, self.epoch, self.aug.warmup_epoch) |
|
prob_retr = get_warmup_value(0, retr_prob, self.epoch, self.aug.warmup_epoch) |
|
|
|
choice = np.random.choice(['one', 'random', 'retrieval'], p=[1-aug_prob, prob_rand, prob_retr]) |
|
else : |
|
choice = np.random.choice(['one', 'random', 'retrieval'], p=[1-aug_prob, rand_prob, retr_prob]) |
|
|
|
if choice == 'one': |
|
num_bgs = 1 |
|
else : |
|
num_bgs = 4 |
|
else: |
|
num_bgs = 1 |
|
choice = 'one' |
|
else: |
|
num_bgs = 1 |
|
choice = 'one' |
|
|
|
target_sent_idx = np.random.choice(len(self.input_ids[index])) |
|
ref_id = self.ref_idx2id[index] |
|
|
|
insert_idx = np.random.choice(range(num_bgs)) |
|
|
|
if num_bgs==1: |
|
ref_ids = [] |
|
sent_idxs = [] |
|
sents = np.array([], dtype='str') |
|
img_ids = [self.refer.Refs[ref_id]['image_id']] |
|
|
|
else: |
|
if choice == 'retrieval': |
|
sent_id = list(lmdb_dict.keys())[target_sent_idx] |
|
img_ids = list(np.random.choice(lmdb_dict[sent_id][:self.aug.top_k], size=num_bgs-1, replace=True)) |
|
img_ids = np.insert(img_ids, insert_idx, self.refer.Refs[ref_id]['image_id']) |
|
|
|
ref_ids = list(np.random.choice(self.ref_ids, size=num_bgs-1, replace=False)) |
|
sent_idxs = [np.random.choice(len(self.refer.Refs[r]['sentences'])) for r in ref_ids] |
|
sents = np.array([self.refer.Refs[r]['sentences'][sent_idxs[i]]['raw'] for i, r in enumerate(ref_ids)], dtype='str') |
|
|
|
ref_ids = np.insert(ref_ids, insert_idx, self.ref_idx2id[index]).astype(int) |
|
sents = np.insert(sents, insert_idx, |
|
self.refer.Refs[ref_ids[insert_idx]]['sentences'][target_sent_idx]['raw']) |
|
sent_idxs = np.insert(sent_idxs, insert_idx, target_sent_idx).astype(int) |
|
|
|
|
|
|
|
if self.aug.tgt_selection == 'random': |
|
target_idx = np.random.choice(range(num_bgs)) |
|
target_ref_idx = self.ref_id2idx[ref_ids[target_idx]] |
|
target_sent_idx = int(np.random.choice(len(self.input_ids[target_ref_idx]))) |
|
elif self.aug.tgt_selection == 'longest': |
|
target_idx = np.argmax(list(map(len, sents))) |
|
target_sent_idx = sent_idxs[target_idx] |
|
elif self.aug.tgt_selection == 'fixed': |
|
target_idx = insert_idx |
|
|
|
target_ref_id = self.ref_idx2id[index] |
|
|
|
|
|
|
|
imgs, masks = [], [] |
|
if choice == 'retrieval': |
|
|
|
for img_id in img_ids: |
|
|
|
img_info = self.refer.Imgs[img_id] |
|
img_path = os.path.join(self.refer.IMAGE_DIR, img_info['file_name']) |
|
|
|
img = Image.open(img_path).convert("RGB") |
|
imgs.append(np.array(img)) |
|
ref = self.refer.imgToRefs[img_id][0] |
|
mask = np.array(self.refer.getMask(ref)['mask']) |
|
masks.append(mask) |
|
else : |
|
for ref_id in ref_ids: |
|
|
|
img_id = self.refer.getImgIds([ref_id])[0] |
|
img_info = self.refer.Imgs[img_id] |
|
img_path = os.path.join(self.refer.IMAGE_DIR, img_info['file_name']) |
|
|
|
img = Image.open(img_path).convert("RGB") |
|
imgs.append(np.array(img)) |
|
ref = self.refer.loadRefs(ref_ids=[ref_id]) |
|
mask = np.array(self.refer.getMask(ref[0])['mask']) |
|
masks.append(mask) |
|
|
|
|
|
if num_bgs==1: |
|
resized = self.resize_bg1(image=imgs[0], mask=masks[0]) |
|
imgs, masks = [resized['image']], [resized['mask']] |
|
img = imgs[0] |
|
else: |
|
|
|
if self.aug.move_crs_pnt: |
|
crs_y = np.random.randint(0, self.img_sz+1) |
|
crs_x = np.random.randint(0, self.img_sz+1) |
|
else: |
|
crs_y = 480//2 |
|
crs_x = 480//2 |
|
|
|
if crs_y==0 or crs_x==0: |
|
img1 = np.zeros([0,crs_x,3]) if crs_y==0 else np.zeros([crs_y,0,3]) |
|
mask1 = np.zeros([0,crs_x]) if crs_y==0 else np.zeros([crs_y,0]) |
|
else: |
|
resize_bg1 = A.Compose([A.Resize(crs_y, crs_x, always_apply=True)]) |
|
temp = resize_bg1(image=imgs[0], mask=masks[0]) |
|
img1 = temp['image'] |
|
mask1 = temp['mask'] |
|
|
|
if crs_y==0 or crs_x==self.img_sz: |
|
img2 = np.zeros([0,self.img_sz-crs_x,3]) if crs_y==0 \ |
|
else np.zeros([crs_y,0,3]) |
|
mask2 = np.zeros([0,self.img_sz-crs_x]) if crs_y==0 \ |
|
else np.zeros([crs_y,0]) |
|
else: |
|
resize_bg2 = A.Compose([ |
|
A.Resize(crs_y, self.img_sz-crs_x, always_apply=True)]) |
|
temp = resize_bg2(image=imgs[1], mask=masks[1]) |
|
img2 = temp['image'] |
|
mask2 = temp['mask'] |
|
|
|
if crs_y==self.img_sz or crs_x==0: |
|
img3 = np.zeros([0,crs_x,3]) if crs_y==self.img_sz \ |
|
else np.zeros([self.img_sz-crs_y,0,3]) |
|
mask3 = np.zeros([0,crs_x]) if crs_y==self.img_sz \ |
|
else np.zeros([self.img_sz-crs_y,0]) |
|
else: |
|
resize_bg3 = A.Compose([ |
|
A.Resize(self.img_sz-crs_y, crs_x, always_apply=True)]) |
|
temp = resize_bg3(image=imgs[2], mask=masks[2]) |
|
img3 = temp['image'] |
|
mask3 = temp['mask'] |
|
|
|
if crs_y==self.img_sz or crs_x==self.img_sz: |
|
img4 = np.zeros([0,self.img_sz-crs_x,3]) if crs_y==self.img_sz \ |
|
else np.zeros([self.img_sz-crs_y,0,3]) |
|
mask4 = np.zeros([0,self.img_sz-crs_x]) if crs_y==self.img_sz \ |
|
else np.zeros([self.img_sz-crs_y,0]) |
|
else: |
|
resize_bg4 = A.Compose([ |
|
A.Resize(self.img_sz-crs_y, |
|
self.img_sz-crs_x, always_apply=True)]) |
|
temp = resize_bg4(image=imgs[3], mask=masks[3]) |
|
img4 = temp['image'] |
|
mask4 = temp['mask'] |
|
|
|
imgs = [img1, img2, img3, img4] |
|
masks = [mask1, mask2, mask3, mask4] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.aug.blur: |
|
imgs = [np.asarray(Image.fromarray(x).filter(self.blur)) if i!=insert_idx else x for i, x in enumerate(imgs)] |
|
|
|
num_rows = num_cols = int(math.sqrt(num_bgs)) |
|
idxs = [(i*num_cols,i*num_cols+num_cols) for i in range(num_rows)] |
|
img = [np.concatenate(imgs[_from:_to], axis=1) for (_from, _to) in idxs] |
|
img = np.concatenate(img, axis=0).astype(np.uint8) |
|
|
|
masks_arr = [] |
|
for bg_idx in range(num_bgs): |
|
mask = masks[bg_idx] |
|
temp = [mask if idx==bg_idx else np.zeros_like(masks[idx]) for idx in range(num_bgs)] |
|
mask = [np.concatenate(temp[_from:_to], axis=1) for (_from, _to) in idxs] |
|
mask = np.concatenate(mask, axis=0).astype(np.int32) |
|
masks_arr.append(mask) |
|
masks = masks_arr |
|
|
|
mask = masks[target_idx] |
|
mask = mask.astype(np.uint8) |
|
mask[mask>0] = 1 |
|
|
|
item = self.transforms(image=img, mask=mask) |
|
img_tensor = item['image'] |
|
target = item['mask'].long() |
|
|
|
target_ref_idx = self.ref_id2idx[target_ref_id] |
|
if self.eval_mode: |
|
embedding = [] |
|
att = [] |
|
|
|
for s in range(len(self.input_ids[target_ref_idx])): |
|
padded_input_ids = self.input_ids[target_ref_idx][s] |
|
|
|
tensor_embeddings = torch.tensor(padded_input_ids).unsqueeze(0) |
|
|
|
attention_mask = self.attention_masks[target_ref_idx][s] |
|
attention_mask = torch.tensor(attention_mask).unsqueeze(0) |
|
|
|
embedding.append(tensor_embeddings.unsqueeze(-1)) |
|
att.append(attention_mask.unsqueeze(-1)) |
|
tensor_embeddings = torch.cat(embedding, dim=-1) |
|
attention_mask = torch.cat(att, dim=-1) |
|
else: |
|
padded_input_ids = self.input_ids[target_ref_idx][target_sent_idx] |
|
|
|
tensor_embeddings = torch.tensor(padded_input_ids).unsqueeze(0) |
|
attention_mask = self.attention_masks[target_ref_idx][target_sent_idx] |
|
attention_mask = torch.tensor(attention_mask).unsqueeze(0) |
|
|
|
item = { |
|
'image': img_tensor, |
|
'seg_target': target, |
|
'sentence': tensor_embeddings, |
|
'attn_mask': attention_mask |
|
} |
|
return item |
|
|