from torch.utils.data import DataLoader, Dataset, Sampler from pathlib import Path from collections import defaultdict import json import random from multiprocessing import Pool import h5py import pickle import math from tqdm import tqdm import torch import numpy as np from copy import deepcopy from PIL import Image from torch.utils.data.distributed import DistributedSampler # from transformers import T5TokenizerFast, BartTokenizer # from tokenization import VLT5TokenizerFast # from vis_encoder import _transform # from vqa_raw_data import augmentation_transform from dataset.randaugment import RandomAugment import torch from torch import nn from torchvision import transforms import os import re # project_dir = Path(__file__).resolve().parent.parent # VLT5 # workspace_dir = project_dir.parent # dataset_dir = workspace_dir.joinpath('datasets/').resolve() # coco_dir = dataset_dir.joinpath('COCO') # vg_dir = dataset_dir.joinpath('VG') # coco_img_dir = coco_dir.joinpath('images/') # coco_feature_dir = coco_dir.joinpath('features') class COCOCaptionFineTuneDataset(Dataset): def __init__(self, split='karpathy_train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train', data_dir='/data/mshukor/data'): super().__init__() self.raw_dataset = raw_dataset self.topk = topk self.verbose = verbose self.args = args self.args.BUTD100 = False self.mode = mode dataset_dir = Path(data_dir) coco_dir = dataset_dir.joinpath('COCO') vg_dir = dataset_dir.joinpath('VG') coco_img_dir = coco_dir.joinpath('images/') coco_feature_dir = coco_dir.joinpath('features') # Loading datasets to data self.source = split if self.verbose: print('Data source: ', self.source) normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) self.train_transform = transforms.Compose([ transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), transforms.ToTensor(), normalize, ]) self.test_transform = transforms.Compose([ transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC), transforms.ToTensor(), normalize, ]) # if self.args.tokenizer is None: # self.args.tokenizer = self.args.backbone # if 't5' in self.args.tokenizer: # if self.args.use_vision: # self.tokenizer = VLT5TokenizerFast.from_pretrained( # args.backbone, # # max_length=self.args.max_text_length, # do_lower_case=self.args.do_lower_case) # else: # self.tokenizer = T5TokenizerFast.from_pretrained( # args.backbone, # # max_length=self.args.max_text_length, # do_lower_case=self.args.do_lower_case) # elif 'bart' in self.args.tokenizer: # self.tokenizer = BartTokenizer.from_pretrained( # args.backbone, # # max_length=self.args.max_text_length, # do_lower_case=self.args.do_lower_case) # additional_special_tokens = [f'' for i in range(100-1, -1, -1)] + \ # [f'' for i in range(100-1, -1, -1)] # special_tokens_dict = {'additional_special_tokens': additional_special_tokens} # num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict) # if self.args.oscar_tags: # # Load VG Classes # vg_classes = [] # with open(vg_dir.joinpath('objects_vocab.txt')) as f: # for obj in f.readlines(): # vg_classes.append(obj.split(',')[0].lower().strip()) # self.vg_classes = vg_classes data_info_path = dataset_dir.joinpath('COCO/dataset_coco.json') with open(data_info_path) as f: karpathy_data = json.load(f) split_rename = { 'train': 'train', 'restval': 'train', 'val': 'val', 'test': 'test' } n_images = 0 data = [] for datum in karpathy_data['images']: re_split = split_rename[datum['split']] if re_split != self.source.split('_')[-1]: continue if re_split == 'train': for d in datum['sentences']: # if self.args.BUTD100: # img_id = str(int(datum['filename'].split('.')[0].split('_')[-1])) # else: img_id = datum['filename'].split('.')[0] new_datum = { 'img_id': img_id, 'sent': d['raw'].strip(), 'targets': [d['raw'].strip() for d in datum['sentences']], 'is_train': True, } data.append(new_datum) else: # if self.args.BUTD100: # img_id = str( # int(datum['filename'].split('.')[0].split('_')[-1])) # else: img_id = datum['filename'].split('.')[0] new_datum = { 'img_id': img_id, # 'sent': d['raw'], 'targets': [d['raw'].strip() for d in datum['sentences']], 'is_train': False, } data.append(new_datum) n_images += 1 if self.verbose: print(f"{self.source} has {n_images} images") print(f"Loaded {len(data)} data from", split) # self.n_gpus = torch.cuda.device_count() # self.rank = rank if isinstance(self.topk, float) and (0 < self.topk <= 1): used_samples = int(self.topk * len(data)) data = random.sample(data, used_samples) if self.verbose: print(f"Use only {len(data)} data") elif self.topk > 0: data = data[:int(self.topk)] if self.verbose: print(f"Use only {len(data)} data") self.data = data if self.verbose: print("# all sentences:", len(self.data)) # self.n_boxes = args.n_boxes self.image_size = self.args.image_size if mode == "train" and self.args.use_data_augmentation: self.transform = self.train_transform else: self.transform = self.test_transform self.source_to_h5 = {} # if self.args.max_n_boxes == 36: self.source_to_h5.update({ 'train2014': coco_img_dir.joinpath(f'train2014'), 'val2014': coco_img_dir.joinpath(f'val2014'), }) def __len__(self): return len(self.data) def __getitem__(self, idx): out_dict = {} out_dict['args'] = self.args datum = self.data[idx] ###### Image ###### # if self.args.use_vision: img_id = datum['img_id'] out_dict['img_id'] = img_id if self.args.BUTD100: source = self.source else: if 'train' in img_id: source = 'train2014' elif 'val' in img_id: source = 'val2014' path = self.source_to_h5[source].joinpath(f"{img_id}.jpg") image = Image.open(path).convert('RGB') out_dict["image"] = self.transform(image) # out_dict['n_boxes'] = self.args.n_boxes ###### Text ##### # if self.args.no_prefix: # input_text = '' # input_ids = [] # else: # if self.args.prefix is None: # prefix = f'{self.args.prompt}' # elif self.args.prefix == 'span': # prefix = "span prediction:" # elif self.args.prefix == 'denoise': # prefix = "denoise text: " # elif self.args.prefix == 'mask': # if 'bart' in self.args.tokenizer: # prefix = "" # input_tokens = [prefix] # # if self.args.oscar_tags: # # prefix = f'describe image with tags:' # # input_tokens = [prefix] # # obj_ids = f[f'{img_id}/obj_id'][()] # # for obj_id in obj_ids: # # obj = self.vg_classes[obj_id] # # if obj not in input_tokens: # # input_tokens.append(obj) # input_text = ' '.join(input_tokens) # # if 't5' in self.args.tokenizer: # # input_ids = self.tokenizer.encode( # # input_text, # # max_length=self.args.max_text_length, truncation=True) # # elif 'bart' in self.args.tokenizer: # # input_ids = self.tokenizer.encode( # # input_text, # # max_length=self.args.max_text_length, truncation=True) # # else: # # input_ids = self.tokenizer.convert_tokens_to_ids( # # self.tokenizer.tokenize(input_text)[:self.args.max_text_length - 1] + ['[SEP]']) # out_dict['input_text'] = input_text # out_dict['input_ids'] = torch.LongTensor(input_ids) # out_dict['input_length'] = len(input_ids) if datum['is_train']: sent = datum['sent'].strip() # if 't5' in self.args.tokenizer: # target_ids = self.tokenizer.encode(sent, max_length=self.args.gen_max_length, truncation=True) # elif 'bart' in self.args.tokenizer: # target_ids = self.tokenizer.encode(sent, max_length=self.args.gen_max_length, truncation=True) # assert len(target_ids) <= self.args.gen_max_length, len(target_ids) out_dict['sent'] = sent # out_dict['target_ids'] = torch.LongTensor(target_ids) # out_dict['target_length'] = len(target_ids) if 'targets' in datum: out_dict['targets'] = datum['targets'] return out_dict def collate_fn(self, batch): batch_entry = {} B = len(batch) # S_W_L = max(entry['input_length'] for entry in batch) # input_ids = torch.ones(B, S_W_L, dtype=torch.long) * self.tokenizer.pad_token_id # if self.args.no_prefix: # assert input_ids.size() == (B, 0) # if self.args.use_vision: # pass if 'target_ids' in batch[0]: T_W_L = max(entry['target_length'] for entry in batch) target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id # sentences = [] targets = [] img_ids = [] img_paths = [] input_text = [] images = [] sents = [] for i, entry in enumerate(batch): # input_ids[i, :entry['input_length']] = entry['input_ids'] # if self.args.use_vision: # n_boxes = entry['n_boxes'] images.append(entry['image']) img_ids.append(entry['img_id']) # img_paths.append(entry['img_path']) if 'target_ids' in entry: target_ids[i, :entry['target_length']] = entry['target_ids'] # if 'input_text' in entry: # input_text.append(entry['input_text']) # sentences.append(entry['sent']) if 'targets' in entry: targets.append(entry['targets']) if 'sent' in entry: sents.append(entry['sent']) # batch_entry['input_ids'] = input_ids # if 'target_ids' in batch[0]: # word_mask = target_ids != self.tokenizer.pad_token_id # target_ids[~word_mask] = -100 # batch_entry['target_ids'] = target_ids # if self.args.use_vision: batch_entry['images'] = torch.stack(images) batch_entry['img_id'] = img_ids batch_entry['img_paths'] = img_paths if 'sent' in entry: batch_entry['sent'] = sents # batch_entry['sent'] = sentences # batch_entry['input_text'] = input_text batch_entry['targets'] = targets batch_entry['task'] = 'caption' return batch_entry def pre_caption(caption,max_words): caption = re.sub( r"([,.'!?\"()*#:;~])", '', caption.lower(), ).replace('-', ' ').replace('/', ' ').replace('', 'person') caption = re.sub( r"\s{2,}", ' ', caption, ) caption = caption.rstrip('\n') caption = caption.strip(' ') #truncate caption caption_words = caption.split(' ') if len(caption_words)>max_words: caption = ' '.join(caption_words[:max_words]) return caption class CCDataset(Dataset): def __init__(self, split='CC', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train', data_dir='/data/mshukor/data', config_dir='/data/mshukor/data/cc3m.json', max_words=30): super().__init__() self.raw_dataset = raw_dataset self.topk = topk self.verbose = verbose self.args = args self.mode = mode data = [] ann_files = [config_dir] ann_file = [] for p in ann_files: ann_file.append(os.path.join(args.data_json_dir, p)) for f in ann_file: tmp = json.load(open(f,'r')) data += tmp print('size of', f, len(tmp)) print(len(data)) self.max_words = max_words for e in data: e['image'] = os.path.join(data_dir, ('/').join(e['image'].split('/')[4:])) # Loading datasets to data self.source = split if self.verbose: print('Data source: ', self.source) normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) self.train_transform = transforms.Compose([ transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), transforms.ToTensor(), normalize, ]) self.test_transform = transforms.Compose([ transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC), transforms.ToTensor(), normalize, ]) if isinstance(self.topk, float) and (0 < self.topk <= 1): used_samples = int(self.topk * len(data)) data = random.sample(data, used_samples) if self.verbose: print(f"Use only {len(data)} data") elif self.topk > 0: data = data[:int(self.topk)] if self.verbose: print(f"Use only {len(data)} data") self.data = data if self.verbose: print("# all sentences:", len(self.data)) self.image_size = self.args.image_size if mode == "train" and self.args.use_data_augmentation: self.transform = self.train_transform else: self.transform = self.test_transform def __len__(self): return len(self.data) def __getitem__(self, idx): out_dict = {} out_dict['args'] = self.args datum = self.data[idx] if type(datum['caption']) == list: caption = pre_caption(random.choice(datum['caption']), self.max_words) else: caption = pre_caption(datum['caption'], self.max_words) ###### Image ###### image = Image.open(datum['image']).convert('RGB') img_id = datum['image'].split('/')[-1].split('.')[0] out_dict['img_id'] = img_id out_dict["image"] = self.transform(image) out_dict['sent'] = caption out_dict['targets'] = caption return out_dict def collate_fn(self, batch): batch_entry = {} B = len(batch) if 'target_ids' in batch[0]: T_W_L = max(entry['target_length'] for entry in batch) target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id targets = [] img_ids = [] img_paths = [] input_text = [] images = [] sents = [] for i, entry in enumerate(batch): images.append(entry['image']) img_ids.append(entry['img_id']) if 'target_ids' in entry: target_ids[i, :entry['target_length']] = entry['target_ids'] if 'targets' in entry: targets.append(entry['targets']) if 'sent' in entry: sents.append(entry['sent']) # if self.args.use_vision: batch_entry['images'] = torch.stack(images) batch_entry['img_id'] = img_ids batch_entry['img_paths'] = img_paths if 'sent' in entry: batch_entry['sent'] = sents batch_entry['targets'] = targets batch_entry['task'] = 'caption' return batch_entry def get_loader(args, split='train', mode='train', batch_size=32, workers=4, distributed=False, gpu=0, topk=-1, data_dir='/data/mshukor/data', local_rank=None, world_size=None, verbose=False, config_dir=None): # if 'mscoco' in split: # verbose = (gpu == 0) if 'CC' in split: dataset = CCDataset(split, data_dir=data_dir, mode=mode, topk=topk, args=args, verbose=verbose, rank=gpu, config_dir=config_dir) else: dataset = COCOCaptionFineTuneDataset( split, # raw_dataset=_dset, rank=gpu, topk=topk, verbose=verbose, args=args, mode=mode, data_dir=data_dir) if distributed and mode == 'train': train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) # train_sampler = DistributedSampler(dataset) # train_sampler = RandomNonreplacmentSampler(dataset, dataset.n_iter) else: train_sampler = None if mode == 'train': loader = DataLoader( dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=workers, pin_memory=True, sampler=train_sampler, collate_fn=dataset.collate_fn) else: loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True, sampler=None, collate_fn=dataset.collate_fn, drop_last=False) if verbose: loader.evaluator = COCOCaptionEvaluator() loader.task = 'caption' return loader class COCOCaptionEvaluator: def __init__(self): import language_evaluation self.evaluator = language_evaluation.CocoEvaluator(verbose=False) def evaluate(self, predicts, answers): results = self.evaluator.run_evaluation(predicts, answers) return results