import click import datetime import pprint from typing import Optional from src import ( load_dataset, fit_predict_with_model, score_predictions, AVAILABLE_DATASETS, AVAILABLE_MODELS, SEASONALITY_MAP, ) def apply_ablation(ablation: str, model_kwargs: dict) -> dict: if ablation == "NoEnsemble": model_kwargs["enable_ensemble"] = False elif ablation == "NoDeepModels": model_kwargs["hyperparameters"] = { "Naive": {}, "SeasonalNaive": {}, "ARIMA": {}, "ETS": {}, "AutoETS": {}, "AutoARIMA": {}, "Theta": {}, "AutoGluonTabular": {}, } elif ablation == "NoStatModels": model_kwargs["hyperparameters"] = { "AutoGluonTabular": {}, "DeepAR": {}, "SimpleFeedForward": {}, "TemporalFusionTransformer": {}, } elif ablation == "NoTreeModels": model_kwargs["hyperparameters"] = { "Naive": {}, "SeasonalNaive": {}, "ARIMA": {}, "ETS": {}, "AutoETS": {}, "AutoARIMA": {}, "Theta": {}, "DeepAR": {}, "SimpleFeedForward": {}, "TemporalFusionTransformer": {}, } return model_kwargs @click.command( context_settings=dict( ignore_unknown_options=True, allow_extra_args=True, ) ) @click.option( "--dataset_name", "-d", required=True, default="m3_other", help="The dataset to train the model on", type=click.Choice(AVAILABLE_DATASETS), ) @click.option( "--model_name", "-m", default="autogluon", help="Model to train", type=click.Choice(AVAILABLE_MODELS), ) @click.option( "--eval_metric", "-e", default="MASE", type=click.Choice(["MASE", "mean_wQuantileLoss"]), ) @click.option( "--seed", "-s", default=1, type=int, ) @click.option( "--time_limit", "-t", default=4 * 3600, type=int, ) @click.option( "--ablation", "-a", default=None, type=click.Choice(["NoEnsemble", "NoDeepModels", "NoStatModels", "NoTreeModels"]), ) @click.pass_context def main( ctx, dataset_name: str, model_name: str, eval_metric: str, seed: int, time_limit: int, ablation: Optional[str], ): print(f"Evaluating {model_name} on {dataset_name}") dataset = load_dataset(dataset_name) task_kwargs = { "prediction_length": dataset.metadata.prediction_length, "freq": dataset.metadata.freq, "eval_metric": eval_metric, "seasonality": SEASONALITY_MAP[dataset.metadata.freq], } print("Task definition:") pprint.pprint(task_kwargs) # Additional command line arguments like `--name value` are parsed as {"name": "value"} model_kwargs = {ctx.args[i][2:]: ctx.args[i + 1] for i in range(0, len(ctx.args), 2)} model_kwargs["seed"] = seed model_kwargs["time_limit"] = time_limit if ablation is not None: assert model_name == "autogluon", f"{model_name} does not support ablations" model_kwargs = apply_ablation(ablation, model_kwargs) if len(model_kwargs) > 0: print("Model kwargs:") pprint.pprint(model_kwargs) print(f"Starting training {datetime.datetime.now()}") predictions, info = fit_predict_with_model( model_name, dataset.train, **task_kwargs, **model_kwargs ) metrics = score_predictions( dataset=dataset.test, predictions=predictions, prediction_length=task_kwargs["prediction_length"], seasonality=task_kwargs["seasonality"], ) print("================================================") print(f"model: {model_name}") print(f"dataset: {dataset_name}") print(f"total_run_time: {info['run_time']:.2f}") print(f"mase: {metrics['MASE']:.4f}") print(f"mean_wQuantileLoss: {metrics['mean_wQuantileLoss']:.4f}") if __name__ == "__main__": main()