""" Main file to launch training and testing experiments. """ import yaml import os import argparse import numpy as np import torch from .config.project_config import Config as cfg from .train import train_net from .export import export_predictions, export_homograpy_adaptation # Pytorch configurations torch.cuda.empty_cache() torch.backends.cudnn.benchmark = True def load_config(config_path): """Load configurations from a given yaml file.""" # Check file exists if not os.path.exists(config_path): raise ValueError("[Error] The provided config path is not valid.") # Load the configuration with open(config_path, "r") as f: config = yaml.safe_load(f) return config def update_config(path, model_cfg=None, dataset_cfg=None): """Update configuration file from the resume path.""" # Check we need to update or completely override. model_cfg = {} if model_cfg is None else model_cfg dataset_cfg = {} if dataset_cfg is None else dataset_cfg # Load saved configs with open(os.path.join(path, "model_cfg.yaml"), "r") as f: model_cfg_saved = yaml.safe_load(f) model_cfg.update(model_cfg_saved) with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f: dataset_cfg_saved = yaml.safe_load(f) dataset_cfg.update(dataset_cfg_saved) # Update the saved yaml file if not model_cfg == model_cfg_saved: with open(os.path.join(path, "model_cfg.yaml"), "w") as f: yaml.dump(model_cfg, f) if not dataset_cfg == dataset_cfg_saved: with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f: yaml.dump(dataset_cfg, f) return model_cfg, dataset_cfg def record_config(model_cfg, dataset_cfg, output_path): """Record dataset config to the log path.""" # Record model config with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f: yaml.safe_dump(model_cfg, f) # Record dataset config with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f: yaml.safe_dump(dataset_cfg, f) def train(args, dataset_cfg, model_cfg, output_path): """Training function.""" # Update model config from the resume path (only in resume mode) if args.resume: if os.path.realpath(output_path) != os.path.realpath(args.resume_path): record_config(model_cfg, dataset_cfg, output_path) # First time, then write the config file to the output path else: record_config(model_cfg, dataset_cfg, output_path) # Launch the training train_net(args, dataset_cfg, model_cfg, output_path) def export( args, dataset_cfg, model_cfg, output_path, export_dataset_mode=None, device=torch.device("cuda"), ): """Export function.""" # Choose between normal predictions export or homography adaptation if dataset_cfg.get("homography_adaptation") is not None: print("[Info] Export predictions with homography adaptation.") export_homograpy_adaptation( args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device ) else: print("[Info] Export predictions normally.") export_predictions( args, dataset_cfg, model_cfg, output_path, export_dataset_mode ) def main( args, dataset_cfg, model_cfg, export_dataset_mode=None, device=torch.device("cuda") ): """Main function.""" # Make the output path output_path = os.path.join(cfg.EXP_PATH, args.exp_name) if args.mode == "train": if not os.path.exists(output_path): os.makedirs(output_path) print("[Info] Training mode") print("\t Output path: %s" % output_path) train(args, dataset_cfg, model_cfg, output_path) elif args.mode == "export": # Different output_path in export mode output_path = os.path.join(cfg.export_dataroot, args.exp_name) print("[Info] Export mode") print("\t Output path: %s" % output_path) export( args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device=device, ) else: raise ValueError("[Error]: Unknown mode: " + args.mode) def set_random_seed(seed): np.random.seed(seed) torch.manual_seed(seed) if __name__ == "__main__": # Parse input arguments parser = argparse.ArgumentParser() parser.add_argument( "--mode", type=str, default="train", help="'train' or 'export'." ) parser.add_argument( "--dataset_config", type=str, default=None, help="Path to the dataset config." ) parser.add_argument( "--model_config", type=str, default=None, help="Path to the model config." ) parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.") parser.add_argument( "--resume", action="store_true", default=False, help="Load a previously trained model.", ) parser.add_argument( "--pretrained", action="store_true", default=False, help="Start training from a pre-trained model.", ) parser.add_argument( "--resume_path", default=None, help="Path from which to resume training." ) parser.add_argument( "--pretrained_path", default=None, help="Path to the pre-trained model." ) parser.add_argument( "--checkpoint_name", default=None, help="Name of the checkpoint to use." ) parser.add_argument( "--export_dataset_mode", default=None, help="'train' or 'test'." ) parser.add_argument( "--export_batch_size", default=4, type=int, help="Export batch size." ) args = parser.parse_args() # Check if GPU is available # Get the model if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") # Check if dataset config and model config is given. if ( ((args.dataset_config is None) or (args.model_config is None)) and (not args.resume) and (args.mode == "train") ): raise ValueError( "[Error] The dataset config and model config should be given in non-resume mode" ) # If resume, check if the resume path has been given if args.resume and (args.resume_path is None): raise ValueError("[Error] Missing resume path.") # [Training] Load the config file. if args.mode == "train" and (not args.resume): # Check the pretrained checkpoint_path exists if args.pretrained: checkpoint_folder = args.resume_path checkpoint_path = os.path.join(args.pretrained_path, args.checkpoint_name) if not os.path.exists(checkpoint_path): raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) dataset_cfg = load_config(args.dataset_config) model_cfg = load_config(args.model_config) # [resume Training, Test, Export] Load the config file. elif (args.mode == "train" and args.resume) or (args.mode == "export"): # Check checkpoint path exists checkpoint_folder = args.resume_path checkpoint_path = os.path.join(args.resume_path, args.checkpoint_name) if not os.path.exists(checkpoint_path): raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) # Load model_cfg from checkpoint folder if not provided if args.model_config is None: print("[Info] No model config provided. Loading from checkpoint folder.") model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml") if not os.path.exists(model_cfg_path): raise ValueError("[Error] Missing model config in checkpoint path.") model_cfg = load_config(model_cfg_path) else: model_cfg = load_config(args.model_config) # Load dataset_cfg from checkpoint folder if not provided if args.dataset_config is None: print("[Info] No dataset config provided. Loading from checkpoint folder.") dataset_cfg_path = os.path.join(checkpoint_folder, "dataset_cfg.yaml") if not os.path.exists(dataset_cfg_path): raise ValueError("[Error] Missing dataset config in checkpoint path.") dataset_cfg = load_config(dataset_cfg_path) else: dataset_cfg = load_config(args.dataset_config) # Check the --export_dataset_mode flag if (args.mode == "export") and (args.export_dataset_mode is None): raise ValueError("[Error] Empty --export_dataset_mode flag.") else: raise ValueError("[Error] Unknown mode: " + args.mode) # Set the random seed seed = dataset_cfg.get("random_seed", 0) set_random_seed(seed) main( args, dataset_cfg, model_cfg, export_dataset_mode=args.export_dataset_mode, device=device, )