Ruslan-DS commited on
Commit
3c80cef
1 Parent(s): 4c883db

Update models/LSTM.py

Browse files
Files changed (1) hide show
  1. models/LSTM.py +87 -0
models/LSTM.py CHANGED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+
5
+ from models.preprocess_stage.preprocess_lstm import preprocess_lstm
6
+
7
+ EMBEDDING_DIM = 128
8
+ HIDDEN_SIZE = 16
9
+ MAX_LEN = 125
10
+
11
+ embedding_matrix = np.load('models/datasets/embedding_matrix.npy')
12
+ embedding_layer = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))
13
+
14
+
15
+ class AtenttionTest(nn.Module):
16
+ def __init__(self, hidden_size=HIDDEN_SIZE):
17
+ super().__init__()
18
+
19
+ self.hidden_size = hidden_size
20
+ self.fc1 = nn.Linear(self.hidden_size, self.hidden_size)
21
+ self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
22
+ self.tahn = nn.Tanh()
23
+ self.fc3 = nn.Linear(self.hidden_size, 1)
24
+
25
+ def forward(self, outputs_lmst, h_n):
26
+
27
+ output_fc1 = self.fc1(outputs_lmst)
28
+ output_fc2 = self.fc2(h_n.squeeze(0))
29
+
30
+ fc1_fc2_cat = output_fc1 + output_fc2.unsqueeze(1)
31
+
32
+ output_tahn = self.tahn(fc1_fc2_cat)
33
+
34
+ attention_weights = torch.softmax(self.fc3(output_tahn).squeeze(2), dim=1)
35
+
36
+ output_finished = torch.bmm(output_fc1.transpose(1, 2), attention_weights.unsqueeze(2))
37
+
38
+ return output_finished, attention_weights
39
+
40
+
41
+ class LSTMnn(nn.Module):
42
+
43
+ def __init__(self):
44
+ super().__init__()
45
+
46
+ self.embedding = embedding_layer
47
+ self.lstm = nn.LSTM(
48
+ input_size=EMBEDDING_DIM,
49
+ hidden_size=HIDDEN_SIZE,
50
+ num_layers=1,
51
+ batch_first=True
52
+ )
53
+ self.attention = AtenttionTest(hidden_size=HIDDEN_SIZE)
54
+ self.fc_out = nn.Sequential(
55
+ nn.Linear(HIDDEN_SIZE, 128),
56
+ nn.Dropout(),
57
+ nn.Tanh(),
58
+ nn.Linear(128, 1)
59
+ )
60
+
61
+ def forward(self, x):
62
+
63
+ embedding = self.embedding(x)
64
+
65
+ output_lstm, (h_n, _) = self.lstm(embedding)
66
+
67
+ output_attention, attention_weights = self.attention(output_lstm, h_n)
68
+
69
+ output_finished = self.fc_out(output_attention.squeeze(2))
70
+
71
+ return torch.sigmoid(output_finished), attention_weights
72
+
73
+
74
+ model = LSTMnn()
75
+ model.load_state_dict(torch.load('models/weights/LSTMBestWeights.pt'))
76
+
77
+
78
+ def predict_3(text):
79
+
80
+ preprocessed_text = preprocess_lstm(text, MAX_LEN=MAX_LEN)
81
+
82
+ model.eval()
83
+ predict, attention = model(torch.tensor(preprocessed_text).unsqueeze(0))
84
+
85
+ predict = round(predict.item())
86
+
87
+ return predict