BadWordDetector / python.py
BogdanNotGay's picture
Upload python.py
d3eda1e
import torch
import torch.nn as nn
import torch.optim as optim
class TextFilter(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(TextFilter, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size)
self.fc = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
def forward(self, input):
embedded = self.embedding(input).view(1, 1, -1)
output, hidden = self.gru(embedded)
output = self.fc(output[0])
output = self.sigmoid(output)
return output, hidden
def init_hidden(self):
return torch.zeros(1, 1, self.hidden_size)
input_size = # размер словаря
hidden_size = 256 # размер скрытого слоя
output_size = 1 # размер выходного слоя (плохое слово / не плохое слово)
model = TextFilter(input_size, hidden_size, output_size)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
# Обучение модели
def train(text, label):
hidden = model.init_hidden()
model.zero_grad()
for i in range(text.size()[0]):
output, hidden = model(text[i], hidden)
loss = criterion(output, label)
loss.backward()
optimizer.step()
return output, loss.item()
# Пример использования модели
text = torch.tensor([1, 2, 3, 4, 5]) # предложение в виде списка индексов слов
label = torch.tensor([1]) # метка класса (1 - плохое слово, 0 - не плохое слово)
output, loss = train(text, label)
print(output)