from torch.utils.data import DataLoader, Dataset from pathlib import Path import json import random import torch from PIL import Image from torch.utils.data.distributed import DistributedSampler import torch from torchvision import transforms import re from dataset.video_utils import VIDEO_READER_FUNCS class MSRVTTCaptionFineTuneDataset(Dataset): def __init__(self, split='karpathy_train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train', data_dir='/data/mshukor/data', black_image=False): super().__init__() self.raw_dataset = raw_dataset self.topk = topk self.verbose = verbose self.args = args self.mode = mode data_dir = Path(data_dir) dataset_dir = data_dir.joinpath('annotation') coco_img_dir = data_dir.joinpath('videos/all') self.black_image = black_image self.source = split if self.verbose: print('Data source: ', self.source) # video self.num_frames = args.num_frames # 4 self.video_reader = VIDEO_READER_FUNCS['decord'] self.as_images = args.as_images # True self.num_tries = args.num_tries # 2 self.sample_type = args.sample_type # 'rand' normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) self.train_transform = transforms.Compose([ transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC), transforms.RandomHorizontalFlip(), transforms.RandAugment(), type_transform, normalize, ]) self.test_transform = transforms.Compose([ transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC), type_transform, normalize, ]) data_info_path = dataset_dir.joinpath(split+'.json') with open(data_info_path) as f: karpathy_data = json.load(f) n_images = 0 data = [] for datum in karpathy_data: if 'train' in split : caption = datum['caption'] if isinstance(caption, list): for d in caption: img_id = ".".join(datum['video'].split('.')[:-1]) new_datum = { 'img_id': img_id, 'sent': d.strip(), 'targets': [k.strip() for k in caption], 'is_train': True, 'video': datum['video'], } data.append(new_datum) else: img_id = ".".join(datum['video'].split('.')[:-1]) new_datum = { 'img_id': img_id, 'sent': caption.strip(), 'targets': caption.strip(), 'is_train': True, 'video': datum['video'], } data.append(new_datum) else: caption = datum['caption'] if not isinstance(caption, list): caption = [caption] img_id = ".".join(datum['video'].split('.')[:-1]) new_datum = { 'img_id': img_id, 'targets': [d.strip() for d in caption], 'is_train': False, 'video': datum['video'], } data.append(new_datum) n_images += 1 if self.verbose: print(f"{self.source} has {n_images} images") print(f"Loaded {len(data)} data from", split) if isinstance(self.topk, float) and (0 < self.topk <= 1): used_samples = int(self.topk * len(data)) data = random.sample(data, used_samples) if self.verbose: print(f"Use only {len(data)} data") elif self.topk > 0: data = data[:int(self.topk)] if self.verbose: print(f"Use only {len(data)} data") self.data = data if self.verbose: print("# all sentences:", len(self.data)) self.image_size = self.args.image_size if mode == "train" and self.args.use_data_augmentation: self.transform = self.train_transform else: self.transform = self.test_transform self.source_to_h5 = {} self.source_to_h5.update({ 'all': coco_img_dir, }) def __len__(self): return len(self.data) def __getitem__(self, idx): out_dict = {} out_dict['args'] = self.args for i in range(self.num_tries): try: datum = self.data[idx] ###### Image ###### img_id = datum['img_id'] out_dict['img_id'] = img_id video = datum['video'] path = str(self.source_to_h5['all'].joinpath(f"{video}")) max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 frames, frame_indices, video_duration = self.video_reader( path, self.num_frames, self.sample_type, max_num_frames=max_num_frames ) except Exception as e: print(i, path) idx = random.randint(0, len(self) - 1) print( f"Caught exception {e} when loading video {path}, " f"randomly sample a new video as replacement" ) continue out_dict["image"] = self.transform(frames) if self.black_image: out_dict["image"] = torch.zeros_like(out_dict["image"]) if not self.as_images: out_dict["image"] = out_dict["image"].permute(1, 0, 2, 3) # -> CTHW if datum['is_train']: sent = datum['sent'].strip() out_dict['sent'] = sent if 'targets' in datum: out_dict['targets'] = datum['targets'] return out_dict def collate_fn(self, batch): batch_entry = {} B = len(batch) if 'target_ids' in batch[0]: T_W_L = max(entry['target_length'] for entry in batch) target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id targets = [] img_ids = [] img_paths = [] input_text = [] images = [] sents = [] for i, entry in enumerate(batch): images.append(entry['image']) img_ids.append(entry['img_id']) if 'target_ids' in entry: target_ids[i, :entry['target_length']] = entry['target_ids'] if 'targets' in entry: targets.append(entry['targets']) if 'sent' in entry: sents.append(entry['sent']) # if self.args.use_vision: batch_entry['images'] = torch.stack(images) batch_entry['img_id'] = img_ids batch_entry['img_paths'] = img_paths if 'sent' in entry: batch_entry['sent'] = sents batch_entry['targets'] = targets batch_entry['task'] = 'caption' return batch_entry def pre_caption(caption,max_words): caption = re.sub( r"([,.'!?\"()*#:;~])", '', caption.lower(), ).replace('-', ' ').replace('/', ' ').replace('', 'person') caption = re.sub( r"\s{2,}", ' ', caption, ) caption = caption.rstrip('\n') caption = caption.strip(' ') #truncate caption caption_words = caption.split(' ') if len(caption_words)>max_words: caption = ' '.join(caption_words[:max_words]) return caption def get_loader(args, split='train', mode='train', batch_size=32, workers=4, distributed=False, gpu=0, topk=-1, data_dir='/data/mshukor/data', local_rank=None, world_size=None, verbose=False, config_dir=None, black_image=False): dataset = MSRVTTCaptionFineTuneDataset( split, rank=gpu, topk=topk, verbose=verbose, args=args, mode=mode, data_dir=data_dir, black_image=black_image) if distributed and mode == 'train': train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) else: train_sampler = None if mode == 'train': loader = DataLoader( dataset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=workers, pin_memory=True, sampler=train_sampler, collate_fn=dataset.collate_fn) else: loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True, sampler=None, collate_fn=dataset.collate_fn, drop_last=False) if verbose: loader.evaluator = COCOCaptionEvaluator() loader.task = 'caption' return loader class COCOCaptionEvaluator: def __init__(self): import language_evaluation self.evaluator = language_evaluation.CocoEvaluator(verbose=False) def evaluate(self, predicts, answers): results = self.evaluator.run_evaluation(predicts, answers) return results