hoang1007's picture
Upload 69 files
44db343
raw
history blame
No virus
2.24 kB
import sys
sys.path.append("..")
from dataset.vocab import Vocab
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
class TokenAligner():
def __init__(self, tokenizer: AutoTokenizer, vocab: Vocab):
self.tokenizer = tokenizer
self.vocab = vocab
"""
params:
text ---- str
"""
def _char_tokenize(self, text):
characters = list(text)
tokens = [ token + "@@" if i < len(characters) - 1 and characters[i + 1] != " " else token for i, token in enumerate(characters)]
tokens = [token for token in tokens if token not in [" @@", " "]]
encoded = self.tokenizer.encode_plus(tokens, return_tensors = "pt")
token_ids = encoded['input_ids'].squeeze(0)
attn_mask = encoded['attention_mask'].squeeze(0)
return tokens, token_ids, attn_mask
def char_tokenize(self, batch_texts):
doc = dict()
doc['tokens'] = []
doc['token_ids'] = []
doc['attention_mask'] = []
for text in batch_texts:
tokens, token_ids, attn_mask = self._char_tokenize(text)
doc['tokens'].append(tokens)
doc['token_ids'].append(token_ids)
doc['attention_mask'].append(attn_mask)
return doc
def tokenize_for_transformer_with_tokenization(self, batch_noised_text, batch_label_texts = None):
docs = self.char_tokenize(batch_noised_text)
batch_srcs = docs['token_ids']
batch_attention_masks = docs['attention_mask']
batch_attention_masks = pad_sequence(batch_attention_masks ,
batch_first=True, padding_value=0)
batch_srcs = pad_sequence(batch_srcs ,
batch_first=True, padding_value=self.tokenizer.pad_token_id)
if batch_label_texts != None:
batch_lengths = [len(self.tokenizer.tokenize(text)) for text in batch_label_texts]
batch_tgts = self.tokenizer.batch_encode_plus(batch_label_texts, max_length = 512,
truncation = True, padding=True, return_tensors="pt")['input_ids']
return batch_srcs, batch_tgts, batch_lengths, batch_attention_masks
return batch_srcs, batch_attention_masks