is_click_predictor / model_trainer.py
KaiquanMah's picture
Yair - added error handling for NA y_train labels
3b7934d verified
import time
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from config import CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS
def train_models(X_train, y_train, categorical_columns):
""" Train and return machine learning models """
models = {}
# Train CatBoost
start_time = time.time()
catboost = CatBoostClassifier(**CATBOOST_PARAMS)
catboost.fit(X_train, y_train, cat_features=[X_train.columns.get_loc(col) for col in categorical_columns])
models["CatBoost"] = catboost
print(f"βœ… CatBoost trained in {time.time() - start_time:.2f} sec")
# Train XGBoost
if set(y_train.unique()) <= {0, 1}: # Ensure only valid labels exist
start_time = time.time()
xgb = XGBClassifier(**XGB_PARAMS)
xgb.fit(X_train, y_train)
models["XGBoost"] = xgb
print(f"βœ… XGBoost trained in {time.time() - start_time:.2f} sec")
else:
x_train_xgboost = X_train[~y_train.isna()]
y_train_xgboost = y_train.dropna()
if set(y_train_xgboost.unique()) <= {0, 1}:
start_time = time.time()
xgb = XGBClassifier(**XGB_PARAMS)
xgb.fit(x_train_xgboost, y_train_xgboost)
models["XGBoost"] = xgb
print(f"βœ… XGBoost trained in {time.time() - start_time:.2f} sec")
else:
models["XGBoost"] = None
print("⚠ XGBoost training skipped due to invalid labels!")
# Train RandomForest
start_time = time.time()
rf = RandomForestClassifier(**RF_PARAMS)
rf.fit(X_train, y_train)
models["RandomForest"] = rf
print(f"βœ… RandomForest trained in {time.time() - start_time:.2f} sec")
return models