Spaces:
Sleeping
Sleeping
| # app.py (updated: fixes species bug, adds caching, parallel monthly queries, monthly bell curve) | |
| import os | |
| import time | |
| import math | |
| import joblib | |
| import ee | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime, date, timedelta | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, Field | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Tuple | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import threading | |
| import json | |
| import scipy.signal | |
| # google oauth helpers (used earlier) | |
| from google.oauth2.credentials import Credentials | |
| from google.auth.transport.requests import Request as GoogleRequest | |
| from google.oauth2 import service_account as google_service_account | |
| # ------------------------------ | |
| # CONFIG / FILENAMES / TUNABLES | |
| # ------------------------------ | |
| MODEL_FILE = Path("mil_bloom_model.joblib") | |
| SCALER_FILE = Path("mil_scaler.joblib") | |
| FEATURES_FILE = Path("mil_features.joblib") | |
| PHENO_FILE = Path("phenologythingy.csv") | |
| SPECIES_STATS_FILE = Path("species_stats.csv") | |
| MIN_BLOOM_THRESHOLD = float(os.environ.get("MIN_BLOOM_THRESHOLD", 20.0)) # minimum probability to predict species | |
| MIN_PEAK_FOR_BELL = float(os.environ.get("MIN_PEAK_FOR_BELL", 25.0)) | |
| ELEV_IMAGE_ID = "USGS/SRTMGL1_003" | |
| BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200)) | |
| MAX_DAYS_BACK = int(os.environ.get("MAX_DAYS_BACK", 30)) | |
| MIN_COUNT_FOR_SPECIES = int(os.environ.get("MIN_COUNT_FOR_SPECIES", 20)) | |
| TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5)) | |
| DOY_BINS = 366 | |
| DOY_SMOOTH = 15 | |
| EPS_STD = 1.0 | |
| # TUNABLES (additions) | |
| ALPHA = float(os.environ.get("MONTH_ALPHA", 2.0)) # >1 sharpens peaks, 1.0 = no change | |
| SMOOTH_SIGMA = float(os.environ.get("SMOOTH_SIGMA", 1.2)) # gaussian sigma in months | |
| TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5)) # how many species to return | |
| MIN_CURVE_PROB = float(os.environ.get("MIN_CURVE_PROB", 0.01)) # min per-month percentage floor | |
| # Tune parallelism: how many months to fetch at once | |
| MAX_WORKERS = int(os.environ.get("MAX_WORKERS", 4)) | |
| # EE OAuth env vars expected in Hugging Face Space secrets | |
| CLIENT_ID = os.environ.get("CLIENT_ID") | |
| CLIENT_SECRET = os.environ.get("CLIENT_SECRET") | |
| REFRESH_TOKEN = os.environ.get("REFRESH_TOKEN") | |
| EE_PROJECT = os.environ.get("PROJECT") or os.environ.get("EE_PROJECT") or None | |
| EE_SCOPES = [ | |
| "https://www.googleapis.com/auth/earthengine", | |
| "https://www.googleapis.com/auth/cloud-platform", | |
| "https://www.googleapis.com/auth/drive", | |
| "https://www.googleapis.com/auth/devstorage.full_control", | |
| ] | |
| # ------------------------------ | |
| # Pydantic models | |
| # ------------------------------ | |
| class BloomPredictionRequest(BaseModel): | |
| lat: float = Field(..., ge=-90, le=90) | |
| lon: float = Field(..., ge=-180, le=180) | |
| date: str = Field(..., description="YYYY-MM-DD") | |
| class SimplifiedMonthlyResult(BaseModel): | |
| month: int | |
| bloom_probability: float | |
| prediction: str # "BLOOM" or "NO_BLOOM" | |
| class SpeciesResult(BaseModel): | |
| name: str | |
| probability: float # as percentage | |
| class BloomPredictionResponse(BaseModel): | |
| success: bool | |
| status: str # "BLOOM_DETECTED", "NO_BLOOM", "LOW_CONFIDENCE" | |
| # Only include these if there's a valid bloom season | |
| peak_month: Optional[int] = None | |
| peak_probability: Optional[float] = None | |
| bloom_window: Optional[List[int]] = None # months with >40% probability | |
| # Only include species if peak > threshold | |
| top_species: Optional[List[SpeciesResult]] = None | |
| # Simplified monthly data (probabilities only) | |
| monthly_probabilities: Dict[int, float] | |
| # ------------------------------ | |
| # Globals & cache | |
| # ------------------------------ | |
| ML_MODEL = None | |
| SCALER = None | |
| FEATURE_COLUMNS = None | |
| SPECIES_STATS_DF = None | |
| DOY_HIST_MAP: Dict[str, np.ndarray] = {} | |
| # simple in-memory cache (keyed by (lat,lon,date_str) -> sat_data) | |
| ee_cache_lock = threading.Lock() | |
| ee_cache: Dict[Tuple[float, float, str], dict] = {} | |
| # ------------------------------ | |
| # Utility functions | |
| # ------------------------------ | |
| def gaussian_kernel(length=12, sigma=1.2): | |
| """Return a normalized 1D gaussian kernel of size `length` centered (odd/even ok) with sigma in index units.""" | |
| # create symmetric kernel (size = length*3 is overkill; we'll sample it centered) | |
| half = max(6, int(3 * sigma)) | |
| xs = np.arange(-half, half + 1) | |
| kern = np.exp(-0.5 * (xs / sigma) ** 2) | |
| kern = kern / kern.sum() | |
| return kern | |
| def smooth_monthly_probs_preserve_magnitude(raw_probs, | |
| alpha=ALPHA, | |
| sigma=SMOOTH_SIGMA, | |
| min_curve_prob=MIN_CURVE_PROB): | |
| """ | |
| Convert raw_probs (0..100) -> two arrays: | |
| - monthly_scores: absolute smoothed scores scaled to 0..100 (preserves magnitude) | |
| - monthly_perc: normalized percentages summing to 100 (for plotting) | |
| Steps: | |
| 1. scale to 0..1 | |
| 2. tiny contrast-stretch so differences matter | |
| 3. apply exponent alpha to emphasize peaks | |
| 4. circular gaussian smoothing | |
| 5. monthly_scores: scale smoothed values so max -> 100 | |
| 6. monthly_perc: normalize the same smoothed values to sum -> 100 | |
| """ | |
| a = np.asarray(raw_probs, dtype=float) / 100.0 # 0..1 | |
| # Guard: if all zeros, return uniform small floor | |
| if a.sum() == 0 or np.allclose(a, 0.0): | |
| uniform = np.ones(12) * (min_curve_prob / 100.0) | |
| monthly_scores = (uniform / uniform.max()) * 100.0 | |
| monthly_perc = (uniform / uniform.sum()) * 100.0 | |
| return monthly_perc.tolist(), monthly_scores.tolist() | |
| # Contrast stretch (min-max) to amplify differences | |
| amin = float(a.min()) | |
| amax = float(a.max()) | |
| rng = amax - amin | |
| if rng < 1e-6: | |
| a_cs = a - amin | |
| if a_cs.max() <= 1e-12: | |
| a_cs = np.ones_like(a) * 1e-6 | |
| else: | |
| a_cs = a_cs / (a_cs.max() + 1e-12) | |
| else: | |
| a_cs = (a - amin) / (rng + 1e-12) | |
| # floor tiny values (avoid zero everywhere) | |
| floor = min_curve_prob / 100.0 | |
| a_cs = np.clip(a_cs, floor * 1e-3, 1.0) | |
| # sharpen peaks | |
| if alpha != 1.0: | |
| a_sh = np.power(a_cs, alpha) | |
| else: | |
| a_sh = a_cs | |
| # circular pad and gaussian kernel | |
| sigma = float(max(0.6, sigma)) | |
| pad = max(3, int(round(3 * sigma))) | |
| padded = np.concatenate([a_sh[-pad:], a_sh, a_sh[:pad]]) | |
| kern_range = np.arange(-pad, pad + 1) | |
| kern = np.exp(-0.5 * (kern_range / sigma) ** 2) | |
| kern = kern / kern.sum() | |
| smoothed = np.convolve(padded, kern, mode='same') | |
| center = smoothed[pad:pad+12] | |
| center = np.clip(center, 0.0, None) | |
| # monthly_perc: normalized to sum 100 for plotting | |
| if center.sum() <= 0: | |
| center = np.ones_like(center) | |
| norm = center / center.sum() | |
| monthly_perc = (norm * 100.0).round(3) | |
| # monthly_scores: scale by max to 0..100 (preserve magnitude) | |
| maxv = float(center.max()) | |
| if maxv <= 0: | |
| monthly_scores = np.ones_like(center) * (min_curve_prob) | |
| else: | |
| monthly_scores = (center / maxv) * 100.0 | |
| monthly_scores = np.round(monthly_scores, 3) | |
| # final safety normalization | |
| s = float(monthly_perc.sum()) | |
| if s <= 0: | |
| monthly_perc = np.ones(12) * (100.0 / 12.0) | |
| else: | |
| monthly_perc = monthly_perc * (100.0 / monthly_perc.sum()) | |
| return monthly_perc.tolist(), monthly_scores.tolist() | |
| def is_bell_shaped(perc_list): | |
| """ | |
| Basic unimodal/symmetric check: | |
| - count peaks (local maxima) - expect 1 | |
| - compute skewness sign (negative/positive) | |
| - compute ratio of mass on left vs right of peak (expect roughly symmetric) | |
| Returns (is_bell, diagnostics_dict) | |
| """ | |
| arr = np.asarray(perc_list, dtype=float) | |
| # small smoothing to remove noise | |
| arr_smooth = np.array(scipy.signal.savgol_filter(arr, 5 if len(arr)>=5 else len(arr)-1, 3 if len(arr)>=3 else 1)) | |
| # find peaks | |
| peaks = (np.diff(np.sign(np.diff(arr_smooth))) < 0).nonzero()[0] + 1 # indices of local maxima | |
| num_peaks = len(peaks) | |
| peak_idx = int(peaks[0]) + 1 if num_peaks >= 1 else int(arr_smooth.argmax()) | |
| # symmetry: mass difference left/right of peak | |
| left_mass = arr[:peak_idx].sum() | |
| right_mass = arr[peak_idx+1:].sum() if peak_idx+1 < len(arr) else 0.0 | |
| sym_ratio = left_mass / (right_mass + 1e-9) | |
| # skewness: | |
| m = arr.mean() | |
| s = arr.std(ddof=0) if arr.std(ddof=0) > 0 else 1.0 | |
| skew = ((arr - m) ** 3).mean() / (s ** 3) | |
| # heuristics | |
| is_unimodal = (num_peaks <= 1) | |
| is_symmish = 0.5 <= sym_ratio <= 2.0 # within factor 2 | |
| # final decision | |
| is_bell = is_unimodal and is_symmish | |
| diagnostics = { | |
| "num_peaks": num_peaks, | |
| "peak_month": int(np.argmax(arr))+1, | |
| "sym_ratio_left_to_right": float(sym_ratio), | |
| "skewness": float(skew), | |
| "is_unimodal": bool(is_unimodal), | |
| "is_symmetric_enough": bool(is_symmish) | |
| } | |
| return bool(is_bell), diagnostics | |
| def gaussian_pdf_scalar(x_scalar, mean, std): | |
| """Return Gaussian PDF scalar for scalar x. If x_scalar is None or NaN, return 1.0 (neutral).""" | |
| try: | |
| if x_scalar is None or (isinstance(x_scalar, float) and np.isnan(x_scalar)): | |
| return 1.0 | |
| std = max(float(std), 1e-6) | |
| coef = 1.0 / (std * math.sqrt(2 * math.pi)) | |
| z = (float(x_scalar) - float(mean)) / std | |
| return coef * math.exp(-0.5 * z * z) | |
| except Exception: | |
| # fallback neutral | |
| return 1.0 | |
| def circular_histogram(doys, bins=DOY_BINS, smooth_window=DOY_SMOOTH): | |
| if len(doys) == 0: | |
| return np.ones(bins) / bins | |
| counts = np.bincount(doys.astype(int), minlength=bins+1)[1:] | |
| window = np.ones(smooth_window) / smooth_window | |
| doubled = np.concatenate([counts, counts]) | |
| smoothed = np.convolve(doubled, window, mode='same')[:bins] | |
| total = smoothed.sum() | |
| if total <= 0: | |
| return np.ones(bins) / bins | |
| return smoothed / total | |
| # ------------------------------ | |
| # Earth Engine init (OAuth refresh-token or service account fallback) | |
| # ------------------------------ | |
| def initialize_ee_from_env(): | |
| # Try OAuth refresh token first | |
| try: | |
| if CLIENT_ID and CLIENT_SECRET and REFRESH_TOKEN: | |
| creds = Credentials( | |
| token=None, | |
| refresh_token=REFRESH_TOKEN, | |
| client_id=CLIENT_ID, | |
| client_secret=CLIENT_SECRET, | |
| token_uri="https://oauth2.googleapis.com/token", | |
| scopes=EE_SCOPES | |
| ) | |
| request = GoogleRequest() | |
| creds.refresh(request) | |
| ee.Initialize(credentials=creds, project=EE_PROJECT) | |
| print("β Earth Engine initialized with OAuth refresh-token credentials") | |
| return True | |
| except Exception as e: | |
| print("β οΈ OAuth refresh-token init failed:", e) | |
| # Try service account JSON in env var EE_SERVICE_ACCOUNT_JSON | |
| try: | |
| sa_json_env = os.environ.get("EE_SERVICE_ACCOUNT_JSON") | |
| sa_file_env = os.environ.get("EE_SERVICE_ACCOUNT_FILE") | |
| if sa_json_env: | |
| sa_data = json.loads(sa_json_env) | |
| tmp_path = "/tmp/ee_service_account.json" | |
| with open(tmp_path, "w") as f: | |
| json.dump(sa_data, f) | |
| creds = google_service_account.Credentials.from_service_account_file(tmp_path, scopes=EE_SCOPES) | |
| ee.Initialize(credentials=creds, project=EE_PROJECT) | |
| print("β Earth Engine initialized with service-account (from EE_SERVICE_ACCOUNT_JSON)") | |
| return True | |
| if sa_file_env and Path(sa_file_env).exists(): | |
| creds = google_service_account.Credentials.from_service_account_file(sa_file_env, scopes=EE_SCOPES) | |
| ee.Initialize(credentials=creds, project=EE_PROJECT) | |
| print("β Earth Engine initialized with service-account (from file)") | |
| return True | |
| except Exception as e: | |
| print("β οΈ Service-account init failed:", e) | |
| # fallback to default | |
| try: | |
| ee.Initialize(project=EE_PROJECT) if EE_PROJECT else ee.Initialize() | |
| print("β Earth Engine initialized via default ee.Initialize() (fallback).") | |
| return True | |
| except Exception as e: | |
| print("β Earth Engine initialization failed (all methods):", e) | |
| return False | |
| def get_elevation_from_ee(lat, lon): | |
| try: | |
| img = ee.Image(ELEV_IMAGE_ID) | |
| pt = ee.Geometry.Point([float(lon), float(lat)]) | |
| rr = img.reduceRegion(ee.Reducer.first(), pt, scale=30, maxPixels=1e6) | |
| if rr is None: | |
| return None | |
| try: | |
| val = rr.get("elevation").getInfo() | |
| return float(val) if val is not None else None | |
| except Exception: | |
| keys = rr.keys().getInfo() | |
| for k in keys: | |
| v = rr.get(k).getInfo() | |
| if isinstance(v, (int, float)): | |
| return float(v) | |
| return None | |
| except Exception as e: | |
| print("β get_elevation_from_ee error:", e) | |
| return None | |
| # ------------------------------ | |
| # Satellite retrieval + caching | |
| # ------------------------------ | |
| def get_single_date_satellite_data(lat, lon, date_str, satellite, buffer_meters, area): | |
| collection_id = "LANDSAT/LC09/C02/T1_L2" if satellite == "Landsat-9" else "LANDSAT/LC08/C02/T1_L2" | |
| try: | |
| filtered = (ee.ImageCollection(collection_id) | |
| .filterBounds(area) | |
| .filterDate(date_str, f"{date_str}T23:59:59") | |
| .sort("CLOUD_COVER") | |
| .limit(1)) | |
| size = int(filtered.size().getInfo()) | |
| if size == 0: | |
| return None | |
| image = ee.Image(filtered.first()) | |
| info = image.getInfo().get("properties", {}) | |
| cloud_cover = float(info.get("CLOUD_COVER", 100.0)) | |
| if cloud_cover > 80: | |
| return None | |
| ndvi = image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI") | |
| ndwi = image.normalizedDifference(["SR_B3", "SR_B5"]).rename("NDWI") | |
| evi = image.expression( | |
| "2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))", | |
| {"NIR": image.select("SR_B5"), "RED": image.select("SR_B4"), "BLUE": image.select("SR_B2")}, | |
| ).rename("EVI") | |
| lst = image.select("ST_B10").multiply(0.00341802).add(149.0).subtract(273.15).rename("LST") | |
| composite = ndvi.addBands([ndwi, evi, lst]) | |
| stats = composite.reduceRegion( | |
| reducer=ee.Reducer.mean(), geometry=area, scale=30, maxPixels=1e6, bestEffort=True | |
| ).getInfo() | |
| ndvi_val = stats.get("NDVI") | |
| if ndvi_val is None: | |
| return None | |
| ndwi_val = stats.get("NDWI") | |
| evi_val = stats.get("EVI") | |
| lst_val = stats.get("LST") | |
| current_dt = datetime.strptime(date_str, "%Y-%m-%d") | |
| return { | |
| "ndvi": float(ndvi_val), | |
| "ndwi": float(ndwi_val) if ndwi_val is not None else None, | |
| "evi": float(evi_val) if evi_val is not None else None, | |
| "lst": float(lst_val) if lst_val is not None else None, | |
| "cloud_cover": float(cloud_cover), | |
| "month": current_dt.month, | |
| "day_of_year": current_dt.timetuple().tm_yday, | |
| "satellite": satellite, | |
| "date": date_str, | |
| "buffer_size": buffer_meters, | |
| } | |
| except Exception as e: | |
| print("β get_single_date_satellite_data error:", e) | |
| return None | |
| def get_satellite_data_with_fallback(lat, lon, target_dt, satellite, buffer_meters, area, max_days_back=MAX_DAYS_BACK): | |
| for days_back in range(0, max_days_back + 1): | |
| current_date = (target_dt - timedelta(days=days_back)).strftime("%Y-%m-%d") | |
| # check cache (only for final results of get_essential_vegetation_data). Here we do not use cache per satellite | |
| data = get_single_date_satellite_data(lat, lon, current_date, satellite, buffer_meters, area) | |
| if data and data.get("ndvi") is not None: | |
| data["original_request_date"] = target_dt.strftime("%Y-%m-%d") | |
| data["actual_data_date"] = current_date | |
| data["days_offset"] = days_back | |
| return data | |
| return None | |
| def get_essential_vegetation_data_with_cache(lat, lon, target_date, buffer_meters=BUFFER_METERS, max_days_back=MAX_DAYS_BACK): | |
| key = (round(float(lat), 6), round(float(lon), 6), str(target_date)) | |
| with ee_cache_lock: | |
| if key in ee_cache: | |
| return ee_cache[key] | |
| # Else fetch (L9 -> L8) | |
| point = ee.Geometry.Point([float(lon), float(lat)]) | |
| area = point.buffer(buffer_meters) | |
| target_dt = datetime.strptime(target_date, "%Y-%m-%d") | |
| data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-9", buffer_meters, area, max_days_back) | |
| if not data: | |
| data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-8", buffer_meters, area, max_days_back) | |
| with ee_cache_lock: | |
| ee_cache[key] = data | |
| return data | |
| # ------------------------------ | |
| # ML prediction wrapper (unchanged features; elevation NOT passed to ML) | |
| # ------------------------------ | |
| def predict_bloom_with_ml(features_dict): | |
| ndvi = features_dict.get("ndvi", 0.0) or 0.0 | |
| evi = features_dict.get("evi", 0.0) or 0.0 | |
| if ndvi < 0.05: | |
| return {"bloom_probability": 8.0, "prediction": "NO_BLOOM", "confidence": "HIGH"} | |
| if evi < 0.1 and ndvi < 0.1: | |
| return {"bloom_probability": 10.0, "prediction": "NO_BLOOM", "confidence": "HIGH"} | |
| if ML_MODEL is not None and SCALER is not None: | |
| try: | |
| features_array = np.array( | |
| [ | |
| [ | |
| float(features_dict.get("ndvi", 0.0)), | |
| float(features_dict.get("ndwi", 0.0) or 0.0), | |
| float(features_dict.get("evi", 0.0) or 0.0), | |
| float(features_dict.get("lst", 0.0) or 0.0), | |
| float(features_dict.get("cloud_cover", 0.0) or 0.0), | |
| float(features_dict.get("month", 0) or 0), | |
| float(features_dict.get("day_of_year", 0) or 0), | |
| ] | |
| ], | |
| dtype=np.float64, | |
| ) | |
| features_scaled = SCALER.transform(features_array) | |
| probabilities = ML_MODEL.predict_proba(features_scaled) | |
| bloom_prob = probabilities[0, 1] if probabilities.shape[1] == 2 else probabilities[0, 0] | |
| prediction = ML_MODEL.predict(features_scaled)[0] | |
| bloom_prob_pct = round(float(bloom_prob * 100.0), 2) | |
| if bloom_prob_pct > 75 or bloom_prob_pct < 25: | |
| conf = "HIGH" | |
| elif bloom_prob_pct > 60 or bloom_prob_pct < 40: | |
| conf = "MEDIUM" | |
| else: | |
| conf = "LOW" | |
| return {"bloom_probability": bloom_prob_pct, "prediction": "BLOOM" if prediction == 1 else "NO_BLOOM", "confidence": conf} | |
| except Exception as e: | |
| print("β ML model error:", e) | |
| return predict_bloom_fallback(features_dict) | |
| def predict_bloom_fallback(features_dict): | |
| ndvi = float(features_dict.get("ndvi") or 0.0) | |
| ndwi = float(features_dict.get("ndwi") or 0.0) | |
| evi = float(features_dict.get("evi") or 0.0) | |
| lst = float(features_dict.get("lst") or 0.0) | |
| month = int(features_dict.get("month") or 1) | |
| score = 0.0 | |
| if evi > 0.7: | |
| score += 50 | |
| elif evi > 0.5: | |
| score += 35 | |
| elif evi > 0.3: | |
| score += 20 | |
| if ndvi > 0.5: | |
| score += 25 | |
| elif ndvi > 0.3: | |
| score += 15 | |
| if -0.2 < ndwi < 0.05: | |
| score += 15 | |
| if 12 < lst < 32: | |
| score += 12 | |
| if month in [3, 4, 5]: | |
| score += 15 | |
| if month in [11, 12, 1, 2]: | |
| score -= 3 | |
| prob = min(90, max(8, score)) | |
| if prob > 52: | |
| pred = "BLOOM" | |
| conf = "MEDIUM" if prob > 65 else "LOW" | |
| else: | |
| pred = "NO_BLOOM" | |
| conf = "MEDIUM" if prob < 25 else "LOW" | |
| return {"bloom_probability": round(prob, 2), "prediction": pred, "confidence": conf} | |
| # ------------------------------ | |
| # Species stats builder / predictor (fixed) | |
| # ------------------------------ | |
| def load_or_build_species_stats(): | |
| global PHENO_FILE, SPECIES_STATS_FILE | |
| if SPECIES_STATS_FILE.exists(): | |
| df = pd.read_csv(SPECIES_STATS_FILE) | |
| doy_map = {} | |
| for s in df["species"].tolist(): | |
| doy_map[s] = np.ones(DOY_BINS) / DOY_BINS | |
| return df, doy_map | |
| if PHENO_FILE.exists(): | |
| ph = pd.read_csv(PHENO_FILE, low_memory=False) | |
| if "phenophaseStatus" in ph.columns: | |
| ph["phenophaseStatus"] = ph["phenophaseStatus"].astype(str).str.strip().str.lower() | |
| ph_yes = ph[ph["phenophaseStatus"].str.startswith("y")].copy() | |
| else: | |
| ph_yes = ph.copy() | |
| ph_yes = ph_yes.dropna(subset=["elevation"]) | |
| if "dayOfYear" in ph_yes.columns: | |
| ph_yes["dayOfYear"] = pd.to_numeric(ph_yes["dayOfYear"], errors="coerce").dropna().astype(int).clip(1, 366) | |
| rows = [] | |
| doy_map = {} | |
| grouped = ph_yes.groupby("scientificName") | |
| for name, g in grouped: | |
| cnt = len(g) | |
| mean_elev = float(g["elevation"].dropna().mean()) if cnt > 0 else np.nan | |
| std_elev = float(g["elevation"].dropna().std(ddof=0)) if cnt > 0 else EPS_STD | |
| std_elev = max(std_elev if not np.isnan(std_elev) else 0.0, EPS_STD) | |
| rows.append({"species": name, "count": cnt, "mean_elev": mean_elev, "std_elev": std_elev}) | |
| if "dayOfYear" in g.columns: | |
| doy_map[name] = circular_histogram(g["dayOfYear"].to_numpy(dtype=int)) | |
| else: | |
| doy_map[name] = np.ones(DOY_BINS) / DOY_BINS | |
| species_df = pd.DataFrame(rows) | |
| total = species_df["count"].sum() | |
| species_df["prior"] = species_df["count"] / total if total > 0 else 1.0 / max(1, len(species_df)) | |
| rare = species_df[species_df["count"] < MIN_COUNT_FOR_SPECIES] | |
| frequent = species_df[species_df["count"] >= MIN_COUNT_FOR_SPECIES] | |
| final_rows = frequent.to_dict("records") | |
| if len(rare) > 0: | |
| rare_names = rare["species"].tolist() | |
| rare_obs = ph_yes[ph_yes["scientificName"].isin(rare_names)] | |
| total_rare = len(rare_obs) | |
| if total_rare > 0: | |
| mean_other = float(rare_obs["elevation"].dropna().mean()) | |
| std_other = float(rare_obs["elevation"].dropna().std(ddof=0)) if total_rare > 1 else EPS_STD | |
| std_other = max(std_other if not np.isnan(std_other) else 0.0, EPS_STD) | |
| final_rows.append( | |
| { | |
| "species": "OTHER", | |
| "count": int(total_rare), | |
| "mean_elev": mean_other, | |
| "std_elev": std_other, | |
| "prior": int(total_rare) / total if total > 0 else int(total_rare), | |
| } | |
| ) | |
| doy_map["OTHER"] = circular_histogram(rare_obs["dayOfYear"].to_numpy(dtype=int)) if "dayOfYear" in rare_obs.columns else np.ones(DOY_BINS) / DOY_BINS | |
| final_df = pd.DataFrame(final_rows).fillna(0) | |
| if "prior" not in final_df.columns: | |
| t2 = final_df["count"].sum() | |
| final_df["prior"] = final_df["count"] / t2 if t2 > 0 else 1.0 / len(final_df) | |
| return final_df, doy_map | |
| return pd.DataFrame(columns=["species", "count", "mean_elev", "std_elev", "prior"]), {} | |
| def predict_species_by_elevation(elevation, doy=None, top_k=TOP_K_SPECIES): | |
| """ | |
| Return top_k list of tuples (species, prob_fraction in [0,1]). | |
| If elevation is None, rely on priors and DOY histogram if available. | |
| """ | |
| global SPECIES_STATS_DF, DOY_HIST_MAP | |
| if SPECIES_STATS_DF is None or SPECIES_STATS_DF.empty: | |
| return [] | |
| species = SPECIES_STATS_DF["species"].tolist() | |
| priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float) | |
| means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float) | |
| stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float) | |
| if elevation is None or (isinstance(elevation, float) and np.isnan(elevation)): | |
| # If no elevation, start from priors | |
| post = priors.copy() | |
| post = post / post.sum() | |
| else: | |
| # compute likelihood per species using scalar gaussian | |
| likes = np.array([gaussian_pdf_scalar(elevation, means[i], stds[i]) for i in range(len(species))]) | |
| post = priors * likes | |
| if post.sum() == 0: | |
| post = np.ones(len(species)) / len(species) | |
| else: | |
| post = post / post.sum() | |
| # incorporate DOY if available | |
| if doy is not None and not np.isnan(doy): | |
| doy_idx = int(doy) - 1 | |
| doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in species]) | |
| combined = post * doy_probs | |
| if combined.sum() > 0: | |
| post = combined / combined.sum() | |
| order = np.argsort(-post) | |
| top = [] | |
| for i in order[:top_k]: | |
| top.append((species[i], float(post[i]))) | |
| return top | |
| # ------------------------------ | |
| # Lifespan: load model, initialize EE, build species stats | |
| # ------------------------------ | |
| async def lifespan(app): | |
| global ML_MODEL, SCALER, FEATURE_COLUMNS, SPECIES_STATS_DF, DOY_HIST_MAP | |
| # load optional models | |
| if MODEL_FILE.exists(): | |
| try: | |
| ML_MODEL = joblib.load(MODEL_FILE) | |
| print("β MIL model loaded.") | |
| except Exception as e: | |
| print("β MIL model load error:", e) | |
| if SCALER_FILE.exists(): | |
| try: | |
| SCALER = joblib.load(SCALER_FILE) | |
| print("β Scaler loaded.") | |
| except Exception as e: | |
| print("β Scaler load error:", e) | |
| if FEATURES_FILE.exists(): | |
| try: | |
| FEATURE_COLUMNS = joblib.load(FEATURES_FILE) | |
| print("β Features list loaded.") | |
| except Exception as e: | |
| print("β Features list load error:", e) | |
| ok = initialize_ee_from_env() | |
| if not ok: | |
| raise RuntimeError("Earth Engine initialization failed. Set credentials in Space secrets (EE_SERVICE_ACCOUNT_JSON or CLIENT_ID/CLIENT_SECRET/REFRESH_TOKEN).") | |
| try: | |
| SPECIES_STATS_DF, DOY_HIST_MAP = load_or_build_species_stats() | |
| print("β Species stats ready. species count:", len(SPECIES_STATS_DF)) | |
| except Exception as e: | |
| print("β οΈ Species stats build error:", e) | |
| SPECIES_STATS_DF = pd.DataFrame() | |
| DOY_HIST_MAP = {} | |
| yield | |
| print("π Shutting down") | |
| # ------------------------------ | |
| # FastAPI app & endpoint | |
| # ------------------------------ | |
| app = FastAPI(title="Bloom Prediction (HF Space)", lifespan=lifespan) | |
| async def root(): | |
| return {"message": "Bloom Prediction API (HF Space)", "model_loaded": ML_MODEL is not None} | |
| def process_month_task(lat, lon, year, month, elevation): | |
| """ | |
| Perform the satellite retrieval and predictions for one month (15th). | |
| Returns a dict compatible with MonthlyResult. | |
| """ | |
| sample_dt = date(year, month, 15) | |
| sample_date_str = sample_dt.strftime("%Y-%m-%d") | |
| # fetch satellite data (with caching) | |
| sat_data = get_essential_vegetation_data_with_cache(lat, lon, sample_date_str) | |
| result = { | |
| "month": month, | |
| "sample_date": sample_date_str, | |
| "ml_bloom_probability": None, | |
| "ml_prediction": None, | |
| "ml_confidence": None, | |
| "species_top": None, | |
| "species_probs": None, | |
| "elevation_m": elevation, | |
| "data_quality": None, | |
| "satellite": None, | |
| "note": None, | |
| } | |
| if sat_data is None: | |
| result["note"] = f"No satellite data within {MAX_DAYS_BACK} days for {sample_date_str}" | |
| return result | |
| # run ML | |
| ml_out = predict_bloom_with_ml(sat_data) | |
| result["ml_bloom_probability"] = float(ml_out.get("bloom_probability", 0.0)) | |
| result["ml_prediction"] = ml_out.get("prediction") | |
| result["ml_confidence"] = ml_out.get("confidence") | |
| result["data_quality"] = { | |
| "satellite": sat_data.get("satellite"), | |
| "cloud_cover": sat_data.get("cloud_cover"), | |
| "days_offset": sat_data.get("days_offset"), | |
| "buffer_radius_m": sat_data.get("buffer_size"), | |
| } | |
| result["satellite"] = sat_data.get("satellite") | |
| # species prediction if bloom | |
| try: | |
| bloom_bool = (result["ml_prediction"] == "BLOOM") or (result["ml_bloom_probability"] >= 50.0) | |
| if bloom_bool: | |
| doy = sat_data.get("day_of_year", None) | |
| top_species = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES) | |
| # convert to percentages | |
| result["species_top"] = [(s, round(p * 100.0, 2)) for s, p in top_species] | |
| # build full species_probs | |
| species_probs = {} | |
| if (SPECIES_STATS_DF is not None) and (not SPECIES_STATS_DF.empty): | |
| all_species = SPECIES_STATS_DF["species"].tolist() | |
| priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float) | |
| means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float) | |
| stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float) | |
| # If elevation missing, likes = 1 (neutral) | |
| if elevation is None or (isinstance(elevation, float) and np.isnan(elevation)): | |
| likes = np.ones(len(all_species)) | |
| else: | |
| likes = np.array([gaussian_pdf_scalar(elevation, means[i], stds[i]) for i in range(len(all_species))]) | |
| post = priors * likes | |
| if post.sum() == 0: | |
| post = np.ones(len(all_species)) / len(all_species) | |
| else: | |
| post = post / post.sum() | |
| # DOY adjust | |
| if doy is not None and not np.isnan(doy): | |
| doy_idx = int(doy) - 1 | |
| doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in all_species]) | |
| combined = post * doy_probs | |
| if combined.sum() > 0: | |
| post = combined / combined.sum() | |
| for s, p in zip(all_species, post): | |
| species_probs[s] = round(float(p * 100.0), 6) | |
| result["species_probs"] = species_probs | |
| else: | |
| result["species_top"] = [] | |
| result["species_probs"] = {} | |
| except Exception as e: | |
| print("β species prediction error:", e) | |
| result["species_top"] = [] | |
| result["species_probs"] = {} | |
| result["note"] = (result.get("note", "") + " ; species_pred_error") if result.get("note") else "species_pred_error" | |
| return result | |
| async def predict_bloom(req: BloomPredictionRequest): | |
| start_time = time.time() | |
| # Validate date | |
| try: | |
| req_dt = datetime.strptime(req.date, "%Y-%m-%d") | |
| except ValueError: | |
| raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD") | |
| # Get elevation once | |
| elevation = get_elevation_from_ee(req.lat, req.lon) | |
| year = req_dt.year | |
| monthly_results = [None] * 12 | |
| # Run monthly tasks in parallel | |
| with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex: | |
| futures = { | |
| ex.submit(process_month_task, req.lat, req.lon, year, month, elevation): month | |
| for month in range(1, 13) | |
| } | |
| for fut in as_completed(futures): | |
| month = futures[fut] | |
| try: | |
| res = fut.result() | |
| except Exception as e: | |
| print(f"β month {month} processing error:", e) | |
| res = { | |
| "month": month, | |
| "sample_date": date(year, month, 15).strftime("%Y-%m-%d"), | |
| "ml_bloom_probability": 0.0, | |
| "ml_prediction": "NO_BLOOM", | |
| "ml_confidence": "LOW", | |
| "species_top": [], | |
| "species_probs": {}, | |
| "elevation_m": elevation, | |
| "data_quality": None, | |
| "satellite": None, | |
| "note": "processing_error" | |
| } | |
| monthly_results[month - 1] = res | |
| # Extract raw probabilities | |
| raw_probs = np.array([ | |
| (mr.get("ml_bloom_probability") or 0.0) if isinstance(mr, dict) | |
| else (mr.ml_bloom_probability or 0.0) | |
| for mr in monthly_results | |
| ], dtype=float) | |
| # debug prints | |
| print("DEBUG raw_probs (months 1..12):", raw_probs.tolist()) | |
| print("DEBUG raw_stats min,max,mean:", float(raw_probs.min()), float(raw_probs.max()), float(raw_probs.mean())) | |
| # --- smoothing: returns (monthly_perc, monthly_scores) | |
| try: | |
| monthly_perc, monthly_scores = smooth_monthly_probs_preserve_magnitude( | |
| raw_probs.tolist(), alpha=ALPHA, sigma=SMOOTH_SIGMA | |
| ) | |
| except Exception as e: | |
| print("β smoothing error:", e) | |
| # fallback: simple normalization & identity scores | |
| safe = np.clip(raw_probs, 0.0, 100.0) | |
| s = float(safe.sum()) if safe.sum() > 0 else 1.0 | |
| monthly_perc = ((safe / s) * 100.0).round(3).tolist() | |
| monthly_scores = (np.clip(safe, 0.0, 100.0)).round(3).tolist() | |
| # Defensive flattening: sometimes elements are lists (bad), fix that | |
| def _ensure_flat_float_list(x): | |
| out = [] | |
| for v in x: | |
| if isinstance(v, list) or isinstance(v, tuple) or hasattr(v, '__iter__') and not isinstance(v, (str, bytes)): | |
| # take first numeric element found | |
| fv = None | |
| for cand in v: | |
| if isinstance(cand, (int, float, np.floating, np.integer)): | |
| fv = float(cand) | |
| break | |
| out.append(fv if fv is not None else 0.0) | |
| else: | |
| try: | |
| out.append(float(v)) | |
| except Exception: | |
| out.append(0.0) | |
| return out | |
| monthly_perc = _ensure_flat_float_list(monthly_perc) | |
| monthly_scores = _ensure_flat_float_list(monthly_scores) | |
| # safety lengths | |
| if len(monthly_perc) != 12 or len(monthly_scores) != 12: | |
| print("β smoothing returned wrong length; resizing to 12 with uniform fallback") | |
| monthly_perc = (np.ones(12) * (100.0 / 12.0)).tolist() | |
| monthly_scores = (np.ones(12) * np.clip(raw_probs.max(), 0.0, 100.0)).tolist() | |
| # Build monthly_curve for frontend (percentages) | |
| monthly_curve = {i+1: float(round(monthly_perc[i], 3)) for i in range(12)} | |
| print("DEBUG monthly_perc:", monthly_perc) | |
| print("DEBUG monthly_scores (abs 0..100):", monthly_scores) | |
| # Use absolute scores for detection/thresholding | |
| peak_idx = int(np.argmax(monthly_scores)) | |
| peak_month = peak_idx + 1 | |
| peak_score = float(monthly_scores[peak_idx]) | |
| # optional ambiguous-season check: too many months near peak | |
| near_peak_count = sum(1 for s in monthly_scores if s >= (peak_score * 0.9)) | |
| ambiguous = near_peak_count >= 4 | |
| # Decide status based on peak_score and bell shape | |
| bell_ok, bell_diag = is_bell_shaped(list(monthly_perc)) | |
| mean_score = float(np.mean(monthly_scores)) | |
| median_score = float(np.median(monthly_scores)) | |
| near_peak_count = sum(1 for s in monthly_scores if s >= (peak_score * 0.9)) | |
| months_above_threshold = sum(1 for s in monthly_scores if s >= MIN_BLOOM_THRESHOLD) | |
| print(f"DEBUG detection: peak_score={peak_score}, mean={mean_score:.2f}, median={median_score:.2f}, near_peak_count={near_peak_count}, months_above_{MIN_BLOOM_THRESHOLD}={months_above_threshold}, bell_ok={bell_ok}, bell_diag={bell_diag}") | |
| # Heuristic decision: | |
| # 1) If the absolute peak is below MIN_PEAK_FOR_BELL -> NO_BLOOM | |
| if peak_score < MIN_PEAK_FOR_BELL: | |
| status = "NO_BLOOM" | |
| top_species = None | |
| bloom_window = None | |
| peak_month_out = None | |
| peak_prob_out = None | |
| else: | |
| # 2) If peak is high enough, accept bloom even if not perfectly bell-shaped, | |
| # unless the season is extremely ambiguous (many months nearly at peak). | |
| # Use a dominance ratio: peak relative to mean/median. | |
| dominance_ratio = (peak_score / (mean_score + 1e-9)) | |
| # Criteria to accept bloom: | |
| accept_bloom = False | |
| # a) clear bell-shaped + peak high -> accept | |
| if bell_ok: | |
| accept_bloom = True | |
| # b) peak strongly dominates the rest -> accept | |
| elif dominance_ratio >= 1.25 and near_peak_count <= 4: | |
| accept_bloom = True | |
| # c) several months above bloom threshold (broad season) but peak high -> accept but mark low confidence | |
| elif months_above_threshold >= 3 and peak_score >= (MIN_BLOOM_THRESHOLD + 10): | |
| accept_bloom = True | |
| if not accept_bloom: | |
| # Ambiguous/flat high-months case -> LOW_CONFIDENCE | |
| status = "LOW_CONFIDENCE" | |
| top_species = None | |
| bloom_window = [i+1 for i, s in enumerate(monthly_scores) if s > (MIN_BLOOM_THRESHOLD * 0.5)] | |
| peak_month_out = peak_month | |
| peak_prob_out = peak_score | |
| else: | |
| # Strong enough to declare bloom | |
| # If the peak is only moderately high, use LOW_CONFIDENCE; otherwise BLOOM_DETECTED | |
| if peak_score < (MIN_BLOOM_THRESHOLD + 5): | |
| status = "LOW_CONFIDENCE" | |
| else: | |
| status = "BLOOM_DETECTED" | |
| bloom_window = [i+1 for i, p in enumerate(monthly_scores) if p > MIN_BLOOM_THRESHOLD] | |
| peak_month_out = peak_month | |
| peak_prob_out = peak_score | |
| # species prediction when we accept bloom | |
| try: | |
| peak_result = monthly_results[peak_idx] | |
| if isinstance(peak_result, dict): | |
| doy = peak_result.get("day_of_year") | |
| else: | |
| doy = date(year, peak_month, 15).timetuple().tm_yday | |
| species_predictions = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES) | |
| top_species = [SpeciesResult(name=sp, probability=round(prob * 100.0, 2)) for sp, prob in species_predictions] | |
| except Exception as e: | |
| print(f"β species prediction error at final stage: {e}") | |
| top_species = None | |
| processing_time = round(time.time() - start_time, 2) | |
| # Build response (ensure fields match your Pydantic model) | |
| response = BloomPredictionResponse( | |
| success=True, | |
| status=status, | |
| # requested_date field removed from model earlier; ensure you aren't passing unexpected args | |
| peak_month=peak_month_out, | |
| peak_probability=peak_prob_out, | |
| bloom_window=bloom_window, | |
| top_species=top_species, | |
| monthly_probabilities=monthly_curve, | |
| ) | |
| # optionally attach processing_time as extra info in logs (not in model) | |
| print(f"INFO processing_time: {processing_time}s, peak_score: {peak_score}, peak_month: {peak_month}") | |
| return response | |
| # ------------------------------ | |
| # Local run | |
| # ------------------------------ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) | |