File size: 1,747 Bytes
b8bf9dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b7934d
 
 
 
 
 
 
 
 
 
 
b8bf9dd
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import time
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from config import CATBOOST_PARAMS, XGB_PARAMS, RF_PARAMS

def train_models(X_train, y_train, categorical_columns):
    """ Train and return machine learning models """
    models = {}

    # Train CatBoost
    start_time = time.time()
    catboost = CatBoostClassifier(**CATBOOST_PARAMS)
    catboost.fit(X_train, y_train, cat_features=[X_train.columns.get_loc(col) for col in categorical_columns])
    models["CatBoost"] = catboost
    print(f"βœ… CatBoost trained in {time.time() - start_time:.2f} sec")

    # Train XGBoost
    if set(y_train.unique()) <= {0, 1}:  # Ensure only valid labels exist
        start_time = time.time()
        xgb = XGBClassifier(**XGB_PARAMS)
        xgb.fit(X_train, y_train)
        models["XGBoost"] = xgb
        print(f"βœ… XGBoost trained in {time.time() - start_time:.2f} sec")
    else:
        x_train_xgboost = X_train[~y_train.isna()]
        y_train_xgboost = y_train.dropna()
        if set(y_train_xgboost.unique()) <= {0, 1}:
            start_time = time.time()
            xgb = XGBClassifier(**XGB_PARAMS)
            xgb.fit(x_train_xgboost, y_train_xgboost)
            models["XGBoost"] = xgb
            print(f"βœ… XGBoost trained in {time.time() - start_time:.2f} sec")
        else:
            models["XGBoost"] = None
            print("⚠ XGBoost training skipped due to invalid labels!")

    # Train RandomForest
    start_time = time.time()
    rf = RandomForestClassifier(**RF_PARAMS)
    rf.fit(X_train, y_train)
    models["RandomForest"] = rf
    print(f"βœ… RandomForest trained in {time.time() - start_time:.2f} sec")

    return models