Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict, Any | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import json | |
| from catboost import CatBoostRegressor | |
| from difflib import SequenceMatcher | |
| app = FastAPI( | |
| title="Districtmaps.ai API", | |
| description="District-level NCD risk intelligence for India. Pre-computed scores + live inference.", | |
| version="2.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ GLOBALS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| df = None | |
| model = None | |
| feature_names = None | |
| feature_medians = None | |
| feature_map = None | |
| # ββ STARTUP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_all(): | |
| global df, model, feature_names, feature_medians, feature_map | |
| # Pre-computed CSV | |
| DATA_PATH = os.getenv("DATA_PATH", "india_all_districts_risk.csv") | |
| df = pd.read_csv(DATA_PATH) | |
| df.columns = [c.strip().lower().replace(" ", "_") for c in df.columns] | |
| df["district_lower"] = df["district"].str.lower().str.strip() | |
| df["state_lower"] = df["state"].str.lower().str.strip() | |
| print(f"β Loaded {len(df)} districts") | |
| # Live inference model | |
| model = CatBoostRegressor() | |
| model.load_model("model_clean_inference.cbm") | |
| print("β Model loaded") | |
| with open("feature_names.json") as f: feature_names = json.load(f) | |
| with open("feature_medians.json") as f: feature_medians = json.load(f) | |
| with open("feature_map.json") as f: feature_map = json.load(f) | |
| print(f"β Feature map loaded: {len(feature_map)} mappings") | |
| # ββ HELPERS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def safe_float(val): | |
| try: | |
| f = float(val) | |
| return round(f, 4) if not np.isnan(f) else None | |
| except: | |
| return None | |
| def format_district(row): | |
| return { | |
| "district": row.get("district", ""), | |
| "state": row.get("state", ""), | |
| "risk_scores": { | |
| "diabetes": safe_float(row.get("diabetes_risk")), | |
| "blood_pressure": safe_float(row.get("blood_pressure_risk")), | |
| "obesity": safe_float(row.get("obesity_risk")), | |
| "anaemia": safe_float(row.get("anaemia_risk")), | |
| }, | |
| "composite_risk": safe_float(row.get("composite_risk")), | |
| "risk_percentile": safe_float(row.get("diabetes_risk_norm")), | |
| } | |
| def fuzzy_match_feature(input_key: str, threshold: float = 0.6): | |
| """Match an input column name to a model feature name.""" | |
| input_clean = input_key.lower().strip().replace("_", " ").replace("-", " ") | |
| # Direct match in feature map | |
| if input_clean in feature_map: | |
| return feature_map[input_clean], 1.0 | |
| # Partial match in feature map keys | |
| best_score, best_match = 0, None | |
| for map_key, feat_name in feature_map.items(): | |
| score = SequenceMatcher(None, input_clean, map_key).ratio() | |
| if score > best_score: | |
| best_score = score | |
| best_match = feat_name | |
| if best_score >= threshold: | |
| return best_match, best_score | |
| # Direct match against raw feature names | |
| for feat in feature_names: | |
| feat_clean = feat.lower().replace("_", " ") | |
| score = SequenceMatcher(None, input_clean, feat_clean).ratio() | |
| if score > best_score: | |
| best_score = score | |
| best_match = feat | |
| if best_score >= threshold: | |
| return best_match, best_score | |
| return None, best_score | |
| # ββ ROUTES ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return { | |
| "product": "Districtmaps.ai", | |
| "version": "2.0.0", | |
| "description": "District-level NCD risk intelligence for India", | |
| "districts": len(df) if df is not None else 0, | |
| "conditions": ["diabetes", "blood_pressure", "obesity", "anaemia"], | |
| "validation": { | |
| "cross_sectional_r2": 0.7132, | |
| "temporal_r2": 0.6279, | |
| "temporal_gap": "4 years (NFHS-4 2015-16 β NFHS-5 2019-21)", | |
| "features": len(feature_names) if feature_names else 0, | |
| }, | |
| "endpoints": { | |
| "GET /risk": "Pre-computed risk scores for a named district", | |
| "GET /top": "Top N highest risk districts", | |
| "GET /state/{state}": "All districts within a state", | |
| "GET /districts": "Full ranked list", | |
| "POST /predict": "Live inference β send your own district data", | |
| "GET /features": "List all supported input features", | |
| } | |
| } | |
| # ββ PRE-COMPUTED ENDPOINTS ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_district_risk(district: str, state: str = None): | |
| mask = df["district_lower"] == district.lower().strip() | |
| if state: | |
| mask &= df["state_lower"] == state.lower().strip() | |
| results = df[mask] | |
| if results.empty: | |
| mask2 = df["district_lower"].str.contains(district.lower().strip(), na=False) | |
| if state: | |
| mask2 &= df["state_lower"].str.contains(state.lower().strip(), na=False) | |
| results = df[mask2] | |
| if results.empty: | |
| raise HTTPException(status_code=404, detail=f"District '{district}' not found.") | |
| return {"query": district, "matches": [format_district(row) for _, row in results.iterrows()]} | |
| def get_all_districts(sort_by: str = "composite_risk", order: str = "desc", limit: int = 708): | |
| col = sort_by if sort_by in df.columns else "composite_risk" | |
| sorted_df = df.sort_values(col, ascending=(order == "asc")).head(limit) | |
| return {"total": len(sorted_df), "sorted_by": col, | |
| "districts": [format_district(row) for _, row in sorted_df.iterrows()]} | |
| def get_top_districts(n: int = 10, condition: str = "composite_risk"): | |
| col = condition if condition in df.columns else "composite_risk" | |
| top = df.nlargest(n, col) | |
| return {"condition": col, "top_n": n, | |
| "districts": [format_district(row) for _, row in top.iterrows()]} | |
| def get_state_districts(state: str): | |
| mask = df["state_lower"].str.contains(state.lower().strip(), na=False) | |
| results = df[mask].sort_values("composite_risk", ascending=False) | |
| if results.empty: | |
| raise HTTPException(status_code=404, detail=f"State '{state}' not found.") | |
| return {"state": state, "districts": len(results), | |
| "ranked": [format_district(row) for _, row in results.iterrows()]} | |
| # ββ LIVE INFERENCE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PredictRequest(BaseModel): | |
| data: Dict[str, Any] | |
| district_name: Optional[str] = "Unknown" | |
| fill_missing: Optional[bool] = True | |
| def predict(request: PredictRequest): | |
| """ | |
| Live inference endpoint. Send any district-level health indicators | |
| in your own column naming convention. We fuzzy-match to our 78 features, | |
| fill missing values with national medians, and return a live prediction. | |
| Example: | |
| { | |
| "district_name": "My District", | |
| "data": { | |
| "obesity": 32.1, | |
| "tobacco": 18.4, | |
| "literacy": 89.2, | |
| "insurance": 45.0, | |
| "anaemia": 52.3 | |
| } | |
| } | |
| """ | |
| input_data = request.data | |
| matched = {} | |
| unmatched = [] | |
| match_report = [] | |
| # Fuzzy match each input column to a model feature | |
| for input_key, input_val in input_data.items(): | |
| feat_name, score = fuzzy_match_feature(str(input_key)) | |
| if feat_name: | |
| matched[feat_name] = float(input_val) | |
| match_report.append({ | |
| "input_column": input_key, | |
| "mapped_to": feat_name, | |
| "confidence": round(score, 3), | |
| "value": float(input_val) | |
| }) | |
| else: | |
| unmatched.append(input_key) | |
| if len(matched) == 0: | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"None of your columns could be matched to model features. " | |
| f"Call GET /features to see supported inputs." | |
| ) | |
| # Build full feature vector | |
| feature_vector = {} | |
| filled_with_median = [] | |
| for feat in feature_names: | |
| if feat in matched: | |
| feature_vector[feat] = matched[feat] | |
| elif request.fill_missing: | |
| feature_vector[feat] = feature_medians.get(feat, 0) | |
| filled_with_median.append(feat) | |
| else: | |
| feature_vector[feat] = 0 | |
| # Run prediction | |
| X_input = pd.DataFrame([feature_vector])[feature_names] | |
| prediction = float(model.predict(X_input)[0]) | |
| prediction = round(max(0, min(100, prediction)), 4) | |
| # Risk band | |
| if prediction > 15: risk_band = "VERY HIGH" | |
| elif prediction > 10: risk_band = "HIGH" | |
| elif prediction > 7: risk_band = "MODERATE" | |
| elif prediction > 4: risk_band = "LOW" | |
| else: risk_band = "VERY LOW" | |
| return { | |
| "district": request.district_name, | |
| "prediction": { | |
| "diabetes_risk_pct": prediction, | |
| "risk_band": risk_band, | |
| "model": "CatBoost Β· RΒ²=0.7132 Β· MAE=2.06%", | |
| }, | |
| "input_summary": { | |
| "columns_received": len(input_data), | |
| "columns_matched": len(matched), | |
| "columns_unmatched": len(unmatched), | |
| "features_filled_median": len(filled_with_median), | |
| "coverage_pct": round(len(matched) / len(feature_names) * 100, 1), | |
| }, | |
| "match_report": match_report, | |
| "unmatched_cols": unmatched, | |
| "note": f"Prediction based on {len(matched)} matched features out of {len(feature_names)} total. " | |
| f"{len(filled_with_median)} features filled with national district medians." | |
| } | |
| def list_features(): | |
| """Returns all supported feature names and common aliases for /predict.""" | |
| return { | |
| "total_features": len(feature_names), | |
| "model_feature_names": feature_names, | |
| "common_aliases": { | |
| "obesity / overweight / bmi": "Women overweight or obese (BMI β₯25)", | |
| "tobacco / smoking": "Men age 15+ who use tobacco", | |
| "literacy / education": "Female population who attended school", | |
| "insurance / health insurance": "Households with health insurance", | |
| "anaemia / anemia": "Women age 15-49 who are anaemic", | |
| "sanitation": "Households with improved sanitation", | |
| "teen pregnancy": "Women 15-19 already mothers or pregnant", | |
| "children overweight": "Children under 5 who are overweight", | |
| }, | |
| "tip": "Column names are fuzzy-matched β send whatever naming convention you use." | |
| } | |