RvanB's picture
Add files from other repo
fbf7e95
raw
history blame
2.84 kB
import pytorch_lightning as lightning
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import warnings
import yaml
import argparse
import os
import torch
from marcai.pl import MARCDataModule, SimilarityVectorModel
from marcai.utils import load_config
import tarfile
def train(name=None):
config_path = "config.yaml"
config = load_config(config_path)
model_config = load_config(config_path)["model"]
# Create data module from processed data
warnings.filterwarnings("ignore", ".*does not have many workers.*")
data = MARCDataModule(
model_config["train_processed_path"],
model_config["val_processed_path"],
model_config["test_processed_path"],
model_config["features"],
model_config["batch_size"],
)
# Create model
model = SimilarityVectorModel(
model_config["lr"],
model_config["weight_decay"],
model_config["optimizer"],
model_config["batch_size"],
model_config["features"],
model_config["hidden_sizes"],
)
save_dir = os.path.join(model_config["saved_models_dir"], name)
os.makedirs(save_dir, exist_ok=True)
# Save best models
checkpoint_callback = ModelCheckpoint(
monitor="val_acc", mode="max", dirpath=save_dir, filename="model"
)
callbacks = [checkpoint_callback]
if model_config["patience"] != -1:
early_stop_callback = EarlyStopping(
monitor="val_acc",
min_delta=0.00,
patience=model_config["patience"],
verbose=False,
mode="max",
)
callbacks.append(early_stop_callback)
trainer = lightning.Trainer(
max_epochs=model_config["max_epochs"], callbacks=callbacks, accelerator="cpu"
)
trainer.fit(model, data)
# Save ONNX
onnx_path = os.path.join(save_dir, "model.onnx")
input_sample = torch.randn((1, len(model.attrs)))
torch.onnx.export(
model,
input_sample,
onnx_path,
export_params=True,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
# Save config
config_filename = os.path.join(save_dir, "config.yaml")
with open(config_filename, "w") as f:
dump = yaml.dump(config)
f.write(dump)
# Compress model directory files
tar_path = f"{save_dir}/{name}.tar.gz"
with tarfile.open(tar_path, mode="w:gz") as archive:
archive.add(save_dir, arcname=os.path.basename(save_dir))
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-n", "--run-name", help="Name for training run"
)
args = parser.parse_args()
train(args.run_name)
if __name__ == "__main__":
main()