import spaces import gradio as gr from detect_deepsort import run_deepsort from detect_strongsort import run_strongsort from detect import run import os import torch from PIL import Image import numpy as np import threading import cv2 should_continue = True @spaces.GPU(duration=120) def yolov9_inference(model_id, img_path=None, vid_path=None, tracking_algorithm = None): global should_continue img_extensions = ['.jpg', '.jpeg', '.png', '.gif'] # Add more image extensions if needed vid_extensions = ['.mp4', '.avi', '.mov', '.mkv'] # Add more video extensions if needed #assert img_path is not None or vid_path is not None, "Either img_path or vid_path must be provided." image_size = 640 conf_threshold = 0.5 iou_threshold = 0.5 input_path = None output_path = None if img_path is not None: # Convert the numpy array to an image img = Image.fromarray(img_path) img_path = 'output.png' # Save the image img.save(img_path) input_path = img_path print(input_path) output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) elif vid_path is not None: vid_name = 'output.mp4' # Create a VideoCapture object cap = cv2.VideoCapture(vid_path) # Check if video opened successfully if not cap.isOpened(): print("Error opening video file") # Read the video frame by frame frames = [] while cap.isOpened(): ret, frame = cap.read() if ret: frames.append(frame) else: break # Release the VideoCapture object cap.release() # Convert the list of frames to a numpy array vid_data = np.array(frames) # Create a VideoWriter object out = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frames[0].shape[1], frames[0].shape[0])) # Write the frames to the output video file for frame in frames: out.write(frame) # Release the VideoWriter object out.release() input_path = vid_name if tracking_algorithm == 'deep_sort': output_path = run_deepsort(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', draw_trails=True) elif tracking_algorithm == 'strong_sort': device_strongsort = torch.device('cuda:0') output_path = run_strongsort(yolo_weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device=device_strongsort, strong_sort_weights = "osnet_x0_25_msmt17.pt", hide_conf= True) else: output_path = run(weights=model_id, imgsz=(image_size,image_size), conf_thres=conf_threshold, iou_thres=iou_threshold, source=input_path, device='0', hide_conf= True) # Assuming output_path is the path to the output file _, output_extension = os.path.splitext(output_path) if output_extension.lower() in img_extensions: output_image = output_path # Load the image file here output_video = None elif output_extension.lower() in vid_extensions: output_image = None output_video = output_path # Load the video file here return output_image, output_video, output_path def app(model_id, img_path, vid_path, tracking_algorithm): return yolov9_inference(model_id, img_path, vid_path, tracking_algorithm) iface = gr.Interface( fn=app, inputs=[ gr.Dropdown( label="Model", choices=[ "our-converted.pt", "yolov9_e_trained-converted.pt", "last_best_model.pt" ], value="our-converted.pt" ), gr.Image(label="Image"), gr.Video(label="Video"), gr.Dropdown( label= "Tracking Algorithm", choices=[ "None", "deep_sort", "strong_sort" ], value="None" ) ], outputs=[ gr.Image(type="numpy",label="Output Image"), gr.Video(label="Output Video"), gr.Textbox(label="Output path") ], examples=[ ["last_best_model.pt", "camera1_A_133.png", None, "deep_sort"], ["last_best_model.pt", None, "test.mp4", "strong_sort"] ], title='YOLOv9: Real-time Object Detection', description='This is a real-time object detection system using YOLOv9.', theme='huggingface' ) iface.launch(debug=True)