|
import pytorch_lightning as lightning |
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
|
import warnings |
|
import yaml |
|
import argparse |
|
import os |
|
import torch |
|
from marcai.pl import MARCDataModule, SimilarityVectorModel |
|
from marcai.utils import load_config |
|
import tarfile |
|
|
|
|
|
def train(name=None): |
|
config_path = "config.yaml" |
|
config = load_config(config_path) |
|
model_config = load_config(config_path)["model"] |
|
|
|
|
|
warnings.filterwarnings("ignore", ".*does not have many workers.*") |
|
data = MARCDataModule( |
|
model_config["train_processed_path"], |
|
model_config["val_processed_path"], |
|
model_config["test_processed_path"], |
|
model_config["features"], |
|
model_config["batch_size"], |
|
) |
|
|
|
|
|
model = SimilarityVectorModel( |
|
model_config["lr"], |
|
model_config["weight_decay"], |
|
model_config["optimizer"], |
|
model_config["batch_size"], |
|
model_config["features"], |
|
model_config["hidden_sizes"], |
|
) |
|
|
|
save_dir = os.path.join(model_config["saved_models_dir"], name) |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
monitor="val_acc", mode="max", dirpath=save_dir, filename="model" |
|
) |
|
callbacks = [checkpoint_callback] |
|
|
|
if model_config["patience"] != -1: |
|
early_stop_callback = EarlyStopping( |
|
monitor="val_acc", |
|
min_delta=0.00, |
|
patience=model_config["patience"], |
|
verbose=False, |
|
mode="max", |
|
) |
|
callbacks.append(early_stop_callback) |
|
|
|
trainer = lightning.Trainer( |
|
max_epochs=model_config["max_epochs"], callbacks=callbacks, accelerator="cpu" |
|
) |
|
trainer.fit(model, data) |
|
|
|
|
|
onnx_path = os.path.join(save_dir, "model.onnx") |
|
input_sample = torch.randn((1, len(model.attrs))) |
|
torch.onnx.export( |
|
model, |
|
input_sample, |
|
onnx_path, |
|
export_params=True, |
|
do_constant_folding=True, |
|
input_names=["input"], |
|
output_names=["output"], |
|
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, |
|
) |
|
|
|
|
|
config_filename = os.path.join(save_dir, "config.yaml") |
|
|
|
with open(config_filename, "w") as f: |
|
dump = yaml.dump(config) |
|
f.write(dump) |
|
|
|
|
|
tar_path = f"{save_dir}/{name}.tar.gz" |
|
with tarfile.open(tar_path, mode="w:gz") as archive: |
|
archive.add(save_dir, arcname=os.path.basename(save_dir)) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-n", "--run-name", help="Name for training run" |
|
) |
|
args = parser.parse_args() |
|
|
|
train(args.run_name) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|