import pickle import re from typing import List class TAIMGANTokenizer: def __init__(self, captions_path): with open(captions_path, "rb") as ckpt_file: captions = pickle.load(ckpt_file) self.ix_to_word = captions[2] self.word_to_ix = captions[3] self.token_regex = r'\w+' self.pad_token_id = self.word_to_ix[""] self.pad_repr = "[PAD]" def encode(self, text: str) -> List[int]: return [self.word_to_ix.get(word, self.pad_token_id) for word in re.findall(self.token_regex, text.lower())] def decode(self, tokens: List[int]) -> str: return ' '.join([self.ix_to_word[token] if token != self.pad_token_id else self.pad_repr for token in tokens])