Lwasinam's picture
Upload 8 files
61e1114
raw
history blame
4.61 kB
import torch
from model import build_transformer
from train import greedy_decode, get_model, get_or_build_tokenizer
from config import get_config, get_weights_file_path
from tokenizers import Tokenizer
from pathlib import Path
config = get_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def process_text(config, src_text, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
seq_len = seq_len
# ds = ds
tokenizer_src = tokenizer_src
tokenizer_tgt = tokenizer_tgt
src_lang = src_lang
tgt_lang = tgt_lang
sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
# Transform the text into tokens
enc_input_tokens = tokenizer_src.encode(src_text).ids
# dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
# Add sos, eos and padding to each sentence
enc_num_padding_tokens = seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
# # We will only add <s>, and </s> only on the label
# dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
# Make sure the number of padding tokens is not negative. If it is, the sentence is too long
if enc_num_padding_tokens < 0:
raise ValueError("Sentence is too long")
# Add <s> and </s> token
encoder_input = torch.cat(
[
sos_token,
torch.tensor(enc_input_tokens, dtype=torch.int64),
eos_token,
torch.tensor([pad_token] * enc_num_padding_tokens, dtype=torch.int64),
],
dim=0,
)
# # Add only <s> token
# decoder_input = torch.cat(
# [
# self.sos_token,
# torch.tensor(dec_input_tokens, dtype=torch.int64),
# torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
# ],
# dim=0,
# )
# # Add only </s> token
# label = torch.cat(
# [
# torch.tensor(dec_input_tokens, dtype=torch.int64),
# self.eos_token,
# torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
# ],
# dim=0,
# )
# Double check the size of the tensors to make sure they are all seq_len long
assert encoder_input.size(0) == seq_len
# assert decoder_input.size(0) == seq_len
# assert label.size(0) == seq_len
return {
'encoder_input': encoder_input,
# 'decoder_input': decoder_input,
"encoder_mask": (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
# "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
# "label": label, # (seq_len)
# "src_text": src_text,
# "tgt_text": tgt_text,
}
def causal_mask(size):
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
return mask == 0
def infer(text, config):
tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src']))))
tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt']))))
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size())
state = torch.load('tmodel_36.pt', map_location=torch.device('cpu'))
model.load_state_dict(state['model_state_dict'])
model.eval()
with torch.no_grad():
processed_text = process_text(config, text, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
encoder_input = processed_text['encoder_input']
encoder_mask = processed_text['encoder_mask']
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device)
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
return model_out_text
import streamlit as st
st.title("English to Hausa Translation")
user_input = st.text_input("Enter your text:")
if user_input:
result = infer(user_input, config)
st.write("Inference Result:", result)