|
import pickle |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
from utils.train_util import pad_sequence |
|
|
|
|
|
class DictTokenizer: |
|
|
|
def __init__(self, |
|
tokenizer_path: str = None, |
|
max_length: int = 20) -> None: |
|
self.word2idx = {} |
|
self.idx2word = {} |
|
self.idx = 0 |
|
self.add_word("<pad>") |
|
self.add_word("<start>") |
|
self.add_word("<end>") |
|
self.add_word("<unk>") |
|
if tokenizer_path is not None and Path(tokenizer_path).exists(): |
|
state_dict = pickle.load(open(tokenizer_path, "rb")) |
|
self.load_state_dict(state_dict) |
|
self.loaded = True |
|
else: |
|
self.loaded = False |
|
self.bos, self.eos = self.word2idx["<start>"], self.word2idx["<end>"] |
|
self.pad = self.word2idx["<pad>"] |
|
self.max_length = max_length |
|
|
|
def add_word(self, word): |
|
if not word in self.word2idx: |
|
self.word2idx[word] = self.idx |
|
self.idx2word[self.idx] = word |
|
self.idx += 1 |
|
|
|
def encode_word(self, word): |
|
if word in self.word2idx: |
|
return self.word2idx[word] |
|
else: |
|
return self.word2idx["<unk>"] |
|
|
|
def __call__(self, texts): |
|
assert isinstance(texts, list), "the input must be List[str]" |
|
batch_tokens = [] |
|
for text in texts: |
|
tokens = [self.encode_word(token) for token in text.split()][:self.max_length] |
|
tokens = [self.bos] + tokens + [self.eos] |
|
tokens = np.array(tokens) |
|
batch_tokens.append(tokens) |
|
caps, cap_lens = pad_sequence(batch_tokens, self.pad) |
|
return { |
|
"cap": caps, |
|
"cap_len": cap_lens |
|
} |
|
|
|
def decode(self, batch_token_ids): |
|
output = [] |
|
for token_ids in batch_token_ids: |
|
tokens = [] |
|
for token_id in token_ids: |
|
if token_id == self.eos: |
|
break |
|
elif token_id == self.bos: |
|
continue |
|
tokens.append(self.idx2word[token_id]) |
|
output.append(" ".join(tokens)) |
|
return output |
|
|
|
def __len__(self): |
|
return len(self.word2idx) |
|
|
|
def state_dict(self): |
|
return self.word2idx |
|
|
|
def load_state_dict(self, state_dict): |
|
self.word2idx = state_dict |
|
self.idx2word = {idx: word for word, idx in self.word2idx.items()} |
|
self.idx = len(self.word2idx) |
|
|
|
|
|
class HuggingfaceTokenizer: |
|
|
|
def __init__(self, |
|
model_name_or_path, |
|
max_length) -> None: |
|
from transformers import AutoTokenizer |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
self.max_length = max_length |
|
self.bos, self.eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id |
|
self.pad = self.tokenizer.pad_token_id |
|
self.loaded = True |
|
|
|
def __call__(self, texts): |
|
assert isinstance(texts, list), "the input must be List[str]" |
|
batch_token_dict = self.tokenizer(texts, |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors="pt") |
|
batch_token_dict["cap"] = batch_token_dict["input_ids"] |
|
cap_lens = batch_token_dict["attention_mask"].sum(dim=1) |
|
cap_lens = cap_lens.numpy().astype(np.int32) |
|
batch_token_dict["cap_len"] = cap_lens |
|
return batch_token_dict |
|
|
|
def decode(self, batch_token_ids): |
|
return self.tokenizer.batch_decode(batch_token_ids, skip_special_tokens=True) |
|
|