import torch import torch.nn as nn class Attention(nn.Module): def __init__(self, hidden_size): super(Attention, self).__init__() self.W1 = nn.Linear(hidden_size, hidden_size) self.W2 = nn.Linear(hidden_size, hidden_size) self.v = nn.Linear(hidden_size, 1, bias=False) def forward(self, hidden, encoder_outputs): sequence_len = encoder_outputs.shape[1] hidden = hidden.unsqueeze(1).repeat(1, sequence_len, 1) energy = torch.tanh(self.W1(encoder_outputs) + self.W2(hidden)) attention = self.v(energy).squeeze(2) attention_weights = torch.softmax(attention, dim=1) context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1) return context, attention_weights class SimpleRecurrentNetworkWithAttention(nn.Module): def __init__(self, input_size, hidden_size, output_size, cell_type='RNN', device='cpu'): super(SimpleRecurrentNetworkWithAttention, self).__init__() self.device = device self.embedding = nn.Embedding(input_size, hidden_size) self.attention = Attention(hidden_size * 2) if cell_type == 'LSTM': self.rnn = nn.LSTM(hidden_size, hidden_size, batch_first=True, bidirectional=True) elif cell_type == 'GRU': self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) else: self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True, bidirectional=True) self.fc = nn.Linear(hidden_size * 2, output_size) def forward(self, inputs): embedded = self.embedding(inputs.to(self.device)) rnn_output, hidden = self.rnn(embedded) if isinstance(hidden, tuple): hidden = hidden[0] hidden = torch.cat((hidden[-2], hidden[-1]), dim=1) context, attention_weights = self.attention(hidden, rnn_output) output = self.fc(context) return output, attention_weights