nlp-bert-team / models /model1 /lstm_model.py
VerVelVel's picture
images
961ee03
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
import joblib
from gensim.models import Word2Vec
vocab_to_int = joblib.load('models/model1/lstm_vocab_to_int.pkl')
wv = Word2Vec.load("models/model1/word2vec_model.bin")
# Определение embedding_layer
embedding_matrix = np.zeros((3379, 32))
for word, i in vocab_to_int.items():
try:
embedding_vector = wv.wv[word]
embedding_matrix[i] = embedding_vector
except KeyError as e:
pass
print(f'{e}: word: {word}')
embedding_layer = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_matrix))
class BahdanauAttention(nn.Module):
def __init__(self, hidden_size=32):
super().__init__()
self.hidden_size = hidden_size
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size)
self.linear_2 = nn.Linear(self.hidden_size, self.hidden_size)
self.alogn = nn.Linear(self.hidden_size, 1)
self.tanh = nn.Tanh()
def forward(self, lstm_outputs, final_hidden):
keys = self.linear_1(lstm_outputs) # keys.shape: [batch_size, seq_len, hidden_size]
query = self.linear_2(final_hidden) # query.shape: [batch_size, hidden_size]
query = query.unsqueeze(1).expand(-1, lstm_outputs.size(1), -1) # query.shape: [batch_size, seq_len, hidden_size]
keys_query = keys + query # keys_query.shape: [batch_size, seq_len, hidden_size]
att_weights = self.tanh(keys_query) # att_weights.shape: [batch_size, seq_len, hidden_size]
att_weights = self.alogn(att_weights) # att_weights.shape: [batch_size, seq_len, 1]
att_weights = F.softmax(att_weights.squeeze(2), dim=1) # att_weights.shape: [batch_size, seq_len]
# Compute the context vector
context = torch.bmm(lstm_outputs.transpose(1, 2), att_weights.unsqueeze(2)) # context.shape: [batch_size, hidden_size, 1]
context = context.squeeze(2) # context.shape: [batch_size, hidden_size]
return context, att_weights
# Определение класса модели
class LSTMConcatAttention(nn.Module):
def __init__(self):
super().__init__()
self.embedding = embedding_layer
self.lstm = nn.LSTM(32, 32, batch_first=True)
self.attn = BahdanauAttention(32)
self.clf = nn.Sequential(
nn.Linear(32, 128),
nn.Dropout(),
nn.Tanh(),
nn.Linear(128, 1)
)
def forward(self, x):
embeddings = self.embedding(x)
outputs, (h_n, _) = self.lstm(embeddings)
context, att_weights = self.attn(outputs, h_n.squeeze(0))
out = self.clf(context)
return out, att_weights