eP-ALM / dataset /.ipynb_checkpoints /caption-checkpoint.py
mshukor
init
3eb682b
raw
history blame
19.9 kB
from torch.utils.data import DataLoader, Dataset, Sampler
from pathlib import Path
from collections import defaultdict
import json
import random
from multiprocessing import Pool
import h5py
import pickle
import math
from tqdm import tqdm
import torch
import numpy as np
from copy import deepcopy
from PIL import Image
from torch.utils.data.distributed import DistributedSampler
# from transformers import T5TokenizerFast, BartTokenizer
# from tokenization import VLT5TokenizerFast
# from vis_encoder import _transform
# from vqa_raw_data import augmentation_transform
from dataset.randaugment import RandomAugment
import torch
from torch import nn
from torchvision import transforms
import os
import re
# project_dir = Path(__file__).resolve().parent.parent # VLT5
# workspace_dir = project_dir.parent
# dataset_dir = workspace_dir.joinpath('datasets/').resolve()
# coco_dir = dataset_dir.joinpath('COCO')
# vg_dir = dataset_dir.joinpath('VG')
# coco_img_dir = coco_dir.joinpath('images/')
# coco_feature_dir = coco_dir.joinpath('features')
class COCOCaptionFineTuneDataset(Dataset):
def __init__(self, split='karpathy_train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train',
data_dir='/data/mshukor/data'):
super().__init__()
self.raw_dataset = raw_dataset
self.topk = topk
self.verbose = verbose
self.args = args
self.args.BUTD100 = False
self.mode = mode
dataset_dir = Path(data_dir)
coco_dir = dataset_dir.joinpath('COCO')
vg_dir = dataset_dir.joinpath('VG')
coco_img_dir = coco_dir.joinpath('images/')
coco_feature_dir = coco_dir.joinpath('features')
# Loading datasets to data
self.source = split
if self.verbose:
print('Data source: ', self.source)
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
normalize,
])
self.test_transform = transforms.Compose([
transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC),
transforms.ToTensor(),
normalize,
])
# if self.args.tokenizer is None:
# self.args.tokenizer = self.args.backbone
# if 't5' in self.args.tokenizer:
# if self.args.use_vision:
# self.tokenizer = VLT5TokenizerFast.from_pretrained(
# args.backbone,
# # max_length=self.args.max_text_length,
# do_lower_case=self.args.do_lower_case)
# else:
# self.tokenizer = T5TokenizerFast.from_pretrained(
# args.backbone,
# # max_length=self.args.max_text_length,
# do_lower_case=self.args.do_lower_case)
# elif 'bart' in self.args.tokenizer:
# self.tokenizer = BartTokenizer.from_pretrained(
# args.backbone,
# # max_length=self.args.max_text_length,
# do_lower_case=self.args.do_lower_case)
# additional_special_tokens = [f'<extra_id_{i}>' for i in range(100-1, -1, -1)] + \
# [f'<vis_extra_id_{i}>' for i in range(100-1, -1, -1)]
# special_tokens_dict = {'additional_special_tokens': additional_special_tokens}
# num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)
# if self.args.oscar_tags:
# # Load VG Classes
# vg_classes = []
# with open(vg_dir.joinpath('objects_vocab.txt')) as f:
# for obj in f.readlines():
# vg_classes.append(obj.split(',')[0].lower().strip())
# self.vg_classes = vg_classes
data_info_path = dataset_dir.joinpath('COCO/dataset_coco.json')
with open(data_info_path) as f:
karpathy_data = json.load(f)
split_rename = {
'train': 'train',
'restval': 'train',
'val': 'val',
'test': 'test'
}
n_images = 0
data = []
for datum in karpathy_data['images']:
re_split = split_rename[datum['split']]
if re_split != self.source.split('_')[-1]:
continue
if re_split == 'train':
for d in datum['sentences']:
# if self.args.BUTD100:
# img_id = str(int(datum['filename'].split('.')[0].split('_')[-1]))
# else:
img_id = datum['filename'].split('.')[0]
new_datum = {
'img_id': img_id,
'sent': d['raw'].strip(),
'targets': [d['raw'].strip() for d in datum['sentences']],
'is_train': True,
}
data.append(new_datum)
else:
# if self.args.BUTD100:
# img_id = str(
# int(datum['filename'].split('.')[0].split('_')[-1]))
# else:
img_id = datum['filename'].split('.')[0]
new_datum = {
'img_id': img_id,
# 'sent': d['raw'],
'targets': [d['raw'].strip() for d in datum['sentences']],
'is_train': False,
}
data.append(new_datum)
n_images += 1
if self.verbose:
print(f"{self.source} has {n_images} images")
print(f"Loaded {len(data)} data from", split)
# self.n_gpus = torch.cuda.device_count()
# self.rank = rank
if isinstance(self.topk, float) and (0 < self.topk <= 1):
used_samples = int(self.topk * len(data))
data = random.sample(data, used_samples)
if self.verbose:
print(f"Use only {len(data)} data")
elif self.topk > 0:
data = data[:int(self.topk)]
if self.verbose:
print(f"Use only {len(data)} data")
self.data = data
if self.verbose:
print("# all sentences:", len(self.data))
# self.n_boxes = args.n_boxes
self.image_size = self.args.image_size
if mode == "train" and self.args.use_data_augmentation:
self.transform = self.train_transform
else:
self.transform = self.test_transform
self.source_to_h5 = {}
# if self.args.max_n_boxes == 36:
self.source_to_h5.update({
'train2014': coco_img_dir.joinpath(f'train2014'),
'val2014': coco_img_dir.joinpath(f'val2014'),
})
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
out_dict = {}
out_dict['args'] = self.args
datum = self.data[idx]
###### Image ######
# if self.args.use_vision:
img_id = datum['img_id']
out_dict['img_id'] = img_id
if self.args.BUTD100:
source = self.source
else:
if 'train' in img_id:
source = 'train2014'
elif 'val' in img_id:
source = 'val2014'
path = self.source_to_h5[source].joinpath(f"{img_id}.jpg")
image = Image.open(path).convert('RGB')
out_dict["image"] = self.transform(image)
# out_dict['n_boxes'] = self.args.n_boxes
###### Text #####
# if self.args.no_prefix:
# input_text = ''
# input_ids = []
# else:
# if self.args.prefix is None:
# prefix = f'{self.args.prompt}'
# elif self.args.prefix == 'span':
# prefix = "span prediction:"
# elif self.args.prefix == 'denoise':
# prefix = "denoise text: <mask>"
# elif self.args.prefix == 'mask':
# if 'bart' in self.args.tokenizer:
# prefix = "<mask>"
# input_tokens = [prefix]
# # if self.args.oscar_tags:
# # prefix = f'describe image with tags:'
# # input_tokens = [prefix]
# # obj_ids = f[f'{img_id}/obj_id'][()]
# # for obj_id in obj_ids:
# # obj = self.vg_classes[obj_id]
# # if obj not in input_tokens:
# # input_tokens.append(obj)
# input_text = ' '.join(input_tokens)
# # if 't5' in self.args.tokenizer:
# # input_ids = self.tokenizer.encode(
# # input_text,
# # max_length=self.args.max_text_length, truncation=True)
# # elif 'bart' in self.args.tokenizer:
# # input_ids = self.tokenizer.encode(
# # input_text,
# # max_length=self.args.max_text_length, truncation=True)
# # else:
# # input_ids = self.tokenizer.convert_tokens_to_ids(
# # self.tokenizer.tokenize(input_text)[:self.args.max_text_length - 1] + ['[SEP]'])
# out_dict['input_text'] = input_text
# out_dict['input_ids'] = torch.LongTensor(input_ids)
# out_dict['input_length'] = len(input_ids)
if datum['is_train']:
sent = datum['sent'].strip()
# if 't5' in self.args.tokenizer:
# target_ids = self.tokenizer.encode(sent, max_length=self.args.gen_max_length, truncation=True)
# elif 'bart' in self.args.tokenizer:
# target_ids = self.tokenizer.encode(sent, max_length=self.args.gen_max_length, truncation=True)
# assert len(target_ids) <= self.args.gen_max_length, len(target_ids)
out_dict['sent'] = sent
# out_dict['target_ids'] = torch.LongTensor(target_ids)
# out_dict['target_length'] = len(target_ids)
if 'targets' in datum:
out_dict['targets'] = datum['targets']
return out_dict
def collate_fn(self, batch):
batch_entry = {}
B = len(batch)
# S_W_L = max(entry['input_length'] for entry in batch)
# input_ids = torch.ones(B, S_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
# if self.args.no_prefix:
# assert input_ids.size() == (B, 0)
# if self.args.use_vision:
# pass
if 'target_ids' in batch[0]:
T_W_L = max(entry['target_length'] for entry in batch)
target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
# sentences = []
targets = []
img_ids = []
img_paths = []
input_text = []
images = []
sents = []
for i, entry in enumerate(batch):
# input_ids[i, :entry['input_length']] = entry['input_ids']
# if self.args.use_vision:
# n_boxes = entry['n_boxes']
images.append(entry['image'])
img_ids.append(entry['img_id'])
# img_paths.append(entry['img_path'])
if 'target_ids' in entry:
target_ids[i, :entry['target_length']] = entry['target_ids']
# if 'input_text' in entry:
# input_text.append(entry['input_text'])
# sentences.append(entry['sent'])
if 'targets' in entry:
targets.append(entry['targets'])
if 'sent' in entry:
sents.append(entry['sent'])
# batch_entry['input_ids'] = input_ids
# if 'target_ids' in batch[0]:
# word_mask = target_ids != self.tokenizer.pad_token_id
# target_ids[~word_mask] = -100
# batch_entry['target_ids'] = target_ids
# if self.args.use_vision:
batch_entry['images'] = torch.stack(images)
batch_entry['img_id'] = img_ids
batch_entry['img_paths'] = img_paths
if 'sent' in entry:
batch_entry['sent'] = sents
# batch_entry['sent'] = sentences
# batch_entry['input_text'] = input_text
batch_entry['targets'] = targets
batch_entry['task'] = 'caption'
return batch_entry
def pre_caption(caption,max_words):
caption = re.sub(
r"([,.'!?\"()*#:;~])",
'',
caption.lower(),
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
#truncate caption
caption_words = caption.split(' ')
if len(caption_words)>max_words:
caption = ' '.join(caption_words[:max_words])
return caption
class CCDataset(Dataset):
def __init__(self, split='CC', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train',
data_dir='/data/mshukor/data', config_dir='/data/mshukor/data/cc3m.json', max_words=30):
super().__init__()
self.raw_dataset = raw_dataset
self.topk = topk
self.verbose = verbose
self.args = args
self.mode = mode
data = []
ann_files = [config_dir]
ann_file = []
for p in ann_files:
ann_file.append(os.path.join(args.data_json_dir, p))
for f in ann_file:
tmp = json.load(open(f,'r'))
data += tmp
print('size of', f, len(tmp))
print(len(data))
self.max_words = max_words
for e in data:
e['image'] = os.path.join(data_dir, ('/').join(e['image'].split('/')[4:]))
# Loading datasets to data
self.source = split
if self.verbose:
print('Data source: ', self.source)
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
transforms.ToTensor(),
normalize,
])
self.test_transform = transforms.Compose([
transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC),
transforms.ToTensor(),
normalize,
])
if isinstance(self.topk, float) and (0 < self.topk <= 1):
used_samples = int(self.topk * len(data))
data = random.sample(data, used_samples)
if self.verbose:
print(f"Use only {len(data)} data")
elif self.topk > 0:
data = data[:int(self.topk)]
if self.verbose:
print(f"Use only {len(data)} data")
self.data = data
if self.verbose:
print("# all sentences:", len(self.data))
self.image_size = self.args.image_size
if mode == "train" and self.args.use_data_augmentation:
self.transform = self.train_transform
else:
self.transform = self.test_transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
out_dict = {}
out_dict['args'] = self.args
datum = self.data[idx]
if type(datum['caption']) == list:
caption = pre_caption(random.choice(datum['caption']), self.max_words)
else:
caption = pre_caption(datum['caption'], self.max_words)
###### Image ######
image = Image.open(datum['image']).convert('RGB')
img_id = datum['image'].split('/')[-1].split('.')[0]
out_dict['img_id'] = img_id
out_dict["image"] = self.transform(image)
out_dict['sent'] = caption
out_dict['targets'] = caption
return out_dict
def collate_fn(self, batch):
batch_entry = {}
B = len(batch)
if 'target_ids' in batch[0]:
T_W_L = max(entry['target_length'] for entry in batch)
target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
targets = []
img_ids = []
img_paths = []
input_text = []
images = []
sents = []
for i, entry in enumerate(batch):
images.append(entry['image'])
img_ids.append(entry['img_id'])
if 'target_ids' in entry:
target_ids[i, :entry['target_length']] = entry['target_ids']
if 'targets' in entry:
targets.append(entry['targets'])
if 'sent' in entry:
sents.append(entry['sent'])
# if self.args.use_vision:
batch_entry['images'] = torch.stack(images)
batch_entry['img_id'] = img_ids
batch_entry['img_paths'] = img_paths
if 'sent' in entry:
batch_entry['sent'] = sents
batch_entry['targets'] = targets
batch_entry['task'] = 'caption'
return batch_entry
def get_loader(args, split='train', mode='train',
batch_size=32, workers=4, distributed=False, gpu=0,
topk=-1, data_dir='/data/mshukor/data', local_rank=None, world_size=None, verbose=False, config_dir=None):
# if 'mscoco' in split:
# verbose = (gpu == 0)
if 'CC' in split:
dataset = CCDataset(split, data_dir=data_dir, mode=mode, topk=topk, args=args, verbose=verbose, rank=gpu, config_dir=config_dir)
else:
dataset = COCOCaptionFineTuneDataset(
split,
# raw_dataset=_dset,
rank=gpu,
topk=topk,
verbose=verbose,
args=args,
mode=mode, data_dir=data_dir)
if distributed and mode == 'train':
train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank)
# train_sampler = DistributedSampler(dataset)
# train_sampler = RandomNonreplacmentSampler(dataset, dataset.n_iter)
else:
train_sampler = None
if mode == 'train':
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=(train_sampler is None),
num_workers=workers, pin_memory=True, sampler=train_sampler,
collate_fn=dataset.collate_fn)
else:
loader = DataLoader(
dataset,
batch_size=batch_size, shuffle=False,
num_workers=workers, pin_memory=True,
sampler=None,
collate_fn=dataset.collate_fn,
drop_last=False)
if verbose:
loader.evaluator = COCOCaptionEvaluator()
loader.task = 'caption'
return loader
class COCOCaptionEvaluator:
def __init__(self):
import language_evaluation
self.evaluator = language_evaluation.CocoEvaluator(verbose=False)
def evaluate(self, predicts, answers):
results = self.evaluator.run_evaluation(predicts, answers)
return results