evo-1-131k-base / tokenizer.py
Zymrael's picture
Update tokenizer.py
bc8a8a8 verified
raw
history blame
4.25 kB
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
from abc import ABC
import json
import pathlib
import torch
import tqdm
from tokenizers import Tokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from abc import abstractmethod
from typing import Any, List, Union
import numpy as np
class HFAutoTokenizer:
def __init__(self, vocab_file):
self.tokenizer = Tokenizer.from_file(vocab_file)
self.eos = "</s>"
self.bos = "<s>"
self.eos_id = self.tokenize(self.eos)
self.bos_id = self.tokenize(self.bos)
self.vsize = 32000
def encode_to_list(self, text):
return self.tokenizer.encode(text, add_special_tokens=False)
def tokenize_file(self, input_file, output_file, verbose=False):
if verbose:
print(f"Tokenizing file: {input_file}")
if pathlib.Path(output_file).exists():
print(f"Output file {output_file} already exists, skipping")
return
with open(input_file, "r") as fin, open(output_file, "w") as fout:
for line in tqdm.tqdm(fin):
if verbose:
print(f"Tokenizing line: {line[-200:]}")
data = json.loads(line.strip())
if "text" not in data.keys():
break
tokenized_data = self.tokenize(data["text"])
fout.write(json.dumps({"tokens": tokenized_data}) + "\n")
def tokenize(self, text: str, *args, **kwargs):
ids = self.tokenizer.encode(text)
if type(ids) == list:
return torch.tensor(ids)
else:
return torch.tensor(ids.ids)
def tokenize_batch(self, text_batch):
return self.tokenizer.encode_batch(text_batch)
def detokenize(self, token_ids, skip_special_tokens=False):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def detokenize_batch(self, token_ids_batch, skip_special_tokens=False):
out = []
for token_ids in token_ids_batch:
out.append(
self.detokenize(
[t.item() for t in token_ids],
skip_special_tokens=skip_special_tokens,
)
)
return out
@property
def eod(self):
return self.eod_id
@property
def vocab_size(self):
return 32000
class ByteTokenizer(PreTrainedTokenizer):
"""UTF-8 Encoder."""
def __init__(self):
super().__init__(
bos_token=self.decode_token(2),
eos_token=self.decode_token(0),
unk_token=self.decode_token(0),
pad_token=self.decode_token(1),
mask_token=self.decode_token(3),
)
@property
def vocab_size(self) -> int:
return 512
@classmethod
def from_pretrained(cls, *args, **kwargs):
return cls()
def get_vocab(self):
return {str(i): i for i in range(512)}
def clamp(self, n):
return max(32, min(n, self.vocab_size))
def decode_token(self, token: int):
return str(chr(self.clamp(token)))
def __call__(self, text: str, return_tensors: bool = False, *args, **kwargs):
ids = torch.tensor(self.tokenize(text), dtype=torch.long).unsqueeze(0)
return {"input_ids": ids} if return_tensors == False else ids
def _tokenize(self, text: str):
return np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
def tokenize(self, text: str):
return self._tokenize(text).tolist()
def tokenize_batch(self, text_batch: Union[List[str], str]):
if isinstance(text_batch, list):
return [self.tokenize(s) for s in text_batch]
else:
return self.tokenize(text_batch)
def decode(self, token_ids):
return "".join(list(map(self.decode_token, token_ids)))
def decode_batch(self, token_ids: Union[List[str], str]):
if isinstance(token_ids, list):
return [self.decode(s) for s in token_ids]
elif isinstance(token_ids, torch.Tensor):
return [self.decode(s) for s in token_ids.tolist()]
else:
return self.decode(token_ids)