hv_model / app.py
keffy's picture
Update app.py
dae4253 verified
import gradio as gr
import cv2
from ultralytics import YOLO
import tempfile
import os
# Load the YOLO model
model = YOLO("best.pt") # Replace with the path to your model
# Define the inference function
def yolo_inference(input_file):
# Check if the input is an image or a video
if input_file.endswith((".jpg", ".jpeg", ".png")):
# Process as an image
img = cv2.imread(input_file)
results = model(img)
annotated_img = results[0].plot()
# Display the annotated image in a window
cv2.imshow("YOLO Detection", annotated_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
return input_file # Return the original file for consistency (can be adjusted)
elif input_file.endswith((".mp4", ".avi", ".mov")):
# Process as a video
cap = cv2.VideoCapture(input_file)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Create a temporary output video path
temp_dir = tempfile.mkdtemp()
output_video_path = os.path.join(temp_dir, "output.mp4")
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Run YOLO on each frame
results = model(frame)
annotated_frame = results[0].plot()
# Display the annotated frame in a window
cv2.imshow("YOLO Detection", annotated_frame)
if cv2.waitKey(1) & 0xFF == ord('q'): # Press 'q' to quit early
break
# Save the annotated frame to the video
out.write(annotated_frame)
cap.release()
out.release()
cv2.destroyAllWindows()
return input_file # Return the original video file for consistency (can be adjusted)
else:
raise ValueError("Unsupported file format. Please upload an image or video.")
# Define the Gradio interface
interface = gr.Interface(
fn=yolo_inference,
inputs=gr.File(label="Upload an Image or Video"),
outputs="text", # Display a message about console output
title="YOLO Object Detection",
description="Upload an image or video for object detection. The results are displayed on the console."
)
# Launch the app
if __name__ == "__main__":
interface.launch(share=True)