eP-ALM / dataset /video_caption.py
mshukor
init
3eb682b
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
import json
import random
import torch
from PIL import Image
from torch.utils.data.distributed import DistributedSampler
import torch
from torchvision import transforms
import re
from dataset.video_utils import VIDEO_READER_FUNCS
class MSRVTTCaptionFineTuneDataset(Dataset):
def __init__(self, split='karpathy_train', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train',
data_dir='/data/mshukor/data', black_image=False):
super().__init__()
self.raw_dataset = raw_dataset
self.topk = topk
self.verbose = verbose
self.args = args
self.mode = mode
data_dir = Path(data_dir)
dataset_dir = data_dir.joinpath('annotation')
coco_img_dir = data_dir.joinpath('videos/all')
self.black_image = black_image
self.source = split
if self.verbose:
print('Data source: ', self.source)
# video
self.num_frames = args.num_frames # 4
self.video_reader = VIDEO_READER_FUNCS['decord']
self.as_images = args.as_images # True
self.num_tries = args.num_tries # 2
self.sample_type = args.sample_type # 'rand'
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
type_transform = transforms.Lambda(lambda x: x.float().div(255.0))
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.RandAugment(),
type_transform,
normalize,
])
self.test_transform = transforms.Compose([
transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC),
type_transform,
normalize,
])
data_info_path = dataset_dir.joinpath(split+'.json')
with open(data_info_path) as f:
karpathy_data = json.load(f)
n_images = 0
data = []
for datum in karpathy_data:
if 'train' in split :
caption = datum['caption']
if isinstance(caption, list):
for d in caption:
img_id = ".".join(datum['video'].split('.')[:-1])
new_datum = {
'img_id': img_id,
'sent': d.strip(),
'targets': [k.strip() for k in caption],
'is_train': True,
'video': datum['video'],
}
data.append(new_datum)
else:
img_id = ".".join(datum['video'].split('.')[:-1])
new_datum = {
'img_id': img_id,
'sent': caption.strip(),
'targets': caption.strip(),
'is_train': True,
'video': datum['video'],
}
data.append(new_datum)
else:
caption = datum['caption']
if not isinstance(caption, list):
caption = [caption]
img_id = ".".join(datum['video'].split('.')[:-1])
new_datum = {
'img_id': img_id,
'targets': [d.strip() for d in caption],
'is_train': False,
'video': datum['video'],
}
data.append(new_datum)
n_images += 1
if self.verbose:
print(f"{self.source} has {n_images} images")
print(f"Loaded {len(data)} data from", split)
if isinstance(self.topk, float) and (0 < self.topk <= 1):
used_samples = int(self.topk * len(data))
data = random.sample(data, used_samples)
if self.verbose:
print(f"Use only {len(data)} data")
elif self.topk > 0:
data = data[:int(self.topk)]
if self.verbose:
print(f"Use only {len(data)} data")
self.data = data
if self.verbose:
print("# all sentences:", len(self.data))
self.image_size = self.args.image_size
if mode == "train" and self.args.use_data_augmentation:
self.transform = self.train_transform
else:
self.transform = self.test_transform
self.source_to_h5 = {}
self.source_to_h5.update({
'all': coco_img_dir,
})
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
out_dict = {}
out_dict['args'] = self.args
for i in range(self.num_tries):
try:
datum = self.data[idx]
###### Image ######
img_id = datum['img_id']
out_dict['img_id'] = img_id
video = datum['video']
path = str(self.source_to_h5['all'].joinpath(f"{video}"))
max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1
frames, frame_indices, video_duration = self.video_reader(
path, self.num_frames, self.sample_type, max_num_frames=max_num_frames
)
except Exception as e:
print(i, path)
idx = random.randint(0, len(self) - 1)
print(
f"Caught exception {e} when loading video {path}, "
f"randomly sample a new video as replacement"
)
continue
out_dict["image"] = self.transform(frames)
if self.black_image:
out_dict["image"] = torch.zeros_like(out_dict["image"])
if not self.as_images:
out_dict["image"] = out_dict["image"].permute(1, 0, 2, 3) # -> CTHW
if datum['is_train']:
sent = datum['sent'].strip()
out_dict['sent'] = sent
if 'targets' in datum:
out_dict['targets'] = datum['targets']
return out_dict
def collate_fn(self, batch):
batch_entry = {}
B = len(batch)
if 'target_ids' in batch[0]:
T_W_L = max(entry['target_length'] for entry in batch)
target_ids = torch.ones(B, T_W_L, dtype=torch.long) * self.tokenizer.pad_token_id
targets = []
img_ids = []
img_paths = []
input_text = []
images = []
sents = []
for i, entry in enumerate(batch):
images.append(entry['image'])
img_ids.append(entry['img_id'])
if 'target_ids' in entry:
target_ids[i, :entry['target_length']] = entry['target_ids']
if 'targets' in entry:
targets.append(entry['targets'])
if 'sent' in entry:
sents.append(entry['sent'])
# if self.args.use_vision:
batch_entry['images'] = torch.stack(images)
batch_entry['img_id'] = img_ids
batch_entry['img_paths'] = img_paths
if 'sent' in entry:
batch_entry['sent'] = sents
batch_entry['targets'] = targets
batch_entry['task'] = 'caption'
return batch_entry
def pre_caption(caption,max_words):
caption = re.sub(
r"([,.'!?\"()*#:;~])",
'',
caption.lower(),
).replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
#truncate caption
caption_words = caption.split(' ')
if len(caption_words)>max_words:
caption = ' '.join(caption_words[:max_words])
return caption
def get_loader(args, split='train', mode='train',
batch_size=32, workers=4, distributed=False, gpu=0,
topk=-1, data_dir='/data/mshukor/data', local_rank=None, world_size=None, verbose=False,
config_dir=None, black_image=False):
dataset = MSRVTTCaptionFineTuneDataset(
split,
rank=gpu,
topk=topk,
verbose=verbose,
args=args,
mode=mode, data_dir=data_dir, black_image=black_image)
if distributed and mode == 'train':
train_sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank)
else:
train_sampler = None
if mode == 'train':
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=(train_sampler is None),
num_workers=workers, pin_memory=True, sampler=train_sampler,
collate_fn=dataset.collate_fn)
else:
loader = DataLoader(
dataset,
batch_size=batch_size, shuffle=False,
num_workers=workers, pin_memory=True,
sampler=None,
collate_fn=dataset.collate_fn,
drop_last=False)
if verbose:
loader.evaluator = COCOCaptionEvaluator()
loader.task = 'caption'
return loader
class COCOCaptionEvaluator:
def __init__(self):
import language_evaluation
self.evaluator = language_evaluation.CocoEvaluator(verbose=False)
def evaluate(self, predicts, answers):
results = self.evaluator.run_evaluation(predicts, answers)
return results