File size: 538 Bytes
0ba900e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import numpy as np
class TokenWeighter:
def __init__(self, tokenizer):
self.tokenizer_ = tokenizer
self.proba = self.get_token_proba()
def get_token_proba(self):
valid_token_mask = self._filter_short_partial(self.tokenizer_.vocab)
return valid_token_mask
def _filter_short_partial(self, vocab):
valid_token_ids = [v for k, v in vocab.items() if len(k) > 1 and "#" not in k]
is_valid = np.zeros(len(vocab.keys()))
is_valid[valid_token_ids] = 1
return is_valid |