import spaces import gradio as gr from detect_deepsort import run_deepsort from detect_strongsort import run_strongsort from detect import run import os import torch import seaborn as sns from PIL import Image import cv2 import numpy as np import matplotlib.pyplot as plt import threading should_continue = True @spaces.GPU(duration=240) def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None): global should_continue img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed #assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided." image_size = 640 conf_threshold = 0.5 iou_threshold = 0.5 input_path = None output_path = None if img_path is not None: # Convert the numpy array to an image img = Image.fromarray(img_path) img_path = 'output.png' # Save the image img.save(img_path) input_path = img_path output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) elif vid_path is not None: vid_name = 'output.mp4' # Create a VideoCapture object cap = cv2.VideoCapture(vid_path) # Check if video opened successfully if not cap.isOpened(): print("Error opening video file") # Read the video frame by frame frames = [] while cap.isOpened(): ret, frame = cap.read() if ret: frames.append(frame) else: break # Release the VideoCapture object cap.release() # Convert the list of frames to a numpy array vid_data = np.array(frames) # Create a VideoWriter object out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0])) # Write the frames to the output video file for frame in frames: out.write(frame) # Release the VideoWriter object out.release() input_path = vid_name if tracking_algorithm == 'deep_sort': output_path, df, frame_counts_df = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True) elif tracking_algorithm == 'strong_sort': device_strongsort = torch.device('cuda:0') output_path, df, frame_counts_df = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True) else: output_path, df, frame_counts_df = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) # Assuming output_path is the path to the output file _, output_extension = os.path.splitext(output_path) palette = {"Bus": "red", "Bike": "blue", "Car": "green", "Pedestrian": "yellow", "Truck": "purple"} if output_extension.lower() in img_extensions: output_image = output_path # Load the image file here output_video = None plt.style.use("ggplot") fig, ax = plt.subplots(figsize=(10, 6)) #for label in labels: #df_label = frame_counts_df[frame_counts_df['label'] == label] sns.barplot(ax=ax, data=df, x='label', y='count', palette=palette, hue='label') # Customizations ax.set_title('Count of Labels', fontsize=20, pad=20) # Increase padding for the title ax.set_xlabel('Label', fontsize=16) # Increase font size ax.set_ylabel('Count', fontsize=16) # Increase font size ax.tick_params(axis='x', rotation=45, labelsize=12) # Increase label size and rotate x-axis labels for better readability ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis sns.despine() # Remove the top and right spines from plot # Add grid but make it lighter and put it behind bars ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6) ax.set_axisbelow(True) # Add a legend with a smaller font size ax.legend(fontsize=10) plt.tight_layout() # Ensure the entire plot fits into the figure area #ax.set_facecolor('#D3D3D3') elif output_extension.lower() in vid_extensions: output_video = output_path # Load the video file here output_image = None plt.style.use("ggplot") fig, ax = plt.subplots(figsize=(10, 6)) #for label in labels: #df_label = frame_counts_df[frame_counts_df['label'] == label] sns.lineplot(ax = ax, data = frame_counts_df, x = 'frame', y = 'count', hue = 'label', palette=palette,linewidth=2.5) ax.set_title('Count of Labels over Frames', fontsize=20, pad=20) # Increase padding for the title ax.set_xlabel('Frame', fontsize=16) # Increase font size ax.set_ylabel('Count', fontsize=16) # Increase font size ax.tick_params(axis='x', labelsize=12) # Increase label size for x-axis ax.tick_params(axis='y', labelsize=12) # Increase label size for y-axis # Add grid but make it lighter and put it behind bars ax.grid(True, linestyle=':', linewidth=0.6, color='gray', alpha=0.6) ax.set_axisbelow(True) # Change the background color to a lighter shade ax.set_facecolor('#F0F0F0') # Add a legend with a smaller font size ax.legend(fontsize=10) plt.tight_layout() # Ensure the entire return output_image, output_video, fig def app(): with gr.Blocks(title="YOLOv9: Real-time Object Detection", css=".gradio-container {background:lightyellow;}"): with gr.Row(): with gr.Column(): gr.HTML("

Input Parameters

") img_path = gr.Image(label="Image", height = 260, width = 410) vid_path = gr.Video(label="Video", height = 260, width = 410) #gr.Examples(['./img_examples/Exam_1.png','./img_examples/Exam_2.png','./img_examples/Exam_3.png','./img_examples/Exam_4.png','./img_examples/Exam_5.png'], inputs=img_path,label = "Image Example", cache_examples = False) #gr.Examples(['./video_examples/video_1.mp4', './video_examples/video_2.mp4','./video_examples/video_3.mp4','./video_examples/video_4.mp4','./video_examples/video_5.mp4'], inputs=vid_path, label = "Video Example", cache_examples = False) with gr.Column(min_width = 270): gr.HTML("

Output

") output_image = gr.Image(type="numpy",label="Output", height = 260, width = 410) #df = gr.BarPlot(show_label=False, x="label", y="counts", x_title="Labels", y_title="Counts", vertical=False) output_video = gr.Video(label="Output", height = 260, width = 410) #frame_counts_df = gr.LinePlot(show_label=False, x="frame", y="count", x_title="Frame", y_title="Counts", color="label") fig = gr.Plot(label = "label") #output_path = gr.Textbox(label="Output path") with gr.Column(): gr.HTML("

Example, Model and Tracking Algorithm

") model_id = gr.Dropdown( label="Model", choices=[ "Our_Model.pt", "yolov9_e_trained.pt" ], value="Our_Model.pt" ) tracking_algorithm = gr.Dropdown( label= "Tracking Algorithm", choices=[ "None", "deep_sort", "strong_sort" ], value="None" ) yolov9_infer = gr.Button(value="Inference") gr.Examples(['./img_examples/Exam_1.png','./img_examples/Exam_2.png','./img_examples/Exam_3.png','./img_examples/Exam_4.png','./img_examples/Exam_5.png'], inputs=img_path,label = "Image Example", cache_examples = False, examples_per_page = 3) gr.Examples(['./video_examples/video_1.mp4', './video_examples/video_2.mp4','./video_examples/video_3.mp4','./video_examples/video_4.mp4','./video_examples/video_5.mp4'], inputs=vid_path, label = "Video Example", cache_examples = False, examples_per_page = 3) yolov9_infer.click( fn=yolov9_inference, inputs=[ model_id, img_path, vid_path, tracking_algorithm ], outputs=[output_image, output_video, fig], ) gradio_app = gr.Blocks() with gradio_app: gr.HTML( """

YOLOv9-FishEye: Improving model for realtime fisheye camera object detection

""") css = """ body { background-color: #f0f0f0; } h1 { color: #4CAF50; } """ with gr.Row(): with gr.Column(): app() gradio_app.launch(debug=True)