from torch.utils.data import DataLoader, Dataset from pathlib import Path import json import random from multiprocessing import Pool import torch from PIL import Image from torch.utils.data.distributed import DistributedSampler from dataset.randaugment import RandomAugment import torch from torchvision import transforms import os import re class COCOCaptionFineTuneDataset(Dataset): def __init__(self, split='karpathy_train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train', data_dir='/data/mshukor/data', black_image=False): super().__init__() self.raw_dataset = raw_dataset self.topk = topk self.verbose = verbose self.args = args self.args.BUTD100 = False self.mode = mode dataset_dir = Path(data_dir) coco_dir = dataset_dir.joinpath('COCO') vg_dir = dataset_dir.joinpath('VG') coco_img_dir = coco_dir.joinpath('images/') coco_feature_dir = coco_dir.joinpath('features') self.black_image = black_image # Loading datasets to data self.source = split if self.verbose: print('Data source: ', self.source) normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) self.train_transform = transforms.Compose([ transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), RandomAugment(2,7,isPIL=True,augs=['Identity','AutoContrast','Equalize','Brightness','Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), transforms.ToTensor(), normalize, ]) self.test_transform = transforms.Compose([ transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC), transforms.ToTensor(), normalize, ]) data_info_path = dataset_dir.joinpath('COCO/dataset_coco.json') with open(data_info_path) as f: karpathy_data = json.load(f) split_rename = { 'train': 'train', 'restval': 'train', 'val': 'val', 'test': 'test' } n_images = 0 data = [] for datum in karpathy_data['images']: re_split = split_rename[datum['split']] if re_split != self.source.split('_')[-1]: continue if re_split == 'train': for d in datum['sentences']: img_id = datum['filename'].split('.')[0] new_datum = { 'img_id': img_id, 'sent': d['raw'].strip(), 'targets': [d['raw'].strip() for d in datum['sentences']], 'is_train': True, } data.append(new_datum) else: img_id = datum['filename'].split('.')[0] new_datum = { 'img_id': img_id, # 'sent': d['raw'], 'targets': [d['raw'].strip() for d in datum['sentences']], 'is_train': False, } data.append(new_datum) n_images += 1 if self.verbose: print(f"{self.source} has {n_images} images") print(f"Loaded {len(data)} data from", split) if isinstance(self.topk, float) and (0 < self.topk <= 1): used_samples = int(self.topk * len(data)) data = random.sample(data, used_samples) if self.verbose: print(f"Use only {len(data)} data") elif self.topk > 0: data = data[:int(self.topk)] if self.verbose: print(f"Use only {len(data)} data") self.data = data if self.verbose: print("# all sentences:", len(self.data)) self.image_size = self.args.image_size if mode == "train" and self.args.use_data_augmentation: self.transform = self.train_transform else: self.transform = self.test_transform self.source_to_h5 = {} self.source_to_h5.update({ 'train2014': coco_img_dir.joinpath(f'train2014'), 'val2014': coco_img_dir.joinpath(f'val2014'), }) def __len__(self): return len(self.data) def __getitem__(self, idx): out_dict = {} out_dict['args'] = self.args datum = self.data[idx] ###### Image ###### img_id = datum['img_id'] out_dict['img_id'] = img_id if self.args.BUTD100: source = self.source else: if 'train' in img_id: source = 'train2014' elif 'val' in img_id: source = 'val2014' path = self.source_to_h5[source].joinpath(f"{img_id}.jpg") image = Image.open(path).convert('RGB') out_dict["image"] = self.transform(image) if self.black_image: out_dict["image"] = torch.zeros_like(out_dict["image"]) if datum['is_train']: sent = datum['sent'].strip() out_dict['sent'] = sent if 'targets' in datum: out_dict['targets'] = datum['targets'] return out_dict def collate_fn(self, batch): batch_entry = {} B = len(batch) if 'target_ids' in batch[0]: T_W_L = max(entry['target_length'] for entry in batch) target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id targets = [] img_ids = [] img_paths = [] input_text = [] images = [] sents = [] for i, entry in enumerate(batch): images.append(entry['image']) img_ids.append(entry['img_id']) if 'target_ids' in entry: target_ids[i, :entry['target_length']] = entry['target_ids'] if 'targets' in entry: targets.append(entry['targets']) if 'sent' in entry: sents.append(entry['sent']) batch_entry['images'] = torch.stack(images) batch_entry['img_id'] = img_ids batch_entry['img_paths'] = img_paths if 'sent' in entry: batch_entry['sent'] = sents batch_entry['targets'] = targets batch_entry['task'] = 'caption' return batch_entry def pre_caption(caption,max_words): caption = re.sub( r"([,.'!?\"()*#:;~])", '', caption.lower(), ).replace('-', ' ').replace('/', ' ').replace('', 'person') caption = re.sub( r"\s{2,}", ' ', caption, ) caption = caption.rstrip('\n') caption = caption.strip(' ') #truncate caption caption_words = caption.split(' ') if len(caption_words)>max_words: caption = ' '.join(caption_words[:max_words]) return caption def get_loader(args, split='train', mode='train', batch_size=32, workers=4, distributed=False, gpu=0, topk=-1, data_dir='/data/mshukor/data', local_rank=None, world_size=None, verbose=False, config_dir=None, black_image=False): dataset = COCOCaptionFineTuneDataset( split, # raw_dataset=_dset, rank=gpu, topk=topk, verbose=verbose, args=args, mode=mode, data_dir=data_dir, black_image=black_image) if distributed and mode == 'train': train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) else: train_sampler = None if mode == 'train': loader = DataLoader( dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=workers, pin_memory=True, sampler=train_sampler, collate_fn=dataset.collate_fn) else: loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True, sampler=None, collate_fn=dataset.collate_fn, drop_last=False) if verbose: loader.evaluator = COCOCaptionEvaluator() loader.task = 'caption' return loader class COCOCaptionEvaluator: def __init__(self): import language_evaluation self.evaluator = language_evaluation.CocoEvaluator(verbose=False) def evaluate(self, predicts, answers): results = self.evaluator.run_evaluation(predicts, answers) return results