Spaces:
Runtime error
Runtime error
import logging | |
import torch.utils.data | |
from fairseq.data import FairseqDataset | |
logger = logging.getLogger(__name__) | |
class OFADataset(FairseqDataset): | |
def __len__(self): | |
return len(self.dataset) | |
def encode_text(self, text, length=None, append_bos=False, append_eos=False): | |
s = self.tgt_dict.encode_line( | |
line=self.bpe.encode(text), | |
add_if_not_exist=False, | |
append_eos=False | |
).long() | |
if length is not None: | |
s = s[:length] | |
if append_bos: | |
s = torch.cat([self.bos_item, s]) | |
if append_eos: | |
s = torch.cat([s, self.eos_item]) | |
return s | |