|
import torch |
|
import json |
|
from transformer_model import TransformerModel |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
d_model = 512 |
|
seq_length = 10 |
|
vocab_size = 25672 |
|
batch_size = 32 |
|
num_heads = 8 |
|
dim_feedforward = 2048 |
|
|
|
|
|
model = TransformerModel(vocab_size, d_model, num_heads, dim_feedforward, seq_length) |
|
model.load_state_dict(torch.load('transformer_model.pth')) |
|
model.eval() |
|
|
|
|
|
with open('vocabulary.json', 'r') as vocab_file: |
|
vocab = json.load(vocab_file) |
|
|
|
if '<unk>' not in vocab: |
|
|
|
vocab['<unk>'] = len(vocab) |
|
|
|
|
|
|
|
def text_to_tensor(text, vocab, seq_length): |
|
tokens = text.split() |
|
indices = [vocab.get(token, vocab['<unk>']) for token in tokens] |
|
indices = indices[:seq_length] |
|
indices += [vocab['<pad>']] * (seq_length - len(indices)) |
|
return torch.tensor(indices, dtype=torch.long).unsqueeze(0) |
|
|
|
|
|
|
|
input_text = "please make the" |
|
input_tensor = text_to_tensor(input_text, vocab, seq_length) |
|
src = input_tensor |
|
tgt = input_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_square_subsequent_mask(sz): |
|
mask = torch.triu(torch.ones(sz, sz, device=device), diagonal=1).bool() |
|
return mask |
|
|
|
|
|
def create_padding_mask(seq): |
|
mask = (seq == 0).transpose(0, 1) |
|
return mask |
|
|
|
src_seq_len = src.size(1) |
|
tgt_seq_len = tgt.size(1) |
|
|
|
src_mask = generate_square_subsequent_mask(src_seq_len) |
|
|
|
tgt_mask = generate_square_subsequent_mask(tgt_seq_len) |
|
src_key_padding_mask = create_padding_mask(src) |
|
tgt_key_padding_mask = create_padding_mask(tgt) |
|
|
|
src.size() |
|
tgt.size() |
|
src_mask.size() |
|
tgt_mask.size() |
|
src_key_padding_mask.size() |
|
tgt_key_padding_mask.size() |
|
|
|
with torch.no_grad(): |
|
output = model(src, tgt, src_mask, tgt_mask, |
|
src_key_padding_mask.transpose(0, 1), tgt_key_padding_mask.transpose(0, 1)) |
|
|
|
|
|
predicted_indices = torch.argmax(output, dim=-1).squeeze(0).tolist() |
|
predicted_indices |
|
|
|
inverse_vocab = {value: key for key, value in vocab.items()} |
|
|
|
import itertools |
|
|
|
flattened_list = list(itertools.chain.from_iterable(predicted_indices)) |
|
|
|
[inverse_vocab[key] for key in flattened_list] |
|
|
|
|
|
def generate_prediction(text, model, vocab, seq_length): |
|
model.eval() |
|
|
|
|
|
input_tensor = text_to_tensor(text, vocab, seq_length) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor, input_tensor) |
|
|
|
|
|
predicted_indices = torch.argmax(output, dim=-1).squeeze(0).tolist() |
|
predicted_tokens = [vocab[index] for index in predicted_indices] |
|
|
|
return predicted_tokens |
|
|
|
|
|
|
|
text = """Here were the servants of your adversary And |
|
yours""" |
|
|
|
prediction = generate_prediction(text, model, vocab, seq_length) |
|
print(prediction) |
|
|