FishEye8K / app.py
tuansunday05's picture
Update default model
0845688 verified
raw
history blame
No virus
7.72 kB
import spaces
import gradio as gr
from detect_deepsort import run_deepsort
from detect_strongsort import run_strongsort
from detect import run
import os
import torch
import seaborn as sns
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
import threading
should_continue = True
@spaces.GPU(duration=120)
def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None):
global should_continue
img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed
vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed
#assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided."
image_size = 640
conf_threshold = 0.5
iou_threshold = 0.5
input_path = None
output_path = None
if img_path is not None:
# Convert the numpy array to an image
img = Image.fromarray(img_path)
img_path = 'output.png'
# Save the image
img.save(img_path)
input_path = img_path
output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
elif vid_path is not None:
vid_name = 'output.mp4'
# Create a VideoCapture object
cap = cv2.VideoCapture(vid_path)
# Check if video opened successfully
if not cap.isOpened():
print("Error opening video file")
# Read the video frame by frame
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
break
# Release the VideoCapture object
cap.release()
# Convert the list of frames to a numpy array
vid_data = np.array(frames)
# Create a VideoWriter object
out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0]))
# Write the frames to the output video file
for frame in frames:
out.write(frame)
# Release the VideoWriter object
out.release()
input_path = vid_name
if tracking_algorithm == 'deep_sort':
output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True)
elif tracking_algorithm == 'strong_sort':
device_strongsort = torch.device('cuda:0')
output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True, hide_labels = True)
else:
output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True, hide_labels = True)
# Assuming output_path is the path to the output file
_, output_extension = os.path.splitext(output_path)
palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"}
if output_extension.lower() in img_extensions:
output_image = output_path # Load the image file here
output_video = None
plt.style.use("ggplot")
fig, ax = plt.subplots(figsize=(10, 6))
#for label in labels:
#df_label = frame_counts_df[frame_counts_df['label'] == label]
sns.barplot(ax = ax, data = df, x = 'label', y = 'count', palette=palette, hue = 'label', legend = False)
# Customizations
ax.set_title('Count of Labels', fontsize=20)
ax.set_xlabel('Label', fontsize=15)
ax.set_ylabel('Count', fontsize=15)
ax.tick_params(axis='x', rotation=45) # Rotate x-axis labels for better readability
sns.despine() # Remove the top and right spines from plot
#ax.legend()
ax.grid(True)
#ax.set_facecolor('#D3D3D3')
elif output_extension.lower() in vid_extensions:
output_video = output_path # Load the video file here
output_image = None
plt.style.use("ggplot")
fig, ax = plt.subplots(figsize=(10, 6))
#for label in labels:
#df_label = frame_counts_df[frame_counts_df['label'] == label]
sns.lineplot(ax = ax, data = frame_counts_df, x = 'frame', y = 'count', hue = 'label', palette = palette)
ax.set_xlabel('Frame')
ax.set_ylabel('Count')
ax.set_title('Count of Labels over Frames')
ax.legend()
ax.grid(True)
ax.set_facecolor('#D3D3D3')
return output_image, output_video, fig
def app():
with gr.Blocks(title="YOLOv9: Real-time Object Detection", css=".gradio-container {background:lightyellow;}"):
with gr.Row():
with gr.Column():
gr.HTML("<h2>Input Parameters</h2>")
img_path = gr.Image(label="Image", height = 370, width = 600)
vid_path = gr.Video(label="Video", height = 370, width = 600)
model_id = gr.Dropdown(
label="Model",
choices=[
"Our_Model.pt",
"yolov9_e_trained.pt"
],
value="Our_Model.pt"
)
tracking_algorithm = gr.Dropdown(
label= "Tracking Algorithm",
choices=[
"None",
"deep_sort",
"strong_sort"
],
value="None"
)
yolov9_infer = gr.Button(value="Inference")
gr.Examples(['./img_examples/Exam_1.png','./img_examples/Exam_2.png','./img_examples/Exam_3.png','./img_examples/Exam_4.png','./img_examples/Exam_5.png'], inputs=img_path,label = "Image Example", cache_examples = False)
gr.Examples(['./video_examples/video_1.mp4', './video_examples/video_2.mp4','./video_examples/video_3.mp4','./video_examples/video_4.mp4','./video_examples/video_5.mp4'], inputs=vid_path, label = "Video Example", cache_examples = False)
with gr.Column():
gr.HTML("<h2>Output</h2>")
output_image = gr.Image(type="numpy",label="Output")
#df = gr.BarPlot(show_label=False, x="label", y="counts", x_title="Labels", y_title="Counts", vertical=False)
output_video = gr.Video(label="Output")
#frame_counts_df = gr.LinePlot(show_label=False, x="frame", y="count", x_title="Frame", y_title="Counts", color="label")
fig = gr.Plot(label = "Plot")
#output_path = gr.Textbox(label="Output path")
yolov9_infer.click(
fn=yolov9_inference,
inputs=[
model_id,
img_path,
vid_path,
tracking_algorithm
],
outputs=[output_image, output_video, fig],
)
gradio_app = gr.Blocks()
with gradio_app:
gr.HTML(
"""
<h1 style='text-align: center'>
YOLOv9: Real-time Object Detection
</h1>
""")
css = """
body {
background-color: #f0f0f0;
}
h1 {
color: #4CAF50;
}
"""
with gr.Row():
with gr.Column():
app()
gradio_app.launch(debug=True)