Sam Passaglia
initial commit
9aba307
raw
history blame
No virus
3.39 kB
"""main.py
Main entry point for training
"""
import sys
import tempfile
import warnings
from argparse import Namespace
from pathlib import Path
import mlflow
from datasets import load_dataset
from config import config
from config.config import logger
from yomikata import utils
from yomikata.dbert import dBert
warnings.filterwarnings("ignore")
def train_model(
model_name: "dBert",
dataset_name: str = "",
experiment_name: str = "baselines",
run_name: str = "dbert-default",
training_args: dict = {},
) -> None:
"""Train a model given arguments.
Args:
dataset_name (str): name of the dataset to be trained on. Defaults to the full dataset.
args_fp (str): location of args.
experiment_name (str): name of experiment.
run_name (str): name of specific run in experiment.
"""
mlflow.set_experiment(experiment_name=experiment_name)
with mlflow.start_run(run_name=run_name):
run_id = mlflow.active_run().info.run_id
logger.info(f"Run ID: {run_id}")
experiment_id = mlflow.get_run(run_id=run_id).info.experiment_id
artifacts_dir = Path(config.RUN_REGISTRY, experiment_id, run_id, "artifacts")
# Initialize the model
if model_name == "dBert":
reader = dBert(reinitialize=True, artifacts_dir=artifacts_dir)
else:
raise ValueError("model_name must be dBert for now")
# Load train val test data
dataset = load_dataset(
"csv",
data_files={
"train": str(
Path(config.TRAIN_DATA_DIR, "train_" + dataset_name + ".csv")
),
"val": str(Path(config.VAL_DATA_DIR, "val_" + dataset_name + ".csv")),
"test": str(
Path(config.TEST_DATA_DIR, "test_" + dataset_name + ".csv")
),
},
)
# Train
training_performance = reader.train(dataset, training_args=training_args)
# general_performance = evaluate.evaluate(reader, max_evals=20)
with tempfile.TemporaryDirectory() as dp:
# reader.save(dp)
# utils.save_dict(general_performance, Path(dp, "general_performance.json"))
utils.save_dict(training_performance, Path(dp, "training_performance.json"))
mlflow.log_artifacts(dp)
def get_artifacts_dir_from_run(run_id: str):
"""Load artifacts directory for a given run_id.
Args:
run_id (str): id of run to load artifacts from.
Returns:
Path: path to artifacts directory.
"""
# Locate specifics artifacts directory
experiment_id = mlflow.get_run(run_id=run_id).info.experiment_id
artifacts_dir = Path(config.RUN_REGISTRY, experiment_id, run_id, "artifacts")
return artifacts_dir
if __name__ == "__main__":
# get args filepath from input
args_fp = sys.argv[1]
# load the args_file
args = Namespace(**utils.load_dict(filepath=args_fp)).__dict__
# pop meta variables
model_name = args.pop("model")
dataset_name = args.pop("dataset")
experiment_name = args.pop("experiment")
run_name = args.pop("run")
# Perform training
train_model(
model_name=model_name,
dataset_name=dataset_name,
experiment_name=experiment_name,
run_name=run_name,
training_args=args,
)