isLandLZ commited on
Commit
9e14f32
1 Parent(s): 908deda

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +113 -0
train.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import optim
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ from Pytorch_MNIST图片识别.model import Net
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+
10
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
11
+
12
+ # TODO epoch的数量定义了我们将循环整个训练数据集的次数
13
+ n_epochs = 3
14
+
15
+ # TODO 使用batch_size=64进行训练,并使用size=1000对这个数据集进行测试
16
+ batch_size_train = 64
17
+ batch_size_test = 1000
18
+
19
+ # TODO 优化器的超参数
20
+ learning_rate = 0.01
21
+ momentum = 0.5
22
+
23
+ log_interval = 10
24
+ random_seed = 1
25
+ torch.manual_seed(random_seed)
26
+
27
+ # TODO 自动将MNIST数据集下载到目录下的data文件夹
28
+ train_loader = torch.utils.data.DataLoader(
29
+ torchvision.datasets.MNIST('./data/', train=True, download=True,
30
+ transform=torchvision.transforms.Compose([
31
+
32
+ torchvision.transforms.ToTensor(),
33
+ # TODO MNIST数据集的全局平均值和标准偏差
34
+ torchvision.transforms.Normalize(
35
+ (0.1307,), (0.3081,))
36
+ ])),
37
+ batch_size=batch_size_train, shuffle=True)
38
+ test_loader = torch.utils.data.DataLoader(
39
+ torchvision.datasets.MNIST('./data/', train=False, download=True,
40
+ transform=torchvision.transforms.Compose([
41
+ torchvision.transforms.ToTensor(),
42
+ # TODO MNIST数据集的全局平均值和标准偏差
43
+ torchvision.transforms.Normalize(
44
+ (0.1307,), (0.3081,))
45
+ ])),
46
+ batch_size=batch_size_test, shuffle=False)
47
+
48
+ # TODO 初始化网络和优化器
49
+ network = Net()
50
+ optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)
51
+
52
+ train_losses = []
53
+ train_counter = []
54
+ test_losses = []
55
+ test_counter = [i * len(train_loader.dataset) for i in range(n_epochs + 1)]
56
+
57
+ # TODO 模型存储位置(一个是完整的模型,一个是只有参数的模型)
58
+ # TODO 需要先建立一个model文件夹
59
+ model_path = './model1/model.pth'
60
+ optimizer_path = './model1/optimizer.pth'
61
+
62
+
63
+ def train(epoch):
64
+ network.train()
65
+ for batch_idx, (data, target) in enumerate(train_loader):
66
+ # TODO 需要使用optimizer.zero_grad()手动将梯度设置为零,因为PyTorch在默认情况下会累积梯度
67
+ optimizer.zero_grad()
68
+ output = network(data)
69
+ loss = F.nll_loss(output, target)
70
+ loss.backward()
71
+ optimizer.step()
72
+ if batch_idx % log_interval == 0:
73
+ print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
74
+ epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
75
+ train_losses.append(loss.item())
76
+ train_counter.append(
77
+ (batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
78
+ if epoch == (n_epochs - 1):
79
+ # TODO 存储模型
80
+ torch.save(network.state_dict(), model_path)
81
+ torch.save(optimizer.state_dict(), optimizer_path)
82
+
83
+
84
+ def test():
85
+ network.eval()
86
+ test_loss = 0
87
+ correct = 0
88
+ with torch.no_grad():
89
+ for data, target in test_loader:
90
+ output = network(data)
91
+ test_loss += F.nll_loss(output, target, size_average=False).item()
92
+ pred = output.data.max(1, keepdim=True)[1]
93
+ correct += pred.eq(target.data.view_as(pred)).sum()
94
+ test_loss /= len(test_loader.dataset)
95
+ test_losses.append(test_loss)
96
+ print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
97
+ test_loss, correct, len(test_loader.dataset),
98
+ 100. * correct / len(test_loader.dataset)))
99
+
100
+
101
+ for epoch in range(1, n_epochs + 1):
102
+ train(epoch)
103
+ test()
104
+
105
+ fig = plt.figure()
106
+ plt.plot(train_counter, train_losses, color='blue')
107
+ print(len(test_counter))
108
+ print(len(test_losses))
109
+ plt.scatter(test_counter, test_losses, color='red')
110
+ plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
111
+ plt.xlabel('number of training examples seen')
112
+ plt.ylabel('negative log likelihood loss')
113
+ plt.show()