File size: 4,907 Bytes
f52cf42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import argparse
import sentencepiece as spm
from utils import utils_cls
from model import BanglaTransformer
from config import config as cfg
torch.manual_seed(0)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
uobj = utils_cls(device=device)

__MODULE__       = "Bangla Language Translation"
__MAIL__         = "saifulbrur79@gmail.com"
__MODIFICAIOTN__ = "28/03/2023"
__LICENSE__      = "MIT"



class Bn2EnTranslation:
    def __init__(self):
        self.bn_tokenizer='./model/bn_model.model'
        self.en_tokenizer='./model/en_model.model'
        self.bn_vocab='./model/bn_vocab.pkl'
        self.en_vocab='./model/en_vocab.pkl'
        self.model='./model/pytorch_model.pt'

    def read_data(self, data_path):
        with open(data_path, "r") as f:
            data = f.readlines()
        data = list(map(lambda x: [x.split("\t")[0], x.split("\t")[1].replace("\n", "")], data))
        return data

    def load_tokenizer(self, tokenizer_path:str = "")->object:
        _tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
        return _tokenizer

    def get_vocab(self, BN_VOCAL_PATH:str="", EN_VOCAL_PATH:str=""):
        bn_vocal, en_vocal = uobj.load_bn_vocal(BN_VOCAL_PATH), uobj.load_en_vocal(EN_VOCAL_PATH)
        return bn_vocal, en_vocal

    def load_model(self, model_path:str = "", SRC_VOCAB_SIZE:int=0, TGT_VOCAB_SIZE:int=0):
        model = BanglaTransformer(
            cfg.NUM_ENCODER_LAYERS, cfg.NUM_DECODER_LAYERS, cfg.EMB_SIZE,  SRC_VOCAB_SIZE,
            TGT_VOCAB_SIZE, cfg.FFN_HID_DIM, nhead= cfg.NHEAD)
        model.to(device)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        return model

    def greedy_decode(self, model, src, src_mask, max_len, start_symbol, eos_index):
        src = src.to(device)
        src_mask = src_mask.to(device)
        memory = model.encode(src, src_mask)
        ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
        for i in range(max_len-1):
            memory = memory.to(device)
            memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
            tgt_mask = (uobj.generate_square_subsequent_mask(ys.size(0))
                                        .type(torch.bool)).to(device)
            out = model.decode(ys, memory, tgt_mask)
            out = out.transpose(0, 1)
            prob = model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim = 1)
            next_word = next_word.item()
            ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
            if next_word == eos_index:
                break
        return ys
    def get_bntoen_model(self):
        print("Tokenizer Loading ...... : ", end="", flush=True)
        bn_tokenizer = self.load_tokenizer(tokenizer_path=self.bn_tokenizer)
        print("Done")
        print("Vocab Loading ......     : ", end="", flush=True)
        bn_vocab, en_vocab = self.get_vocab(BN_VOCAL_PATH=self.bn_vocab, EN_VOCAL_PATH=self.en_vocab)
        print("Done")
        print("Model Loading ......     : ", end="", flush=True)
        model = self.load_model(model_path=self.model, SRC_VOCAB_SIZE=len(bn_vocab), TGT_VOCAB_SIZE=len(en_vocab))
        print("Done") 

        models = {
            "bn_tokenizer" : bn_tokenizer, 
            "bn_vocab" : bn_vocab,
            "en_vocab" :  en_vocab, 
            "model" : model
        }
        return models 
    def translate(self, text, models):
        model = models["model"]
        src_vocab = models["bn_vocab"]
        tgt_vocab = models["en_vocab"]
        src_tokenizer = models["bn_tokenizer"]
        src = text
        PAD_IDX, BOS_IDX, EOS_IDX= src_vocab['<pad>'], src_vocab['<bos>'], src_vocab['<eos>']
        tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
        num_tokens = len(tokens)
        src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        tgt_tokens = self.greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, eos_index= EOS_IDX).flatten()
        p_text = " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
        pts = " ".join(list(map(lambda x : x , p_text.replace(" ", "").split("▁"))))
        return pts.strip()




if __name__ == "__main__":
    print(torch.cuda.get_device_name(0))
    text = "এই উপজেলায় ১টি সরকারি কলেজ রয়েছে"

    obj = Bn2EnTranslation()
    models = obj.get_bntoen_model()
    pre = obj.translate(text, models)
    print("="*20)
    print(f"input : {text}")
    print(f"prediction: {pre}")