|
|
|
|
|
from torch.utils.data import DataLoader, Dataset |
|
from pathlib import Path |
|
import json |
|
import random |
|
import torch |
|
import numpy as np |
|
|
|
from torchvision import transforms |
|
|
|
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
from PIL import Image |
|
import re |
|
|
|
from dataset.video_utils import VIDEO_READER_FUNCS |
|
|
|
|
|
project_dir = Path(__file__).resolve().parent.parent |
|
workspace_dir = project_dir.parent |
|
|
|
|
|
|
|
class MSRVTTVQAFineTuneDataset(Dataset): |
|
def __init__(self, split='train,valid', raw_dataset=None, rank=-1, topk=-1, verbose=True, args=None, mode='train', data_dir=None): |
|
super().__init__() |
|
|
|
self.raw_dataset = raw_dataset |
|
self.topk = topk |
|
self.verbose = verbose |
|
self.args = args |
|
|
|
self.mode = mode |
|
|
|
data_dir = Path(data_dir) |
|
dataset_dir = data_dir.joinpath('annotation') |
|
coco_img_dir = data_dir.joinpath('videos/all') |
|
|
|
|
|
self.num_frames = args.num_frames |
|
self.video_reader = VIDEO_READER_FUNCS['decord'] |
|
self.as_images = args.as_images |
|
self.num_tries = args.num_tries |
|
self.sample_type = args.sample_type |
|
|
|
|
|
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
|
|
type_transform = transforms.Lambda(lambda x: x.float().div(255.0)) |
|
|
|
|
|
|
|
self.train_transform = transforms.Compose([ |
|
transforms.RandomResizedCrop(args.image_size,scale=(0.5, 1.0), interpolation=Image.BICUBIC), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.RandAugment(), |
|
type_transform, |
|
normalize, |
|
]) |
|
self.test_transform = transforms.Compose([ |
|
transforms.Resize((args.image_size,args.image_size),interpolation=Image.BICUBIC), |
|
type_transform, |
|
normalize, |
|
]) |
|
|
|
|
|
|
|
|
|
self.sources = split.split(',') |
|
if self.verbose: |
|
print('Data sources: ', self.sources) |
|
|
|
|
|
|
|
data_info_path = dataset_dir.joinpath(split+'.json') |
|
with open(data_info_path) as f: |
|
karpathy_data = json.load(f) |
|
|
|
|
|
|
|
data = karpathy_data |
|
|
|
|
|
|
|
if isinstance(self.topk, float) and (0 < self.topk <= 1): |
|
used_samples = int(self.topk * len(data)) |
|
data = random.sample(data, used_samples) |
|
if self.verbose: |
|
print(f"Use only {len(data)} data") |
|
|
|
elif self.topk > 0: |
|
data = data[:int(self.topk)] |
|
if self.verbose: |
|
print(f"Use only {len(data)} data") |
|
|
|
self.data = data |
|
|
|
if self.verbose: |
|
print("# all sentences:", len(self.data)) |
|
|
|
|
|
self.image_size = self.args.image_size |
|
|
|
if mode == "train" and self.args.use_data_augmentation: |
|
self.transform = self.train_transform |
|
else: |
|
self.transform = self.test_transform |
|
|
|
self.source_to_h5 = {} |
|
self.source_to_h5.update({ |
|
'all': coco_img_dir, |
|
}) |
|
|
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
|
|
out_dict = {} |
|
out_dict['args'] = self.args |
|
|
|
datum = self.data[idx] |
|
|
|
|
|
|
|
|
|
for i in range(self.num_tries): |
|
|
|
try: |
|
datum = self.data[idx] |
|
|
|
|
|
video = datum['video'] |
|
out_dict['img_id'] = video.split('.')[0] |
|
|
|
path = str(self.source_to_h5['all'].joinpath(f"{video}")) |
|
|
|
|
|
max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 |
|
frames, frame_indices, video_duration = self.video_reader( |
|
path, self.num_frames, self.sample_type, max_num_frames=max_num_frames |
|
) |
|
|
|
except Exception as e: |
|
print(i, path) |
|
idx = random.randint(0, len(self) - 1) |
|
print( |
|
f"Caught exception {e} when loading video {path}, " |
|
f"randomly sample a new video as replacement" |
|
) |
|
continue |
|
|
|
|
|
out_dict["image"] = self.transform(frames) |
|
if not self.as_images: |
|
out_dict["image"] = out_dict["image"].permute(1, 0, 2, 3) |
|
|
|
|
|
|
|
|
|
if 'sent' in datum: |
|
sent = datum['sent'] |
|
elif 'question' in datum: |
|
sent = datum['question'] |
|
|
|
question_id = datum['question_id'] |
|
out_dict['question_id'] = question_id |
|
|
|
out_dict['sent'] = sent |
|
|
|
|
|
if 'label' in datum: |
|
label = datum['label'] |
|
out_dict['label'] = label |
|
|
|
|
|
answers = [] |
|
scores = [] |
|
for a, s in label.items(): |
|
answers.append(a) |
|
scores.append(s) |
|
|
|
score_sum = sum(scores) |
|
|
|
if score_sum == 0: |
|
answer = '' |
|
score = 0. |
|
else: |
|
prob = [score / score_sum for score in scores] |
|
choice = np.random.multinomial(1, prob).argmax() |
|
answer = answers[choice] |
|
score = scores[choice] |
|
assert len(answer) > 0, (sent, label, choice, answer) |
|
|
|
out_dict['answer'] = answer |
|
out_dict['score'] = score |
|
out_dict['all_answers'] = answers |
|
|
|
|
|
|
|
return out_dict |
|
|
|
|
|
def collate_fn(self, batch): |
|
batch_entry = {} |
|
|
|
args = batch[0]['args'] |
|
|
|
B = len(batch) |
|
|
|
|
|
|
|
sentences = [] |
|
question_ids = [] |
|
answers = [] |
|
all_answers = [] |
|
labels = [] |
|
scores = [] |
|
|
|
images = [] |
|
|
|
for i, entry in enumerate(batch): |
|
|
|
images.append(entry["image"]) |
|
|
|
|
|
sentences.append(entry['sent']) |
|
question_ids.append(entry['question_id']) |
|
if 'answer' in entry: |
|
answers.append(entry['answer']) |
|
if 'all_answers' in entry: |
|
all_answers.append(entry['all_answers']) |
|
|
|
if 'score' in entry: |
|
scores.append(entry['score']) |
|
|
|
if 'label' in entry: |
|
labels.append(entry['label']) |
|
|
|
|
|
batch_entry['images'] = torch.stack(images) |
|
|
|
|
|
batch_entry['sent'] = sentences |
|
batch_entry['question_ids'] = question_ids |
|
batch_entry['answers'] = answers |
|
batch_entry['all_answers'] = all_answers |
|
|
|
batch_entry['scores'] = torch.FloatTensor(scores) |
|
batch_entry['labels'] = labels |
|
|
|
batch_entry['task'] = 'gqa' |
|
|
|
return batch_entry |
|
|
|
|
|
def get_loader(args, split='train', mode='train', |
|
batch_size=32, workers=4, distributed=False, gpu=0, |
|
topk=-1, verbose=None, data_dir='/data/mshukor/data', local_rank=None, world_size=None): |
|
|
|
|
|
|
|
_dset = MSRVTTVQADataset(split, verbose, data_dir=data_dir) |
|
|
|
dataset = MSRVTTVQAFineTuneDataset( |
|
split, |
|
raw_dataset=_dset, |
|
rank=gpu, |
|
topk=topk, |
|
verbose=verbose, |
|
args=args, |
|
mode=mode, data_dir=data_dir) |
|
|
|
if distributed: |
|
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=local_rank) |
|
else: |
|
sampler = None |
|
|
|
|
|
if mode == 'train': |
|
loader = DataLoader( |
|
dataset, batch_size=batch_size, shuffle=(sampler is None), |
|
num_workers=workers, pin_memory=True, sampler=sampler, |
|
collate_fn=dataset.collate_fn) |
|
else: |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
num_workers=workers, pin_memory=True, |
|
sampler=sampler, |
|
shuffle=None if (sampler is not None) else False, |
|
collate_fn=dataset.collate_fn, |
|
drop_last=False) |
|
|
|
loader.evaluator = MSRVTTVQAQAEvaluator(_dset) |
|
loader.task = 'msrvttvqa' |
|
|
|
return loader |
|
|
|
|
|
class MSRVTTVQADataset: |
|
""" |
|
A GQA data example in json file: |
|
{ |
|
"video": "2375429.mp4", |
|
"label": { |
|
"pipe": 1.0 |
|
}, |
|
"question_id": "07333408", |
|
"sent": "What is on the white wall?" |
|
} |
|
""" |
|
|
|
def __init__(self, splits: str, verbose=True, data_dir='/data/mshukor/data'): |
|
self.name = splits |
|
self.splits = splits.split(',') |
|
|
|
data_dir = Path(data_dir) |
|
dataset_dir = data_dir.joinpath('annotation') |
|
|
|
|
|
|
|
|
|
|
|
self.data = [] |
|
for split in self.splits: |
|
self.data.extend(json.load(open(dataset_dir.joinpath("%s.json" % split)))) |
|
if verbose: |
|
print("Load %d data from split(s) %s." % |
|
(len(self.data), self.name)) |
|
|
|
|
|
self.id2datum = { |
|
datum['question_id']: datum |
|
for datum in self.data |
|
} |
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
|
|
class MSRVTTVQAQAEvaluator: |
|
def __init__(self, dataset: MSRVTTVQADataset): |
|
self.dataset = dataset |
|
|
|
"""https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py""" |
|
|
|
self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \ |
|
"couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \ |
|
"hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \ |
|
"he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \ |
|
"Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \ |
|
"maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \ |
|
"mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \ |
|
"ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \ |
|
"she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \ |
|
"somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \ |
|
"somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \ |
|
"someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \ |
|
"something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \ |
|
"there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \ |
|
"they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \ |
|
"wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \ |
|
"whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \ |
|
"whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \ |
|
"whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \ |
|
"wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \ |
|
"y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \ |
|
"youll": "you'll", "youre": "you're", "youve": "you've"} |
|
|
|
self.manualMap = { 'none': '0', |
|
'zero': '0', |
|
'one': '1', |
|
'two': '2', |
|
'three': '3', |
|
'four': '4', |
|
'five': '5', |
|
'six': '6', |
|
'seven': '7', |
|
'eight': '8', |
|
'nine': '9', |
|
'ten': '10' |
|
} |
|
|
|
self.articles = ['a', |
|
'an', |
|
'the' |
|
] |
|
|
|
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") |
|
self.commaStrip = re.compile("(\d)(\,)(\d)") |
|
self.punct = [';', r"/", '[', ']', '"', '{', '}', |
|
'(', ')', '=', '+', '\\', '_', '-', |
|
'>', '<', '@', '`', ',', '?', '!'] |
|
|
|
def evaluate(self, quesid2ans: dict, normalize_answer=False): |
|
score = 0. |
|
for quesid, ans in quesid2ans.items(): |
|
datum = self.dataset.id2datum[quesid] |
|
label = datum['label'] |
|
if normalize_answer: |
|
ans = self.normalize_answer(ans) |
|
new_label = {self.normalize_answer(l): label[l] for l in label} |
|
else: |
|
new_label = label |
|
|
|
if ans in new_label: |
|
score += new_label[ans] |
|
return score / len(quesid2ans) |
|
|
|
def normalize_answer(self, resAns): |
|
resAns = resAns.replace('\n', ' ') |
|
resAns = resAns.replace('\t', ' ') |
|
resAns = resAns.strip() |
|
resAns = self.processPunctuation(resAns) |
|
resAns = self.processDigitArticle(resAns) |
|
resAns = resAns.replace(',', '') |
|
return resAns |
|
|
|
def processPunctuation(self, inText): |
|
outText = inText |
|
for p in self.punct: |
|
if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None): |
|
outText = outText.replace(p, '') |
|
else: |
|
outText = outText.replace(p, ' ') |
|
outText = self.periodStrip.sub("", |
|
outText, |
|
re.UNICODE) |
|
return outText |
|
|
|
def processDigitArticle(self, inText): |
|
outText = [] |
|
tempText = inText.lower().split() |
|
for word in tempText: |
|
word = self.manualMap.setdefault(word, word) |
|
if word not in self.articles: |
|
outText.append(word) |
|
else: |
|
pass |
|
for wordId, word in enumerate(outText): |
|
if word in self.contractions: |
|
outText[wordId] = self.contractions[word] |
|
outText = ' '.join(outText) |
|
return outText |
|
|
|
def dump_result(self, quesid2ans: dict, path): |
|
""" |
|
Dump the result to a GQA-challenge submittable json file. |
|
GQA json file submission requirement: |
|
results = [result] |
|
result = { |
|
"questionId": str, # Note: it's a actually an int number but the server requires an str. |
|
"prediction": str |
|
} |
|
:param quesid2ans: A dict mapping question id to its predicted answer. |
|
:param path: The file path to save the json file. |
|
:return: |
|
""" |
|
with open(path, 'w') as f: |
|
result = [] |
|
for ques_id, ans in quesid2ans.items(): |
|
result.append({ |
|
'questionId': ques_id, |
|
'prediction': ans |
|
}) |
|
json.dump(result, f, indent=4, sort_keys=True) |
|
|