import torch | |
from transformers import AutoTokenizer | |
class DaedalusTokenizer(AutoTokenizer): | |
def __init__(self, config): | |
super(DaedalusTokenizer, self).__init__(config) | |
self.config = config | |
def encode(self, text): | |
return self.encode_plus(text, max_length=self.config.max_seq_length, padding='max_length', truncation=True) | |
def decode(self, ids): | |
return self.decode(ids, skip_special_tokens=True) |