Spaces:
Sleeping
Sleeping
File size: 4,583 Bytes
68e3bf5 f4c379b 839f10e 9e56ba5 839f10e ada3dea 839f10e 9e56ba5 839f10e 9e56ba5 839f10e 9e56ba5 839f10e c412fb1 9e56ba5 839f10e c412fb1 839f10e c412fb1 839f10e 9e56ba5 f4c379b 9e56ba5 f4c379b 787692b f4c379b 839f10e 9e56ba5 f4c379b 839f10e f4c379b 9e56ba5 f4c379b 9e56ba5 f4c379b 9e56ba5 f4c379b 839f10e f4c379b 9e56ba5 839f10e f4c379b 9e56ba5 f4c379b 9e56ba5 f4c379b 9e56ba5 f4c379b 9e56ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import spaces
import gradio as gr
from detect_deepsort import run_deepsort
from detect_strongsort import run_strongsort
from detect import run
import os
import threading
should_continue = True
@spaces.GPU(duration=120)
def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
global should_continue
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed
#assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided."
image_size = 640
conf_threshold = 0.5
iou_threshold = 0.5
input_path = None
output_path = None
if img_path is not None:
#_, img_extension = os.path.splitext(img_path)
#if img_extension.lower() in img_extensions:
input_path = img_path
print(input_path)
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
elif vid_path is not None:
#_, vid_extension = os.path.splitext(vid_path)
#if vid_extension.lower() in vid_extensions:
input_path = vid_path
print(input_path)
if tracking_algorithm == 'deep_sort':
output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
elif tracking_algorithm == 'strong_sort':
output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True)
else:
output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True)
# Assuming output_path is the path to the output file
_, output_extension = os.path.splitext(output_path)
if output_extension.lower() in img_extensions:
output_image = output_path # Load the image file here
output_video = None
elif output_extension.lower() in vid_extensions:
output_image = None
output_video = output_path # Load the video file here
return output_image, output_video, output_path
def stop_processing():
global should_continue
should_continue = False
return "Stop..."
def app():
with gr.Blocks():
with gr.Row():
with gr.Column():
gr.HTML("<h2>Input Parameters</h2>")
img_path = gr.File(label="Image")
vid_path = gr.File(label="Video")
model_id = gr.Dropdown(
label="Model",
choices=[
"last_best_model.pt",
"best_model-converted.pt",
"yolov9_e_trained.pt",
],
value="last_best_model.pt"
)
tracking_algorithm = gr.Dropdown(
label= "Tracking Algorithm",
choices=[
"None",
"deep_sort",
"strong_sort"
],
value="None"
)
yolov9_infer = gr.Button(value="Inference")
stop_button = gr.Button(value="Stop")
with gr.Column():
gr.HTML("<h2>Output</h2>")
output_image = gr.Image(type="numpy",label="Output Image")
output_video = gr.Video(label="Output Video")
output_path = gr.Textbox(label="Output path")
yolov9_infer.click(
fn=yolov9_inference,
inputs=[
model_id,
img_path,
vid_path,
tracking_algorithm
],
outputs=[output_image, output_video, output_path],
)
stop_button.click(stop_processing)
gradio_app = gr.Blocks()
with gradio_app:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv9: Real-time Object Detection
</h1>
""")
css = """
body {
background-color: #f0f0f0;
}
h1 {
color: #4CAF50;
}
"""
with gr.Row():
with gr.Column():
app()
gradio_app.launch(debug=True)
|