Upload 5 files
Browse files- 0.png +0 -0
- Model.py +20 -0
- inference.py +25 -0
- mnist.pkl +3 -0
- requirements.txt +4 -0
0.png
ADDED
Model.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class MNIST(torch.nn.Module):
|
4 |
+
def __init__(self):
|
5 |
+
super(MNIST, self).__init__()
|
6 |
+
self.conv = torch.nn.Sequential(torch.nn.Conv2d(1, 32, 3, 1, 1),
|
7 |
+
torch.nn.ReLU(),
|
8 |
+
torch.nn.Conv2d(32, 64, 3, 1, 1),
|
9 |
+
torch.nn.ReLU(),
|
10 |
+
torch.nn.MaxPool2d(2, 2))
|
11 |
+
self.dense = torch.nn.Sequential(torch.nn.Linear(14 * 14 * 64, 1024),
|
12 |
+
torch.nn.ReLU(),
|
13 |
+
torch.nn.Dropout(p=0.2),
|
14 |
+
torch.nn.Linear(1024, 10))
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
x = self.conv(x)
|
18 |
+
x = x.view(-1, 14 * 14 * 64)
|
19 |
+
x = self.dense(x)
|
20 |
+
return x
|
inference.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import time
|
5 |
+
from Model import MNIST
|
6 |
+
|
7 |
+
|
8 |
+
def images2tensor(image):
|
9 |
+
img = cv2.imread(image)
|
10 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
11 |
+
transf = transforms.ToTensor()
|
12 |
+
img_tensor = torch.unsqueeze(transf(img), dim=0)
|
13 |
+
return img_tensor
|
14 |
+
|
15 |
+
|
16 |
+
if __name__ == "__main__":
|
17 |
+
start = time.time()
|
18 |
+
device = torch.device('cpu')
|
19 |
+
model = MNIST().to(device)
|
20 |
+
model.load_state_dict(torch.load('mnist.pkl')) # load
|
21 |
+
input_data = images2tensor("0.png")
|
22 |
+
res = model(input_data)
|
23 |
+
end = time.time()
|
24 |
+
print("手写数字图片检测的结果为:", res.argmax())
|
25 |
+
print("infer time: ", end - start)
|
mnist.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f23ce66930c4c38a18e1b51ab4306e2c70319bb3ce7ef1820bf3390876e0b751
|
3 |
+
size 51503055
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
opencv-python
|
4 |
+
opencv-python-headless
|