import torch from torch.nn.utils.rnn import pad_sequence def load_config(path): d = torch.load(path, map_location='cpu') return d['config'] class Tokenizer: def __init__(self, config) -> None: self.src_vocab = config['src_vocab'] self.trg_vocab = config['trg_vocab'] self.src_char_index = {char:i for i,char in enumerate(self.src_vocab)} self.trg_char_index = {char:i for i,char in enumerate(self.trg_vocab)} self.trg_null_idx = self.trg_char_index[''] self.src_null_idx = self.src_char_index[''] self.src_pad_idx = self.src_char_index[''] self.trg_pad_idx = self.trg_char_index[''] self.trg_unk_idx = self.trg_char_index[''] self.src_unk_idx = self.src_char_index[''] def encode_src(self, text: str): src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text] src = torch.tensor(src, dtype=torch.long) return src def decode_src(self, src: torch.Tensor): return [self.src_vocab[i] for i in src] def decode_trg(self, trg: torch.Tensor): trg = trg.flatten().tolist() trg = [r for r in trg if r != self.trg_null_idx] return [self.trg_vocab[i] for i in trg] def collate_fn(self, batch): src = [x for x, _ in batch] trg = [y for _, y in batch] src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx) trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx) return src_padded, trg_padded def language_detect(text, tokenizer_tj_fa: "Tokenizer", tokenizer_fa_tj: "Tokenizer"): # Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_tj_fa percentage_tj_fa = sum(char in tokenizer_tj_fa.src_vocab for char in text) / len(text) # Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_fa_tj percentage_fa_tj = sum(char in tokenizer_fa_tj.src_vocab for char in text) / len(text) # Return the language code of the tokenizer with the higher percentage if percentage_tj_fa > percentage_fa_tj: return 'tj' else: return 'fa'