efficient_audio_captioning / text_tokenizer.py
wsntxxn
Add AudioCaps checkpoint
6065472
raw
history blame
3.7 kB
import pickle
from pathlib import Path
import numpy as np
from utils.train_util import pad_sequence
class DictTokenizer:
def __init__(self,
tokenizer_path: str = None,
max_length: int = 20) -> None:
self.word2idx = {}
self.idx2word = {}
self.idx = 0
self.add_word("<pad>")
self.add_word("<start>")
self.add_word("<end>")
self.add_word("<unk>")
if tokenizer_path is not None and Path(tokenizer_path).exists():
state_dict = pickle.load(open(tokenizer_path, "rb"))
self.load_state_dict(state_dict)
self.loaded = True
else:
self.loaded = False
self.bos, self.eos = self.word2idx["<start>"], self.word2idx["<end>"]
self.pad = self.word2idx["<pad>"]
self.max_length = max_length
def add_word(self, word):
if not word in self.word2idx:
self.word2idx[word] = self.idx
self.idx2word[self.idx] = word
self.idx += 1
def encode_word(self, word):
if word in self.word2idx:
return self.word2idx[word]
else:
return self.word2idx["<unk>"]
def __call__(self, texts):
assert isinstance(texts, list), "the input must be List[str]"
batch_tokens = []
for text in texts:
tokens = [self.encode_word(token) for token in text.split()][:self.max_length]
tokens = [self.bos] + tokens + [self.eos]
tokens = np.array(tokens)
batch_tokens.append(tokens)
caps, cap_lens = pad_sequence(batch_tokens, self.pad)
return {
"cap": caps,
"cap_len": cap_lens
}
def decode(self, batch_token_ids):
output = []
for token_ids in batch_token_ids:
tokens = []
for token_id in token_ids:
if token_id == self.eos:
break
elif token_id == self.bos:
continue
tokens.append(self.idx2word[token_id])
output.append(" ".join(tokens))
return output
def __len__(self):
return len(self.word2idx)
def state_dict(self):
return self.word2idx
def load_state_dict(self, state_dict):
self.word2idx = state_dict
self.idx2word = {idx: word for word, idx in self.word2idx.items()}
self.idx = len(self.word2idx)
class HuggingfaceTokenizer:
def __init__(self,
model_name_or_path,
max_length) -> None:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.max_length = max_length
self.bos, self.eos = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id
self.pad = self.tokenizer.pad_token_id
self.loaded = True
def __call__(self, texts):
assert isinstance(texts, list), "the input must be List[str]"
batch_token_dict = self.tokenizer(texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt")
batch_token_dict["cap"] = batch_token_dict["input_ids"]
cap_lens = batch_token_dict["attention_mask"].sum(dim=1)
cap_lens = cap_lens.numpy().astype(np.int32)
batch_token_dict["cap_len"] = cap_lens
return batch_token_dict
def decode(self, batch_token_ids):
return self.tokenizer.batch_decode(batch_token_ids, skip_special_tokens=True)