Spaces:
Build error
Build error
tianliang@szlanyou.com
commited on
Commit
•
87d20de
1
Parent(s):
5d46ea2
add application file
Browse files
app.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
# from neural_network import neural_network
|
3 |
+
from neural_network_CNN import neural_network
|
4 |
+
from preprocess_data import pic_process
|
5 |
+
import matplotlib
|
6 |
+
matplotlib.use("Agg")
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import gradio as gr
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
# net=neural_network(784,512,10)
|
13 |
+
net=neural_network()
|
14 |
+
model_path=r"D:\PythonProject\Mnist_study\model\Mnist_model.pth"
|
15 |
+
output_path=r"D:\PythonProject\Mnist_study\img\output.jpg"
|
16 |
+
|
17 |
+
|
18 |
+
def net_predict(image):
|
19 |
+
images = pic_process(image)[0]
|
20 |
+
images_array = pic_process(image)[1]
|
21 |
+
# 载入训练好的神经网络模型
|
22 |
+
try:
|
23 |
+
net.load_state_dict(torch.load(model_path))
|
24 |
+
except:
|
25 |
+
print("fail to load model")
|
26 |
+
# labels=Variable(labels)
|
27 |
+
outputs=net.forward(images) #前向传播
|
28 |
+
_,predicts=torch.max(outputs.data,1) #输出概率最大的索引值
|
29 |
+
# 显示图片
|
30 |
+
# fig = plt.figure()
|
31 |
+
plt.imshow(images_array,cmap=plt.cm.binary)
|
32 |
+
plt.title(label='predicted number:'+str(np.array(predicts)[0]))
|
33 |
+
plt.savefig(output_path)
|
34 |
+
# plt.show()
|
35 |
+
return output_path
|
36 |
+
|
37 |
+
interface = gr.Interface(fn=net_predict, inputs=gr.inputs.Image(source="canvas",type="pil",image_mode="L"),outputs=gr.outputs.Image(type="file"))
|
38 |
+
interface.launch(share=True)
|
39 |
+
|
40 |
+
|