File size: 2,211 Bytes
93b9482 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import logging
import re
import torch.utils.data
from fairseq.data import FairseqDataset
logger = logging.getLogger(__name__)
class OFADataset(FairseqDataset):
def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
self.split = split
self.dataset = dataset
self.bpe = bpe
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.bos = src_dict.bos()
self.eos = src_dict.eos()
self.pad = src_dict.pad()
self.bos_item = torch.LongTensor([self.bos])
self.eos_item = torch.LongTensor([self.eos])
def __len__(self):
return len(self.dataset)
def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
s = self.tgt_dict.encode_line(
line=self.bpe.encode(text) if use_bpe else text,
add_if_not_exist=False,
append_eos=False
).long()
if length is not None:
s = s[:length]
if append_bos:
s = torch.cat([self.bos_item, s])
if append_eos:
s = torch.cat([s, self.eos_item])
return s
def pre_question(self, question, max_ques_words):
question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
question = re.sub(
r"\s{2,}",
' ',
question,
)
question = question.rstrip('\n')
question = question.strip(' ')
# truncate question
question_words = question.split(' ')
if len(question_words) > max_ques_words:
question = ' '.join(question_words[:max_ques_words])
return question
def pre_caption(self, caption, max_words):
caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
caption = re.sub(
r"\s{2,}",
' ',
caption,
)
caption = caption.rstrip('\n')
caption = caption.strip(' ')
# truncate caption
caption_words = caption.split(' ')
if len(caption_words) > max_words:
caption = ' '.join(caption_words[:max_words])
return caption
|