""" Main file to launch training and testing experiments. """ import yaml import os import argparse import numpy as np import torch from .config.project_config import Config as cfg from .train import train_net from .export import export_predictions, export_homograpy_adaptation # Pytorch configurations torch.cuda.empty_cache() torch.backends.cudnn.benchmark = True def load_config(config_path): """ Load configurations from a given yaml file. """ # Check file exists if not os.path.exists(config_path): raise ValueError("[Error] The provided config path is not valid.") # Load the configuration with open(config_path, "r") as f: config = yaml.safe_load(f) return config def update_config(path, model_cfg=None, dataset_cfg=None): """ Update configuration file from the resume path. """ # Check we need to update or completely override. model_cfg = {} if model_cfg is None else model_cfg dataset_cfg = {} if dataset_cfg is None else dataset_cfg # Load saved configs with open(os.path.join(path, "model_cfg.yaml"), "r") as f: model_cfg_saved = yaml.safe_load(f) model_cfg.update(model_cfg_saved) with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f: dataset_cfg_saved = yaml.safe_load(f) dataset_cfg.update(dataset_cfg_saved) # Update the saved yaml file if not model_cfg == model_cfg_saved: with open(os.path.join(path, "model_cfg.yaml"), "w") as f: yaml.dump(model_cfg, f) if not dataset_cfg == dataset_cfg_saved: with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f: yaml.dump(dataset_cfg, f) return model_cfg, dataset_cfg def record_config(model_cfg, dataset_cfg, output_path): """ Record dataset config to the log path. """ # Record model config with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f: yaml.safe_dump(model_cfg, f) # Record dataset config with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f: yaml.safe_dump(dataset_cfg, f) def train(args, dataset_cfg, model_cfg, output_path): """ Training function. """ # Update model config from the resume path (only in resume mode) if args.resume: if os.path.realpath(output_path) != os.path.realpath(args.resume_path): record_config(model_cfg, dataset_cfg, output_path) # First time, then write the config file to the output path else: record_config(model_cfg, dataset_cfg, output_path) # Launch the training train_net(args, dataset_cfg, model_cfg, output_path) def export(args, dataset_cfg, model_cfg, output_path, export_dataset_mode=None, device=torch.device("cuda")): """ Export function. """ # Choose between normal predictions export or homography adaptation if dataset_cfg.get("homography_adaptation") is not None: print("[Info] Export predictions with homography adaptation.") export_homograpy_adaptation(args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device) else: print("[Info] Export predictions normally.") export_predictions(args, dataset_cfg, model_cfg, output_path, export_dataset_mode) def main(args, dataset_cfg, model_cfg, export_dataset_mode=None, device=torch.device("cuda")): """ Main function. """ # Make the output path output_path = os.path.join(cfg.EXP_PATH, args.exp_name) if args.mode == "train": if not os.path.exists(output_path): os.makedirs(output_path) print("[Info] Training mode") print("\t Output path: %s" % output_path) train(args, dataset_cfg, model_cfg, output_path) elif args.mode == "export": # Different output_path in export mode output_path = os.path.join(cfg.export_dataroot, args.exp_name) print("[Info] Export mode") print("\t Output path: %s" % output_path) export(args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device=device) else: raise ValueError("[Error]: Unknown mode: " + args.mode) def set_random_seed(seed): np.random.seed(seed) torch.manual_seed(seed) if __name__ == "__main__": # Parse input arguments parser = argparse.ArgumentParser() parser.add_argument("--mode", type=str, default="train", help="'train' or 'export'.") parser.add_argument("--dataset_config", type=str, default=None, help="Path to the dataset config.") parser.add_argument("--model_config", type=str, default=None, help="Path to the model config.") parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.") parser.add_argument("--resume", action="store_true", default=False, help="Load a previously trained model.") parser.add_argument("--pretrained", action="store_true", default=False, help="Start training from a pre-trained model.") parser.add_argument("--resume_path", default=None, help="Path from which to resume training.") parser.add_argument("--pretrained_path", default=None, help="Path to the pre-trained model.") parser.add_argument("--checkpoint_name", default=None, help="Name of the checkpoint to use.") parser.add_argument("--export_dataset_mode", default=None, help="'train' or 'test'.") parser.add_argument("--export_batch_size", default=4, type=int, help="Export batch size.") args = parser.parse_args() # Check if GPU is available # Get the model if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") # Check if dataset config and model config is given. if (((args.dataset_config is None) or (args.model_config is None)) and (not args.resume) and (args.mode == "train")): raise ValueError( "[Error] The dataset config and model config should be given in non-resume mode") # If resume, check if the resume path has been given if args.resume and (args.resume_path is None): raise ValueError( "[Error] Missing resume path.") # [Training] Load the config file. if args.mode == "train" and (not args.resume): # Check the pretrained checkpoint_path exists if args.pretrained: checkpoint_folder = args.resume_path checkpoint_path = os.path.join(args.pretrained_path, args.checkpoint_name) if not os.path.exists(checkpoint_path): raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) dataset_cfg = load_config(args.dataset_config) model_cfg = load_config(args.model_config) # [resume Training, Test, Export] Load the config file. elif (args.mode == "train" and args.resume) or (args.mode == "export"): # Check checkpoint path exists checkpoint_folder = args.resume_path checkpoint_path = os.path.join(args.resume_path, args.checkpoint_name) if not os.path.exists(checkpoint_path): raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) # Load model_cfg from checkpoint folder if not provided if args.model_config is None: print("[Info] No model config provided. Loading from checkpoint folder.") model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml") if not os.path.exists(model_cfg_path): raise ValueError( "[Error] Missing model config in checkpoint path.") model_cfg = load_config(model_cfg_path) else: model_cfg = load_config(args.model_config) # Load dataset_cfg from checkpoint folder if not provided if args.dataset_config is None: print("[Info] No dataset config provided. Loading from checkpoint folder.") dataset_cfg_path = os.path.join(checkpoint_folder, "dataset_cfg.yaml") if not os.path.exists(dataset_cfg_path): raise ValueError( "[Error] Missing dataset config in checkpoint path.") dataset_cfg = load_config(dataset_cfg_path) else: dataset_cfg = load_config(args.dataset_config) # Check the --export_dataset_mode flag if (args.mode == "export") and (args.export_dataset_mode is None): raise ValueError("[Error] Empty --export_dataset_mode flag.") else: raise ValueError("[Error] Unknown mode: " + args.mode) # Set the random seed seed = dataset_cfg.get("random_seed", 0) set_random_seed(seed) main(args, dataset_cfg, model_cfg, export_dataset_mode=args.export_dataset_mode, device=device)