|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader, Dataset |
|
import json |
|
import os |
|
|
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, texts, labels): |
|
self.texts = texts |
|
self.labels = labels |
|
|
|
def __len__(self): |
|
return len(self.texts) |
|
|
|
def __getitem__(self, idx): |
|
return self.texts[idx], self.labels[idx] |
|
|
|
|
|
class LSTMModel(nn.Module): |
|
def __init__(self, input_size, hidden_size, output_size): |
|
super(LSTMModel, self).__init__() |
|
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) |
|
self.fc = nn.Linear(hidden_size, output_size) |
|
|
|
def forward(self, x): |
|
lstm_out, _ = self.lstm(x) |
|
out = self.fc(lstm_out[:, -1, :]) |
|
return out |
|
|
|
|
|
input_size = 100 |
|
hidden_size = 64 |
|
output_size = 10 |
|
num_epochs = 5 |
|
learning_rate = 0.001 |
|
|
|
|
|
model = LSTMModel(input_size, hidden_size, output_size) |
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
|
|
texts = torch.randn(100, 10, input_size) |
|
labels = torch.randint(0, output_size, (100,)) |
|
|
|
|
|
dataset = CustomDataset(texts, labels) |
|
data_loader = DataLoader(dataset, batch_size=16, shuffle=True) |
|
|
|
|
|
for epoch in range(num_epochs): |
|
for inputs, targets in data_loader: |
|
|
|
outputs = model(inputs) |
|
loss = criterion(outputs, targets) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}') |
|
|
|
|
|
model_save_path = "model" |
|
os.makedirs(model_save_path, exist_ok=True) |
|
|
|
|
|
torch.save(model.state_dict(), os.path.join(model_save_path, "pytorch_model.bin")) |
|
|
|
|
|
config = { |
|
"input_size": input_size, |
|
"hidden_size": hidden_size, |
|
"output_size": output_size, |
|
"num_layers": 1, |
|
"dropout": 0.2 |
|
} |
|
|
|
|
|
with open(os.path.join(model_save_path, "config.json"), "w") as f: |
|
json.dump(config, f) |
|
|
|
print("Model and configuration saved successfully!") |
|
|