time-series-score / run_experiment.py
kashif's picture
kashif HF staff
Upload run_experiment.py
a028d0b
import click
import datetime
import pprint
from typing import Optional
from src import (
load_dataset,
fit_predict_with_model,
score_predictions,
AVAILABLE_DATASETS,
AVAILABLE_MODELS,
SEASONALITY_MAP,
)
def apply_ablation(ablation: str, model_kwargs: dict) -> dict:
if ablation == "NoEnsemble":
model_kwargs["enable_ensemble"] = False
elif ablation == "NoDeepModels":
model_kwargs["hyperparameters"] = {
"Naive": {},
"SeasonalNaive": {},
"ARIMA": {},
"ETS": {},
"AutoETS": {},
"AutoARIMA": {},
"Theta": {},
"AutoGluonTabular": {},
}
elif ablation == "NoStatModels":
model_kwargs["hyperparameters"] = {
"AutoGluonTabular": {},
"DeepAR": {},
"SimpleFeedForward": {},
"TemporalFusionTransformer": {},
}
elif ablation == "NoTreeModels":
model_kwargs["hyperparameters"] = {
"Naive": {},
"SeasonalNaive": {},
"ARIMA": {},
"ETS": {},
"AutoETS": {},
"AutoARIMA": {},
"Theta": {},
"DeepAR": {},
"SimpleFeedForward": {},
"TemporalFusionTransformer": {},
}
return model_kwargs
@click.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
)
)
@click.option(
"--dataset_name",
"-d",
required=True,
default="m3_other",
help="The dataset to train the model on",
type=click.Choice(AVAILABLE_DATASETS),
)
@click.option(
"--model_name",
"-m",
default="autogluon",
help="Model to train",
type=click.Choice(AVAILABLE_MODELS),
)
@click.option(
"--eval_metric",
"-e",
default="MASE",
type=click.Choice(["MASE", "mean_wQuantileLoss"]),
)
@click.option(
"--seed",
"-s",
default=1,
type=int,
)
@click.option(
"--time_limit",
"-t",
default=4 * 3600,
type=int,
)
@click.option(
"--ablation",
"-a",
default=None,
type=click.Choice(["NoEnsemble", "NoDeepModels", "NoStatModels", "NoTreeModels"]),
)
@click.pass_context
def main(
ctx,
dataset_name: str,
model_name: str,
eval_metric: str,
seed: int,
time_limit: int,
ablation: Optional[str],
):
print(f"Evaluating {model_name} on {dataset_name}")
dataset = load_dataset(dataset_name)
task_kwargs = {
"prediction_length": dataset.metadata.prediction_length,
"freq": dataset.metadata.freq,
"eval_metric": eval_metric,
"seasonality": SEASONALITY_MAP[dataset.metadata.freq],
}
print("Task definition:")
pprint.pprint(task_kwargs)
# Additional command line arguments like `--name value` are parsed as {"name": "value"}
model_kwargs = {ctx.args[i][2:]: ctx.args[i + 1] for i in range(0, len(ctx.args), 2)}
model_kwargs["seed"] = seed
model_kwargs["time_limit"] = time_limit
if ablation is not None:
assert model_name == "autogluon", f"{model_name} does not support ablations"
model_kwargs = apply_ablation(ablation, model_kwargs)
if len(model_kwargs) > 0:
print("Model kwargs:")
pprint.pprint(model_kwargs)
print(f"Starting training {datetime.datetime.now()}")
predictions, info = fit_predict_with_model(
model_name, dataset.train, **task_kwargs, **model_kwargs
)
metrics = score_predictions(
dataset=dataset.test,
predictions=predictions,
prediction_length=task_kwargs["prediction_length"],
seasonality=task_kwargs["seasonality"],
)
print("================================================")
print(f"model: {model_name}")
print(f"dataset: {dataset_name}")
print(f"total_run_time: {info['run_time']:.2f}")
print(f"mase: {metrics['MASE']:.4f}")
print(f"mean_wQuantileLoss: {metrics['mean_wQuantileLoss']:.4f}")
if __name__ == "__main__":
main()