|
import cv2 |
|
import torch |
|
import torchvision.transforms as transforms |
|
import time |
|
from Model import MNIST |
|
|
|
|
|
def images2tensor(image): |
|
img = cv2.imread(image) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
|
transf = transforms.ToTensor() |
|
img_tensor = torch.unsqueeze(transf(img), dim=0) |
|
return img_tensor |
|
|
|
|
|
if __name__ == "__main__": |
|
start = time.time() |
|
device = torch.device('cpu') |
|
model = MNIST().to(device) |
|
model.load_state_dict(torch.load('mnist.pkl')) |
|
input_data = images2tensor("0.png") |
|
res = model(input_data) |
|
end = time.time() |
|
print("手写数å—图片检测的结果为:", res.argmax()) |
|
print("infer time: ", end - start) |
|
|