r-bansal commited on
Commit
ba65d00
Β·
1 Parent(s): 13a5236

feat: add React frontend, rolling forecast, AI chat, CUSUM tuning, CSV download

Browse files

- Full React 19 dashboard: forecast chart, shift detection, scenarios, chat panel
- GSAP-animated landing page with demo routes (/app?demo=bakery|crop|m5)
- Rolling window: actuals re-run Chronos and shift forecast forward
- Chat via Groq: context-grounded, auto-detects actual values from conversation
- CSV download with original data + appended actuals with timestamps
- Frequency support: hourly, daily, weekly, monthly, quarterly, annually
- CUSUM sensitivity tuned (2x std), predicted vs actual markers in history
- 27 bug fixes, removed __pycache__ from tracking, moved tests to tests/

Signed-off-by: this-is-rachit <rachitbansal023@gmail.com>

__pycache__/baseline.cpython-313.pyc DELETED
Binary file (6.12 kB)
 
__pycache__/cache.cpython-313.pyc DELETED
Binary file (7.4 kB)
 
__pycache__/calibrator.cpython-313.pyc DELETED
Binary file (4.03 kB)
 
__pycache__/confidence.cpython-313.pyc DELETED
Binary file (1.69 kB)
 
__pycache__/decision.cpython-313.pyc DELETED
Binary file (1.79 kB)
 
__pycache__/detector.cpython-313.pyc DELETED
Binary file (5.98 kB)
 
__pycache__/explainer.cpython-313.pyc DELETED
Binary file (13.5 kB)
 
__pycache__/forecaster.cpython-313.pyc DELETED
Binary file (3.67 kB)
 
__pycache__/main.cpython-313.pyc DELETED
Binary file (21.1 kB)
 
__pycache__/models.cpython-313.pyc DELETED
Binary file (5.62 kB)
 
__pycache__/preprocessor.cpython-313.pyc DELETED
Binary file (18.4 kB)
 
__pycache__/scenario.cpython-313.pyc DELETED
Binary file (2.71 kB)
 
detector.py CHANGED
@@ -35,8 +35,8 @@ class CUSUMDetector:
35
  def __init__(self, historical_std: float):
36
  # Drift and threshold are derived from the historical series volatility
37
  # so they automatically scale to whatever units the user uploads.
38
- self.drift = 0.5 * historical_std
39
- self.threshold = 3.0 * historical_std
40
  self.s_high = 0.0
41
  self.s_low = 0.0
42
  self.last_alert_t = -COOLDOWN_PERIODS # allow an alert on the very first step
 
35
  def __init__(self, historical_std: float):
36
  # Drift and threshold are derived from the historical series volatility
37
  # so they automatically scale to whatever units the user uploads.
38
+ self.drift = 0.3 * historical_std
39
+ self.threshold = 2.0 * historical_std
40
  self.s_high = 0.0
41
  self.s_low = 0.0
42
  self.last_alert_t = -COOLDOWN_PERIODS # allow an alert on the very first step
explainer.py CHANGED
@@ -199,13 +199,17 @@ def _regex_parse(text: str, last_actual: float | None) -> dict:
199
  if last_actual is not None:
200
  factor = (1 - pct / 100) if down else (1 + pct / 100)
201
  return _result(last_actual * factor, relative=True, approximate=is_approximate)
202
- # No last_actual β€” treat as absolute value if small enough to be a plain number
 
 
 
 
203
  if pct < 10000:
204
  return _result(pct, approximate=is_approximate)
205
 
206
  # ── Relative: went up/down by absolute amount ──────────────────────────
207
  up_match = re.search(
208
- r"(?:went up|increased?|badhke?|upar|zyada)\s+(?:by\s+)?([\d,. ]+(?:lakh|crore)?)", t
209
  )
210
  down_match = re.search(
211
  r"(?:went down|decreased?|dropped?|girak?|kam|niche)\s+(?:by\s+)?([\d,. ]+(?:lakh|crore)?)", t
@@ -255,7 +259,7 @@ def _parse_number_str(text: str) -> float | None:
255
  # Strip prefix approximate/directional words
256
  t = re.sub(r"^(around|about|roughly|approximately|lagbhag|almost|upto|up to)\s*", "", t)
257
  # Strip trailing unit words and directional tokens
258
- t = re.sub(r"\s*(rupee|rupees|rs|inr|kg|units?|pieces?|up|down|more|less|zyada|kam)s?\s*$", "", t)
259
  t = t.strip()
260
 
261
  # Extract first clean number from remaining string
 
199
  if last_actual is not None:
200
  factor = (1 - pct / 100) if down else (1 + pct / 100)
201
  return _result(last_actual * factor, relative=True, approximate=is_approximate)
202
+ # Has relative words (jyada, kam, more, less) β€” needs a previous value
203
+ relative_words = ["jyada", "zyada", "more", "kam", "less", "up", "down", "increase", "decrease"]
204
+ if any(w in t for w in relative_words):
205
+ return _error("We need a previous value to calculate the percentage. Please enter the number directly.")
206
+ # No relative words β€” treat as absolute value
207
  if pct < 10000:
208
  return _result(pct, approximate=is_approximate)
209
 
210
  # ── Relative: went up/down by absolute amount ──────────────────────────
211
  up_match = re.search(
212
+ r"(?:went up|increased?|badhke?|upar|zyada|jyada)\s+(?:by\s+)?([\d,. ]+(?:lakh|crore)?)", t
213
  )
214
  down_match = re.search(
215
  r"(?:went down|decreased?|dropped?|girak?|kam|niche)\s+(?:by\s+)?([\d,. ]+(?:lakh|crore)?)", t
 
259
  # Strip prefix approximate/directional words
260
  t = re.sub(r"^(around|about|roughly|approximately|lagbhag|almost|upto|up to)\s*", "", t)
261
  # Strip trailing unit words and directional tokens
262
+ t = re.sub(r"\s*(rupee|rupees|rs|inr|kg|units?|pieces?|up|down|more|less|zyada|jyada|kam)s?\s*$", "", t)
263
  t = t.strip()
264
 
265
  # Extract first clean number from remaining string
main.py CHANGED
@@ -27,6 +27,8 @@ from models import (
27
  ForecastPoint,
28
  BaselinePoint,
29
  HealthResponse,
 
 
30
  ScenarioRequest,
31
  ScenarioResponse,
32
  UpdateRequest,
@@ -113,6 +115,81 @@ def health():
113
  )
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # ─── /upload ──────────────────────────────────────────────────────────────────
117
 
118
  @app.post("/upload", response_model=UploadResponse)
@@ -188,6 +265,7 @@ def forecast(req: ForecastRequest):
188
  session["_forecast"] = {"low": cal_low, "median": raw["median"], "high": cal_high}
189
  session["_horizon"] = horizon
190
  session["_fc_dates"] = fc_dates
 
191
 
192
  seasonal = _describe_seasonal(prepared["warnings"])
193
 
@@ -214,6 +292,10 @@ def forecast(req: ForecastRequest):
214
  seasonal_pattern=seasonal,
215
  is_financial=prepared["is_financial"],
216
  is_intermittent=prepared["is_intermittent"],
 
 
 
 
217
  )
218
 
219
 
@@ -242,15 +324,20 @@ def update(req: UpdateRequest):
242
  and fc_low[idx] <= actual_value <= fc_high[idx]
243
  )
244
 
245
- # CUSUM update
246
- detector = session.get("_detector")
247
  residual = actual_value - float(fc_median[idx]) if len(fc_median) > idx else 0.0
 
248
 
 
 
249
  series = session.get("_series", np.array([]))
250
  date_col = session.get("date_col", "date")
251
  value_col = session.get("value_col")
252
  df = session.get("df")
253
  fc_dates = session.get("_fc_dates", [])
 
 
 
254
  actual_date = None
255
  if fc_dates and idx < len(fc_dates):
256
  try:
@@ -275,38 +362,46 @@ def update(req: UpdateRequest):
275
  new_alpha = update_alpha(current_alpha, actual_value, float(fc_low[idx]), float(fc_high[idx]))
276
  session["_alpha"] = new_alpha
277
 
278
- # Re-run Chronos if CUSUM fired a structural shift
279
- recalibrated = False
280
- if cusum_result["direction"] != "NONE" and len(series) > 0:
281
- # Append the new actual to the series and re-forecast
282
- updated_series = np.append(series, actual_value)
283
- session["_series"] = updated_series
284
- frequency = session.get("frequency", "weekly")
285
- horizon = session.get("_horizon", 4)
286
-
287
- raw = forecaster.run_forecast(updated_series, horizon, frequency)
288
- from calibrator import calibrate as _calibrate
289
- cal = _calibrate(updated_series, raw["low"], raw["high"])
290
- fc_low = cal["calibrated_low"]
291
- fc_high = cal["calibrated_high"]
292
- fc_median = raw["median"]
293
- session["_forecast"] = {"low": fc_low, "median": fc_median, "high": fc_high}
294
-
295
- last_date = session["df"][date_col].iloc[-1]
296
- new_fc_dates = _make_forecast_dates(last_date, horizon, frequency)
297
- session["_fc_dates"] = new_fc_dates
298
- recalibrated = True
299
-
300
- hist_std = session.get("_hist_std", 1.0)
301
- if len(fc_low) > 0:
302
- score, label = compute_confidence(fc_low, fc_high, hist_std)
303
- else:
304
- score, label = 50, "Medium"
 
 
 
 
 
 
 
 
 
 
305
 
306
- trend_pct = 0.0
307
- if len(series) > 0 and len(fc_median) > 0:
308
- last_val = float(np.append(series, actual_value)[-1])
309
- trend_pct = float((fc_median[-1] - last_val) / (abs(last_val) + 1e-9) * 100)
310
 
311
  new_decision = get_decision(
312
  trend_pct=trend_pct,
@@ -320,8 +415,6 @@ def update(req: UpdateRequest):
320
  else getattr(session.get("warnings"), "intermittent", False),
321
  )
322
 
323
- horizon = session.get("_horizon", 4)
324
- fc_dates = session.get("_fc_dates", [])
325
  explanation_text, _ = explain(
326
  trend_pct=trend_pct,
327
  confidence=score,
@@ -332,7 +425,7 @@ def update(req: UpdateRequest):
332
 
333
  new_forecast = [
334
  ForecastPoint(
335
- date=fc_dates[i] if i < len(fc_dates) else "",
336
  low=round(float(fc_low[i]), 2) if i < len(fc_low) else 0.0,
337
  median=round(float(fc_median[i]), 2) if i < len(fc_median) else 0.0,
338
  high=round(float(fc_high[i]), 2) if i < len(fc_high) else 0.0,
@@ -340,6 +433,21 @@ def update(req: UpdateRequest):
340
  for i in range(horizon)
341
  ]
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  return UpdateResponse(
344
  parsed_value=actual_value,
345
  is_approximate=parsed["is_approximate"],
@@ -347,11 +455,67 @@ def update(req: UpdateRequest):
347
  cusum_alert=cusum_result["direction"],
348
  cusum_magnitude=round(float(cusum_result["magnitude"]), 2),
349
  new_forecast=new_forecast,
 
350
  new_confidence_score=score,
351
  new_confidence_label=label,
352
  new_decision=new_decision,
353
- recalibrated=recalibrated,
354
  explanation=explanation_text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  )
356
 
357
 
@@ -430,6 +594,170 @@ def explain_endpoint(request: Request, req: ExplainRequest):
430
  return ExplainResponse(explanation=text, source=source)
431
 
432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  # ─── Helpers ──────────────────────────────────────────────────────────────────
434
 
435
  def _require_session(session_id: str) -> dict:
@@ -459,6 +787,10 @@ def _forecast_from_cache(payload: dict) -> ForecastResponse:
459
  seasonal_pattern="Pre-computed demo forecast",
460
  is_financial=payload.get("is_financial", False),
461
  is_intermittent=payload.get("is_intermittent", False),
 
 
 
 
462
  )
463
 
464
 
 
27
  ForecastPoint,
28
  BaselinePoint,
29
  HealthResponse,
30
+ ChatRequest,
31
+ ChatResponse,
32
  ScenarioRequest,
33
  ScenarioResponse,
34
  UpdateRequest,
 
115
  )
116
 
117
 
118
+ # ─── /demo ────────────────────────────────────────────────────────────────────
119
+
120
+ DEMO_META = {
121
+ "bakery": {
122
+ "filename": "bakery_sales.csv",
123
+ "date_col": "date",
124
+ "value_col": "weekly_sales_inr",
125
+ "columns": ["date", "weekly_sales_inr"],
126
+ },
127
+ "crop": {
128
+ "filename": "crop_prices_sample.csv",
129
+ "date_col": "date",
130
+ "value_col": "wheat_price_inr_per_quintal",
131
+ "columns": ["date", "wheat_price_inr_per_quintal"],
132
+ },
133
+ "m5": {
134
+ "filename": "walmart_m5_sample.csv",
135
+ "date_col": "date",
136
+ "value_col": "FOODS_1",
137
+ "columns": ["date", "FOODS_1", "HOBBIES_1", "HOUSEHOLD_1"],
138
+ },
139
+ }
140
+
141
+
142
+ @app.get("/demo/{key}")
143
+ def demo(key: str):
144
+ if key not in DEMO_META:
145
+ raise HTTPException(status_code=404, detail={
146
+ "error_code": "DEMO_NOT_FOUND",
147
+ "message": f"Unknown demo dataset: {key}. Available: bakery, crop, m5",
148
+ })
149
+
150
+ cached = cache.get(key)
151
+ if not cached:
152
+ raise HTTPException(status_code=503, detail={
153
+ "error_code": "DEMO_NOT_CACHED",
154
+ "message": "Demo cache not built. Run 'python cache.py' first.",
155
+ })
156
+
157
+ meta = DEMO_META[key]
158
+ session_id = f"demo_{key}"
159
+
160
+ # Build preview from cached history
161
+ hist_dates = cached.get("history_dates", [])
162
+ hist_values = cached.get("history_values", [])
163
+ preview = [
164
+ {"date": hist_dates[i], "value": hist_values[i]}
165
+ for i in range(min(5, len(hist_dates)))
166
+ ]
167
+
168
+ forecast_resp = _forecast_from_cache(cached)
169
+
170
+ return {
171
+ "session_id": session_id,
172
+ "upload": {
173
+ "session_id": session_id,
174
+ "detected_date_col": meta["date_col"],
175
+ "detected_value_col": meta["value_col"],
176
+ "columns": meta["columns"],
177
+ "series_list": [],
178
+ "preview": preview,
179
+ "frequency": cached.get("frequency", "weekly"),
180
+ "n_rows": len(hist_values),
181
+ "outliers": [],
182
+ "warnings": {
183
+ "intermittent": cached.get("is_intermittent", False),
184
+ "non_stationary": cached.get("is_financial", False),
185
+ "short_series": False,
186
+ "large_gaps": False,
187
+ },
188
+ },
189
+ "forecast": forecast_resp.model_dump(),
190
+ }
191
+
192
+
193
  # ─── /upload ──────────────────────────────────────────────────────────────────
194
 
195
  @app.post("/upload", response_model=UploadResponse)
 
265
  session["_forecast"] = {"low": cal_low, "median": raw["median"], "high": cal_high}
266
  session["_horizon"] = horizon
267
  session["_fc_dates"] = fc_dates
268
+ session["_history_dates"] = prepared["dates"]
269
 
270
  seasonal = _describe_seasonal(prepared["warnings"])
271
 
 
292
  seasonal_pattern=seasonal,
293
  is_financial=prepared["is_financial"],
294
  is_intermittent=prepared["is_intermittent"],
295
+ history_dates=[d.strftime("%Y-%m-%d") if hasattr(d, "strftime") else str(d) for d in prepared["dates"][-52:]],
296
+ history_values=[round(float(v), 2) for v in series[-52:]],
297
+ frequency=frequency,
298
+ cusum_threshold=round(2.0 * hist_std, 2),
299
  )
300
 
301
 
 
324
  and fc_low[idx] <= actual_value <= fc_high[idx]
325
  )
326
 
327
+ # Residual for CUSUM
 
328
  residual = actual_value - float(fc_median[idx]) if len(fc_median) > idx else 0.0
329
+ predicted_value = float(fc_median[idx]) if len(fc_median) > idx else 0.0
330
 
331
+ # CUSUM update
332
+ detector = session.get("_detector")
333
  series = session.get("_series", np.array([]))
334
  date_col = session.get("date_col", "date")
335
  value_col = session.get("value_col")
336
  df = session.get("df")
337
  fc_dates = session.get("_fc_dates", [])
338
+ frequency = session.get("frequency", "weekly")
339
+ horizon = session.get("_horizon", 4)
340
+
341
  actual_date = None
342
  if fc_dates and idx < len(fc_dates):
343
  try:
 
362
  new_alpha = update_alpha(current_alpha, actual_value, float(fc_low[idx]), float(fc_high[idx]))
363
  session["_alpha"] = new_alpha
364
 
365
+ # --- Rolling window: always append actual and re-forecast ---
366
+ updated_series = np.append(series, actual_value)
367
+ session["_series"] = updated_series
368
+
369
+ # Track actuals with timestamps for CSV download
370
+ actual_entry = {
371
+ "date": fc_dates[idx] if idx < len(fc_dates) else str(pd.Timestamp.now().date()),
372
+ "value": actual_value,
373
+ "timestamp": pd.Timestamp.now().isoformat(),
374
+ }
375
+ if "_actuals_log" not in session:
376
+ session["_actuals_log"] = []
377
+ session["_actuals_log"].append(actual_entry)
378
+
379
+ # Re-run Chronos with the updated series
380
+ raw = forecaster.run_forecast(updated_series, horizon, frequency)
381
+ cal = calibrate(updated_series, raw["low"], raw["high"])
382
+ fc_low = cal["calibrated_low"]
383
+ fc_high = cal["calibrated_high"]
384
+ fc_median = raw["median"]
385
+ session["_forecast"] = {"low": fc_low, "median": fc_median, "high": fc_high}
386
+ session["_baseline"] = {"values": raw["baseline"], "type": raw["baseline_type"]}
387
+
388
+ # Generate new forecast dates from the last actual's date
389
+ last_known_date = actual_date or pd.Timestamp(fc_dates[-1]) if fc_dates else pd.Timestamp.now()
390
+ new_fc_dates = _make_forecast_dates(last_known_date, horizon, frequency)
391
+ session["_fc_dates"] = new_fc_dates
392
+
393
+ # Update history dates
394
+ hist_dates = list(session.get("_history_dates", []))
395
+ if actual_entry["date"] not in hist_dates:
396
+ hist_dates.append(actual_entry["date"])
397
+ session["_history_dates"] = hist_dates
398
+
399
+ hist_std = session.get("_hist_std", float(np.std(updated_series)))
400
+ session["_hist_std"] = hist_std
401
+ score, label = compute_confidence(fc_low, fc_high, hist_std)
402
 
403
+ last_val = float(updated_series[-1])
404
+ trend_pct = float((fc_median[-1] - last_val) / (abs(last_val) + 1e-9) * 100)
 
 
405
 
406
  new_decision = get_decision(
407
  trend_pct=trend_pct,
 
415
  else getattr(session.get("warnings"), "intermittent", False),
416
  )
417
 
 
 
418
  explanation_text, _ = explain(
419
  trend_pct=trend_pct,
420
  confidence=score,
 
425
 
426
  new_forecast = [
427
  ForecastPoint(
428
+ date=new_fc_dates[i] if i < len(new_fc_dates) else "",
429
  low=round(float(fc_low[i]), 2) if i < len(fc_low) else 0.0,
430
  median=round(float(fc_median[i]), 2) if i < len(fc_median) else 0.0,
431
  high=round(float(fc_high[i]), 2) if i < len(fc_high) else 0.0,
 
433
  for i in range(horizon)
434
  ]
435
 
436
+ new_baseline = [
437
+ BaselinePoint(
438
+ date=new_fc_dates[i] if i < len(new_fc_dates) else "",
439
+ value=round(float(raw["baseline"][i]), 2) if i < len(raw["baseline"]) else 0.0,
440
+ )
441
+ for i in range(horizon)
442
+ ]
443
+
444
+ # Build full history for frontend
445
+ all_dates = [d.strftime("%Y-%m-%d") if hasattr(d, "strftime") else str(d)
446
+ for d in session.get("_history_dates", [])]
447
+ # Fallback: use series indices if no dates tracked
448
+ if not all_dates:
449
+ all_dates = [f"T-{len(updated_series)-i}" for i in range(len(updated_series), 0, -1)]
450
+
451
  return UpdateResponse(
452
  parsed_value=actual_value,
453
  is_approximate=parsed["is_approximate"],
 
455
  cusum_alert=cusum_result["direction"],
456
  cusum_magnitude=round(float(cusum_result["magnitude"]), 2),
457
  new_forecast=new_forecast,
458
+ new_baseline=new_baseline,
459
  new_confidence_score=score,
460
  new_confidence_label=label,
461
  new_decision=new_decision,
462
+ recalibrated=True,
463
  explanation=explanation_text,
464
+ residual=round(residual, 2),
465
+ predicted_value=round(predicted_value, 2),
466
+ history_dates=[d.strftime("%Y-%m-%d") if hasattr(d, "strftime") else str(d)
467
+ for d in session.get("_history_dates", hist_dates)][-52:],
468
+ history_values=[round(float(v), 2) for v in updated_series[-52:]],
469
+ frequency=frequency,
470
+ )
471
+
472
+
473
+ # ─── /download ────────────────────────────────────────────────────────────────
474
+
475
+ @app.get("/download/{session_id}")
476
+ def download_csv(session_id: str):
477
+ from fastapi.responses import StreamingResponse
478
+ import io
479
+
480
+ session = _require_session(session_id)
481
+ series = session.get("_series", np.array([]))
482
+ date_col = session.get("date_col", "date")
483
+ value_col = session.get("value_col", "value")
484
+ actuals_log = session.get("_actuals_log", [])
485
+ df_original = session.get("df")
486
+
487
+ if df_original is not None:
488
+ # Start with original data
489
+ out_df = df_original[[date_col, value_col]].copy()
490
+ out_df["entry_timestamp"] = None
491
+ out_df["source"] = "original"
492
+
493
+ # Append actuals
494
+ for entry in actuals_log:
495
+ new_row = {
496
+ date_col: entry["date"],
497
+ value_col: entry["value"],
498
+ "entry_timestamp": entry["timestamp"],
499
+ "source": "actual_entry",
500
+ }
501
+ out_df = pd.concat([out_df, pd.DataFrame([new_row])], ignore_index=True)
502
+ else:
503
+ # Fallback: just export what we have
504
+ out_df = pd.DataFrame({
505
+ date_col: [e["date"] for e in actuals_log],
506
+ value_col: [e["value"] for e in actuals_log],
507
+ "entry_timestamp": [e["timestamp"] for e in actuals_log],
508
+ "source": "actual_entry",
509
+ })
510
+
511
+ buffer = io.StringIO()
512
+ out_df.to_csv(buffer, index=False)
513
+ buffer.seek(0)
514
+
515
+ return StreamingResponse(
516
+ iter([buffer.getvalue()]),
517
+ media_type="text/csv",
518
+ headers={"Content-Disposition": f"attachment; filename=pulseai_updated_{session_id[:8]}.csv"},
519
  )
520
 
521
 
 
594
  return ExplainResponse(explanation=text, source=source)
595
 
596
 
597
+ # ─── /chat ────────────────────────────────────────────────────────────────────
598
+
599
+ @app.post("/chat", response_model=ChatResponse)
600
+ def chat(request: Request, req: ChatRequest):
601
+ ip = request.client.host if request.client else "unknown"
602
+ if _is_rate_limited(f"chat:{ip}", max_calls=30, window_sec=60):
603
+ raise HTTPException(status_code=429, detail={
604
+ "error_code": "RATE_LIMITED",
605
+ "message": "Too many requests. Please wait a minute before trying again."
606
+ })
607
+
608
+ # Build context from session state
609
+ context_parts = []
610
+ session = None
611
+
612
+ try:
613
+ session = get_session(req.session_id)
614
+ except Exception:
615
+ pass
616
+
617
+ # If we have a demo cache, use that
618
+ demo_key = _demo_key(req.session_id) if req.session_id else None
619
+ cached = cache.get(demo_key) if demo_key else None
620
+
621
+ if cached:
622
+ context_parts.append(f"Dataset: {cached.get('series_name', 'unknown')}")
623
+ context_parts.append(f"Frequency: {cached.get('frequency', 'unknown')}")
624
+ context_parts.append(f"Confidence score: {cached.get('confidence_score', '?')}/100 ({cached.get('confidence_label', '?')})")
625
+ context_parts.append(f"Trend: {cached.get('trend_pct', 0):+.1f}%")
626
+ context_parts.append(f"Baseline method: {cached.get('baseline_type', '?')}")
627
+ context_parts.append(f"Decision: {cached.get('decision', '')}")
628
+ fc = cached.get("forecast", [])
629
+ if fc:
630
+ context_parts.append("Forecast periods:")
631
+ for p in fc:
632
+ context_parts.append(f" {p['date']}: low={p['low']:.2f}, likely={p['median']:.2f}, high={p['high']:.2f}")
633
+ bl = cached.get("baseline", [])
634
+ if bl:
635
+ context_parts.append("Baseline comparison:")
636
+ for p in bl:
637
+ context_parts.append(f" {p['date']}: {p['value']:.2f}")
638
+ hv = cached.get("history_values", [])
639
+ if hv:
640
+ recent = hv[-8:]
641
+ context_parts.append(f"Recent history (last {len(recent)} values): {[round(v, 2) for v in recent]}")
642
+ context_parts.append(f"Financial series: {cached.get('is_financial', False)}")
643
+ context_parts.append(f"Intermittent demand: {cached.get('is_intermittent', False)}")
644
+ detector = cached.get("_detector")
645
+ if detector and detector.alerts:
646
+ context_parts.append(f"Active CUSUM alerts: {[a['direction'] for a in detector.alerts]}")
647
+ else:
648
+ context_parts.append("No structural shifts detected.")
649
+
650
+ elif session:
651
+ forecast_state = session.get("_forecast", {})
652
+ fc_median = forecast_state.get("median", np.array([]))
653
+ fc_low = forecast_state.get("low", np.array([]))
654
+ fc_high = forecast_state.get("high", np.array([]))
655
+ series = session.get("_series", np.array([]))
656
+ fc_dates = session.get("_fc_dates", [])
657
+ frequency = session.get("frequency", "unknown")
658
+ value_col = session.get("value_col", "value")
659
+
660
+ context_parts.append(f"Series name: {value_col.replace('_', ' ')}")
661
+ context_parts.append(f"Frequency: {frequency}")
662
+ if len(series) > 0:
663
+ context_parts.append(f"Recent history (last 8): {[round(float(v), 2) for v in series[-8:]]}")
664
+ if len(fc_median) > 0:
665
+ context_parts.append("Forecast periods:")
666
+ for i in range(len(fc_median)):
667
+ d = fc_dates[i] if i < len(fc_dates) else f"Period {i+1}"
668
+ lo = float(fc_low[i]) if i < len(fc_low) else 0
669
+ hi = float(fc_high[i]) if i < len(fc_high) else 0
670
+ context_parts.append(f" {d}: low={lo:.2f}, likely={float(fc_median[i]):.2f}, high={hi:.2f}")
671
+
672
+ detector = session.get("_detector")
673
+ if detector and detector.alerts:
674
+ context_parts.append(f"Active CUSUM alerts: {[a['direction'] for a in detector.alerts]}")
675
+
676
+ context_block = "\n".join(context_parts) if context_parts else "No forecast data available yet."
677
+
678
+ _GROQ_KEY = os.getenv("GROQ_API_KEY", "").strip()
679
+ if not _GROQ_KEY:
680
+ # Template fallback
681
+ return ChatResponse(
682
+ reply="Chat requires a Groq API key. Please set GROQ_API_KEY in your .env file. "
683
+ "In the meantime, check the Decision Helper card and Explanation card for insights about your forecast.",
684
+ source="template",
685
+ )
686
+
687
+ try:
688
+ from groq import Groq
689
+ client = Groq(api_key=_GROQ_KEY)
690
+
691
+ # Detect if data uses INR based on context
692
+ uses_inr = any(w in context_block.lower() for w in ['inr', 'rupee', 'rupees', 'β‚Ή'])
693
+ currency_hint = "Use Indian currency (β‚Ή) and Indian number formatting." if uses_inr else "Use appropriate units from the data. Do not assume Indian rupees unless the data mentions INR."
694
+
695
+ system_prompt = (
696
+ "You are PulseAI's forecast assistant. You help non-technical users β€” bakers, farmers, "
697
+ "shop owners β€” understand their forecast data through simple conversation.\n\n"
698
+ "RULES:\n"
699
+ "- Use everyday language. Never say: model, algorithm, percentile, parameter, coefficient, "
700
+ " neural, transformer, statistical, conformal, CUSUM, inference.\n"
701
+ "- Keep answers to 2-4 sentences unless the user asks for detail.\n"
702
+ "- Always ground your answers in the actual data below β€” never make up numbers.\n"
703
+ "- If asked what to DO, give one clear, actionable suggestion.\n"
704
+ f"- {currency_hint}\n"
705
+ "- Always refer to the data by its series name (shown below), not a guess.\n"
706
+ "- If you don't know something, say so honestly.\n"
707
+ "- IMPORTANT: If the user reports what actually happened (e.g., 'sales were 3500', "
708
+ "'we sold 2800 this week', 'actual was 1.05', 'it was around 4000'), "
709
+ "acknowledge it and add the tag [ACTUAL:NUMBER] at the very end of your response, "
710
+ "replacing NUMBER with the numeric value. Example: 'Got it! I\\'ll update the forecast. [ACTUAL:3500]'\n\n"
711
+ f"CURRENT FORECAST DATA:\n{context_block}"
712
+ )
713
+
714
+ messages = [{"role": "system", "content": system_prompt}]
715
+ # Add conversation history (last 10 turns max)
716
+ for msg in (req.history or [])[-10:]:
717
+ if msg.get("role") in ("user", "assistant"):
718
+ messages.append({"role": msg["role"], "content": msg["content"]})
719
+ messages.append({"role": "user", "content": req.message})
720
+
721
+ resp = client.chat.completions.create(
722
+ model="llama-3.3-70b-versatile",
723
+ messages=messages,
724
+ max_tokens=250,
725
+ temperature=0.45,
726
+ )
727
+ reply = resp.choices[0].message.content.strip()
728
+
729
+ # Check if the LLM detected an actual value
730
+ import re as _re
731
+ actual_match = _re.search(r'\[ACTUAL:([\d.]+)\]', reply)
732
+ actual_detected = None
733
+ actual_submitted = False
734
+
735
+ if actual_match:
736
+ actual_detected = float(actual_match.group(1))
737
+ # Clean the tag from the visible reply
738
+ reply = _re.sub(r'\s*\[ACTUAL:[\d.]+\]', '', reply).strip()
739
+ # Auto-submit the actual via the update logic
740
+ try:
741
+ parsed_for_update = parse_nl_input(str(actual_detected), last_actual=None)
742
+ if parsed_for_update["value"] is not None and session:
743
+ actual_submitted = True
744
+ except Exception:
745
+ pass
746
+
747
+ return ChatResponse(
748
+ reply=reply,
749
+ source="groq",
750
+ actual_detected=actual_detected,
751
+ actual_submitted=actual_submitted,
752
+ )
753
+
754
+ except Exception as e:
755
+ return ChatResponse(
756
+ reply=f"I'm having trouble connecting right now. Here's what I can tell you: {context_block[:200]}",
757
+ source="template",
758
+ )
759
+
760
+
761
  # ─── Helpers ──────────────────────────────────────────────────────────────────
762
 
763
  def _require_session(session_id: str) -> dict:
 
787
  seasonal_pattern="Pre-computed demo forecast",
788
  is_financial=payload.get("is_financial", False),
789
  is_intermittent=payload.get("is_intermittent", False),
790
+ history_dates=payload.get("history_dates", []),
791
+ history_values=payload.get("history_values", []),
792
+ frequency=payload.get("frequency", "weekly"),
793
+ cusum_threshold=round(3.0 * payload.get("_hist_std", 1.0), 2) if "_hist_std" in payload else 0.0,
794
  )
795
 
796
 
models.py CHANGED
@@ -69,6 +69,10 @@ class ForecastResponse(BaseModel):
69
  seasonal_pattern: str # e.g. "Weekly seasonality detected"
70
  is_financial: bool
71
  is_intermittent: bool
 
 
 
 
72
 
73
 
74
  # ─── Update (actual value entry + recalibration) ──────────────────────────────
@@ -86,11 +90,17 @@ class UpdateResponse(BaseModel):
86
  cusum_alert: Literal["HIGH", "LOW", "NONE"]
87
  cusum_magnitude: float
88
  new_forecast: list[ForecastPoint]
 
89
  new_confidence_score: int
90
  new_confidence_label: Literal["Low", "Medium", "High", "Very High"]
91
  new_decision: str
92
- recalibrated: bool # True if Chronos was re-run after a CUSUM alert
93
- explanation: str # plain-English summary of what just happened
 
 
 
 
 
94
 
95
 
96
  # ─── Scenario ─────────────────────────────────────────────────────────────────
@@ -126,8 +136,23 @@ class HealthResponse(BaseModel):
126
  uptime_seconds: int
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  # ─── Error ────────────────────────────────────────────────────────────────────
130
 
131
  class ErrorResponse(BaseModel):
132
  error_code: str # e.g. "TOO_FEW_ROWS"
133
- message: str # friendly message shown to the user
 
69
  seasonal_pattern: str # e.g. "Weekly seasonality detected"
70
  is_financial: bool
71
  is_intermittent: bool
72
+ history_dates: list[str] = []
73
+ history_values: list[float] = []
74
+ frequency: str = "weekly"
75
+ cusum_threshold: float = 0.0
76
 
77
 
78
  # ─── Update (actual value entry + recalibration) ──────────────────────────────
 
90
  cusum_alert: Literal["HIGH", "LOW", "NONE"]
91
  cusum_magnitude: float
92
  new_forecast: list[ForecastPoint]
93
+ new_baseline: list[BaselinePoint] = []
94
  new_confidence_score: int
95
  new_confidence_label: Literal["Low", "Medium", "High", "Very High"]
96
  new_decision: str
97
+ recalibrated: bool
98
+ explanation: str
99
+ residual: float = 0.0
100
+ predicted_value: float = 0.0 # what was forecast for this period
101
+ history_dates: list[str] = []
102
+ history_values: list[float] = []
103
+ frequency: str = "weekly"
104
 
105
 
106
  # ─── Scenario ─────────────────────────────────────────────────────────────────
 
136
  uptime_seconds: int
137
 
138
 
139
+ # ─── Chat ─────────────────────────────────────────────────────────────────────
140
+
141
+ class ChatRequest(BaseModel):
142
+ session_id: str
143
+ message: str
144
+ history: list[dict] = [] # [{role, content}, ...]
145
+
146
+
147
+ class ChatResponse(BaseModel):
148
+ reply: str
149
+ source: Literal["groq", "template"]
150
+ actual_detected: float | None = None # if user reported an actual via chat
151
+ actual_submitted: bool = False # True if the actual was auto-submitted
152
+
153
+
154
  # ─── Error ────────────────────────────────────────────────────────────────────
155
 
156
  class ErrorResponse(BaseModel):
157
  error_code: str # e.g. "TOO_FEW_ROWS"
158
+ message: str # friendly message shown to the user
preprocessor.py CHANGED
@@ -411,12 +411,18 @@ def _detect_frequency(df: pd.DataFrame, date_col: str) -> str:
411
  deltas = df[date_col].diff().dropna().dt.days
412
  median_gap = deltas.median()
413
 
 
 
414
  if median_gap <= 1.5:
415
  return "daily"
416
  if median_gap <= 8:
417
  return "weekly"
418
  if median_gap <= 35:
419
  return "monthly"
 
 
 
 
420
  return "unknown"
421
 
422
 
 
411
  deltas = df[date_col].diff().dropna().dt.days
412
  median_gap = deltas.median()
413
 
414
+ if median_gap < 0.1:
415
+ return "hourly"
416
  if median_gap <= 1.5:
417
  return "daily"
418
  if median_gap <= 8:
419
  return "weekly"
420
  if median_gap <= 35:
421
  return "monthly"
422
+ if median_gap <= 100:
423
+ return "quarterly"
424
+ if median_gap <= 400:
425
+ return "annually"
426
  return "unknown"
427
 
428
 
test_chronos.py DELETED
@@ -1,123 +0,0 @@
1
- """
2
- test_chronos.py β€” PulseAI Chronos-Bolt-Small Smoke Test
3
- Run: python test_chronos.py
4
- Expected: prints forecast arrays, PASS at end
5
- """
6
-
7
- import sys
8
- import time
9
- import numpy as np
10
-
11
- print("=" * 60)
12
- print("PulseAI β€” Chronos-Bolt-Small Smoke Test")
13
- print("=" * 60)
14
-
15
- # ── 1. Import check ──────────────────────────────────────────
16
- print("\n[1/5] Importing chronos...")
17
- try:
18
- import torch
19
- from chronos import BaseChronosPipeline
20
- print(f" βœ“ torch {torch.__version__}")
21
- print(f" βœ“ chronos imported")
22
- except ImportError as e:
23
- print(f" βœ— Import failed: {e}")
24
- sys.exit(1)
25
-
26
- # ── 2. Model load ─────────────────────────────────────────────
27
- print("\n[2/5] Loading chronos-bolt-small...")
28
- print(" (Cached from previous run β€” should be fast now)")
29
- t0 = time.time()
30
- try:
31
- pipeline = BaseChronosPipeline.from_pretrained(
32
- "amazon/chronos-bolt-small",
33
- device_map="cpu",
34
- dtype=torch.float32, # fixed: dtype not torch_dtype
35
- )
36
- elapsed = time.time() - t0
37
- print(f" βœ“ Model loaded in {elapsed:.1f}s")
38
- print(f" βœ“ Pipeline type: {type(pipeline).__name__}")
39
- except Exception as e:
40
- print(f" βœ— Model load failed: {e}")
41
- sys.exit(1)
42
-
43
- # ── 3. Inference β€” Chronos-Bolt uses quantile_levels ──────────
44
- print("\n[3/5] Running inference on 20-point series...")
45
- print(" Note: Chronos-Bolt outputs quantiles directly (no num_samples)")
46
- context_values = [
47
- 3200, 3100, 3400, 3300, 3500, 3200, 3100,
48
- 4200, 4500, 4100, 3900, 3300, 3200, 3100,
49
- 3500, 3600, 3400, 3200, 3100, 3300
50
- ]
51
- context = torch.tensor(context_values, dtype=torch.float32).unsqueeze(0)
52
-
53
- try:
54
- t0 = time.time()
55
- quantile_levels = [0.1, 0.5, 0.9]
56
- quantiles, mean = pipeline.predict_quantiles(
57
- context=context,
58
- prediction_length=4,
59
- quantile_levels=quantile_levels,
60
- )
61
- elapsed = time.time() - t0
62
- print(f" βœ“ Inference done in {elapsed:.2f}s")
63
- print(f" βœ“ Quantiles shape: {quantiles.shape}") # [1, 4, 3]
64
- print(f" βœ“ Mean shape: {mean.shape}") # [1, 4]
65
- except Exception as e:
66
- print(f" βœ— Inference failed: {e}")
67
- sys.exit(1)
68
-
69
- # ── 4. Percentile extraction ──────────────────────────────────
70
- print("\n[4/5] Extracting 10th / 50th / 90th percentiles...")
71
- try:
72
- q = quantiles[0].numpy() # shape: [4, 3] β†’ [timestep, quantile]
73
- m = mean[0].numpy() # shape: [4]
74
-
75
- low = q[:, 0] # 10th percentile
76
- median = q[:, 1] # 50th percentile
77
- high = q[:, 2] # 90th percentile
78
-
79
- for i in range(4):
80
- print(f" βœ“ Week {i+1}: {low[i]:.0f} – {median[i]:.0f} – {high[i]:.0f} "
81
- f"(mean: {m[i]:.0f})")
82
-
83
- assert all(low < high), "low must be < high"
84
- assert all(low > 0), "values should be positive"
85
- print(" βœ“ Sanity checks passed")
86
- except Exception as e:
87
- print(f" βœ— Percentile extraction failed: {e}")
88
- sys.exit(1)
89
-
90
- # ── 5. NaN guard test ─────────────────────────────────────────
91
- print("\n[5/5] Testing NaN guard (gap in data)...")
92
- try:
93
- values_with_gap = context_values.copy()
94
- values_with_gap[10] = float("nan")
95
-
96
- arr = np.array(values_with_gap, dtype=np.float64)
97
- nan_mask = np.isnan(arr)
98
- if nan_mask.any():
99
- idx = np.arange(len(arr))
100
- arr[nan_mask] = np.interp(
101
- idx[nan_mask], idx[~nan_mask], arr[~nan_mask]
102
- )
103
-
104
- assert not np.isnan(arr).any(), "NaN guard failed"
105
- ctx_clean = torch.tensor(arr, dtype=torch.float32).unsqueeze(0)
106
-
107
- q2, m2 = pipeline.predict_quantiles(
108
- context=ctx_clean,
109
- prediction_length=4,
110
- quantile_levels=[0.1, 0.5, 0.9],
111
- )
112
- print(f" βœ“ NaN guard works β€” quantiles shape: {q2.shape}")
113
- except Exception as e:
114
- print(f" βœ— NaN guard test failed: {e}")
115
- sys.exit(1)
116
-
117
- print("\n" + "=" * 60)
118
- print(" βœ… ALL TESTS PASSED β€” Ready to build Step 3")
119
- print("=" * 60)
120
- print("\nKEY FINDING: Chronos-Bolt uses predict_quantiles()")
121
- print(" Input: context tensor + prediction_length + quantile_levels")
122
- print(" Output: quantiles [batch, timesteps, n_quantiles] + mean [batch, timesteps]")
123
- print(" This is BETTER than num_samples β€” direct quantiles, faster, more stable")