File size: 2,278 Bytes
b56c828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2dafec
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch

from torch.nn.utils.rnn import pad_sequence


def load_config(path):
    d = torch.load(path, map_location='cpu')
    return d['config']


class Tokenizer:
    def __init__(self, config) -> None:
        self.src_vocab = config['src_vocab']
        self.trg_vocab = config['trg_vocab']

        self.src_char_index = {char:i for i,char in enumerate(self.src_vocab)}
        self.trg_char_index = {char:i for i,char in enumerate(self.trg_vocab)}
        self.trg_null_idx = self.trg_char_index['<NULL>']
        self.src_null_idx = self.src_char_index['<NULL>']
        self.src_pad_idx = self.src_char_index['<PAD>']
        self.trg_pad_idx = self.trg_char_index['<PAD>']
        self.trg_unk_idx = self.trg_char_index['<UNK>']
        self.src_unk_idx = self.src_char_index['<UNK>']

    def encode_src(self, text: str):
        src = [self.src_char_index.get(src_char, self.src_unk_idx) for src_char in text]
        src = torch.tensor(src, dtype=torch.long)
        return src
    
    def decode_src(self, src: torch.Tensor):
        return [self.src_vocab[i] for i in src]
    
    def decode_trg(self, trg: torch.Tensor):
        trg = trg.flatten().tolist()
        trg = [r for r in trg if r != self.trg_null_idx]

        return [self.trg_vocab[i] for i in trg]

    def collate_fn(self, batch):
        src = [x for x, _ in batch]
        trg = [y for _, y in batch]
        src_padded = pad_sequence(src, batch_first=True, padding_value=self.src_pad_idx)
        trg_padded = pad_sequence(trg, batch_first=True, padding_value=self.trg_pad_idx)
        return src_padded, trg_padded


def language_detect(text, tokenizer_tj_fa: "Tokenizer", tokenizer_fa_tj: "Tokenizer"):
    # Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_tj_fa
    percentage_tj_fa = sum(char in tokenizer_tj_fa.src_vocab for char in text) / len(text)

    # Calculate the percentage of characters in text that are present in the source vocabulary of tokenizer_fa_tj
    percentage_fa_tj = sum(char in tokenizer_fa_tj.src_vocab for char in text) / len(text)

    # Return the language code of the tokenizer with the higher percentage
    if percentage_tj_fa > percentage_fa_tj:
        return 'tj'
    else:
        return 'fa'