import streamlit as st import numpy as np import torch from torch import nn from torch.utils.data import Dataset from transformer import Transformer english_file = 'dataset/english.txt' spanish_file = 'dataset/spanish.txt' START_TOKEN = '' PADDING_TOKEN = '' END_TOKEN = '' english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '’', '‘', ';', '₂', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'á', 'é', 'í', 'ó', 'ú', 'ñ', 'ü', '¿', '¡', 'Á', 'É', 'Í', 'Ó', 'Ú', 'Ñ', 'Ü', '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN, 'à', 'è', 'ì', 'ò', 'ù', 'À', 'È', 'Ì', 'Ò', 'Ù', 'â', 'ê', 'î', 'ô', 'û', 'Â', 'Ê', 'Î', 'Ô', 'Û', 'ä', 'ë', 'ï', 'ö', 'ü', 'Ä', 'Ë', 'Ï', 'Ö', 'ã', 'õ', 'Ã', 'Õ', 'ā', 'ē', 'ī', 'ō', 'ū', 'Ā', 'Ē', 'Ī', 'Ō', 'Ū', 'ą', 'ę', 'į', 'ǫ', 'ų', 'Ą', 'Ę', 'Į', 'Ǫ', 'Ų', 'ç', 'Ç', 'ş', 'Ş', 'ğ', 'Ğ', 'ń', 'Ń', 'ś', 'Ś', 'ź', 'Ź', 'ż', 'Ż', 'č', 'Č', 'ć', 'Ć', 'đ', 'Đ', 'ł', 'Ł', 'ř', 'Ř', 'š', 'Š', 'ť', 'Ť', 'ý', 'ÿ', 'Ý', 'Ÿ', 'ž', 'Ž', 'ß', 'œ', 'Œ', 'æ', 'Æ', 'å', 'Å', 'ø', 'Ø', 'å', 'Å', 'æ', 'Æ', 'œ', 'Œ'] spanish_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '’', '‘', ';', '₂', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', 'á', 'é', 'í', 'ó', 'ú', 'ñ', 'ü', '¿', '¡', 'Á', 'É', 'Í', 'Ó', 'Ú', 'Ñ', 'Ü', '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN, 'à', 'è', 'ì', 'ò', 'ù', 'À', 'È', 'Ì', 'Ò', 'Ù', 'â', 'ê', 'î', 'ô', 'û', 'Â', 'Ê', 'Î', 'Ô', 'Û', 'ä', 'ë', 'ï', 'ö', 'ü', 'Ä', 'Ë', 'Ï', 'Ö', 'ã', 'õ', 'Ã', 'Õ', 'ā', 'ē', 'ī', 'ō', 'ū', 'Ā', 'Ē', 'Ī', 'Ō', 'Ū', 'ą', 'ę', 'į', 'ǫ', 'ų', 'Ą', 'Ę', 'Į', 'Ǫ', 'Ų', 'ç', 'Ç', 'ş', 'Ş', 'ğ', 'Ğ', 'ń', 'Ń', 'ś', 'Ś', 'ź', 'Ź', 'ż', 'Ż', 'č', 'Č', 'ć', 'Ć', 'đ', 'Đ', 'ł', 'Ł', 'ř', 'Ř', 'š', 'Š', 'ť', 'Ť', 'ý', 'ÿ', 'Ý', 'Ÿ', 'ž', 'Ž', 'ß', 'œ', 'Œ', 'æ', 'Æ', 'å', 'Å', 'ø', 'Ø', 'å', 'Å', 'æ', 'Æ', 'œ', 'Œ'] index_to_english = {k: v for k, v in enumerate(english_vocabulary)} english_to_index = {v: k for k, v in enumerate(english_vocabulary)} index_to_spanish = {k: v for k, v in enumerate(spanish_vocabulary)} spanish_to_index = {v: k for k, v in enumerate(spanish_vocabulary)} d_model = 512 batch_size = 30 ffn_hidden = 2048 num_heads = 8 drop_prob = 0.1 num_layers = 1 max_sequence_length = 200 es_vocab_size = len(spanish_vocabulary) transformer = Transformer(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, es_vocab_size, english_to_index, spanish_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN) class TextDataset(Dataset): def __init__(self, english_sentences, spanish_sentences): self.english_sentences = english_sentences self.spanish_sentences = spanish_sentences def __len__(self): return len(self.english_sentences) def __getitem__(self, idx): return self.english_sentences[idx], self.spanish_sentences[idx] # device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') device = "cpu" NEG_INFTY = -1e9 def create_masks(eng_batch, kn_batch): num_sentences = len(eng_batch) look_ahead_mask = torch.full([max_sequence_length, max_sequence_length], True) look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1) encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length], False) decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length], False) decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length], False) for idx in range(num_sentences): eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx]) eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length) kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length) encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0) decoder_self_attention_mask = torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0) decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0) return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask def translate(eng_sentence): d_model = 512 batch_size = 30 ffn_hidden = 2048 num_heads = 8 drop_prob = 0.1 num_layers = 1 max_sequence_length = 200 es_vocab_size = len(spanish_vocabulary) transformer = Transformer(d_model, ffn_hidden, num_heads, drop_prob, num_layers, max_sequence_length, es_vocab_size, english_to_index, spanish_to_index, START_TOKEN, END_TOKEN, PADDING_TOKEN) transformer.load_state_dict(torch.load("englishTOspanish.pt", map_location=torch.device('cpu'))) transformer.eval() eng_sentence = (eng_sentence,) es_sentence = ("",) for word_counter in range(max_sequence_length): encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks( eng_sentence, es_sentence) predictions = transformer(eng_sentence, es_sentence, encoder_self_attention_mask.to(device), decoder_self_attention_mask.to(device), decoder_cross_attention_mask.to(device), enc_start_token=False, enc_end_token=False, dec_start_token=True, dec_end_token=False) next_token_prob_distribution = predictions[0][word_counter] next_token_index = torch.argmax(next_token_prob_distribution).item() next_token = index_to_spanish[next_token_index] es_sentence = (es_sentence[0] + next_token,) if next_token == END_TOKEN: break return es_sentence[0] st.title("seq2seq Machine Translation") st.write("Translate English to Spanish") st.write("\n") st.write("Some example sentences:") st.write("i'm happy to see you here") st.write("i have nothing to do with it") st.write("what did you say yesterday?") st.write("\n") input_text = st.text_area("Enter English text:") if st.button("Translate"): if input_text.strip() == "": st.warning("Please enter some text.") else: translated_text = translate(input_text) st.write("Your Text (English):") st.title(input_text) st.write("Translated Text (Spanish):") st.title(translated_text[:-5])