jb / app.py
caozibo's picture
Update app.py
4de1c20 verified
import torch
from PIL import Image
import gradio as gr
# 从官方仓库加载模型
model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt', force_reload=True)
model.conf = 0.6 # 设置置信度阈值
model.iou = 0.2 # 设置交并比阈值
# 创建预测函数
def predict(image):
# 将 PIL Image 转换为适合模型输入的格式
results = model(image, size=1088) # 这里我使用了更小的 size 值来尝试减少 NMS 的计算量
# 获取绘制了检测框的图像
annotated_image = results.render() # results.render() 现在返回的是一个包含绘制了边界框的图像列表
# 获取处理后的图像和物体个数
# 我们需要确保 results.render() 返回的图像可以正确处理
labeled_image = Image.fromarray(annotated_image[0]) if annotated_image else Image.fromarray(np.array(image))
num_objects = len(results.xyxy[0]) # 检测到的物体数量
return labeled_image, num_objects
# 创建 Gradio 接口
iface = gr.Interface(
fn=predict,
inputs=gr.components.Image(type='pil', label="上传mwr~~的图片"),
outputs=[gr.components.Image(type='pil', label="含标记的图片结果"), gr.components.Label(label="总币数")],
title="数金币",
description="只能识别个数"
)
# 启动应用
iface.launch(debug=True)