Jianfeng777 commited on
Commit
0efd334
·
1 Parent(s): 6550b65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py CHANGED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # 替换成你的 ONNX 模型文件路径
7
+ ONNX_MODEL_PATH = 'Car_Bike_Classification/end2end.onnx'
8
+
9
+ # 加载模型
10
+ ort_session = ort.InferenceSession(ONNX_MODEL_PATH)
11
+
12
+ # 定义预测函数
13
+ def classify_image(image):
14
+ # 确保image是一个PIL图像对象
15
+ if not isinstance(image, Image.Image):
16
+ image = Image.fromarray(image)
17
+
18
+ # 预处理输入图片
19
+ image = image.resize((224, 224))
20
+ image = np.array(image).astype('float32')
21
+ image = np.transpose(image, (2, 0, 1)) # Change data layout from HWC to CHW
22
+ image = np.expand_dims(image, axis=0)
23
+
24
+ # 使用 ONNX 运行推理
25
+ inputs = {ort_session.get_inputs()[0].name: image}
26
+ outputs = ort_session.run(None, inputs)
27
+
28
+ # 获取预测结果
29
+ predictions = outputs[0]
30
+
31
+ # 创建一个矩形图展示可能性
32
+ pred_probs = predictions[0].tolist() # 假设模型输出是一个概率列表
33
+ pred_probs = [float(i)/sum(pred_probs) for i in pred_probs] # 归一化概率值
34
+ class_names = ['Bike', 'Car']
35
+ result = {class_name: prob for class_name, prob in zip(class_names, pred_probs)}
36
+
37
+ # 生成结果文本
38
+ pred_text = f"该图片里是{'汽车' if np.argmax(predictions) else '单车'}。"
39
+
40
+ return result, pred_text
41
+
42
+ # 根据Gradio新的API更新了导入和组件
43
+ from gradio import Interface, components
44
+
45
+ # 创建 Gradio 界面
46
+ iface = Interface(
47
+ fn=classify_image,
48
+ inputs=components.Image(shape=(224, 224)),
49
+ outputs=[
50
+ components.Label(num_top_classes=2),
51
+ components.Textbox(label="分类结果")
52
+ ],
53
+ title="单车与汽车图像分类",
54
+ description="上传一张图片,模型将预测图片是单车还是汽车。"
55
+ )
56
+
57
+ # 启动界面
58
+ iface.launch()