import argparse import os from data_loader import load_and_process_data, CATEGORICAL_COLUMNS from model_trainer import train_models from model_manager import save_models, load_models from model_predictor import predict from config import MODEL_DIR, CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS import wandb from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report import pandas as pd ## =========================== # MAIN FUNCTION # =========================== def main(train=True, retrain=False): """ Main entry point to train, retrain or predict """ # Create model directory if it doesn't exist if not os.path.exists(MODEL_DIR): os.makedirs(MODEL_DIR) print("\nšŸš€ Loading data...") X_train, X_val, y_train, y_val, test_df = load_and_process_data() if train or retrain: print("\nšŸš€ Training models...") models = train_models(X_train, y_train, CATEGORICAL_COLUMNS) save_models(models) else: print("\nšŸš€ Loading existing models...") models = load_models() # add wandb, validation set scoring param_grid = {"CATBOOST_PARAMS": CATBOOST_PARAMS, "XGB_PARAMS": XGB_PARAMS, "RF_PARAMS": RF_PARAMS} os.getenv("WANDB_API_KEY") run = wandb.init(project="is_click_predictor", config=param_grid) print("\nšŸ” Makings predictions for validation set...") predictions_val = predict(models, X_val) accuracy_val = accuracy_score(y_val, predictions_val["is_click_predicted"]) balanced_accuracy_val = balanced_accuracy_score(y_val, predictions_val["is_click_predicted"]) classification_report_val = classification_report(y_val, predictions_val["is_click_predicted"], output_dict=True) classification_report_val = pd.DataFrame(classification_report_val).transpose() predictions_val_table = wandb.Table(dataframe=predictions_val) classification_report_val_table = wandb.Table(dataframe=classification_report_val) print("\nšŸ” Making predictions for test set...") predictions = predict(models, test_df) # wandb logging run.log({"param_grid": param_grid, "accuracy_val": accuracy_val, "balanced_accuracy_val": balanced_accuracy_val, "classification_report_val_table": classification_report_val_table, "predictions_val_table": predictions_val_table, "y_val": y_val.tolist()}) run.finish() # Save final predictions predictions.to_csv("final_predictions.csv", index=False) print("\nāœ… Predictions saved successfully as 'final_predictions.csv'!") # =========================== # COMMAND-LINE EXECUTION # =========================== if __name__ == "__main__": # parser = argparse.ArgumentParser(description="Train, retrain or make predictions") # parser.add_argument("--train", action="store_true", help="Train new models") # parser.add_argument("--retrain", action="store_true", help="Retrain models with updated data") # # args = parser.parse_args() main(train=True, retrain=False)