File size: 885 Bytes
b8bf9dd
 
 
 
 
 
 
 
b0bc543
 
 
b8bf9dd
 
 
 
 
 
 
 
b0bc543
b8bf9dd
 
 
 
b0bc543
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import joblib
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
from config import CATBOOST_MODEL_PATH, XGB_MODEL_PATH, RF_MODEL_PATH

def save_models(models):
    """ Save trained models """
    models["CatBoost"].save_model(CATBOOST_MODEL_PATH)
    if models["XGBoost"] is not None:
        # Save XGBoost model in binary format to reduce memory usage
        models["XGBoost"].get_booster().save_model(XGB_MODEL_PATH)
    joblib.dump(models["RandomForest"], RF_MODEL_PATH)
    print("✅ Models saved successfully!")

def load_models():
    """ Load trained models """
    catboost = CatBoostClassifier()
    catboost.load_model(CATBOOST_MODEL_PATH)

    xgb = XGBClassifier()  # Load XGBoost model in binary format
    xgb.load_model(XGB_MODEL_PATH)

    rf = joblib.load(RF_MODEL_PATH)

    return {"CatBoost": catboost, "XGBoost": xgb, "RandomForest": rf}