File size: 5,211 Bytes
94bea44
f52cf42
 
8fdaf65
f52cf42
 
 
 
 
234e9f7
 
f52cf42
 
 
 
 
 
 
b4b95df
 
 
4e70c1c
7d77c34
f52cf42
 
 
94bea44
 
 
 
 
f52cf42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fdaf65
f52cf42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fdaf65
 
 
 
 
 
 
f52cf42
8fdaf65
 
f52cf42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import torch
import argparse
import streamlit as st
import sentencepiece as spm
from utils import utils_cls
from model import BanglaTransformer
from config import config as cfg
torch.manual_seed(0)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
uobj = utils_cls(device=device)

__MODULE__       = "Bangla Language Translation"
__MAIL__         = "saifulbrur79@gmail.com"
__MODIFICAIOTN__ = "28/03/2023"
__LICENSE__      = "MIT"

st.write(""" Bangla to English Translation """)


BASE_URL = "./model"


class Bn2EnTranslation:
    def __init__(self):
        self.bn_tokenizer= os.path.join(BASE_URL , "bn_model.model")
        self.en_tokenizer=os.path.join(BASE_URL, 'en_model.model')
        self.bn_vocab=os.path.join(BASE_URL,'bn_vocab.pkl')
        self.en_vocab=os.path.join(BASE_URL, 'en_vocab.pkl')
        self.model= os.path.join(BASE_URL,'pytorch_model.pt')

    def read_data(self, data_path):
        with open(data_path, "r") as f:
            data = f.readlines()
        data = list(map(lambda x: [x.split("\t")[0], x.split("\t")[1].replace("\n", "")], data))
        return data

    def load_tokenizer(self, tokenizer_path:str = "")->object:
        _tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
        return _tokenizer

    def get_vocab(self, BN_VOCAL_PATH:str="", EN_VOCAL_PATH:str=""):
        bn_vocal, en_vocal = uobj.load_bn_vocal(BN_VOCAL_PATH), uobj.load_en_vocal(EN_VOCAL_PATH)
        return bn_vocal, en_vocal

    def load_model(self, model_path:str = "", SRC_VOCAB_SIZE:int=0, TGT_VOCAB_SIZE:int=0):
        model = BanglaTransformer(
            cfg.NUM_ENCODER_LAYERS, cfg.NUM_DECODER_LAYERS, cfg.EMB_SIZE,  SRC_VOCAB_SIZE,
            TGT_VOCAB_SIZE, cfg.FFN_HID_DIM, nhead= cfg.NHEAD)
        model.to(device)
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        return model

    def greedy_decode(self, model, src, src_mask, max_len, start_symbol, eos_index):
        src = src.to(device)
        src_mask = src_mask.to(device)
        memory = model.encode(src, src_mask)
        ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
        for i in range(max_len-1):
            memory = memory.to(device)
            memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
            tgt_mask = (uobj.generate_square_subsequent_mask(ys.size(0))
                                        .type(torch.bool)).to(device)
            out = model.decode(ys, memory, tgt_mask)
            out = out.transpose(0, 1)
            prob = model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim = 1)
            next_word = next_word.item()
            ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
            if next_word == eos_index:
                break
        return ys
    def get_bntoen_model(self):
        print("Tokenizer Loading ...... : ", end="", flush=True)
        bn_tokenizer = self.load_tokenizer(tokenizer_path=self.bn_tokenizer)
        print("Done")
        print("Vocab Loading ......     : ", end="", flush=True)
        bn_vocab, en_vocab = self.get_vocab(BN_VOCAL_PATH=self.bn_vocab, EN_VOCAL_PATH=self.en_vocab)
        print("Done")
        print("Model Loading ......     : ", end="", flush=True)
        model = self.load_model(model_path=self.model, SRC_VOCAB_SIZE=len(bn_vocab), TGT_VOCAB_SIZE=len(en_vocab))
        print("Done") 

        models = {
            "bn_tokenizer" : bn_tokenizer, 
            "bn_vocab" : bn_vocab,
            "en_vocab" :  en_vocab, 
            "model": model
        }
        return models 
    def translate(self, text, models):
        model = models["model"]
        src_vocab = models["bn_vocab"]
        tgt_vocab = models["en_vocab"]
        src_tokenizer = models["bn_tokenizer"]
        src = text
        PAD_IDX, BOS_IDX, EOS_IDX= src_vocab['<pad>'], src_vocab['<bos>'], src_vocab['<eos>']
        tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
        num_tokens = len(tokens)
        src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        tgt_tokens = self.greedy_decode(model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, eos_index= EOS_IDX).flatten()
        p_text = " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
        pts = " ".join(list(map(lambda x : x , p_text.replace(" ", "").split("▁"))))
        return pts.strip()




# if __name__ == "__main__":
#     print(torch.cuda.get_device_name(0))
text = "এই উপজেলায় ১টি সরকারি কলেজ রয়েছে"
obj = Bn2EnTranslation()
models = obj.get_bntoen_model()
text = st.text_area("Enter some text:এই উপজেলায় ১টি সরকারি কলেজ রয়েছে")
if text:
    pre = obj.translate(text, models)
    print(f"Input      : {text}")
    print(f"Prediction : {pre}")