import argparse import logging import sys import time from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import torch import yaml from src.datasets.detection_dataset import DetectionDataset from src.models import models from src.trainer import GDTrainer from src.commons import set_seed def save_model( model: torch.nn.Module, model_dir: Union[Path, str], name: str, ) -> None: full_model_dir = Path(f"{model_dir}/{name}") full_model_dir.mkdir(parents=True, exist_ok=True) torch.save(model.state_dict(), f"{full_model_dir}/ckpt.pth") def get_datasets( datasets_paths: List[Union[Path, str]], amount_to_use: Tuple[Optional[int], Optional[int]], ) -> Tuple[DetectionDataset, DetectionDataset]: data_train = DetectionDataset( asvspoof_path=datasets_paths[0], subset="train", reduced_number=amount_to_use[0], oversample=True, ) data_test = DetectionDataset( asvspoof_path=datasets_paths[0], subset="test", reduced_number=amount_to_use[1], oversample=True, ) return data_train, data_test def train_nn( datasets_paths: List[Union[Path, str]], batch_size: int, epochs: int, device: str, config: Dict, model_dir: Optional[Path] = None, amount_to_use: Tuple[Optional[int], Optional[int]] = (None, None), config_save_path: str = "configs", ) -> Tuple[str, str]: logging.info("Loading data...") model_config = config["model"] model_name, model_parameters = model_config["name"], model_config["parameters"] optimizer_config = model_config["optimizer"] timestamp = time.time() checkpoint_path = "" data_train, data_test = get_datasets( datasets_paths=datasets_paths, amount_to_use=amount_to_use, ) current_model = models.get_model( model_name=model_name, config=model_parameters, device=device, ) # If provided weights, apply corresponding ones (from an appropriate fold) model_path = config["checkpoint"]["path"] if model_path: current_model.load_state_dict(torch.load(model_path)) logging.info( f"Finetuning '{model_name}' model, weights path: '{model_path}', on {len(data_train)} audio files." ) if config["model"]["parameters"].get("freeze_encoder"): for param in current_model.whisper_model.parameters(): param.requires_grad = False else: logging.info(f"Training '{model_name}' model on {len(data_train)} audio files.") current_model = current_model.to(device) use_scheduler = "rawnet3" in model_name.lower() current_model = GDTrainer( device=device, batch_size=batch_size, epochs=epochs, optimizer_kwargs=optimizer_config, use_scheduler=use_scheduler, ).train( dataset=data_train, model=current_model, test_dataset=data_test, ) if model_dir is not None: save_name = f"model__{model_name}__{timestamp}" save_model( model=current_model, model_dir=model_dir, name=save_name, ) checkpoint_path = str(model_dir.resolve() / save_name / "ckpt.pth") # Save config for testing if model_dir is not None: config["checkpoint"] = {"path": checkpoint_path} config_name = f"model__{model_name}__{timestamp}.yaml" config_save_path = str(Path(config_save_path) / config_name) with open(config_save_path, "w") as f: yaml.dump(config, f) logging.info("Test config saved at location '{}'!".format(config_save_path)) return config_save_path, checkpoint_path def main(args): LOGGER = logging.getLogger() LOGGER.setLevel(logging.INFO) ch = logging.StreamHandler() formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) LOGGER.addHandler(ch) logging.basicConfig(stream=sys.stdout, level=logging.INFO) with open(args.config, "r") as f: config = yaml.safe_load(f) seed = config["data"].get("seed", 42) # fix all seeds set_seed(seed) if not args.cpu and torch.cuda.is_available(): device = "cuda" else: device = "cpu" model_dir = Path(args.ckpt) model_dir.mkdir(parents=True, exist_ok=True) train_nn( datasets_paths=[ args.asv_path, args.wavefake_path, args.celeb_path, args.asv19_path, ], device=device, amount_to_use=(args.train_amount, args.test_amount), batch_size=args.batch_size, epochs=args.epochs, model_dir=model_dir, config=config, ) def parse_args(): parser = argparse.ArgumentParser() ASVSPOOF_DATASET_PATH = "../datasets/ASVspoof2021/DF" parser.add_argument( "--asv_path", type=str, default=ASVSPOOF_DATASET_PATH, help="Path to ASVspoof2021 dataset directory", ) default_model_config = "config.yaml" parser.add_argument( "--config", help="Model config file path (default: config.yaml)", type=str, default=default_model_config, ) default_train_amount = None parser.add_argument( "--train_amount", "-a", help=f"Amount of files to load for training.", type=int, default=default_train_amount, ) default_test_amount = None parser.add_argument( "--test_amount", "-ta", help=f"Amount of files to load for testing.", type=int, default=default_test_amount, ) default_batch_size = 8 parser.add_argument( "--batch_size", "-b", help=f"Batch size (default: {default_batch_size}).", type=int, default=default_batch_size, ) default_epochs = 10 parser.add_argument( "--epochs", "-e", help=f"Epochs (default: {default_epochs}).", type=int, default=default_epochs, ) default_model_dir = "trained_models" parser.add_argument( "--ckpt", help=f"Checkpoint directory (default: {default_model_dir}).", type=str, default=default_model_dir, ) parser.add_argument("--cpu", "-c", help="Force using cpu?", action="store_true") return parser.parse_args() if __name__ == "__main__": main(parse_args())