Spaces:
Configuration error
Configuration error
import time | |
from catboost import CatBoostClassifier | |
from xgboost import XGBClassifier | |
from sklearn.ensemble import RandomForestClassifier | |
from config import CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS | |
def train_models(X_train, y_train, categorical_columns): | |
""" Train and return machine learning models """ | |
models = {} | |
# Train CatBoost | |
start_time = time.time() | |
catboost = CatBoostClassifier(**CATBOOST_PARAMS) | |
catboost.fit(X_train, y_train, cat_features=[X_train.columns.get_loc(col) for col in categorical_columns]) | |
models["CatBoost"] = catboost | |
print(f"β CatBoost trained in {time.time() - start_time:.2f} sec") | |
# Train XGBoost | |
if set(y_train.unique()) <= {0, 1}: # Ensure only valid labels exist | |
start_time = time.time() | |
xgb = XGBClassifier(**XGB_PARAMS) | |
xgb.fit(X_train, y_train) | |
models["XGBoost"] = xgb | |
print(f"β XGBoost trained in {time.time() - start_time:.2f} sec") | |
else: | |
x_train_xgboost = X_train[~y_train.isna()] | |
y_train_xgboost = y_train.dropna() | |
if set(y_train_xgboost.unique()) <= {0, 1}: | |
start_time = time.time() | |
xgb = XGBClassifier(**XGB_PARAMS) | |
xgb.fit(x_train_xgboost, y_train_xgboost) | |
models["XGBoost"] = xgb | |
print(f"β XGBoost trained in {time.time() - start_time:.2f} sec") | |
else: | |
models["XGBoost"] = None | |
print("β XGBoost training skipped due to invalid labels!") | |
# Train RandomForest | |
start_time = time.time() | |
rf = RandomForestClassifier(**RF_PARAMS) | |
rf.fit(X_train, y_train) | |
models["RandomForest"] = rf | |
print(f"β RandomForest trained in {time.time() - start_time:.2f} sec") | |
return models | |