Jianfeng777's picture
Update app.py
8f7d5fd
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
# 替换成你的 ONNX 模型文件路径
ONNX_MODEL_PATH = 'end2end.onnx'
# 加载模型
ort_session = ort.InferenceSession(ONNX_MODEL_PATH)
# 定义预测函数
def classify_image(image):
# 确保image是一个PIL图像对象
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# 预处理输入图片
image = image.resize((224, 224))
image = np.array(image).astype('float32')
image = np.transpose(image, (2, 0, 1)) # Change data layout from HWC to CHW
image = np.expand_dims(image, axis=0)
# 使用 ONNX 运行推理
inputs = {ort_session.get_inputs()[0].name: image}
outputs = ort_session.run(None, inputs)
# 获取预测结果
predictions = outputs[0]
# 创建一个矩形图展示可能性
pred_probs = predictions[0].tolist() # 假设模型输出是一个概率列表
pred_probs = [float(i)/sum(pred_probs) for i in pred_probs] # 归一化概率值
class_names = ['Bike', 'Car']
result = {class_name: prob for class_name, prob in zip(class_names, pred_probs)}
# 生成结果文本
pred_text = f"该图片里是{'汽车' if np.argmax(predictions) else '单车'}。"
return result, pred_text
# 根据Gradio新的API更新了导入和组件
from gradio import Interface, components
# 创建 Gradio 界面
iface = Interface(
fn=classify_image,
inputs=components.Image(),
outputs=[
components.Label(num_top_classes=2),
components.Textbox(label="分类结果")
],
title="单车与汽车图像分类",
description="上传一张图片,模型将预测图片是单车还是汽车。"
)
# 启动界面
iface.launch(share=True)