# Importing necessary libraries import io import os import utils import random import shutil import zipfile import numpy as np import pandas as pd import streamlit as st from ultralytics import YOLO import plotly.graph_objs as go from onnx.defs import onnx_opset_version from plotly.subplots import make_subplots # Function to get the dataset directory path based on the specified path type def get_path(path_type): main_directory_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if path_type == "train": return os.path.join( main_directory_path, "model_data", "input_files", "datasets", "train", ) elif path_type == "val": return os.path.join( main_directory_path, "model_data", "input_files", "datasets", "val", ) elif path_type == "test": return os.path.join( main_directory_path, "model_data", "input_files", "datasets", "test", ) elif path_type == "config": return os.path.join(main_directory_path, "model_data", "input_files") elif path_type == "models": return os.path.join(main_directory_path, "model_data", "models") elif path_type == "output": return os.path.join(main_directory_path, "model_data", "output_files") else: raise ValueError(f"Invalid path_type: {path_type}") # Function to check minimum images in training and validation set def check_min_images(total_files, train_pct, val_pct, test_pct): # Calculate raw counts based on percentages train_count = int(total_files * train_pct / 100) val_count = int(total_files * val_pct / 100) test_count = int(total_files * test_pct / 100) # Ensure that both train and validation have at least one file if train_count < 1 or val_count < 1: return False return True # Function to clear data a folders def clear_data_folders(): base_path = "./model_data/input_files/datasets" for folder in ["train", "test", "val"]: for subfolder in ["images", "labels"]: folder_path = os.path.join(base_path, folder, subfolder) if os.path.exists(folder_path): shutil.rmtree(folder_path) os.makedirs(folder_path, exist_ok=True) # Function to pairs image and label files based on their filenames def pair_files(files): paired_files = {} for file in files: # Split the filename into name and extension file_name, file_ext = os.path.splitext(file.name) # Initialize a dict for each unique file name if file_name not in paired_files: paired_files[file_name] = {"image": None, "label": None} # Assign the file to its corresponding type (image or label) based on extension if file_ext.lower() in [".jpg", ".png"]: paired_files[file_name]["image"] = file elif file_ext.lower() == ".txt": paired_files[file_name]["label"] = file return paired_files # Function to split the paired files into training, testing, and validation sets based on specified percentages and saves them in corresponding folders def split_and_save_files(paired_files, train_pct, test_pct): base_path = "./model_data/input_files/datasets" all_keys = list(paired_files.keys()) random.shuffle(all_keys) # Determine the size of each dataset split total_files = len(all_keys) train_size = int(total_files * train_pct / 100) test_size = int(total_files * test_pct / 100) # Split the file keys into training, testing, and validation sets train_keys = all_keys[:train_size] test_keys = all_keys[train_size : train_size + test_size] val_keys = all_keys[train_size + test_size :] # Iterate through each split and save the files to their respective directories for folder_name, keys in zip( ["train", "test", "val"], [train_keys, test_keys, val_keys] ): for key in keys: image_file = paired_files[key]["image"] label_file = paired_files[key]["label"] # Save the image and label files if they exist if image_file: save_file_to_folder( image_file, os.path.join(base_path, folder_name, "images") ) if label_file: save_file_to_folder( label_file, os.path.join(base_path, folder_name, "labels") ) # Function to save an individual file to a specified folder def save_file_to_folder(file, folder_path): os.makedirs(folder_path, exist_ok=True) file_path = os.path.join(folder_path, file.name) with open(file_path, "wb") as f: f.write(file.getbuffer()) # Function to save uploaded files to a specific folder within the base path def save_files_to_folder(uploaded_files, folder_name): # Define the base path for saving the files base_path = "./model_data/input_files/datasets" # Iterate through each uploaded file for file in uploaded_files: if file: # Determine the file type based on file extension file_type = ( "images" if os.path.splitext(file.name)[1].lower() in [".jpg", ".png"] else "labels" ) # Save the file to the appropriate subfolder (images or labels) save_file_to_folder(file, os.path.join(base_path, folder_name, file_type)) # Function to validate each line in the label file for bounding box data def check_bboxes_label(label_file, class_dict): for line in label_file: try: # Decode the line, strip whitespace, split into parts, and convert each part to float class_id, x_center, y_center, width, height = map( float, line.decode().strip().split() ) # Check if bounding box coordinates and class ID are valid if not ( 0 <= x_center <= 1 and 0 <= y_center <= 1 and 0 <= width <= 1 and 0 <= height <= 1 and class_id in class_dict.keys() ): # Return False if any condition is not met (invalid data) return False except Exception as e: # Return False in case of any exception (e.g., parsing error) return False # Return True if all lines in the label file pass the validation return True # Function to validate each line in the label file for mask data def check_masks_label(label_file, class_dict): for line in label_file: try: # Decode the line and split into parts: class ID and points parts = line.decode().strip().split() class_id = int( parts[0] ) # Convert the first part to an integer for class ID points = [ float(p) for p in parts[1:] ] # Convert the remaining parts to float for coordinates # Check if class ID exists in the class dictionary and all points are within [0, 1] if not (class_id in class_dict.keys() and all(0 <= p <= 1 for p in points)): return False # Return False if validation fails except Exception as e: # Return False in case of any exception (e.g., parsing error) return False return True # Return True if all lines in the label file pass the validation # Function to read label from YOLO format def read_label(file, selected_option, class_dict): # Read the content of the file file_content = file.readlines() # Check and validate bounding box labels if the selected option is 'Bboxes' if selected_option == "Bboxes": return check_bboxes_label(file_content, class_dict) # Validate bbox labels # Check and validate mask labels if the selected option is 'Masks' elif selected_option == "Masks": return check_masks_label(file_content, class_dict) # Validate mask labels # Return False if the selected option is neither 'Bboxes' nor 'Masks' return False # Function to check for duplicates def check_file_duplicates(file_names): unique_names = set(file_names) return len(unique_names) == len(file_names) # Function to validates the uploaded image and label files def validate_files(image_names, label_names): # Check for duplicate filenames in both images and labels if not check_file_duplicates(image_names) or not check_file_duplicates(label_names): # Show warning if duplicates are found st.warning( "Duplicate file names detected. Please ensure each image and label has a unique name.", icon="⚠️", ) return False # Return False indicating validation failed # Check if the number of images matches the number of labels if len(image_names) != len(label_names): # Show warning if counts don't match st.warning( "Count Mismatch: The number of uploaded images and labels does not match.", icon="⚠️", ) return False # Return False indicating validation failed # Display a success message if the above checks pass st.info( f"Validated: {len(image_names)} images and labels successfully matched.", icon="✅", ) return True # Return True indicating successful validation # Function to check labels format @st.cache_resource(show_spinner=False) def check_valid_labels(uploaded_files, selected_option, class_dict): # Check if no files were uploaded if len(uploaded_files) == 0: st.warning("Please upload images and labels.", icon="⚠️") return False # Initialize lists to store names of image and label files image_names, label_names = [], [] # Initialize a progress bar and progress text progress_bar = st.progress(0) progress_text = st.empty() total_files = len(uploaded_files) # Iterate over each uploaded file for index, file in enumerate(uploaded_files): # Reset the file pointer to the beginning file.seek(0) # Check file type and categorize as image or label if file.type in ["image/jpeg", "image/png"]: # Add to image names list if file is an image image_names.append(file.name) elif file.type == "text/plain": # Read and validate label file if not read_label(file, selected_option, class_dict): # Show warning if label format or data is invalid st.warning( f"Invalid label format or data in file: {file.name}", icon="⚠️" ) return False # Add to label names list if file is a valid label label_names.append(file.name) # Update progress bar and display current progress progress_percentage = (index + 1) / total_files progress_bar.progress(progress_percentage) progress_text.text(f"Validating file {index + 1} of {total_files}") # Remove progress bar and progress text after processing progress_bar.empty() progress_text.empty() # Validate if all images have corresponding labels and vice versa return validate_files(image_names, label_names) # Function to get training, validation and export configurations def get_training_validation_export_configuration(selected_training): with st.expander("Training Configuration"): # User Instruction for Default Values st.markdown( """
User Instructions: If you are unsure about the specific values to use for training parameters, it is recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance between performance and resource utilization for most scenarios. You can always come back and tweak these settings once you have more experience or specific requirements for your model training.
""", unsafe_allow_html=True, ) # Padding utils.top_padding(2) # Training Configuration st.markdown("### Training Configuration") # Model Selection st.write("**Model Selection**") selected_model = st.selectbox( "Choose a YOLOv8 model variant", list(utils.models_info.keys()) ) model_spec = utils.models_info[selected_model] spec_string = ( "
" f"The selected model, {selected_model}, is benchmarked on an image size of 640x640 pixels. It has a Mean Average Precision (mAPval) of {model_spec['mAPval']}, " f"operates with a speed of {model_spec['speed_cpu']} ms on CPU (ONNX) and {model_spec['speed_gpu']} ms on GPU (TensorRT). " f"It consists of approximately {model_spec['params']} million parameters and requires about {model_spec['flops']} billion Floating Point Operations (FLOPs)." "
" ) st.markdown(spec_string, unsafe_allow_html=True) # Spacer st.markdown("---") # Time Configuration st.write("**Time Configuration**") col1_time, col2_time = st.columns([1, 3]) with col1_time: top_padding_time = st.container() time_allow = st.checkbox("Enable Time", value=False) if time_allow: with top_padding_time: utils.top_padding(2) time = col2_time.number_input( "Time (hours)", min_value=1, max_value=100, value=1, step=1 ) else: time = None st.markdown( "
Set the training duration in hours. This option overrides the epochs setting. Useful for limiting training time in scenarios with constrained resources.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Epochs Configuration st.write("**Epochs Configuration**") epochs = st.number_input( "Epochs", min_value=1, max_value=1000, value=50, step=10 ) st.markdown( "
Define the number of epochs for the training process. An epoch represents a complete pass over the entire dataset. More epochs can improve accuracy but increase training time.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Patience Configuration st.write("**Patience Configuration**") col1_patience, col2_patience = st.columns([1, 3]) with col1_patience: top_padding_patience = st.container() patience_allow = st.checkbox("Enable Patience", value=False) if patience_allow: with top_padding_patience: utils.top_padding(2) patience = col2_patience.number_input( "Patience (epochs)", min_value=5, max_value=50, value=5, step=1 ) else: patience = None st.markdown( "
Configure the early stopping mechanism. Patience denotes the number of epochs to wait for improvement in performance before stopping the training, helping to avoid overfitting.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Batch Size Configuration st.write("**Batch Size Configuration**") batch = st.number_input( "Batch Size", min_value=-1, max_value=128, value=-1, step=1 ) st.markdown( "
Determine the number of images processed together in one pass (batch). A larger batch size can lead to faster training but requires more memory. Use -1 for automatic batch sizing.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Image Size Configuration st.write("**Image Size Configuration**") imgsz = st.number_input( "Image Size (pixels)", min_value=64, max_value=4096, value=640, step=32 ) st.markdown( "
Specify the size of the input images. Larger images can capture more details but require more computational resources. The size is typically a square dimension, like 640x640 pixels.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Cache Configuration st.write("**Cache Configuration**") cache = st.selectbox("Cache Option", ["False", "True/ram", "disk"]) st.markdown( "
Choose a caching method for data loading to speed up training. 'True/ram' caches data in RAM, 'disk' caches on disk, and 'False' disables caching.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Optimizer Configuration st.write("**Optimizer Configuration**") optimizer = st.selectbox( "Optimizer", ["SGD", "Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "auto"], index=7, ) st.markdown( "
Select the optimizer for training. The optimizer adjusts weights to minimize the loss function. Choices include SGD, Adam, and others, with 'auto' selecting automatically based on the model.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # AMP Configuration st.write("**AMP Configuration**") amp = st.checkbox("Enable AMP", value=True) st.markdown( "
Enable Automatic Mixed Precision (AMP) to accelerate training on compatible hardware. AMP uses lower precision to reduce memory usage and speed up computations.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Deterministic Mode Configuration st.write("**Deterministic Mode Configuration**") deterministic = st.checkbox("Enable Deterministic Mode", value=False) st.markdown( "
Activate deterministic mode to ensure reproducible results. This mode might slow down the training but is useful for experimentation and debugging.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Rectangular Training Configuration st.write("**Rectangular Training Configuration**") rect = st.checkbox("Enable Rectangular Training", value=False) st.markdown( "
Enable rectangular training to process batches with minimal padding by reshaping images. This can lead to performance improvements but may affect accuracy.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Cosine Learning Rate Scheduler Configuration st.write("**Cosine Learning Rate Scheduler**") cos_lr = st.checkbox("Use Cosine LR Scheduler", value=False) st.markdown( "
Use a cosine learning rate scheduler to adjust the learning rate following a cosine curve, potentially leading to better convergence during training.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Freeze Layer Configuration st.write("**Freeze Layer Configuration**") col1_freeze, col2_freeze = st.columns([1, 3]) with col1_freeze: top_padding_freeze = st.container() freeze_allow = st.checkbox("Enable Freeze Layers", value=False) if freeze_allow: with top_padding_freeze: utils.top_padding(2) freeze = col2_freeze.number_input( "Freeze Layers", min_value=1, max_value=1000, value=10, placeholder="Enter number of layers", ) else: freeze = None st.markdown( "
Enable freezing the initial layers of the model during training. Specify the number of layers to freeze or a comma-separated list of specific layer indices. Useful for fine-tuning pre-trained models without modifying early layers.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Initial Learning Rate Configuration st.write("**Initial Learning Rate (lr0)**") lr0 = st.number_input( "Initial Learning Rate (lr0)", min_value=0.00001, max_value=1.0, value=0.01, format="%.5f", ) st.markdown( "
Specify the initial learning rate (lr0) for the training process. The initial rate is crucial as it determines the starting step size for weight updates. A well-chosen initial rate helps in achieving a balance between fast convergence and overshooting the optimal solution.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Final Learning Rate Configuration st.write("**Final Learning Rate (lrf)**") lrf = st.number_input( "Final Learning Rate (lrf)", min_value=0.00001, max_value=1.0, value=0.01, format="%.5f", ) st.markdown( "
Determine the final learning rate, which is a factor (lrf) of the initial learning rate (lr0). This parameter is used to adjust the learning rate over the course of training, gradually decreasing it to fine-tune model weights and stabilize training as it approaches the minimum of the loss function.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Momentum Configuration st.write("**Momentum Configuration**") momentum = st.number_input( "Momentum", min_value=0.0, max_value=1.0, value=0.937, format="%.3f" ) st.markdown( "
Set the momentum value for the optimizer. Momentum helps in accelerating the optimizer in the relevant direction and dampens oscillations, facilitating faster convergence.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Weight Decay Configuration st.write("**Weight Decay Configuration**") weight_decay = st.number_input( "Weight Decay", min_value=0.0, max_value=0.1, value=0.0005, format="%.5f" ) st.markdown( "
Specify the weight decay, a regularization technique that adds a small penalty to the loss function for larger weights. It helps in preventing overfitting by encouraging simpler models.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Warmup Epochs Configuration st.write("**Warmup Epochs Configuration**") warmup_epochs = st.number_input( "Warmup Epochs", min_value=0.0, max_value=10.0, value=3.0, step=0.1 ) st.markdown( "
Define the number of warmup epochs. During warmup, the learning rate gradually increases to its initial value, which helps in stabilizing the training process in its early stages.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Warmup Momentum Configuration st.write("**Warmup Momentum Configuration**") warmup_momentum = st.number_input( "Warmup Momentum", min_value=0.0, max_value=1.0, value=0.8, format="%.1f" ) st.markdown( "
Configure the momentum during the warmup phase. A lower momentum at the start can help in stabilizing the optimization process before reaching the specified momentum for the remaining epochs.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Warmup Bias Learning Rate Configuration st.write("**Warmup Bias Learning Rate Configuration**") warmup_bias_lr = st.number_input( "Warmup Bias Learning Rate", min_value=0.0, max_value=1.0, value=0.1, format="%.1f", ) st.markdown( "
Adjust the bias learning rate during the warmup period. This parameter can be tuned to manage the initial learning rate specifically for the bias parameters in the early training phase.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Box Loss Gain Configuration st.write("**Box Loss Gain Configuration**") box = st.number_input( "Box Loss Gain", min_value=0.0, max_value=10.0, value=7.5, step=0.1 ) st.markdown( "
Configure the gain factor for the box loss. This gain helps in adjusting the importance of the box size and location accuracy in the loss function, affecting how the model prioritizes bounding box precision.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Class Loss Gain Configuration st.write("**Class Loss Gain Configuration**") cls = st.number_input( "Class Loss Gain", min_value=0.0, max_value=10.0, value=0.5, step=0.1 ) st.markdown( "
Set the gain factor for the class loss. This parameter scales the contribution of class prediction accuracy in the total loss, influencing how the model prioritizes correct class identification.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # DFL Loss Gain Configuration st.write("**DFL Loss Gain Configuration**") dfl = st.number_input( "DFL Loss Gain", min_value=0.0, max_value=10.0, value=1.5, step=0.1 ) st.markdown( "
Determine the gain factor for the DFL loss. Adjusting this gain influences the model's focus on the Directional Focal Loss component, which is critical for precise object localization and classification.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Label Smoothing Configuration st.write("**Label Smoothing Configuration**") label_smoothing = st.number_input( "Label Smoothing (fraction)", min_value=0.0, max_value=1.0, value=0.0, format="%.1f", ) st.markdown( "
Specify the label smoothing value, a technique that introduces softening to the target labels. It promotes model generalization and reduces the impact of noisy labels on the training process.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Nominal Batch Size Configuration st.write("**Nominal Batch Size Configuration**") nbs = st.number_input( "Nominal Batch Size", min_value=1, max_value=128, value=64, step=1 ) st.markdown( "
Set the nominal batch size, which is used for normalizing the loss. This size does not affect the actual batch size but is used to scale the loss to a standard reference batch size.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Overlap Mask Configuration st.write("**Overlap Mask Configuration**") overlap_mask = st.checkbox("Masks Overlap during Training", value=True) st.markdown( "
Choose whether to allow masks to overlap during instance segmentation training. Overlapping can lead to more precise segmentation but may increase complexity.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Mask Ratio Configuration st.write("**Mask Ratio Configuration**") mask_ratio = st.number_input( "Mask Downsample Ratio", min_value=1, max_value=10, value=4, step=1 ) st.markdown( "
Set the downsample ratio for masks in instance segmentation. A higher ratio reduces the mask resolution, which can speed up computations but might decrease segmentation accuracy.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Dropout Configuration st.write("**Dropout Configuration**") dropout = st.number_input( "Dropout Regularization", min_value=0.0, max_value=1.0, value=0.0, format="%.1f", ) st.markdown( "
Configure the dropout rate, which randomly disables a proportion of neurons during training. This prevents the model from relying too much on certain features and promotes better generalization.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Validation/Test Configuration st.write("**Validation/Test Configuration**") val = st.checkbox("Validate/Test during Training", value=True) st.markdown( "
Decide whether to perform validation and testing during the training process. Regular validation helps monitor model performance and adjust training accordingly.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Save Plots Configuration st.write("**Save Plots Configuration**") plots = st.checkbox("Save Plots and Images during Training", value=True) st.markdown( "
Enable saving of plots and images during training. This feature provides visual insights into the training progress and helps in diagnosing model performance across epochs.
", unsafe_allow_html=True, ) # Padding utils.top_padding(2) with st.expander("Validation Configuration"): # User Instruction for Default Values st.markdown( """
User Instructions: If you are unsure about the specific values to use for validation parameters, it is recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance between performance and resource utilization for most scenarios. You can always come back and tweak these settings once you have more experience or specific requirements for your model validation.
""", unsafe_allow_html=True, ) # Padding utils.top_padding(2) # Validation Configuration st.markdown("### Validation Configuration") # Object Confidence Threshold st.write("**Object Confidence Threshold**") conf = st.number_input( "Confidence Threshold", min_value=0.0, max_value=1.0, value=0.001, format="%.3f", ) st.markdown( "
Set the confidence threshold for object detection. This threshold filters out detections with lower confidence, reducing false positives and focusing on more likely object detections.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Intersection Over Union (IoU) Threshold st.write("**IoU Threshold for NMS**") iou = st.number_input( "IoU Threshold", min_value=0.0, max_value=1.0, value=0.6, format="%.1f" ) st.markdown( "
Define the IoU threshold for Non-Maximum Suppression. NMS is used to refine the bounding boxes by eliminating redundancies and retaining the most probable ones.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Maximum Number of Detections st.write("**Maximum Number of Detections**") max_det = st.number_input( "Max Detections", min_value=1, max_value=1000, value=300, step=1 ) st.markdown( "
Limit the maximum number of detections per image. This setting is crucial for controlling the computational load and focusing the model on the most confident and relevant detections.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Use Half Precision st.write("**Use Half Precision (FP16)**") half = st.checkbox("Enable Half Precision", value=True) st.markdown( "
Enable half precision (FP16) training for enhanced performance on compatible GPUs. It reduces memory requirements and accelerates computation, beneficial for larger models and datasets.
", unsafe_allow_html=True, ) # Padding utils.top_padding(2) with st.expander("Export Configuration"): # User Instruction for Default Values st.markdown( """
User Instructions: If you are unsure about the specific values to use for export parameters, it is recommended to stick with the default values provided. These defaults are carefully chosen to provide a good balance between performance and resource utilization for most scenarios. You can always come back and tweak these settings once you have more experience or specific requirements for your model export.
""", unsafe_allow_html=True, ) # Padding utils.top_padding(2) # Validation Configuration st.markdown("### Export Configuration") # Select Export Format st.write("**Export Format**") export_format = st.selectbox( "Select Export Format", [ "Only PyTorch", "TorchScript", "ONNX", "OpenVINO", "TensorRT", "CoreML", "TF SavedModel", "TF GraphDef", "TF Lite", "TF Edge TPU", "TF.js", "PaddlePaddle", "ncnn", ], ) # Dynamically generate description if export_format == "Only PyTorch": st.markdown( """
You have selected PyTorch as the export format. This will export the model in the standard PyTorch .pt format. There are no additional format-specific parameters to consider for this selection. The exported model will be the same as selected during training.
""", unsafe_allow_html=True, ) else: format_info = utils.export_formats[export_format] # Handling additional arguments if len(format_info["arguments"]) > 0: additional_arguments = ", ".join(format_info["arguments"]) arguments_info = f"Consider the following arguments for the {export_format} format: {additional_arguments}." else: arguments_info = ( "No additional parameters need to be considered for this format." ) st.markdown( f"""
You have selected {export_format} as the export format. Along with the PyTorch model, this selection will also export the model in the {export_format} format. The image size of the exported model will be the same as selected during training. {arguments_info}
""", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Use Keras for TF SavedModel export st.write("**Use Keras for TF SavedModel Export**") keras = st.checkbox("Enable Keras", value=False) st.markdown( "
Enabling Keras optimizes the TensorFlow SavedModel export for compatibility with the Keras API, making it easier to work with in Keras-centric workflows.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Optimize for mobile (TorchScript) st.write("**Optimize TorchScript for Mobile**") optimize = st.checkbox("Enable Optimization", value=False) st.markdown( "
Optimizing for mobile reduces the model size and computational needs, enhancing performance on mobile devices with limited resources.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # FP16 quantization st.write("**FP16 Quantization**") half = st.checkbox("Enable FP16 Quantization", value=False) st.markdown( "
FP16 quantization reduces model size and speeds up inference, especially on GPUs with Tensor Cores, while maintaining model accuracy.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # INT8 quantization st.write("**INT8 Quantization**") int8 = st.checkbox("Enable INT8 Quantization", value=False) st.markdown( "
INT8 quantization further reduces model size and inference time, ideal for edge devices, at the cost of a slight decrease in accuracy.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Dynamic axes for ONNX/TensorRT st.write("**Dynamic Axes for ONNX/TensorRT**") dynamic = st.checkbox("Enable Dynamic Axes", value=False) st.markdown( "
Dynamic axes allow the ONNX/TensorRT models to handle variable input sizes, increasing the model's flexibility in deployment.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Simplify model for ONNX/TensorRT st.write("**Simplify Model for ONNX/TensorRT**") simplify = st.checkbox("Enable Model Simplification", value=False) st.markdown( "
Simplification optimizes the ONNX/TensorRT models by removing redundant operations, improving efficiency without impacting accuracy.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # ONNX Opset Version Configuration st.write("**ONNX Opset Version Configuration**") col1_opset, col2_opset = st.columns([1, 3]) with col1_opset: top_padding_opset = st.container() opset_allow = st.checkbox("Specify Opset Version", value=False) if opset_allow: with top_padding_opset: utils.top_padding(2) # Create a range of opset versions for the dropdown opset_versions = list(range(1, onnx_opset_version() + 1)) with col2_opset: opset = st.selectbox( "Select Opset Version", opset_versions, index=len(opset_versions) - 1, ) else: opset = None st.markdown( "
Select the ONNX opset version for the export. " "Specifying an opset version can ensure compatibility with specific ONNX versions. " "The latest version is recommended to ensure the most up-to-date features and optimizations. " "If unsure, leave the checkbox unchecked to use the default opset version.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # TensorRT workspace size st.write("**TensorRT Workspace Size (GB)**") workspace = st.number_input( "Workspace Size", min_value=1, max_value=32, value=4, step=1 ) st.markdown( "
Set the TensorRT workspace size in GB. A larger workspace can lead to more optimized models but requires more memory.
", unsafe_allow_html=True, ) # Spacer st.markdown("---") # Add NMS for CoreML st.write("**Add NMS for CoreML**") nms = st.checkbox("Enable NMS", value=False) st.markdown( "
Enabling NMS (Non-Maximum Suppression) for CoreML models helps in reducing overlapping bounding boxes and improves the clarity of object detection results.
", unsafe_allow_html=True, ) # Padding utils.top_padding(2) if selected_training == "Object Detection": model_path = os.path.join( get_path("models"), selected_model.lower() + ".pt" ) task = "detect" elif selected_training == "Instance Segmentation": model_path = os.path.join( get_path("models"), selected_model.lower() + "-seg.pt" ) task = "segment" export_settings = { "format": None if export_format == "Only PyTorch" else export_format, "keras": keras, "optimize": optimize, "half": half, "int8": int8, "dynamic": dynamic, "simplify": simplify, "opset": opset, "workspace": workspace, "nms": nms, } return { "model_path": model_path, "task": task, "model": selected_model, "time": time, "epochs": epochs, "patience": patience, "batch": batch, "imgsz": imgsz, "cache": cache, "optimizer": optimizer, "amp": amp, "deterministic": deterministic, "rect": rect, "cos_lr": cos_lr, "freeze": freeze, "lr0": lr0, "lrf": lrf, "momentum": momentum, "weight_decay": weight_decay, "warmup_epochs": warmup_epochs, "warmup_momentum": warmup_momentum, "warmup_bias_lr": warmup_bias_lr, "box": box, "cls": cls, "dfl": dfl, "label_smoothing": label_smoothing, "nbs": nbs, "overlap_mask": overlap_mask, "mask_ratio": mask_ratio, "dropout": dropout, "val": val, "plots": plots, "conf": conf, "iou": iou, "max_det": max_det, "half": half, "export_settings": export_settings, } # Function to generate python code for model training def generate_python_code_model_training(training_configuration): # Copy the original configuration and update with additional parameters training_configuration_code = training_configuration.copy() training_configuration_code["data"] = r".\config.yaml" # Path to config file training_configuration_code["save_dir"] = r".\output\train" # Output directory training_configuration_code["pretrained"] = True # Use a pretrained model training_configuration_code["save"] = True # Save the trained model training_configuration_code["save_period"] = -1 # Save period configuration training_configuration_code["augment"] = False # Augmentation setting training_configuration_code["seed"] = 0 # Seed for reproducibility training_configuration_code["verbose"] = True # Verbose output training_configuration_code["single_cls"] = False # Single class setting training_configuration_code["resume"] = False # Resume training setting training_configuration_code["exist_ok"] = True # Overwrite existing files training_configuration_code["project"] = r".\output" # Project directory training_configuration_code["name"] = "train" # Project name # Extract the model name from the model path model_name = training_configuration_code["model_path"].split("\\")[-1] # Start with necessary library imports and model initialization code_str = "# Importing necessary libraries\n" code_str += "from ultralytics import YOLO\n\n" # Initialize the YOLO model code_str += f"# Initialize the YOLO model '{model_name}'\n" code_str += f"model = YOLO('{model_name}')\n" # Add the model training code code_str += "\n# Start the training process\n" code_str += "model.train(\n" for key, value in training_configuration_code.items(): if key not in [ "model_path", "model", "export_settings", ]: # Exclude specific keys code_str += f" {key}={value},\n" code_str = code_str.rstrip(",\n") + "\n)\n" # Add model export code code_str += "\n# Model export process\n" code_str += "model.export(\n" for key, value in training_configuration_code["export_settings"].items(): if key == "format" and value is None: continue # Skip format if it's None code_str += f" {key}={value},\n" code_str = code_str.rstrip(",\n") + "\n)\n" return code_str # Function to overwrites a Python file with new code def overwrite_python_file(code_str, file_path): # Open the file in write mode, which automatically deletes old content with open(file_path, "w") as file: file.write(code_str) # Function to generate a downloadable file def display_code_and_download_button(generated_code): # Display the generated code in Streamlit with description and download button in columns with st.expander("Plug and Play Code"): col1, col2 = st.columns([7, 3]) with col1: st.markdown( """ ### Description of the Code Pipeline """ ) st.markdown( """
This Python script is configured for training a YOLO model. It includes necessary configurations and parameters for a custom YOLO model training session. **To use this script:** - Ensure you have the necessary dependencies installed. - Place your image and label files in the `'datasets/train'`, `'datasets/val'`, and `'datasets/test'` folders respectively. - The `'config.yaml'` file and the training script are set up based on your provided configurations. ### Python Code
""", unsafe_allow_html=True, ) # Display python code st.code(generated_code, language="python") # Determine the main directory path main_directory_path = os.path.dirname( os.path.dirname(os.path.abspath(__file__)) ) # Overwrites a Python file with new code overwrite_python_file( generated_code, os.path.join( main_directory_path, "model_data", "model_training_code_pipline", "model_training.py", ), ) # Determine the main directory path main_directory_path = os.path.dirname( os.path.dirname(os.path.abspath(__file__)) ) # Prepare a ZIP file of the training output folder in memory for download zip_bytes_io = zip_folder_to_bytesio( os.path.join( main_directory_path, "model_data", "model_training_code_pipline" ) ) with col2: # Create a button for downloading the training pipeline st.download_button( label="Download Training Pipeline", data=zip_bytes_io, file_name="model_training_code.zip", mime="application/zip", use_container_width=True, ) # Function to generates a YOLO model training code snippet and displays it with a download button def generate_and_display_yolo_training_code(class_labels, training_configuration): # Determine the main directory path main_directory_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # Construct the path to the config file directory config_file_path = os.path.join( main_directory_path, "model_data", "model_training_code_pipline" ) # Define the path to the dataset directory dataset_directory_path = "./datasets" # Create YOLO config file using provided class labels and dataset directory create_yolo_config_file(config_file_path, class_labels, dataset_directory_path) # Generate the Python code for YOLO model training generated_code = generate_python_code_model_training(training_configuration) # Display the generated code and a download button display_code_and_download_button(generated_code) # Function to create a yolo config file def create_yolo_config_file( config_file_path, class_labels, dataset_directory_path=None ): if dataset_directory_path is None: dataset_directory_path = os.path.join(config_file_path, "datasets") # Number of classes num_classes = len(class_labels) # Create the configuration content config_content = f"""path: {dataset_directory_path} # Path to the dataset directory train: train # Path to the training set directory val: val # Path to the validation set directory test: test # Path to the testing set directory nc: {num_classes} # Number of classes names: {class_labels} # List of class names """ # Write the configuration to a file with open(os.path.join(config_file_path, "config.yaml"), "w") as file: file.write(config_content) # Function to delete and recreate a folder def delete_and_recreate_folder(folder_path): try: # Use shutil.rmtree to delete the folder and its contents shutil.rmtree(folder_path) # Recreate the folder at the same location os.makedirs(folder_path) except Exception as e: print(f"Error deleting or recreating folder {folder_path}: {e}") # Function to read csv and get values def read_csv_and_get_values(csv_file_path): # Read the CSV file into a pandas DataFrame df = pd.read_csv(csv_file_path) # Initialize an empty dictionary to store the results result_dict = {} # Iterate through the columns of the DataFrame for column in df.columns: # Remove leading and trailing spaces from the column name clean_column_name = column.strip() # Get the values in the column column_values = df[column].astype(float) # Add the cleaned column name and values to the result dictionary result_dict[clean_column_name] = np.array(column_values) return result_dict # Global variables plot_container = None val_dataframe_container = None progress_bar = None progress_text = None # Function to define a custom callback function for on_pretrain_routine_start def on_pretrain_routine_start(trainer): global progress_text, progress_bar progress_bar = st.empty() progress_text = st.empty() progress_text.info( "Loading selected model...", icon="✅", ) # Function to define a custom callback function for on_train_start def on_train_start(trainer): global progress_bar, progress_text progress_bar = st.progress(0) progress_text.info( "Training Started...", icon="✅", ) # Function to display metrics plot st.cache_resource(show_spinner=False) def display_metrics_plot(output_data): global plot_container # Extract data for each metric epoch_history = output_data.get("epoch") # Extract loss histories train_box_loss_history = output_data.get("train/box_loss") train_cls_loss_history = output_data.get("train/cls_loss") train_dfl_loss_history = output_data.get("train/dfl_loss") train_seg_loss_history = output_data.get("train/seg_loss") val_box_loss_history = output_data.get("val/box_loss") val_cls_loss_history = output_data.get("val/cls_loss") val_dfl_loss_history = output_data.get("val/dfl_loss") val_seg_loss_history = output_data.get("val/seg_loss") if train_seg_loss_history is None: train_seg_loss_history = epoch_history * 0 val_seg_loss_history = epoch_history * 0 # Extract precision, recall, and mAP histories for B and M box/mask precision_B_history = output_data.get("metrics/precision(B)") recall_B_history = output_data.get("metrics/recall(B)") mAP50_B_history = output_data.get("metrics/mAP50(B)") mAP50_95_B_history = output_data.get("metrics/mAP50-95(B)") precision_M_history = output_data.get("metrics/precision(M)") recall_M_history = output_data.get("metrics/recall(M)") mAP50_M_history = output_data.get("metrics/mAP50(M)") mAP50_95_M_history = output_data.get("metrics/mAP50-95(M)") # Check for 'None' data and adjust the number of rows in the grid num_rows = 4 subplot_titles = [ "Precision B", "Recall B", "mAP50 B", "mAP50-95 B", "Precision R", "Recall R", "mAP50 R", "mAP50-95 R", "Train Box Loss", "Train Class Loss", "Train DFL Loss", "Train Seg Loss", "Val Box Loss", "Val Class Loss", "Val DFL Loss", "Val Seg Loss", ] if precision_M_history is None: num_rows = 3 subplot_titles = subplot_titles[0:4] + subplot_titles[8:] # Create a subplot grid fig = make_subplots( rows=num_rows, cols=4, subplot_titles=subplot_titles, vertical_spacing=0.05, ) # Initialize row number row_number = 1 # Add precision, recall, mAP plots for B and R box/mask fig.add_trace( go.Scatter( x=epoch_history, y=precision_B_history, mode="lines", name="Precision B" ), row=row_number, col=1, ) fig.add_trace( go.Scatter(x=epoch_history, y=recall_B_history, mode="lines", name="Recall B"), row=row_number, col=2, ) fig.add_trace( go.Scatter(x=epoch_history, y=mAP50_B_history, mode="lines", name="mAP50 B"), row=row_number, col=3, ) fig.add_trace( go.Scatter( x=epoch_history, y=mAP50_95_B_history, mode="lines", name="mAP50-95 B" ), row=row_number, col=4, ) if precision_M_history is not None: # Increment row number row_number += 1 fig.add_trace( go.Scatter( x=epoch_history, y=precision_M_history, mode="lines", name="Precision R" ), row=row_number, col=1, ) fig.add_trace( go.Scatter( x=epoch_history, y=recall_M_history, mode="lines", name="Recall R" ), row=row_number, col=2, ) fig.add_trace( go.Scatter( x=epoch_history, y=mAP50_M_history, mode="lines", name="mAP50 R" ), row=row_number, col=3, ) fig.add_trace( go.Scatter( x=epoch_history, y=mAP50_95_M_history, mode="lines", name="mAP50-95 R" ), row=row_number, col=4, ) # Increment row number row_number += 1 # Add loss plots fig.add_trace( go.Scatter( x=epoch_history, y=train_box_loss_history, mode="lines", name="Train Box Loss", ), row=row_number, col=1, ) fig.add_trace( go.Scatter( x=epoch_history, y=train_cls_loss_history, mode="lines", name="Train Class Loss", ), row=row_number, col=2, ) fig.add_trace( go.Scatter( x=epoch_history, y=train_dfl_loss_history, mode="lines", name="Train DFL Loss", ), row=row_number, col=3, ) fig.add_trace( go.Scatter( x=epoch_history, y=train_seg_loss_history, mode="lines", name="Train Seg Loss", ), row=row_number, col=4, ) # Increment row number row_number += 1 fig.add_trace( go.Scatter( x=epoch_history, y=val_box_loss_history, mode="lines", name="Val Box Loss" ), row=row_number, col=1, ) fig.add_trace( go.Scatter( x=epoch_history, y=val_cls_loss_history, mode="lines", name="Val Class Loss" ), row=row_number, col=2, ) fig.add_trace( go.Scatter( x=epoch_history, y=val_dfl_loss_history, mode="lines", name="Val DFL Loss" ), row=row_number, col=3, ) fig.add_trace( go.Scatter( x=epoch_history, y=val_seg_loss_history, mode="lines", name="Val Seg Loss", ), row=row_number, col=4, ) # Check if the plot container is already initialized if plot_container is None: plot_container = st.empty() # Update layout fig.update_layout( height=1200, width=1600, title_text="Metrics", legend=dict(orientation="h", yanchor="bottom", xanchor="left"), ) # Display the updated plot in the same container plot_container.plotly_chart(fig, use_container_width=True) # Function to define a custom callback function for on_fit_epoch_end def on_fit_epoch_end(trainer): current_epoch = int(trainer.epoch) total_epochs = int(trainer.epochs) # Define the path to the output CSV output_csv_path = os.path.join(get_path("output"), "train", "results.csv") # Read the CSV data st.session_state["plot_data"] = read_csv_and_get_values(output_csv_path) # Call a function to update the plot using this data display_metrics_plot(st.session_state["plot_data"]) # Update progress bar and text progress_bar.progress((current_epoch + 1) / total_epochs) progress_text.write(f"Epoch {(current_epoch + 1)}/{total_epochs}") # Function to define a custom callback function for on_train_end def on_train_end(trainer): global progress_bar, progress_text progress_bar.empty() progress_text.info( "Best and last model save completed successfully.", icon="✅", ) # Function to add various callbacks to the YOLO model for different stages of the training process def callback_add(model): # Add a callback to be triggered at the start of the pre-training routine model.add_callback("on_pretrain_routine_start", on_pretrain_routine_start) # Add a callback to be triggered at the start of the training model.add_callback("on_train_start", on_train_start) # Add a callback to be triggered at the end of each training epoch model.add_callback("on_fit_epoch_end", on_fit_epoch_end) # Add a callback to be triggered at the end of the training process model.add_callback("on_train_end", on_train_end) # Function to zip a folder and all its subfolders and return a BytesIO object def zip_folder_to_bytesio(folder_path): bytes_io = io.BytesIO() with zipfile.ZipFile(bytes_io, "w", zipfile.ZIP_DEFLATED) as zipf: folder_path_abs = os.path.abspath(folder_path) for root, dirs, files in os.walk(folder_path): # Calculate the relative path from the folder_path folder_rel_path = os.path.relpath(root, folder_path_abs) # If the directory is empty, add the directory itself if not dirs and not files: # ZIP format requires a trailing slash for empty directories zip_dir_path = f"{folder_rel_path}/" if folder_rel_path != "." else "" zipf.write(root, zip_dir_path) for file in files: file_path = os.path.join(root, file) # Construct the path within the zip file zip_file_path = ( os.path.join(folder_rel_path, file) if folder_rel_path != "." else file ) zipf.write(file_path, zip_file_path) bytes_io.seek(0) # Go to the start of the BytesIO buffer return bytes_io # Function to display Metrics Table st.cache_resource(show_spinner=False) def display_val_dataframe(val_dataframe): global val_dataframe_container # Check if the dataframe container is already initialized if val_dataframe_container is None: val_dataframe_container = st.container() # Display the updated dataframe in the same container with val_dataframe_container: # Display the message to indicate that the metrics table is ready st.markdown("**Metrics Table**", unsafe_allow_html=True) # Display the DataFrame st.dataframe(val_dataframe) # Function to display the DataFrame def val_dataframe(model): # Placeholder for the initial message message = st.empty() message.markdown("**Generating Metrics Table...**", unsafe_allow_html=True) # Extract the metrics from the model metrics = model.val() # Extract the class indices and names class_index = metrics.ap_class_index class_names = metrics.names # Extract precision, recall, and mAP values for the box (B) metrics precision_B_values = metrics.box.p recall_B_values = metrics.box.r mAP50_95_B_values = [metrics.box.maps[i] for i in class_index] # Check if segmentation (mask) metrics exist try: metrics_mask = metrics.seg except: metrics_mask = False if metrics_mask: precision_M_values = metrics_mask.p recall_M_values = metrics_mask.r mAP50_95_M_values = [metrics_mask.maps[i] for i in class_index] # Extract aggregated metrics from the results dictionary results_dict = metrics.results_dict # Initialize lists for overall precision, recall, and mAP for box (B) precision_B = [results_dict.get("metrics/precision(B)")] recall_B = [results_dict.get("metrics/recall(B)")] mAP50_95_B = [results_dict.get("metrics/mAP50-95(B)")] # Initialize lists for overall precision, recall, and mAP for mask (M) if available precision_M = [results_dict.get("metrics/precision(M)")] if metrics_mask else None recall_M = [results_dict.get("metrics/recall(M)")] if metrics_mask else None mAP50_95_M = [results_dict.get("metrics/mAP50-95(M)")] if metrics_mask else None # Create a list of class names starting with "All" for the overall metrics name_list = ["All"] + [str(class_names[i]) for i in class_index] # Extend the metrics lists with values for each class precision_B.extend(precision_B_values) recall_B.extend(recall_B_values) mAP50_95_B.extend(mAP50_95_B_values) # If mask metrics are available, extend their lists with values for each class if metrics_mask: precision_M.extend(precision_M_values) recall_M.extend(recall_M_values) mAP50_95_M.extend(mAP50_95_M_values) # Create a DataFrame with the computed metrics if metrics_mask: st.session_state["val_dataframe"] = pd.DataFrame( { "Class Name": name_list, "Precision (B)": precision_B, "Recall (B)": recall_B, "mAP50-95 (B)": mAP50_95_B, "Precision (M)": precision_M, "Recall (M)": recall_M, "mAP50-95 (M)": mAP50_95_M, } ) else: st.session_state["val_dataframe"] = pd.DataFrame( { "Class Name": name_list, "Precision (B)": precision_B, "Recall (B)": recall_B, "mAP50-95 (B)": mAP50_95_B, } ) # Clear the initial message message.empty() # Update the message to indicate that the metrics table is ready and Display the DataFrame display_val_dataframe(st.session_state["val_dataframe"]) # Function to train the YOLO model def train_yolo_model(training_configuration): # Clear and recreate the output folder to ensure a fresh start delete_and_recreate_folder(get_path("output")) # Initialize the YOLO model with the specified path from the training configuration model = YOLO(training_configuration["model_path"]) # Add any callbacks or additional configuration to the model callback_add(model) # Train the model with the specified parameters model.train( task=training_configuration["task"], data=os.path.join(get_path("config"), "config.yaml"), epochs=training_configuration["epochs"], time=training_configuration["time"], patience=training_configuration["patience"], batch=training_configuration["batch"], imgsz=training_configuration["imgsz"], save=True, save_period=-1, cache=training_configuration["cache"], pretrained=True, optimizer=training_configuration["optimizer"], verbose=True, seed=0, deterministic=training_configuration["deterministic"], single_cls=False, rect=training_configuration["rect"], cos_lr=training_configuration["cos_lr"], resume=False, amp=training_configuration["amp"], fraction=1.0, freeze=training_configuration["freeze"], lr0=training_configuration["lr0"], lrf=training_configuration["lrf"], momentum=training_configuration["momentum"], weight_decay=training_configuration["weight_decay"], warmup_epochs=training_configuration["warmup_epochs"], warmup_momentum=training_configuration["warmup_momentum"], warmup_bias_lr=training_configuration["warmup_bias_lr"], box=training_configuration["box"], cls=training_configuration["cls"], dfl=training_configuration["dfl"], label_smoothing=training_configuration["label_smoothing"], nbs=training_configuration["nbs"], overlap_mask=training_configuration["overlap_mask"], mask_ratio=training_configuration["mask_ratio"], dropout=training_configuration["dropout"], val=training_configuration["val"], plots=training_configuration["plots"], save_dir=os.path.join(get_path("output"), "train"), project=get_path("output"), name="train", augment=False, exist_ok=True, ) return model # Function to export the model with the given parameters def export_model_with_parameters(model, export_params): global progress_text if export_params["format"] is not None: # Informing the user that the export process has started progress_text.info( "Starting the export process with the specified settings.", icon="✅", ) # Perform the model export model.export( format=export_params["format"], keras=export_params["keras"], optimize=export_params["optimize"], half=export_params["half"], int8=export_params["int8"], dynamic=export_params["dynamic"], simplify=export_params["simplify"], opset=export_params["opset"], workspace=export_params["workspace"], nms=export_params["nms"], ) # Informing the user that the export process has completed successfully progress_text.info( "The model has been successfully saved using the specified export settings.", icon="✅", ) # Function to start the YOLO model training process def start_yolo_training(selected_training, class_labels): global plot_container, val_dataframe_container # Retrieve the training configuration based on the user's selection training_configuration = get_training_validation_export_configuration( selected_training ) # Generates a YOLO model training code snippet and displays it with a download button generate_and_display_yolo_training_code(class_labels, training_configuration) # Create two columns col1, col2 = st.columns(2) # When the "Start Training" button is clicked in the first column if col1.button("Start Training", use_container_width=True): plot_container = None val_dataframe_container = None with st.spinner("Training in Progress..."): # Train the YOLO model using the provided configuration trained_model = train_yolo_model(training_configuration) # Export the model with the given parameters export_model_with_parameters( trained_model, training_configuration["export_settings"] ) # Display the validation results in a DataFrame after training val_dataframe(trained_model) elif "plot_data" in st.session_state and "val_dataframe" in st.session_state: plot_container = None val_dataframe_container = None # Display metrics plot and table if already exist display_metrics_plot(st.session_state["plot_data"]) display_val_dataframe(st.session_state["val_dataframe"]) # Prepare a ZIP file of the training output folder in memory for download zip_bytes_io = zip_folder_to_bytesio(os.path.join(get_path("output"), "train")) # Provide a button in the second column to download the ZIP file col2.download_button( label="Download", data=zip_bytes_io, file_name="model_training_output.zip", mime="application/zip", use_container_width=True, )