# app.py (updated: fixes species bug, adds caching, parallel monthly queries, monthly bell curve) import os import time import math import joblib import ee import pandas as pd import numpy as np from datetime import datetime, date, timedelta from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from contextlib import asynccontextmanager from pathlib import Path from typing import Optional, List, Dict, Tuple from concurrent.futures import ThreadPoolExecutor, as_completed import threading import json import scipy.signal # google oauth helpers (used earlier) from google.oauth2.credentials import Credentials from google.auth.transport.requests import Request as GoogleRequest from google.oauth2 import service_account as google_service_account # ------------------------------ # CONFIG / FILENAMES / TUNABLES # ------------------------------ MODEL_FILE = Path("mil_bloom_model.joblib") SCALER_FILE = Path("mil_scaler.joblib") FEATURES_FILE = Path("mil_features.joblib") PHENO_FILE = Path("phenologythingy.csv") SPECIES_STATS_FILE = Path("species_stats.csv") MIN_BLOOM_THRESHOLD = float(os.environ.get("MIN_BLOOM_THRESHOLD", 20.0)) # minimum probability to predict species MIN_PEAK_FOR_BELL = float(os.environ.get("MIN_PEAK_FOR_BELL", 25.0)) ELEV_IMAGE_ID = "USGS/SRTMGL1_003" BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200)) MAX_DAYS_BACK = int(os.environ.get("MAX_DAYS_BACK", 30)) MIN_COUNT_FOR_SPECIES = int(os.environ.get("MIN_COUNT_FOR_SPECIES", 20)) TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5)) DOY_BINS = 366 DOY_SMOOTH = 15 EPS_STD = 1.0 # TUNABLES (additions) ALPHA = float(os.environ.get("MONTH_ALPHA", 2.0)) # >1 sharpens peaks, 1.0 = no change SMOOTH_SIGMA = float(os.environ.get("SMOOTH_SIGMA", 1.2)) # gaussian sigma in months TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5)) # how many species to return MIN_CURVE_PROB = float(os.environ.get("MIN_CURVE_PROB", 0.01)) # min per-month percentage floor # Tune parallelism: how many months to fetch at once MAX_WORKERS = int(os.environ.get("MAX_WORKERS", 4)) # EE OAuth env vars expected in Hugging Face Space secrets CLIENT_ID = os.environ.get("CLIENT_ID") CLIENT_SECRET = os.environ.get("CLIENT_SECRET") REFRESH_TOKEN = os.environ.get("REFRESH_TOKEN") EE_PROJECT = os.environ.get("PROJECT") or os.environ.get("EE_PROJECT") or None EE_SCOPES = [ "https://www.googleapis.com/auth/earthengine", "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/drive", "https://www.googleapis.com/auth/devstorage.full_control", ] # ------------------------------ # Pydantic models # ------------------------------ class BloomPredictionRequest(BaseModel): lat: float = Field(..., ge=-90, le=90) lon: float = Field(..., ge=-180, le=180) date: str = Field(..., description="YYYY-MM-DD") class SimplifiedMonthlyResult(BaseModel): month: int bloom_probability: float prediction: str # "BLOOM" or "NO_BLOOM" class SpeciesResult(BaseModel): name: str probability: float # as percentage class BloomPredictionResponse(BaseModel): success: bool status: str # "BLOOM_DETECTED", "NO_BLOOM", "LOW_CONFIDENCE" # Only include these if there's a valid bloom season peak_month: Optional[int] = None peak_probability: Optional[float] = None bloom_window: Optional[List[int]] = None # months with >40% probability # Only include species if peak > threshold top_species: Optional[List[SpeciesResult]] = None # Simplified monthly data (probabilities only) monthly_probabilities: Dict[int, float] # ------------------------------ # Globals & cache # ------------------------------ ML_MODEL = None SCALER = None FEATURE_COLUMNS = None SPECIES_STATS_DF = None DOY_HIST_MAP: Dict[str, np.ndarray] = {} # simple in-memory cache (keyed by (lat,lon,date_str) -> sat_data) ee_cache_lock = threading.Lock() ee_cache: Dict[Tuple[float, float, str], dict] = {} # ------------------------------ # Utility functions # ------------------------------ def gaussian_kernel(length=12, sigma=1.2): """Return a normalized 1D gaussian kernel of size `length` centered (odd/even ok) with sigma in index units.""" # create symmetric kernel (size = length*3 is overkill; we'll sample it centered) half = max(6, int(3 * sigma)) xs = np.arange(-half, half + 1) kern = np.exp(-0.5 * (xs / sigma) ** 2) kern = kern / kern.sum() return kern def smooth_monthly_probs_preserve_magnitude(raw_probs, alpha=ALPHA, sigma=SMOOTH_SIGMA, min_curve_prob=MIN_CURVE_PROB): """ Convert raw_probs (0..100) -> two arrays: - monthly_scores: absolute smoothed scores scaled to 0..100 (preserves magnitude) - monthly_perc: normalized percentages summing to 100 (for plotting) Steps: 1. scale to 0..1 2. tiny contrast-stretch so differences matter 3. apply exponent alpha to emphasize peaks 4. circular gaussian smoothing 5. monthly_scores: scale smoothed values so max -> 100 6. monthly_perc: normalize the same smoothed values to sum -> 100 """ a = np.asarray(raw_probs, dtype=float) / 100.0 # 0..1 # Guard: if all zeros, return uniform small floor if a.sum() == 0 or np.allclose(a, 0.0): uniform = np.ones(12) * (min_curve_prob / 100.0) monthly_scores = (uniform / uniform.max()) * 100.0 monthly_perc = (uniform / uniform.sum()) * 100.0 return monthly_perc.tolist(), monthly_scores.tolist() # Contrast stretch (min-max) to amplify differences amin = float(a.min()) amax = float(a.max()) rng = amax - amin if rng < 1e-6: a_cs = a - amin if a_cs.max() <= 1e-12: a_cs = np.ones_like(a) * 1e-6 else: a_cs = a_cs / (a_cs.max() + 1e-12) else: a_cs = (a - amin) / (rng + 1e-12) # floor tiny values (avoid zero everywhere) floor = min_curve_prob / 100.0 a_cs = np.clip(a_cs, floor * 1e-3, 1.0) # sharpen peaks if alpha != 1.0: a_sh = np.power(a_cs, alpha) else: a_sh = a_cs # circular pad and gaussian kernel sigma = float(max(0.6, sigma)) pad = max(3, int(round(3 * sigma))) padded = np.concatenate([a_sh[-pad:], a_sh, a_sh[:pad]]) kern_range = np.arange(-pad, pad + 1) kern = np.exp(-0.5 * (kern_range / sigma) ** 2) kern = kern / kern.sum() smoothed = np.convolve(padded, kern, mode='same') center = smoothed[pad:pad+12] center = np.clip(center, 0.0, None) # monthly_perc: normalized to sum 100 for plotting if center.sum() <= 0: center = np.ones_like(center) norm = center / center.sum() monthly_perc = (norm * 100.0).round(3) # monthly_scores: scale by max to 0..100 (preserve magnitude) maxv = float(center.max()) if maxv <= 0: monthly_scores = np.ones_like(center) * (min_curve_prob) else: monthly_scores = (center / maxv) * 100.0 monthly_scores = np.round(monthly_scores, 3) # final safety normalization s = float(monthly_perc.sum()) if s <= 0: monthly_perc = np.ones(12) * (100.0 / 12.0) else: monthly_perc = monthly_perc * (100.0 / monthly_perc.sum()) return monthly_perc.tolist(), monthly_scores.tolist() def is_bell_shaped(perc_list): """ Basic unimodal/symmetric check: - count peaks (local maxima) - expect 1 - compute skewness sign (negative/positive) - compute ratio of mass on left vs right of peak (expect roughly symmetric) Returns (is_bell, diagnostics_dict) """ arr = np.asarray(perc_list, dtype=float) # small smoothing to remove noise arr_smooth = np.array(scipy.signal.savgol_filter(arr, 5 if len(arr)>=5 else len(arr)-1, 3 if len(arr)>=3 else 1)) # find peaks peaks = (np.diff(np.sign(np.diff(arr_smooth))) < 0).nonzero()[0] + 1 # indices of local maxima num_peaks = len(peaks) peak_idx = int(peaks[0]) + 1 if num_peaks >= 1 else int(arr_smooth.argmax()) # symmetry: mass difference left/right of peak left_mass = arr[:peak_idx].sum() right_mass = arr[peak_idx+1:].sum() if peak_idx+1 < len(arr) else 0.0 sym_ratio = left_mass / (right_mass + 1e-9) # skewness: m = arr.mean() s = arr.std(ddof=0) if arr.std(ddof=0) > 0 else 1.0 skew = ((arr - m) ** 3).mean() / (s ** 3) # heuristics is_unimodal = (num_peaks <= 1) is_symmish = 0.5 <= sym_ratio <= 2.0 # within factor 2 # final decision is_bell = is_unimodal and is_symmish diagnostics = { "num_peaks": num_peaks, "peak_month": int(np.argmax(arr))+1, "sym_ratio_left_to_right": float(sym_ratio), "skewness": float(skew), "is_unimodal": bool(is_unimodal), "is_symmetric_enough": bool(is_symmish) } return bool(is_bell), diagnostics def gaussian_pdf_scalar(x_scalar, mean, std): """Return Gaussian PDF scalar for scalar x. If x_scalar is None or NaN, return 1.0 (neutral).""" try: if x_scalar is None or (isinstance(x_scalar, float) and np.isnan(x_scalar)): return 1.0 std = max(float(std), 1e-6) coef = 1.0 / (std * math.sqrt(2 * math.pi)) z = (float(x_scalar) - float(mean)) / std return coef * math.exp(-0.5 * z * z) except Exception: # fallback neutral return 1.0 def circular_histogram(doys, bins=DOY_BINS, smooth_window=DOY_SMOOTH): if len(doys) == 0: return np.ones(bins) / bins counts = np.bincount(doys.astype(int), minlength=bins+1)[1:] window = np.ones(smooth_window) / smooth_window doubled = np.concatenate([counts, counts]) smoothed = np.convolve(doubled, window, mode='same')[:bins] total = smoothed.sum() if total <= 0: return np.ones(bins) / bins return smoothed / total # ------------------------------ # Earth Engine init (OAuth refresh-token or service account fallback) # ------------------------------ def initialize_ee_from_env(): # Try OAuth refresh token first try: if CLIENT_ID and CLIENT_SECRET and REFRESH_TOKEN: creds = Credentials( token=None, refresh_token=REFRESH_TOKEN, client_id=CLIENT_ID, client_secret=CLIENT_SECRET, token_uri="https://oauth2.googleapis.com/token", scopes=EE_SCOPES ) request = GoogleRequest() creds.refresh(request) ee.Initialize(credentials=creds, project=EE_PROJECT) print("✅ Earth Engine initialized with OAuth refresh-token credentials") return True except Exception as e: print("⚠️ OAuth refresh-token init failed:", e) # Try service account JSON in env var EE_SERVICE_ACCOUNT_JSON try: sa_json_env = os.environ.get("EE_SERVICE_ACCOUNT_JSON") sa_file_env = os.environ.get("EE_SERVICE_ACCOUNT_FILE") if sa_json_env: sa_data = json.loads(sa_json_env) tmp_path = "/tmp/ee_service_account.json" with open(tmp_path, "w") as f: json.dump(sa_data, f) creds = google_service_account.Credentials.from_service_account_file(tmp_path, scopes=EE_SCOPES) ee.Initialize(credentials=creds, project=EE_PROJECT) print("✅ Earth Engine initialized with service-account (from EE_SERVICE_ACCOUNT_JSON)") return True if sa_file_env and Path(sa_file_env).exists(): creds = google_service_account.Credentials.from_service_account_file(sa_file_env, scopes=EE_SCOPES) ee.Initialize(credentials=creds, project=EE_PROJECT) print("✅ Earth Engine initialized with service-account (from file)") return True except Exception as e: print("⚠️ Service-account init failed:", e) # fallback to default try: ee.Initialize(project=EE_PROJECT) if EE_PROJECT else ee.Initialize() print("✅ Earth Engine initialized via default ee.Initialize() (fallback).") return True except Exception as e: print("❌ Earth Engine initialization failed (all methods):", e) return False def get_elevation_from_ee(lat, lon): try: img = ee.Image(ELEV_IMAGE_ID) pt = ee.Geometry.Point([float(lon), float(lat)]) rr = img.reduceRegion(ee.Reducer.first(), pt, scale=30, maxPixels=1e6) if rr is None: return None try: val = rr.get("elevation").getInfo() return float(val) if val is not None else None except Exception: keys = rr.keys().getInfo() for k in keys: v = rr.get(k).getInfo() if isinstance(v, (int, float)): return float(v) return None except Exception as e: print("❌ get_elevation_from_ee error:", e) return None # ------------------------------ # Satellite retrieval + caching # ------------------------------ def get_single_date_satellite_data(lat, lon, date_str, satellite, buffer_meters, area): collection_id = "LANDSAT/LC09/C02/T1_L2" if satellite == "Landsat-9" else "LANDSAT/LC08/C02/T1_L2" try: filtered = (ee.ImageCollection(collection_id) .filterBounds(area) .filterDate(date_str, f"{date_str}T23:59:59") .sort("CLOUD_COVER") .limit(1)) size = int(filtered.size().getInfo()) if size == 0: return None image = ee.Image(filtered.first()) info = image.getInfo().get("properties", {}) cloud_cover = float(info.get("CLOUD_COVER", 100.0)) if cloud_cover > 80: return None ndvi = image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI") ndwi = image.normalizedDifference(["SR_B3", "SR_B5"]).rename("NDWI") evi = image.expression( "2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))", {"NIR": image.select("SR_B5"), "RED": image.select("SR_B4"), "BLUE": image.select("SR_B2")}, ).rename("EVI") lst = image.select("ST_B10").multiply(0.00341802).add(149.0).subtract(273.15).rename("LST") composite = ndvi.addBands([ndwi, evi, lst]) stats = composite.reduceRegion( reducer=ee.Reducer.mean(), geometry=area, scale=30, maxPixels=1e6, bestEffort=True ).getInfo() ndvi_val = stats.get("NDVI") if ndvi_val is None: return None ndwi_val = stats.get("NDWI") evi_val = stats.get("EVI") lst_val = stats.get("LST") current_dt = datetime.strptime(date_str, "%Y-%m-%d") return { "ndvi": float(ndvi_val), "ndwi": float(ndwi_val) if ndwi_val is not None else None, "evi": float(evi_val) if evi_val is not None else None, "lst": float(lst_val) if lst_val is not None else None, "cloud_cover": float(cloud_cover), "month": current_dt.month, "day_of_year": current_dt.timetuple().tm_yday, "satellite": satellite, "date": date_str, "buffer_size": buffer_meters, } except Exception as e: print("❌ get_single_date_satellite_data error:", e) return None def get_satellite_data_with_fallback(lat, lon, target_dt, satellite, buffer_meters, area, max_days_back=MAX_DAYS_BACK): for days_back in range(0, max_days_back + 1): current_date = (target_dt - timedelta(days=days_back)).strftime("%Y-%m-%d") # check cache (only for final results of get_essential_vegetation_data). Here we do not use cache per satellite data = get_single_date_satellite_data(lat, lon, current_date, satellite, buffer_meters, area) if data and data.get("ndvi") is not None: data["original_request_date"] = target_dt.strftime("%Y-%m-%d") data["actual_data_date"] = current_date data["days_offset"] = days_back return data return None def get_essential_vegetation_data_with_cache(lat, lon, target_date, buffer_meters=BUFFER_METERS, max_days_back=MAX_DAYS_BACK): key = (round(float(lat), 6), round(float(lon), 6), str(target_date)) with ee_cache_lock: if key in ee_cache: return ee_cache[key] # Else fetch (L9 -> L8) point = ee.Geometry.Point([float(lon), float(lat)]) area = point.buffer(buffer_meters) target_dt = datetime.strptime(target_date, "%Y-%m-%d") data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-9", buffer_meters, area, max_days_back) if not data: data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-8", buffer_meters, area, max_days_back) with ee_cache_lock: ee_cache[key] = data return data # ------------------------------ # ML prediction wrapper (unchanged features; elevation NOT passed to ML) # ------------------------------ def predict_bloom_with_ml(features_dict): ndvi = features_dict.get("ndvi", 0.0) or 0.0 evi = features_dict.get("evi", 0.0) or 0.0 if ndvi < 0.05: return {"bloom_probability": 8.0, "prediction": "NO_BLOOM", "confidence": "HIGH"} if evi < 0.1 and ndvi < 0.1: return {"bloom_probability": 10.0, "prediction": "NO_BLOOM", "confidence": "HIGH"} if ML_MODEL is not None and SCALER is not None: try: features_array = np.array( [ [ float(features_dict.get("ndvi", 0.0)), float(features_dict.get("ndwi", 0.0) or 0.0), float(features_dict.get("evi", 0.0) or 0.0), float(features_dict.get("lst", 0.0) or 0.0), float(features_dict.get("cloud_cover", 0.0) or 0.0), float(features_dict.get("month", 0) or 0), float(features_dict.get("day_of_year", 0) or 0), ] ], dtype=np.float64, ) features_scaled = SCALER.transform(features_array) probabilities = ML_MODEL.predict_proba(features_scaled) bloom_prob = probabilities[0, 1] if probabilities.shape[1] == 2 else probabilities[0, 0] prediction = ML_MODEL.predict(features_scaled)[0] bloom_prob_pct = round(float(bloom_prob * 100.0), 2) if bloom_prob_pct > 75 or bloom_prob_pct < 25: conf = "HIGH" elif bloom_prob_pct > 60 or bloom_prob_pct < 40: conf = "MEDIUM" else: conf = "LOW" return {"bloom_probability": bloom_prob_pct, "prediction": "BLOOM" if prediction == 1 else "NO_BLOOM", "confidence": conf} except Exception as e: print("❌ ML model error:", e) return predict_bloom_fallback(features_dict) def predict_bloom_fallback(features_dict): ndvi = float(features_dict.get("ndvi") or 0.0) ndwi = float(features_dict.get("ndwi") or 0.0) evi = float(features_dict.get("evi") or 0.0) lst = float(features_dict.get("lst") or 0.0) month = int(features_dict.get("month") or 1) score = 0.0 if evi > 0.7: score += 50 elif evi > 0.5: score += 35 elif evi > 0.3: score += 20 if ndvi > 0.5: score += 25 elif ndvi > 0.3: score += 15 if -0.2 < ndwi < 0.05: score += 15 if 12 < lst < 32: score += 12 if month in [3, 4, 5]: score += 15 if month in [11, 12, 1, 2]: score -= 3 prob = min(90, max(8, score)) if prob > 52: pred = "BLOOM" conf = "MEDIUM" if prob > 65 else "LOW" else: pred = "NO_BLOOM" conf = "MEDIUM" if prob < 25 else "LOW" return {"bloom_probability": round(prob, 2), "prediction": pred, "confidence": conf} # ------------------------------ # Species stats builder / predictor (fixed) # ------------------------------ def load_or_build_species_stats(): global PHENO_FILE, SPECIES_STATS_FILE if SPECIES_STATS_FILE.exists(): df = pd.read_csv(SPECIES_STATS_FILE) doy_map = {} for s in df["species"].tolist(): doy_map[s] = np.ones(DOY_BINS) / DOY_BINS return df, doy_map if PHENO_FILE.exists(): ph = pd.read_csv(PHENO_FILE, low_memory=False) if "phenophaseStatus" in ph.columns: ph["phenophaseStatus"] = ph["phenophaseStatus"].astype(str).str.strip().str.lower() ph_yes = ph[ph["phenophaseStatus"].str.startswith("y")].copy() else: ph_yes = ph.copy() ph_yes = ph_yes.dropna(subset=["elevation"]) if "dayOfYear" in ph_yes.columns: ph_yes["dayOfYear"] = pd.to_numeric(ph_yes["dayOfYear"], errors="coerce").dropna().astype(int).clip(1, 366) rows = [] doy_map = {} grouped = ph_yes.groupby("scientificName") for name, g in grouped: cnt = len(g) mean_elev = float(g["elevation"].dropna().mean()) if cnt > 0 else np.nan std_elev = float(g["elevation"].dropna().std(ddof=0)) if cnt > 0 else EPS_STD std_elev = max(std_elev if not np.isnan(std_elev) else 0.0, EPS_STD) rows.append({"species": name, "count": cnt, "mean_elev": mean_elev, "std_elev": std_elev}) if "dayOfYear" in g.columns: doy_map[name] = circular_histogram(g["dayOfYear"].to_numpy(dtype=int)) else: doy_map[name] = np.ones(DOY_BINS) / DOY_BINS species_df = pd.DataFrame(rows) total = species_df["count"].sum() species_df["prior"] = species_df["count"] / total if total > 0 else 1.0 / max(1, len(species_df)) rare = species_df[species_df["count"] < MIN_COUNT_FOR_SPECIES] frequent = species_df[species_df["count"] >= MIN_COUNT_FOR_SPECIES] final_rows = frequent.to_dict("records") if len(rare) > 0: rare_names = rare["species"].tolist() rare_obs = ph_yes[ph_yes["scientificName"].isin(rare_names)] total_rare = len(rare_obs) if total_rare > 0: mean_other = float(rare_obs["elevation"].dropna().mean()) std_other = float(rare_obs["elevation"].dropna().std(ddof=0)) if total_rare > 1 else EPS_STD std_other = max(std_other if not np.isnan(std_other) else 0.0, EPS_STD) final_rows.append( { "species": "OTHER", "count": int(total_rare), "mean_elev": mean_other, "std_elev": std_other, "prior": int(total_rare) / total if total > 0 else int(total_rare), } ) doy_map["OTHER"] = circular_histogram(rare_obs["dayOfYear"].to_numpy(dtype=int)) if "dayOfYear" in rare_obs.columns else np.ones(DOY_BINS) / DOY_BINS final_df = pd.DataFrame(final_rows).fillna(0) if "prior" not in final_df.columns: t2 = final_df["count"].sum() final_df["prior"] = final_df["count"] / t2 if t2 > 0 else 1.0 / len(final_df) return final_df, doy_map return pd.DataFrame(columns=["species", "count", "mean_elev", "std_elev", "prior"]), {} def predict_species_by_elevation(elevation, doy=None, top_k=TOP_K_SPECIES): """ Return top_k list of tuples (species, prob_fraction in [0,1]). If elevation is None, rely on priors and DOY histogram if available. """ global SPECIES_STATS_DF, DOY_HIST_MAP if SPECIES_STATS_DF is None or SPECIES_STATS_DF.empty: return [] species = SPECIES_STATS_DF["species"].tolist() priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float) means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float) stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float) if elevation is None or (isinstance(elevation, float) and np.isnan(elevation)): # If no elevation, start from priors post = priors.copy() post = post / post.sum() else: # compute likelihood per species using scalar gaussian likes = np.array([gaussian_pdf_scalar(elevation, means[i], stds[i]) for i in range(len(species))]) post = priors * likes if post.sum() == 0: post = np.ones(len(species)) / len(species) else: post = post / post.sum() # incorporate DOY if available if doy is not None and not np.isnan(doy): doy_idx = int(doy) - 1 doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in species]) combined = post * doy_probs if combined.sum() > 0: post = combined / combined.sum() order = np.argsort(-post) top = [] for i in order[:top_k]: top.append((species[i], float(post[i]))) return top # ------------------------------ # Lifespan: load model, initialize EE, build species stats # ------------------------------ @asynccontextmanager async def lifespan(app): global ML_MODEL, SCALER, FEATURE_COLUMNS, SPECIES_STATS_DF, DOY_HIST_MAP # load optional models if MODEL_FILE.exists(): try: ML_MODEL = joblib.load(MODEL_FILE) print("✅ MIL model loaded.") except Exception as e: print("❌ MIL model load error:", e) if SCALER_FILE.exists(): try: SCALER = joblib.load(SCALER_FILE) print("✅ Scaler loaded.") except Exception as e: print("❌ Scaler load error:", e) if FEATURES_FILE.exists(): try: FEATURE_COLUMNS = joblib.load(FEATURES_FILE) print("✅ Features list loaded.") except Exception as e: print("❌ Features list load error:", e) ok = initialize_ee_from_env() if not ok: raise RuntimeError("Earth Engine initialization failed. Set credentials in Space secrets (EE_SERVICE_ACCOUNT_JSON or CLIENT_ID/CLIENT_SECRET/REFRESH_TOKEN).") try: SPECIES_STATS_DF, DOY_HIST_MAP = load_or_build_species_stats() print("✅ Species stats ready. species count:", len(SPECIES_STATS_DF)) except Exception as e: print("⚠️ Species stats build error:", e) SPECIES_STATS_DF = pd.DataFrame() DOY_HIST_MAP = {} yield print("🔄 Shutting down") # ------------------------------ # FastAPI app & endpoint # ------------------------------ app = FastAPI(title="Bloom Prediction (HF Space)", lifespan=lifespan) @app.get("/") async def root(): return {"message": "Bloom Prediction API (HF Space)", "model_loaded": ML_MODEL is not None} def process_month_task(lat, lon, year, month, elevation): """ Perform the satellite retrieval and predictions for one month (15th). Returns a dict compatible with MonthlyResult. """ sample_dt = date(year, month, 15) sample_date_str = sample_dt.strftime("%Y-%m-%d") # fetch satellite data (with caching) sat_data = get_essential_vegetation_data_with_cache(lat, lon, sample_date_str) result = { "month": month, "sample_date": sample_date_str, "ml_bloom_probability": None, "ml_prediction": None, "ml_confidence": None, "species_top": None, "species_probs": None, "elevation_m": elevation, "data_quality": None, "satellite": None, "note": None, } if sat_data is None: result["note"] = f"No satellite data within {MAX_DAYS_BACK} days for {sample_date_str}" return result # run ML ml_out = predict_bloom_with_ml(sat_data) result["ml_bloom_probability"] = float(ml_out.get("bloom_probability", 0.0)) result["ml_prediction"] = ml_out.get("prediction") result["ml_confidence"] = ml_out.get("confidence") result["data_quality"] = { "satellite": sat_data.get("satellite"), "cloud_cover": sat_data.get("cloud_cover"), "days_offset": sat_data.get("days_offset"), "buffer_radius_m": sat_data.get("buffer_size"), } result["satellite"] = sat_data.get("satellite") # species prediction if bloom try: bloom_bool = (result["ml_prediction"] == "BLOOM") or (result["ml_bloom_probability"] >= 50.0) if bloom_bool: doy = sat_data.get("day_of_year", None) top_species = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES) # convert to percentages result["species_top"] = [(s, round(p * 100.0, 2)) for s, p in top_species] # build full species_probs species_probs = {} if (SPECIES_STATS_DF is not None) and (not SPECIES_STATS_DF.empty): all_species = SPECIES_STATS_DF["species"].tolist() priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float) means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float) stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float) # If elevation missing, likes = 1 (neutral) if elevation is None or (isinstance(elevation, float) and np.isnan(elevation)): likes = np.ones(len(all_species)) else: likes = np.array([gaussian_pdf_scalar(elevation, means[i], stds[i]) for i in range(len(all_species))]) post = priors * likes if post.sum() == 0: post = np.ones(len(all_species)) / len(all_species) else: post = post / post.sum() # DOY adjust if doy is not None and not np.isnan(doy): doy_idx = int(doy) - 1 doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in all_species]) combined = post * doy_probs if combined.sum() > 0: post = combined / combined.sum() for s, p in zip(all_species, post): species_probs[s] = round(float(p * 100.0), 6) result["species_probs"] = species_probs else: result["species_top"] = [] result["species_probs"] = {} except Exception as e: print("❌ species prediction error:", e) result["species_top"] = [] result["species_probs"] = {} result["note"] = (result.get("note", "") + " ; species_pred_error") if result.get("note") else "species_pred_error" return result @app.post("/predict", response_model=BloomPredictionResponse) async def predict_bloom(req: BloomPredictionRequest): start_time = time.time() # Validate date try: req_dt = datetime.strptime(req.date, "%Y-%m-%d") except ValueError: raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD") # Get elevation once elevation = get_elevation_from_ee(req.lat, req.lon) year = req_dt.year monthly_results = [None] * 12 # Run monthly tasks in parallel with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: futures = { ex.submit(process_month_task, req.lat, req.lon, year, month, elevation): month for month in range(1, 13) } for fut in as_completed(futures): month = futures[fut] try: res = fut.result() except Exception as e: print(f"❌ month {month} processing error:", e) res = { "month": month, "sample_date": date(year, month, 15).strftime("%Y-%m-%d"), "ml_bloom_probability": 0.0, "ml_prediction": "NO_BLOOM", "ml_confidence": "LOW", "species_top": [], "species_probs": {}, "elevation_m": elevation, "data_quality": None, "satellite": None, "note": "processing_error" } monthly_results[month - 1] = res # Extract raw probabilities raw_probs = np.array([ (mr.get("ml_bloom_probability") or 0.0) if isinstance(mr, dict) else (mr.ml_bloom_probability or 0.0) for mr in monthly_results ], dtype=float) # debug prints print("DEBUG raw_probs (months 1..12):", raw_probs.tolist()) print("DEBUG raw_stats min,max,mean:", float(raw_probs.min()), float(raw_probs.max()), float(raw_probs.mean())) # --- smoothing: returns (monthly_perc, monthly_scores) try: monthly_perc, monthly_scores = smooth_monthly_probs_preserve_magnitude( raw_probs.tolist(), alpha=ALPHA, sigma=SMOOTH_SIGMA ) except Exception as e: print("❌ smoothing error:", e) # fallback: simple normalization & identity scores safe = np.clip(raw_probs, 0.0, 100.0) s = float(safe.sum()) if safe.sum() > 0 else 1.0 monthly_perc = ((safe / s) * 100.0).round(3).tolist() monthly_scores = (np.clip(safe, 0.0, 100.0)).round(3).tolist() # Defensive flattening: sometimes elements are lists (bad), fix that def _ensure_flat_float_list(x): out = [] for v in x: if isinstance(v, list) or isinstance(v, tuple) or hasattr(v, '__iter__') and not isinstance(v, (str, bytes)): # take first numeric element found fv = None for cand in v: if isinstance(cand, (int, float, np.floating, np.integer)): fv = float(cand) break out.append(fv if fv is not None else 0.0) else: try: out.append(float(v)) except Exception: out.append(0.0) return out monthly_perc = _ensure_flat_float_list(monthly_perc) monthly_scores = _ensure_flat_float_list(monthly_scores) # safety lengths if len(monthly_perc) != 12 or len(monthly_scores) != 12: print("❌ smoothing returned wrong length; resizing to 12 with uniform fallback") monthly_perc = (np.ones(12) * (100.0 / 12.0)).tolist() monthly_scores = (np.ones(12) * np.clip(raw_probs.max(), 0.0, 100.0)).tolist() # Build monthly_curve for frontend (percentages) monthly_curve = {i+1: float(round(monthly_perc[i], 3)) for i in range(12)} print("DEBUG monthly_perc:", monthly_perc) print("DEBUG monthly_scores (abs 0..100):", monthly_scores) # Use absolute scores for detection/thresholding peak_idx = int(np.argmax(monthly_scores)) peak_month = peak_idx + 1 peak_score = float(monthly_scores[peak_idx]) # optional ambiguous-season check: too many months near peak near_peak_count = sum(1 for s in monthly_scores if s >= (peak_score * 0.9)) ambiguous = near_peak_count >= 4 # Decide status based on peak_score and bell shape bell_ok, bell_diag = is_bell_shaped(list(monthly_perc)) mean_score = float(np.mean(monthly_scores)) median_score = float(np.median(monthly_scores)) near_peak_count = sum(1 for s in monthly_scores if s >= (peak_score * 0.9)) months_above_threshold = sum(1 for s in monthly_scores if s >= MIN_BLOOM_THRESHOLD) print(f"DEBUG detection: peak_score={peak_score}, mean={mean_score:.2f}, median={median_score:.2f}, near_peak_count={near_peak_count}, months_above_{MIN_BLOOM_THRESHOLD}={months_above_threshold}, bell_ok={bell_ok}, bell_diag={bell_diag}") # Heuristic decision: # 1) If the absolute peak is below MIN_PEAK_FOR_BELL -> NO_BLOOM if peak_score < MIN_PEAK_FOR_BELL: status = "NO_BLOOM" top_species = None bloom_window = None peak_month_out = None peak_prob_out = None else: # 2) If peak is high enough, accept bloom even if not perfectly bell-shaped, # unless the season is extremely ambiguous (many months nearly at peak). # Use a dominance ratio: peak relative to mean/median. dominance_ratio = (peak_score / (mean_score + 1e-9)) # Criteria to accept bloom: accept_bloom = False # a) clear bell-shaped + peak high -> accept if bell_ok: accept_bloom = True # b) peak strongly dominates the rest -> accept elif dominance_ratio >= 1.25 and near_peak_count <= 4: accept_bloom = True # c) several months above bloom threshold (broad season) but peak high -> accept but mark low confidence elif months_above_threshold >= 3 and peak_score >= (MIN_BLOOM_THRESHOLD + 10): accept_bloom = True if not accept_bloom: # Ambiguous/flat high-months case -> LOW_CONFIDENCE status = "LOW_CONFIDENCE" top_species = None bloom_window = [i+1 for i, s in enumerate(monthly_scores) if s > (MIN_BLOOM_THRESHOLD * 0.5)] peak_month_out = peak_month peak_prob_out = peak_score else: # Strong enough to declare bloom # If the peak is only moderately high, use LOW_CONFIDENCE; otherwise BLOOM_DETECTED if peak_score < (MIN_BLOOM_THRESHOLD + 5): status = "LOW_CONFIDENCE" else: status = "BLOOM_DETECTED" bloom_window = [i+1 for i, p in enumerate(monthly_scores) if p > MIN_BLOOM_THRESHOLD] peak_month_out = peak_month peak_prob_out = peak_score # species prediction when we accept bloom try: peak_result = monthly_results[peak_idx] if isinstance(peak_result, dict): doy = peak_result.get("day_of_year") else: doy = date(year, peak_month, 15).timetuple().tm_yday species_predictions = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES) top_species = [SpeciesResult(name=sp, probability=round(prob * 100.0, 2)) for sp, prob in species_predictions] except Exception as e: print(f"❌ species prediction error at final stage: {e}") top_species = None processing_time = round(time.time() - start_time, 2) # Build response (ensure fields match your Pydantic model) response = BloomPredictionResponse( success=True, status=status, # requested_date field removed from model earlier; ensure you aren't passing unexpected args peak_month=peak_month_out, peak_probability=peak_prob_out, bloom_window=bloom_window, top_species=top_species, monthly_probabilities=monthly_curve, ) # optionally attach processing_time as extra info in logs (not in model) print(f"INFO processing_time: {processing_time}s, peak_score: {peak_score}, peak_month: {peak_month}") return response # ------------------------------ # Local run # ------------------------------ if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))