wangjie
commited on
Commit
•
f786a72
1
Parent(s):
1594b82
mnist test
Browse files- Model.py +20 -0
- inference.py +29 -0
- mnist.pkl +3 -0
Model.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class MNIST(torch.nn.Module):
|
4 |
+
def __init__(self):
|
5 |
+
super(MNIST, self).__init__()
|
6 |
+
self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 32, 3, 1, 1),
|
7 |
+
torch.nn.ReLU(),
|
8 |
+
torch.nn.Conv2d(32, 64, 3, 1, 1),
|
9 |
+
torch.nn.ReLU(),
|
10 |
+
torch.nn.MaxPool2d(2, 2))
|
11 |
+
self.dense = torch.nn.Sequential(torch.nn.Linear(14 * 14 * 64, 1024),
|
12 |
+
torch.nn.ReLU(),
|
13 |
+
torch.nn.Dropout(p=0.2),
|
14 |
+
torch.nn.Linear(1024, 10))
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
x = self.conv(x)
|
18 |
+
x = x.view(-1, 14 * 14 * 64)
|
19 |
+
x = self.dense(x)
|
20 |
+
return x
|
inference.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from torchvision import datasets, transforms
|
4 |
+
import time
|
5 |
+
from Model import MNIST
|
6 |
+
import numpy
|
7 |
+
# from torch.utils.data import DataLoader
|
8 |
+
def images2tensor(image):
|
9 |
+
img = cv2.imread(image)
|
10 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
11 |
+
transf = transforms.ToTensor()
|
12 |
+
img_tensor = torch.unsqueeze(transf(img), dim=0)
|
13 |
+
return img_tensor
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
device = torch.device('cpu')
|
18 |
+
model = MNIST().to(device)
|
19 |
+
model.load_state_dict(torch.load('mnist.pkl' , map_location=device)) # load
|
20 |
+
# test_dataset = datasets.MNIST(root='data/', train=False, transform=transforms.ToTensor(), download=True)
|
21 |
+
# test_image , l = test_dataset[0]
|
22 |
+
input_data = images2tensor("0.png")
|
23 |
+
start = time.time()
|
24 |
+
res = model(input_data)
|
25 |
+
end = time.time()
|
26 |
+
res = res.detach().numpy()
|
27 |
+
# res = numpy.array(res)
|
28 |
+
print("手写数字图片检测的结果为:", res.argmax())
|
29 |
+
# print("infer time: ", end - start)
|
mnist.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f23ce66930c4c38a18e1b51ab4306e2c70319bb3ce7ef1820bf3390876e0b751
|
3 |
+
size 51503055
|