Spaces:
Configuration error
Configuration error
import argparse | |
import os | |
from data_loader import load_and_process_data, CATEGORICAL_COLUMNS | |
from model_trainer import train_models | |
from model_manager import save_models, load_models | |
from model_predictor import predict | |
from config import MODEL_DIR, CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS | |
import wandb | |
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report | |
import pandas as pd | |
## =========================== | |
# MAIN FUNCTION | |
# =========================== | |
def main(train=True, retrain=False): | |
""" Main entry point to train, retrain or predict """ | |
# Create model directory if it doesn't exist | |
if not os.path.exists(MODEL_DIR): | |
os.makedirs(MODEL_DIR) | |
print("\nπ Loading data...") | |
X_train, X_val, y_train, y_val, test_df = load_and_process_data() | |
if train or retrain: | |
print("\nπ Training models...") | |
models = train_models(X_train, y_train, CATEGORICAL_COLUMNS) | |
save_models(models) | |
else: | |
print("\nπ Loading existing models...") | |
models = load_models() | |
# add wandb, validation set scoring | |
param_grid = {"CATBOOST_PARAMS": CATBOOST_PARAMS, | |
"XGB_PARAMS": XGB_PARAMS, | |
"RF_PARAMS": RF_PARAMS} | |
os.getenv("WANDB_API_KEY") | |
run = wandb.init(project="is_click_predictor", config=param_grid) | |
print("\nπ Makings predictions for validation set...") | |
predictions_val = predict(models, X_val) | |
accuracy_val = accuracy_score(y_val, predictions_val["is_click_predicted"]) | |
balanced_accuracy_val = balanced_accuracy_score(y_val, predictions_val["is_click_predicted"]) | |
classification_report_val = classification_report(y_val, predictions_val["is_click_predicted"], output_dict=True) | |
classification_report_val = pd.DataFrame(classification_report_val).transpose() | |
predictions_val_table = wandb.Table(dataframe=predictions_val) | |
classification_report_val_table = wandb.Table(dataframe=classification_report_val) | |
print("\nπ Making predictions for test set...") | |
predictions = predict(models, test_df) | |
# wandb logging | |
run.log({"param_grid": param_grid, | |
"accuracy_val": accuracy_val, | |
"balanced_accuracy_val": balanced_accuracy_val, | |
"classification_report_val_table": classification_report_val_table, | |
"predictions_val_table": predictions_val_table, | |
"y_val": y_val.tolist()}) | |
run.finish() | |
# Save final predictions | |
predictions.to_csv("final_predictions.csv", index=False) | |
print("\nβ Predictions saved successfully as 'final_predictions.csv'!") | |
# =========================== | |
# COMMAND-LINE EXECUTION | |
# =========================== | |
if __name__ == "__main__": | |
# parser = argparse.ArgumentParser(description="Train, retrain or make predictions") | |
# parser.add_argument("--train", action="store_true", help="Train new models") | |
# parser.add_argument("--retrain", action="store_true", help="Retrain models with updated data") | |
# | |
# args = parser.parse_args() | |
main(train=True, retrain=False) |