tianliang@szlanyou.com commited on
Commit
87d20de
1 Parent(s): 5d46ea2

add application file

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # from neural_network import neural_network
3
+ from neural_network_CNN import neural_network
4
+ from preprocess_data import pic_process
5
+ import matplotlib
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import gradio as gr
10
+ from PIL import Image
11
+
12
+ # net=neural_network(784,512,10)
13
+ net=neural_network()
14
+ model_path=r"D:\PythonProject\Mnist_study\model\Mnist_model.pth"
15
+ output_path=r"D:\PythonProject\Mnist_study\img\output.jpg"
16
+
17
+
18
+ def net_predict(image):
19
+ images = pic_process(image)[0]
20
+ images_array = pic_process(image)[1]
21
+ # 载入训练好的神经网络模型
22
+ try:
23
+ net.load_state_dict(torch.load(model_path))
24
+ except:
25
+ print("fail to load model")
26
+ # labels=Variable(labels)
27
+ outputs=net.forward(images) #前向传播
28
+ _,predicts=torch.max(outputs.data,1) #输出概率最大的索引值
29
+ # 显示图片
30
+ # fig = plt.figure()
31
+ plt.imshow(images_array,cmap=plt.cm.binary)
32
+ plt.title(label='predicted number:'+str(np.array(predicts)[0]))
33
+ plt.savefig(output_path)
34
+ # plt.show()
35
+ return output_path
36
+
37
+ interface = gr.Interface(fn=net_predict, inputs=gr.inputs.Image(source="canvas",type="pil",image_mode="L"),outputs=gr.outputs.Image(type="file"))
38
+ interface.launch(share=True)
39
+
40
+