# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license import io from typing import Any import cv2 from ultralytics import YOLO from ultralytics.utils import LOGGER from ultralytics.utils.checks import check_requirements from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS class Inference: """ A class to perform object detection, image classification, image segmentation and pose estimation inference using Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings, uploading video files, and performing real-time inference. Attributes: st (module): Streamlit module for UI creation. temp_dict (dict): Temporary dictionary to store the model path. model_path (str): Path to the loaded model. model (YOLO): The YOLO model instance. source (str): Selected video source. enable_trk (str): Enable tracking option. conf (float): Confidence threshold. iou (float): IoU threshold for non-max suppression. vid_file_name (str): Name of the uploaded video file. selected_ind (list): List of selected class indices. Methods: web_ui: Sets up the Streamlit web interface with custom HTML elements. sidebar: Configures the Streamlit sidebar for model and inference settings. source_upload: Handles video file uploads through the Streamlit interface. configure: Configures the model and loads selected classes for inference. inference: Performs real-time object detection inference. Examples: >>> inf = solutions.Inference(model="path/to/model.pt") # Model is not necessary argument. >>> inf.inference() """ def __init__(self, **kwargs: Any): """ Initializes the Inference class, checking Streamlit requirements and setting up the model path. Args: **kwargs (Any): Additional keyword arguments for model configuration. """ check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds import streamlit as st self.st = st # Reference to the Streamlit class instance self.source = None # Placeholder for video or webcam source details self.enable_trk = False # Flag to toggle object tracking self.conf = 0.25 # Confidence threshold for detection self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression self.org_frame = None # Container for the original frame to be displayed self.ann_frame = None # Container for the annotated frame to be displayed self.vid_file_name = None # Holds the name of the video file self.selected_ind = [] # List of selected classes for detection or tracking self.model = None # Container for the loaded model instance self.temp_dict = {"model": None, **kwargs} self.model_path = None # Store model file name with path if self.temp_dict["model"] is not None: self.model_path = self.temp_dict["model"] LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}") def web_ui(self): """Sets up the Streamlit web interface with custom HTML elements.""" menu_style_cfg = """""" # Hide main menu style # Main title of streamlit application main_title_cfg = """

Ultralytics YOLO Streamlit Application

""" # Subtitle of streamlit application sub_title_cfg = """

Experience real-time object detection on your webcam with the power of Ultralytics YOLO! 🚀

""" # Set html page configuration and append custom HTML self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide") self.st.markdown(menu_style_cfg, unsafe_allow_html=True) self.st.markdown(main_title_cfg, unsafe_allow_html=True) self.st.markdown(sub_title_cfg, unsafe_allow_html=True) def sidebar(self): """Configures the Streamlit sidebar for model and inference settings.""" with self.st.sidebar: # Add Ultralytics LOGO logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg" self.st.image(logo, width=250) self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu self.source = self.st.sidebar.selectbox( "Video", ("webcam", "video"), ) # Add source selection dropdown self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking self.conf = float( self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01) ) # Slider for confidence self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold col1, col2 = self.st.columns(2) self.org_frame = col1.empty() self.ann_frame = col2.empty() def source_upload(self): """Handles video file uploads through the Streamlit interface.""" self.vid_file_name = "" if self.source == "video": vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"]) if vid_file is not None: g = io.BytesIO(vid_file.read()) # BytesIO Object with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes out.write(g.read()) # Read bytes into file self.vid_file_name = "ultralytics.mp4" elif self.source == "webcam": self.vid_file_name = 0 def configure(self): """Configures the model and loads selected classes for inference.""" # Add dropdown menu for model selection available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")] if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later available_models.insert(0, self.model_path.split(".pt")[0]) selected_model = self.st.sidebar.selectbox("Model", available_models) with self.st.spinner("Model is downloading..."): self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model class_names = list(self.model.names.values()) # Convert dictionary to list of class names self.st.success("Model loaded successfully!") # Multiselect box with class names and get indices of selected classes selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3]) self.selected_ind = [class_names.index(option) for option in selected_classes] if not isinstance(self.selected_ind, list): # Ensure selected_options is a list self.selected_ind = list(self.selected_ind) def inference(self): """Performs real-time object detection inference.""" self.web_ui() # Initialize the web interface self.sidebar() # Create the sidebar self.source_upload() # Upload the video source self.configure() # Configure the app if self.st.sidebar.button("Start"): stop_button = self.st.button("Stop") # Button to stop the inference cap = cv2.VideoCapture(self.vid_file_name) # Capture the video if not cap.isOpened(): self.st.error("Could not open webcam.") while cap.isOpened(): success, frame = cap.read() if not success: self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.") break # Store model predictions if self.enable_trk == "Yes": results = self.model.track( frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True ) else: results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind) annotated_frame = results[0].plot() # Add annotations on frame if stop_button: cap.release() # Release the capture self.st.stop() # Stop streamlit app self.org_frame.image(frame, channels="BGR") # Display original frame self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame cap.release() # Release the capture cv2.destroyAllWindows() # Destroy window if __name__ == "__main__": import sys # Import the sys module for accessing command-line arguments # Check if a model name is provided as a command-line argument args = len(sys.argv) model = sys.argv[1] if args > 1 else None # assign first argument as the model name # Create an instance of the Inference class and run inference Inference(model=model).inference()