Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from detect_deepsort import run_deepsort | |
from detect_strongsort import run_strongsort | |
from detect import run | |
import os | |
import threading | |
should_continue = True | |
def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None): | |
global should_continue | |
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed | |
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed | |
#assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided." | |
image_size = 640 | |
conf_threshold = 0.5 | |
iou_threshold = 0.5 | |
input_path = None | |
output_path = None | |
if img_path is not None: | |
#_, img_extension = os.path.splitext(img_path) | |
#if img_extension.lower() in img_extensions: | |
input_path = img_path | |
print(input_path) | |
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) | |
elif vid_path is not None: | |
#_, vid_extension = os.path.splitext(vid_path) | |
#if vid_extension.lower() in vid_extensions: | |
input_path = vid_path | |
print(input_path) | |
if tracking_algorithm == 'deep_sort': | |
output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True) | |
elif tracking_algorithm == 'strong_sort': | |
output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True) | |
else: | |
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) | |
# Assuming output_path is the path to the output file | |
_, output_extension = os.path.splitext(output_path) | |
if output_extension.lower() in img_extensions: | |
output_image = output_path # Load the image file here | |
output_video = None | |
elif output_extension.lower() in vid_extensions: | |
output_image = None | |
output_video = output_path # Load the video file here | |
return output_image, output_video, output_path | |
def stop_processing(): | |
global should_continue | |
should_continue = False | |
return "Stop..." | |
def app(): | |
with gr.Blocks(): | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML("<h2>Input Parameters</h2>") | |
img_path = gr.File(label="Image") | |
vid_path = gr.File(label="Video") | |
model_id = gr.Dropdown( | |
label="Model", | |
choices=[ | |
"last_best_model.pt", | |
"best_model-converted.pt", | |
"yolov9_e_trained.pt", | |
], | |
value="last_best_model.pt" | |
) | |
tracking_algorithm = gr.Dropdown( | |
label= "Tracking Algorithm", | |
choices=[ | |
"None", | |
"deep_sort", | |
"strong_sort" | |
], | |
value="None" | |
) | |
yolov9_infer = gr.Button(value="Inference") | |
stop_button = gr.Button(value="Stop") | |
with gr.Column(): | |
gr.HTML("<h2>Output</h2>") | |
output_image = gr.Image(type="numpy",label="Output Image") | |
output_video = gr.Video(label="Output Video") | |
output_path = gr.Textbox(label="Output path") | |
yolov9_infer.click( | |
fn=yolov9_inference, | |
inputs=[ | |
model_id, | |
img_path, | |
vid_path, | |
tracking_algorithm | |
], | |
outputs=[output_image, output_video, output_path], | |
) | |
stop_button.click(stop_processing) | |
gradio_app = gr.Blocks() | |
with gradio_app: | |
gr.HTML( | |
""" | |
<h1 style='text-align: center'> | |
YOLOv9: Real-time Object Detection | |
</h1> | |
""") | |
css = """ | |
body { | |
background-color: #f0f0f0; | |
} | |
h1 { | |
color: #4CAF50; | |
} | |
""" | |
with gr.Row(): | |
with gr.Column(): | |
app() | |
gradio_app.launch(debug=True) | |