import pickle from pathlib import Path import numpy as np from utils.train_util import pad_sequence class DictTokenizer: def __init__(self, tokenizer_path: str = None, max_length: int = 20) -> None: self.word2idx = {} self.idx2word = {} self.idx = 0 self.add_word("") self.add_word("") self.add_word("") self.add_word("") if tokenizer_path is not None and Path(tokenizer_path).exists(): state_dict = pickle.load(open(tokenizer_path, "rb")) self.load_state_dict(state_dict) self.loaded = True else: self.loaded = False self.bos, self.eos = self.word2idx[""], self.word2idx[""] self.pad = self.word2idx[""] self.max_length = max_length def add_word(self, word): if not word in self.word2idx: self.word2idx[word] = self.idx self.idx2word[self.idx] = word self.idx += 1 def encode_word(self, word): if word in self.word2idx: return self.word2idx[word] else: return self.word2idx[""] def __call__(self, texts): assert isinstance(texts, list), "the input must be List[str]" batch_tokens = [] for text in texts: tokens = [self.encode_word(token) for token in text.split()][:self.max_length] tokens = [self.bos] + tokens + [self.eos] tokens = np.array(tokens) batch_tokens.append(tokens) caps, cap_lens = pad_sequence(batch_tokens, self.pad) return { "cap": caps, "cap_len": cap_lens } def decode(self, batch_token_ids): output = [] for token_ids in batch_token_ids: tokens = [] for token_id in token_ids: if token_id == self.eos: break elif token_id == self.bos: continue tokens.append(self.idx2word[token_id]) output.append(" ".join(tokens)) return output def __len__(self): return len(self.word2idx) def state_dict(self): return self.word2idx def load_state_dict(self, state_dict): self.word2idx = state_dict self.idx2word = {idx: word for word, idx in self.word2idx.items()} self.idx = len(self.word2idx) class HuggingfaceTokenizer: def __init__(self, model_name_or_path, max_length) -> None: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.max_length = max_length self.bos, self.eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id self.pad = self.tokenizer.pad_token_id self.loaded = True def __call__(self, texts): assert isinstance(texts, list), "the input must be List[str]" batch_token_dict = self.tokenizer(texts, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt") batch_token_dict["cap"] = batch_token_dict["input_ids"] cap_lens = batch_token_dict["attention_mask"].sum(dim=1) cap_lens = cap_lens.numpy().astype(np.int32) batch_token_dict["cap_len"] = cap_lens return batch_token_dict def decode(self, batch_token_ids): return self.tokenizer.batch_decode(batch_token_ids, skip_special_tokens=True)