BloomAI / app.py
NaseefNazrul's picture
Update app.py
0e40918 verified
# app.py (updated: fixes species bug, adds caching, parallel monthly queries, monthly bell curve)
import os
import time
import math
import joblib
import ee
import pandas as pd
import numpy as np
from datetime import datetime, date, timedelta
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional, List, Dict, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import json
import scipy.signal
# google oauth helpers (used earlier)
from google.oauth2.credentials import Credentials
from google.auth.transport.requests import Request as GoogleRequest
from google.oauth2 import service_account as google_service_account
# ------------------------------
# CONFIG / FILENAMES / TUNABLES
# ------------------------------
MODEL_FILE = Path("mil_bloom_model.joblib")
SCALER_FILE = Path("mil_scaler.joblib")
FEATURES_FILE = Path("mil_features.joblib")
PHENO_FILE = Path("phenologythingy.csv")
SPECIES_STATS_FILE = Path("species_stats.csv")
MIN_BLOOM_THRESHOLD = float(os.environ.get("MIN_BLOOM_THRESHOLD", 20.0)) # minimum probability to predict species
MIN_PEAK_FOR_BELL = float(os.environ.get("MIN_PEAK_FOR_BELL", 25.0))
ELEV_IMAGE_ID = "USGS/SRTMGL1_003"
BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200))
MAX_DAYS_BACK = int(os.environ.get("MAX_DAYS_BACK", 30))
MIN_COUNT_FOR_SPECIES = int(os.environ.get("MIN_COUNT_FOR_SPECIES", 20))
TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5))
DOY_BINS = 366
DOY_SMOOTH = 15
EPS_STD = 1.0
# TUNABLES (additions)
ALPHA = float(os.environ.get("MONTH_ALPHA", 2.0)) # >1 sharpens peaks, 1.0 = no change
SMOOTH_SIGMA = float(os.environ.get("SMOOTH_SIGMA", 1.2)) # gaussian sigma in months
TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5)) # how many species to return
MIN_CURVE_PROB = float(os.environ.get("MIN_CURVE_PROB", 0.01)) # min per-month percentage floor
# Tune parallelism: how many months to fetch at once
MAX_WORKERS = int(os.environ.get("MAX_WORKERS", 4))
# EE OAuth env vars expected in Hugging Face Space secrets
CLIENT_ID = os.environ.get("CLIENT_ID")
CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
REFRESH_TOKEN = os.environ.get("REFRESH_TOKEN")
EE_PROJECT = os.environ.get("PROJECT") or os.environ.get("EE_PROJECT") or None
EE_SCOPES = [
"https://www.googleapis.com/auth/earthengine",
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/devstorage.full_control",
]
# ------------------------------
# Pydantic models
# ------------------------------
class BloomPredictionRequest(BaseModel):
lat: float = Field(..., ge=-90, le=90)
lon: float = Field(..., ge=-180, le=180)
date: str = Field(..., description="YYYY-MM-DD")
class SimplifiedMonthlyResult(BaseModel):
month: int
bloom_probability: float
prediction: str # "BLOOM" or "NO_BLOOM"
class SpeciesResult(BaseModel):
name: str
probability: float # as percentage
class BloomPredictionResponse(BaseModel):
success: bool
status: str # "BLOOM_DETECTED", "NO_BLOOM", "LOW_CONFIDENCE"
# Only include these if there's a valid bloom season
peak_month: Optional[int] = None
peak_probability: Optional[float] = None
bloom_window: Optional[List[int]] = None # months with >40% probability
# Only include species if peak > threshold
top_species: Optional[List[SpeciesResult]] = None
# Simplified monthly data (probabilities only)
monthly_probabilities: Dict[int, float]
# ------------------------------
# Globals & cache
# ------------------------------
ML_MODEL = None
SCALER = None
FEATURE_COLUMNS = None
SPECIES_STATS_DF = None
DOY_HIST_MAP: Dict[str, np.ndarray] = {}
# simple in-memory cache (keyed by (lat,lon,date_str) -> sat_data)
ee_cache_lock = threading.Lock()
ee_cache: Dict[Tuple[float, float, str], dict] = {}
# ------------------------------
# Utility functions
# ------------------------------
def gaussian_kernel(length=12, sigma=1.2):
"""Return a normalized 1D gaussian kernel of size `length` centered (odd/even ok) with sigma in index units."""
# create symmetric kernel (size = length*3 is overkill; we'll sample it centered)
half = max(6, int(3 * sigma))
xs = np.arange(-half, half + 1)
kern = np.exp(-0.5 * (xs / sigma) ** 2)
kern = kern / kern.sum()
return kern
def smooth_monthly_probs_preserve_magnitude(raw_probs,
alpha=ALPHA,
sigma=SMOOTH_SIGMA,
min_curve_prob=MIN_CURVE_PROB):
"""
Convert raw_probs (0..100) -> two arrays:
- monthly_scores: absolute smoothed scores scaled to 0..100 (preserves magnitude)
- monthly_perc: normalized percentages summing to 100 (for plotting)
Steps:
1. scale to 0..1
2. tiny contrast-stretch so differences matter
3. apply exponent alpha to emphasize peaks
4. circular gaussian smoothing
5. monthly_scores: scale smoothed values so max -> 100
6. monthly_perc: normalize the same smoothed values to sum -> 100
"""
a = np.asarray(raw_probs, dtype=float) / 100.0 # 0..1
# Guard: if all zeros, return uniform small floor
if a.sum() == 0 or np.allclose(a, 0.0):
uniform = np.ones(12) * (min_curve_prob / 100.0)
monthly_scores = (uniform / uniform.max()) * 100.0
monthly_perc = (uniform / uniform.sum()) * 100.0
return monthly_perc.tolist(), monthly_scores.tolist()
# Contrast stretch (min-max) to amplify differences
amin = float(a.min())
amax = float(a.max())
rng = amax - amin
if rng < 1e-6:
a_cs = a - amin
if a_cs.max() <= 1e-12:
a_cs = np.ones_like(a) * 1e-6
else:
a_cs = a_cs / (a_cs.max() + 1e-12)
else:
a_cs = (a - amin) / (rng + 1e-12)
# floor tiny values (avoid zero everywhere)
floor = min_curve_prob / 100.0
a_cs = np.clip(a_cs, floor * 1e-3, 1.0)
# sharpen peaks
if alpha != 1.0:
a_sh = np.power(a_cs, alpha)
else:
a_sh = a_cs
# circular pad and gaussian kernel
sigma = float(max(0.6, sigma))
pad = max(3, int(round(3 * sigma)))
padded = np.concatenate([a_sh[-pad:], a_sh, a_sh[:pad]])
kern_range = np.arange(-pad, pad + 1)
kern = np.exp(-0.5 * (kern_range / sigma) ** 2)
kern = kern / kern.sum()
smoothed = np.convolve(padded, kern, mode='same')
center = smoothed[pad:pad+12]
center = np.clip(center, 0.0, None)
# monthly_perc: normalized to sum 100 for plotting
if center.sum() <= 0:
center = np.ones_like(center)
norm = center / center.sum()
monthly_perc = (norm * 100.0).round(3)
# monthly_scores: scale by max to 0..100 (preserve magnitude)
maxv = float(center.max())
if maxv <= 0:
monthly_scores = np.ones_like(center) * (min_curve_prob)
else:
monthly_scores = (center / maxv) * 100.0
monthly_scores = np.round(monthly_scores, 3)
# final safety normalization
s = float(monthly_perc.sum())
if s <= 0:
monthly_perc = np.ones(12) * (100.0 / 12.0)
else:
monthly_perc = monthly_perc * (100.0 / monthly_perc.sum())
return monthly_perc.tolist(), monthly_scores.tolist()
def is_bell_shaped(perc_list):
"""
Basic unimodal/symmetric check:
- count peaks (local maxima) - expect 1
- compute skewness sign (negative/positive)
- compute ratio of mass on left vs right of peak (expect roughly symmetric)
Returns (is_bell, diagnostics_dict)
"""
arr = np.asarray(perc_list, dtype=float)
# small smoothing to remove noise
arr_smooth = np.array(scipy.signal.savgol_filter(arr, 5 if len(arr)>=5 else len(arr)-1, 3 if len(arr)>=3 else 1))
# find peaks
peaks = (np.diff(np.sign(np.diff(arr_smooth))) < 0).nonzero()[0] + 1 # indices of local maxima
num_peaks = len(peaks)
peak_idx = int(peaks[0]) + 1 if num_peaks >= 1 else int(arr_smooth.argmax())
# symmetry: mass difference left/right of peak
left_mass = arr[:peak_idx].sum()
right_mass = arr[peak_idx+1:].sum() if peak_idx+1 < len(arr) else 0.0
sym_ratio = left_mass / (right_mass + 1e-9)
# skewness:
m = arr.mean()
s = arr.std(ddof=0) if arr.std(ddof=0) > 0 else 1.0
skew = ((arr - m) ** 3).mean() / (s ** 3)
# heuristics
is_unimodal = (num_peaks <= 1)
is_symmish = 0.5 <= sym_ratio <= 2.0 # within factor 2
# final decision
is_bell = is_unimodal and is_symmish
diagnostics = {
"num_peaks": num_peaks,
"peak_month": int(np.argmax(arr))+1,
"sym_ratio_left_to_right": float(sym_ratio),
"skewness": float(skew),
"is_unimodal": bool(is_unimodal),
"is_symmetric_enough": bool(is_symmish)
}
return bool(is_bell), diagnostics
def gaussian_pdf_scalar(x_scalar, mean, std):
"""Return Gaussian PDF scalar for scalar x. If x_scalar is None or NaN, return 1.0 (neutral)."""
try:
if x_scalar is None or (isinstance(x_scalar, float) and np.isnan(x_scalar)):
return 1.0
std = max(float(std), 1e-6)
coef = 1.0 / (std * math.sqrt(2 * math.pi))
z = (float(x_scalar) - float(mean)) / std
return coef * math.exp(-0.5 * z * z)
except Exception:
# fallback neutral
return 1.0
def circular_histogram(doys, bins=DOY_BINS, smooth_window=DOY_SMOOTH):
if len(doys) == 0:
return np.ones(bins) / bins
counts = np.bincount(doys.astype(int), minlength=bins+1)[1:]
window = np.ones(smooth_window) / smooth_window
doubled = np.concatenate([counts, counts])
smoothed = np.convolve(doubled, window, mode='same')[:bins]
total = smoothed.sum()
if total <= 0:
return np.ones(bins) / bins
return smoothed / total
# ------------------------------
# Earth Engine init (OAuth refresh-token or service account fallback)
# ------------------------------
def initialize_ee_from_env():
# Try OAuth refresh token first
try:
if CLIENT_ID and CLIENT_SECRET and REFRESH_TOKEN:
creds = Credentials(
token=None,
refresh_token=REFRESH_TOKEN,
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET,
token_uri="https://oauth2.googleapis.com/token",
scopes=EE_SCOPES
)
request = GoogleRequest()
creds.refresh(request)
ee.Initialize(credentials=creds, project=EE_PROJECT)
print("βœ… Earth Engine initialized with OAuth refresh-token credentials")
return True
except Exception as e:
print("⚠️ OAuth refresh-token init failed:", e)
# Try service account JSON in env var EE_SERVICE_ACCOUNT_JSON
try:
sa_json_env = os.environ.get("EE_SERVICE_ACCOUNT_JSON")
sa_file_env = os.environ.get("EE_SERVICE_ACCOUNT_FILE")
if sa_json_env:
sa_data = json.loads(sa_json_env)
tmp_path = "/tmp/ee_service_account.json"
with open(tmp_path, "w") as f:
json.dump(sa_data, f)
creds = google_service_account.Credentials.from_service_account_file(tmp_path, scopes=EE_SCOPES)
ee.Initialize(credentials=creds, project=EE_PROJECT)
print("βœ… Earth Engine initialized with service-account (from EE_SERVICE_ACCOUNT_JSON)")
return True
if sa_file_env and Path(sa_file_env).exists():
creds = google_service_account.Credentials.from_service_account_file(sa_file_env, scopes=EE_SCOPES)
ee.Initialize(credentials=creds, project=EE_PROJECT)
print("βœ… Earth Engine initialized with service-account (from file)")
return True
except Exception as e:
print("⚠️ Service-account init failed:", e)
# fallback to default
try:
ee.Initialize(project=EE_PROJECT) if EE_PROJECT else ee.Initialize()
print("βœ… Earth Engine initialized via default ee.Initialize() (fallback).")
return True
except Exception as e:
print("❌ Earth Engine initialization failed (all methods):", e)
return False
def get_elevation_from_ee(lat, lon):
try:
img = ee.Image(ELEV_IMAGE_ID)
pt = ee.Geometry.Point([float(lon), float(lat)])
rr = img.reduceRegion(ee.Reducer.first(), pt, scale=30, maxPixels=1e6)
if rr is None:
return None
try:
val = rr.get("elevation").getInfo()
return float(val) if val is not None else None
except Exception:
keys = rr.keys().getInfo()
for k in keys:
v = rr.get(k).getInfo()
if isinstance(v, (int, float)):
return float(v)
return None
except Exception as e:
print("❌ get_elevation_from_ee error:", e)
return None
# ------------------------------
# Satellite retrieval + caching
# ------------------------------
def get_single_date_satellite_data(lat, lon, date_str, satellite, buffer_meters, area):
collection_id = "LANDSAT/LC09/C02/T1_L2" if satellite == "Landsat-9" else "LANDSAT/LC08/C02/T1_L2"
try:
filtered = (ee.ImageCollection(collection_id)
.filterBounds(area)
.filterDate(date_str, f"{date_str}T23:59:59")
.sort("CLOUD_COVER")
.limit(1))
size = int(filtered.size().getInfo())
if size == 0:
return None
image = ee.Image(filtered.first())
info = image.getInfo().get("properties", {})
cloud_cover = float(info.get("CLOUD_COVER", 100.0))
if cloud_cover > 80:
return None
ndvi = image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI")
ndwi = image.normalizedDifference(["SR_B3", "SR_B5"]).rename("NDWI")
evi = image.expression(
"2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))",
{"NIR": image.select("SR_B5"), "RED": image.select("SR_B4"), "BLUE": image.select("SR_B2")},
).rename("EVI")
lst = image.select("ST_B10").multiply(0.00341802).add(149.0).subtract(273.15).rename("LST")
composite = ndvi.addBands([ndwi, evi, lst])
stats = composite.reduceRegion(
reducer=ee.Reducer.mean(), geometry=area, scale=30, maxPixels=1e6, bestEffort=True
).getInfo()
ndvi_val = stats.get("NDVI")
if ndvi_val is None:
return None
ndwi_val = stats.get("NDWI")
evi_val = stats.get("EVI")
lst_val = stats.get("LST")
current_dt = datetime.strptime(date_str, "%Y-%m-%d")
return {
"ndvi": float(ndvi_val),
"ndwi": float(ndwi_val) if ndwi_val is not None else None,
"evi": float(evi_val) if evi_val is not None else None,
"lst": float(lst_val) if lst_val is not None else None,
"cloud_cover": float(cloud_cover),
"month": current_dt.month,
"day_of_year": current_dt.timetuple().tm_yday,
"satellite": satellite,
"date": date_str,
"buffer_size": buffer_meters,
}
except Exception as e:
print("❌ get_single_date_satellite_data error:", e)
return None
def get_satellite_data_with_fallback(lat, lon, target_dt, satellite, buffer_meters, area, max_days_back=MAX_DAYS_BACK):
for days_back in range(0, max_days_back + 1):
current_date = (target_dt - timedelta(days=days_back)).strftime("%Y-%m-%d")
# check cache (only for final results of get_essential_vegetation_data). Here we do not use cache per satellite
data = get_single_date_satellite_data(lat, lon, current_date, satellite, buffer_meters, area)
if data and data.get("ndvi") is not None:
data["original_request_date"] = target_dt.strftime("%Y-%m-%d")
data["actual_data_date"] = current_date
data["days_offset"] = days_back
return data
return None
def get_essential_vegetation_data_with_cache(lat, lon, target_date, buffer_meters=BUFFER_METERS, max_days_back=MAX_DAYS_BACK):
key = (round(float(lat), 6), round(float(lon), 6), str(target_date))
with ee_cache_lock:
if key in ee_cache:
return ee_cache[key]
# Else fetch (L9 -> L8)
point = ee.Geometry.Point([float(lon), float(lat)])
area = point.buffer(buffer_meters)
target_dt = datetime.strptime(target_date, "%Y-%m-%d")
data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-9", buffer_meters, area, max_days_back)
if not data:
data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-8", buffer_meters, area, max_days_back)
with ee_cache_lock:
ee_cache[key] = data
return data
# ------------------------------
# ML prediction wrapper (unchanged features; elevation NOT passed to ML)
# ------------------------------
def predict_bloom_with_ml(features_dict):
ndvi = features_dict.get("ndvi", 0.0) or 0.0
evi = features_dict.get("evi", 0.0) or 0.0
if ndvi < 0.05:
return {"bloom_probability": 8.0, "prediction": "NO_BLOOM", "confidence": "HIGH"}
if evi < 0.1 and ndvi < 0.1:
return {"bloom_probability": 10.0, "prediction": "NO_BLOOM", "confidence": "HIGH"}
if ML_MODEL is not None and SCALER is not None:
try:
features_array = np.array(
[
[
float(features_dict.get("ndvi", 0.0)),
float(features_dict.get("ndwi", 0.0) or 0.0),
float(features_dict.get("evi", 0.0) or 0.0),
float(features_dict.get("lst", 0.0) or 0.0),
float(features_dict.get("cloud_cover", 0.0) or 0.0),
float(features_dict.get("month", 0) or 0),
float(features_dict.get("day_of_year", 0) or 0),
]
],
dtype=np.float64,
)
features_scaled = SCALER.transform(features_array)
probabilities = ML_MODEL.predict_proba(features_scaled)
bloom_prob = probabilities[0, 1] if probabilities.shape[1] == 2 else probabilities[0, 0]
prediction = ML_MODEL.predict(features_scaled)[0]
bloom_prob_pct = round(float(bloom_prob * 100.0), 2)
if bloom_prob_pct > 75 or bloom_prob_pct < 25:
conf = "HIGH"
elif bloom_prob_pct > 60 or bloom_prob_pct < 40:
conf = "MEDIUM"
else:
conf = "LOW"
return {"bloom_probability": bloom_prob_pct, "prediction": "BLOOM" if prediction == 1 else "NO_BLOOM", "confidence": conf}
except Exception as e:
print("❌ ML model error:", e)
return predict_bloom_fallback(features_dict)
def predict_bloom_fallback(features_dict):
ndvi = float(features_dict.get("ndvi") or 0.0)
ndwi = float(features_dict.get("ndwi") or 0.0)
evi = float(features_dict.get("evi") or 0.0)
lst = float(features_dict.get("lst") or 0.0)
month = int(features_dict.get("month") or 1)
score = 0.0
if evi > 0.7:
score += 50
elif evi > 0.5:
score += 35
elif evi > 0.3:
score += 20
if ndvi > 0.5:
score += 25
elif ndvi > 0.3:
score += 15
if -0.2 < ndwi < 0.05:
score += 15
if 12 < lst < 32:
score += 12
if month in [3, 4, 5]:
score += 15
if month in [11, 12, 1, 2]:
score -= 3
prob = min(90, max(8, score))
if prob > 52:
pred = "BLOOM"
conf = "MEDIUM" if prob > 65 else "LOW"
else:
pred = "NO_BLOOM"
conf = "MEDIUM" if prob < 25 else "LOW"
return {"bloom_probability": round(prob, 2), "prediction": pred, "confidence": conf}
# ------------------------------
# Species stats builder / predictor (fixed)
# ------------------------------
def load_or_build_species_stats():
global PHENO_FILE, SPECIES_STATS_FILE
if SPECIES_STATS_FILE.exists():
df = pd.read_csv(SPECIES_STATS_FILE)
doy_map = {}
for s in df["species"].tolist():
doy_map[s] = np.ones(DOY_BINS) / DOY_BINS
return df, doy_map
if PHENO_FILE.exists():
ph = pd.read_csv(PHENO_FILE, low_memory=False)
if "phenophaseStatus" in ph.columns:
ph["phenophaseStatus"] = ph["phenophaseStatus"].astype(str).str.strip().str.lower()
ph_yes = ph[ph["phenophaseStatus"].str.startswith("y")].copy()
else:
ph_yes = ph.copy()
ph_yes = ph_yes.dropna(subset=["elevation"])
if "dayOfYear" in ph_yes.columns:
ph_yes["dayOfYear"] = pd.to_numeric(ph_yes["dayOfYear"], errors="coerce").dropna().astype(int).clip(1, 366)
rows = []
doy_map = {}
grouped = ph_yes.groupby("scientificName")
for name, g in grouped:
cnt = len(g)
mean_elev = float(g["elevation"].dropna().mean()) if cnt > 0 else np.nan
std_elev = float(g["elevation"].dropna().std(ddof=0)) if cnt > 0 else EPS_STD
std_elev = max(std_elev if not np.isnan(std_elev) else 0.0, EPS_STD)
rows.append({"species": name, "count": cnt, "mean_elev": mean_elev, "std_elev": std_elev})
if "dayOfYear" in g.columns:
doy_map[name] = circular_histogram(g["dayOfYear"].to_numpy(dtype=int))
else:
doy_map[name] = np.ones(DOY_BINS) / DOY_BINS
species_df = pd.DataFrame(rows)
total = species_df["count"].sum()
species_df["prior"] = species_df["count"] / total if total > 0 else 1.0 / max(1, len(species_df))
rare = species_df[species_df["count"] < MIN_COUNT_FOR_SPECIES]
frequent = species_df[species_df["count"] >= MIN_COUNT_FOR_SPECIES]
final_rows = frequent.to_dict("records")
if len(rare) > 0:
rare_names = rare["species"].tolist()
rare_obs = ph_yes[ph_yes["scientificName"].isin(rare_names)]
total_rare = len(rare_obs)
if total_rare > 0:
mean_other = float(rare_obs["elevation"].dropna().mean())
std_other = float(rare_obs["elevation"].dropna().std(ddof=0)) if total_rare > 1 else EPS_STD
std_other = max(std_other if not np.isnan(std_other) else 0.0, EPS_STD)
final_rows.append(
{
"species": "OTHER",
"count": int(total_rare),
"mean_elev": mean_other,
"std_elev": std_other,
"prior": int(total_rare) / total if total > 0 else int(total_rare),
}
)
doy_map["OTHER"] = circular_histogram(rare_obs["dayOfYear"].to_numpy(dtype=int)) if "dayOfYear" in rare_obs.columns else np.ones(DOY_BINS) / DOY_BINS
final_df = pd.DataFrame(final_rows).fillna(0)
if "prior" not in final_df.columns:
t2 = final_df["count"].sum()
final_df["prior"] = final_df["count"] / t2 if t2 > 0 else 1.0 / len(final_df)
return final_df, doy_map
return pd.DataFrame(columns=["species", "count", "mean_elev", "std_elev", "prior"]), {}
def predict_species_by_elevation(elevation, doy=None, top_k=TOP_K_SPECIES):
"""
Return top_k list of tuples (species, prob_fraction in [0,1]).
If elevation is None, rely on priors and DOY histogram if available.
"""
global SPECIES_STATS_DF, DOY_HIST_MAP
if SPECIES_STATS_DF is None or SPECIES_STATS_DF.empty:
return []
species = SPECIES_STATS_DF["species"].tolist()
priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float)
means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float)
stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float)
if elevation is None or (isinstance(elevation, float) and np.isnan(elevation)):
# If no elevation, start from priors
post = priors.copy()
post = post / post.sum()
else:
# compute likelihood per species using scalar gaussian
likes = np.array([gaussian_pdf_scalar(elevation, means[i], stds[i]) for i in range(len(species))])
post = priors * likes
if post.sum() == 0:
post = np.ones(len(species)) / len(species)
else:
post = post / post.sum()
# incorporate DOY if available
if doy is not None and not np.isnan(doy):
doy_idx = int(doy) - 1
doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in species])
combined = post * doy_probs
if combined.sum() > 0:
post = combined / combined.sum()
order = np.argsort(-post)
top = []
for i in order[:top_k]:
top.append((species[i], float(post[i])))
return top
# ------------------------------
# Lifespan: load model, initialize EE, build species stats
# ------------------------------
@asynccontextmanager
async def lifespan(app):
global ML_MODEL, SCALER, FEATURE_COLUMNS, SPECIES_STATS_DF, DOY_HIST_MAP
# load optional models
if MODEL_FILE.exists():
try:
ML_MODEL = joblib.load(MODEL_FILE)
print("βœ… MIL model loaded.")
except Exception as e:
print("❌ MIL model load error:", e)
if SCALER_FILE.exists():
try:
SCALER = joblib.load(SCALER_FILE)
print("βœ… Scaler loaded.")
except Exception as e:
print("❌ Scaler load error:", e)
if FEATURES_FILE.exists():
try:
FEATURE_COLUMNS = joblib.load(FEATURES_FILE)
print("βœ… Features list loaded.")
except Exception as e:
print("❌ Features list load error:", e)
ok = initialize_ee_from_env()
if not ok:
raise RuntimeError("Earth Engine initialization failed. Set credentials in Space secrets (EE_SERVICE_ACCOUNT_JSON or CLIENT_ID/CLIENT_SECRET/REFRESH_TOKEN).")
try:
SPECIES_STATS_DF, DOY_HIST_MAP = load_or_build_species_stats()
print("βœ… Species stats ready. species count:", len(SPECIES_STATS_DF))
except Exception as e:
print("⚠️ Species stats build error:", e)
SPECIES_STATS_DF = pd.DataFrame()
DOY_HIST_MAP = {}
yield
print("πŸ”„ Shutting down")
# ------------------------------
# FastAPI app & endpoint
# ------------------------------
app = FastAPI(title="Bloom Prediction (HF Space)", lifespan=lifespan)
@app.get("/")
async def root():
return {"message": "Bloom Prediction API (HF Space)", "model_loaded": ML_MODEL is not None}
def process_month_task(lat, lon, year, month, elevation):
"""
Perform the satellite retrieval and predictions for one month (15th).
Returns a dict compatible with MonthlyResult.
"""
sample_dt = date(year, month, 15)
sample_date_str = sample_dt.strftime("%Y-%m-%d")
# fetch satellite data (with caching)
sat_data = get_essential_vegetation_data_with_cache(lat, lon, sample_date_str)
result = {
"month": month,
"sample_date": sample_date_str,
"ml_bloom_probability": None,
"ml_prediction": None,
"ml_confidence": None,
"species_top": None,
"species_probs": None,
"elevation_m": elevation,
"data_quality": None,
"satellite": None,
"note": None,
}
if sat_data is None:
result["note"] = f"No satellite data within {MAX_DAYS_BACK} days for {sample_date_str}"
return result
# run ML
ml_out = predict_bloom_with_ml(sat_data)
result["ml_bloom_probability"] = float(ml_out.get("bloom_probability", 0.0))
result["ml_prediction"] = ml_out.get("prediction")
result["ml_confidence"] = ml_out.get("confidence")
result["data_quality"] = {
"satellite": sat_data.get("satellite"),
"cloud_cover": sat_data.get("cloud_cover"),
"days_offset": sat_data.get("days_offset"),
"buffer_radius_m": sat_data.get("buffer_size"),
}
result["satellite"] = sat_data.get("satellite")
# species prediction if bloom
try:
bloom_bool = (result["ml_prediction"] == "BLOOM") or (result["ml_bloom_probability"] >= 50.0)
if bloom_bool:
doy = sat_data.get("day_of_year", None)
top_species = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES)
# convert to percentages
result["species_top"] = [(s, round(p * 100.0, 2)) for s, p in top_species]
# build full species_probs
species_probs = {}
if (SPECIES_STATS_DF is not None) and (not SPECIES_STATS_DF.empty):
all_species = SPECIES_STATS_DF["species"].tolist()
priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float)
means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float)
stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float)
# If elevation missing, likes = 1 (neutral)
if elevation is None or (isinstance(elevation, float) and np.isnan(elevation)):
likes = np.ones(len(all_species))
else:
likes = np.array([gaussian_pdf_scalar(elevation, means[i], stds[i]) for i in range(len(all_species))])
post = priors * likes
if post.sum() == 0:
post = np.ones(len(all_species)) / len(all_species)
else:
post = post / post.sum()
# DOY adjust
if doy is not None and not np.isnan(doy):
doy_idx = int(doy) - 1
doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in all_species])
combined = post * doy_probs
if combined.sum() > 0:
post = combined / combined.sum()
for s, p in zip(all_species, post):
species_probs[s] = round(float(p * 100.0), 6)
result["species_probs"] = species_probs
else:
result["species_top"] = []
result["species_probs"] = {}
except Exception as e:
print("❌ species prediction error:", e)
result["species_top"] = []
result["species_probs"] = {}
result["note"] = (result.get("note", "") + " ; species_pred_error") if result.get("note") else "species_pred_error"
return result
@app.post("/predict", response_model=BloomPredictionResponse)
async def predict_bloom(req: BloomPredictionRequest):
start_time = time.time()
# Validate date
try:
req_dt = datetime.strptime(req.date, "%Y-%m-%d")
except ValueError:
raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD")
# Get elevation once
elevation = get_elevation_from_ee(req.lat, req.lon)
year = req_dt.year
monthly_results = [None] * 12
# Run monthly tasks in parallel
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
futures = {
ex.submit(process_month_task, req.lat, req.lon, year, month, elevation): month
for month in range(1, 13)
}
for fut in as_completed(futures):
month = futures[fut]
try:
res = fut.result()
except Exception as e:
print(f"❌ month {month} processing error:", e)
res = {
"month": month,
"sample_date": date(year, month, 15).strftime("%Y-%m-%d"),
"ml_bloom_probability": 0.0,
"ml_prediction": "NO_BLOOM",
"ml_confidence": "LOW",
"species_top": [],
"species_probs": {},
"elevation_m": elevation,
"data_quality": None,
"satellite": None,
"note": "processing_error"
}
monthly_results[month - 1] = res
# Extract raw probabilities
raw_probs = np.array([
(mr.get("ml_bloom_probability") or 0.0) if isinstance(mr, dict)
else (mr.ml_bloom_probability or 0.0)
for mr in monthly_results
], dtype=float)
# debug prints
print("DEBUG raw_probs (months 1..12):", raw_probs.tolist())
print("DEBUG raw_stats min,max,mean:", float(raw_probs.min()), float(raw_probs.max()), float(raw_probs.mean()))
# --- smoothing: returns (monthly_perc, monthly_scores)
try:
monthly_perc, monthly_scores = smooth_monthly_probs_preserve_magnitude(
raw_probs.tolist(), alpha=ALPHA, sigma=SMOOTH_SIGMA
)
except Exception as e:
print("❌ smoothing error:", e)
# fallback: simple normalization & identity scores
safe = np.clip(raw_probs, 0.0, 100.0)
s = float(safe.sum()) if safe.sum() > 0 else 1.0
monthly_perc = ((safe / s) * 100.0).round(3).tolist()
monthly_scores = (np.clip(safe, 0.0, 100.0)).round(3).tolist()
# Defensive flattening: sometimes elements are lists (bad), fix that
def _ensure_flat_float_list(x):
out = []
for v in x:
if isinstance(v, list) or isinstance(v, tuple) or hasattr(v, '__iter__') and not isinstance(v, (str, bytes)):
# take first numeric element found
fv = None
for cand in v:
if isinstance(cand, (int, float, np.floating, np.integer)):
fv = float(cand)
break
out.append(fv if fv is not None else 0.0)
else:
try:
out.append(float(v))
except Exception:
out.append(0.0)
return out
monthly_perc = _ensure_flat_float_list(monthly_perc)
monthly_scores = _ensure_flat_float_list(monthly_scores)
# safety lengths
if len(monthly_perc) != 12 or len(monthly_scores) != 12:
print("❌ smoothing returned wrong length; resizing to 12 with uniform fallback")
monthly_perc = (np.ones(12) * (100.0 / 12.0)).tolist()
monthly_scores = (np.ones(12) * np.clip(raw_probs.max(), 0.0, 100.0)).tolist()
# Build monthly_curve for frontend (percentages)
monthly_curve = {i+1: float(round(monthly_perc[i], 3)) for i in range(12)}
print("DEBUG monthly_perc:", monthly_perc)
print("DEBUG monthly_scores (abs 0..100):", monthly_scores)
# Use absolute scores for detection/thresholding
peak_idx = int(np.argmax(monthly_scores))
peak_month = peak_idx + 1
peak_score = float(monthly_scores[peak_idx])
# optional ambiguous-season check: too many months near peak
near_peak_count = sum(1 for s in monthly_scores if s >= (peak_score * 0.9))
ambiguous = near_peak_count >= 4
# Decide status based on peak_score and bell shape
bell_ok, bell_diag = is_bell_shaped(list(monthly_perc))
mean_score = float(np.mean(monthly_scores))
median_score = float(np.median(monthly_scores))
near_peak_count = sum(1 for s in monthly_scores if s >= (peak_score * 0.9))
months_above_threshold = sum(1 for s in monthly_scores if s >= MIN_BLOOM_THRESHOLD)
print(f"DEBUG detection: peak_score={peak_score}, mean={mean_score:.2f}, median={median_score:.2f}, near_peak_count={near_peak_count}, months_above_{MIN_BLOOM_THRESHOLD}={months_above_threshold}, bell_ok={bell_ok}, bell_diag={bell_diag}")
# Heuristic decision:
# 1) If the absolute peak is below MIN_PEAK_FOR_BELL -> NO_BLOOM
if peak_score < MIN_PEAK_FOR_BELL:
status = "NO_BLOOM"
top_species = None
bloom_window = None
peak_month_out = None
peak_prob_out = None
else:
# 2) If peak is high enough, accept bloom even if not perfectly bell-shaped,
# unless the season is extremely ambiguous (many months nearly at peak).
# Use a dominance ratio: peak relative to mean/median.
dominance_ratio = (peak_score / (mean_score + 1e-9))
# Criteria to accept bloom:
accept_bloom = False
# a) clear bell-shaped + peak high -> accept
if bell_ok:
accept_bloom = True
# b) peak strongly dominates the rest -> accept
elif dominance_ratio >= 1.25 and near_peak_count <= 4:
accept_bloom = True
# c) several months above bloom threshold (broad season) but peak high -> accept but mark low confidence
elif months_above_threshold >= 3 and peak_score >= (MIN_BLOOM_THRESHOLD + 10):
accept_bloom = True
if not accept_bloom:
# Ambiguous/flat high-months case -> LOW_CONFIDENCE
status = "LOW_CONFIDENCE"
top_species = None
bloom_window = [i+1 for i, s in enumerate(monthly_scores) if s > (MIN_BLOOM_THRESHOLD * 0.5)]
peak_month_out = peak_month
peak_prob_out = peak_score
else:
# Strong enough to declare bloom
# If the peak is only moderately high, use LOW_CONFIDENCE; otherwise BLOOM_DETECTED
if peak_score < (MIN_BLOOM_THRESHOLD + 5):
status = "LOW_CONFIDENCE"
else:
status = "BLOOM_DETECTED"
bloom_window = [i+1 for i, p in enumerate(monthly_scores) if p > MIN_BLOOM_THRESHOLD]
peak_month_out = peak_month
peak_prob_out = peak_score
# species prediction when we accept bloom
try:
peak_result = monthly_results[peak_idx]
if isinstance(peak_result, dict):
doy = peak_result.get("day_of_year")
else:
doy = date(year, peak_month, 15).timetuple().tm_yday
species_predictions = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES)
top_species = [SpeciesResult(name=sp, probability=round(prob * 100.0, 2)) for sp, prob in species_predictions]
except Exception as e:
print(f"❌ species prediction error at final stage: {e}")
top_species = None
processing_time = round(time.time() - start_time, 2)
# Build response (ensure fields match your Pydantic model)
response = BloomPredictionResponse(
success=True,
status=status,
# requested_date field removed from model earlier; ensure you aren't passing unexpected args
peak_month=peak_month_out,
peak_probability=peak_prob_out,
bloom_window=bloom_window,
top_species=top_species,
monthly_probabilities=monthly_curve,
)
# optionally attach processing_time as extra info in logs (not in model)
print(f"INFO processing_time: {processing_time}s, peak_score: {peak_score}, peak_month: {peak_month}")
return response
# ------------------------------
# Local run
# ------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))