Spaces:
Runtime error
Runtime error
File size: 5,211 Bytes
94bea44 f52cf42 8fdaf65 f52cf42 234e9f7 f52cf42 b4b95df 4e70c1c 7d77c34 f52cf42 94bea44 f52cf42 8fdaf65 f52cf42 8fdaf65 f52cf42 8fdaf65 f52cf42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import torch
import argparse
import streamlit as st
import sentencepiece as spm
from utils import utils_cls
from model import BanglaTransformer
from config import config as cfg
torch.manual_seed(0)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
uobj = utils_cls(device=device)
__MODULE__ = "Bangla Language Translation"
__MAIL__ = "saifulbrur79@gmail.com"
__MODIFICAIOTN__ = "28/03/2023"
__LICENSE__ = "MIT"
st.write(""" Bangla to English Translation """)
BASE_URL = "./model"
class Bn2EnTranslation:
def __init__(self):
self.bn_tokenizer= os.path.join(BASE_URL , "bn_model.model")
self.en_tokenizer=os.path.join(BASE_URL, 'en_model.model')
self.bn_vocab=os.path.join(BASE_URL,'bn_vocab.pkl')
self.en_vocab=os.path.join(BASE_URL, 'en_vocab.pkl')
self.model= os.path.join(BASE_URL,'pytorch_model.pt')
def read_data(self, data_path):
with open(data_path, "r") as f:
data = f.readlines()
data = list(map(lambda x: [x.split("\t")[0], x.split("\t")[1].replace("\n", "")], data))
return data
def load_tokenizer(self, tokenizer_path:str = "")->object:
_tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
return _tokenizer
def get_vocab(self, BN_VOCAL_PATH:str="", EN_VOCAL_PATH:str=""):
bn_vocal, en_vocal = uobj.load_bn_vocal(BN_VOCAL_PATH), uobj.load_en_vocal(EN_VOCAL_PATH)
return bn_vocal, en_vocal
def load_model(self, model_path:str = "", SRC_VOCAB_SIZE:int=0, TGT_VOCAB_SIZE:int=0):
model = BanglaTransformer(
cfg.NUM_ENCODER_LAYERS, cfg.NUM_DECODER_LAYERS, cfg.EMB_SIZE, SRC_VOCAB_SIZE,
TGT_VOCAB_SIZE, cfg.FFN_HID_DIM, nhead= cfg.NHEAD)
model.to(device)
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
def greedy_decode(self, model, src, src_mask, max_len, start_symbol, eos_index):
src = src.to(device)
src_mask = src_mask.to(device)
memory = model.encode(src, src_mask)
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
for i in range(max_len-1):
memory = memory.to(device)
memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
tgt_mask = (uobj.generate_square_subsequent_mask(ys.size(0))
.type(torch.bool)).to(device)
out = model.decode(ys, memory, tgt_mask)
out = out.transpose(0, 1)
prob = model.generator(out[:, -1])
_, next_word = torch.max(prob, dim = 1)
next_word = next_word.item()
ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
if next_word == eos_index:
break
return ys
def get_bntoen_model(self):
print("Tokenizer Loading ...... : ", end="", flush=True)
bn_tokenizer = self.load_tokenizer(tokenizer_path=self.bn_tokenizer)
print("Done")
print("Vocab Loading ...... : ", end="", flush=True)
bn_vocab, en_vocab = self.get_vocab(BN_VOCAL_PATH=self.bn_vocab, EN_VOCAL_PATH=self.en_vocab)
print("Done")
print("Model Loading ...... : ", end="", flush=True)
model = self.load_model(model_path=self.model, SRC_VOCAB_SIZE=len(bn_vocab), TGT_VOCAB_SIZE=len(en_vocab))
print("Done")
models = {
"bn_tokenizer" : bn_tokenizer,
"bn_vocab" : bn_vocab,
"en_vocab" : en_vocab,
"model": model
}
return models
def translate(self, text, models):
model = models["model"]
src_vocab = models["bn_vocab"]
tgt_vocab = models["en_vocab"]
src_tokenizer = models["bn_tokenizer"]
src = text
PAD_IDX, BOS_IDX, EOS_IDX= src_vocab['<pad>'], src_vocab['<bos>'], src_vocab['<eos>']
tokens = [BOS_IDX] + [src_vocab.get_stoi()[tok] for tok in src_tokenizer.encode(src, out_type=str)]+ [EOS_IDX]
num_tokens = len(tokens)
src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
tgt_tokens = self.greedy_decode(model, src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX, eos_index= EOS_IDX).flatten()
p_text = " ".join([tgt_vocab.get_itos()[tok] for tok in tgt_tokens]).replace("<bos>", "").replace("<eos>", "")
pts = " ".join(list(map(lambda x : x , p_text.replace(" ", "").split("▁"))))
return pts.strip()
# if __name__ == "__main__":
# print(torch.cuda.get_device_name(0))
text = "এই উপজেলায় ১টি সরকারি কলেজ রয়েছে"
obj = Bn2EnTranslation()
models = obj.get_bntoen_model()
text = st.text_area("Enter some text:এই উপজেলায় ১টি সরকারি কলেজ রয়েছে")
if text:
pre = obj.translate(text, models)
print(f"Input : {text}")
print(f"Prediction : {pre}")
|