import time from catboost import CatBoostClassifier from xgboost import XGBClassifier from sklearn.ensemble import RandomForestClassifier from config import CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS def train_models(X_train, y_train, categorical_columns): """ Train and return machine learning models """ models = {} # Train CatBoost start_time = time.time() catboost = CatBoostClassifier(**CATBOOST_PARAMS) catboost.fit(X_train, y_train, cat_features=[X_train.columns.get_loc(col) for col in categorical_columns]) models["CatBoost"] = catboost print(f"✅ CatBoost trained in {time.time() - start_time:.2f} sec") # Train XGBoost if set(y_train.unique()) <= {0, 1}: # Ensure only valid labels exist start_time = time.time() xgb = XGBClassifier(**XGB_PARAMS) xgb.fit(X_train, y_train) models["XGBoost"] = xgb print(f"✅ XGBoost trained in {time.time() - start_time:.2f} sec") else: x_train_xgboost = X_train[~y_train.isna()] y_train_xgboost = y_train.dropna() if set(y_train_xgboost.unique()) <= {0, 1}: start_time = time.time() xgb = XGBClassifier(**XGB_PARAMS) xgb.fit(x_train_xgboost, y_train_xgboost) models["XGBoost"] = xgb print(f"✅ XGBoost trained in {time.time() - start_time:.2f} sec") else: models["XGBoost"] = None print("⚠ XGBoost training skipped due to invalid labels!") # Train RandomForest start_time = time.time() rf = RandomForestClassifier(**RF_PARAMS) rf.fit(X_train, y_train) models["RandomForest"] = rf print(f"✅ RandomForest trained in {time.time() - start_time:.2f} sec") return models