Multiplication / train.py
xcx0902's picture
Upload folder using huggingface_hub
0613d9c verified
raw
history blame
1.22 kB
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from MultiplicationNet import MultiplicationNet
from device import device
def generate_data(num_samples, min_val=0, max_val=100):
x1 = np.random.randint(min_val, max_val, size=(num_samples, 1))
x2 = np.random.randint(min_val, max_val, size=(num_samples, 1))
y = x1 * x2
return np.hstack([x1, x2]), y
def train():
num_samples = 10000
num_epochs = 30000
learning_rate = 0.01
x, y = generate_data(num_samples)
x_train = torch.tensor(x, dtype=torch.float).to(device)
y_train = torch.tensor(y, dtype=torch.float).to(device)
model = MultiplicationNet().to(device)
criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.95)
for epoch in range(num_epochs):
outputs = model(x_train)
loss = criterion(outputs, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
print(f"Epoch {epoch}, loss = {loss.item()}")
torch.save(model, "model.pth")
if __name__ == '__main__':
train()