johnbridges commited on
Commit
81d8cd8
·
1 Parent(s): d74feed
Files changed (1) hide show
  1. timesfm_backend.py +96 -54
timesfm_backend.py CHANGED
@@ -1,9 +1,6 @@
1
  # timesfm_backend.py
2
- import time
3
- import json
4
- import logging
5
- from typing import Any, Dict, Optional
6
-
7
  import numpy as np
8
  import torch
9
 
@@ -12,33 +9,29 @@ from config import settings
12
 
13
  logger = logging.getLogger(__name__)
14
 
 
15
  try:
16
  from timesfm import TimesFm
17
  _TIMESFM_AVAILABLE = True
18
  except Exception as e:
19
- logger.warning("timesfm not available (%s)", e)
20
- TimesFm = None
21
  _TIMESFM_AVAILABLE = False
22
 
23
 
24
- # ---------- small helpers ----------
25
  def _parse_series(series: Any) -> np.ndarray:
26
  if series is None:
27
  raise ValueError("series is required")
28
  if isinstance(series, dict):
29
- if "values" in series:
30
- series = series["values"]
31
- elif "y" in series:
32
- series = series["y"]
33
 
34
- vals = []
35
  if isinstance(series, (list, tuple)):
36
  if series and isinstance(series[0], dict):
37
  for item in series:
38
- if "y" in item:
39
- vals.append(float(item["y"]))
40
- elif "value" in item:
41
- vals.append(float(item["value"]))
42
  else:
43
  vals = [float(x) for x in series]
44
  else:
@@ -56,66 +49,122 @@ def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray:
56
  return np.full((horizon,), base, dtype=np.float32)
57
 
58
 
59
- # ---------- backend ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  class TimesFMBackend(ChatBackend):
 
 
 
 
 
 
 
 
61
  def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
62
  self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
63
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
64
- self._model = None
65
 
66
- def _ensure_model(self):
67
  if self._model is not None or not _TIMESFM_AVAILABLE:
68
  return
69
  try:
70
- # you may need to adjust context_len/horizon_len to match checkpoint
71
- self._model = TimesFm(
72
- context_len=512,
73
- horizon_len=128,
74
- input_patch_len=32,
75
- )
76
  self._model.load_from_checkpoint(self.model_id)
77
- self._model.to(self.device)
78
- logger.info("TimesFM model loaded from %s on %s", self.model_id, self.device)
 
 
 
79
  except Exception as e:
80
- logger.exception("Failed to init TimesFM, fallback only. %s", e)
81
  self._model = None
82
 
83
  async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]:
84
- if "data" in payload and isinstance(payload["data"], dict):
 
85
  payload = {**payload, **payload["data"]}
86
- if "timeseries" in payload and isinstance(payload["timeseries"], dict):
87
  payload = {**payload, **payload["timeseries"]}
 
 
88
 
89
  y = _parse_series(payload.get("series"))
90
  horizon = int(payload.get("horizon", 0))
91
  freq = payload.get("freq")
92
-
93
  if horizon <= 0:
94
- raise ValueError("horizon must be positive")
95
 
96
  self._ensure_model()
97
 
98
  note = None
99
- if self._model is not None:
100
  try:
101
- x = torch.tensor(y, dtype=torch.float32, device=self.device)[None, :]
102
- preds = self._model.forecast_on_batch(x, horizon)
103
  fc = preds[0].detach().cpu().numpy().astype(float).tolist()
104
  except Exception as e:
105
- logger.exception("TimesFM forecast failed, using fallback. %s", e)
106
  fc = _fallback_forecast(y, horizon).tolist()
107
  note = "fallback_used_due_to_predict_error"
108
  else:
109
  fc = _fallback_forecast(y, horizon).tolist()
110
  note = "fallback_used_timesfm_missing"
111
 
112
- return {
113
- "model": self.model_id,
114
- "horizon": horizon,
115
- "freq": freq,
116
- "forecast": fc,
117
- "note": note,
118
- }
119
 
120
  async def stream(self, request: Dict[str, Any]):
121
  rid = f"chatcmpl-timesfm-{int(time.time())}"
@@ -135,16 +184,9 @@ class TimesFMBackend(ChatBackend):
135
  return
136
 
137
  content = json.dumps(
138
- {
139
- "model": result["model"],
140
- "horizon": result["horizon"],
141
- "freq": result["freq"],
142
- "forecast": result["forecast"],
143
- "note": result.get("note"),
144
- "backend": "timesfm",
145
- },
146
- separators=(",", ":"),
147
- ensure_ascii=False,
148
  )
149
  yield {
150
  "id": rid,
 
1
  # timesfm_backend.py
2
+ import time, json, logging
3
+ from typing import Any, Dict, List, Optional
 
 
 
4
  import numpy as np
5
  import torch
6
 
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
+ # --- TimesFM import (fallback-safe) ---
13
  try:
14
  from timesfm import TimesFm
15
  _TIMESFM_AVAILABLE = True
16
  except Exception as e:
17
+ logger.warning("timesfm not available (%s) — using naive fallback.", e)
18
+ TimesFm = None # type: ignore
19
  _TIMESFM_AVAILABLE = False
20
 
21
 
22
+ # --- helpers ---
23
  def _parse_series(series: Any) -> np.ndarray:
24
  if series is None:
25
  raise ValueError("series is required")
26
  if isinstance(series, dict):
27
+ series = series.get("values") or series.get("y")
 
 
 
28
 
29
+ vals: List[float] = []
30
  if isinstance(series, (list, tuple)):
31
  if series and isinstance(series[0], dict):
32
  for item in series:
33
+ if "y" in item: vals.append(float(item["y"]))
34
+ elif "value" in item: vals.append(float(item["value"]))
 
 
35
  else:
36
  vals = [float(x) for x in series]
37
  else:
 
49
  return np.full((horizon,), base, dtype=np.float32)
50
 
51
 
52
+ def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]:
53
+ s = s.strip()
54
+ # whole-string JSON
55
+ if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
56
+ try:
57
+ obj = json.loads(s)
58
+ return obj if isinstance(obj, dict) else None
59
+ except Exception:
60
+ pass
61
+ # fenced ```json ... ```
62
+ if "```" in s:
63
+ parts = s.split("```")
64
+ for i in range(1, len(parts), 2):
65
+ block = parts[i]
66
+ if block.lstrip().lower().startswith("json"):
67
+ block = block.split("\n", 1)[-1]
68
+ try:
69
+ obj = json.loads(block.strip())
70
+ return obj if isinstance(obj, dict) else None
71
+ except Exception:
72
+ continue
73
+ return None
74
+
75
+
76
+ def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]:
77
+ """
78
+ OpenAI chat format:
79
+ messages: [{role, content}, ...]
80
+ content can be a string or a list of parts: [{"type":"text","text":"..."}]
81
+ We scan from last to first user message and merge first JSON dict found.
82
+ """
83
+ msgs = payload.get("messages")
84
+ if not isinstance(msgs, list):
85
+ return payload
86
+
87
+ for m in reversed(msgs):
88
+ if not isinstance(m, dict) or m.get("role") != "user":
89
+ continue
90
+ c = m.get("content")
91
+ # Text parts array
92
+ if isinstance(c, list):
93
+ texts = [p.get("text") for p in c if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str)]
94
+ for t in reversed(texts):
95
+ obj = _extract_json_from_text(t)
96
+ if isinstance(obj, dict):
97
+ return {**payload, **obj}
98
+ # Plain string
99
+ if isinstance(c, str):
100
+ obj = _extract_json_from_text(c)
101
+ if isinstance(obj, dict):
102
+ return {**payload, **obj}
103
+ return payload
104
+
105
+
106
+ # --- backend ---
107
  class TimesFMBackend(ChatBackend):
108
+ """
109
+ Accepts OpenAI chat-completions requests.
110
+ Pulls timeseries config from:
111
+ - top-level keys, OR
112
+ - payload['data'] (CloudEvents), OR
113
+ - last user message JSON (OpenAI format, string or text-part).
114
+ Keys: series: list[float|{y|value}], horizon: int, freq: str (optional)
115
+ """
116
  def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
117
  self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
118
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
119
+ self._model: Optional[TimesFm] = None # type: ignore
120
 
121
+ def _ensure_model(self) -> None:
122
  if self._model is not None or not _TIMESFM_AVAILABLE:
123
  return
124
  try:
125
+ self._model = TimesFm(context_len=512, horizon_len=128, input_patch_len=32)
 
 
 
 
 
126
  self._model.load_from_checkpoint(self.model_id)
127
+ try:
128
+ self._model.to(self.device) # type: ignore[attr-defined]
129
+ except Exception:
130
+ pass
131
+ logger.info("TimesFM loaded from %s on %s", self.model_id, self.device)
132
  except Exception as e:
133
+ logger.exception("TimesFM init failed; fallback only. %s", e)
134
  self._model = None
135
 
136
  async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]:
137
+ # unwrap CloudEvents .data and nested .timeseries
138
+ if isinstance(payload.get("data"), dict):
139
  payload = {**payload, **payload["data"]}
140
+ if isinstance(payload.get("timeseries"), dict):
141
  payload = {**payload, **payload["timeseries"]}
142
+ # merge JSON embedded in last user message (OpenAI format)
143
+ payload = _merge_openai_message_json(payload)
144
 
145
  y = _parse_series(payload.get("series"))
146
  horizon = int(payload.get("horizon", 0))
147
  freq = payload.get("freq")
 
148
  if horizon <= 0:
149
+ raise ValueError("horizon must be a positive integer")
150
 
151
  self._ensure_model()
152
 
153
  note = None
154
+ if _TIMESFM_AVAILABLE and self._model is not None:
155
  try:
156
+ x = torch.tensor(y, dtype=torch.float32, device=self.device).unsqueeze(0) # [1,T]
157
+ preds = self._model.forecast_on_batch(x, horizon) # -> [1,H]
158
  fc = preds[0].detach().cpu().numpy().astype(float).tolist()
159
  except Exception as e:
160
+ logger.exception("TimesFM forecast failed; fallback used. %s", e)
161
  fc = _fallback_forecast(y, horizon).tolist()
162
  note = "fallback_used_due_to_predict_error"
163
  else:
164
  fc = _fallback_forecast(y, horizon).tolist()
165
  note = "fallback_used_timesfm_missing"
166
 
167
+ return {"model": self.model_id, "horizon": horizon, "freq": freq, "forecast": fc, "note": note}
 
 
 
 
 
 
168
 
169
  async def stream(self, request: Dict[str, Any]):
170
  rid = f"chatcmpl-timesfm-{int(time.time())}"
 
184
  return
185
 
186
  content = json.dumps(
187
+ {"model": result["model"], "horizon": result["horizon"], "freq": result["freq"],
188
+ "forecast": result["forecast"], "note": result.get("note"), "backend": "timesfm"},
189
+ separators=(",", ":"), ensure_ascii=False
 
 
 
 
 
 
 
190
  )
191
  yield {
192
  "id": rid,