|
import os |
|
|
|
import json |
|
|
|
import random |
|
|
|
import torch |
|
|
|
import ijson |
|
|
|
import numpy as np |
|
|
|
from PIL import Image |
|
|
|
from torchvision.transforms import ToTensor |
|
|
|
from torchvision.ops import box_convert, clip_boxes_to_image |
|
|
|
from re_classifier import REClassifier |
|
|
|
from utils import progressbar |
|
|
|
|
|
def collate_fn(batch): |
|
image = torch.stack([s['image'] for s in batch], dim=0) |
|
|
|
image_size = torch.FloatTensor([s['image_size'] for s in batch]) |
|
|
|
|
|
bbox = torch.cat([s['bbox'] for s in batch], dim=0) |
|
|
|
|
|
bbox_raw = torch.cat([s['bbox_raw'] for s in batch], dim=0) |
|
|
|
expr = [s['expr'] for s in batch] |
|
|
|
tok = None |
|
if batch[0]['tok'] is not None: |
|
tok = { |
|
'input_ids': torch.cat([s['tok']['input_ids'] for s in batch], dim=0), |
|
'attention_mask': torch.cat([s['tok']['attention_mask'] for s in batch], dim=0) |
|
} |
|
|
|
|
|
max_length = max([s['tok']['length'] for s in batch]) |
|
tok = { |
|
'input_ids': tok['input_ids'][:, :max_length], |
|
'attention_mask': tok['attention_mask'][:, :max_length], |
|
} |
|
|
|
mask = None |
|
if batch[0]['mask'] is not None: |
|
mask = torch.stack([s['mask'] for s in batch], dim=0) |
|
|
|
mask_bbox = None |
|
if batch[0]['mask_bbox'] is not None: |
|
mask_bbox = torch.stack([s['mask_bbox'] for s in batch], dim=0) |
|
|
|
tr_param = [s['tr_param'] for s in batch] |
|
|
|
return { |
|
'image': image, |
|
'image_size': image_size, |
|
'bbox': bbox, |
|
'bbox_raw': bbox_raw, |
|
'expr': expr, |
|
'tok': tok, |
|
'tr_param': tr_param, |
|
'mask': mask, |
|
'mask_bbox': mask_bbox, |
|
} |
|
|
|
|
|
class RECDataset(torch.utils.data.Dataset): |
|
def __init__(self, transform=None, tokenizer=None, max_length=32, with_mask_bbox=False): |
|
super().__init__() |
|
self.samples = [] |
|
self.transform = transform |
|
self.tokenizer = tokenizer |
|
self.max_length = int(max_length) |
|
self.with_mask_bbox = bool(with_mask_bbox) |
|
|
|
def tokenize(self, inp, max_length): |
|
return self.tokenizer( |
|
inp, |
|
return_tensors='pt', |
|
padding='max_length', |
|
return_token_type_ids=False, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
truncation=True, |
|
max_length=max_length |
|
) |
|
|
|
def print_stats(self): |
|
print(f'{len(self.samples)} samples') |
|
lens = [len(expr.split()) for _, expr, _ in self.samples] |
|
print('expression lengths stats: ' |
|
f'min={np.min(lens):.1f}, ' |
|
f'mean={np.mean(lens):.1f}, ' |
|
f'median={np.median(lens):.1f}, ' |
|
f'max={np.max(lens):.1f}, ' |
|
f'99.9P={np.percentile(lens, 99.9):.1f}' |
|
) |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
def __getitem__(self, idx): |
|
file_name, expr, bbox = self.samples[idx] |
|
|
|
if not os.path.exists(file_name): |
|
raise IOError(f'{file_name} not found') |
|
img = Image.open(file_name).convert('RGB') |
|
|
|
|
|
|
|
|
|
|
|
W0, H0 = img.size |
|
|
|
|
|
|
|
|
|
|
|
sample = { |
|
'image': img, |
|
'image_size': (H0, W0), |
|
'bbox': bbox.clone(), |
|
'bbox_raw': bbox.clone(), |
|
'expr': expr, |
|
'tok': None, |
|
'mask': torch.ones((1, H0, W0), dtype=torch.float32), |
|
'mask_bbox': None, |
|
} |
|
|
|
|
|
if self.transform is None: |
|
sample['image'] = ToTensor()(sample['image']) |
|
else: |
|
sample = self.transform(sample) |
|
|
|
|
|
if self.tokenizer is not None: |
|
sample['tok'] = self.tokenize(sample['expr'], self.max_length) |
|
sample['tok']['length'] = sample['tok']['attention_mask'].sum(1).item() |
|
|
|
|
|
if self.with_mask_bbox: |
|
|
|
_, H, W = sample['image'].size() |
|
|
|
|
|
bbox = sample['bbox'].clone() |
|
bbox[:, (0, 2)] *= W |
|
bbox[:, (1, 3)] *= H |
|
bbox = clip_boxes_to_image((bbox + 0.5).long(), (H, W)) |
|
|
|
|
|
sample['mask_bbox'] = torch.zeros((1, H, W), dtype=torch.float32) |
|
for x1, y1, x2, y2 in bbox.tolist(): |
|
sample['mask_bbox'][:, y1:y2+1, x1:x2+1] = 1.0 |
|
|
|
return sample |
|
|
|
|
|
class RegionDescriptionsVisualGnome(RECDataset): |
|
def __init__(self, data_root, transform=None, tokenizer=None, |
|
max_length=32, with_mask_bbox=False): |
|
super().__init__(transform=transform, tokenizer=tokenizer, |
|
max_length=max_length, with_mask_bbox=with_mask_bbox) |
|
|
|
|
|
|
|
|
|
try: |
|
with open('./refcoco_valtest_ids.txt', 'r') as fh: |
|
refcoco_ids = [int(lin.strip()) for lin in fh.readlines()] |
|
except: |
|
refcoco_ids = [] |
|
|
|
def path_from_url(fname): |
|
return os.path.join(data_root, fname[fname.index('VG_100K'):]) |
|
|
|
with open(os.path.join(data_root, 'image_data.json'), 'r') as f: |
|
image_data = { |
|
data['image_id']: path_from_url(data['url']) |
|
for data in json.load(f) |
|
if data['coco_id'] is None or data['coco_id'] not in refcoco_ids |
|
} |
|
print(f'{len(image_data)} images') |
|
|
|
self.samples = [] |
|
|
|
with open(os.path.join(data_root, 'region_descriptions.json'), 'r') as f: |
|
for record in progressbar(ijson.items(f, 'item.regions.item'), desc='loading data'): |
|
if record['image_id'] not in image_data: |
|
continue |
|
file_name = image_data[record['image_id']] |
|
|
|
expr = record['phrase'] |
|
|
|
bbox = [record['x'], record['y'], record['width'], record['height']] |
|
bbox = torch.atleast_2d(torch.FloatTensor(bbox)) |
|
bbox = box_convert(bbox, 'xywh', 'xyxy') |
|
|
|
self.samples.append((file_name, expr, bbox)) |
|
|
|
self.print_stats() |
|
|
|
|
|
class ReferDataset(RECDataset): |
|
def __init__(self, data_root, dataset, split_by, split, transform=None, |
|
tokenizer=None, max_length=32, with_mask_bbox=False): |
|
super().__init__(transform=transform, tokenizer=tokenizer, |
|
max_length=max_length, with_mask_bbox=with_mask_bbox) |
|
|
|
|
|
try: |
|
import sys |
|
sys.path.append('refer') |
|
from refer import REFER |
|
except: |
|
raise RuntimeError('create a symlink to valid refer compilation ' |
|
'(see https://github.com/lichengunc/refer)') |
|
|
|
refer = REFER(data_root, dataset, split_by) |
|
ref_ids = sorted(refer.getRefIds(split=split)) |
|
|
|
self.samples = [] |
|
|
|
for rid in progressbar(ref_ids, desc='loading data'): |
|
ref = refer.Refs[rid] |
|
ann = refer.refToAnn[rid] |
|
|
|
file_name = refer.Imgs[ref['image_id']]['file_name'] |
|
if dataset == 'refclef': |
|
file_name = os.path.join( |
|
'refer', 'data', 'images', 'saiapr_tc-12', file_name |
|
) |
|
else: |
|
coco_set = file_name.split('_')[1] |
|
file_name = os.path.join( |
|
'refer', 'data', 'images', 'mscoco', coco_set, file_name |
|
) |
|
|
|
bbox = ann['bbox'] |
|
bbox = torch.atleast_2d(torch.FloatTensor(bbox)) |
|
bbox = box_convert(bbox, 'xywh', 'xyxy') |
|
|
|
sentences = [s['sent'] for s in ref['sentences']] |
|
if 'train' in split: |
|
sentences = list(set(sentences)) |
|
sentences = sorted(sentences) |
|
|
|
self.samples += [(file_name, expr, bbox) for expr in sentences] |
|
|
|
self.print_stats() |
|
|
|
|
|
class RefCLEF(ReferDataset): |
|
def __init__(self, *args, **kwargs): |
|
assert args[0] in ('train', 'val', 'test') |
|
super().__init__('refer/data', 'refclef', 'berkeley', *args, **kwargs) |
|
|
|
|
|
class RefCOCO(ReferDataset): |
|
def __init__(self, *args, **kwargs): |
|
assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB') |
|
super().__init__('refer/data', 'refcoco', 'unc', *args, **kwargs) |
|
|
|
|
|
class RefCOCOp(ReferDataset): |
|
def __init__(self, *args, **kwargs): |
|
assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB') |
|
super().__init__('refer/data', 'refcoco+', 'unc', *args, **kwargs) |
|
|
|
|
|
class RefCOCOg(ReferDataset): |
|
def __init__(self, *args, **kwargs): |
|
assert args[0] in ('train', 'val', 'test') |
|
super().__init__('refer/data', 'refcocog', 'umd', *args, **kwargs) |
|
|