saiful9379's picture
change device to cpu
234e9f7
raw
history blame contribute delete
No virus
5.21 kB
import os
import torch
import argparse
import streamlit as st
import sentencepiece as spm
from utils import utils_cls
from model import BanglaTransformer
from config import config as cfg
torch.manual_seed(0)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
uobj = utils_cls(device=device)
__MODULE__ = "Bangla Language Translation"
__MAIL__ = "saifulbrur79@gmail.com"
__MODIFICAIOTN__ = "28/03/2023"
__LICENSE__ = "MIT"
st.write(""" Bangla to English Translation """)
BASE_URL = "./model"
class Bn2EnTranslation:
def __init__(self):
self.bn_tokenizer= os.path.join(BASE_URL , "bn_model.model")
self.en_tokenizer=os.path.join(BASE_URL, 'en_model.model')
self.bn_vocab=os.path.join(BASE_URL,'bn_vocab.pkl')
self.en_vocab=os.path.join(BASE_URL, 'en_vocab.pkl')
self.model= os.path.join(BASE_URL,'pytorch_model.pt')
def read_data(self, data_path):
with open(data_path, "r") as f:
data = f.readlines()
data = list(map(lambda x: [x.split("\t")[0], x.split("\t")[1].replace("\n", "")], data))
return data
def load_tokenizer(self, tokenizer_path:str = "")->object:
_tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
return _tokenizer
def get_vocab(self, BN_VOCAL_PATH:str="", EN_VOCAL_PATH:str=""):
bn_vocal, en_vocal = uobj.load_bn_vocal(BN_VOCAL_PATH), uobj.load_en_vocal(EN_VOCAL_PATH)
return bn_vocal, en_vocal
def load_model(self, model_path:str = "", SRC_VOCAB_SIZE:int=0, TGT_VOCAB_SIZE:int=0):
model = BanglaTransformer(
cfg.NUM_ENCODER_LAYERS, cfg.NUM_DECODER_LAYERS, cfg.EMB_SIZE, SRC_VOCAB_SIZE,
TGT_VOCAB_SIZE, cfg.FFN_HID_DIM, nhead= cfg.NHEAD)
model.to(device)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def greedy_decode(self, model, src, src_mask, max_len, start_symbol, eos_index):
src = src.to(device)
src_mask = src_mask.to(device)
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
for i in range(max_len-1):
memory = memory.to(device)
memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
tgt_mask = (uobj.generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(device)
out = model.decode(ys, memory, tgt_mask)
out = out.transpose(0, 1)
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim = 1)
next_word = next_word.item()
ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
if next_word == eos_index:
break
return ys
def get_bntoen_model(self):
print("Tokenizer Loading ...... : ", end="", flush=True)
bn_tokenizer = self.load_tokenizer(tokenizer_path=self.bn_tokenizer)
print("Done")
print("Vocab Loading ...... : ", end="", flush=True)
bn_vocab, en_vocab = self.get_vocab(BN_VOCAL_PATH=self.bn_vocab, EN_VOCAL_PATH=self.en_vocab)
print("Done")
print("Model Loading ...... : ", end="", flush=True)
model = self.load_model(model_path=self.model, SRC_VOCAB_SIZE=len(bn_vocab), TGT_VOCAB_SIZE=len(en_vocab))
print("Done")
models = {
"bn_tokenizer" : bn_tokenizer,
"bn_vocab" : bn_vocab,
"en_vocab" : en_vocab,
"model": model
}
return models
def translate(self, text, models):
model = models["model"]
src_vocab = models["bn_vocab"]
tgt_vocab = models["en_vocab"]
src_tokenizer = models["bn_tokenizer"]
src = text
PAD_IDX, BOS_IDX, EOS_IDX= src_vocab['<pad>'], src_vocab['<bos>'], src_vocab['<eos>']
tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
num_tokens = len(tokens)
src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
tgt_tokens = self.greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, eos_index= EOS_IDX).flatten()
p_text = " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
pts = " ".join(list(map(lambda x : x , p_text.replace(" ", "").split("▁"))))
return pts.strip()
# if __name__ == "__main__":
# print(torch.cuda.get_device_name(0))
text = "এই উপজেলায় ১টি সরকারি কলেজ রয়েছে"
obj = Bn2EnTranslation()
models = obj.get_bntoen_model()
text = st.text_area("Enter some text:এই উপজেলায় ১টি সরকারি কলেজ রয়েছে")
if text:
pre = obj.translate(text, models)
print(f"Input : {text}")
print(f"Prediction : {pre}")