Spaces:
Runtime error
Runtime error
import torch | |
from model import build_transformer | |
from train import greedy_decode, get_model, get_or_build_tokenizer | |
from config import get_config, get_weights_file_path | |
from tokenizers import Tokenizer | |
from pathlib import Path | |
config = get_config() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def process_text(config, src_text, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len): | |
seq_len = seq_len | |
# ds = ds | |
tokenizer_src = tokenizer_src | |
tokenizer_tgt = tokenizer_tgt | |
src_lang = src_lang | |
tgt_lang = tgt_lang | |
sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64) | |
eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64) | |
pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64) | |
# Transform the text into tokens | |
enc_input_tokens = tokenizer_src.encode(src_text).ids | |
# dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids | |
# Add sos, eos and padding to each sentence | |
enc_num_padding_tokens = seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s> | |
# # We will only add <s>, and </s> only on the label | |
# dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 | |
# Make sure the number of padding tokens is not negative. If it is, the sentence is too long | |
if enc_num_padding_tokens < 0: | |
raise ValueError("Sentence is too long") | |
# Add <s> and </s> token | |
encoder_input = torch.cat( | |
[ | |
sos_token, | |
torch.tensor(enc_input_tokens, dtype=torch.int64), | |
eos_token, | |
torch.tensor([pad_token] * enc_num_padding_tokens, dtype=torch.int64), | |
], | |
dim=0, | |
) | |
# # Add only <s> token | |
# decoder_input = torch.cat( | |
# [ | |
# self.sos_token, | |
# torch.tensor(dec_input_tokens, dtype=torch.int64), | |
# torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), | |
# ], | |
# dim=0, | |
# ) | |
# # Add only </s> token | |
# label = torch.cat( | |
# [ | |
# torch.tensor(dec_input_tokens, dtype=torch.int64), | |
# self.eos_token, | |
# torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), | |
# ], | |
# dim=0, | |
# ) | |
# Double check the size of the tensors to make sure they are all seq_len long | |
assert encoder_input.size(0) == seq_len | |
# assert decoder_input.size(0) == seq_len | |
# assert label.size(0) == seq_len | |
return { | |
'encoder_input': encoder_input, | |
# 'decoder_input': decoder_input, | |
"encoder_mask": (encoder_input != pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len) | |
# "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len), | |
# "label": label, # (seq_len) | |
# "src_text": src_text, | |
# "tgt_text": tgt_text, | |
} | |
def causal_mask(size): | |
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) | |
return mask == 0 | |
def infer(text, config): | |
tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src'])))) | |
tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt'])))) | |
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()) | |
state = torch.load('tmodel_36.pt', map_location=torch.device('cpu')) | |
model.load_state_dict(state['model_state_dict']) | |
model.eval() | |
with torch.no_grad(): | |
processed_text = process_text(config, text, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len']) | |
encoder_input = processed_text['encoder_input'] | |
encoder_mask = processed_text['encoder_mask'] | |
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device) | |
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy()) | |
return model_out_text | |
import streamlit as st | |
st.title("English to Hausa Translation") | |
user_input = st.text_input("Enter your text:") | |
if user_input: | |
result = infer(user_input, config) | |
st.write("Inference Result:", result) | |