Spaces:
Configuration error
Configuration error
File size: 1,747 Bytes
b8bf9dd 3b7934d b8bf9dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import time
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from config import CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS
def train_models(X_train, y_train, categorical_columns):
""" Train and return machine learning models """
models = {}
# Train CatBoost
start_time = time.time()
catboost = CatBoostClassifier(**CATBOOST_PARAMS)
catboost.fit(X_train, y_train, cat_features=[X_train.columns.get_loc(col) for col in categorical_columns])
models["CatBoost"] = catboost
print(f"β
CatBoost trained in {time.time() - start_time:.2f} sec")
# Train XGBoost
if set(y_train.unique()) <= {0, 1}: # Ensure only valid labels exist
start_time = time.time()
xgb = XGBClassifier(**XGB_PARAMS)
xgb.fit(X_train, y_train)
models["XGBoost"] = xgb
print(f"β
XGBoost trained in {time.time() - start_time:.2f} sec")
else:
x_train_xgboost = X_train[~y_train.isna()]
y_train_xgboost = y_train.dropna()
if set(y_train_xgboost.unique()) <= {0, 1}:
start_time = time.time()
xgb = XGBClassifier(**XGB_PARAMS)
xgb.fit(x_train_xgboost, y_train_xgboost)
models["XGBoost"] = xgb
print(f"β
XGBoost trained in {time.time() - start_time:.2f} sec")
else:
models["XGBoost"] = None
print("β XGBoost training skipped due to invalid labels!")
# Train RandomForest
start_time = time.time()
rf = RandomForestClassifier(**RF_PARAMS)
rf.fit(X_train, y_train)
models["RandomForest"] = rf
print(f"β
RandomForest trained in {time.time() - start_time:.2f} sec")
return models
|