import argparse import os import tarfile import warnings import pytorch_lightning as lightning import torch import yaml from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from marcai.pl import MARCDataModule, SimilarityVectorModel from marcai.utils import load_config def train(name=None): config_path = "config.yaml" config = load_config(config_path) model_config = load_config(config_path)["model"] # Create data module from processed data warnings.filterwarnings("ignore", ".*does not have many workers.*") data = MARCDataModule( model_config["train_processed_path"], model_config["val_processed_path"], model_config["test_processed_path"], model_config["features"], model_config["batch_size"], ) # Create model model = SimilarityVectorModel( model_config["lr"], model_config["weight_decay"], model_config["optimizer"], model_config["batch_size"], model_config["features"], model_config["hidden_sizes"], ) save_dir = os.path.join(model_config["saved_models_dir"], name) os.makedirs(save_dir, exist_ok=True) # Save best models checkpoint_callback = ModelCheckpoint( monitor="val_acc", mode="max", dirpath=save_dir, filename="model" ) callbacks = [checkpoint_callback] if model_config["patience"] != -1: early_stop_callback = EarlyStopping( monitor="val_acc", min_delta=0.00, patience=model_config["patience"], verbose=False, mode="max", ) callbacks.append(early_stop_callback) trainer = lightning.Trainer( max_epochs=model_config["max_epochs"], callbacks=callbacks, accelerator="cpu" ) trainer.fit(model, data) # Save ONNX onnx_path = os.path.join(save_dir, "model.onnx") input_sample = torch.randn((1, len(model.attrs))) torch.onnx.export( model, input_sample, onnx_path, export_params=True, do_constant_folding=True, input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, ) # Save config config_filename = os.path.join(save_dir, "config.yaml") with open(config_filename, "w") as f: dump = yaml.dump(config) f.write(dump) # Compress model directory files tar_path = f"{save_dir}/{name}.tar.gz" with tarfile.open(tar_path, mode="w:gz") as archive: archive.add(save_dir, arcname=os.path.basename(save_dir)) def args_parser(): parser = argparse.ArgumentParser() parser.add_argument("-n", "--run-name", help="Name for training run", required=True) return parser def main(args): train(args.run_name) if __name__ == "__main__": args = args_parser().parse_args() main(args)