from torch.utils.data import Dataset import numpy as np import os import random import re import torch from torch import nn from torchvision import transforms from PIL import Image import json def pre_question(question,max_ques_words): question = re.sub( r"([,.'!?\"()*#:;~])", '', question.lower(), ).replace('-', ' ').replace('/', ' ') question = question.rstrip(' ') #truncate question question_words = question.split(' ') if len(question_words)>max_ques_words: question = ' '.join(question_words[:max_ques_words]) return question class vqa_dataset(Dataset): def __init__(self, ann_file, transform, vqa_root, vg_root, eos='[SEP]', split="train", max_ques_words=30, answer_list=''): self.split = split self.ann = [] for f in ann_file: tmp = json.load(open(f,'r')) self.ann += tmp print(f, len(self.ann), len(tmp)) self.transform = transform self.vqa_root = vqa_root self.vg_root = vg_root self.max_ques_words = max_ques_words self.eos = eos if split=='test': self.max_ques_words = 50 # do not limit question length during test self.answer_list = json.load(open(answer_list,'r')) def __len__(self): return len(self.ann) def __getitem__(self, index): ann = self.ann[index] if ann['dataset']=='vqa': image_path = os.path.join(self.vqa_root,ann['image']) elif ann['dataset']=='vg': image_path = os.path.join(self.vg_root,ann['image']) image = Image.open(image_path).convert('RGB') image = self.transform(image) if self.split == 'test': question = pre_question(ann['question'],self.max_ques_words) question_id = ann['question_id'] return image, question, question_id elif self.split=='train': question = pre_question(ann['question'],self.max_ques_words) if ann['dataset']=='vqa': answer_weight = {} for answer in ann['answer']: if answer in answer_weight.keys(): answer_weight[answer] += 1/len(ann['answer']) else: answer_weight[answer] = 1/len(ann['answer']) answers = list(answer_weight.keys()) weights = list(answer_weight.values()) elif ann['dataset']=='vg': answers = [ann['answer']] weights = [0.5] answers = [answer+self.eos for answer in answers] return image, question, answers, weights def vqa_collate_fn(batch): image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] for image, question, answer, weights in batch: image_list.append(image) question_list.append(question) weight_list += weights answer_list += answer n.append(len(answer)) return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n