Spaces:
Sleeping
Sleeping
File size: 39,451 Bytes
6720ba6 7b96393 5ee657c 7b96393 5ee657c 7b96393 5ee657c 7b96393 5ee657c 6720ba6 5b16c17 6720ba6 5ee657c 75ac53f 7b96393 5ee657c 6720ba6 5ee657c 9d8ff0f 7b96393 5ee657c 5b16c17 7b96393 6720ba6 5ee657c 23c8193 5ee657c 23c8193 5ee657c 23c8193 5ee657c 6720ba6 5ee657c 7b96393 5ee657c 7b96393 6720ba6 5ee657c 6720ba6 5ee657c 5b16c17 17999ac 5b16c17 17999ac 5b16c17 17999ac 9d8ff0f 17999ac 9d8ff0f 17999ac 9d8ff0f 17999ac 9d8ff0f 17999ac 9d8ff0f 17999ac 9d8ff0f 17999ac 9d8ff0f 17999ac 9d8ff0f 17999ac 5b16c17 9d8ff0f 17999ac 9d8ff0f 5b16c17 17999ac 5b16c17 9d8ff0f 17999ac 9d8ff0f 17999ac 5b16c17 17999ac 5b16c17 6720ba6 7b96393 5ee657c 6720ba6 5ee657c 6720ba6 7b96393 5ee657c 75ac53f 5ee657c 75ac53f 6720ba6 75ac53f 6720ba6 75ac53f 6720ba6 75ac53f 6720ba6 5ee657c 75ac53f 6720ba6 75ac53f 6720ba6 75ac53f 7b96393 75ac53f 5ee657c 7b96393 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 7b96393 5ee657c 7b96393 5ee657c 7b96393 5ee657c 6720ba6 5ee657c 7b96393 6720ba6 5ee657c 6720ba6 5ee657c 7b96393 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 7b96393 5ee657c 7b96393 5ee657c 6720ba6 5ee657c 6720ba6 5ee657c 6720ba6 23c8193 5ee657c 23c8193 5ee657c 6720ba6 23c8193 6720ba6 23c8193 6720ba6 23c8193 6720ba6 23c8193 6720ba6 23c8193 f0b20e0 23c8193 5b16c17 f0b20e0 9d8ff0f f0b20e0 9d8ff0f f0b20e0 23c8193 f0b20e0 0e40918 23c8193 0e40918 23c8193 0e40918 5b16c17 6e6068d f0b20e0 23c8193 f0b20e0 23c8193 f0b20e0 23c8193 7b96393 6720ba6 7b96393 5ee657c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 |
# app.py (updated: fixes species bug, adds caching, parallel monthly queries, monthly bell curve)
import os
import time
import math
import joblib
import ee
import pandas as pd
import numpy as np
from datetime import datetime, date, timedelta
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional, List, Dict, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import json
import scipy.signal
# google oauth helpers (used earlier)
from google.oauth2.credentials import Credentials
from google.auth.transport.requests import Request as GoogleRequest
from google.oauth2 import service_account as google_service_account
# ------------------------------
# CONFIG / FILENAMES / TUNABLES
# ------------------------------
MODEL_FILE = Path("mil_bloom_model.joblib")
SCALER_FILE = Path("mil_scaler.joblib")
FEATURES_FILE = Path("mil_features.joblib")
PHENO_FILE = Path("phenologythingy.csv")
SPECIES_STATS_FILE = Path("species_stats.csv")
MIN_BLOOM_THRESHOLD = float(os.environ.get("MIN_BLOOM_THRESHOLD", 20.0)) # minimum probability to predict species
MIN_PEAK_FOR_BELL = float(os.environ.get("MIN_PEAK_FOR_BELL", 25.0))
ELEV_IMAGE_ID = "USGS/SRTMGL1_003"
BUFFER_METERS = int(os.environ.get("BUFFER_METERS", 200))
MAX_DAYS_BACK = int(os.environ.get("MAX_DAYS_BACK", 30))
MIN_COUNT_FOR_SPECIES = int(os.environ.get("MIN_COUNT_FOR_SPECIES", 20))
TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5))
DOY_BINS = 366
DOY_SMOOTH = 15
EPS_STD = 1.0
# TUNABLES (additions)
ALPHA = float(os.environ.get("MONTH_ALPHA", 2.0)) # >1 sharpens peaks, 1.0 = no change
SMOOTH_SIGMA = float(os.environ.get("SMOOTH_SIGMA", 1.2)) # gaussian sigma in months
TOP_K_SPECIES = int(os.environ.get("TOP_K_SPECIES", 5)) # how many species to return
MIN_CURVE_PROB = float(os.environ.get("MIN_CURVE_PROB", 0.01)) # min per-month percentage floor
# Tune parallelism: how many months to fetch at once
MAX_WORKERS = int(os.environ.get("MAX_WORKERS", 4))
# EE OAuth env vars expected in Hugging Face Space secrets
CLIENT_ID = os.environ.get("CLIENT_ID")
CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
REFRESH_TOKEN = os.environ.get("REFRESH_TOKEN")
EE_PROJECT = os.environ.get("PROJECT") or os.environ.get("EE_PROJECT") or None
EE_SCOPES = [
"https://www.googleapis.com/auth/earthengine",
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/devstorage.full_control",
]
# ------------------------------
# Pydantic models
# ------------------------------
class BloomPredictionRequest(BaseModel):
lat: float = Field(..., ge=-90, le=90)
lon: float = Field(..., ge=-180, le=180)
date: str = Field(..., description="YYYY-MM-DD")
class SimplifiedMonthlyResult(BaseModel):
month: int
bloom_probability: float
prediction: str # "BLOOM" or "NO_BLOOM"
class SpeciesResult(BaseModel):
name: str
probability: float # as percentage
class BloomPredictionResponse(BaseModel):
success: bool
status: str # "BLOOM_DETECTED", "NO_BLOOM", "LOW_CONFIDENCE"
# Only include these if there's a valid bloom season
peak_month: Optional[int] = None
peak_probability: Optional[float] = None
bloom_window: Optional[List[int]] = None # months with >40% probability
# Only include species if peak > threshold
top_species: Optional[List[SpeciesResult]] = None
# Simplified monthly data (probabilities only)
monthly_probabilities: Dict[int, float]
# ------------------------------
# Globals & cache
# ------------------------------
ML_MODEL = None
SCALER = None
FEATURE_COLUMNS = None
SPECIES_STATS_DF = None
DOY_HIST_MAP: Dict[str, np.ndarray] = {}
# simple in-memory cache (keyed by (lat,lon,date_str) -> sat_data)
ee_cache_lock = threading.Lock()
ee_cache: Dict[Tuple[float, float, str], dict] = {}
# ------------------------------
# Utility functions
# ------------------------------
def gaussian_kernel(length=12, sigma=1.2):
"""Return a normalized 1D gaussian kernel of size `length` centered (odd/even ok) with sigma in index units."""
# create symmetric kernel (size = length*3 is overkill; we'll sample it centered)
half = max(6, int(3 * sigma))
xs = np.arange(-half, half + 1)
kern = np.exp(-0.5 * (xs / sigma) ** 2)
kern = kern / kern.sum()
return kern
def smooth_monthly_probs_preserve_magnitude(raw_probs,
alpha=ALPHA,
sigma=SMOOTH_SIGMA,
min_curve_prob=MIN_CURVE_PROB):
"""
Convert raw_probs (0..100) -> two arrays:
- monthly_scores: absolute smoothed scores scaled to 0..100 (preserves magnitude)
- monthly_perc: normalized percentages summing to 100 (for plotting)
Steps:
1. scale to 0..1
2. tiny contrast-stretch so differences matter
3. apply exponent alpha to emphasize peaks
4. circular gaussian smoothing
5. monthly_scores: scale smoothed values so max -> 100
6. monthly_perc: normalize the same smoothed values to sum -> 100
"""
a = np.asarray(raw_probs, dtype=float) / 100.0 # 0..1
# Guard: if all zeros, return uniform small floor
if a.sum() == 0 or np.allclose(a, 0.0):
uniform = np.ones(12) * (min_curve_prob / 100.0)
monthly_scores = (uniform / uniform.max()) * 100.0
monthly_perc = (uniform / uniform.sum()) * 100.0
return monthly_perc.tolist(), monthly_scores.tolist()
# Contrast stretch (min-max) to amplify differences
amin = float(a.min())
amax = float(a.max())
rng = amax - amin
if rng < 1e-6:
a_cs = a - amin
if a_cs.max() <= 1e-12:
a_cs = np.ones_like(a) * 1e-6
else:
a_cs = a_cs / (a_cs.max() + 1e-12)
else:
a_cs = (a - amin) / (rng + 1e-12)
# floor tiny values (avoid zero everywhere)
floor = min_curve_prob / 100.0
a_cs = np.clip(a_cs, floor * 1e-3, 1.0)
# sharpen peaks
if alpha != 1.0:
a_sh = np.power(a_cs, alpha)
else:
a_sh = a_cs
# circular pad and gaussian kernel
sigma = float(max(0.6, sigma))
pad = max(3, int(round(3 * sigma)))
padded = np.concatenate([a_sh[-pad:], a_sh, a_sh[:pad]])
kern_range = np.arange(-pad, pad + 1)
kern = np.exp(-0.5 * (kern_range / sigma) ** 2)
kern = kern / kern.sum()
smoothed = np.convolve(padded, kern, mode='same')
center = smoothed[pad:pad+12]
center = np.clip(center, 0.0, None)
# monthly_perc: normalized to sum 100 for plotting
if center.sum() <= 0:
center = np.ones_like(center)
norm = center / center.sum()
monthly_perc = (norm * 100.0).round(3)
# monthly_scores: scale by max to 0..100 (preserve magnitude)
maxv = float(center.max())
if maxv <= 0:
monthly_scores = np.ones_like(center) * (min_curve_prob)
else:
monthly_scores = (center / maxv) * 100.0
monthly_scores = np.round(monthly_scores, 3)
# final safety normalization
s = float(monthly_perc.sum())
if s <= 0:
monthly_perc = np.ones(12) * (100.0 / 12.0)
else:
monthly_perc = monthly_perc * (100.0 / monthly_perc.sum())
return monthly_perc.tolist(), monthly_scores.tolist()
def is_bell_shaped(perc_list):
"""
Basic unimodal/symmetric check:
- count peaks (local maxima) - expect 1
- compute skewness sign (negative/positive)
- compute ratio of mass on left vs right of peak (expect roughly symmetric)
Returns (is_bell, diagnostics_dict)
"""
arr = np.asarray(perc_list, dtype=float)
# small smoothing to remove noise
arr_smooth = np.array(scipy.signal.savgol_filter(arr, 5 if len(arr)>=5 else len(arr)-1, 3 if len(arr)>=3 else 1))
# find peaks
peaks = (np.diff(np.sign(np.diff(arr_smooth))) < 0).nonzero()[0] + 1 # indices of local maxima
num_peaks = len(peaks)
peak_idx = int(peaks[0]) + 1 if num_peaks >= 1 else int(arr_smooth.argmax())
# symmetry: mass difference left/right of peak
left_mass = arr[:peak_idx].sum()
right_mass = arr[peak_idx+1:].sum() if peak_idx+1 < len(arr) else 0.0
sym_ratio = left_mass / (right_mass + 1e-9)
# skewness:
m = arr.mean()
s = arr.std(ddof=0) if arr.std(ddof=0) > 0 else 1.0
skew = ((arr - m) ** 3).mean() / (s ** 3)
# heuristics
is_unimodal = (num_peaks <= 1)
is_symmish = 0.5 <= sym_ratio <= 2.0 # within factor 2
# final decision
is_bell = is_unimodal and is_symmish
diagnostics = {
"num_peaks": num_peaks,
"peak_month": int(np.argmax(arr))+1,
"sym_ratio_left_to_right": float(sym_ratio),
"skewness": float(skew),
"is_unimodal": bool(is_unimodal),
"is_symmetric_enough": bool(is_symmish)
}
return bool(is_bell), diagnostics
def gaussian_pdf_scalar(x_scalar, mean, std):
"""Return Gaussian PDF scalar for scalar x. If x_scalar is None or NaN, return 1.0 (neutral)."""
try:
if x_scalar is None or (isinstance(x_scalar, float) and np.isnan(x_scalar)):
return 1.0
std = max(float(std), 1e-6)
coef = 1.0 / (std * math.sqrt(2 * math.pi))
z = (float(x_scalar) - float(mean)) / std
return coef * math.exp(-0.5 * z * z)
except Exception:
# fallback neutral
return 1.0
def circular_histogram(doys, bins=DOY_BINS, smooth_window=DOY_SMOOTH):
if len(doys) == 0:
return np.ones(bins) / bins
counts = np.bincount(doys.astype(int), minlength=bins+1)[1:]
window = np.ones(smooth_window) / smooth_window
doubled = np.concatenate([counts, counts])
smoothed = np.convolve(doubled, window, mode='same')[:bins]
total = smoothed.sum()
if total <= 0:
return np.ones(bins) / bins
return smoothed / total
# ------------------------------
# Earth Engine init (OAuth refresh-token or service account fallback)
# ------------------------------
def initialize_ee_from_env():
# Try OAuth refresh token first
try:
if CLIENT_ID and CLIENT_SECRET and REFRESH_TOKEN:
creds = Credentials(
token=None,
refresh_token=REFRESH_TOKEN,
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET,
token_uri="https://oauth2.googleapis.com/token",
scopes=EE_SCOPES
)
request = GoogleRequest()
creds.refresh(request)
ee.Initialize(credentials=creds, project=EE_PROJECT)
print("β
Earth Engine initialized with OAuth refresh-token credentials")
return True
except Exception as e:
print("β οΈ OAuth refresh-token init failed:", e)
# Try service account JSON in env var EE_SERVICE_ACCOUNT_JSON
try:
sa_json_env = os.environ.get("EE_SERVICE_ACCOUNT_JSON")
sa_file_env = os.environ.get("EE_SERVICE_ACCOUNT_FILE")
if sa_json_env:
sa_data = json.loads(sa_json_env)
tmp_path = "/tmp/ee_service_account.json"
with open(tmp_path, "w") as f:
json.dump(sa_data, f)
creds = google_service_account.Credentials.from_service_account_file(tmp_path, scopes=EE_SCOPES)
ee.Initialize(credentials=creds, project=EE_PROJECT)
print("β
Earth Engine initialized with service-account (from EE_SERVICE_ACCOUNT_JSON)")
return True
if sa_file_env and Path(sa_file_env).exists():
creds = google_service_account.Credentials.from_service_account_file(sa_file_env, scopes=EE_SCOPES)
ee.Initialize(credentials=creds, project=EE_PROJECT)
print("β
Earth Engine initialized with service-account (from file)")
return True
except Exception as e:
print("β οΈ Service-account init failed:", e)
# fallback to default
try:
ee.Initialize(project=EE_PROJECT) if EE_PROJECT else ee.Initialize()
print("β
Earth Engine initialized via default ee.Initialize() (fallback).")
return True
except Exception as e:
print("β Earth Engine initialization failed (all methods):", e)
return False
def get_elevation_from_ee(lat, lon):
try:
img = ee.Image(ELEV_IMAGE_ID)
pt = ee.Geometry.Point([float(lon), float(lat)])
rr = img.reduceRegion(ee.Reducer.first(), pt, scale=30, maxPixels=1e6)
if rr is None:
return None
try:
val = rr.get("elevation").getInfo()
return float(val) if val is not None else None
except Exception:
keys = rr.keys().getInfo()
for k in keys:
v = rr.get(k).getInfo()
if isinstance(v, (int, float)):
return float(v)
return None
except Exception as e:
print("β get_elevation_from_ee error:", e)
return None
# ------------------------------
# Satellite retrieval + caching
# ------------------------------
def get_single_date_satellite_data(lat, lon, date_str, satellite, buffer_meters, area):
collection_id = "LANDSAT/LC09/C02/T1_L2" if satellite == "Landsat-9" else "LANDSAT/LC08/C02/T1_L2"
try:
filtered = (ee.ImageCollection(collection_id)
.filterBounds(area)
.filterDate(date_str, f"{date_str}T23:59:59")
.sort("CLOUD_COVER")
.limit(1))
size = int(filtered.size().getInfo())
if size == 0:
return None
image = ee.Image(filtered.first())
info = image.getInfo().get("properties", {})
cloud_cover = float(info.get("CLOUD_COVER", 100.0))
if cloud_cover > 80:
return None
ndvi = image.normalizedDifference(["SR_B5", "SR_B4"]).rename("NDVI")
ndwi = image.normalizedDifference(["SR_B3", "SR_B5"]).rename("NDWI")
evi = image.expression(
"2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))",
{"NIR": image.select("SR_B5"), "RED": image.select("SR_B4"), "BLUE": image.select("SR_B2")},
).rename("EVI")
lst = image.select("ST_B10").multiply(0.00341802).add(149.0).subtract(273.15).rename("LST")
composite = ndvi.addBands([ndwi, evi, lst])
stats = composite.reduceRegion(
reducer=ee.Reducer.mean(), geometry=area, scale=30, maxPixels=1e6, bestEffort=True
).getInfo()
ndvi_val = stats.get("NDVI")
if ndvi_val is None:
return None
ndwi_val = stats.get("NDWI")
evi_val = stats.get("EVI")
lst_val = stats.get("LST")
current_dt = datetime.strptime(date_str, "%Y-%m-%d")
return {
"ndvi": float(ndvi_val),
"ndwi": float(ndwi_val) if ndwi_val is not None else None,
"evi": float(evi_val) if evi_val is not None else None,
"lst": float(lst_val) if lst_val is not None else None,
"cloud_cover": float(cloud_cover),
"month": current_dt.month,
"day_of_year": current_dt.timetuple().tm_yday,
"satellite": satellite,
"date": date_str,
"buffer_size": buffer_meters,
}
except Exception as e:
print("β get_single_date_satellite_data error:", e)
return None
def get_satellite_data_with_fallback(lat, lon, target_dt, satellite, buffer_meters, area, max_days_back=MAX_DAYS_BACK):
for days_back in range(0, max_days_back + 1):
current_date = (target_dt - timedelta(days=days_back)).strftime("%Y-%m-%d")
# check cache (only for final results of get_essential_vegetation_data). Here we do not use cache per satellite
data = get_single_date_satellite_data(lat, lon, current_date, satellite, buffer_meters, area)
if data and data.get("ndvi") is not None:
data["original_request_date"] = target_dt.strftime("%Y-%m-%d")
data["actual_data_date"] = current_date
data["days_offset"] = days_back
return data
return None
def get_essential_vegetation_data_with_cache(lat, lon, target_date, buffer_meters=BUFFER_METERS, max_days_back=MAX_DAYS_BACK):
key = (round(float(lat), 6), round(float(lon), 6), str(target_date))
with ee_cache_lock:
if key in ee_cache:
return ee_cache[key]
# Else fetch (L9 -> L8)
point = ee.Geometry.Point([float(lon), float(lat)])
area = point.buffer(buffer_meters)
target_dt = datetime.strptime(target_date, "%Y-%m-%d")
data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-9", buffer_meters, area, max_days_back)
if not data:
data = get_satellite_data_with_fallback(lat, lon, target_dt, "Landsat-8", buffer_meters, area, max_days_back)
with ee_cache_lock:
ee_cache[key] = data
return data
# ------------------------------
# ML prediction wrapper (unchanged features; elevation NOT passed to ML)
# ------------------------------
def predict_bloom_with_ml(features_dict):
ndvi = features_dict.get("ndvi", 0.0) or 0.0
evi = features_dict.get("evi", 0.0) or 0.0
if ndvi < 0.05:
return {"bloom_probability": 8.0, "prediction": "NO_BLOOM", "confidence": "HIGH"}
if evi < 0.1 and ndvi < 0.1:
return {"bloom_probability": 10.0, "prediction": "NO_BLOOM", "confidence": "HIGH"}
if ML_MODEL is not None and SCALER is not None:
try:
features_array = np.array(
[
[
float(features_dict.get("ndvi", 0.0)),
float(features_dict.get("ndwi", 0.0) or 0.0),
float(features_dict.get("evi", 0.0) or 0.0),
float(features_dict.get("lst", 0.0) or 0.0),
float(features_dict.get("cloud_cover", 0.0) or 0.0),
float(features_dict.get("month", 0) or 0),
float(features_dict.get("day_of_year", 0) or 0),
]
],
dtype=np.float64,
)
features_scaled = SCALER.transform(features_array)
probabilities = ML_MODEL.predict_proba(features_scaled)
bloom_prob = probabilities[0, 1] if probabilities.shape[1] == 2 else probabilities[0, 0]
prediction = ML_MODEL.predict(features_scaled)[0]
bloom_prob_pct = round(float(bloom_prob * 100.0), 2)
if bloom_prob_pct > 75 or bloom_prob_pct < 25:
conf = "HIGH"
elif bloom_prob_pct > 60 or bloom_prob_pct < 40:
conf = "MEDIUM"
else:
conf = "LOW"
return {"bloom_probability": bloom_prob_pct, "prediction": "BLOOM" if prediction == 1 else "NO_BLOOM", "confidence": conf}
except Exception as e:
print("β ML model error:", e)
return predict_bloom_fallback(features_dict)
def predict_bloom_fallback(features_dict):
ndvi = float(features_dict.get("ndvi") or 0.0)
ndwi = float(features_dict.get("ndwi") or 0.0)
evi = float(features_dict.get("evi") or 0.0)
lst = float(features_dict.get("lst") or 0.0)
month = int(features_dict.get("month") or 1)
score = 0.0
if evi > 0.7:
score += 50
elif evi > 0.5:
score += 35
elif evi > 0.3:
score += 20
if ndvi > 0.5:
score += 25
elif ndvi > 0.3:
score += 15
if -0.2 < ndwi < 0.05:
score += 15
if 12 < lst < 32:
score += 12
if month in [3, 4, 5]:
score += 15
if month in [11, 12, 1, 2]:
score -= 3
prob = min(90, max(8, score))
if prob > 52:
pred = "BLOOM"
conf = "MEDIUM" if prob > 65 else "LOW"
else:
pred = "NO_BLOOM"
conf = "MEDIUM" if prob < 25 else "LOW"
return {"bloom_probability": round(prob, 2), "prediction": pred, "confidence": conf}
# ------------------------------
# Species stats builder / predictor (fixed)
# ------------------------------
def load_or_build_species_stats():
global PHENO_FILE, SPECIES_STATS_FILE
if SPECIES_STATS_FILE.exists():
df = pd.read_csv(SPECIES_STATS_FILE)
doy_map = {}
for s in df["species"].tolist():
doy_map[s] = np.ones(DOY_BINS) / DOY_BINS
return df, doy_map
if PHENO_FILE.exists():
ph = pd.read_csv(PHENO_FILE, low_memory=False)
if "phenophaseStatus" in ph.columns:
ph["phenophaseStatus"] = ph["phenophaseStatus"].astype(str).str.strip().str.lower()
ph_yes = ph[ph["phenophaseStatus"].str.startswith("y")].copy()
else:
ph_yes = ph.copy()
ph_yes = ph_yes.dropna(subset=["elevation"])
if "dayOfYear" in ph_yes.columns:
ph_yes["dayOfYear"] = pd.to_numeric(ph_yes["dayOfYear"], errors="coerce").dropna().astype(int).clip(1, 366)
rows = []
doy_map = {}
grouped = ph_yes.groupby("scientificName")
for name, g in grouped:
cnt = len(g)
mean_elev = float(g["elevation"].dropna().mean()) if cnt > 0 else np.nan
std_elev = float(g["elevation"].dropna().std(ddof=0)) if cnt > 0 else EPS_STD
std_elev = max(std_elev if not np.isnan(std_elev) else 0.0, EPS_STD)
rows.append({"species": name, "count": cnt, "mean_elev": mean_elev, "std_elev": std_elev})
if "dayOfYear" in g.columns:
doy_map[name] = circular_histogram(g["dayOfYear"].to_numpy(dtype=int))
else:
doy_map[name] = np.ones(DOY_BINS) / DOY_BINS
species_df = pd.DataFrame(rows)
total = species_df["count"].sum()
species_df["prior"] = species_df["count"] / total if total > 0 else 1.0 / max(1, len(species_df))
rare = species_df[species_df["count"] < MIN_COUNT_FOR_SPECIES]
frequent = species_df[species_df["count"] >= MIN_COUNT_FOR_SPECIES]
final_rows = frequent.to_dict("records")
if len(rare) > 0:
rare_names = rare["species"].tolist()
rare_obs = ph_yes[ph_yes["scientificName"].isin(rare_names)]
total_rare = len(rare_obs)
if total_rare > 0:
mean_other = float(rare_obs["elevation"].dropna().mean())
std_other = float(rare_obs["elevation"].dropna().std(ddof=0)) if total_rare > 1 else EPS_STD
std_other = max(std_other if not np.isnan(std_other) else 0.0, EPS_STD)
final_rows.append(
{
"species": "OTHER",
"count": int(total_rare),
"mean_elev": mean_other,
"std_elev": std_other,
"prior": int(total_rare) / total if total > 0 else int(total_rare),
}
)
doy_map["OTHER"] = circular_histogram(rare_obs["dayOfYear"].to_numpy(dtype=int)) if "dayOfYear" in rare_obs.columns else np.ones(DOY_BINS) / DOY_BINS
final_df = pd.DataFrame(final_rows).fillna(0)
if "prior" not in final_df.columns:
t2 = final_df["count"].sum()
final_df["prior"] = final_df["count"] / t2 if t2 > 0 else 1.0 / len(final_df)
return final_df, doy_map
return pd.DataFrame(columns=["species", "count", "mean_elev", "std_elev", "prior"]), {}
def predict_species_by_elevation(elevation, doy=None, top_k=TOP_K_SPECIES):
"""
Return top_k list of tuples (species, prob_fraction in [0,1]).
If elevation is None, rely on priors and DOY histogram if available.
"""
global SPECIES_STATS_DF, DOY_HIST_MAP
if SPECIES_STATS_DF is None or SPECIES_STATS_DF.empty:
return []
species = SPECIES_STATS_DF["species"].tolist()
priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float)
means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float)
stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float)
if elevation is None or (isinstance(elevation, float) and np.isnan(elevation)):
# If no elevation, start from priors
post = priors.copy()
post = post / post.sum()
else:
# compute likelihood per species using scalar gaussian
likes = np.array([gaussian_pdf_scalar(elevation, means[i], stds[i]) for i in range(len(species))])
post = priors * likes
if post.sum() == 0:
post = np.ones(len(species)) / len(species)
else:
post = post / post.sum()
# incorporate DOY if available
if doy is not None and not np.isnan(doy):
doy_idx = int(doy) - 1
doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in species])
combined = post * doy_probs
if combined.sum() > 0:
post = combined / combined.sum()
order = np.argsort(-post)
top = []
for i in order[:top_k]:
top.append((species[i], float(post[i])))
return top
# ------------------------------
# Lifespan: load model, initialize EE, build species stats
# ------------------------------
@asynccontextmanager
async def lifespan(app):
global ML_MODEL, SCALER, FEATURE_COLUMNS, SPECIES_STATS_DF, DOY_HIST_MAP
# load optional models
if MODEL_FILE.exists():
try:
ML_MODEL = joblib.load(MODEL_FILE)
print("β
MIL model loaded.")
except Exception as e:
print("β MIL model load error:", e)
if SCALER_FILE.exists():
try:
SCALER = joblib.load(SCALER_FILE)
print("β
Scaler loaded.")
except Exception as e:
print("β Scaler load error:", e)
if FEATURES_FILE.exists():
try:
FEATURE_COLUMNS = joblib.load(FEATURES_FILE)
print("β
Features list loaded.")
except Exception as e:
print("β Features list load error:", e)
ok = initialize_ee_from_env()
if not ok:
raise RuntimeError("Earth Engine initialization failed. Set credentials in Space secrets (EE_SERVICE_ACCOUNT_JSON or CLIENT_ID/CLIENT_SECRET/REFRESH_TOKEN).")
try:
SPECIES_STATS_DF, DOY_HIST_MAP = load_or_build_species_stats()
print("β
Species stats ready. species count:", len(SPECIES_STATS_DF))
except Exception as e:
print("β οΈ Species stats build error:", e)
SPECIES_STATS_DF = pd.DataFrame()
DOY_HIST_MAP = {}
yield
print("π Shutting down")
# ------------------------------
# FastAPI app & endpoint
# ------------------------------
app = FastAPI(title="Bloom Prediction (HF Space)", lifespan=lifespan)
@app.get("/")
async def root():
return {"message": "Bloom Prediction API (HF Space)", "model_loaded": ML_MODEL is not None}
def process_month_task(lat, lon, year, month, elevation):
"""
Perform the satellite retrieval and predictions for one month (15th).
Returns a dict compatible with MonthlyResult.
"""
sample_dt = date(year, month, 15)
sample_date_str = sample_dt.strftime("%Y-%m-%d")
# fetch satellite data (with caching)
sat_data = get_essential_vegetation_data_with_cache(lat, lon, sample_date_str)
result = {
"month": month,
"sample_date": sample_date_str,
"ml_bloom_probability": None,
"ml_prediction": None,
"ml_confidence": None,
"species_top": None,
"species_probs": None,
"elevation_m": elevation,
"data_quality": None,
"satellite": None,
"note": None,
}
if sat_data is None:
result["note"] = f"No satellite data within {MAX_DAYS_BACK} days for {sample_date_str}"
return result
# run ML
ml_out = predict_bloom_with_ml(sat_data)
result["ml_bloom_probability"] = float(ml_out.get("bloom_probability", 0.0))
result["ml_prediction"] = ml_out.get("prediction")
result["ml_confidence"] = ml_out.get("confidence")
result["data_quality"] = {
"satellite": sat_data.get("satellite"),
"cloud_cover": sat_data.get("cloud_cover"),
"days_offset": sat_data.get("days_offset"),
"buffer_radius_m": sat_data.get("buffer_size"),
}
result["satellite"] = sat_data.get("satellite")
# species prediction if bloom
try:
bloom_bool = (result["ml_prediction"] == "BLOOM") or (result["ml_bloom_probability"] >= 50.0)
if bloom_bool:
doy = sat_data.get("day_of_year", None)
top_species = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES)
# convert to percentages
result["species_top"] = [(s, round(p * 100.0, 2)) for s, p in top_species]
# build full species_probs
species_probs = {}
if (SPECIES_STATS_DF is not None) and (not SPECIES_STATS_DF.empty):
all_species = SPECIES_STATS_DF["species"].tolist()
priors = SPECIES_STATS_DF["prior"].to_numpy(dtype=float)
means = SPECIES_STATS_DF["mean_elev"].to_numpy(dtype=float)
stds = SPECIES_STATS_DF["std_elev"].to_numpy(dtype=float)
# If elevation missing, likes = 1 (neutral)
if elevation is None or (isinstance(elevation, float) and np.isnan(elevation)):
likes = np.ones(len(all_species))
else:
likes = np.array([gaussian_pdf_scalar(elevation, means[i], stds[i]) for i in range(len(all_species))])
post = priors * likes
if post.sum() == 0:
post = np.ones(len(all_species)) / len(all_species)
else:
post = post / post.sum()
# DOY adjust
if doy is not None and not np.isnan(doy):
doy_idx = int(doy) - 1
doy_probs = np.array([DOY_HIST_MAP.get(s, np.ones(DOY_BINS) / DOY_BINS)[doy_idx] for s in all_species])
combined = post * doy_probs
if combined.sum() > 0:
post = combined / combined.sum()
for s, p in zip(all_species, post):
species_probs[s] = round(float(p * 100.0), 6)
result["species_probs"] = species_probs
else:
result["species_top"] = []
result["species_probs"] = {}
except Exception as e:
print("β species prediction error:", e)
result["species_top"] = []
result["species_probs"] = {}
result["note"] = (result.get("note", "") + " ; species_pred_error") if result.get("note") else "species_pred_error"
return result
@app.post("/predict", response_model=BloomPredictionResponse)
async def predict_bloom(req: BloomPredictionRequest):
start_time = time.time()
# Validate date
try:
req_dt = datetime.strptime(req.date, "%Y-%m-%d")
except ValueError:
raise HTTPException(status_code=400, detail="date must be YYYY-MM-DD")
# Get elevation once
elevation = get_elevation_from_ee(req.lat, req.lon)
year = req_dt.year
monthly_results = [None] * 12
# Run monthly tasks in parallel
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
futures = {
ex.submit(process_month_task, req.lat, req.lon, year, month, elevation): month
for month in range(1, 13)
}
for fut in as_completed(futures):
month = futures[fut]
try:
res = fut.result()
except Exception as e:
print(f"β month {month} processing error:", e)
res = {
"month": month,
"sample_date": date(year, month, 15).strftime("%Y-%m-%d"),
"ml_bloom_probability": 0.0,
"ml_prediction": "NO_BLOOM",
"ml_confidence": "LOW",
"species_top": [],
"species_probs": {},
"elevation_m": elevation,
"data_quality": None,
"satellite": None,
"note": "processing_error"
}
monthly_results[month - 1] = res
# Extract raw probabilities
raw_probs = np.array([
(mr.get("ml_bloom_probability") or 0.0) if isinstance(mr, dict)
else (mr.ml_bloom_probability or 0.0)
for mr in monthly_results
], dtype=float)
# debug prints
print("DEBUG raw_probs (months 1..12):", raw_probs.tolist())
print("DEBUG raw_stats min,max,mean:", float(raw_probs.min()), float(raw_probs.max()), float(raw_probs.mean()))
# --- smoothing: returns (monthly_perc, monthly_scores)
try:
monthly_perc, monthly_scores = smooth_monthly_probs_preserve_magnitude(
raw_probs.tolist(), alpha=ALPHA, sigma=SMOOTH_SIGMA
)
except Exception as e:
print("β smoothing error:", e)
# fallback: simple normalization & identity scores
safe = np.clip(raw_probs, 0.0, 100.0)
s = float(safe.sum()) if safe.sum() > 0 else 1.0
monthly_perc = ((safe / s) * 100.0).round(3).tolist()
monthly_scores = (np.clip(safe, 0.0, 100.0)).round(3).tolist()
# Defensive flattening: sometimes elements are lists (bad), fix that
def _ensure_flat_float_list(x):
out = []
for v in x:
if isinstance(v, list) or isinstance(v, tuple) or hasattr(v, '__iter__') and not isinstance(v, (str, bytes)):
# take first numeric element found
fv = None
for cand in v:
if isinstance(cand, (int, float, np.floating, np.integer)):
fv = float(cand)
break
out.append(fv if fv is not None else 0.0)
else:
try:
out.append(float(v))
except Exception:
out.append(0.0)
return out
monthly_perc = _ensure_flat_float_list(monthly_perc)
monthly_scores = _ensure_flat_float_list(monthly_scores)
# safety lengths
if len(monthly_perc) != 12 or len(monthly_scores) != 12:
print("β smoothing returned wrong length; resizing to 12 with uniform fallback")
monthly_perc = (np.ones(12) * (100.0 / 12.0)).tolist()
monthly_scores = (np.ones(12) * np.clip(raw_probs.max(), 0.0, 100.0)).tolist()
# Build monthly_curve for frontend (percentages)
monthly_curve = {i+1: float(round(monthly_perc[i], 3)) for i in range(12)}
print("DEBUG monthly_perc:", monthly_perc)
print("DEBUG monthly_scores (abs 0..100):", monthly_scores)
# Use absolute scores for detection/thresholding
peak_idx = int(np.argmax(monthly_scores))
peak_month = peak_idx + 1
peak_score = float(monthly_scores[peak_idx])
# optional ambiguous-season check: too many months near peak
near_peak_count = sum(1 for s in monthly_scores if s >= (peak_score * 0.9))
ambiguous = near_peak_count >= 4
# Decide status based on peak_score and bell shape
bell_ok, bell_diag = is_bell_shaped(list(monthly_perc))
mean_score = float(np.mean(monthly_scores))
median_score = float(np.median(monthly_scores))
near_peak_count = sum(1 for s in monthly_scores if s >= (peak_score * 0.9))
months_above_threshold = sum(1 for s in monthly_scores if s >= MIN_BLOOM_THRESHOLD)
print(f"DEBUG detection: peak_score={peak_score}, mean={mean_score:.2f}, median={median_score:.2f}, near_peak_count={near_peak_count}, months_above_{MIN_BLOOM_THRESHOLD}={months_above_threshold}, bell_ok={bell_ok}, bell_diag={bell_diag}")
# Heuristic decision:
# 1) If the absolute peak is below MIN_PEAK_FOR_BELL -> NO_BLOOM
if peak_score < MIN_PEAK_FOR_BELL:
status = "NO_BLOOM"
top_species = None
bloom_window = None
peak_month_out = None
peak_prob_out = None
else:
# 2) If peak is high enough, accept bloom even if not perfectly bell-shaped,
# unless the season is extremely ambiguous (many months nearly at peak).
# Use a dominance ratio: peak relative to mean/median.
dominance_ratio = (peak_score / (mean_score + 1e-9))
# Criteria to accept bloom:
accept_bloom = False
# a) clear bell-shaped + peak high -> accept
if bell_ok:
accept_bloom = True
# b) peak strongly dominates the rest -> accept
elif dominance_ratio >= 1.25 and near_peak_count <= 4:
accept_bloom = True
# c) several months above bloom threshold (broad season) but peak high -> accept but mark low confidence
elif months_above_threshold >= 3 and peak_score >= (MIN_BLOOM_THRESHOLD + 10):
accept_bloom = True
if not accept_bloom:
# Ambiguous/flat high-months case -> LOW_CONFIDENCE
status = "LOW_CONFIDENCE"
top_species = None
bloom_window = [i+1 for i, s in enumerate(monthly_scores) if s > (MIN_BLOOM_THRESHOLD * 0.5)]
peak_month_out = peak_month
peak_prob_out = peak_score
else:
# Strong enough to declare bloom
# If the peak is only moderately high, use LOW_CONFIDENCE; otherwise BLOOM_DETECTED
if peak_score < (MIN_BLOOM_THRESHOLD + 5):
status = "LOW_CONFIDENCE"
else:
status = "BLOOM_DETECTED"
bloom_window = [i+1 for i, p in enumerate(monthly_scores) if p > MIN_BLOOM_THRESHOLD]
peak_month_out = peak_month
peak_prob_out = peak_score
# species prediction when we accept bloom
try:
peak_result = monthly_results[peak_idx]
if isinstance(peak_result, dict):
doy = peak_result.get("day_of_year")
else:
doy = date(year, peak_month, 15).timetuple().tm_yday
species_predictions = predict_species_by_elevation(elevation, doy=doy, top_k=TOP_K_SPECIES)
top_species = [SpeciesResult(name=sp, probability=round(prob * 100.0, 2)) for sp, prob in species_predictions]
except Exception as e:
print(f"β species prediction error at final stage: {e}")
top_species = None
processing_time = round(time.time() - start_time, 2)
# Build response (ensure fields match your Pydantic model)
response = BloomPredictionResponse(
success=True,
status=status,
# requested_date field removed from model earlier; ensure you aren't passing unexpected args
peak_month=peak_month_out,
peak_probability=peak_prob_out,
bloom_window=bloom_window,
top_species=top_species,
monthly_probabilities=monthly_curve,
)
# optionally attach processing_time as extra info in logs (not in model)
print(f"INFO processing_time: {processing_time}s, peak_score: {peak_score}, peak_month: {peak_month}")
return response
# ------------------------------
# Local run
# ------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
|