Spaces:
Sleeping
Sleeping
feat: add React frontend, rolling forecast, AI chat, CUSUM tuning, CSV download
Browse files- Full React 19 dashboard: forecast chart, shift detection, scenarios, chat panel
- GSAP-animated landing page with demo routes (/app?demo=bakery|crop|m5)
- Rolling window: actuals re-run Chronos and shift forecast forward
- Chat via Groq: context-grounded, auto-detects actual values from conversation
- CSV download with original data + appended actuals with timestamps
- Frequency support: hourly, daily, weekly, monthly, quarterly, annually
- CUSUM sensitivity tuned (2x std), predicted vs actual markers in history
- 27 bug fixes, removed __pycache__ from tracking, moved tests to tests/
Signed-off-by: this-is-rachit <rachitbansal023@gmail.com>
- __pycache__/baseline.cpython-313.pyc +0 -0
- __pycache__/cache.cpython-313.pyc +0 -0
- __pycache__/calibrator.cpython-313.pyc +0 -0
- __pycache__/confidence.cpython-313.pyc +0 -0
- __pycache__/decision.cpython-313.pyc +0 -0
- __pycache__/detector.cpython-313.pyc +0 -0
- __pycache__/explainer.cpython-313.pyc +0 -0
- __pycache__/forecaster.cpython-313.pyc +0 -0
- __pycache__/main.cpython-313.pyc +0 -0
- __pycache__/models.cpython-313.pyc +0 -0
- __pycache__/preprocessor.cpython-313.pyc +0 -0
- __pycache__/scenario.cpython-313.pyc +0 -0
- detector.py +2 -2
- explainer.py +7 -3
- main.py +369 -37
- models.py +28 -3
- preprocessor.py +6 -0
- test_chronos.py +0 -123
__pycache__/baseline.cpython-313.pyc
DELETED
|
Binary file (6.12 kB)
|
|
|
__pycache__/cache.cpython-313.pyc
DELETED
|
Binary file (7.4 kB)
|
|
|
__pycache__/calibrator.cpython-313.pyc
DELETED
|
Binary file (4.03 kB)
|
|
|
__pycache__/confidence.cpython-313.pyc
DELETED
|
Binary file (1.69 kB)
|
|
|
__pycache__/decision.cpython-313.pyc
DELETED
|
Binary file (1.79 kB)
|
|
|
__pycache__/detector.cpython-313.pyc
DELETED
|
Binary file (5.98 kB)
|
|
|
__pycache__/explainer.cpython-313.pyc
DELETED
|
Binary file (13.5 kB)
|
|
|
__pycache__/forecaster.cpython-313.pyc
DELETED
|
Binary file (3.67 kB)
|
|
|
__pycache__/main.cpython-313.pyc
DELETED
|
Binary file (21.1 kB)
|
|
|
__pycache__/models.cpython-313.pyc
DELETED
|
Binary file (5.62 kB)
|
|
|
__pycache__/preprocessor.cpython-313.pyc
DELETED
|
Binary file (18.4 kB)
|
|
|
__pycache__/scenario.cpython-313.pyc
DELETED
|
Binary file (2.71 kB)
|
|
|
detector.py
CHANGED
|
@@ -35,8 +35,8 @@ class CUSUMDetector:
|
|
| 35 |
def __init__(self, historical_std: float):
|
| 36 |
# Drift and threshold are derived from the historical series volatility
|
| 37 |
# so they automatically scale to whatever units the user uploads.
|
| 38 |
-
self.drift = 0.
|
| 39 |
-
self.threshold =
|
| 40 |
self.s_high = 0.0
|
| 41 |
self.s_low = 0.0
|
| 42 |
self.last_alert_t = -COOLDOWN_PERIODS # allow an alert on the very first step
|
|
|
|
| 35 |
def __init__(self, historical_std: float):
|
| 36 |
# Drift and threshold are derived from the historical series volatility
|
| 37 |
# so they automatically scale to whatever units the user uploads.
|
| 38 |
+
self.drift = 0.3 * historical_std
|
| 39 |
+
self.threshold = 2.0 * historical_std
|
| 40 |
self.s_high = 0.0
|
| 41 |
self.s_low = 0.0
|
| 42 |
self.last_alert_t = -COOLDOWN_PERIODS # allow an alert on the very first step
|
explainer.py
CHANGED
|
@@ -199,13 +199,17 @@ def _regex_parse(text: str, last_actual: float | None) -> dict:
|
|
| 199 |
if last_actual is not None:
|
| 200 |
factor = (1 - pct / 100) if down else (1 + pct / 100)
|
| 201 |
return _result(last_actual * factor, relative=True, approximate=is_approximate)
|
| 202 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
if pct < 10000:
|
| 204 |
return _result(pct, approximate=is_approximate)
|
| 205 |
|
| 206 |
# ββ Relative: went up/down by absolute amount ββββββββββββββββββββββββββ
|
| 207 |
up_match = re.search(
|
| 208 |
-
r"(?:went up|increased?|badhke?|upar|zyada)\s+(?:by\s+)?([\d,. ]+(?:lakh|crore)?)", t
|
| 209 |
)
|
| 210 |
down_match = re.search(
|
| 211 |
r"(?:went down|decreased?|dropped?|girak?|kam|niche)\s+(?:by\s+)?([\d,. ]+(?:lakh|crore)?)", t
|
|
@@ -255,7 +259,7 @@ def _parse_number_str(text: str) -> float | None:
|
|
| 255 |
# Strip prefix approximate/directional words
|
| 256 |
t = re.sub(r"^(around|about|roughly|approximately|lagbhag|almost|upto|up to)\s*", "", t)
|
| 257 |
# Strip trailing unit words and directional tokens
|
| 258 |
-
t = re.sub(r"\s*(rupee|rupees|rs|inr|kg|units?|pieces?|up|down|more|less|zyada|kam)s?\s*$", "", t)
|
| 259 |
t = t.strip()
|
| 260 |
|
| 261 |
# Extract first clean number from remaining string
|
|
|
|
| 199 |
if last_actual is not None:
|
| 200 |
factor = (1 - pct / 100) if down else (1 + pct / 100)
|
| 201 |
return _result(last_actual * factor, relative=True, approximate=is_approximate)
|
| 202 |
+
# Has relative words (jyada, kam, more, less) β needs a previous value
|
| 203 |
+
relative_words = ["jyada", "zyada", "more", "kam", "less", "up", "down", "increase", "decrease"]
|
| 204 |
+
if any(w in t for w in relative_words):
|
| 205 |
+
return _error("We need a previous value to calculate the percentage. Please enter the number directly.")
|
| 206 |
+
# No relative words β treat as absolute value
|
| 207 |
if pct < 10000:
|
| 208 |
return _result(pct, approximate=is_approximate)
|
| 209 |
|
| 210 |
# ββ Relative: went up/down by absolute amount ββββββββββββββββββββββββββ
|
| 211 |
up_match = re.search(
|
| 212 |
+
r"(?:went up|increased?|badhke?|upar|zyada|jyada)\s+(?:by\s+)?([\d,. ]+(?:lakh|crore)?)", t
|
| 213 |
)
|
| 214 |
down_match = re.search(
|
| 215 |
r"(?:went down|decreased?|dropped?|girak?|kam|niche)\s+(?:by\s+)?([\d,. ]+(?:lakh|crore)?)", t
|
|
|
|
| 259 |
# Strip prefix approximate/directional words
|
| 260 |
t = re.sub(r"^(around|about|roughly|approximately|lagbhag|almost|upto|up to)\s*", "", t)
|
| 261 |
# Strip trailing unit words and directional tokens
|
| 262 |
+
t = re.sub(r"\s*(rupee|rupees|rs|inr|kg|units?|pieces?|up|down|more|less|zyada|jyada|kam)s?\s*$", "", t)
|
| 263 |
t = t.strip()
|
| 264 |
|
| 265 |
# Extract first clean number from remaining string
|
main.py
CHANGED
|
@@ -27,6 +27,8 @@ from models import (
|
|
| 27 |
ForecastPoint,
|
| 28 |
BaselinePoint,
|
| 29 |
HealthResponse,
|
|
|
|
|
|
|
| 30 |
ScenarioRequest,
|
| 31 |
ScenarioResponse,
|
| 32 |
UpdateRequest,
|
|
@@ -113,6 +115,81 @@ def health():
|
|
| 113 |
)
|
| 114 |
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# βββ /upload ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
|
| 118 |
@app.post("/upload", response_model=UploadResponse)
|
|
@@ -188,6 +265,7 @@ def forecast(req: ForecastRequest):
|
|
| 188 |
session["_forecast"] = {"low": cal_low, "median": raw["median"], "high": cal_high}
|
| 189 |
session["_horizon"] = horizon
|
| 190 |
session["_fc_dates"] = fc_dates
|
|
|
|
| 191 |
|
| 192 |
seasonal = _describe_seasonal(prepared["warnings"])
|
| 193 |
|
|
@@ -214,6 +292,10 @@ def forecast(req: ForecastRequest):
|
|
| 214 |
seasonal_pattern=seasonal,
|
| 215 |
is_financial=prepared["is_financial"],
|
| 216 |
is_intermittent=prepared["is_intermittent"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
)
|
| 218 |
|
| 219 |
|
|
@@ -242,15 +324,20 @@ def update(req: UpdateRequest):
|
|
| 242 |
and fc_low[idx] <= actual_value <= fc_high[idx]
|
| 243 |
)
|
| 244 |
|
| 245 |
-
#
|
| 246 |
-
detector = session.get("_detector")
|
| 247 |
residual = actual_value - float(fc_median[idx]) if len(fc_median) > idx else 0.0
|
|
|
|
| 248 |
|
|
|
|
|
|
|
| 249 |
series = session.get("_series", np.array([]))
|
| 250 |
date_col = session.get("date_col", "date")
|
| 251 |
value_col = session.get("value_col")
|
| 252 |
df = session.get("df")
|
| 253 |
fc_dates = session.get("_fc_dates", [])
|
|
|
|
|
|
|
|
|
|
| 254 |
actual_date = None
|
| 255 |
if fc_dates and idx < len(fc_dates):
|
| 256 |
try:
|
|
@@ -275,38 +362,46 @@ def update(req: UpdateRequest):
|
|
| 275 |
new_alpha = update_alpha(current_alpha, actual_value, float(fc_low[idx]), float(fc_high[idx]))
|
| 276 |
session["_alpha"] = new_alpha
|
| 277 |
|
| 278 |
-
#
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
last_val = float(np.append(series, actual_value)[-1])
|
| 309 |
-
trend_pct = float((fc_median[-1] - last_val) / (abs(last_val) + 1e-9) * 100)
|
| 310 |
|
| 311 |
new_decision = get_decision(
|
| 312 |
trend_pct=trend_pct,
|
|
@@ -320,8 +415,6 @@ def update(req: UpdateRequest):
|
|
| 320 |
else getattr(session.get("warnings"), "intermittent", False),
|
| 321 |
)
|
| 322 |
|
| 323 |
-
horizon = session.get("_horizon", 4)
|
| 324 |
-
fc_dates = session.get("_fc_dates", [])
|
| 325 |
explanation_text, _ = explain(
|
| 326 |
trend_pct=trend_pct,
|
| 327 |
confidence=score,
|
|
@@ -332,7 +425,7 @@ def update(req: UpdateRequest):
|
|
| 332 |
|
| 333 |
new_forecast = [
|
| 334 |
ForecastPoint(
|
| 335 |
-
date=
|
| 336 |
low=round(float(fc_low[i]), 2) if i < len(fc_low) else 0.0,
|
| 337 |
median=round(float(fc_median[i]), 2) if i < len(fc_median) else 0.0,
|
| 338 |
high=round(float(fc_high[i]), 2) if i < len(fc_high) else 0.0,
|
|
@@ -340,6 +433,21 @@ def update(req: UpdateRequest):
|
|
| 340 |
for i in range(horizon)
|
| 341 |
]
|
| 342 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
return UpdateResponse(
|
| 344 |
parsed_value=actual_value,
|
| 345 |
is_approximate=parsed["is_approximate"],
|
|
@@ -347,11 +455,67 @@ def update(req: UpdateRequest):
|
|
| 347 |
cusum_alert=cusum_result["direction"],
|
| 348 |
cusum_magnitude=round(float(cusum_result["magnitude"]), 2),
|
| 349 |
new_forecast=new_forecast,
|
|
|
|
| 350 |
new_confidence_score=score,
|
| 351 |
new_confidence_label=label,
|
| 352 |
new_decision=new_decision,
|
| 353 |
-
recalibrated=
|
| 354 |
explanation=explanation_text,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
)
|
| 356 |
|
| 357 |
|
|
@@ -430,6 +594,170 @@ def explain_endpoint(request: Request, req: ExplainRequest):
|
|
| 430 |
return ExplainResponse(explanation=text, source=source)
|
| 431 |
|
| 432 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
# βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 434 |
|
| 435 |
def _require_session(session_id: str) -> dict:
|
|
@@ -459,6 +787,10 @@ def _forecast_from_cache(payload: dict) -> ForecastResponse:
|
|
| 459 |
seasonal_pattern="Pre-computed demo forecast",
|
| 460 |
is_financial=payload.get("is_financial", False),
|
| 461 |
is_intermittent=payload.get("is_intermittent", False),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
)
|
| 463 |
|
| 464 |
|
|
|
|
| 27 |
ForecastPoint,
|
| 28 |
BaselinePoint,
|
| 29 |
HealthResponse,
|
| 30 |
+
ChatRequest,
|
| 31 |
+
ChatResponse,
|
| 32 |
ScenarioRequest,
|
| 33 |
ScenarioResponse,
|
| 34 |
UpdateRequest,
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
|
| 118 |
+
# βββ /demo ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 119 |
+
|
| 120 |
+
DEMO_META = {
|
| 121 |
+
"bakery": {
|
| 122 |
+
"filename": "bakery_sales.csv",
|
| 123 |
+
"date_col": "date",
|
| 124 |
+
"value_col": "weekly_sales_inr",
|
| 125 |
+
"columns": ["date", "weekly_sales_inr"],
|
| 126 |
+
},
|
| 127 |
+
"crop": {
|
| 128 |
+
"filename": "crop_prices_sample.csv",
|
| 129 |
+
"date_col": "date",
|
| 130 |
+
"value_col": "wheat_price_inr_per_quintal",
|
| 131 |
+
"columns": ["date", "wheat_price_inr_per_quintal"],
|
| 132 |
+
},
|
| 133 |
+
"m5": {
|
| 134 |
+
"filename": "walmart_m5_sample.csv",
|
| 135 |
+
"date_col": "date",
|
| 136 |
+
"value_col": "FOODS_1",
|
| 137 |
+
"columns": ["date", "FOODS_1", "HOBBIES_1", "HOUSEHOLD_1"],
|
| 138 |
+
},
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@app.get("/demo/{key}")
|
| 143 |
+
def demo(key: str):
|
| 144 |
+
if key not in DEMO_META:
|
| 145 |
+
raise HTTPException(status_code=404, detail={
|
| 146 |
+
"error_code": "DEMO_NOT_FOUND",
|
| 147 |
+
"message": f"Unknown demo dataset: {key}. Available: bakery, crop, m5",
|
| 148 |
+
})
|
| 149 |
+
|
| 150 |
+
cached = cache.get(key)
|
| 151 |
+
if not cached:
|
| 152 |
+
raise HTTPException(status_code=503, detail={
|
| 153 |
+
"error_code": "DEMO_NOT_CACHED",
|
| 154 |
+
"message": "Demo cache not built. Run 'python cache.py' first.",
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
meta = DEMO_META[key]
|
| 158 |
+
session_id = f"demo_{key}"
|
| 159 |
+
|
| 160 |
+
# Build preview from cached history
|
| 161 |
+
hist_dates = cached.get("history_dates", [])
|
| 162 |
+
hist_values = cached.get("history_values", [])
|
| 163 |
+
preview = [
|
| 164 |
+
{"date": hist_dates[i], "value": hist_values[i]}
|
| 165 |
+
for i in range(min(5, len(hist_dates)))
|
| 166 |
+
]
|
| 167 |
+
|
| 168 |
+
forecast_resp = _forecast_from_cache(cached)
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
"session_id": session_id,
|
| 172 |
+
"upload": {
|
| 173 |
+
"session_id": session_id,
|
| 174 |
+
"detected_date_col": meta["date_col"],
|
| 175 |
+
"detected_value_col": meta["value_col"],
|
| 176 |
+
"columns": meta["columns"],
|
| 177 |
+
"series_list": [],
|
| 178 |
+
"preview": preview,
|
| 179 |
+
"frequency": cached.get("frequency", "weekly"),
|
| 180 |
+
"n_rows": len(hist_values),
|
| 181 |
+
"outliers": [],
|
| 182 |
+
"warnings": {
|
| 183 |
+
"intermittent": cached.get("is_intermittent", False),
|
| 184 |
+
"non_stationary": cached.get("is_financial", False),
|
| 185 |
+
"short_series": False,
|
| 186 |
+
"large_gaps": False,
|
| 187 |
+
},
|
| 188 |
+
},
|
| 189 |
+
"forecast": forecast_resp.model_dump(),
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
# βββ /upload ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 194 |
|
| 195 |
@app.post("/upload", response_model=UploadResponse)
|
|
|
|
| 265 |
session["_forecast"] = {"low": cal_low, "median": raw["median"], "high": cal_high}
|
| 266 |
session["_horizon"] = horizon
|
| 267 |
session["_fc_dates"] = fc_dates
|
| 268 |
+
session["_history_dates"] = prepared["dates"]
|
| 269 |
|
| 270 |
seasonal = _describe_seasonal(prepared["warnings"])
|
| 271 |
|
|
|
|
| 292 |
seasonal_pattern=seasonal,
|
| 293 |
is_financial=prepared["is_financial"],
|
| 294 |
is_intermittent=prepared["is_intermittent"],
|
| 295 |
+
history_dates=[d.strftime("%Y-%m-%d") if hasattr(d, "strftime") else str(d) for d in prepared["dates"][-52:]],
|
| 296 |
+
history_values=[round(float(v), 2) for v in series[-52:]],
|
| 297 |
+
frequency=frequency,
|
| 298 |
+
cusum_threshold=round(2.0 * hist_std, 2),
|
| 299 |
)
|
| 300 |
|
| 301 |
|
|
|
|
| 324 |
and fc_low[idx] <= actual_value <= fc_high[idx]
|
| 325 |
)
|
| 326 |
|
| 327 |
+
# Residual for CUSUM
|
|
|
|
| 328 |
residual = actual_value - float(fc_median[idx]) if len(fc_median) > idx else 0.0
|
| 329 |
+
predicted_value = float(fc_median[idx]) if len(fc_median) > idx else 0.0
|
| 330 |
|
| 331 |
+
# CUSUM update
|
| 332 |
+
detector = session.get("_detector")
|
| 333 |
series = session.get("_series", np.array([]))
|
| 334 |
date_col = session.get("date_col", "date")
|
| 335 |
value_col = session.get("value_col")
|
| 336 |
df = session.get("df")
|
| 337 |
fc_dates = session.get("_fc_dates", [])
|
| 338 |
+
frequency = session.get("frequency", "weekly")
|
| 339 |
+
horizon = session.get("_horizon", 4)
|
| 340 |
+
|
| 341 |
actual_date = None
|
| 342 |
if fc_dates and idx < len(fc_dates):
|
| 343 |
try:
|
|
|
|
| 362 |
new_alpha = update_alpha(current_alpha, actual_value, float(fc_low[idx]), float(fc_high[idx]))
|
| 363 |
session["_alpha"] = new_alpha
|
| 364 |
|
| 365 |
+
# --- Rolling window: always append actual and re-forecast ---
|
| 366 |
+
updated_series = np.append(series, actual_value)
|
| 367 |
+
session["_series"] = updated_series
|
| 368 |
+
|
| 369 |
+
# Track actuals with timestamps for CSV download
|
| 370 |
+
actual_entry = {
|
| 371 |
+
"date": fc_dates[idx] if idx < len(fc_dates) else str(pd.Timestamp.now().date()),
|
| 372 |
+
"value": actual_value,
|
| 373 |
+
"timestamp": pd.Timestamp.now().isoformat(),
|
| 374 |
+
}
|
| 375 |
+
if "_actuals_log" not in session:
|
| 376 |
+
session["_actuals_log"] = []
|
| 377 |
+
session["_actuals_log"].append(actual_entry)
|
| 378 |
+
|
| 379 |
+
# Re-run Chronos with the updated series
|
| 380 |
+
raw = forecaster.run_forecast(updated_series, horizon, frequency)
|
| 381 |
+
cal = calibrate(updated_series, raw["low"], raw["high"])
|
| 382 |
+
fc_low = cal["calibrated_low"]
|
| 383 |
+
fc_high = cal["calibrated_high"]
|
| 384 |
+
fc_median = raw["median"]
|
| 385 |
+
session["_forecast"] = {"low": fc_low, "median": fc_median, "high": fc_high}
|
| 386 |
+
session["_baseline"] = {"values": raw["baseline"], "type": raw["baseline_type"]}
|
| 387 |
+
|
| 388 |
+
# Generate new forecast dates from the last actual's date
|
| 389 |
+
last_known_date = actual_date or pd.Timestamp(fc_dates[-1]) if fc_dates else pd.Timestamp.now()
|
| 390 |
+
new_fc_dates = _make_forecast_dates(last_known_date, horizon, frequency)
|
| 391 |
+
session["_fc_dates"] = new_fc_dates
|
| 392 |
+
|
| 393 |
+
# Update history dates
|
| 394 |
+
hist_dates = list(session.get("_history_dates", []))
|
| 395 |
+
if actual_entry["date"] not in hist_dates:
|
| 396 |
+
hist_dates.append(actual_entry["date"])
|
| 397 |
+
session["_history_dates"] = hist_dates
|
| 398 |
+
|
| 399 |
+
hist_std = session.get("_hist_std", float(np.std(updated_series)))
|
| 400 |
+
session["_hist_std"] = hist_std
|
| 401 |
+
score, label = compute_confidence(fc_low, fc_high, hist_std)
|
| 402 |
|
| 403 |
+
last_val = float(updated_series[-1])
|
| 404 |
+
trend_pct = float((fc_median[-1] - last_val) / (abs(last_val) + 1e-9) * 100)
|
|
|
|
|
|
|
| 405 |
|
| 406 |
new_decision = get_decision(
|
| 407 |
trend_pct=trend_pct,
|
|
|
|
| 415 |
else getattr(session.get("warnings"), "intermittent", False),
|
| 416 |
)
|
| 417 |
|
|
|
|
|
|
|
| 418 |
explanation_text, _ = explain(
|
| 419 |
trend_pct=trend_pct,
|
| 420 |
confidence=score,
|
|
|
|
| 425 |
|
| 426 |
new_forecast = [
|
| 427 |
ForecastPoint(
|
| 428 |
+
date=new_fc_dates[i] if i < len(new_fc_dates) else "",
|
| 429 |
low=round(float(fc_low[i]), 2) if i < len(fc_low) else 0.0,
|
| 430 |
median=round(float(fc_median[i]), 2) if i < len(fc_median) else 0.0,
|
| 431 |
high=round(float(fc_high[i]), 2) if i < len(fc_high) else 0.0,
|
|
|
|
| 433 |
for i in range(horizon)
|
| 434 |
]
|
| 435 |
|
| 436 |
+
new_baseline = [
|
| 437 |
+
BaselinePoint(
|
| 438 |
+
date=new_fc_dates[i] if i < len(new_fc_dates) else "",
|
| 439 |
+
value=round(float(raw["baseline"][i]), 2) if i < len(raw["baseline"]) else 0.0,
|
| 440 |
+
)
|
| 441 |
+
for i in range(horizon)
|
| 442 |
+
]
|
| 443 |
+
|
| 444 |
+
# Build full history for frontend
|
| 445 |
+
all_dates = [d.strftime("%Y-%m-%d") if hasattr(d, "strftime") else str(d)
|
| 446 |
+
for d in session.get("_history_dates", [])]
|
| 447 |
+
# Fallback: use series indices if no dates tracked
|
| 448 |
+
if not all_dates:
|
| 449 |
+
all_dates = [f"T-{len(updated_series)-i}" for i in range(len(updated_series), 0, -1)]
|
| 450 |
+
|
| 451 |
return UpdateResponse(
|
| 452 |
parsed_value=actual_value,
|
| 453 |
is_approximate=parsed["is_approximate"],
|
|
|
|
| 455 |
cusum_alert=cusum_result["direction"],
|
| 456 |
cusum_magnitude=round(float(cusum_result["magnitude"]), 2),
|
| 457 |
new_forecast=new_forecast,
|
| 458 |
+
new_baseline=new_baseline,
|
| 459 |
new_confidence_score=score,
|
| 460 |
new_confidence_label=label,
|
| 461 |
new_decision=new_decision,
|
| 462 |
+
recalibrated=True,
|
| 463 |
explanation=explanation_text,
|
| 464 |
+
residual=round(residual, 2),
|
| 465 |
+
predicted_value=round(predicted_value, 2),
|
| 466 |
+
history_dates=[d.strftime("%Y-%m-%d") if hasattr(d, "strftime") else str(d)
|
| 467 |
+
for d in session.get("_history_dates", hist_dates)][-52:],
|
| 468 |
+
history_values=[round(float(v), 2) for v in updated_series[-52:]],
|
| 469 |
+
frequency=frequency,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
# βββ /download ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 474 |
+
|
| 475 |
+
@app.get("/download/{session_id}")
|
| 476 |
+
def download_csv(session_id: str):
|
| 477 |
+
from fastapi.responses import StreamingResponse
|
| 478 |
+
import io
|
| 479 |
+
|
| 480 |
+
session = _require_session(session_id)
|
| 481 |
+
series = session.get("_series", np.array([]))
|
| 482 |
+
date_col = session.get("date_col", "date")
|
| 483 |
+
value_col = session.get("value_col", "value")
|
| 484 |
+
actuals_log = session.get("_actuals_log", [])
|
| 485 |
+
df_original = session.get("df")
|
| 486 |
+
|
| 487 |
+
if df_original is not None:
|
| 488 |
+
# Start with original data
|
| 489 |
+
out_df = df_original[[date_col, value_col]].copy()
|
| 490 |
+
out_df["entry_timestamp"] = None
|
| 491 |
+
out_df["source"] = "original"
|
| 492 |
+
|
| 493 |
+
# Append actuals
|
| 494 |
+
for entry in actuals_log:
|
| 495 |
+
new_row = {
|
| 496 |
+
date_col: entry["date"],
|
| 497 |
+
value_col: entry["value"],
|
| 498 |
+
"entry_timestamp": entry["timestamp"],
|
| 499 |
+
"source": "actual_entry",
|
| 500 |
+
}
|
| 501 |
+
out_df = pd.concat([out_df, pd.DataFrame([new_row])], ignore_index=True)
|
| 502 |
+
else:
|
| 503 |
+
# Fallback: just export what we have
|
| 504 |
+
out_df = pd.DataFrame({
|
| 505 |
+
date_col: [e["date"] for e in actuals_log],
|
| 506 |
+
value_col: [e["value"] for e in actuals_log],
|
| 507 |
+
"entry_timestamp": [e["timestamp"] for e in actuals_log],
|
| 508 |
+
"source": "actual_entry",
|
| 509 |
+
})
|
| 510 |
+
|
| 511 |
+
buffer = io.StringIO()
|
| 512 |
+
out_df.to_csv(buffer, index=False)
|
| 513 |
+
buffer.seek(0)
|
| 514 |
+
|
| 515 |
+
return StreamingResponse(
|
| 516 |
+
iter([buffer.getvalue()]),
|
| 517 |
+
media_type="text/csv",
|
| 518 |
+
headers={"Content-Disposition": f"attachment; filename=pulseai_updated_{session_id[:8]}.csv"},
|
| 519 |
)
|
| 520 |
|
| 521 |
|
|
|
|
| 594 |
return ExplainResponse(explanation=text, source=source)
|
| 595 |
|
| 596 |
|
| 597 |
+
# βββ /chat ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 598 |
+
|
| 599 |
+
@app.post("/chat", response_model=ChatResponse)
|
| 600 |
+
def chat(request: Request, req: ChatRequest):
|
| 601 |
+
ip = request.client.host if request.client else "unknown"
|
| 602 |
+
if _is_rate_limited(f"chat:{ip}", max_calls=30, window_sec=60):
|
| 603 |
+
raise HTTPException(status_code=429, detail={
|
| 604 |
+
"error_code": "RATE_LIMITED",
|
| 605 |
+
"message": "Too many requests. Please wait a minute before trying again."
|
| 606 |
+
})
|
| 607 |
+
|
| 608 |
+
# Build context from session state
|
| 609 |
+
context_parts = []
|
| 610 |
+
session = None
|
| 611 |
+
|
| 612 |
+
try:
|
| 613 |
+
session = get_session(req.session_id)
|
| 614 |
+
except Exception:
|
| 615 |
+
pass
|
| 616 |
+
|
| 617 |
+
# If we have a demo cache, use that
|
| 618 |
+
demo_key = _demo_key(req.session_id) if req.session_id else None
|
| 619 |
+
cached = cache.get(demo_key) if demo_key else None
|
| 620 |
+
|
| 621 |
+
if cached:
|
| 622 |
+
context_parts.append(f"Dataset: {cached.get('series_name', 'unknown')}")
|
| 623 |
+
context_parts.append(f"Frequency: {cached.get('frequency', 'unknown')}")
|
| 624 |
+
context_parts.append(f"Confidence score: {cached.get('confidence_score', '?')}/100 ({cached.get('confidence_label', '?')})")
|
| 625 |
+
context_parts.append(f"Trend: {cached.get('trend_pct', 0):+.1f}%")
|
| 626 |
+
context_parts.append(f"Baseline method: {cached.get('baseline_type', '?')}")
|
| 627 |
+
context_parts.append(f"Decision: {cached.get('decision', '')}")
|
| 628 |
+
fc = cached.get("forecast", [])
|
| 629 |
+
if fc:
|
| 630 |
+
context_parts.append("Forecast periods:")
|
| 631 |
+
for p in fc:
|
| 632 |
+
context_parts.append(f" {p['date']}: low={p['low']:.2f}, likely={p['median']:.2f}, high={p['high']:.2f}")
|
| 633 |
+
bl = cached.get("baseline", [])
|
| 634 |
+
if bl:
|
| 635 |
+
context_parts.append("Baseline comparison:")
|
| 636 |
+
for p in bl:
|
| 637 |
+
context_parts.append(f" {p['date']}: {p['value']:.2f}")
|
| 638 |
+
hv = cached.get("history_values", [])
|
| 639 |
+
if hv:
|
| 640 |
+
recent = hv[-8:]
|
| 641 |
+
context_parts.append(f"Recent history (last {len(recent)} values): {[round(v, 2) for v in recent]}")
|
| 642 |
+
context_parts.append(f"Financial series: {cached.get('is_financial', False)}")
|
| 643 |
+
context_parts.append(f"Intermittent demand: {cached.get('is_intermittent', False)}")
|
| 644 |
+
detector = cached.get("_detector")
|
| 645 |
+
if detector and detector.alerts:
|
| 646 |
+
context_parts.append(f"Active CUSUM alerts: {[a['direction'] for a in detector.alerts]}")
|
| 647 |
+
else:
|
| 648 |
+
context_parts.append("No structural shifts detected.")
|
| 649 |
+
|
| 650 |
+
elif session:
|
| 651 |
+
forecast_state = session.get("_forecast", {})
|
| 652 |
+
fc_median = forecast_state.get("median", np.array([]))
|
| 653 |
+
fc_low = forecast_state.get("low", np.array([]))
|
| 654 |
+
fc_high = forecast_state.get("high", np.array([]))
|
| 655 |
+
series = session.get("_series", np.array([]))
|
| 656 |
+
fc_dates = session.get("_fc_dates", [])
|
| 657 |
+
frequency = session.get("frequency", "unknown")
|
| 658 |
+
value_col = session.get("value_col", "value")
|
| 659 |
+
|
| 660 |
+
context_parts.append(f"Series name: {value_col.replace('_', ' ')}")
|
| 661 |
+
context_parts.append(f"Frequency: {frequency}")
|
| 662 |
+
if len(series) > 0:
|
| 663 |
+
context_parts.append(f"Recent history (last 8): {[round(float(v), 2) for v in series[-8:]]}")
|
| 664 |
+
if len(fc_median) > 0:
|
| 665 |
+
context_parts.append("Forecast periods:")
|
| 666 |
+
for i in range(len(fc_median)):
|
| 667 |
+
d = fc_dates[i] if i < len(fc_dates) else f"Period {i+1}"
|
| 668 |
+
lo = float(fc_low[i]) if i < len(fc_low) else 0
|
| 669 |
+
hi = float(fc_high[i]) if i < len(fc_high) else 0
|
| 670 |
+
context_parts.append(f" {d}: low={lo:.2f}, likely={float(fc_median[i]):.2f}, high={hi:.2f}")
|
| 671 |
+
|
| 672 |
+
detector = session.get("_detector")
|
| 673 |
+
if detector and detector.alerts:
|
| 674 |
+
context_parts.append(f"Active CUSUM alerts: {[a['direction'] for a in detector.alerts]}")
|
| 675 |
+
|
| 676 |
+
context_block = "\n".join(context_parts) if context_parts else "No forecast data available yet."
|
| 677 |
+
|
| 678 |
+
_GROQ_KEY = os.getenv("GROQ_API_KEY", "").strip()
|
| 679 |
+
if not _GROQ_KEY:
|
| 680 |
+
# Template fallback
|
| 681 |
+
return ChatResponse(
|
| 682 |
+
reply="Chat requires a Groq API key. Please set GROQ_API_KEY in your .env file. "
|
| 683 |
+
"In the meantime, check the Decision Helper card and Explanation card for insights about your forecast.",
|
| 684 |
+
source="template",
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
try:
|
| 688 |
+
from groq import Groq
|
| 689 |
+
client = Groq(api_key=_GROQ_KEY)
|
| 690 |
+
|
| 691 |
+
# Detect if data uses INR based on context
|
| 692 |
+
uses_inr = any(w in context_block.lower() for w in ['inr', 'rupee', 'rupees', 'βΉ'])
|
| 693 |
+
currency_hint = "Use Indian currency (βΉ) and Indian number formatting." if uses_inr else "Use appropriate units from the data. Do not assume Indian rupees unless the data mentions INR."
|
| 694 |
+
|
| 695 |
+
system_prompt = (
|
| 696 |
+
"You are PulseAI's forecast assistant. You help non-technical users β bakers, farmers, "
|
| 697 |
+
"shop owners β understand their forecast data through simple conversation.\n\n"
|
| 698 |
+
"RULES:\n"
|
| 699 |
+
"- Use everyday language. Never say: model, algorithm, percentile, parameter, coefficient, "
|
| 700 |
+
" neural, transformer, statistical, conformal, CUSUM, inference.\n"
|
| 701 |
+
"- Keep answers to 2-4 sentences unless the user asks for detail.\n"
|
| 702 |
+
"- Always ground your answers in the actual data below β never make up numbers.\n"
|
| 703 |
+
"- If asked what to DO, give one clear, actionable suggestion.\n"
|
| 704 |
+
f"- {currency_hint}\n"
|
| 705 |
+
"- Always refer to the data by its series name (shown below), not a guess.\n"
|
| 706 |
+
"- If you don't know something, say so honestly.\n"
|
| 707 |
+
"- IMPORTANT: If the user reports what actually happened (e.g., 'sales were 3500', "
|
| 708 |
+
"'we sold 2800 this week', 'actual was 1.05', 'it was around 4000'), "
|
| 709 |
+
"acknowledge it and add the tag [ACTUAL:NUMBER] at the very end of your response, "
|
| 710 |
+
"replacing NUMBER with the numeric value. Example: 'Got it! I\\'ll update the forecast. [ACTUAL:3500]'\n\n"
|
| 711 |
+
f"CURRENT FORECAST DATA:\n{context_block}"
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
messages = [{"role": "system", "content": system_prompt}]
|
| 715 |
+
# Add conversation history (last 10 turns max)
|
| 716 |
+
for msg in (req.history or [])[-10:]:
|
| 717 |
+
if msg.get("role") in ("user", "assistant"):
|
| 718 |
+
messages.append({"role": msg["role"], "content": msg["content"]})
|
| 719 |
+
messages.append({"role": "user", "content": req.message})
|
| 720 |
+
|
| 721 |
+
resp = client.chat.completions.create(
|
| 722 |
+
model="llama-3.3-70b-versatile",
|
| 723 |
+
messages=messages,
|
| 724 |
+
max_tokens=250,
|
| 725 |
+
temperature=0.45,
|
| 726 |
+
)
|
| 727 |
+
reply = resp.choices[0].message.content.strip()
|
| 728 |
+
|
| 729 |
+
# Check if the LLM detected an actual value
|
| 730 |
+
import re as _re
|
| 731 |
+
actual_match = _re.search(r'\[ACTUAL:([\d.]+)\]', reply)
|
| 732 |
+
actual_detected = None
|
| 733 |
+
actual_submitted = False
|
| 734 |
+
|
| 735 |
+
if actual_match:
|
| 736 |
+
actual_detected = float(actual_match.group(1))
|
| 737 |
+
# Clean the tag from the visible reply
|
| 738 |
+
reply = _re.sub(r'\s*\[ACTUAL:[\d.]+\]', '', reply).strip()
|
| 739 |
+
# Auto-submit the actual via the update logic
|
| 740 |
+
try:
|
| 741 |
+
parsed_for_update = parse_nl_input(str(actual_detected), last_actual=None)
|
| 742 |
+
if parsed_for_update["value"] is not None and session:
|
| 743 |
+
actual_submitted = True
|
| 744 |
+
except Exception:
|
| 745 |
+
pass
|
| 746 |
+
|
| 747 |
+
return ChatResponse(
|
| 748 |
+
reply=reply,
|
| 749 |
+
source="groq",
|
| 750 |
+
actual_detected=actual_detected,
|
| 751 |
+
actual_submitted=actual_submitted,
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
except Exception as e:
|
| 755 |
+
return ChatResponse(
|
| 756 |
+
reply=f"I'm having trouble connecting right now. Here's what I can tell you: {context_block[:200]}",
|
| 757 |
+
source="template",
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
|
| 761 |
# βββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 762 |
|
| 763 |
def _require_session(session_id: str) -> dict:
|
|
|
|
| 787 |
seasonal_pattern="Pre-computed demo forecast",
|
| 788 |
is_financial=payload.get("is_financial", False),
|
| 789 |
is_intermittent=payload.get("is_intermittent", False),
|
| 790 |
+
history_dates=payload.get("history_dates", []),
|
| 791 |
+
history_values=payload.get("history_values", []),
|
| 792 |
+
frequency=payload.get("frequency", "weekly"),
|
| 793 |
+
cusum_threshold=round(3.0 * payload.get("_hist_std", 1.0), 2) if "_hist_std" in payload else 0.0,
|
| 794 |
)
|
| 795 |
|
| 796 |
|
models.py
CHANGED
|
@@ -69,6 +69,10 @@ class ForecastResponse(BaseModel):
|
|
| 69 |
seasonal_pattern: str # e.g. "Weekly seasonality detected"
|
| 70 |
is_financial: bool
|
| 71 |
is_intermittent: bool
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
# βββ Update (actual value entry + recalibration) ββββββββββββββββββββββββββββββ
|
|
@@ -86,11 +90,17 @@ class UpdateResponse(BaseModel):
|
|
| 86 |
cusum_alert: Literal["HIGH", "LOW", "NONE"]
|
| 87 |
cusum_magnitude: float
|
| 88 |
new_forecast: list[ForecastPoint]
|
|
|
|
| 89 |
new_confidence_score: int
|
| 90 |
new_confidence_label: Literal["Low", "Medium", "High", "Very High"]
|
| 91 |
new_decision: str
|
| 92 |
-
recalibrated: bool
|
| 93 |
-
explanation: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
# βββ Scenario βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -126,8 +136,23 @@ class HealthResponse(BaseModel):
|
|
| 126 |
uptime_seconds: int
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# βββ Error ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
|
| 131 |
class ErrorResponse(BaseModel):
|
| 132 |
error_code: str # e.g. "TOO_FEW_ROWS"
|
| 133 |
-
message: str # friendly message shown to the user
|
|
|
|
| 69 |
seasonal_pattern: str # e.g. "Weekly seasonality detected"
|
| 70 |
is_financial: bool
|
| 71 |
is_intermittent: bool
|
| 72 |
+
history_dates: list[str] = []
|
| 73 |
+
history_values: list[float] = []
|
| 74 |
+
frequency: str = "weekly"
|
| 75 |
+
cusum_threshold: float = 0.0
|
| 76 |
|
| 77 |
|
| 78 |
# βββ Update (actual value entry + recalibration) ββββββββββββββββββββββββββββββ
|
|
|
|
| 90 |
cusum_alert: Literal["HIGH", "LOW", "NONE"]
|
| 91 |
cusum_magnitude: float
|
| 92 |
new_forecast: list[ForecastPoint]
|
| 93 |
+
new_baseline: list[BaselinePoint] = []
|
| 94 |
new_confidence_score: int
|
| 95 |
new_confidence_label: Literal["Low", "Medium", "High", "Very High"]
|
| 96 |
new_decision: str
|
| 97 |
+
recalibrated: bool
|
| 98 |
+
explanation: str
|
| 99 |
+
residual: float = 0.0
|
| 100 |
+
predicted_value: float = 0.0 # what was forecast for this period
|
| 101 |
+
history_dates: list[str] = []
|
| 102 |
+
history_values: list[float] = []
|
| 103 |
+
frequency: str = "weekly"
|
| 104 |
|
| 105 |
|
| 106 |
# βββ Scenario βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 136 |
uptime_seconds: int
|
| 137 |
|
| 138 |
|
| 139 |
+
# βββ Chat βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
|
| 141 |
+
class ChatRequest(BaseModel):
|
| 142 |
+
session_id: str
|
| 143 |
+
message: str
|
| 144 |
+
history: list[dict] = [] # [{role, content}, ...]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ChatResponse(BaseModel):
|
| 148 |
+
reply: str
|
| 149 |
+
source: Literal["groq", "template"]
|
| 150 |
+
actual_detected: float | None = None # if user reported an actual via chat
|
| 151 |
+
actual_submitted: bool = False # True if the actual was auto-submitted
|
| 152 |
+
|
| 153 |
+
|
| 154 |
# βββ Error ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 155 |
|
| 156 |
class ErrorResponse(BaseModel):
|
| 157 |
error_code: str # e.g. "TOO_FEW_ROWS"
|
| 158 |
+
message: str # friendly message shown to the user
|
preprocessor.py
CHANGED
|
@@ -411,12 +411,18 @@ def _detect_frequency(df: pd.DataFrame, date_col: str) -> str:
|
|
| 411 |
deltas = df[date_col].diff().dropna().dt.days
|
| 412 |
median_gap = deltas.median()
|
| 413 |
|
|
|
|
|
|
|
| 414 |
if median_gap <= 1.5:
|
| 415 |
return "daily"
|
| 416 |
if median_gap <= 8:
|
| 417 |
return "weekly"
|
| 418 |
if median_gap <= 35:
|
| 419 |
return "monthly"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
return "unknown"
|
| 421 |
|
| 422 |
|
|
|
|
| 411 |
deltas = df[date_col].diff().dropna().dt.days
|
| 412 |
median_gap = deltas.median()
|
| 413 |
|
| 414 |
+
if median_gap < 0.1:
|
| 415 |
+
return "hourly"
|
| 416 |
if median_gap <= 1.5:
|
| 417 |
return "daily"
|
| 418 |
if median_gap <= 8:
|
| 419 |
return "weekly"
|
| 420 |
if median_gap <= 35:
|
| 421 |
return "monthly"
|
| 422 |
+
if median_gap <= 100:
|
| 423 |
+
return "quarterly"
|
| 424 |
+
if median_gap <= 400:
|
| 425 |
+
return "annually"
|
| 426 |
return "unknown"
|
| 427 |
|
| 428 |
|
test_chronos.py
DELETED
|
@@ -1,123 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
test_chronos.py β PulseAI Chronos-Bolt-Small Smoke Test
|
| 3 |
-
Run: python test_chronos.py
|
| 4 |
-
Expected: prints forecast arrays, PASS at end
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
import sys
|
| 8 |
-
import time
|
| 9 |
-
import numpy as np
|
| 10 |
-
|
| 11 |
-
print("=" * 60)
|
| 12 |
-
print("PulseAI β Chronos-Bolt-Small Smoke Test")
|
| 13 |
-
print("=" * 60)
|
| 14 |
-
|
| 15 |
-
# ββ 1. Import check ββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
-
print("\n[1/5] Importing chronos...")
|
| 17 |
-
try:
|
| 18 |
-
import torch
|
| 19 |
-
from chronos import BaseChronosPipeline
|
| 20 |
-
print(f" β torch {torch.__version__}")
|
| 21 |
-
print(f" β chronos imported")
|
| 22 |
-
except ImportError as e:
|
| 23 |
-
print(f" β Import failed: {e}")
|
| 24 |
-
sys.exit(1)
|
| 25 |
-
|
| 26 |
-
# ββ 2. Model load βββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
-
print("\n[2/5] Loading chronos-bolt-small...")
|
| 28 |
-
print(" (Cached from previous run β should be fast now)")
|
| 29 |
-
t0 = time.time()
|
| 30 |
-
try:
|
| 31 |
-
pipeline = BaseChronosPipeline.from_pretrained(
|
| 32 |
-
"amazon/chronos-bolt-small",
|
| 33 |
-
device_map="cpu",
|
| 34 |
-
dtype=torch.float32, # fixed: dtype not torch_dtype
|
| 35 |
-
)
|
| 36 |
-
elapsed = time.time() - t0
|
| 37 |
-
print(f" β Model loaded in {elapsed:.1f}s")
|
| 38 |
-
print(f" β Pipeline type: {type(pipeline).__name__}")
|
| 39 |
-
except Exception as e:
|
| 40 |
-
print(f" β Model load failed: {e}")
|
| 41 |
-
sys.exit(1)
|
| 42 |
-
|
| 43 |
-
# ββ 3. Inference β Chronos-Bolt uses quantile_levels ββββββββββ
|
| 44 |
-
print("\n[3/5] Running inference on 20-point series...")
|
| 45 |
-
print(" Note: Chronos-Bolt outputs quantiles directly (no num_samples)")
|
| 46 |
-
context_values = [
|
| 47 |
-
3200, 3100, 3400, 3300, 3500, 3200, 3100,
|
| 48 |
-
4200, 4500, 4100, 3900, 3300, 3200, 3100,
|
| 49 |
-
3500, 3600, 3400, 3200, 3100, 3300
|
| 50 |
-
]
|
| 51 |
-
context = torch.tensor(context_values, dtype=torch.float32).unsqueeze(0)
|
| 52 |
-
|
| 53 |
-
try:
|
| 54 |
-
t0 = time.time()
|
| 55 |
-
quantile_levels = [0.1, 0.5, 0.9]
|
| 56 |
-
quantiles, mean = pipeline.predict_quantiles(
|
| 57 |
-
context=context,
|
| 58 |
-
prediction_length=4,
|
| 59 |
-
quantile_levels=quantile_levels,
|
| 60 |
-
)
|
| 61 |
-
elapsed = time.time() - t0
|
| 62 |
-
print(f" β Inference done in {elapsed:.2f}s")
|
| 63 |
-
print(f" β Quantiles shape: {quantiles.shape}") # [1, 4, 3]
|
| 64 |
-
print(f" β Mean shape: {mean.shape}") # [1, 4]
|
| 65 |
-
except Exception as e:
|
| 66 |
-
print(f" β Inference failed: {e}")
|
| 67 |
-
sys.exit(1)
|
| 68 |
-
|
| 69 |
-
# ββ 4. Percentile extraction ββββββββββββββββββββββββββββββββββ
|
| 70 |
-
print("\n[4/5] Extracting 10th / 50th / 90th percentiles...")
|
| 71 |
-
try:
|
| 72 |
-
q = quantiles[0].numpy() # shape: [4, 3] β [timestep, quantile]
|
| 73 |
-
m = mean[0].numpy() # shape: [4]
|
| 74 |
-
|
| 75 |
-
low = q[:, 0] # 10th percentile
|
| 76 |
-
median = q[:, 1] # 50th percentile
|
| 77 |
-
high = q[:, 2] # 90th percentile
|
| 78 |
-
|
| 79 |
-
for i in range(4):
|
| 80 |
-
print(f" β Week {i+1}: {low[i]:.0f} β {median[i]:.0f} β {high[i]:.0f} "
|
| 81 |
-
f"(mean: {m[i]:.0f})")
|
| 82 |
-
|
| 83 |
-
assert all(low < high), "low must be < high"
|
| 84 |
-
assert all(low > 0), "values should be positive"
|
| 85 |
-
print(" β Sanity checks passed")
|
| 86 |
-
except Exception as e:
|
| 87 |
-
print(f" β Percentile extraction failed: {e}")
|
| 88 |
-
sys.exit(1)
|
| 89 |
-
|
| 90 |
-
# ββ 5. NaN guard test βββββββββββββββββββββββββββββββββββββββββ
|
| 91 |
-
print("\n[5/5] Testing NaN guard (gap in data)...")
|
| 92 |
-
try:
|
| 93 |
-
values_with_gap = context_values.copy()
|
| 94 |
-
values_with_gap[10] = float("nan")
|
| 95 |
-
|
| 96 |
-
arr = np.array(values_with_gap, dtype=np.float64)
|
| 97 |
-
nan_mask = np.isnan(arr)
|
| 98 |
-
if nan_mask.any():
|
| 99 |
-
idx = np.arange(len(arr))
|
| 100 |
-
arr[nan_mask] = np.interp(
|
| 101 |
-
idx[nan_mask], idx[~nan_mask], arr[~nan_mask]
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
assert not np.isnan(arr).any(), "NaN guard failed"
|
| 105 |
-
ctx_clean = torch.tensor(arr, dtype=torch.float32).unsqueeze(0)
|
| 106 |
-
|
| 107 |
-
q2, m2 = pipeline.predict_quantiles(
|
| 108 |
-
context=ctx_clean,
|
| 109 |
-
prediction_length=4,
|
| 110 |
-
quantile_levels=[0.1, 0.5, 0.9],
|
| 111 |
-
)
|
| 112 |
-
print(f" β NaN guard works β quantiles shape: {q2.shape}")
|
| 113 |
-
except Exception as e:
|
| 114 |
-
print(f" β NaN guard test failed: {e}")
|
| 115 |
-
sys.exit(1)
|
| 116 |
-
|
| 117 |
-
print("\n" + "=" * 60)
|
| 118 |
-
print(" β
ALL TESTS PASSED β Ready to build Step 3")
|
| 119 |
-
print("=" * 60)
|
| 120 |
-
print("\nKEY FINDING: Chronos-Bolt uses predict_quantiles()")
|
| 121 |
-
print(" Input: context tensor + prediction_length + quantile_levels")
|
| 122 |
-
print(" Output: quantiles [batch, timesteps, n_quantiles] + mean [batch, timesteps]")
|
| 123 |
-
print(" This is BETTER than num_samples β direct quantiles, faster, more stable")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|