sobir-hf's picture
Added fa-tj model and language detection
f2dafec
import torch
from torch.nn.utils.rnn import pad_sequence
def load_config(path):
d = torch.load(path, map_location='cpu')
return d['config']
class Tokenizer:
def __init__(self, config) -> None:
self.src_vocab = config['src_vocab']
self.trg_vocab = config['trg_vocab']
self.src_char_index = {char:i for i,char in enumerate(self.src_vocab)}
self.trg_char_index = {char:i for i,char in enumerate(self.trg_vocab)}
self.trg_null_idx = self.trg_char_index['<NULL>']
self.src_null_idx = self.src_char_index['<NULL>']
self.src_pad_idx = self.src_char_index['<PAD>']
self.trg_pad_idx = self.trg_char_index['<PAD>']
self.trg_unk_idx = self.trg_char_index['<UNK>']
self.src_unk_idx = self.src_char_index['<UNK>']
def encode_src(self, text: str):
src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text]
src = torch.tensor(src, dtype=torch.long)
return src
def decode_src(self, src: torch.Tensor):
return [self.src_vocab[i] for i in src]
def decode_trg(self, trg: torch.Tensor):
trg = trg.flatten().tolist()
trg = [r for r in trg if r != self.trg_null_idx]
return [self.trg_vocab[i] for i in trg]
def collate_fn(self, batch):
src = [x for x, _ in batch]
trg = [y for _, y in batch]
src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx)
trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx)
return src_padded, trg_padded
def language_detect(text, tokenizer_tj_fa: "Tokenizer", tokenizer_fa_tj: "Tokenizer"):
# Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_tj_fa
percentage_tj_fa = sum(char in tokenizer_tj_fa.src_vocab for char in text) / len(text)
# Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_fa_tj
percentage_fa_tj = sum(char in tokenizer_fa_tj.src_vocab for char in text) / len(text)
# Return the language code of the tokenizer with the higher percentage
if percentage_tj_fa > percentage_fa_tj:
return 'tj'
else:
return 'fa'