import argparse import logging from pathlib import Path import torch import yaml import train_models import evaluate_models from src.commons import set_seed LOGGER = logging.getLogger() LOGGER.setLevel(logging.INFO) ch = logging.StreamHandler() formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) LOGGER.addHandler(ch) def parse_args(): parser = argparse.ArgumentParser() ASVSPOOF_DATASET_PATH = "../datasets/ASVspoof2021/DF" IN_THE_WILD_DATASET_PATH = "../datasets/release_in_the_wild" parser.add_argument( "--asv_path", type=str, default=ASVSPOOF_DATASET_PATH, help="Path to ASVspoof2021 dataset directory", ) parser.add_argument( "--in_the_wild_path", type=str, default=IN_THE_WILD_DATASET_PATH, help="Path to In The Wild dataset directory", ) default_model_config = "config.yaml" parser.add_argument( "--config", help="Model config file path (default: config.yaml)", type=str, default=default_model_config, ) default_train_amount = None parser.add_argument( "--train_amount", "-a", help=f"Amount of files to load for training.", type=int, default=default_train_amount, ) default_valid_amount = None parser.add_argument( "--valid_amount", "-va", help=f"Amount of files to load for testing.", type=int, default=default_valid_amount, ) default_test_amount = None parser.add_argument( "--test_amount", "-ta", help=f"Amount of files to load for testing.", type=int, default=default_test_amount, ) default_batch_size = 8 parser.add_argument( "--batch_size", "-b", help=f"Batch size (default: {default_batch_size}).", type=int, default=default_batch_size, ) default_epochs = 10 # it was 5 originally parser.add_argument( "--epochs", "-e", help=f"Epochs (default: {default_epochs}).", type=int, default=default_epochs, ) default_model_dir = "trained_models" parser.add_argument( "--ckpt", help=f"Checkpoint directory (default: {default_model_dir}).", type=str, default=default_model_dir, ) parser.add_argument("--cpu", "-c", help="Force using cpu?", action="store_true") return parser.parse_args() if __name__ == "__main__": args = parse_args() # TRAIN MODEL with open(args.config, "r") as f: config = yaml.safe_load(f) seed = config["data"].get("seed", 42) # fix all seeds set_seed(seed) if not args.cpu and torch.cuda.is_available(): device = "cuda" else: device = "cpu" model_dir = Path(args.ckpt) model_dir.mkdir(parents=True, exist_ok=True) evaluation_config_path, model_path = train_models.train_nn( datasets_paths=[ args.asv_path, ], device=device, amount_to_use=(args.train_amount, args.valid_amount), batch_size=args.batch_size, epochs=args.epochs, model_dir=model_dir, config=config, ) with open(evaluation_config_path, "r") as f: config = yaml.safe_load(f) evaluate_models.evaluate_nn( model_paths=config["checkpoint"].get("path", []), batch_size=args.batch_size, datasets_paths=[args.in_the_wild_path], model_config=config["model"], amount_to_use=args.test_amount, device=device, )