File size: 3,096 Bytes
b8bf9dd
 
5eff4ab
b8bf9dd
 
 
5eff4ab
 
 
 
 
b8bf9dd
 
 
 
 
 
 
 
 
 
 
 
 
 
5eff4ab
b8bf9dd
 
 
 
 
 
5eff4ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8bf9dd
 
5eff4ab
 
 
 
 
 
4408db4
5eff4ab
 
b8bf9dd
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import os
from data_loader import load_and_process_data, CATEGORICAL_COLUMNS
from model_trainer import train_models
from model_manager import save_models, load_models
from model_predictor import predict
from config import MODEL_DIR, CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS
import wandb
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report
import pandas as pd

## ===========================
#  MAIN FUNCTION
# ===========================

def main(train=True, retrain=False):
    """ Main entry point to train, retrain or predict """
    # Create model directory if it doesn't exist
    if not os.path.exists(MODEL_DIR):
        os.makedirs(MODEL_DIR)
    print("\nπŸš€ Loading data...")
    X_train, X_val, y_train, y_val, test_df = load_and_process_data()

    if train or retrain:
        print("\nπŸš€ Training models...")
        models = train_models(X_train, y_train, CATEGORICAL_COLUMNS)
        save_models(models)

    else:
        print("\nπŸš€ Loading existing models...")
        models = load_models()


    # add wandb, validation set scoring
    param_grid = {"CATBOOST_PARAMS": CATBOOST_PARAMS,
                  "XGB_PARAMS": XGB_PARAMS,
                  "RF_PARAMS": RF_PARAMS}
    os.getenv("WANDB_API_KEY")
    run = wandb.init(project="is_click_predictor", config=param_grid)

    print("\nπŸ” Makings predictions for validation set...")
    predictions_val = predict(models, X_val)
    accuracy_val = accuracy_score(y_val, predictions_val["is_click_predicted"])
    balanced_accuracy_val = balanced_accuracy_score(y_val, predictions_val["is_click_predicted"])
    classification_report_val = classification_report(y_val, predictions_val["is_click_predicted"], output_dict=True)
    classification_report_val = pd.DataFrame(classification_report_val).transpose()
    predictions_val_table = wandb.Table(dataframe=predictions_val)
    classification_report_val_table = wandb.Table(dataframe=classification_report_val)

    print("\nπŸ” Making predictions for test set...")
    predictions = predict(models, test_df)

    # wandb logging
    run.log({"param_grid": param_grid,
               "accuracy_val": accuracy_val,
               "balanced_accuracy_val": balanced_accuracy_val,
               "classification_report_val_table": classification_report_val_table,
               "predictions_val_table": predictions_val_table,
               "y_val": y_val.tolist()})
    run.finish()

    # Save final predictions
    predictions.to_csv("final_predictions.csv", index=False)
    print("\nβœ… Predictions saved successfully as 'final_predictions.csv'!")

# ===========================
#  COMMAND-LINE EXECUTION
# ===========================
if __name__ == "__main__":
    # parser = argparse.ArgumentParser(description="Train, retrain or make predictions")
    # parser.add_argument("--train", action="store_true", help="Train new models")
    # parser.add_argument("--retrain", action="store_true", help="Retrain models with updated data")
    #
    # args = parser.parse_args()
    main(train=True, retrain=False)