FishEye8K / app.py
thai thong
fix error
c412fb1
raw
history blame
No virus
4.58 kB
import spaces
import gradio as gr
from detect_deepsort import run_deepsort
from detect_strongsort import run_strongsort
from detect import run
import os
import threading
should_continue = True
@spaces.GPU(duration=120)
def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
global should_continue
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed
#assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided."
image_size = 640
conf_threshold = 0.5
iou_threshold = 0.5
input_path = None
output_path = None
if img_path is not None:
#_, img_extension = os.path.splitext(img_path)
#if img_extension.lower() in img_extensions:
input_path = img_path
print(input_path)
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
elif vid_path is not None:
#_, vid_extension = os.path.splitext(vid_path)
#if vid_extension.lower() in vid_extensions:
input_path = vid_path
print(input_path)
if tracking_algorithm == 'deep_sort':
output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
elif tracking_algorithm == 'strong_sort':
output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
else:
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
# Assuming output_path is the path to the output file
_, output_extension = os.path.splitext(output_path)
if output_extension.lower() in img_extensions:
output_image = output_path # Load the image file here
output_video = None
elif output_extension.lower() in vid_extensions:
output_image = None
output_video = output_path # Load the video file here
return output_image, output_video, output_path
def stop_processing():
global should_continue
should_continue = False
return "Stop..."
def app():
with gr.Blocks():
with gr.Row():
with gr.Column():
gr.HTML("<h2>Input Parameters</h2>")
img_path = gr.File(label="Image")
vid_path = gr.File(label="Video")
model_id = gr.Dropdown(
label="Model",
choices=[
"last_best_model.pt",
"best_model-converted.pt",
"yolov9_e_trained.pt",
],
value="last_best_model.pt"
)
tracking_algorithm = gr.Dropdown(
label= "Tracking Algorithm",
choices=[
"None",
"deep_sort",
"strong_sort"
],
value="None"
)
yolov9_infer = gr.Button(value="Inference")
stop_button = gr.Button(value="Stop")
with gr.Column():
gr.HTML("<h2>Output</h2>")
output_image = gr.Image(type="numpy",label="Output Image")
output_video = gr.Video(label="Output Video")
output_path = gr.Textbox(label="Output path")
yolov9_infer.click(
fn=yolov9_inference,
inputs=[
model_id,
img_path,
vid_path,
tracking_algorithm
],
outputs=[output_image, output_video, output_path],
)
stop_button.click(stop_processing)
gradio_app = gr.Blocks()
with gradio_app:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv9: Real-time Object Detection
</h1>
""")
css = """
body {
background-color: #f0f0f0;
}
h1 {
color: #4CAF50;
}
"""
with gr.Row():
with gr.Column():
app()
gradio_app.launch(debug=True)