Spaces:
Sleeping
Sleeping
""" | |
Main file to launch training and testing experiments. | |
""" | |
import yaml | |
import os | |
import argparse | |
import numpy as np | |
import torch | |
from .config.project_config import Config as cfg | |
from .train import train_net | |
from .export import export_predictions, export_homograpy_adaptation | |
# Pytorch configurations | |
torch.cuda.empty_cache() | |
torch.backends.cudnn.benchmark = True | |
def load_config(config_path): | |
"""Load configurations from a given yaml file.""" | |
# Check file exists | |
if not os.path.exists(config_path): | |
raise ValueError("[Error] The provided config path is not valid.") | |
# Load the configuration | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
return config | |
def update_config(path, model_cfg=None, dataset_cfg=None): | |
"""Update configuration file from the resume path.""" | |
# Check we need to update or completely override. | |
model_cfg = {} if model_cfg is None else model_cfg | |
dataset_cfg = {} if dataset_cfg is None else dataset_cfg | |
# Load saved configs | |
with open(os.path.join(path, "model_cfg.yaml"), "r") as f: | |
model_cfg_saved = yaml.safe_load(f) | |
model_cfg.update(model_cfg_saved) | |
with open(os.path.join(path, "dataset_cfg.yaml"), "r") as f: | |
dataset_cfg_saved = yaml.safe_load(f) | |
dataset_cfg.update(dataset_cfg_saved) | |
# Update the saved yaml file | |
if not model_cfg == model_cfg_saved: | |
with open(os.path.join(path, "model_cfg.yaml"), "w") as f: | |
yaml.dump(model_cfg, f) | |
if not dataset_cfg == dataset_cfg_saved: | |
with open(os.path.join(path, "dataset_cfg.yaml"), "w") as f: | |
yaml.dump(dataset_cfg, f) | |
return model_cfg, dataset_cfg | |
def record_config(model_cfg, dataset_cfg, output_path): | |
"""Record dataset config to the log path.""" | |
# Record model config | |
with open(os.path.join(output_path, "model_cfg.yaml"), "w") as f: | |
yaml.safe_dump(model_cfg, f) | |
# Record dataset config | |
with open(os.path.join(output_path, "dataset_cfg.yaml"), "w") as f: | |
yaml.safe_dump(dataset_cfg, f) | |
def train(args, dataset_cfg, model_cfg, output_path): | |
"""Training function.""" | |
# Update model config from the resume path (only in resume mode) | |
if args.resume: | |
if os.path.realpath(output_path) != os.path.realpath(args.resume_path): | |
record_config(model_cfg, dataset_cfg, output_path) | |
# First time, then write the config file to the output path | |
else: | |
record_config(model_cfg, dataset_cfg, output_path) | |
# Launch the training | |
train_net(args, dataset_cfg, model_cfg, output_path) | |
def export( | |
args, | |
dataset_cfg, | |
model_cfg, | |
output_path, | |
export_dataset_mode=None, | |
device=torch.device("cuda"), | |
): | |
"""Export function.""" | |
# Choose between normal predictions export or homography adaptation | |
if dataset_cfg.get("homography_adaptation") is not None: | |
print("[Info] Export predictions with homography adaptation.") | |
export_homograpy_adaptation( | |
args, dataset_cfg, model_cfg, output_path, export_dataset_mode, device | |
) | |
else: | |
print("[Info] Export predictions normally.") | |
export_predictions( | |
args, dataset_cfg, model_cfg, output_path, export_dataset_mode | |
) | |
def main( | |
args, dataset_cfg, model_cfg, export_dataset_mode=None, device=torch.device("cuda") | |
): | |
"""Main function.""" | |
# Make the output path | |
output_path = os.path.join(cfg.EXP_PATH, args.exp_name) | |
if args.mode == "train": | |
if not os.path.exists(output_path): | |
os.makedirs(output_path) | |
print("[Info] Training mode") | |
print("\t Output path: %s" % output_path) | |
train(args, dataset_cfg, model_cfg, output_path) | |
elif args.mode == "export": | |
# Different output_path in export mode | |
output_path = os.path.join(cfg.export_dataroot, args.exp_name) | |
print("[Info] Export mode") | |
print("\t Output path: %s" % output_path) | |
export( | |
args, | |
dataset_cfg, | |
model_cfg, | |
output_path, | |
export_dataset_mode, | |
device=device, | |
) | |
else: | |
raise ValueError("[Error]: Unknown mode: " + args.mode) | |
def set_random_seed(seed): | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if __name__ == "__main__": | |
# Parse input arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--mode", type=str, default="train", help="'train' or 'export'." | |
) | |
parser.add_argument( | |
"--dataset_config", type=str, default=None, help="Path to the dataset config." | |
) | |
parser.add_argument( | |
"--model_config", type=str, default=None, help="Path to the model config." | |
) | |
parser.add_argument("--exp_name", type=str, default="exp", help="Experiment name.") | |
parser.add_argument( | |
"--resume", | |
action="store_true", | |
default=False, | |
help="Load a previously trained model.", | |
) | |
parser.add_argument( | |
"--pretrained", | |
action="store_true", | |
default=False, | |
help="Start training from a pre-trained model.", | |
) | |
parser.add_argument( | |
"--resume_path", default=None, help="Path from which to resume training." | |
) | |
parser.add_argument( | |
"--pretrained_path", default=None, help="Path to the pre-trained model." | |
) | |
parser.add_argument( | |
"--checkpoint_name", default=None, help="Name of the checkpoint to use." | |
) | |
parser.add_argument( | |
"--export_dataset_mode", default=None, help="'train' or 'test'." | |
) | |
parser.add_argument( | |
"--export_batch_size", default=4, type=int, help="Export batch size." | |
) | |
args = parser.parse_args() | |
# Check if GPU is available | |
# Get the model | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
# Check if dataset config and model config is given. | |
if ( | |
((args.dataset_config is None) or (args.model_config is None)) | |
and (not args.resume) | |
and (args.mode == "train") | |
): | |
raise ValueError( | |
"[Error] The dataset config and model config should be given in non-resume mode" | |
) | |
# If resume, check if the resume path has been given | |
if args.resume and (args.resume_path is None): | |
raise ValueError("[Error] Missing resume path.") | |
# [Training] Load the config file. | |
if args.mode == "train" and (not args.resume): | |
# Check the pretrained checkpoint_path exists | |
if args.pretrained: | |
checkpoint_folder = args.resume_path | |
checkpoint_path = os.path.join(args.pretrained_path, args.checkpoint_name) | |
if not os.path.exists(checkpoint_path): | |
raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) | |
dataset_cfg = load_config(args.dataset_config) | |
model_cfg = load_config(args.model_config) | |
# [resume Training, Test, Export] Load the config file. | |
elif (args.mode == "train" and args.resume) or (args.mode == "export"): | |
# Check checkpoint path exists | |
checkpoint_folder = args.resume_path | |
checkpoint_path = os.path.join(args.resume_path, args.checkpoint_name) | |
if not os.path.exists(checkpoint_path): | |
raise ValueError("[Error] Missing checkpoint: " + checkpoint_path) | |
# Load model_cfg from checkpoint folder if not provided | |
if args.model_config is None: | |
print("[Info] No model config provided. Loading from checkpoint folder.") | |
model_cfg_path = os.path.join(checkpoint_folder, "model_cfg.yaml") | |
if not os.path.exists(model_cfg_path): | |
raise ValueError("[Error] Missing model config in checkpoint path.") | |
model_cfg = load_config(model_cfg_path) | |
else: | |
model_cfg = load_config(args.model_config) | |
# Load dataset_cfg from checkpoint folder if not provided | |
if args.dataset_config is None: | |
print("[Info] No dataset config provided. Loading from checkpoint folder.") | |
dataset_cfg_path = os.path.join(checkpoint_folder, "dataset_cfg.yaml") | |
if not os.path.exists(dataset_cfg_path): | |
raise ValueError("[Error] Missing dataset config in checkpoint path.") | |
dataset_cfg = load_config(dataset_cfg_path) | |
else: | |
dataset_cfg = load_config(args.dataset_config) | |
# Check the --export_dataset_mode flag | |
if (args.mode == "export") and (args.export_dataset_mode is None): | |
raise ValueError("[Error] Empty --export_dataset_mode flag.") | |
else: | |
raise ValueError("[Error] Unknown mode: " + args.mode) | |
# Set the random seed | |
seed = dataset_cfg.get("random_seed", 0) | |
set_random_seed(seed) | |
main( | |
args, | |
dataset_cfg, | |
model_cfg, | |
export_dataset_mode=args.export_dataset_mode, | |
device=device, | |
) | |