|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
from transformers import BatchEncoding |
|
from typing import Dict, List, Tuple, Union |
|
from sentencepiece import SentencePieceProcessor |
|
|
|
_PATH = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
|
|
class IndicTransTokenizer: |
|
def __init__( |
|
self, |
|
src_vocab_fp=None, |
|
tgt_vocab_fp=None, |
|
src_spm_fp=None, |
|
tgt_spm_fp=None, |
|
unk_token="<unk>", |
|
bos_token="<s>", |
|
eos_token="</s>", |
|
pad_token="<pad>", |
|
direction="indic-en", |
|
model_max_length=256, |
|
): |
|
self.model_max_length = model_max_length |
|
|
|
self.supported_langs = [ |
|
"asm_Beng", |
|
"ben_Beng", |
|
"brx_Deva", |
|
"doi_Deva", |
|
"eng_Latn", |
|
"gom_Deva", |
|
"guj_Gujr", |
|
"hin_Deva", |
|
"kan_Knda", |
|
"kas_Arab", |
|
"kas_Deva", |
|
"mai_Deva", |
|
"mal_Mlym", |
|
"mar_Deva", |
|
"mni_Beng", |
|
"mni_Mtei", |
|
"npi_Deva", |
|
"ory_Orya", |
|
"pan_Guru", |
|
"san_Deva", |
|
"sat_Olck", |
|
"snd_Arab", |
|
"snd_Deva", |
|
"tam_Taml", |
|
"tel_Telu", |
|
"urd_Arab", |
|
] |
|
|
|
self.src_vocab_fp = ( |
|
src_vocab_fp |
|
if (src_vocab_fp is not None) |
|
else os.path.join(_PATH, direction, "dict.SRC.json") |
|
) |
|
self.tgt_vocab_fp = ( |
|
tgt_vocab_fp |
|
if (tgt_vocab_fp is not None) |
|
else os.path.join(_PATH, direction, "dict.TGT.json") |
|
) |
|
self.src_spm_fp = ( |
|
src_spm_fp |
|
if (src_spm_fp is not None) |
|
else os.path.join(_PATH, direction, "model.SRC") |
|
) |
|
self.tgt_spm_fp = ( |
|
tgt_spm_fp |
|
if (tgt_spm_fp is not None) |
|
else os.path.join(_PATH, direction, "model.TGT") |
|
) |
|
|
|
self.unk_token = unk_token |
|
self.pad_token = pad_token |
|
self.eos_token = eos_token |
|
self.bos_token = bos_token |
|
|
|
self.encoder = self._load_json(self.src_vocab_fp) |
|
if self.unk_token not in self.encoder: |
|
raise KeyError("<unk> token must be in vocab") |
|
assert self.pad_token in self.encoder |
|
self.encoder_rev = {v: k for k, v in self.encoder.items()} |
|
|
|
self.decoder = self._load_json(self.tgt_vocab_fp) |
|
if self.unk_token not in self.encoder: |
|
raise KeyError("<unk> token must be in vocab") |
|
assert self.pad_token in self.encoder |
|
self.decoder_rev = {v: k for k, v in self.decoder.items()} |
|
|
|
|
|
self.src_spm = self._load_spm(self.src_spm_fp) |
|
self.tgt_spm = self._load_spm(self.tgt_spm_fp) |
|
|
|
def is_special_token(self, x: str): |
|
return (x == self.pad_token) or (x == self.bos_token) or (x == self.eos_token) |
|
|
|
def get_vocab_size(self, src: bool) -> int: |
|
"""Returns the size of the vocabulary""" |
|
return len(self.encoder) if src else len(self.decoder) |
|
|
|
def _load_spm(self, path: str) -> SentencePieceProcessor: |
|
return SentencePieceProcessor(model_file=path) |
|
|
|
def _save_json(self, data, path: str) -> None: |
|
with open(path, "w", encoding="utf-8") as f: |
|
json.dump(data, f, indent=2) |
|
|
|
def _load_json(self, path: str) -> Union[Dict, List]: |
|
with open(path, "r", encoding="utf-8") as f: |
|
return json.load(f) |
|
|
|
def _convert_token_to_id(self, token: str, src: bool) -> int: |
|
"""Converts an token (str) into an index (integer) using the source/target vocabulary map.""" |
|
return ( |
|
self.encoder.get(token, self.encoder[self.unk_token]) |
|
if src |
|
else self.decoder.get(token, self.encoder[self.unk_token]) |
|
) |
|
|
|
def _convert_id_to_token(self, index: int, src: bool) -> str: |
|
"""Converts an index (integer) into a token (str) using the source/target vocabulary map.""" |
|
return ( |
|
self.encoder_rev.get(index, self.unk_token) |
|
if src |
|
else self.decoder_rev.get(index, self.unk_token) |
|
) |
|
|
|
def _convert_tokens_to_string(self, tokens: List[str], src: bool) -> str: |
|
"""Uses sentencepiece model for detokenization""" |
|
if src: |
|
if tokens[0] in self.supported_langs and tokens[1] in self.supported_langs: |
|
tokens = tokens[2:] |
|
return " ".join(tokens) |
|
else: |
|
return " ".join(tokens) |
|
|
|
def _remove_translation_tags(self, text: str) -> Tuple[List, str]: |
|
"""Removes the translation tags before text normalization and tokenization.""" |
|
tokens = text.split(" ") |
|
return tokens[:2], " ".join(tokens[2:]) |
|
|
|
def _tokenize_src_line(self, line: str) -> List[str]: |
|
"""Tokenizes a source line.""" |
|
tags, text = self._remove_translation_tags(line) |
|
tokens = self.src_spm.encode(text, out_type=str) |
|
return tags + tokens |
|
|
|
def _tokenize_tgt_line(self, line: str) -> List[str]: |
|
"""Tokenizes a target line.""" |
|
return self.tgt_spm.encode(line, out_type=str) |
|
|
|
def tokenize(self, text: str, src: bool) -> List[str]: |
|
"""Tokenizes a string into tokens using the source/target vocabulary.""" |
|
return self._tokenize_src_line(text) if src else self._tokenize_tgt_line(text) |
|
|
|
def batch_tokenize(self, batch: List[str], src: bool) -> List[List[str]]: |
|
"""Tokenizes a list of strings into tokens using the source/target vocabulary.""" |
|
return [self.tokenize(line, src) for line in batch] |
|
|
|
def _create_attention_mask(self, ids: List[int], max_seq_len: int) -> List[int]: |
|
"""Creates a attention mask for the input sequence.""" |
|
return ([0] * (max_seq_len - len(ids))) + ([1] * (len(ids) + 1)) |
|
|
|
def _pad_batch(self, tokens: List[str], max_seq_len: int) -> List[str]: |
|
"""Pads a batch of tokens and adds BOS/EOS tokens.""" |
|
return ( |
|
([self.pad_token] * (max_seq_len - len(tokens))) + tokens + [self.eos_token] |
|
) |
|
|
|
def _decode_line(self, ids: List[int], src: bool) -> List[str]: |
|
return [self._convert_id_to_token(_id, src) for _id in ids] |
|
|
|
def _encode_line(self, tokens: List[str], src: bool) -> List[int]: |
|
return [self._convert_token_to_id(token, src) for token in tokens] |
|
|
|
def _strip_special_tokens(self, tokens: List[str]) -> List[str]: |
|
return [token for token in tokens if not self.is_special_token(token)] |
|
|
|
def _single_input_preprocessing( |
|
self, tokens: List[str], src: bool, max_seq_len: int |
|
) -> Tuple[List[int], List[int], int]: |
|
"""Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map.""" |
|
attention_mask = self._create_attention_mask(tokens, max_seq_len) |
|
padded_tokens = self._pad_batch(tokens, max_seq_len) |
|
input_ids = self._encode_line(padded_tokens, src) |
|
return input_ids, attention_mask |
|
|
|
def _single_output_postprocessing(self, ids: List[int], src: bool) -> str: |
|
"""Detokenizes a list of integer ids into a string using the source/target vocabulary.""" |
|
tokens = self._decode_line(ids, src) |
|
tokens = self._strip_special_tokens(tokens) |
|
return self._convert_tokens_to_string(tokens, src) |
|
|
|
def __call__( |
|
self, |
|
batch: Union[list, str], |
|
src: bool, |
|
truncation: bool = False, |
|
padding: str = "longest", |
|
max_length: int = None, |
|
return_tensors: str = "pt", |
|
return_attention_mask: bool = True, |
|
return_length: bool = False, |
|
) -> BatchEncoding: |
|
"""Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map.""" |
|
assert padding in [ |
|
"longest", |
|
"max_length", |
|
], "padding should be either 'longest' or 'max_length'" |
|
|
|
if not isinstance(batch, list): |
|
raise TypeError( |
|
f"batch must be a list, but current batch is of type {type(batch)}" |
|
) |
|
|
|
|
|
batch = self.batch_tokenize(batch, src) |
|
|
|
|
|
if truncation and max_length is not None: |
|
batch = [ids[:max_length] for ids in batch] |
|
|
|
lengths = [len(ids) for ids in batch] |
|
|
|
max_seq_len = max(lengths) if padding == "longest" else max_length |
|
|
|
input_ids, attention_mask = zip( |
|
*[ |
|
self._single_input_preprocessing( |
|
tokens=tokens, src=src, max_seq_len=max_seq_len |
|
) |
|
for tokens in batch |
|
] |
|
) |
|
|
|
_data = {"input_ids": input_ids} |
|
|
|
if return_attention_mask: |
|
_data["attention_mask"] = attention_mask |
|
|
|
if return_length: |
|
_data["lengths"] = lengths |
|
|
|
return BatchEncoding(_data, tensor_type=return_tensors) |
|
|
|
def batch_decode( |
|
self, batch: Union[list, torch.Tensor], src: bool |
|
) -> List[List[str]]: |
|
"""Detokenizes a list of integer ids or a tensor into a list of strings using the source/target vocabulary.""" |
|
|
|
if isinstance(batch, torch.Tensor): |
|
batch = batch.detach().cpu().tolist() |
|
|
|
return [self._single_output_postprocessing(ids=ids, src=src) for ids in batch] |
|
|