In [1]:
from transformer import Transformer
import torch
import numpy as np

In [2]:
english_file = 'dataset/english.txt'
spanish_file = 'dataset/spanish.txt'

START_TOKEN = '<START>'
PADDING_TOKEN = '<PADDING>'
END_TOKEN = '<END>'

english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '’', '‘', ';', '₂',
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                      ':', '<', '=', '>', '?', '@',
                      '[', '\\', ']', '^', '_', '`', 
                      'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                      'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                      'y', 'z', 
                      'á', 'é', 'í', 'ó', 'ú', 'ñ', 'ü',
                      '¿', '¡',
                      'Á', 'É', 'Í', 'Ó', 'Ú', 'Ñ', 'Ü',
                      '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN,
                      'à', 'è', 'ì', 'ò', 'ù', 'À', 'È', 'Ì', 'Ò', 'Ù',
                      'â', 'ê', 'î', 'ô', 'û', 'Â', 'Ê', 'Î', 'Ô', 'Û',
                      'ä', 'ë', 'ï', 'ö', 'ü', 'Ä', 'Ë', 'Ï', 'Ö',
                      'ã', 'õ', 'Ã', 'Õ',
                      'ā', 'ē', 'ī', 'ō', 'ū', 'Ā', 'Ē', 'Ī', 'Ō', 'Ū',
                      'ą', 'ę', 'į', 'ǫ', 'ų', 'Ą', 'Ę', 'Į', 'Ǫ', 'Ų',
                      'ç', 'Ç', 'ş', 'Ş', 'ğ', 'Ğ', 'ń', 'Ń', 'ś', 'Ś', 'ź', 'Ź', 'ż', 'Ż',
                      'č', 'Č', 'ć', 'Ć', 'đ', 'Đ', 'ł', 'Ł', 'ř', 'Ř', 'š', 'Š', 'ť', 'Ť',
                      'ý', 'ÿ', 'Ý', 'Ÿ', 'ž', 'Ž', 'ß', 'œ', 'Œ', 'æ', 'Æ', 'å', 'Å', 'ø', 'Ø', 'å', 'Å',
                      'æ', 'Æ', 'œ', 'Œ']

spanish_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '’', '‘', ';', '₂',
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                      ':', '<', '=', '>', '?', '@',
                      '[', '\\', ']', '^', '_', '`', 
                      'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                      'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                      'y', 'z', 
                      'á', 'é', 'í', 'ó', 'ú', 'ñ', 'ü',
                      '¿', '¡',
                      'Á', 'É', 'Í', 'Ó', 'Ú', 'Ñ', 'Ü',
                      '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN,
                      'à', 'è', 'ì', 'ò', 'ù', 'À', 'È', 'Ì', 'Ò', 'Ù',
                      'â', 'ê', 'î', 'ô', 'û', 'Â', 'Ê', 'Î', 'Ô', 'Û',
                      'ä', 'ë', 'ï', 'ö', 'ü', 'Ä', 'Ë', 'Ï', 'Ö',
                      'ã', 'õ', 'Ã', 'Õ',
                      'ā', 'ē', 'ī', 'ō', 'ū', 'Ā', 'Ē', 'Ī', 'Ō', 'Ū',
                      'ą', 'ę', 'į', 'ǫ', 'ų', 'Ą', 'Ę', 'Į', 'Ǫ', 'Ų',
                      'ç', 'Ç', 'ş', 'Ş', 'ğ', 'Ğ', 'ń', 'Ń', 'ś', 'Ś', 'ź', 'Ź', 'ż', 'Ż',
                      'č', 'Č', 'ć', 'Ć', 'đ', 'Đ', 'ł', 'Ł', 'ř', 'Ř', 'š', 'Š', 'ť', 'Ť',
                      'ý', 'ÿ', 'Ý', 'Ÿ', 'ž', 'Ž', 'ß', 'œ', 'Œ', 'æ', 'Æ', 'å', 'Å', 'ø', 'Ø', 'å', 'Å',
                      'æ', 'Æ', 'œ', 'Œ']


In [3]:
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}
index_to_spanish = {k:v for k,v in enumerate(spanish_vocabulary)}
spanish_to_index = {v:k for k,v in enumerate(spanish_vocabulary)}

In [4]:
with open(english_file, 'r') as file:
    english_sentences = file.readlines()
with open(spanish_file, 'r') as file:
    spanish_sentences = file.readlines()

TOTAL_SENTENCES = 110000
english_sentences = english_sentences[:TOTAL_SENTENCES]
spanish_sentences = spanish_sentences[:TOTAL_SENTENCES]
english_sentences = [sentence.rstrip('\n').lower() for sentence in english_sentences]
spanish_sentences = [sentence.rstrip('\n').lower() for sentence in spanish_sentences]

In [5]:
english_sentences[:10]

['english', 'go.', 'go.', 'go.', 'go.', 'hi.', 'run!', 'run.', 'who?', 'fire!']

In [6]:
spanish_sentences[:10]

['spanish',
 've.',
 'vete.',
 'vaya.',
 'váyase.',
 'hola.',
 '¡corre!',
 'corred.',
 '¿quién?',
 '¡fuego!']

In [7]:
import numpy as np
PERCENTILE = 97
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(x) for x in english_sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length Spanish: {np.percentile([len(x) for x in spanish_sentences], PERCENTILE)}" )


97th percentile length English: 48.0
97th percentile length Spanish: 53.0


In [8]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

valid_sentence_indicies = []
for index in range(len(spanish_sentences)):
    spanish_sentence, english_sentence = spanish_sentences[index], english_sentences[index]
    if is_valid_length(spanish_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length) \
      and is_valid_tokens(spanish_sentence, spanish_vocabulary):
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(spanish_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 110000
Number of valid sentences: 109964


In [9]:
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]
spanish_sentences = [spanish_sentences[i] for i in valid_sentence_indicies]

In [10]:
import torch

d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
kn_vocab_size = len(kannada_vocabulary)
es_vocab_size = len(spanish_vocabulary)

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          es_vocab_size,
                          english_to_index,
                          spanish_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

In [11]:
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SentenceEmbedding(
      (embedding): Embedding(186, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embedding)

In [12]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, english_sentences, spanish_sentences):
        self.english_sentences = english_sentences
        self.spanish_sentences = spanish_sentences

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        return self.english_sentences[idx], self.spanish_sentences[idx]

In [13]:
dataset = TextDataset(english_sentences, spanish_sentences)

In [14]:
len(dataset)

109964

In [15]:
dataset[1]

('go.', 've.')

In [16]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [17]:
for batch_num, batch in enumerate(iterator):
    print(batch)
    if batch_num > 3:
        break

[('english', 'go.', 'go.', 'go.', 'go.', 'hi.', 'run!', 'run.', 'who?', 'fire!', 'fire!', 'fire!', 'help!', 'help!', 'help!', 'jump!', 'jump.', 'stop!', 'stop!', 'stop!', 'wait!', 'wait.', 'go on.', 'go on.', 'hello!', 'i ran.', 'i ran.', 'i try.', 'i won!', 'oh no!'), ('spanish', 've.', 'vete.', 'vaya.', 'váyase.', 'hola.', '¡corre!', 'corred.', '¿quién?', '¡fuego!', '¡incendio!', '¡disparad!', '¡ayuda!', '¡socorro! ¡auxilio!', '¡auxilio!', '¡salta!', 'salte.', '¡parad!', '¡para!', '¡pare!', '¡espera!', 'esperen.', 'continúa.', 'continúe.', 'hola.', 'corrí.', 'corría.', 'lo intento.', '¡he ganado!', '¡oh, no!')]
[('relax.', 'smile.', 'attack!', 'attack!', 'get up.', 'go now.', 'got it!', 'got it?', 'got it?', 'he ran.', 'hop in.', 'hug me.', 'i fell.', 'i know.', 'i left.', 'i lied.', 'i lost.', 'i quit.', 'i quit.', 'i work.', "i'm 19.", "i'm up.", 'listen.', 'listen.', 'listen.', 'no way!', 'no way!', 'no way!', 'no way!', 'no way!'), ('tomátelo con soda.', 'sonríe.', '¡al ataque!',

In [18]:
from torch import nn

criterian = nn.CrossEntropyLoss(ignore_index=spanish_to_index[PADDING_TOKEN],
                                reduction='none')

for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = "cpu"

In [19]:
NEG_INFTY = -1e9

def create_masks(eng_batch, kn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [20]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, es_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, es_batch)
        optim.zero_grad()
        es_predictions = transformer(eng_batch,
                                     es_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(es_batch, start_token=False, end_token=True)
        loss = criterian(
            es_predictions.view(-1, es_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == spanish_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Spanish Translation: {es_batch[0]}")
            es_sentence_predicted = torch.argmax(es_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in es_sentence_predicted:
              if idx == spanish_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_spanish[idx.item()]
            print(f"Spanish Prediction: {predicted_sentence}")


            transformer.eval()
            es_sentence = ("",)
            eng_sentence = ("should we go to the mall?",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, es_sentence)
                predictions = transformer(eng_sentence,
                                          es_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_spanish[next_token_index]
                es_sentence = (es_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break
            
            print(f"Evaluation translation (should we go to the mall?) : {es_sentence}")
            print("-------------------------------------------")

Epoch 0
Iteration 0 : 6.019732475280762
English: english
Spanish Translation: spanish
Spanish Prediction: íú77e33<START>333ų33.3Ä/åeĐĐŒ..?Ć?Ā???????????????????????8????????????????????????q??]77ÛĐÒĐ88Ō33?Ū&??Ū?eŪ?eešeąÝeeeeå]Ý7e/*ÝÝÒq?ą?ş3Ý8Ý8ĪeÝÝŒeÝÝ8eåå/Öe/eǫ8ö?q/ÝeÝqqqqÖÖ`qÝĐÝqqÒÒÒÒ?Ò
Evaluation translation (should we go to the mall?) : ('aa<END>',)
-------------------------------------------
Iteration 100 : 2.9844906330108643
English: i'm very hot.
Spanish Translation: estoy muy cachondo.
Spanish Prediction: oooe e eo oooo
Evaluation translation (should we go to the mall?) : ('eoo   e oooo.<END>',)
-------------------------------------------
Iteration 200 : 2.9213802814483643
English: i have nothing.
Spanish Translation: no tengo nada.
Spanish Prediction: eo o  eo  ooa.
Evaluation translation (should we go to the mall?) : ('eo   e e e eo.<END>',)
-------------------------------------------
Iteration 300 : 2.6449875831604004
English: there's no rush.
Spanish Translation: no hay pri

## Inference

In [24]:
transformer.eval()
def translate(eng_sentence):
  eng_sentence = (eng_sentence,)
  es_sentence = ("",)
  for word_counter in range(max_sequence_length):
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, es_sentence)
    predictions = transformer(eng_sentence,
                              es_sentence,
                              encoder_self_attention_mask.to(device), 
                              decoder_self_attention_mask.to(device), 
                              decoder_cross_attention_mask.to(device),
                              enc_start_token=False,
                              enc_end_token=False,
                              dec_start_token=True,
                              dec_end_token=False)
    next_token_prob_distribution = predictions[0][word_counter]
    next_token_index = torch.argmax(next_token_prob_distribution).item()
    next_token = index_to_spanish[next_token_index]
    es_sentence = (es_sentence[0] + next_token, )
    if next_token == END_TOKEN:
      break
  return es_sentence[0]

In [21]:
torch.save(transformer.state_dict(), 'englishTOspanish.pt')

In [33]:
transformer.load_state_dict(torch.load("englishTOspanish.pt"))

<All keys matched successfully>

In [None]:
#i'm happy to see you here
#i have nothing to do with it
#what did you say yesterday?

In [34]:
translation = translate("what should we do when the day starts?")
print(translation)

¿qué deberías cuando el días días días?<END>


In [35]:
translation = translate("i cannot stand this smell")
print(translation)

no puedo estaba de la lentiempo de la las las las informa.<END>
