gvhd-intel-pro / src /model_utils.py
Synav's picture
Update src/model_utils.py
ac14a48 verified
import streamlit as st
from pathlib import Path
from catboost import CatBoostClassifier
# from xgboost import XGBClassifier
# from lightgbm import LGBMClassifier
from sklearn.ensemble import RandomForestClassifier
MODEL_DIR = Path("src/params")
# MODEL_DIR.mkdir(exist_ok=True)
import yaml
def load_model_params(model_type, target="GVHD", mode="ensemble", path=MODEL_DIR / "model_params.yaml"):
if target not in ["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"]:
raise ValueError("target must be one of 'GVHD', 'Acute GVHD(<100 days)', or 'Chronic GVHD>100 days'")
if mode == "single":
mode = "single_model"
if mode not in ["ensemble", "single_model"]:
raise ValueError("mode must be either 'ensemble', 'single', or 'single_model'")
if model_type not in ["CatBoost", "XGBoost", "LightGBM", "RandomForest"]:
raise ValueError("model_type must be one of 'CatBoost', 'XGBoost', 'LightGBM', or 'RandomForest'")
with open(path, "r") as f:
all_params = yaml.safe_load(f)
params = all_params[model_type][mode]
if "random_seed" in params:
st.session_state.random_seed = params["random_seed"]
return params
def get_model(model_type, mode="ensemble", target="GVHD", best_iter=None, class_weights=None, auto_class_weights=None):
if mode == "single":
mode = "single_model"
if target == "GVHD":
path = MODEL_DIR / "model_params_gvhd.yaml"
elif target == "Acute GVHD(<100 days)":
path = MODEL_DIR / "model_params_acute.yaml"
elif target == "Chronic GVHD>100 days":
path = MODEL_DIR / "model_params_chronic.yaml"
else:
raise ValueError("Unsupported target.")
params = load_model_params(model_type, target, mode, path)
if best_iter is not None:
params["iterations"] = best_iter
if model_type == "CatBoost":
if class_weights is not None:
params["class_weights"] = class_weights
params.pop("auto_class_weights", None)
params.pop("AutoClassWeights", None)
elif auto_class_weights is not None:
params["auto_class_weights"] = auto_class_weights
params.pop("class_weights", None)
return CatBoostClassifier(**params)
elif model_type == "RandomForest":
return RandomForestClassifier(**params)
else:
raise ValueError(f"Unsupported model type: {model_type}")
def save_model(
model,
user_model_name,
metrics_result_single=None,
orig_train_cols=None,
extra_metadata=None,
threshold=None,
):
from datetime import datetime
import io
import pickle
import json
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import login, CommitScheduler
import os
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
if "HF_REPO_ID" not in os.environ or "HF_TOKEN" not in os.environ:
raise EnvironmentError("HF_REPO_ID or HF_TOKEN not set.")
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_single"
# Prepare model dict (same as before)
model_data = {
"timestamp": timestamp,
"model_name": user_model_name,
"target_col": st.session_state.get("target_col", "UNKNOWN"),
"model": model,
"best_iteration": st.session_state.get("best_iteration"),
"metrics_result_single": metrics_result_single,
"orig_train_cols": orig_train_cols or [],
"threshold": threshold,
"extra_metadata": extra_metadata or {},
}
# Serialize (pickle) to bytes
model_bytes = pickle.dumps(model_data)
# Prepare Parquet row
row = {
"filename": filename,
"timestamp": timestamp,
"type": "single",
"model_file": {"path": filename, "bytes": model_bytes},
}
table = pa.Table.from_pylist([row])
table = table.replace_schema_metadata({
"huggingface": json.dumps({"info": {
"features": {
"filename": {"_type": "Value", "dtype": "string"},
"timestamp": {"_type": "Value", "dtype": "string"},
"type": {"_type": "Value", "dtype": "string"},
"model_file": {"_type": "Value", "dtype": "binary"},
}
}})
})
# Write to in-memory buffer
buf = io.BytesIO()
pq.write_table(table, buf)
buf.seek(0)
# Upload to HF dataset
scheduler = CommitScheduler(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
path_in_repo="models",
token=os.environ["HF_TOKEN"],
private=True,
folder_path=Path("/tmp/dummy")
)
scheduler.api.upload_file(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
path_in_repo=f"models/{filename}.parquet",
path_or_fileobj=buf
)
return filename
# LEGACY: ensemble workflow retained temporarily for backward compatibility.
# New training/inference should use single-model save/load.
def save_model_ensemble(
models,
user_model_name,
best_iterations=None,
fold_scores=None,
metrics_result_ensemble=None,
orig_train_cols=None,
threshold=None,
):
from datetime import datetime
import io
import pickle
import json
import pyarrow as pa
import pyarrow.parquet as pq
from huggingface_hub import login, CommitScheduler
import os
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
if "HF_REPO_ID" not in os.environ or "HF_TOKEN" not in os.environ:
raise EnvironmentError("HF_REPO_ID or HF_TOKEN not set.")
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_ensemble"
ensemble_data = {
"timestamp": timestamp,
"model_name": user_model_name,
"target_col": st.session_state.get("target_col", "UNKNOWN"),
"model": models,
"best_iterations": best_iterations,
"fold_scores": fold_scores,
"metrics_result_ensemble": metrics_result_ensemble,
"orig_train_cols": orig_train_cols or [],
"threshold": threshold,
}
model_bytes = pickle.dumps(ensemble_data)
row = {
"filename": filename,
"timestamp": timestamp,
"type": "ensemble",
"model_file": {"path": filename, "bytes": model_bytes},
}
table = pa.Table.from_pylist([row])
table = table.replace_schema_metadata({
"huggingface": json.dumps({"info": {
"features": {
"filename": {"_type": "Value", "dtype": "string"},
"timestamp": {"_type": "Value", "dtype": "string"},
"type": {"_type": "Value", "dtype": "string"},
"model_file": {"_type": "Value", "dtype": "binary"},
}
}})
})
buf = io.BytesIO()
pq.write_table(table, buf)
buf.seek(0)
scheduler = CommitScheduler(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
path_in_repo="models",
token=os.environ["HF_TOKEN"],
private=True,
folder_path=Path("/tmp/dummy")
)
scheduler.api.upload_file(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
path_in_repo=f"models/{filename}.parquet",
path_or_fileobj=buf
)
return filename
def load_model(model_name):
from huggingface_hub import login, hf_hub_download
import pyarrow.parquet as pq
import pickle
import os
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
if "HF_REPO_ID" not in os.environ or "HF_TOKEN" not in os.environ:
raise EnvironmentError("HF_REPO_ID or HF_TOKEN not set.")
from huggingface_hub import HfApi
api = HfApi(token=os.environ["HF_TOKEN"])
all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset")
model_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")]
# Find matching filename
target_file = None
for f in model_files:
downloaded = hf_hub_download(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
filename=f,
token=os.environ["HF_TOKEN"]
)
table = pq.read_table(downloaded)
row = table.to_pylist()[0]
if row["filename"] == model_name:
target_file = downloaded
break
if not target_file:
raise FileNotFoundError(f"Model {model_name} not found in repo.")
model_bytes = row["model_file"]["bytes"]
return pickle.loads(model_bytes)
# LEGACY: ensemble workflow retained temporarily for backward compatibility.
# New training/inference should use single-model save/load.
def load_model_ensemble(filename):
return load_model(filename)
# LEGACY: ensemble workflow retained temporarily for backward compatibility.
# New training/inference should use single-model save/load.
def ensemble_predict(models, X, cat_features):
preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models)
return preds
# -------------------------
# Latest-model helpers (HF)
# -------------------------
import re
from datetime import datetime
_MODEL_NAME_RE = re.compile(
r"^(?P<date>\d{6})_(?P<time>\d{6})(?P<tgt>[A-Z])_(?P<label>.+)_(?P<mode>single|ensemble)$"
)
def _parse_model_dt(model_name: str):
m = _MODEL_NAME_RE.match(model_name)
if not m:
return None
return datetime.strptime(m.group("date") + m.group("time"), "%y%m%d%H%M%S")
def list_saved_model_names_hf() -> list[str]:
"""
Lists model names from HF dataset under models/*.parquet.
Returns names without .parquet, e.g.:
["260312_143055A_some_model_name_single", ...]
"""
from huggingface_hub import HfApi, login
from pathlib import Path
import os
if "HF_TOKEN" in os.environ:
login(token=os.environ["HF_TOKEN"])
if "HF_REPO_ID" not in os.environ or "HF_TOKEN" not in os.environ:
raise EnvironmentError("HF_REPO_ID or HF_TOKEN not set.")
api = HfApi(token=os.environ.get("HF_TOKEN"))
all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset")
model_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")]
return [Path(f).stem for f in model_files]
def get_latest_model_name_hf(target_initial: str, mode: str = "single", name_contains: str | None = None) -> str:
"""
Pick latest model by timestamp in name.
Filters:
- target_initial: 'A' acute, 'C' chronic (from your naming convention)
- mode: 'ensemble' or 'single'
- name_contains: optional substring filter ("Acute GVHD", "Chronic GVHD")
"""
target_initial = target_initial.upper().strip()
mode = mode.lower().strip()
if mode not in {"ensemble", "single"}:
raise ValueError("mode must be 'ensemble' or 'single'")
names = list_saved_model_names_hf()
candidates = []
for nm in names:
m = _MODEL_NAME_RE.match(nm)
if not m:
continue
if m.group("tgt") != target_initial:
continue
if m.group("mode") != mode:
continue
if name_contains and (name_contains.lower() not in nm.lower()):
continue
dt = _parse_model_dt(nm)
if dt is None:
continue
candidates.append((dt, nm))
if not candidates:
msg = f"No matching {mode} model found for target_initial={target_initial}"
if name_contains:
msg += f" and name_contains='{name_contains}'"
raise FileNotFoundError(msg)
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1]
def load_latest_acute_ensemble():
latest_name = get_latest_model_name_hf("A", mode="ensemble", name_contains="Acute GVHD")
return load_model_ensemble(latest_name)
def load_latest_chronic_ensemble():
latest_name = get_latest_model_name_hf("C", mode="ensemble", name_contains="Chronic GVHD")
return load_model_ensemble(latest_name)
def load_latest_acute_single():
latest_name = get_latest_model_name_hf("A", mode="single")
return load_model(latest_name)
def load_latest_chronic_single():
latest_name = get_latest_model_name_hf("C", mode="single")
return load_model(latest_name)
def load_latest_gvhd_single():
latest_name = get_latest_model_name_hf("G", mode="single")
return load_model(latest_name)