Spaces:
Running
Running
| import streamlit as st | |
| from pathlib import Path | |
| from catboost import CatBoostClassifier | |
| # from xgboost import XGBClassifier | |
| # from lightgbm import LGBMClassifier | |
| from sklearn.ensemble import RandomForestClassifier | |
| MODEL_DIR = Path("src/params") | |
| # MODEL_DIR.mkdir(exist_ok=True) | |
| import yaml | |
| def load_model_params(model_type, target="GVHD", mode="ensemble", path=MODEL_DIR / "model_params.yaml"): | |
| if target not in ["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"]: | |
| raise ValueError("target must be one of 'GVHD', 'Acute GVHD(<100 days)', or 'Chronic GVHD>100 days'") | |
| if mode == "single": | |
| mode = "single_model" | |
| if mode not in ["ensemble", "single_model"]: | |
| raise ValueError("mode must be either 'ensemble', 'single', or 'single_model'") | |
| if model_type not in ["CatBoost", "XGBoost", "LightGBM", "RandomForest"]: | |
| raise ValueError("model_type must be one of 'CatBoost', 'XGBoost', 'LightGBM', or 'RandomForest'") | |
| with open(path, "r") as f: | |
| all_params = yaml.safe_load(f) | |
| params = all_params[model_type][mode] | |
| if "random_seed" in params: | |
| st.session_state.random_seed = params["random_seed"] | |
| return params | |
| def get_model(model_type, mode="ensemble", target="GVHD", best_iter=None, class_weights=None, auto_class_weights=None): | |
| if mode == "single": | |
| mode = "single_model" | |
| if target == "GVHD": | |
| path = MODEL_DIR / "model_params_gvhd.yaml" | |
| elif target == "Acute GVHD(<100 days)": | |
| path = MODEL_DIR / "model_params_acute.yaml" | |
| elif target == "Chronic GVHD>100 days": | |
| path = MODEL_DIR / "model_params_chronic.yaml" | |
| else: | |
| raise ValueError("Unsupported target.") | |
| params = load_model_params(model_type, target, mode, path) | |
| if best_iter is not None: | |
| params["iterations"] = best_iter | |
| if model_type == "CatBoost": | |
| if class_weights is not None: | |
| params["class_weights"] = class_weights | |
| params.pop("auto_class_weights", None) | |
| params.pop("AutoClassWeights", None) | |
| elif auto_class_weights is not None: | |
| params["auto_class_weights"] = auto_class_weights | |
| params.pop("class_weights", None) | |
| return CatBoostClassifier(**params) | |
| elif model_type == "RandomForest": | |
| return RandomForestClassifier(**params) | |
| else: | |
| raise ValueError(f"Unsupported model type: {model_type}") | |
| def save_model( | |
| model, | |
| user_model_name, | |
| metrics_result_single=None, | |
| orig_train_cols=None, | |
| extra_metadata=None, | |
| threshold=None, | |
| ): | |
| from datetime import datetime | |
| import io | |
| import pickle | |
| import json | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from huggingface_hub import login, CommitScheduler | |
| import os | |
| if "HF_TOKEN" in os.environ: | |
| login(token=os.environ["HF_TOKEN"]) | |
| if "HF_REPO_ID" not in os.environ or "HF_TOKEN" not in os.environ: | |
| raise EnvironmentError("HF_REPO_ID or HF_TOKEN not set.") | |
| timestamp = datetime.now().strftime("%y%m%d_%H%M%S") | |
| filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_single" | |
| # Prepare model dict (same as before) | |
| model_data = { | |
| "timestamp": timestamp, | |
| "model_name": user_model_name, | |
| "target_col": st.session_state.get("target_col", "UNKNOWN"), | |
| "model": model, | |
| "best_iteration": st.session_state.get("best_iteration"), | |
| "metrics_result_single": metrics_result_single, | |
| "orig_train_cols": orig_train_cols or [], | |
| "threshold": threshold, | |
| "extra_metadata": extra_metadata or {}, | |
| } | |
| # Serialize (pickle) to bytes | |
| model_bytes = pickle.dumps(model_data) | |
| # Prepare Parquet row | |
| row = { | |
| "filename": filename, | |
| "timestamp": timestamp, | |
| "type": "single", | |
| "model_file": {"path": filename, "bytes": model_bytes}, | |
| } | |
| table = pa.Table.from_pylist([row]) | |
| table = table.replace_schema_metadata({ | |
| "huggingface": json.dumps({"info": { | |
| "features": { | |
| "filename": {"_type": "Value", "dtype": "string"}, | |
| "timestamp": {"_type": "Value", "dtype": "string"}, | |
| "type": {"_type": "Value", "dtype": "string"}, | |
| "model_file": {"_type": "Value", "dtype": "binary"}, | |
| } | |
| }}) | |
| }) | |
| # Write to in-memory buffer | |
| buf = io.BytesIO() | |
| pq.write_table(table, buf) | |
| buf.seek(0) | |
| # Upload to HF dataset | |
| scheduler = CommitScheduler( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| path_in_repo="models", | |
| token=os.environ["HF_TOKEN"], | |
| private=True, | |
| folder_path=Path("/tmp/dummy") | |
| ) | |
| scheduler.api.upload_file( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| path_in_repo=f"models/{filename}.parquet", | |
| path_or_fileobj=buf | |
| ) | |
| return filename | |
| # LEGACY: ensemble workflow retained temporarily for backward compatibility. | |
| # New training/inference should use single-model save/load. | |
| def save_model_ensemble( | |
| models, | |
| user_model_name, | |
| best_iterations=None, | |
| fold_scores=None, | |
| metrics_result_ensemble=None, | |
| orig_train_cols=None, | |
| threshold=None, | |
| ): | |
| from datetime import datetime | |
| import io | |
| import pickle | |
| import json | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from huggingface_hub import login, CommitScheduler | |
| import os | |
| if "HF_TOKEN" in os.environ: | |
| login(token=os.environ["HF_TOKEN"]) | |
| if "HF_REPO_ID" not in os.environ or "HF_TOKEN" not in os.environ: | |
| raise EnvironmentError("HF_REPO_ID or HF_TOKEN not set.") | |
| timestamp = datetime.now().strftime("%y%m%d_%H%M%S") | |
| filename = f"{timestamp}{st.session_state.get('target_col', 'UNKNOWN')[0]}_{user_model_name}_ensemble" | |
| ensemble_data = { | |
| "timestamp": timestamp, | |
| "model_name": user_model_name, | |
| "target_col": st.session_state.get("target_col", "UNKNOWN"), | |
| "model": models, | |
| "best_iterations": best_iterations, | |
| "fold_scores": fold_scores, | |
| "metrics_result_ensemble": metrics_result_ensemble, | |
| "orig_train_cols": orig_train_cols or [], | |
| "threshold": threshold, | |
| } | |
| model_bytes = pickle.dumps(ensemble_data) | |
| row = { | |
| "filename": filename, | |
| "timestamp": timestamp, | |
| "type": "ensemble", | |
| "model_file": {"path": filename, "bytes": model_bytes}, | |
| } | |
| table = pa.Table.from_pylist([row]) | |
| table = table.replace_schema_metadata({ | |
| "huggingface": json.dumps({"info": { | |
| "features": { | |
| "filename": {"_type": "Value", "dtype": "string"}, | |
| "timestamp": {"_type": "Value", "dtype": "string"}, | |
| "type": {"_type": "Value", "dtype": "string"}, | |
| "model_file": {"_type": "Value", "dtype": "binary"}, | |
| } | |
| }}) | |
| }) | |
| buf = io.BytesIO() | |
| pq.write_table(table, buf) | |
| buf.seek(0) | |
| scheduler = CommitScheduler( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| path_in_repo="models", | |
| token=os.environ["HF_TOKEN"], | |
| private=True, | |
| folder_path=Path("/tmp/dummy") | |
| ) | |
| scheduler.api.upload_file( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| path_in_repo=f"models/{filename}.parquet", | |
| path_or_fileobj=buf | |
| ) | |
| return filename | |
| def load_model(model_name): | |
| from huggingface_hub import login, hf_hub_download | |
| import pyarrow.parquet as pq | |
| import pickle | |
| import os | |
| if "HF_TOKEN" in os.environ: | |
| login(token=os.environ["HF_TOKEN"]) | |
| if "HF_REPO_ID" not in os.environ or "HF_TOKEN" not in os.environ: | |
| raise EnvironmentError("HF_REPO_ID or HF_TOKEN not set.") | |
| from huggingface_hub import HfApi | |
| api = HfApi(token=os.environ["HF_TOKEN"]) | |
| all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset") | |
| model_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")] | |
| # Find matching filename | |
| target_file = None | |
| for f in model_files: | |
| downloaded = hf_hub_download( | |
| repo_id=os.environ["HF_REPO_ID"], | |
| repo_type="dataset", | |
| filename=f, | |
| token=os.environ["HF_TOKEN"] | |
| ) | |
| table = pq.read_table(downloaded) | |
| row = table.to_pylist()[0] | |
| if row["filename"] == model_name: | |
| target_file = downloaded | |
| break | |
| if not target_file: | |
| raise FileNotFoundError(f"Model {model_name} not found in repo.") | |
| model_bytes = row["model_file"]["bytes"] | |
| return pickle.loads(model_bytes) | |
| # LEGACY: ensemble workflow retained temporarily for backward compatibility. | |
| # New training/inference should use single-model save/load. | |
| def load_model_ensemble(filename): | |
| return load_model(filename) | |
| # LEGACY: ensemble workflow retained temporarily for backward compatibility. | |
| # New training/inference should use single-model save/load. | |
| def ensemble_predict(models, X, cat_features): | |
| preds = sum([model.predict_proba(X)[:, 1] for model in models]) / len(models) | |
| return preds | |
| # ------------------------- | |
| # Latest-model helpers (HF) | |
| # ------------------------- | |
| import re | |
| from datetime import datetime | |
| _MODEL_NAME_RE = re.compile( | |
| r"^(?P<date>\d{6})_(?P<time>\d{6})(?P<tgt>[A-Z])_(?P<label>.+)_(?P<mode>single|ensemble)$" | |
| ) | |
| def _parse_model_dt(model_name: str): | |
| m = _MODEL_NAME_RE.match(model_name) | |
| if not m: | |
| return None | |
| return datetime.strptime(m.group("date") + m.group("time"), "%y%m%d%H%M%S") | |
| def list_saved_model_names_hf() -> list[str]: | |
| """ | |
| Lists model names from HF dataset under models/*.parquet. | |
| Returns names without .parquet, e.g.: | |
| ["260312_143055A_some_model_name_single", ...] | |
| """ | |
| from huggingface_hub import HfApi, login | |
| from pathlib import Path | |
| import os | |
| if "HF_TOKEN" in os.environ: | |
| login(token=os.environ["HF_TOKEN"]) | |
| if "HF_REPO_ID" not in os.environ or "HF_TOKEN" not in os.environ: | |
| raise EnvironmentError("HF_REPO_ID or HF_TOKEN not set.") | |
| api = HfApi(token=os.environ.get("HF_TOKEN")) | |
| all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset") | |
| model_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")] | |
| return [Path(f).stem for f in model_files] | |
| def get_latest_model_name_hf(target_initial: str, mode: str = "single", name_contains: str | None = None) -> str: | |
| """ | |
| Pick latest model by timestamp in name. | |
| Filters: | |
| - target_initial: 'A' acute, 'C' chronic (from your naming convention) | |
| - mode: 'ensemble' or 'single' | |
| - name_contains: optional substring filter ("Acute GVHD", "Chronic GVHD") | |
| """ | |
| target_initial = target_initial.upper().strip() | |
| mode = mode.lower().strip() | |
| if mode not in {"ensemble", "single"}: | |
| raise ValueError("mode must be 'ensemble' or 'single'") | |
| names = list_saved_model_names_hf() | |
| candidates = [] | |
| for nm in names: | |
| m = _MODEL_NAME_RE.match(nm) | |
| if not m: | |
| continue | |
| if m.group("tgt") != target_initial: | |
| continue | |
| if m.group("mode") != mode: | |
| continue | |
| if name_contains and (name_contains.lower() not in nm.lower()): | |
| continue | |
| dt = _parse_model_dt(nm) | |
| if dt is None: | |
| continue | |
| candidates.append((dt, nm)) | |
| if not candidates: | |
| msg = f"No matching {mode} model found for target_initial={target_initial}" | |
| if name_contains: | |
| msg += f" and name_contains='{name_contains}'" | |
| raise FileNotFoundError(msg) | |
| candidates.sort(key=lambda x: x[0], reverse=True) | |
| return candidates[0][1] | |
| def load_latest_acute_ensemble(): | |
| latest_name = get_latest_model_name_hf("A", mode="ensemble", name_contains="Acute GVHD") | |
| return load_model_ensemble(latest_name) | |
| def load_latest_chronic_ensemble(): | |
| latest_name = get_latest_model_name_hf("C", mode="ensemble", name_contains="Chronic GVHD") | |
| return load_model_ensemble(latest_name) | |
| def load_latest_acute_single(): | |
| latest_name = get_latest_model_name_hf("A", mode="single") | |
| return load_model(latest_name) | |
| def load_latest_chronic_single(): | |
| latest_name = get_latest_model_name_hf("C", mode="single") | |
| return load_model(latest_name) | |
| def load_latest_gvhd_single(): | |
| latest_name = get_latest_model_name_hf("G", mode="single") | |
| return load_model(latest_name) |