nlp-project / models /LSTM.py
Ruslan-DS's picture
Update models/LSTM.py
ec33c28
import numpy as np
import torch
from torch import nn
from models.preprocess_stage.preprocess_lstm import preprocess_lstm
EMBEDDING_DIM = 128
HIDDEN_SIZE = 16
MAX_LEN = 125
# DEVICE='cpu'
embedding_matrix = np.load('models/datasets/embedding_matrix.npy')
embedding_layer = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))
class AtenttionTest(nn.Module):
def __init__(self, hidden_size=HIDDEN_SIZE):
super().__init__()
self.hidden_size = hidden_size
self.fc1 = nn.Linear(self.hidden_size, self.hidden_size)
self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
self.tahn = nn.Tanh()
self.fc3 = nn.Linear(self.hidden_size, 1)
def forward(self, outputs_lmst, h_n):
output_fc1 = self.fc1(outputs_lmst)
output_fc2 = self.fc2(h_n.squeeze(0))
fc1_fc2_cat = output_fc1 + output_fc2.unsqueeze(1)
output_tahn = self.tahn(fc1_fc2_cat)
attention_weights = torch.softmax(self.fc3(output_tahn).squeeze(2), dim=1)
output_finished = torch.bmm(output_fc1.transpose(1, 2), attention_weights.unsqueeze(2))
return output_finished, attention_weights
class LSTMnn(nn.Module):
def __init__(self):
super().__init__()
self.embedding = embedding_layer
self.lstm = nn.LSTM(
input_size=EMBEDDING_DIM,
hidden_size=HIDDEN_SIZE,
num_layers=1,
batch_first=True
)
self.attention = AtenttionTest(hidden_size=HIDDEN_SIZE)
self.fc_out = nn.Sequential(
nn.Linear(HIDDEN_SIZE, 128),
nn.Dropout(),
nn.Tanh(),
nn.Linear(128, 1)
)
def forward(self, x):
embedding = self.embedding(x)
output_lstm, (h_n, _) = self.lstm(embedding)
output_attention, attention_weights = self.attention(output_lstm, h_n)
output_finished = self.fc_out(output_attention.squeeze(2))
return torch.sigmoid(output_finished), attention_weights
model = LSTMnn()
model.load_state_dict(torch.load('models/weights/LSTMBestWeights.pt', map_location=torch.device('cpu')))
def predict_3(text):
preprocessed_text = preprocess_lstm(text, MAX_LEN=MAX_LEN)
# model.to(DEVICE)
model.eval()
predict, attention = model(torch.tensor(preprocessed_text).unsqueeze(0))
predict = round(predict.item())
return predict