johnbridges commited on
Commit
16f3ba1
·
1 Parent(s): 81d8cd8
Files changed (1) hide show
  1. timesfm_backend.py +91 -36
timesfm_backend.py CHANGED
@@ -1,6 +1,9 @@
1
  # timesfm_backend.py
2
- import time, json, logging
 
 
3
  from typing import Any, Dict, List, Optional
 
4
  import numpy as np
5
  import torch
6
 
@@ -9,9 +12,9 @@ from config import settings
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
- # --- TimesFM import (fallback-safe) ---
13
  try:
14
- from timesfm import TimesFm
15
  _TIMESFM_AVAILABLE = True
16
  except Exception as e:
17
  logger.warning("timesfm not available (%s) — using naive fallback.", e)
@@ -19,29 +22,41 @@ except Exception as e:
19
  _TIMESFM_AVAILABLE = False
20
 
21
 
22
- # --- helpers ---
23
  def _parse_series(series: Any) -> np.ndarray:
 
 
 
 
24
  if series is None:
25
  raise ValueError("series is required")
 
26
  if isinstance(series, dict):
 
27
  series = series.get("values") or series.get("y")
28
 
29
  vals: List[float] = []
30
  if isinstance(series, (list, tuple)):
31
  if series and isinstance(series[0], dict):
32
  for item in series:
33
- if "y" in item: vals.append(float(item["y"]))
34
- elif "value" in item: vals.append(float(item["value"]))
 
 
35
  else:
36
  vals = [float(x) for x in series]
37
  else:
38
  raise ValueError("series must be a list/tuple or dict with 'values'/'y'")
 
39
  if not vals:
40
  raise ValueError("series is empty")
41
  return np.asarray(vals, dtype=np.float32)
42
 
43
 
44
  def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray:
 
 
 
45
  if horizon <= 0:
46
  return np.zeros((0,), dtype=np.float32)
47
  k = 4 if y.shape[0] >= 4 else y.shape[0]
@@ -50,15 +65,19 @@ def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray:
50
 
51
 
52
  def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]:
 
 
 
 
53
  s = s.strip()
54
- # whole-string JSON
55
  if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
56
  try:
57
  obj = json.loads(s)
58
  return obj if isinstance(obj, dict) else None
59
  except Exception:
60
  pass
61
- # fenced ```json ... ```
62
  if "```" in s:
63
  parts = s.split("```")
64
  for i in range(1, len(parts), 2):
@@ -75,10 +94,10 @@ def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]:
75
 
76
  def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]:
77
  """
78
- OpenAI chat format:
79
- messages: [{role, content}, ...]
80
- content can be a string or a list of parts: [{"type":"text","text":"..."}]
81
- We scan from last to first user message and merge first JSON dict found.
82
  """
83
  msgs = payload.get("messages")
84
  if not isinstance(msgs, list):
@@ -87,32 +106,39 @@ def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]:
87
  for m in reversed(msgs):
88
  if not isinstance(m, dict) or m.get("role") != "user":
89
  continue
90
- c = m.get("content")
91
- # Text parts array
92
- if isinstance(c, list):
93
- texts = [p.get("text") for p in c if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str)]
94
- for t in reversed(texts):
95
- obj = _extract_json_from_text(t)
96
- if isinstance(obj, dict):
97
- return {**payload, **obj}
98
- # Plain string
99
- if isinstance(c, str):
100
- obj = _extract_json_from_text(c)
 
 
101
  if isinstance(obj, dict):
102
  return {**payload, **obj}
 
103
  return payload
104
 
105
 
106
- # --- backend ---
107
  class TimesFMBackend(ChatBackend):
108
  """
109
  Accepts OpenAI chat-completions requests.
110
  Pulls timeseries config from:
111
  - top-level keys, OR
112
- - payload['data'] (CloudEvents), OR
113
- - last user message JSON (OpenAI format, string or text-part).
114
- Keys: series: list[float|{y|value}], horizon: int, freq: str (optional)
 
 
 
115
  """
 
116
  def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
117
  self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
118
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
@@ -122,7 +148,12 @@ class TimesFMBackend(ChatBackend):
122
  if self._model is not None or not _TIMESFM_AVAILABLE:
123
  return
124
  try:
125
- self._model = TimesFm(context_len=512, horizon_len=128, input_patch_len=32)
 
 
 
 
 
126
  self._model.load_from_checkpoint(self.model_id)
127
  try:
128
  self._model.to(self.device) # type: ignore[attr-defined]
@@ -139,6 +170,7 @@ class TimesFMBackend(ChatBackend):
139
  payload = {**payload, **payload["data"]}
140
  if isinstance(payload.get("timeseries"), dict):
141
  payload = {**payload, **payload["timeseries"]}
 
142
  # merge JSON embedded in last user message (OpenAI format)
143
  payload = _merge_openai_message_json(payload)
144
 
@@ -153,8 +185,8 @@ class TimesFMBackend(ChatBackend):
153
  note = None
154
  if _TIMESFM_AVAILABLE and self._model is not None:
155
  try:
156
- x = torch.tensor(y, dtype=torch.float32, device=self.device).unsqueeze(0) # [1,T]
157
- preds = self._model.forecast_on_batch(x, horizon) # -> [1,H]
158
  fc = preds[0].detach().cpu().numpy().astype(float).tolist()
159
  except Exception as e:
160
  logger.exception("TimesFM forecast failed; fallback used. %s", e)
@@ -164,12 +196,23 @@ class TimesFMBackend(ChatBackend):
164
  fc = _fallback_forecast(y, horizon).tolist()
165
  note = "fallback_used_timesfm_missing"
166
 
167
- return {"model": self.model_id, "horizon": horizon, "freq": freq, "forecast": fc, "note": note}
 
 
 
 
 
 
168
 
169
  async def stream(self, request: Dict[str, Any]):
 
 
 
 
170
  rid = f"chatcmpl-timesfm-{int(time.time())}"
171
  now = int(time.time())
172
  payload = dict(request) if isinstance(request, dict) else {}
 
173
  try:
174
  result = await self.forecast(payload)
175
  except Exception as e:
@@ -179,24 +222,36 @@ class TimesFMBackend(ChatBackend):
179
  "object": "chat.completion.chunk",
180
  "created": now,
181
  "model": self.model_id,
182
- "choices": [{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}],
 
 
183
  }
184
  return
185
 
186
  content = json.dumps(
187
- {"model": result["model"], "horizon": result["horizon"], "freq": result["freq"],
188
- "forecast": result["forecast"], "note": result.get("note"), "backend": "timesfm"},
189
- separators=(",", ":"), ensure_ascii=False
 
 
 
 
 
 
 
190
  )
191
  yield {
192
  "id": rid,
193
  "object": "chat.completion.chunk",
194
  "created": now,
195
  "model": self.model_id,
196
- "choices": [{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}],
 
 
197
  }
198
 
199
 
 
200
  class StubImagesBackend(ImagesBackend):
201
  async def generate_b64(self, request: Dict[str, Any]) -> str:
202
  logger.warning("Image generation not supported in TimesFM backend.")
 
1
  # timesfm_backend.py
2
+ import time
3
+ import json
4
+ import logging
5
  from typing import Any, Dict, List, Optional
6
+
7
  import numpy as np
8
  import torch
9
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
+ # ---------------- TimesFM import (fallback-safe) ----------------
16
  try:
17
+ from timesfm import TimesFm # Google TimesFM 2.5+
18
  _TIMESFM_AVAILABLE = True
19
  except Exception as e:
20
  logger.warning("timesfm not available (%s) — using naive fallback.", e)
 
22
  _TIMESFM_AVAILABLE = False
23
 
24
 
25
+ # ---------------- helpers ----------------
26
  def _parse_series(series: Any) -> np.ndarray:
27
+ """
28
+ Accepts: list[float|int], list[dict{'y'|'value'}], or dict with 'values'/'y'.
29
+ Returns: 1D float32 numpy array.
30
+ """
31
  if series is None:
32
  raise ValueError("series is required")
33
+
34
  if isinstance(series, dict):
35
+ # allow {"values":[...]} or {"y":[...]}
36
  series = series.get("values") or series.get("y")
37
 
38
  vals: List[float] = []
39
  if isinstance(series, (list, tuple)):
40
  if series and isinstance(series[0], dict):
41
  for item in series:
42
+ if "y" in item:
43
+ vals.append(float(item["y"]))
44
+ elif "value" in item:
45
+ vals.append(float(item["value"]))
46
  else:
47
  vals = [float(x) for x in series]
48
  else:
49
  raise ValueError("series must be a list/tuple or dict with 'values'/'y'")
50
+
51
  if not vals:
52
  raise ValueError("series is empty")
53
  return np.asarray(vals, dtype=np.float32)
54
 
55
 
56
  def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray:
57
+ """
58
+ Naive fallback: mean of last 4 (or all if <4), repeated H times.
59
+ """
60
  if horizon <= 0:
61
  return np.zeros((0,), dtype=np.float32)
62
  k = 4 if y.shape[0] >= 4 else y.shape[0]
 
65
 
66
 
67
  def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]:
68
+ """
69
+ Try to parse JSON from a plain string or a fenced ```json block.
70
+ Returns dict or None.
71
+ """
72
  s = s.strip()
73
+ # whole-string JSON object/array
74
  if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
75
  try:
76
  obj = json.loads(s)
77
  return obj if isinstance(obj, dict) else None
78
  except Exception:
79
  pass
80
+ # fenced code blocks
81
  if "```" in s:
82
  parts = s.split("```")
83
  for i in range(1, len(parts), 2):
 
94
 
95
  def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]:
96
  """
97
+ OpenAI chat format compatibility:
98
+ payload["messages"] may hold user JSON in the last user message.
99
+ content can be a plain string or a list of parts [{"type":"text","text":...}].
100
+ If a JSON object is found, merge its keys into payload.
101
  """
102
  msgs = payload.get("messages")
103
  if not isinstance(msgs, list):
 
106
  for m in reversed(msgs):
107
  if not isinstance(m, dict) or m.get("role") != "user":
108
  continue
109
+ content = m.get("content")
110
+ texts: List[str] = []
111
+ if isinstance(content, list):
112
+ texts = [
113
+ p.get("text")
114
+ for p in content
115
+ if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str)
116
+ ]
117
+ elif isinstance(content, str):
118
+ texts = [content]
119
+
120
+ for t in reversed(texts):
121
+ obj = _extract_json_from_text(t)
122
  if isinstance(obj, dict):
123
  return {**payload, **obj}
124
+ break # only inspect last user
125
  return payload
126
 
127
 
128
+ # ---------------- backend ----------------
129
  class TimesFMBackend(ChatBackend):
130
  """
131
  Accepts OpenAI chat-completions requests.
132
  Pulls timeseries config from:
133
  - top-level keys, OR
134
+ - payload['data'] (CloudEvents wrapper), OR
135
+ - last user message JSON (OpenAI format).
136
+ Keys:
137
+ series: list[float|int|{y|value}]
138
+ horizon: int (>0)
139
+ freq: optional str
140
  """
141
+
142
  def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
143
  self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
144
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
148
  if self._model is not None or not _TIMESFM_AVAILABLE:
149
  return
150
  try:
151
+ # Set lengths compatible with the 2.5 checkpoints.
152
+ self._model = TimesFm(
153
+ context_len=512,
154
+ horizon_len=128,
155
+ input_patch_len=32,
156
+ )
157
  self._model.load_from_checkpoint(self.model_id)
158
  try:
159
  self._model.to(self.device) # type: ignore[attr-defined]
 
170
  payload = {**payload, **payload["data"]}
171
  if isinstance(payload.get("timeseries"), dict):
172
  payload = {**payload, **payload["timeseries"]}
173
+
174
  # merge JSON embedded in last user message (OpenAI format)
175
  payload = _merge_openai_message_json(payload)
176
 
 
185
  note = None
186
  if _TIMESFM_AVAILABLE and self._model is not None:
187
  try:
188
+ x = torch.tensor(y, dtype=torch.float32, device=self.device).unsqueeze(0) # [1, T]
189
+ preds = self._model.forecast_on_batch(x, horizon) # -> [1, H]
190
  fc = preds[0].detach().cpu().numpy().astype(float).tolist()
191
  except Exception as e:
192
  logger.exception("TimesFM forecast failed; fallback used. %s", e)
 
196
  fc = _fallback_forecast(y, horizon).tolist()
197
  note = "fallback_used_timesfm_missing"
198
 
199
+ return {
200
+ "model": self.model_id,
201
+ "horizon": horizon,
202
+ "freq": freq,
203
+ "forecast": fc,
204
+ "note": note,
205
+ }
206
 
207
  async def stream(self, request: Dict[str, Any]):
208
+ """
209
+ OA-compatible streaming shim:
210
+ Emits exactly one chat.completion.chunk with compact JSON content.
211
+ """
212
  rid = f"chatcmpl-timesfm-{int(time.time())}"
213
  now = int(time.time())
214
  payload = dict(request) if isinstance(request, dict) else {}
215
+
216
  try:
217
  result = await self.forecast(payload)
218
  except Exception as e:
 
222
  "object": "chat.completion.chunk",
223
  "created": now,
224
  "model": self.model_id,
225
+ "choices": [
226
+ {"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}
227
+ ],
228
  }
229
  return
230
 
231
  content = json.dumps(
232
+ {
233
+ "model": result["model"],
234
+ "horizon": result["horizon"],
235
+ "freq": result["freq"],
236
+ "forecast": result["forecast"],
237
+ "note": result.get("note"),
238
+ "backend": "timesfm",
239
+ },
240
+ separators=(",", ":"),
241
+ ensure_ascii=False,
242
  )
243
  yield {
244
  "id": rid,
245
  "object": "chat.completion.chunk",
246
  "created": now,
247
  "model": self.model_id,
248
+ "choices": [
249
+ {"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}
250
+ ],
251
  }
252
 
253
 
254
+ # ---------------- images stub ----------------
255
  class StubImagesBackend(ImagesBackend):
256
  async def generate_b64(self, request: Dict[str, Any]) -> str:
257
  logger.warning("Image generation not supported in TimesFM backend.")