|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
class TextFilter(nn.Module): |
|
def __init__(self, input_size, hidden_size, output_size): |
|
super(TextFilter, self).__init__() |
|
self.hidden_size = hidden_size |
|
self.embedding = nn.Embedding(input_size, hidden_size) |
|
self.gru = nn.GRU(hidden_size, hidden_size) |
|
self.fc = nn.Linear(hidden_size, output_size) |
|
self.sigmoid = nn.Sigmoid() |
|
def forward(self, input): |
|
embedded = self.embedding(input).view(1, 1, -1) |
|
output, hidden = self.gru(embedded) |
|
output = self.fc(output[0]) |
|
output = self.sigmoid(output) |
|
return output, hidden |
|
def init_hidden(self): |
|
return torch.zeros(1, 1, self.hidden_size) |
|
input_size = |
|
hidden_size = 256 |
|
output_size = 1 |
|
model = TextFilter(input_size, hidden_size, output_size) |
|
criterion = nn.BCELoss() |
|
optimizer = optim.SGD(model.parameters(), lr=0.1) |
|
|
|
def train(text, label): |
|
hidden = model.init_hidden() |
|
model.zero_grad() |
|
for i in range(text.size()[0]): |
|
output, hidden = model(text[i], hidden) |
|
loss = criterion(output, label) |
|
loss.backward() |
|
optimizer.step() |
|
return output, loss.item() |
|
|
|
text = torch.tensor([1, 2, 3, 4, 5]) |
|
label = torch.tensor([1]) |
|
output, loss = train(text, label) |
|
print(output) |