|
import torch.nn as nn |
|
import torch |
|
|
|
class LSTMClassifier(nn.Module): |
|
def __init__(self, input_size=1, hidden_size=64, num_layers=1, |
|
bidirectional=True, dropout=0.0, num_classes=2): |
|
super(LSTMClassifier, self).__init__() |
|
self.hidden_size = hidden_size |
|
self.num_layers = num_layers |
|
self.bidirectional = bidirectional |
|
|
|
self.lstm = nn.LSTM( |
|
input_size=input_size, |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
batch_first=True, |
|
dropout=dropout if num_layers > 1 else 0.0, |
|
bidirectional=bidirectional |
|
) |
|
|
|
direction_factor = 2 if bidirectional else 1 |
|
self.fc = nn.Linear(hidden_size * direction_factor, num_classes) |
|
|
|
def forward(self, x): |
|
_, (hn, _) = self.lstm(x) |
|
if self.bidirectional: |
|
forward = hn[-2] |
|
backward = hn[-1] |
|
combined = torch.cat((forward, backward), dim=1) |
|
else: |
|
combined = hn[-1] |
|
return self.fc(combined) |
|
|