import logging import re import torch.utils.data from fairseq.data import FairseqDataset logger = logging.getLogger(__name__) class OFADataset(FairseqDataset): def __init__(self, split, dataset, bpe, src_dict, tgt_dict): self.split = split self.dataset = dataset self.bpe = bpe self.src_dict = src_dict self.tgt_dict = tgt_dict self.bos = src_dict.bos() self.eos = src_dict.eos() self.pad = src_dict.pad() self.bos_item = torch.LongTensor([self.bos]) self.eos_item = torch.LongTensor([self.eos]) def __len__(self): return len(self.dataset) def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True): s = self.tgt_dict.encode_line( line=self.bpe.encode(text) if use_bpe else text, add_if_not_exist=False, append_eos=False ).long() if length is not None: s = s[:length] if append_bos: s = torch.cat([self.bos_item, s]) if append_eos: s = torch.cat([s, self.eos_item]) return s def pre_question(self, question, max_ques_words): question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ') question = re.sub( r"\s{2,}", ' ', question, ) question = question.rstrip('\n') question = question.strip(' ') # truncate question question_words = question.split(' ') if len(question_words) > max_ques_words: question = ' '.join(question_words[:max_ques_words]) return question def pre_caption(self, caption, max_words): caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('', 'person') caption = re.sub( r"\s{2,}", ' ', caption, ) caption = caption.rstrip('\n') caption = caption.strip(' ') # truncate caption caption_words = caption.split(' ') if len(caption_words) > max_words: caption = ' '.join(caption_words[:max_words]) return caption