FST / app.py
Hakureirm's picture
Update app.py
9b94ed9 verified
import os
import cv2
import numpy as np
import gradio as gr
import torch
from mouse_tracker import MouseTrackerAnalyzer
from huggingface_hub import hf_hub_download
# 检查是否在Hugging Face Spaces环境中
try:
import spaces
is_spaces = True
print("检测到 Hugging Face Spaces 环境")
except ImportError:
is_spaces = False
print("在本地环境运行")
# 全局配置
model_base_name = "fst-v1.3-n" # 模型基础名称,无后缀
total_frames = 0
# 根据后缀构造模型路径
def get_model_file_path(model_suffix):
return f"./{model_base_name}{model_suffix}"
# 从视频中提取特定帧
def extract_frame(video_path, frame_num):
if not video_path:
return None
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = cap.read()
cap.release()
if not ret:
return None
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 选择视频文件
def select_video(video_file, model_suffix):
global total_frames
if not video_file:
return None, "请选择视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0)
total_frames = int(cv2.VideoCapture(video_file).get(cv2.CAP_PROP_FRAME_COUNT))
# 读取首帧
cap = cv2.VideoCapture(video_file)
ret, frame = cap.read()
cap.release()
if not ret:
return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 更新滑块
start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path(model_suffix))}"
return frame_rgb, status, start, end
# 预览帧
def preview_frame(video_file, frame_num):
if not video_file:
return None, "请先选择视频文件"
frame = extract_frame(video_file, frame_num)
if frame is None:
return None, "无法读取指定帧"
return frame, f"帧 {frame_num}"
# 分析实现
def _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
if not video:
return None, None, "请选择视频文件"
if start_frame >= end_frame:
return None, None, "起始帧必须小于结束帧"
# 构造路径
video_name = os.path.splitext(os.path.basename(video))[0]
output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = get_model_file_path(model_suffix)
if not os.path.exists(model_path):
if is_spaces:
try:
model_path = hf_hub_download(
repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME",
filename=f"weights/{model_base_name}{model_suffix}"
)
except Exception:
print(f"下载模型失败: {model_path}")
else:
print(f"警告: 本地未找到模型文件 {model_path}")
# 初始化分析器
analyzer = MouseTrackerAnalyzer(
model_path=model_path,
conf=conf,
iou=iou,
max_det=max_det,
verbose=True
)
analyzer.struggle_threshold = threshold
# 运行分析
analyzer.process_video(
video_path=video,
output_path=output_path,
start_frame=start_frame,
end_frame=end_frame,
callback=lambda prog, frm, res: print(f"进度: {prog}% 检测: {len(res)} 项")
)
analyzer.save_results(csv_path)
# 生成图表
plot_path = None
if analyzer.results:
plot_path = analyzer.generate_time_series_plot()
status = f"分析完成。视频: {output_path}, CSV: {csv_path}"
if plot_path:
status += f", 图表: {plot_path}"
return output_path, plot_path, status
# HF Spaces GPU 装饰
if is_spaces:
@spaces.GPU(duration=120)
def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold)
else:
def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold)
# 创建 Gradio 界面
def create_interface():
with gr.Blocks(title="鼠强迫游泳挣扎度分析") as app:
gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(label="输入视频")
model_format = gr.Dropdown(
label="模型格式",
choices=[".onnx", ".engine", ".pt", ".mlpackage"],
value=".onnx",
interactive=True
)
device_info = gr.Textbox(
label="系统信息",
value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}",
interactive=False
)
conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值")
iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值")
max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数")
threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值")
start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧")
end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧")
preview_btn = gr.Button("预览帧")
start_btn = gr.Button("开始分析", variant="primary")
with gr.Column(scale=2):
with gr.Tab("预览"):
preview_image = gr.Image(label="预览图像", type="numpy", height=400)
status_text = gr.Textbox(label="状态", interactive=False)
with gr.Tab("结果"):
output_video = gr.Video(label="分析结果视频")
result_plot = gr.Image(label="挣扎分数时间序列")
result_status = gr.Textbox(label="分析状态", interactive=False)
# 事件绑定,包含模型格式参数
video_input.change(select_video, inputs=[video_input, model_format], outputs=[preview_image, status_text, start_frame, end_frame])
preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
start_btn.click(
start_analysis,
inputs=[video_input, model_format, conf, iou, max_det, start_frame, end_frame, threshold],
outputs=[output_video, result_plot, result_status]
)
return app
if __name__ == "__main__":
# 清理代理
for key in ['http_proxy', 'https_proxy', 'all_proxy']:
os.environ.pop(key, None)
print(f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print(f"默认模型路径: {get_model_file_path('.onnx')}")
app = create_interface()
if is_spaces:
app.launch()
else:
app.launch(server_name="0.0.0.0", server_port=7860, share=False)