Spaces:
Sleeping
Sleeping
from typing import Union, List | |
from transformers import AutoTokenizer | |
import torch | |
class HFPTTokenizer(object): | |
def __init__(self, pt_name=None): | |
self.pt_name = pt_name | |
self.added_sep_token = 0 | |
self.added_cls_token = 0 | |
self.enable_add_tokens = False | |
self.gpt_special_case = (not self.enable_add_tokens) and ("gpt" in self.pt_name) | |
if pt_name is None: | |
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained(pt_name) | |
# Adding tokens to GPT causing NaN training loss. | |
# Disable for now until further investigation. | |
if self.enable_add_tokens: | |
if self.tokenizer.sep_token is None: | |
self.tokenizer.add_special_tokens({"sep_token": "<SEP>"}) | |
self.added_sep_token = 1 | |
if self.tokenizer.cls_token is None: | |
self.tokenizer.add_special_tokens({"cls_token": "<CLS>"}) | |
self.added_cls_token = 1 | |
if self.gpt_special_case: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.tokenizer.sep_token = self.tokenizer.eos_token | |
def get_eot_token(self): | |
return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False)[0] | |
def get_sot_token(self): | |
return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False)[0] | |
def get_eot_token_list(self): | |
return self.tokenizer.encode(self.tokenizer.sep_token, add_special_tokens=False) | |
def get_sot_token_list(self): | |
return self.tokenizer.encode(self.tokenizer.cls_token, add_special_tokens=False) | |
def get_tokenizer_obj(self): | |
return self.tokenizer | |
# Language model needs to know if new tokens | |
# were added to the dictionary. | |
def check_added_tokens(self): | |
return self.added_sep_token + self.added_cls_token | |
def tokenize(self, texts: Union[str, List[str]], context_length: int = 77): | |
if isinstance(texts, str): | |
texts = [texts] | |
padding = "max_length" | |
seqstart = [] | |
seqtok = [] | |
seqend = [] | |
max_length = context_length | |
if self.added_cls_token > 0: | |
seqstart = self.get_sot_token_list() | |
max_length = max_length - 1 | |
if self.added_sep_token > 0: | |
seqend = self.get_eot_token_list() | |
max_length = max_length - 1 | |
tokens = self.tokenizer(texts, padding=padding, truncation=True, max_length=max_length)["input_ids"] | |
for i in range(len(tokens)): | |
tokens[i] = seqstart + tokens[i] + seqend | |
if self.gpt_special_case: | |
for i in range(len(tokens)): | |
tokens[i][-1] = self.get_eot_token() | |
# print(str(tokens)) | |
result = torch.Tensor(tokens).type(torch.LongTensor) | |
return result | |
def get_vocab_size(self): | |
return self.tokenizer.vocab_size | |
def __call__(self, texts: Union[str, List[str]], context_length: int = 77): | |
return self.tokenize(texts, context_length) | |