P-FAD / train_models.py
mrneuralnet's picture
Initial commit
3fb4562
import argparse
import logging
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
import yaml
from src.datasets.detection_dataset import DetectionDataset
from src.models import models
from src.trainer import GDTrainer
from src.commons import set_seed
def save_model(
model: torch.nn.Module,
model_dir: Union[Path, str],
name: str,
) -> None:
full_model_dir = Path(f"{model_dir}/{name}")
full_model_dir.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), f"{full_model_dir}/ckpt.pth")
def get_datasets(
datasets_paths: List[Union[Path, str]],
amount_to_use: Tuple[Optional[int], Optional[int]],
) -> Tuple[DetectionDataset, DetectionDataset]:
data_train = DetectionDataset(
asvspoof_path=datasets_paths[0],
subset="train",
reduced_number=amount_to_use[0],
oversample=True,
)
data_test = DetectionDataset(
asvspoof_path=datasets_paths[0],
subset="test",
reduced_number=amount_to_use[1],
oversample=True,
)
return data_train, data_test
def train_nn(
datasets_paths: List[Union[Path, str]],
batch_size: int,
epochs: int,
device: str,
config: Dict,
model_dir: Optional[Path] = None,
amount_to_use: Tuple[Optional[int], Optional[int]] = (None, None),
config_save_path: str = "configs",
) -> Tuple[str, str]:
logging.info("Loading data...")
model_config = config["model"]
model_name, model_parameters = model_config["name"], model_config["parameters"]
optimizer_config = model_config["optimizer"]
timestamp = time.time()
checkpoint_path = ""
data_train, data_test = get_datasets(
datasets_paths=datasets_paths,
amount_to_use=amount_to_use,
)
current_model = models.get_model(
model_name=model_name,
config=model_parameters,
device=device,
)
# If provided weights, apply corresponding ones (from an appropriate fold)
model_path = config["checkpoint"]["path"]
if model_path:
current_model.load_state_dict(torch.load(model_path))
logging.info(
f"Finetuning '{model_name}' model, weights path: '{model_path}', on {len(data_train)} audio files."
)
if config["model"]["parameters"].get("freeze_encoder"):
for param in current_model.whisper_model.parameters():
param.requires_grad = False
else:
logging.info(f"Training '{model_name}' model on {len(data_train)} audio files.")
current_model = current_model.to(device)
use_scheduler = "rawnet3" in model_name.lower()
current_model = GDTrainer(
device=device,
batch_size=batch_size,
epochs=epochs,
optimizer_kwargs=optimizer_config,
use_scheduler=use_scheduler,
).train(
dataset=data_train,
model=current_model,
test_dataset=data_test,
)
if model_dir is not None:
save_name = f"model__{model_name}__{timestamp}"
save_model(
model=current_model,
model_dir=model_dir,
name=save_name,
)
checkpoint_path = str(model_dir.resolve() / save_name / "ckpt.pth")
# Save config for testing
if model_dir is not None:
config["checkpoint"] = {"path": checkpoint_path}
config_name = f"model__{model_name}__{timestamp}.yaml"
config_save_path = str(Path(config_save_path) / config_name)
with open(config_save_path, "w") as f:
yaml.dump(config, f)
logging.info("Test config saved at location '{}'!".format(config_save_path))
return config_save_path, checkpoint_path
def main(args):
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.INFO)
ch = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
ch.setFormatter(formatter)
LOGGER.addHandler(ch)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
with open(args.config, "r") as f:
config = yaml.safe_load(f)
seed = config["data"].get("seed", 42)
# fix all seeds
set_seed(seed)
if not args.cpu and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
model_dir = Path(args.ckpt)
model_dir.mkdir(parents=True, exist_ok=True)
train_nn(
datasets_paths=[
args.asv_path,
args.wavefake_path,
args.celeb_path,
args.asv19_path,
],
device=device,
amount_to_use=(args.train_amount, args.test_amount),
batch_size=args.batch_size,
epochs=args.epochs,
model_dir=model_dir,
config=config,
)
def parse_args():
parser = argparse.ArgumentParser()
ASVSPOOF_DATASET_PATH = "../datasets/ASVspoof2021/DF"
parser.add_argument(
"--asv_path",
type=str,
default=ASVSPOOF_DATASET_PATH,
help="Path to ASVspoof2021 dataset directory",
)
default_model_config = "config.yaml"
parser.add_argument(
"--config",
help="Model config file path (default: config.yaml)",
type=str,
default=default_model_config,
)
default_train_amount = None
parser.add_argument(
"--train_amount",
"-a",
help=f"Amount of files to load for training.",
type=int,
default=default_train_amount,
)
default_test_amount = None
parser.add_argument(
"--test_amount",
"-ta",
help=f"Amount of files to load for testing.",
type=int,
default=default_test_amount,
)
default_batch_size = 8
parser.add_argument(
"--batch_size",
"-b",
help=f"Batch size (default: {default_batch_size}).",
type=int,
default=default_batch_size,
)
default_epochs = 10
parser.add_argument(
"--epochs",
"-e",
help=f"Epochs (default: {default_epochs}).",
type=int,
default=default_epochs,
)
default_model_dir = "trained_models"
parser.add_argument(
"--ckpt",
help=f"Checkpoint directory (default: {default_model_dir}).",
type=str,
default=default_model_dir,
)
parser.add_argument("--cpu", "-c", help="Force using cpu?", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
main(parse_args())