ashaddams's picture
Update app.py
2263a9c verified
# app.py
# ===============================================
# Algae Yield Predictor — Uncertainty + Response Plot + Bounds + DOI
# ===============================================
import re, json
from dataclasses import dataclass
from pathlib import Path
from difflib import get_close_matches
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gradio as gr
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer
from sklearn.neighbors import NearestNeighbors
# --- sklearn <-> wrappers compatibility shim ---
from sklearn.base import BaseEstimator
if not hasattr(BaseEstimator, "sklearn_tags"):
# scikit-learn < 1.6 only has get_tags(); provide sklearn_tags() alias
def _sklearn_tags(self):
return self.get_tags()
BaseEstimator.sklearn_tags = _sklearn_tags
# Ensemble libs
import joblib
import xgboost as xgb
import lightgbm as lgb
from catboost import CatBoostRegressor
import tensorflow as tf
# ===== Robust sklearn_tags compatibility layer =====
# Works on sklearn<1.6 + 3rd-party wrappers that call super().sklearn_tags()
def _safe_sklearn_tags(self):
"""Return sklearn tags without relying on super().sklearn_tags()."""
try:
if hasattr(self, "get_tags"):
return self.get_tags()
except Exception:
pass
return {}
def _patch_class_and_mro(cls):
"""Attach a safe sklearn_tags to cls and all parents in its MRO."""
if not cls or cls is object:
return
for c in getattr(cls, "mro", lambda: [])():
if c is object:
continue
try:
need = not hasattr(c, "sklearn_tags") or not callable(getattr(c, "sklearn_tags"))
if need:
setattr(c, "sklearn_tags", _safe_sklearn_tags)
except Exception:
pass
# Patch common estimator classes up-front
try:
_patch_class_and_mro(xgb.XGBRegressor)
_patch_class_and_mro(xgb.XGBClassifier)
_patch_class_and_mro(xgb.XGBRFRegressor)
_patch_class_and_mro(xgb.XGBRFClassifier)
except Exception:
pass
try:
_patch_class_and_mro(lgb.LGBMRegressor)
_patch_class_and_mro(lgb.LGBMClassifier)
except Exception:
pass
try:
_patch_class_and_mro(CatBoostRegressor)
except Exception:
pass
# ===== end compatibility layer =====
# -----------------------------
# Paths (relative in a Space)
# -----------------------------
ROOT = Path(".")
RAW_PATH = ROOT / "ai_al.csv" # real data (for allowed-pairs + KNN uncertainty imputer)
DOI_PATH = ROOT / "doi.csv" # optional literature db
MODEL_DIR = ROOT / "models" # saved ensemble models
MODEL_DIR.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------
# Species × Medium literature bounds (normalized)
# ------------------------------------------------
BOUNDS_SM = {
"a. platensis": {
"zarrouks": {"biomass": (0.0, 6.5), "lipid": (0.0, 30), "protein": (0.0, 75), "carb": (0.0, 40)},
"bg 11": {"biomass": (0.0, 6.6), "lipid": (0.0, 33), "protein": (0.0, 75), "carb": (0.0, 40)},
},
"c. pyrenoidosa": {
"bg 11": {"biomass": (0.0, 6.5), "lipid": (0.0, 30), "protein": (0.0, 60), "carb": (0.0, 60)},
"bbm": {"biomass": (0.0, 6.5), "lipid": (0.0, 35), "protein": (0.0, 60), "carb": (0.0, 60)},
"selenite media": {"biomass": (0.0, 6.5), "lipid": (0.0, 30), "protein": (0.0, 55), "carb": (0.0, 60)},
},
"c. sorokiniana": {
"bg 11": {"biomass": (0.0, 5.5), "lipid": (0.0, 45), "protein": (0.0, 45), "carb": (0.0, 60)},
"tap": {"biomass": (0.0, 5.5), "lipid": (0.0, 45), "protein": (0.0, 45), "carb": (0.0, 60)},
},
"c. variabilis": {
"bg 11": {"biomass": (0.0, 5.5), "lipid": (0.0, 35), "protein": (0.0, 45), "carb": (0.0, 45)},
"tap": {"biomass": (0.0, 5.5), "lipid": (0.0, 35), "protein": (0.0, 45), "carb": (0.0, 45)},
"zarrouks":{"biomass": (0.0, 5.5), "lipid": (0.0, 35), "protein": (0.0, 45), "carb": (0.0, 45)},
},
"c. vulgaris": {
"bg 11": {"biomass": (0.0, 6.5), "lipid": (0.0, 45), "protein": (0.0, 50), "carb": (0.0, 55)},
"bbm": {"biomass": (0.0, 6.5), "lipid": (0.0, 45), "protein": (0.0, 50), "carb": (0.0, 55)},
},
"c. zofingiensis": {
"bg 11": {"biomass": (0.0, 6.5), "lipid": (0.0, 50), "protein": (0.0, 45), "carb": (0.0, 55)},
"bbm": {"biomass": (0.0, 6.5), "lipid": (0.0, 50), "protein": (0.0, 45), "carb": (0.0, 55)},
"tap": {"biomass": (0.0, 6.5), "lipid": (0.0, 50), "protein": (0.0, 45), "carb": (0.0, 55)},
},
"h. pluvialis": {
"bg 11": {"biomass": (0.0, 4.5), "lipid": (0.0, 60), "protein": (0.0, 50), "carb": (0.0, 55)},
},
"p. purpureum": {
"artificial sea water": {"biomass": (0.0, 6.5), "lipid": (0.0, 35), "protein": (0.0, 40), "carb": (0.0, 55)},
"f2": {"biomass": (0.0, 6.5), "lipid": (0.0, 35), "protein": (0.0, 50), "carb": (0.0, 55)},
"erdseirber and bold nv": {"biomass": (0.0, 6.5), "lipid": (0.0, 30), "protein": (0.0, 40), "carb": (0.0, 40)},
},
"scenedesmus sp.": {
"bg 11": {"biomass": (0.0, 5.5), "lipid": (0.0, 50), "protein": (0.0, 45), "carb": (0.0, 50)},
"bbm": {"biomass": (0.0, 5.5), "lipid": (0.0, 50), "protein": (0.0, 45), "carb": (0.0, 50)},
},
}
MEDIA_ALIASES = {
"zarrouks": ["zarrouk's", "zarrouks", "zarrouk"],
"zorrouks": ["zarrouk's", "zarrouks", "zarrouk"],
"bg 11": ["bg 11", "bg-11", "bg11"],
"bbm": ["bbm", "bold's basal medium", "bold basal medium", "bolds basal medium"],
"tap": ["tap", "tap water"],
"artificial sea water": ["artificial sea water", "artificial seawater", "asw"],
"erdseirber and bold nv": [
"erdschreiber and bold nv", "erdschreiber", "bold nv", "bold's nv", "erdschreiber & bold nv",
"erdseiber and bold 1nv"
],
"f2": ["f/2", "guillard f/2", "f2"],
"selenite media": ["selenite medium", "selenite media", "selenite enrichment"],
}
def normalize_str(x):
if pd.isna(x): return "nan"
return str(x).strip().lower()
def _canon_media_for_bounds(m: str) -> str:
m = normalize_str(m)
if m in MEDIA_ALIASES:
return m
for k, syns in MEDIA_ALIASES.items():
if m == k or m in [normalize_str(s) for s in syns]:
return k
return m
# Accepts dotted-without-space, dotted-with-space, synonyms, fuzzy fallback.
SPECIES_ALIASES_CANON = {
"a. platensis": ["a.platensis", "a platensis", "arthrospira platensis", "spirulina platensis"],
"c. pyrenoidosa": ["c.pyrenoidosa", "c pyrenoidosa", "chlorella pyrenoidosa"],
"c. sorokiniana": ["c.sorokiniana", "c sorokiniana", "chlorella sorokiniana"],
"c. variabilis": ["c.variabilis", "c variabilis", "chlorella variabilis"],
"c. vulgaris": ["c.vulgaris", "c vulgaris", "chlorella vulgaris"],
"c. zofingiensis": ["c.zofingiensis", "c zofingiensis", "chromochloris zofingiensis", "chlorella zofingiensis"],
"h. pluvialis": ["h.pluvialis", "h pluvialis", "haematococcus pluvialis"],
"p. purpureum": ["p.purpureum", "p purpureum", "porphyridium purpureum"],
"scenedesmus sp.": ["scenedesmus", "scenedesmus sp", "desmodesmus sp."],
}
# --- ADD THIS: maps an arbitrary value to a known encoder class token ---
from difflib import get_close_matches
def _canon_to_known(value, known_classes, alias_map):
"""
Return a token that is guaranteed to exist in known_classes.
- Canonicalize via alias_map
- Exact/normalized match
- Fuzzy fallback
- Else return 'nan' if present, otherwise the first class
"""
# Normalize list of known classes to strings
known = [str(k) for k in list(known_classes)]
# Canonicalize the incoming value using aliases (handles dotted forms etc.)
v = normalize_str(value)
v = _canon_from_alias(v, alias_map)
# Exact hit
if v in known:
return v
# If an alias key resolves to a known token, use it
for k, syns in alias_map.items():
if v == k or v in [normalize_str(s) for s in syns]:
if k in known:
return k
# Try fuzzy match against known tokens
hit = get_close_matches(v, known, n=1, cutoff=0.6)
if hit:
return hit[0]
# Graceful fallback
return "nan" if "nan" in known else known[0]
def _canon_from_alias(value: str, alias_map: dict[str, list[str]]) -> str:
v = normalize_str(value)
if v in alias_map:
return v
for k, syns in alias_map.items():
if v == k or v in [normalize_str(s) for s in syns]:
return k
v2 = v.replace(" .", ".").replace(". ", ".")
for k, syns in alias_map.items():
if v2 == k or v2 in [normalize_str(s) for s in syns]:
return k
v3 = v.replace(" .", ".").replace(".", ". ")
for k, syns in alias_map.items():
if v3 == k or v3 in [normalize_str(s) for s in syns]:
return k
return v
def extract_first_float(x: str):
if pd.isna(x): return np.nan
s = str(x)
m = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", s)
return float(m.group(0)) if m else np.nan
def parse_cycle_first(x: str):
if pd.isna(x): return np.nan
s = str(x)
m = re.search(r"(\d+(?:\.\d+)?)\s*:\s*(\d+(?:\.\d+)?)", s)
return float(m.group(1)) if m else extract_first_float(s)
def coerce_numeric(series: pd.Series, mode: str = "float"):
return series.apply(parse_cycle_first if mode == "cycle_first" else extract_first_float)
def _clamp_scalar(v, lo, hi):
if lo is None or hi is None:
return float(v), False, 0.0
v = float(v)
vc = float(np.clip(v, lo, hi))
return vc, (abs(vc - v) > 1e-12), (vc - v)
def _clamp_array(arr, lo, hi):
arr = np.asarray(arr, dtype=float)
if lo is None or hi is None:
return arr, False
arrc = np.clip(arr, lo, hi)
return arrc, bool(np.any(arrc != arr))
def get_bounds(species: str, media: str, target: str):
# Canonicalize species + media before lookup
sp_raw = (species or "").strip().lower()
md = _canon_media_for_bounds(media)
tg = (target or "").strip().lower()
sp = _canon_from_alias(sp_raw, SPECIES_ALIASES_CANON)
rng = BOUNDS_SM.get(sp, {}).get(md)
if rng is None or tg not in rng:
return None, None
lo, hi = rng[tg]
return float(lo), float(hi)
# -----------------------------
# Curated suggestions
# -----------------------------
SPECIES_SUGGESTIONS = {
"a. platensis": {
"biomass": {"light": "60–300", "days": "15–25"},
"lipid": {"light": "High light intensity (stress)", "days": "15–25"},
"protein": {"light": "60–300", "days": "12–18"},
"carb": {"light": "60–300", "days": "15–25"},
},
"c. pyrenoidosa": {
"biomass": {"light": "50–150", "days": "12–25"},
"lipid": {"light": "High light intensity (stress)", "days": "12–25"},
"protein": {"light": "50–150", "days": "12–18"},
"carb": {"light": "50–150", "days": "12–25"},
},
"c. sorokiniana": {
"biomass": {"light": "60–300", "days": "15–25"},
"lipid": {"light": "High light intensity (stress)", "days": "15–25"},
"protein": {"light": "60–300", "days": "12–18"},
"carb": {"light": "60–300", "days": "15–25"},
},
"c. variabilis": {
"biomass": {"light": "60–250", "days": "15–25"},
"lipid": {"light": "High light intensity (stress)", "days": "15–25"},
"protein": {"light": "60–250", "days": "12–18"},
"carb": {"light": "60–250", "days": "15–25"},
},
"c. vulgaris": {
"biomass": {"light": "60–300", "days": "12–21"},
"lipid": {"light": "High light intensity (stress)", "days": "15–21"},
"protein": {"light": "60–300", "days": "12–18"},
"carb": {"light": "60–300", "days": "12–21"},
},
"c. zofingiensis": {
"biomass": {"light": "50–150", "days": "25–30"},
"lipid": {"light": "High light intensity (stress)", "days": "25–30"},
"protein": {"light": "50–150", "days": "25–30"},
"carb": {"light": "50–150", "days": "25–30"},
},
"h. pluvialis": {
"biomass": {"light": "50–250", "days": "25–30"},
"lipid": {"light": "High light intensity (stress)", "days": "25–30"},
"protein": {"light": "50–250", "days": "25–30"},
"carb": {"light": "50–250", "days": "25–30"},
},
"p. purpureum": {
"biomass": {"light": "100–250", "days": "17–19"},
"lipid": {"light": "High light intensity (stress)", "days": "17–19"},
"protein": {"light": "100–250", "days": "12–15"},
"carb": {"light": "100–250", "days": "17–19"},
},
"scenedesmus sp.": {
"biomass": {"light": "50–250", "days": "12–25"},
"lipid": {"light": "High light intensity (stress)", "days": "12–25"},
"protein": {"light": "50–250", "days": "12–20"},
"carb": {"light": "50–250", "days": "12–25"},
},
}
def _normalize_species_label(s: str) -> str:
if s is None: return ""
s0 = str(s).strip().lower()
s1 = re.sub(r"[_\-]+", " ", s0).replace(" ", " ").strip()
s2 = s1.replace(" .", ".").replace(". ", ". ")
alias = {
"a platensis": "a. platensis", "a.platensis": "a. platensis", "arthrospira platensis": "a. platensis",
"c pyrenoidosa": "c. pyrenoidosa", "c.pyrenoidosa": "c. pyrenoidosa", "chlorella pyrenoidosa": "c. pyrenoidosa",
"c sorokiniana": "c. sorokiniana", "c.sorokiniana": "c. sorokiniana",
"c variabilis": "c. variabilis", "c.variabilis": "c. variabilis",
"c vulgaris": "c. vulgaris", "c.vulgaris": "c. vulgaris", "chlorella vulgaris": "c. vulgaris",
"c zofingiensis": "c. zofingiensis", "c.zofingiensis": "c. zofingiensis",
"h pluvialis": "h. pluvialis", "h.pluvialis": "h. pluvialis", "haematococcus pluvialis": "h. pluvialis",
"p purpureum": "p. purpureum", "p.purpureum": "p. purpureum", "porphyridium purpureum": "p. purpureum",
"scenedesmus": "scenedesmus sp.", "scenedesmus sp": "scenedesmus sp.", "scenedesmus sp.": "scenedesmus sp.",
}
return alias.get(s2, s2)
def _format_suggestion_md(species: str, target: str) -> str:
sp = _normalize_species_label(species)
tg = (target or "").strip().lower()
data = SPECIES_SUGGESTIONS.get(sp, {}).get(tg)
if not data:
return f"> ℹ️ No curated suggestion for **{species}** and **{target}**."
return (
f"### 💡 Suggested conditions for *{sp}* → *{tg}*\n"
f"**Light:** {data['light']} &nbsp;|&nbsp; **Days:** {data['days']}"
)
def update_suggestion_panel(target, species):
return _format_suggestion_md(species, target)
# -----------------------------
# Load and normalize real data (for allowed pairs + KNN imputer)
# -----------------------------
if not RAW_PATH.exists():
raise FileNotFoundError("Missing 'ai_al.csv'. Please upload it to the Space root (same folder as app.py).")
df_raw = pd.read_csv(RAW_PATH)
df_raw.columns = (
df_raw.columns.str.strip()
.str.lower()
.str.replace("[^0-9a-zA-Z]+", "_", regex=True)
)
FEATURES = ["species","media","light","expo_day","expo_night","_c","ph","days"]
CATEGORICAL= ["species","media"]
NUM_CYCLE_FIRST = ["light"]
NUM_PLAIN = ["expo_day","expo_night","_c","ph","days"]
TARGETS = ["biomass","lipid","protein","carb"]
# Normalize cats for encoders
df_enc = df_raw.copy()
for col in CATEGORICAL:
if col in df_enc.columns:
df_enc[col] = df_enc[col].map(normalize_str)
# Fit encoders on CSV vocab (used only for allowed-pairs + KNN imputer)
encoders, value_lists = {}, {}
for col in CATEGORICAL:
le = LabelEncoder()
vals = df_enc[col].astype(str).fillna("nan")
le.fit(vals)
encoders[col] = le
value_lists[col] = sorted(set(vals) - {"nan"})
# Prepare numerics for KNN imputer fit
for c in NUM_CYCLE_FIRST:
if c in df_enc.columns:
df_enc[c] = coerce_numeric(df_enc[c], "cycle_first")
for c in NUM_PLAIN:
if c in df_enc.columns:
df_enc[c] = coerce_numeric(df_enc[c], "float")
def encode_frame(df_like: pd.DataFrame) -> pd.DataFrame:
X = pd.DataFrame()
for col in CATEGORICAL:
if col in df_like.columns:
X[col] = df_like[col].map(normalize_str)
X[col] = encoders[col].transform(X[col].astype(str).fillna("nan"))
for c in NUM_CYCLE_FIRST:
if c in df_like.columns:
X[c] = coerce_numeric(df_like[c], "cycle_first")
for c in NUM_PLAIN:
if c in df_like.columns:
X[c] = coerce_numeric(df_like[c], "float")
for c in FEATURES:
if c not in X.columns:
X[c] = np.nan
return X[FEATURES]
X_for_imputer = encode_frame(df_raw)
imputer = SimpleImputer(strategy="median").fit(X_for_imputer)
# -----------------------------
# Resolve allowed species–media pairs (aliases + fuzzy)
# -----------------------------
ALLOWED_PAIRS_ALIAS = {
"a.platensis": ["zarrouks", "bg 11"],
"c sorokiniana": ["tap", "bg 11"],
"c vulgaris": ["bg 11", "bbm"],
"scenedesmus": ["bg 11", "bbm"],
"p purpureum": ["artificial sea water", "erdseirber and bold nv", "f2"],
"h pluvalis": ["bg 11"],
"c pyreniidosa": ["bg 11", "bbm", "selenite media"],
"c zofingensis": ["bg 11", "bbm", "tap"],
"c variabilis": ["bg 11", "zorrouks", "tap"],
}
SPECIES_ALIASES = {
"a.platensis": ["arthrospira platensis", "spirulina platensis", "a. platensis"],
"c sorokiniana": ["chlorella sorokiniana", "c. sorokiniana"],
"c vulgaris": ["chlorella vulgaris", "c. vulgaris"],
"scendedesmus": ["scenedesmus", "scenedesmus sp.", "desmodesmus sp."],
"scenedesmus": ["scenedesmus", "scenedesmus sp.", "desmodesmus sp."],
"p purpureum": ["porphyridium purpureum", "p. purpureum"],
"h pluvalis": ["haematococcus pluvialis", "h. pluvialis", "h pluvalis"],
"c pyreniidosa": ["chlorella pyrenoidosa", "c. pyrenoidosa", "c pyreniidosa"],
"c zofingensis": ["chromochloris zofingiensis", "c. zofingiensis", "chlorella zofingiensis"],
"c variabilis": ["chlorella variabilis", "c. variabilis"],
}
def match_to_vocab(name: str, vocab: list[str], aliases: dict[str, list[str]], cutoff=0.6):
n = normalize_str(name)
if n in vocab: return n
for syn in aliases.get(n, []):
sn = normalize_str(syn)
if sn in vocab: return sn
hit = get_close_matches(n, vocab, n=1, cutoff=cutoff)
return hit[0] if hit else None
species_vocab = value_lists["species"]
media_vocab = value_lists["media"]
ALLOWED_PAIRS = {}
for s_alias, m_aliases in ALLOWED_PAIRS_ALIAS.items():
s_canon = match_to_vocab(s_alias, species_vocab, SPECIES_ALIASES)
if not s_canon:
continue
canon_media = []
for m_alias in m_aliases:
m_canon = match_to_vocab(m_alias, media_vocab, MEDIA_ALIASES)
if m_canon:
canon_media.append(m_canon)
if canon_media:
ALLOWED_PAIRS[s_canon] = sorted(set(canon_media))
if not ALLOWED_PAIRS:
ALLOWED_PAIRS = {s: sorted(set(media_vocab)) for s in species_vocab}
# -----------------------------
# Allowed-pairs helpers (robust to 'bg-11' vs 'bg 11')
# -----------------------------
def _canon_species_for_allowed(s: str) -> str:
"""Map incoming species to the token used in ALLOWED_PAIRS keys."""
s_norm = normalize_str(s)
if s_norm in ALLOWED_PAIRS:
return s_norm
# try alias-to-key match
for key in ALLOWED_PAIRS.keys():
if _canon_from_alias(s_norm, SPECIES_ALIASES) == key or _canon_from_alias(s_norm, SPECIES_ALIASES_CANON) == key:
return key
hit = get_close_matches(s_norm, list(ALLOWED_PAIRS.keys()), n=1, cutoff=0.6)
return hit[0] if hit else s_norm
def _canon_media_for_allowed(s_token: str, m: str) -> str | None:
"""Map incoming media to one of the allowed tokens for this species (via MEDIA_ALIASES)."""
m_norm = normalize_str(m)
allowed = ALLOWED_PAIRS.get(s_token, [])
if not allowed:
return None
if m_norm in allowed:
return m_norm
# alias-hit: fold both into canonical form and compare
m_norm_canon = _canon_media_for_bounds(m_norm)
for a in allowed:
if _canon_media_for_bounds(a) == m_norm_canon:
return a
hit = get_close_matches(m_norm, allowed, n=1, cutoff=0.6)
return hit[0] if hit else None
def allowed_media_for(species_norm):
s_token = _canon_species_for_allowed(species_norm)
return ALLOWED_PAIRS.get(s_token, [])
# -----------------------------
# Augmented paths (for KNN)
# -----------------------------
def get_augmented_path(target: str):
p200 = ROOT / f"augmented_{target}_200k.csv"
p20 = ROOT / f"augmented_{target}_20k.csv"
return p200 if p200.exists() else (p20 if p20.exists() else None)
# -----------------------------
# DOI database (load + scorer)
# -----------------------------
def _maybe_load_doi():
if not DOI_PATH.exists():
return None, None, None, False
try:
df_doi_raw = pd.read_csv(DOI_PATH)
df_doi_raw.columns = (
df_doi_raw.columns.str.strip()
.str.lower()
.str.replace("[^0-9a-zA-Z]+", "_", regex=True)
)
# normalize categoricals
for c in ["species", "media"]:
if c in df_doi_raw.columns:
df_doi_raw[c] = df_doi_raw[c].map(normalize_str)
# parse numerics
if "light" in df_doi_raw.columns:
df_doi_raw["light"] = coerce_numeric(df_doi_raw["light"], "cycle_first")
for c in ["expo_day","expo_night","_c","ph","days"]:
if c in df_doi_raw.columns:
df_doi_raw[c] = coerce_numeric(df_doi_raw[c], "float")
# find a DOI-like column to link
doi_col_candidates = [c for c in df_doi_raw.columns if c in {"doi","doi_id","reference","url","link"}]
doi_col = doi_col_candidates[0] if doi_col_candidates else None
# build scales for numeric cols, including any target columns present
base_num = ["light","expo_day","expo_night","_c","ph","days"]
target_cols_present = [t for t in TARGETS if t in df_doi_raw.columns]
num_cols = base_num + target_cols_present
scales = {}
for col in num_cols:
v = pd.to_numeric(df_doi_raw[col], errors="coerce").dropna()
if len(v) >= 4:
lo, hi = np.percentile(v, [5,95]); span = max(1e-6, hi - lo)
elif len(v) > 1:
span = max(1e-6, v.max() - v.min())
else:
span = 1.0
scales[col] = span
return df_doi_raw, doi_col, scales, True
except Exception:
return None, None, None, False
df_doi_raw, DOI_COL, DOI_SCALES, DOI_READY = _maybe_load_doi()
def _media_similarity(a, b):
a = normalize_str(a); b = normalize_str(b)
def canon(m):
if m in MEDIA_ALIASES: return m
for k, syns in MEDIA_ALIASES.items():
if m == k or m in [normalize_str(s) for s in syns]: return k
return m
from difflib import SequenceMatcher
ca, cb = canon(a), canon(b)
return 1.0 if ca == cb else SequenceMatcher(None, ca, cb).ratio()
def _doi_url(x):
if x is None or (isinstance(x, float) and np.isnan(x)): return None
s = str(x).strip()
if s.startswith("http://") or s.startswith("https://"): return s
s = s.lower().replace("doi:", "").strip()
return f"https://doi.org/{s}"
def _closest_doi(
target_name, # "biomass" | "lipid" | "protein" | "carb"
species, media,
light, expo_day, expo_night, temp_c, ph, days,
y_target=None, # float | None
topk=5
):
if not DOI_READY or df_doi_raw is None or len(df_doi_raw) == 0:
return "> ℹ️ doi.csv not found or not readable."
# narrow to species (with fuzzy fallback)
s_key = _normalize_species_label(normalize_str(species))
df_cand = df_doi_raw[df_doi_raw.get("species", "") == s_key]
if df_cand.empty and "species" in df_doi_raw.columns:
sp_unique = df_doi_raw["species"].dropna().unique().tolist()
best = get_close_matches(s_key, sp_unique, n=1, cutoff=0.6)
df_cand = df_doi_raw[df_doi_raw["species"] == (best[0] if best else s_key)]
if df_cand.empty:
df_cand = df_doi_raw # last-resort: search whole table
# require rows that at least *have* a value for the chosen target (if present)
if target_name in df_cand.columns:
df_cand = df_cand[pd.to_numeric(df_cand[target_name], errors="coerce").notna()].copy()
if df_cand.empty:
return f"> ℹ️ No entries with '{target_name}' found for species filter."
# query vector
q = {
"light": parse_cycle_first(light),
"expo_day": extract_first_float(expo_day),
"expo_night": extract_first_float(expo_night),
"_c": extract_first_float(temp_c),
"ph": extract_first_float(ph),
"days": extract_first_float(days),
}
# weights
w_media = 0.5
w_num = 1.0
w_tgt = 2.0 if y_target is not None else 0.0
rows = []
for _, r in df_cand.iterrows():
# media similarity
sim = _media_similarity(media, r.get("media", ""))
media_penalty = (1.0 - sim) * w_media
# numeric distance
dist = 0.0; denom = 0
for col in ["light","expo_day","expo_night","_c","ph","days"]:
if col in df_cand.columns:
rv, qv = r.get(col, np.nan), q[col]
if pd.notna(rv) and pd.notna(qv):
span = DOI_SCALES.get(col, 1.0) if DOI_SCALES else 1.0
dist += w_num * abs(float(qv) - float(rv)) / span
denom += 1
dist = dist/denom if denom>0 else 1.0
# target proximity (if we have both the column and a predicted y)
tgt_term = 0.0
if w_tgt > 0 and target_name in df_cand.columns:
rv = r.get(target_name, np.nan)
if pd.notna(rv):
span = DOI_SCALES.get(target_name, 1.0) if DOI_SCALES else 1.0
tgt_term = w_tgt * abs(float(y_target) - float(rv)) / span
score = media_penalty + dist + tgt_term
rows.append((score, r))
if not rows:
return "> ℹ️ No comparable rows in doi.csv."
# rank
rows.sort(key=lambda x: x[0])
top = rows[:topk]
# build markdown
head_note = f" (target: **{target_name}**"
if y_target is not None:
head_note += f", y≈**{float(y_target):.3f}**"
head_note += ")"
md = f"### 📚 Closest DOI matches{head_note}\n"
for rank, (score, r) in enumerate(top, 1):
sim_pct = max(0.0, min(100.0, 100.0 * np.exp(-score)))
doi_link = _doi_url(r.get(DOI_COL)) if DOI_COL else None
title = f"**{rank}. {r.get('species','?')}{r.get('media','?')}** · Similarity **{sim_pct:.1f}%**"
if doi_link:
title += f" · [DOI]({doi_link})"
md += title + "\n"
tgt_str = ""
if target_name in df_cand.columns and pd.notna(r.get(target_name, np.nan)):
tgt_str = f" · {target_name}: {r.get(target_name)}"
md += (
f"• Light: {r.get('light','NA')} · Day: {r.get('expo_day','NA')} · Night: {r.get('expo_night','NA')} · "
f"T(°C): {r.get('_c','NA')} · pH: {r.get('ph','NA')} · Days: {r.get('days','NA')}{tgt_str}\n"
)
return md
# -----------------------------
# Preprocess + validate pair (for KNN uncertainty only) — FIXED
# -----------------------------
def _canon_categorical_for_encoder(col: str, v, enc) -> str:
"""Map user's string to a label known by the saved LabelEncoder."""
s = "nan" if pd.isna(v) else str(v).strip().lower()
if col == "species":
s = _normalize_species_label(s)
elif col == "media":
s = _canon_media_for_bounds(s)
if s in enc.classes_:
return s
norm_map = {str(c).strip().lower(): c for c in enc.classes_}
if s in norm_map:
return norm_map[s]
s2 = s.replace(" .", ".").replace(". ", ".")
if s2 in norm_map:
return norm_map[s2]
hits = get_close_matches(s, list(norm_map.keys()), n=1, cutoff=0.6)
if hits:
return norm_map[hits[0]]
if "nan" in enc.classes_:
return "nan"
return enc.classes_[0]
def preprocess_row(species, media, light, expo_day, expo_night, temp_c, ph, days):
# Canonicalize to what ALLOWED_PAIRS uses
s_allowed = _canon_species_for_allowed(species)
m_allowed = _canon_media_for_allowed(s_allowed, media)
if s_allowed not in ALLOWED_PAIRS:
raise ValueError(f"Species '{species}' not allowed.")
if m_allowed is None or m_allowed not in ALLOWED_PAIRS[s_allowed]:
exp = ", ".join(ALLOWED_PAIRS[s_allowed]) or "∅"
raise ValueError(f"Media '{media}' not allowed for species '{species}'. Expected one of: {exp}")
# Build raw row using canonical tokens
row = pd.DataFrame([{
"species": s_allowed, "media": m_allowed, "light": light,
"expo_day": expo_day, "expo_night": expo_night,
"_c": temp_c, "ph": ph, "days": days
}], columns=FEATURES)
# Encode categoricals safely
for col in CATEGORICAL:
enc = encoders[col]
def _to_known_code(v):
known = _canon_categorical_for_encoder(col, v, enc)
return enc.transform([known])[0]
row[col] = row[col].apply(_to_known_code)
# Parse numerics
row["light"] = row["light"].apply(parse_cycle_first)
for c in ["expo_day","expo_night","_c","ph","days"]:
row[c] = row[c].apply(extract_first_float)
# Impute
row = pd.DataFrame(imputer.transform(row[FEATURES]), columns=FEATURES)
return row
# -----------------------------
# Uncertainty engine (KNN from augmented)
# -----------------------------
_AUG = {} # target -> (X_aug_np (n,p), y_aug_np (n,))
_KNN = {} # target -> NearestNeighbors
_PERC = {} # target -> per-feature (p05, p95)
K_NEI = 200
Q_LO, Q_HI = 0.10, 0.90
def _load_aug_and_knn(target: str):
if target in _KNN: return
aug_path = get_augmented_path(target)
if aug_path is None:
raise FileNotFoundError(f"Missing augmented file for '{target}'. Place augmented_{target}_200k.csv (or _20k.csv) in repo root.")
df_aug = pd.read_csv(aug_path)
if df_aug.empty:
raise ValueError(f"Augmented file for '{target}' is empty.")
for c in FEATURES:
if c not in df_aug.columns: df_aug[c] = np.nan
X_aug = df_aug[FEATURES].copy()
X_aug_imp = pd.DataFrame(imputer.transform(X_aug), columns=FEATURES)
y_aug = df_aug[target].astype(float).values
X_np = X_aug_imp.values.astype(float)
perc = {}
for j, c in enumerate(FEATURES):
colv = X_np[:, j]
perc[c] = (np.nanpercentile(colv, 5), np.nanpercentile(colv, 95))
nn = NearestNeighbors(n_neighbors=min(K_NEI, len(X_np)), algorithm="auto")
nn.fit(X_np)
_AUG[target] = (X_np, y_aug)
_KNN[target] = nn
_PERC[target] = perc
def _local_interval(target: str, X_query: np.ndarray):
_load_aug_and_knn(target)
X_aug, y_aug = _AUG[target]
nn = _KNN[target]
k_use = min(K_NEI, len(X_aug))
_, idxs = nn.kneighbors(X_query, n_neighbors=k_use, return_distance=True)
qlo = np.quantile(y_aug[idxs], Q_LO, axis=1)
qhi = np.quantile(y_aug[idxs], Q_HI, axis=1)
return qlo, qhi
# -----------------------------
# Ensemble loader & predictor
# -----------------------------
@dataclass
class EnsembleBundle:
encoders: dict
imputer: object
scaler: object | None
xgb: xgb.XGBRegressor
lgb_booster: lgb.Booster | None
lgb_model: lgb.LGBMRegressor | None
cat: CatBoostRegressor
mlp: tf.keras.Model
meta: object
feature_order: list[str]
categorical_cols: list[str]
num_cols_cycle_first: list[str]
num_cols_plain: list[str]
_ENSEMBLES: dict[str, EnsembleBundle] = {}
def _load_ensemble(target: str) -> EnsembleBundle:
if target in _ENSEMBLES:
return _ENSEMBLES[target]
base = MODEL_DIR / target
if not base.exists():
raise FileNotFoundError(f"Model folder not found: {base}")
# Preprocess artifacts
encoders_b = joblib.load(base / "encoders.joblib")
imputer_b = joblib.load(base / "imputer.joblib")
scaler_b = joblib.load(base / "scaler.joblib") if (base / "scaler.joblib").exists() else None
cfg = json.loads((base / "config.json").read_text())
feat_order = cfg["feature_order"]
cat_cols = cfg["categorical_cols"]
cyc_cols = cfg["num_cols_cycle_first"]
num_plain = cfg["num_cols_plain"]
# XGB
xgb_model = xgb.XGBRegressor()
xgb_model.load_model(str(base / "xgb.json"))
_patch_class_and_mro(xgb_model.__class__)
# LGBM
lgb_booster, lgb_model = None, None
if (base / "lgb.txt").exists():
lgb_booster = lgb.Booster(model_file=str(base / "lgb.txt"))
elif (base / "lgb.joblib").exists():
lgb_model = joblib.load(base / "lgb.joblib")
_patch_class_and_mro(lgb_model.__class__)
else:
raise FileNotFoundError("Neither lgb.txt nor lgb.joblib found for LGBM.")
# CAT
cat_model = CatBoostRegressor()
cat_model.load_model(str(base / "cat.cbm"))
_patch_class_and_mro(cat_model.__class__)
# MLP (Keras)
mlp_model = tf.keras.models.load_model(base / "mlp.keras")
# Meta
meta = joblib.load(base / "meta.joblib")
bundle = EnsembleBundle(
encoders=encoders_b, imputer=imputer_b, scaler=scaler_b,
xgb=xgb_model, lgb_booster=lgb_booster, lgb_model=lgb_model,
cat=cat_model, mlp=mlp_model, meta=meta,
feature_order=feat_order, categorical_cols=cat_cols,
num_cols_cycle_first=cyc_cols, num_cols_plain=num_plain
)
_ENSEMBLES[target] = bundle
return bundle
def _encode_df_for_bundle(bundle: EnsembleBundle, df_like: pd.DataFrame) -> pd.DataFrame:
"""
Apply the SAVED encoders + numeric parsing + SAVED imputer; returns imputed numeric DF in training feature order.
Canonicalizes species/media to avoid unseen-label errors.
"""
def _norm(x):
return "nan" if pd.isna(x) else str(x).strip().lower()
X = pd.DataFrame({c: df_like[c] if c in df_like.columns else np.nan for c in bundle.feature_order})
if "species" in X.columns:
X["species"] = X["species"].map(_norm).apply(
lambda v: _canon_to_known(v, bundle.encoders["species"].classes_, SPECIES_ALIASES_CANON)
)
if "media" in X.columns:
X["media"] = X["media"].map(_norm).apply(
lambda v: _canon_to_known(v, bundle.encoders["media"].classes_, MEDIA_ALIASES)
)
for col in bundle.categorical_cols:
X[col] = bundle.encoders[col].transform(X[col].astype(str))
def _extract_first_float(x):
if pd.isna(x): return np.nan
s = str(x); m = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", s)
return float(m.group(0)) if m else np.nan
def _parse_cycle_first(x):
if pd.isna(x): return np.nan
s = str(x); m = re.search(r"(\d+(?:\.\d+)?)\s*:\s*(\d+(?:\.\d+)?)", s)
return float(m.group(1)) if m else _extract_first_float(s)
for c in bundle.num_cols_cycle_first:
if c in X.columns:
X[c] = X[c].apply(_parse_cycle_first)
for c in bundle.num_cols_plain:
if c in X.columns:
X[c] = X[c].apply(_extract_first_float)
X_imp = pd.DataFrame(bundle.imputer.transform(X[bundle.feature_order]), columns=bundle.feature_order)
return X_imp
def predict_stack_batch(target: str, df_raw_rows: pd.DataFrame) -> tuple[np.ndarray, dict]:
b = _load_ensemble(target)
X_imp = _encode_df_for_bundle(b, df_raw_rows)
pred_xgb = b.xgb.predict(X_imp)
if b.lgb_booster is not None:
pred_lgb = b.lgb_booster.predict(X_imp)
else:
pred_lgb = b.lgb_model.predict(X_imp)
pred_cat = b.cat.predict(X_imp)
X_mlp = b.scaler.transform(X_imp) if b.scaler is not None else X_imp
pred_mlp = b.mlp.predict(X_mlp, verbose=0).reshape(-1)
meta_in = np.vstack([pred_xgb, pred_lgb, pred_cat, pred_mlp]).T
pred_stack = b.meta.predict(meta_in)
bases = {"XGB": pred_xgb, "LGBM": pred_lgb, "CAT": pred_cat, "MLP": pred_mlp}
return pred_stack, bases
def predict_with_ensemble_one(target: str, raw_row: dict) -> dict:
df = pd.DataFrame([raw_row])
stack, bases = predict_stack_batch(target, df)
return {"STACK": float(stack[0]), "XGB": float(bases["XGB"][0]), "LGBM": float(bases["LGBM"][0]),
"CAT": float(bases["CAT"][0]), "MLP": float(bases["MLP"][0])}
# ---- Model chooser support ----
MODEL_NAMES = ["STACK", "XGB", "LGBM", "CAT", "MLP"]
def _available_models_for_target(target: str) -> list[str]:
base = MODEL_DIR / target
avail = []
if (base / "meta.joblib").exists(): avail.append("STACK")
if (base / "xgb.json").exists(): avail.append("XGB")
if (base / "lgb.txt").exists() or (base / "lgb.joblib").exists(): avail.append("LGBM")
if (base / "cat.cbm").exists(): avail.append("CAT")
if (base / "mlp.keras").exists() or (base / "mlp_savedmodel").exists(): avail.append("MLP")
return [m for m in MODEL_NAMES if m in avail]
def _predict_with_model_choice(target: str, model_choice: str, df_rows: pd.DataFrame) -> np.ndarray:
avail = _available_models_for_target(target)
if not avail:
raise FileNotFoundError(f"No saved models found under models/{target}")
chosen = model_choice if model_choice in avail else avail[0]
if chosen == "STACK":
y, _ = predict_stack_batch(target, df_rows)
return y
b = _load_ensemble(target)
X_imp = _encode_df_for_bundle(b, df_rows)
if chosen == "XGB":
return np.asarray(b.xgb.predict(X_imp), dtype=float)
if chosen == "LGBM":
if b.lgb_booster is not None:
return np.asarray(b.lgb_booster.predict(X_imp), dtype=float)
return np.asarray(b.lgb_model.predict(X_imp), dtype=float)
if chosen == "CAT":
return np.asarray(b.cat.predict(X_imp), dtype=float)
if chosen == "MLP":
Xm = b.scaler.transform(X_imp) if b.scaler is not None else X_imp
return b.mlp.predict(Xm, verbose=0).reshape(-1).astype(float)
raise ValueError(f"Unknown model choice: {model_choice}")
# -----------------------------
# Predict + Uncertainty + Plot (with bounds clamping)
# -----------------------------
Q_LABEL = lambda ql, qh: int((qh - ql) * 100)
def predict_and_plot_ui(
target, model_choice, species, media, light, expo_day, expo_night, temp_c, ph, days, plot_var
):
try:
# 0) raw row for ensemble/base models
raw_row = {
"species": species, "media": media, "light": light,
"expo_day": expo_day, "expo_night": expo_night,
"_c": temp_c, "ph": ph, "days": days
}
# 1) KNN point for uncertainty
X_one = preprocess_row(species, media, light, expo_day, expo_night, temp_c, ph, days)
# 2) Model point prediction (selected model)
df_one = pd.DataFrame([raw_row])
avail = _available_models_for_target(target)
chosen = model_choice if model_choice in avail else (avail[0] if avail else "STACK")
y_point = _predict_with_model_choice(target, chosen, df_one)
yhat_raw = float(y_point[0])
# (Optional) base outputs
preds_point = predict_with_ensemble_one(target, raw_row) if "STACK" in avail else {}
# 3) local uncertainty (make sure the point is within interval before clamping)
qlo, qhi = _local_interval(target, X_one.values)
lo_raw, hi_raw = float(qlo[0]), float(qhi[0])
if lo_raw > hi_raw:
lo_raw, hi_raw = hi_raw, lo_raw
if yhat_raw < lo_raw:
lo_raw = yhat_raw
elif yhat_raw > hi_raw:
hi_raw = yhat_raw
# 4) species×medium bounds
b_lo, b_hi = get_bounds(species, media, target)
# clamp point + interval
yhat, clamped_point, _ = _clamp_scalar(yhat_raw, b_lo, b_hi)
lo_pt, _, _ = _clamp_scalar(lo_raw, b_lo, b_hi)
hi_pt, _, _ = _clamp_scalar(hi_raw, b_lo, b_hi)
# 5) response curve vs selected variable (same chosen model)
plot_var = (plot_var or "light").strip().lower()
if plot_var not in FEATURES:
plot_var = "light"
j = FEATURES.index(plot_var)
# 5) response curve vs selected variable (force x-axis to start at 0)
plot_var = (plot_var or "light").strip().lower()
if plot_var not in FEATURES:
plot_var = "light"
j = FEATURES.index(plot_var)
# Fixed x ranges so the curve starts at 0
DEFAULT_SPANS = {
"light": (0.0, 400.0), # μmol·m⁻²·s⁻¹
"days": (0.0, 45.0),
"expo_day": (0.0, 24.0),
"expo_night": (0.0, 24.0),
"_c": (0.0, 50.0),
"ph": (0.0, 14.0),
"media": (0.0, 0.0), # categorical (not swept)
"species": (0.0, 0.0), # categorical (not swept)
}
lo_x, hi_x = DEFAULT_SPANS.get(plot_var, (np.nan, np.nan))
if not (np.isfinite(lo_x) and np.isfinite(hi_x)) or hi_x <= lo_x:
_load_aug_and_knn(target)
lo_x, hi_x = _PERC[target][plot_var]
xs = np.linspace(lo_x, hi_x, 200)
# Build grid with plot_var swept, others fixed
grid_rows = []
for xv in xs:
row = dict(raw_row)
if plot_var in ["light", "expo_day", "expo_night", "_c", "ph", "days"]:
row[plot_var] = float(xv)
grid_rows.append(row)
raw_grid_df = pd.DataFrame(grid_rows)
# Predictions along the grid (chosen model)
y_grid_raw = _predict_with_model_choice(target, chosen, raw_grid_df)
# KNN local band along the grid (independent of model)
X_grid = np.repeat(X_one.values, len(xs), axis=0)
X_grid[:, j] = xs
qlo_g_raw, qhi_g_raw = _local_interval(target, X_grid)
# clamp curve + band (and remember if any clamping happened)
y_grid, cl_curve = _clamp_array(y_grid_raw, b_lo, b_hi)
qlo_g, cl_qlo = _clamp_array(qlo_g_raw, b_lo, b_hi)
qhi_g, cl_qhi = _clamp_array(qhi_g_raw, b_lo, b_hi)
clamped_curve = bool(cl_curve or cl_qlo or cl_qhi)
# 6) plot — force both axes from 0; force allowed-range shading from 0
fig, ax = plt.subplots(figsize=(7.0, 4.2))
# Allowed range shading (display from 0 -> max bound; even if lookup lower bound wasn't 0)
# Use literature hi bound if available; otherwise pick from data
b_lo_plot = 0.0
if b_hi is None:
# fallback: pick a reasonable ymax for shading
hi_cands = []
if np.size(qhi_g): hi_cands.append(float(np.nanmax(qhi_g)))
if np.size(y_grid): hi_cands.append(float(np.nanmax(y_grid)))
hi_cands.append(float(yhat))
b_hi_plot = max([v for v in hi_cands if np.isfinite(v)] + [1.0])
else:
b_hi_plot = float(b_hi)
ax.axhspan(b_lo_plot, b_hi_plot, alpha=0.10, label="Allowed range")
# Band label (safe fallback if Q_LABEL not defined)
band_label = (
f"Local {Q_LABEL(Q_LO, Q_HI)}% band"
if "Q_LABEL" in globals()
else f"Local {int((Q_HI - Q_LO) * 100)}% band"
)
# Predicted mean + uncertainty band
ax.plot(xs, y_grid, label=f"{chosen} (predicted mean)")
ax.fill_between(xs, qlo_g, qhi_g, alpha=0.25, label=band_label)
# Current point
x0 = float(X_one.values[0, j])
ax.axvline(x0, linestyle="--", alpha=0.6)
ax.scatter([x0], [yhat], zorder=3, label="Current point")
# Nice axis labels
label_map = {"_c": "Temperature (°C)", "ph": "pH", "expo_day": "Day Exposure (h)",
"expo_night": "Night Exposure (h)", "light": "Light (μmol·m⁻²·s⁻¹)", "days": "Days"}
ax.set_xlabel(label_map.get(plot_var, plot_var))
ax.set_ylabel(target)
ax.set_title(f"{target} vs {label_map.get(plot_var, plot_var)} (others fixed)")
ax.legend(loc="best")
# ---- Force x- and y-axes to start at 0
ax.set_xlim(lo_x, hi_x)
# Target-specific default ymax, then expand to include data/bounds
DEFAULT_Y_SPANS = {
"biomass": (0.0, 7.0),
"lipid": (0.0, 60.0),
"protein": (0.0, 80.0),
"carb": (0.0, 60.0),
}
y_lo_def, y_hi_def = DEFAULT_Y_SPANS.get(str(target).strip().lower(), (0.0, np.nan))
y_upper_candidates = []
if np.size(qhi_g): y_upper_candidates.append(float(np.nanmax(qhi_g)))
if np.size(y_grid): y_upper_candidates.append(float(np.nanmax(y_grid)))
y_upper_candidates.append(float(yhat))
y_upper_candidates.append(float(b_hi_plot))
if np.isfinite(y_hi_def): y_upper_candidates.append(float(y_hi_def))
y_max = max([v for v in y_upper_candidates if np.isfinite(v)] + [1.0])
pad = max(0.05 * y_max, 0.5)
ax.set_ylim(0.0, y_max + pad)
plt.tight_layout()
# ---- Markdown output ----
clamp_note = f" _(clamped to literature range; raw {yhat_raw:.3f}{yhat:.3f})_" if clamped_point else ""
md = (
f"### Prediction ({chosen})\n"
f"**{target}** = **{yhat:.3f}**{clamp_note} \n"
f"Local {Q_LABEL(Q_LO, Q_HI)}% interval: **[{lo_pt:.3f}, {hi_pt:.3f}]** \n"
f"*Exogenous factors may affect the value; DOI reference advised.*"
)
if clamped_curve:
md += "\n\n*Response curve clipped to species×medium range.*"
if preds_point:
md += (
"\n\n<details><summary>Base models</summary>\n"
f"XGB: {preds_point['XGB']:.4f} &nbsp;|&nbsp; "
f"LGBM: {preds_point['LGBM']:.4f} &nbsp;|&nbsp; "
f"CAT: {preds_point['CAT']:.4f} &nbsp;|&nbsp; "
f"MLP: {preds_point['MLP']:.4f}\n"
"</details>"
)
return md, fig
except Exception as e:
fig, ax = plt.subplots(figsize=(6,3))
ax.axis("off")
plt.tight_layout()
return f"Error: {e}", fig
def doi_matches_ui(target, species, media, light, expo_day, expo_night, temp_c, ph, days):
"""Find 5 closest DOI rows using condition + target proximity to ŷ."""
yhat = None
try:
raw_row = {
"species": species, "media": media, "light": light,
"expo_day": expo_day, "expo_night": expo_night,
"_c": temp_c, "ph": ph, "days": days
}
df_one = pd.DataFrame([raw_row])
avail = _available_models_for_target(target)
chosen = "STACK" if "STACK" in avail else (avail[0] if avail else None)
if chosen is not None:
y_point = _predict_with_model_choice(target, chosen, df_one)
yhat = float(y_point[0])
except Exception:
yhat = None
return _closest_doi(
target_name=target,
species=species, media=media,
light=light, expo_day=expo_day, expo_night=expo_night, temp_c=temp_c, ph=ph, days=days,
y_target=yhat, topk=5
)
# -----------------------------
# UI — professional layout
# -----------------------------
from gradio.themes import Soft
theme = Soft(primary_hue="emerald", neutral_hue="slate", radius_size="lg", spacing_size="sm")
CSS = """
.card { border: 1px solid var(--border-color-primary); border-radius: 12px; padding: 14px; background: var(--block-background-fill); }
.small { font-size: 0.92rem; opacity: 0.95; }
/* --- persistent footer bar --- */
.footer-bar {
position: fixed;
left: 0; right: 0; bottom: 0;
z-index: 9999;
display: flex; align-items: center; gap: .5rem; flex-wrap: wrap;
padding: 10px 16px;
border-top: 1px solid var(--border-color-primary);
background: rgba(17, 24, 39, 0.85);
color: white;
backdrop-filter: blur(6px);
-webkit-backdrop-filter: blur(6px);
font-size: 0.9rem;
}
.footer-bar a { color: #a7f3d0; text-decoration: none; }
.footer-bar a:hover { text-decoration: underline; }
/* Spacer to prevent content being hidden behind the fixed footer */
.footer-spacer { height: 56px; }
@media (max-width: 640px){
.footer-bar { font-size: .82rem; padding: 8px 12px; }
.footer-spacer { height: 48px; }
}
@media print {
.footer-bar, .footer-spacer { display: none !important; }
}
"""
def update_media(species):
# keep dropdown choices consistent with canonical species key in ALLOWED_PAIRS
s_token = _canon_species_for_allowed(species) if species else None
choices = allowed_media_for(s_token) if s_token else []
value = choices[0] if choices else None
return gr.update(choices=choices, value=value)
def allowed_species_choices():
return sorted(ALLOWED_PAIRS.keys())
# ---- restrict model choices per target ----
def update_model_choices(target):
avail = _available_models_for_target(target)
if not avail:
avail = ["STACK"]
value = "STACK" if "STACK" in avail else avail[0]
return gr.update(choices=avail, value=value)
allowed_species = allowed_species_choices()
first_species = allowed_species[0] if allowed_species else None
first_media_choices = allowed_media_for(first_species) if first_species else []
first_media = first_media_choices[0] if first_media_choices else None
with gr.Blocks(title="Algae Yield Predictor", theme=theme, css=CSS) as demo:
gr.Markdown(
f"<h1>Algae Yield Predictor</h1>"
f"<div class='small'>Predict <b>biomass / lipid / protein / carbohydrate</b> with "
f"a selectable model (<b>STACK / XGB / LGBM / CAT / MLP</b>), local uncertainty bands, "
f"and species×medium literature-range clamping."
f"{'' if DOI_READY else ' &nbsp;<em>(DOI file missing or lacks a doi column.)</em>'}"
f"</div>",
elem_classes=["card"]
)
with gr.Row():
with gr.Column(scale=6):
with gr.Group(elem_classes=["card"]):
gr.Markdown("### Inputs")
target_dd = gr.Dropdown(choices=TARGETS, value="biomass", label="Target", info="Choose outcome to predict")
model_dd = gr.Dropdown(choices=MODEL_NAMES, value="STACK", label="Model", info="Choose which trained model to use")
with gr.Row():
species_dd = gr.Dropdown(choices=allowed_species, value=first_species, label="Species", info="Only curated species")
media_dd = gr.Dropdown(choices=first_media_choices, value=first_media, label="Medium", info="Restricted by species")
gr.Markdown("#### Culture Conditions", elem_classes=["small"])
with gr.Row():
light_sl = gr.Slider(10, 400, value=150, step=5, label="Light (μmol·m⁻²·s⁻¹)")
days_sl = gr.Slider(1, 45, value=18, step=1, label="Days", info="Total culture duration")
with gr.Row():
day_sl = gr.Slider(0, 24, value=18, step=1, label="Day Exposure (h)")
night_sl = gr.Slider(0, 24, value=6, step=1, label="Night Exposure (h)")
with gr.Row():
temp_num = gr.Number(value=27, label="Temperature (°C)", precision=1)
ph_num = gr.Number(value=7.0, label="pH", precision=2)
with gr.Row():
plot_var_dd = gr.Dropdown(
choices=["light","days","expo_day","expo_night","_c","ph"], # 'ph' lowercase
value="light",
label="Plot variable",
info="Sweep one input to see response curve with uncertainty band"
)
with gr.Row():
go = gr.Button("Predict + Plot", variant="primary")
doi_btn = gr.Button("Find Closest DOI Matches", variant="secondary")
with gr.Group(elem_classes=["card"]):
gr.Markdown("### Suggested Conditions")
suggest_md = gr.Markdown(value=_format_suggestion_md(first_species or "", "biomass"))
with gr.Group(elem_classes=["card"]):
gr.Markdown("### Model Tips")
model_tips_md = gr.Markdown("""\
**Recommendations**
- **STACK (Ensemble)** — best overall accuracy (offline metrics ~R² 0.89 / MAE ~0.66).
- **XGB / LGBM** — fast, strong single models (R² ~0.69).
- **CAT** — robust to categorical quirks (R² ~0.62).
- **MLP** — requires scaler; slower cold start (R² ~0.55 here).
**Pick**: Use **STACK** by default. Choose **XGB**/**LGBM** for speed or to sanity-check disagreement across models.
""")
with gr.Column(scale=6):
with gr.Group(elem_classes=["card"]):
pred_md = gr.Markdown("Click **Predict + Plot** to run.")
with gr.Group(elem_classes=["card"]):
gr.Markdown("### Response Plot")
plot_out = gr.Plot()
with gr.Group(elem_classes=["card"]):
gr.Markdown("### Literature (DOI) Matches")
doi_md = gr.Markdown("Click **Find Closest DOI Matches** to see references.")
with gr.Group(elem_classes=["card"]):
gr.Markdown("""\
### Citation
If you use this predictor or dataset, please cite:
**Tiwari, A., Dubey, S., Sumathi, Y., Patel, A. K, & Kuo, T.-R. (2025).**
*Augmented and Real Microalgae Datasets for Biomass and Biochemical Composition Prediction* [Data set]. Zenodo.
[https://doi.org/10.5281/zenodo.17177597](https://doi.org/10.5281/zenodo.17177597)
""")
# Wiring
species_dd.change(fn=update_media, inputs=species_dd, outputs=media_dd)
target_dd.change(update_suggestion_panel, inputs=[target_dd, species_dd], outputs=suggest_md)
species_dd.change(update_suggestion_panel, inputs=[target_dd, species_dd], outputs=suggest_md)
target_dd.change(fn=update_model_choices, inputs=target_dd, outputs=model_dd)
go.click(
fn=predict_and_plot_ui,
inputs=[target_dd, model_dd, species_dd, media_dd, light_sl, day_sl, night_sl, temp_num, ph_num, days_sl, plot_var_dd],
outputs=[pred_md, plot_out]
)
doi_btn.click(
fn=doi_matches_ui,
inputs=[target_dd, species_dd, media_dd, light_sl, day_sl, night_sl, temp_c := temp_num, ph_num, days_sl],
outputs=doi_md
)
# ---- Persistent bottom bar ----
gr.HTML("<div class='footer-spacer'></div>")
gr.HTML("""
<div class="footer-bar">
<strong>Algae Yield Predictor</strong>
&nbsp;·&nbsp; Developed by <b>Ashutosh Tiwari (Lead)</b> <span> &amp; Siddhant Dubey (Co-Lead)
<span>with contributions from</span> Yamini Sumathi</span>.
&nbsp;© 2025 Ashutosh Tiwari and collaborators. All rights reserved.
</div>
""")
# Spaces auto-runs this
if __name__ == "__main__":
demo.launch()