import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.datasets as dsets import torch.nn.functional as F import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import gradio as gr from PIL import Image img_path='number.jpg' model_path="Mnist_model.pth" output_path="output.jpg" # 定义神经网络 class neural_network(nn.Module): def __init__(self): # Conv1 >> ReLU >> MaxPooling >> Conv2 >> ReLU >> MaxPooling >> View >> Fc1 >> ReLU >> Fc2 >> Softmax # 激活函数、池化函数不需要学习参数,可以直接用nn.functional调用 super(neural_network, self).__init__() self.Conv1 = nn.Conv2d(1,10,kernel_size=5,stride=1) self.Conv2 = nn.Conv2d(10,20,kernel_size=5,stride=1) self.Fc1=nn.Linear(320,50) self.Fc2=nn.Linear(50,10) def forward(self, input): out =F.relu(self.Conv1(input)) # 28*28*1 5*5卷积 (28-5)/1+1 >> 24*24*10 out=F.max_pool2d(out,2,2) # 24*24*10 2*2池化 >> 12*12*10 out=F.relu(self.Conv2(out)) # 12*12*10 5*5 卷积 (12-5)/1+1=8 >> 8*8*20 out=F.max_pool2d(out,2,2) # 8*8*20 2*2池化 >> 4*4*20 out=out.view(-1,4*4*20) #输入全连接层网络前,将其展平为列向量 out=F.relu(self.Fc1(out)) # 320*1 >> 50*1 out=self.Fc2(out) # 50*1 >> 10*1 交叉熵损失函数不需要进行概率归一化(softmax) return out net = neural_network() def pic_process(image): size=(28,28) image.save(img_path) im=Image.open(img_path) im2=im.resize(size) #图片大小尺寸转化为28*28 image_array = np.array(im2) # image_array=np.array(im2).convert('L') #转化为灰度图像 for i in range(image_array.shape[0]): # 图像反转处理 for j in range(image_array.shape[1]): image_array[i][j]=255-image_array[i][j] image_tensor=torch.from_numpy(image_array) #将numpy数组转化为tensor张量 image_tensor=image_tensor.float() #数据类型转化为float image_tensor = image_tensor.view(1,1,28,28) # CNN网络一般输入四维张量 分别为批量大小、通道数、高度、宽度 return image_tensor,image_array def net_predict(image): images = pic_process(image)[0] images_array = pic_process(image)[1] # 载入训练好的神经网络模型 try: net.load_state_dict(torch.load(model_path)) except: print("fail to load model") outputs=net.forward(images) #前向传播 _,predicts=torch.max(outputs.data,1) #输出概率最大的索引值 # 显示图片 plt.imshow(images_array,cmap=plt.cm.binary) plt.title(label='predicted number:'+str(np.array(predicts)[0])) plt.savefig(output_path) return output_path interface = gr.Interface(fn=net_predict, inputs=gr.inputs.Image(source="canvas",type="pil",image_mode="L"),outputs=gr.outputs.Image(type="file")) interface.launch()