import numpy as np class TokenWeighter: def __init__(self, tokenizer): self.tokenizer_ = tokenizer self.proba = self.get_token_proba() def get_token_proba(self): valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab) return valid_token_mask def _filter_short_partial(self, vocab): valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k] is_valid = np.zeros(len(vocab.keys())) is_valid[valid_token_ids] = 1 return is_valid