eP-ALM / dataset /caption.py
mshukor
init
3eb682b
raw
history blame
No virus
8.83 kB
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import json
import random
from multiprocessing import Pool
import torch
from PIL import Image
from torch.utils.data.distributed import DistributedSampler
from dataset.randaugment import RandomAugment
import torch
from torchvision import transforms
import os
import re
class COCOCaptionFineTuneDataset(Dataset):
def __init__(self, split='karpathy_train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train',
data_dir='/data/mshukor/data', black_image=False):
super().__init__()
self.raw_dataset = raw_dataset
self.topk = topk
self.verbose = verbose
self.args = args
self.args.BUTD100 = False
self.mode = mode
dataset_dir = Path(data_dir)
coco_dir = dataset_dir.joinpath('COCO')
vg_dir = dataset_dir.joinpath('VG')
coco_img_dir = coco_dir.joinpath('images/')
coco_feature_dir = coco_dir.joinpath('features')
self.black_image = black_image
# Loading datasets to data
self.source = split
if self.verbose:
print('Data source: ', self.source)
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
normalize,
])
self.test_transform = transforms.Compose([
transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC),
transforms.ToTensor(),
normalize,
])
data_info_path = dataset_dir.joinpath('COCO/dataset_coco.json')
with open(data_info_path) as f:
karpathy_data = json.load(f)
split_rename = {
'train': 'train',
'restval': 'train',
'val': 'val',
'test': 'test'
}
n_images = 0
data = []
for datum in karpathy_data['images']:
re_split = split_rename[datum['split']]
if re_split != self.source.split('_')[-1]:
continue
if re_split == 'train':
for d in datum['sentences']:
img_id = datum['filename'].split('.')[0]
new_datum = {
'img_id': img_id,
'sent': d['raw'].strip(),
'targets': [d['raw'].strip() for d in datum['sentences']],
'is_train': True,
}
data.append(new_datum)
else:
img_id = datum['filename'].split('.')[0]
new_datum = {
'img_id': img_id,
# 'sent': d['raw'],
'targets': [d['raw'].strip() for d in datum['sentences']],
'is_train': False,
}
data.append(new_datum)
n_images += 1
if self.verbose:
print(f"{self.source} has {n_images} images")
print(f"Loaded {len(data)} data from", split)
if isinstance(self.topk, float) and (0 < self.topk <= 1):
used_samples = int(self.topk * len(data))
data = random.sample(data, used_samples)
if self.verbose:
print(f"Use only {len(data)} data")
elif self.topk > 0:
data = data[:int(self.topk)]
if self.verbose:
print(f"Use only {len(data)} data")
self.data = data
if self.verbose:
print("# all sentences:", len(self.data))
self.image_size = self.args.image_size
if mode == "train" and self.args.use_data_augmentation:
self.transform = self.train_transform
else:
self.transform = self.test_transform
self.source_to_h5 = {}
self.source_to_h5.update({
'train2014': coco_img_dir.joinpath(f'train2014'),
'val2014': coco_img_dir.joinpath(f'val2014'),
})
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
out_dict = {}
out_dict['args'] = self.args
datum = self.data[idx]
###### Image ######
img_id = datum['img_id']
out_dict['img_id'] = img_id
if self.args.BUTD100:
source = self.source
else:
if 'train' in img_id:
source = 'train2014'
elif 'val' in img_id:
source = 'val2014'
path = self.source_to_h5[source].joinpath(f"{img_id}.jpg")
image = Image.open(path).convert('RGB')
out_dict["image"] = self.transform(image)
if self.black_image:
out_dict["image"] = torch.zeros_like(out_dict["image"])
if datum['is_train']:
sent = datum['sent'].strip()
out_dict['sent'] = sent
if 'targets' in datum:
out_dict['targets'] = datum['targets']
return out_dict
def collate_fn(self, batch):
batch_entry = {}
B = len(batch)
if 'target_ids' in batch[0]:
T_W_L = max(entry['target_length'] for entry in batch)
target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
targets = []
img_ids = []
img_paths = []
input_text = []
images = []
sents = []
for i, entry in enumerate(batch):
images.append(entry['image'])
img_ids.append(entry['img_id'])
if 'target_ids' in entry:
target_ids[i, :entry['target_length']] = entry['target_ids']
if 'targets' in entry:
targets.append(entry['targets'])
if 'sent' in entry:
sents.append(entry['sent'])
batch_entry['images'] = torch.stack(images)
batch_entry['img_id'] = img_ids
batch_entry['img_paths'] = img_paths
if 'sent' in entry:
batch_entry['sent'] = sents
batch_entry['targets'] = targets
batch_entry['task'] = 'caption'
return batch_entry
def pre_caption(caption,max_words):
caption = re.sub(
r"([,.'!?\"()*#:;~])",
'',
caption.lower(),
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
#truncate caption
caption_words = caption.split(' ')
if len(caption_words)>max_words:
caption = ' '.join(caption_words[:max_words])
return caption
def get_loader(args, split='train', mode='train',
batch_size=32, workers=4, distributed=False, gpu=0,
topk=-1, data_dir='/data/mshukor/data', local_rank=None, world_size=None, verbose=False,
config_dir=None, black_image=False):
dataset = COCOCaptionFineTuneDataset(
split,
# raw_dataset=_dset,
rank=gpu,
topk=topk,
verbose=verbose,
args=args,
mode=mode, data_dir=data_dir, black_image=black_image)
if distributed and mode == 'train':
train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank)
else:
train_sampler = None
if mode == 'train':
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=(train_sampler is None),
num_workers=workers, pin_memory=True, sampler=train_sampler,
collate_fn=dataset.collate_fn)
else:
loader = DataLoader(
dataset,
batch_size=batch_size, shuffle=False,
num_workers=workers, pin_memory=True,
sampler=None,
collate_fn=dataset.collate_fn,
drop_last=False)
if verbose:
loader.evaluator = COCOCaptionEvaluator()
loader.task = 'caption'
return loader
class COCOCaptionEvaluator:
def __init__(self):
import language_evaluation
self.evaluator = language_evaluation.CocoEvaluator(verbose=False)
def evaluate(self, predicts, answers):
results = self.evaluator.run_evaluation(predicts, answers)
return results