Spaces:
Sleeping
Sleeping
File size: 2,899 Bytes
fbf7e95 c19ce61 fbf7e95 c19ce61 fbf7e95 5381b52 fbf7e95 5381b52 fbf7e95 d29e6b9 fbf7e95 d29e6b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import argparse
import os
import tarfile
import warnings
import pytorch_lightning as lightning
import torch
import yaml
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from marcai.pl import MARCDataModule, SimilarityVectorModel
from marcai.utils import load_config
def train(name=None):
config_path = "config.yaml"
config = load_config(config_path)
model_config = load_config(config_path)["model"]
# Create data module from processed data
warnings.filterwarnings("ignore", ".*does not have many workers.*")
data = MARCDataModule(
model_config["train_processed_path"],
model_config["val_processed_path"],
model_config["test_processed_path"],
model_config["features"],
model_config["batch_size"],
)
# Create model
model = SimilarityVectorModel(
model_config["lr"],
model_config["weight_decay"],
model_config["optimizer"],
model_config["batch_size"],
model_config["features"],
model_config["hidden_sizes"],
)
save_dir = os.path.join(model_config["saved_models_dir"], name)
os.makedirs(save_dir, exist_ok=True)
# Save best models
checkpoint_callback = ModelCheckpoint(
monitor="val_acc", mode="max", dirpath=save_dir, filename="model"
)
callbacks = [checkpoint_callback]
if model_config["patience"] != -1:
early_stop_callback = EarlyStopping(
monitor="val_acc",
min_delta=0.00,
patience=model_config["patience"],
verbose=False,
mode="max",
)
callbacks.append(early_stop_callback)
trainer = lightning.Trainer(
max_epochs=model_config["max_epochs"], callbacks=callbacks, accelerator="cpu"
)
trainer.fit(model, data)
# Save ONNX
onnx_path = os.path.join(save_dir, "model.onnx")
input_sample = torch.randn((1, len(model.attrs)))
torch.onnx.export(
model,
input_sample,
onnx_path,
export_params=True,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)
# Save config
config_filename = os.path.join(save_dir, "config.yaml")
with open(config_filename, "w") as f:
dump = yaml.dump(config)
f.write(dump)
# Compress model directory files
tar_path = f"{save_dir}/{name}.tar.gz"
with tarfile.open(tar_path, mode="w:gz") as archive:
archive.add(save_dir, arcname=os.path.basename(save_dir))
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--run-name", help="Name for training run", required=True)
return parser
def main(args):
train(args.run_name)
if __name__ == "__main__":
args = args_parser().parse_args()
main(args)
|