|  | import os | 
					
						
						|  |  | 
					
						
						|  | import json | 
					
						
						|  |  | 
					
						
						|  | import random | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | import ijson | 
					
						
						|  |  | 
					
						
						|  | import numpy as np | 
					
						
						|  |  | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  | from torchvision.transforms import ToTensor | 
					
						
						|  |  | 
					
						
						|  | from torchvision.ops import box_convert, clip_boxes_to_image | 
					
						
						|  |  | 
					
						
						|  | from re_classifier import REClassifier | 
					
						
						|  |  | 
					
						
						|  | from utils import progressbar | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def collate_fn(batch): | 
					
						
						|  | image = torch.stack([s['image'] for s in batch], dim=0) | 
					
						
						|  |  | 
					
						
						|  | image_size = torch.FloatTensor([s['image_size'] for s in batch]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bbox = torch.cat([s['bbox'] for s in batch], dim=0) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bbox_raw = torch.cat([s['bbox_raw'] for s in batch], dim=0) | 
					
						
						|  |  | 
					
						
						|  | expr = [s['expr'] for s in batch] | 
					
						
						|  |  | 
					
						
						|  | tok = None | 
					
						
						|  | if batch[0]['tok'] is not None: | 
					
						
						|  | tok = { | 
					
						
						|  | 'input_ids': torch.cat([s['tok']['input_ids'] for s in batch], dim=0), | 
					
						
						|  | 'attention_mask': torch.cat([s['tok']['attention_mask'] for s in batch], dim=0) | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | max_length = max([s['tok']['length'] for s in batch]) | 
					
						
						|  | tok = { | 
					
						
						|  | 'input_ids': tok['input_ids'][:, :max_length], | 
					
						
						|  | 'attention_mask': tok['attention_mask'][:, :max_length], | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | mask = None | 
					
						
						|  | if batch[0]['mask'] is not None: | 
					
						
						|  | mask = torch.stack([s['mask'] for s in batch], dim=0) | 
					
						
						|  |  | 
					
						
						|  | mask_bbox = None | 
					
						
						|  | if batch[0]['mask_bbox'] is not None: | 
					
						
						|  | mask_bbox = torch.stack([s['mask_bbox'] for s in batch], dim=0) | 
					
						
						|  |  | 
					
						
						|  | tr_param = [s['tr_param'] for s in batch] | 
					
						
						|  |  | 
					
						
						|  | return { | 
					
						
						|  | 'image': image, | 
					
						
						|  | 'image_size': image_size, | 
					
						
						|  | 'bbox': bbox, | 
					
						
						|  | 'bbox_raw': bbox_raw, | 
					
						
						|  | 'expr': expr, | 
					
						
						|  | 'tok': tok, | 
					
						
						|  | 'tr_param': tr_param, | 
					
						
						|  | 'mask': mask, | 
					
						
						|  | 'mask_bbox': mask_bbox, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RECDataset(torch.utils.data.Dataset): | 
					
						
						|  | def __init__(self, transform=None, tokenizer=None, max_length=32, with_mask_bbox=False): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.samples = [] | 
					
						
						|  | self.transform = transform | 
					
						
						|  | self.tokenizer = tokenizer | 
					
						
						|  | self.max_length = int(max_length) | 
					
						
						|  | self.with_mask_bbox = bool(with_mask_bbox) | 
					
						
						|  |  | 
					
						
						|  | def tokenize(self, inp, max_length): | 
					
						
						|  | return self.tokenizer( | 
					
						
						|  | inp, | 
					
						
						|  | return_tensors='pt', | 
					
						
						|  | padding='max_length', | 
					
						
						|  | return_token_type_ids=False, | 
					
						
						|  | return_attention_mask=True, | 
					
						
						|  | add_special_tokens=True, | 
					
						
						|  | truncation=True, | 
					
						
						|  | max_length=max_length | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def print_stats(self): | 
					
						
						|  | print(f'{len(self.samples)} samples') | 
					
						
						|  | lens = [len(expr.split()) for _, expr, _ in self.samples] | 
					
						
						|  | print('expression lengths stats: ' | 
					
						
						|  | f'min={np.min(lens):.1f}, ' | 
					
						
						|  | f'mean={np.mean(lens):.1f}, ' | 
					
						
						|  | f'median={np.median(lens):.1f}, ' | 
					
						
						|  | f'max={np.max(lens):.1f}, ' | 
					
						
						|  | f'99.9P={np.percentile(lens, 99.9):.1f}' | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def __len__(self): | 
					
						
						|  | return len(self.samples) | 
					
						
						|  |  | 
					
						
						|  | def __getitem__(self, idx): | 
					
						
						|  | file_name, expr, bbox = self.samples[idx] | 
					
						
						|  |  | 
					
						
						|  | if not os.path.exists(file_name): | 
					
						
						|  | raise IOError(f'{file_name} not found') | 
					
						
						|  | img = Image.open(file_name).convert('RGB') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | W0, H0 = img.size | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sample = { | 
					
						
						|  | 'image': img, | 
					
						
						|  | 'image_size': (H0, W0), | 
					
						
						|  | 'bbox': bbox.clone(), | 
					
						
						|  | 'bbox_raw': bbox.clone(), | 
					
						
						|  | 'expr': expr, | 
					
						
						|  | 'tok': None, | 
					
						
						|  | 'mask': torch.ones((1, H0, W0), dtype=torch.float32), | 
					
						
						|  | 'mask_bbox': None, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.transform is None: | 
					
						
						|  | sample['image'] = ToTensor()(sample['image']) | 
					
						
						|  | else: | 
					
						
						|  | sample = self.transform(sample) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.tokenizer is not None: | 
					
						
						|  | sample['tok'] = self.tokenize(sample['expr'], self.max_length) | 
					
						
						|  | sample['tok']['length'] = sample['tok']['attention_mask'].sum(1).item() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.with_mask_bbox: | 
					
						
						|  |  | 
					
						
						|  | _, H, W = sample['image'].size() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | bbox = sample['bbox'].clone() | 
					
						
						|  | bbox[:, (0, 2)] *= W | 
					
						
						|  | bbox[:, (1, 3)] *= H | 
					
						
						|  | bbox = clip_boxes_to_image((bbox + 0.5).long(), (H, W)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sample['mask_bbox'] = torch.zeros((1, H, W), dtype=torch.float32) | 
					
						
						|  | for x1, y1, x2, y2 in bbox.tolist(): | 
					
						
						|  | sample['mask_bbox'][:, y1:y2+1, x1:x2+1] = 1.0 | 
					
						
						|  |  | 
					
						
						|  | return sample | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RegionDescriptionsVisualGnome(RECDataset): | 
					
						
						|  | def __init__(self, data_root, transform=None, tokenizer=None, | 
					
						
						|  | max_length=32, with_mask_bbox=False): | 
					
						
						|  | super().__init__(transform=transform, tokenizer=tokenizer, | 
					
						
						|  | max_length=max_length, with_mask_bbox=with_mask_bbox) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | with open('./refcoco_valtest_ids.txt', 'r') as fh: | 
					
						
						|  | refcoco_ids = [int(lin.strip()) for lin in fh.readlines()] | 
					
						
						|  | except: | 
					
						
						|  | refcoco_ids = [] | 
					
						
						|  |  | 
					
						
						|  | def path_from_url(fname): | 
					
						
						|  | return os.path.join(data_root, fname[fname.index('VG_100K'):]) | 
					
						
						|  |  | 
					
						
						|  | with open(os.path.join(data_root, 'image_data.json'), 'r') as f: | 
					
						
						|  | image_data = { | 
					
						
						|  | data['image_id']: path_from_url(data['url']) | 
					
						
						|  | for data in json.load(f) | 
					
						
						|  | if data['coco_id'] is None or data['coco_id'] not in refcoco_ids | 
					
						
						|  | } | 
					
						
						|  | print(f'{len(image_data)} images') | 
					
						
						|  |  | 
					
						
						|  | self.samples = [] | 
					
						
						|  |  | 
					
						
						|  | with open(os.path.join(data_root, 'region_descriptions.json'), 'r') as f: | 
					
						
						|  | for record in progressbar(ijson.items(f, 'item.regions.item'), desc='loading data'): | 
					
						
						|  | if record['image_id'] not in image_data: | 
					
						
						|  | continue | 
					
						
						|  | file_name = image_data[record['image_id']] | 
					
						
						|  |  | 
					
						
						|  | expr = record['phrase'] | 
					
						
						|  |  | 
					
						
						|  | bbox = [record['x'], record['y'], record['width'], record['height']] | 
					
						
						|  | bbox = torch.atleast_2d(torch.FloatTensor(bbox)) | 
					
						
						|  | bbox = box_convert(bbox, 'xywh', 'xyxy') | 
					
						
						|  |  | 
					
						
						|  | self.samples.append((file_name, expr, bbox)) | 
					
						
						|  |  | 
					
						
						|  | self.print_stats() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ReferDataset(RECDataset): | 
					
						
						|  | def __init__(self, data_root, dataset, split_by, split, transform=None, | 
					
						
						|  | tokenizer=None, max_length=32, with_mask_bbox=False): | 
					
						
						|  | super().__init__(transform=transform, tokenizer=tokenizer, | 
					
						
						|  | max_length=max_length, with_mask_bbox=with_mask_bbox) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | import sys | 
					
						
						|  | sys.path.append('refer') | 
					
						
						|  | from refer import REFER | 
					
						
						|  | except: | 
					
						
						|  | raise RuntimeError('create a symlink to valid refer compilation ' | 
					
						
						|  | '(see https://github.com/lichengunc/refer)') | 
					
						
						|  |  | 
					
						
						|  | refer = REFER(data_root, dataset, split_by) | 
					
						
						|  | ref_ids = sorted(refer.getRefIds(split=split)) | 
					
						
						|  |  | 
					
						
						|  | self.samples = [] | 
					
						
						|  |  | 
					
						
						|  | for rid in progressbar(ref_ids, desc='loading data'): | 
					
						
						|  | ref = refer.Refs[rid] | 
					
						
						|  | ann = refer.refToAnn[rid] | 
					
						
						|  |  | 
					
						
						|  | file_name = refer.Imgs[ref['image_id']]['file_name'] | 
					
						
						|  | if dataset == 'refclef': | 
					
						
						|  | file_name = os.path.join( | 
					
						
						|  | 'refer', 'data', 'images', 'saiapr_tc-12', file_name | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | coco_set = file_name.split('_')[1] | 
					
						
						|  | file_name = os.path.join( | 
					
						
						|  | 'refer', 'data', 'images', 'mscoco', coco_set, file_name | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | bbox = ann['bbox'] | 
					
						
						|  | bbox = torch.atleast_2d(torch.FloatTensor(bbox)) | 
					
						
						|  | bbox = box_convert(bbox, 'xywh', 'xyxy') | 
					
						
						|  |  | 
					
						
						|  | sentences = [s['sent'] for s in ref['sentences']] | 
					
						
						|  | if 'train' in split: | 
					
						
						|  | sentences = list(set(sentences)) | 
					
						
						|  | sentences = sorted(sentences) | 
					
						
						|  |  | 
					
						
						|  | self.samples += [(file_name, expr, bbox) for expr in sentences] | 
					
						
						|  |  | 
					
						
						|  | self.print_stats() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RefCLEF(ReferDataset): | 
					
						
						|  | def __init__(self, *args, **kwargs): | 
					
						
						|  | assert args[0] in ('train', 'val', 'test') | 
					
						
						|  | super().__init__('refer/data', 'refclef', 'berkeley', *args, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RefCOCO(ReferDataset): | 
					
						
						|  | def __init__(self, *args, **kwargs): | 
					
						
						|  | assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB') | 
					
						
						|  | super().__init__('refer/data', 'refcoco', 'unc', *args, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RefCOCOp(ReferDataset): | 
					
						
						|  | def __init__(self, *args, **kwargs): | 
					
						
						|  | assert args[0] in ('train', 'val', 'trainval', 'testA', 'testB') | 
					
						
						|  | super().__init__('refer/data', 'refcoco+', 'unc', *args, **kwargs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class RefCOCOg(ReferDataset): | 
					
						
						|  | def __init__(self, *args, **kwargs): | 
					
						
						|  | assert args[0] in ('train', 'val', 'test') | 
					
						
						|  | super().__init__('refer/data', 'refcocog', 'umd', *args, **kwargs) | 
					
						
						|  |  |