eP-ALM / dataset /.ipynb_checkpoints /vqa-checkpoint.py
mshukor
init
3eb682b
raw
history blame
No virus
3.35 kB
from torch.utils.data import Dataset
import numpy as np
import os
import random
import re
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
import json
def pre_question(question,max_ques_words):
question = re.sub(
r"([,.'!?\"()*#:;~])",
'',
question.lower(),
).replace('-', ' ').replace('/', ' ')
question = question.rstrip(' ')
#truncate question
question_words = question.split(' ')
if len(question_words)>max_ques_words:
question = ' '.join(question_words[:max_ques_words])
return question
class vqa_dataset(Dataset):
def __init__(self, ann_file, transform, vqa_root, vg_root, eos='[SEP]', split="train", max_ques_words=30, answer_list=''):
self.split = split
self.ann = []
for f in ann_file:
tmp = json.load(open(f,'r'))
self.ann += tmp
print(f, len(self.ann), len(tmp))
self.transform = transform
self.vqa_root = vqa_root
self.vg_root = vg_root
self.max_ques_words = max_ques_words
self.eos = eos
if split=='test':
self.max_ques_words = 50 # do not limit question length during test
self.answer_list = json.load(open(answer_list,'r'))
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
ann = self.ann[index]
if ann['dataset']=='vqa':
image_path = os.path.join(self.vqa_root,ann['image'])
elif ann['dataset']=='vg':
image_path = os.path.join(self.vg_root,ann['image'])
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
if self.split == 'test':
question = pre_question(ann['question'],self.max_ques_words)
question_id = ann['question_id']
return image, question, question_id
elif self.split=='train':
question = pre_question(ann['question'],self.max_ques_words)
if ann['dataset']=='vqa':
answer_weight = {}
for answer in ann['answer']:
if answer in answer_weight.keys():
answer_weight[answer] += 1/len(ann['answer'])
else:
answer_weight[answer] = 1/len(ann['answer'])
answers = list(answer_weight.keys())
weights = list(answer_weight.values())
elif ann['dataset']=='vg':
answers = [ann['answer']]
weights = [0.5]
answers = [answer+self.eos for answer in answers]
return image, question, answers, weights
def vqa_collate_fn(batch):
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
for image, question, answer, weights in batch:
image_list.append(image)
question_list.append(question)
weight_list += weights
answer_list += answer
n.append(len(answer))
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n