exx commited on
Commit
93517af
·
1 Parent(s): 0aa0c9c

Added TimesFM 2.5

Browse files
Files changed (1) hide show
  1. models/TimesFM2.py +665 -0
models/TimesFM2.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Self-contained TimesFM 2.x wrapper compatible with the TimesFM interface."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import dataclasses
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ try:
13
+ from safetensors.torch import load_file as _load_safetensors
14
+ except ImportError: # pragma: no cover - optional dependency
15
+ _load_safetensors = None
16
+
17
+ _TOLERANCE = 1e-6
18
+
19
+
20
+ @dataclasses.dataclass(frozen=True)
21
+ class ResidualBlockConfig:
22
+ input_dims: int
23
+ hidden_dims: int
24
+ output_dims: int
25
+ use_bias: bool
26
+ activation: str
27
+
28
+
29
+ @dataclasses.dataclass(frozen=True)
30
+ class TransformerConfig:
31
+ model_dims: int
32
+ hidden_dims: int
33
+ num_heads: int
34
+ attention_norm: str
35
+ feedforward_norm: str
36
+ qk_norm: str
37
+ use_bias: bool
38
+ use_rotary_position_embeddings: bool
39
+ ff_activation: str
40
+ fuse_qkv: bool
41
+
42
+
43
+ @dataclasses.dataclass(frozen=True)
44
+ class StackedTransformersConfig:
45
+ num_layers: int
46
+ transformer: TransformerConfig
47
+
48
+
49
+ @dataclasses.dataclass(frozen=True)
50
+ class TimesFM2Definition:
51
+ """Framework-agnostic description of TimesFM 2.5 (200M parameters)."""
52
+
53
+ context_limit: int = 16384
54
+ input_patch_len: int = 32
55
+ output_patch_len: int = 128
56
+ output_quantile_len: int = 1024
57
+ quantiles: tuple[float, ...] = (
58
+ 0.1,
59
+ 0.2,
60
+ 0.3,
61
+ 0.4,
62
+ 0.5,
63
+ 0.6,
64
+ 0.7,
65
+ 0.8,
66
+ 0.9,
67
+ )
68
+ decode_index: int = 5
69
+ tokenizer: ResidualBlockConfig = dataclasses.field(
70
+ default_factory=lambda: ResidualBlockConfig(
71
+ input_dims=64,
72
+ hidden_dims=1280,
73
+ output_dims=1280,
74
+ use_bias=True,
75
+ activation="swish",
76
+ )
77
+ )
78
+ stacked_transformers: StackedTransformersConfig = dataclasses.field(
79
+ default_factory=lambda: StackedTransformersConfig(
80
+ num_layers=20,
81
+ transformer=TransformerConfig(
82
+ model_dims=1280,
83
+ hidden_dims=1280,
84
+ num_heads=16,
85
+ attention_norm="rms",
86
+ feedforward_norm="rms",
87
+ qk_norm="rms",
88
+ use_bias=False,
89
+ use_rotary_position_embeddings=True,
90
+ ff_activation="swish",
91
+ fuse_qkv=True,
92
+ ),
93
+ )
94
+ )
95
+ output_projection_point: ResidualBlockConfig = dataclasses.field(
96
+ default_factory=lambda: ResidualBlockConfig(
97
+ input_dims=1280,
98
+ hidden_dims=1280,
99
+ output_dims=1280,
100
+ use_bias=False,
101
+ activation="swish",
102
+ )
103
+ )
104
+ output_projection_quantiles: ResidualBlockConfig = dataclasses.field(
105
+ default_factory=lambda: ResidualBlockConfig(
106
+ input_dims=1280,
107
+ hidden_dims=1280,
108
+ output_dims=10240,
109
+ use_bias=False,
110
+ activation="swish",
111
+ )
112
+ )
113
+
114
+
115
+ @dataclasses.dataclass(frozen=False)
116
+ class DecodeCache:
117
+ next_index: torch.Tensor
118
+ num_masked: torch.Tensor
119
+ key: torch.Tensor
120
+ value: torch.Tensor
121
+
122
+
123
+ def update_running_stats(
124
+ n: torch.Tensor,
125
+ mu: torch.Tensor,
126
+ sigma: torch.Tensor,
127
+ x: torch.Tensor,
128
+ mask: torch.Tensor,
129
+ ) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
130
+ """Updates reversible normalization statistics for a new patch."""
131
+ is_legit = torch.logical_not(mask)
132
+ inc_n = torch.sum(is_legit.to(x.dtype), dim=-1)
133
+
134
+ inc_mu_numerator = torch.sum(x * is_legit, dim=-1)
135
+ inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n)
136
+ inc_mu = inc_mu_numerator / inc_n_safe
137
+ inc_mu = torch.where(inc_n == 0, 0.0, inc_mu)
138
+
139
+ inc_var_numerator = torch.sum(((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1)
140
+ inc_var = inc_var_numerator / inc_n_safe
141
+ inc_var = torch.where(inc_n == 0, 0.0, inc_var)
142
+ inc_sigma = torch.sqrt(inc_var)
143
+
144
+ new_n = n + inc_n
145
+ new_n_safe = torch.where(new_n == 0, 1.0, new_n)
146
+
147
+ new_mu = (n * mu + inc_mu * inc_n) / new_n_safe
148
+ new_mu = torch.where(new_n == 0, 0.0, new_mu)
149
+
150
+ term1 = n * sigma.pow(2)
151
+ term2 = inc_n * inc_sigma.pow(2)
152
+ term3 = n * (mu - new_mu).pow(2)
153
+ term4 = inc_n * (inc_mu - new_mu).pow(2)
154
+
155
+ new_var = (term1 + term2 + term3 + term4) / new_n_safe
156
+ new_var = torch.where(new_n == 0, 0.0, new_var)
157
+ new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0))
158
+
159
+ return (new_n, new_mu, new_sigma), (new_n, new_mu, new_sigma)
160
+
161
+
162
+ def revin(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, reverse: bool = False) -> torch.Tensor:
163
+ """Reversible instance normalization."""
164
+ if len(mu.shape) == len(x.shape) - 1:
165
+ mu = mu[..., None]
166
+ sigma = sigma[..., None]
167
+ elif len(mu.shape) == len(x.shape) - 2:
168
+ mu = mu[..., None, None]
169
+ sigma = sigma[..., None, None]
170
+
171
+ if reverse:
172
+ return x * sigma + mu
173
+
174
+ sigma_safe = torch.where(sigma < _TOLERANCE, torch.ones_like(sigma), sigma)
175
+ return (x - mu) / sigma_safe
176
+
177
+
178
+ class ResidualBlock(nn.Module):
179
+ """Residual block composed of a pair of linear layers."""
180
+
181
+ def __init__(self, config: ResidualBlockConfig):
182
+ super().__init__()
183
+ self.activation = self._resolve_activation(config.activation)
184
+ self.hidden_layer = nn.Linear(config.input_dims, config.hidden_dims, bias=config.use_bias)
185
+ self.output_layer = nn.Linear(config.hidden_dims, config.output_dims, bias=config.use_bias)
186
+ self.residual_layer = nn.Linear(config.input_dims, config.output_dims, bias=config.use_bias)
187
+
188
+ @staticmethod
189
+ def _resolve_activation(name: str) -> nn.Module:
190
+ if name == "relu":
191
+ return nn.ReLU()
192
+ if name == "swish":
193
+ return nn.SiLU()
194
+ if name == "none":
195
+ return nn.Identity()
196
+ raise ValueError(f"Unsupported activation: {name}")
197
+
198
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
199
+ hidden = self.activation(self.hidden_layer(x))
200
+ return self.output_layer(hidden) + self.residual_layer(x)
201
+
202
+
203
+ class RMSNorm(nn.Module):
204
+ """Root-mean-square normalization."""
205
+
206
+ def __init__(self, num_features: int, epsilon: float = 1e-6):
207
+ super().__init__()
208
+ self.scale = nn.Parameter(torch.zeros(num_features))
209
+ self.epsilon = epsilon
210
+
211
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
212
+ var = torch.mean(torch.square(inputs), dim=-1, keepdim=True)
213
+ normed_inputs = inputs * torch.rsqrt(var + self.epsilon)
214
+ return normed_inputs * self.scale
215
+
216
+
217
+ def make_attn_mask(
218
+ query_length: int,
219
+ num_all_masked_kv: torch.Tensor,
220
+ query_index_offset: torch.Tensor | None = None,
221
+ kv_length: int = 0,
222
+ ) -> torch.Tensor:
223
+ """Creates a causal mask consistent with cached decoding."""
224
+ if kv_length == 0:
225
+ kv_length = query_length
226
+
227
+ q_index = torch.arange(query_length, device=num_all_masked_kv.device)[None, None, :, None]
228
+ if query_index_offset is not None:
229
+ q_index = q_index + query_index_offset[:, None, None, None]
230
+ kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[None, None, None, :]
231
+
232
+ return torch.logical_and(q_index >= kv_index, kv_index >= num_all_masked_kv[:, None, None, None])
233
+
234
+
235
+ class RotaryPositionalEmbedding(nn.Module):
236
+ """Applies rotary position embeddings to query/key projections."""
237
+
238
+ def __init__(self, embedding_dims: int, min_timescale: float = 1.0, max_timescale: float = 10000.0):
239
+ super().__init__()
240
+ self.embedding_dims = embedding_dims
241
+ self.min_timescale = min_timescale
242
+ self.max_timescale = max_timescale
243
+
244
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor | None = None) -> torch.Tensor:
245
+ if self.embedding_dims != inputs.shape[-1]:
246
+ raise ValueError("Rotary embedding dimension must equal the head dimension.")
247
+
248
+ half_dim = self.embedding_dims // 2
249
+ fraction = 2 * torch.arange(half_dim, device=inputs.device) / self.embedding_dims
250
+ timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(inputs.device)
251
+
252
+ if position is None:
253
+ position = torch.arange(inputs.shape[1], dtype=torch.float32, device=inputs.device)[None, :]
254
+
255
+ if len(inputs.shape) == 4:
256
+ position = position[..., None, None]
257
+ timescale = timescale[None, None, None, :]
258
+ elif len(inputs.shape) == 3:
259
+ position = position[..., None]
260
+ timescale = timescale[None, None, :]
261
+ else:
262
+ raise ValueError("Expected rank-3 or rank-4 tensor for rotary embeddings.")
263
+
264
+ sinusoid = position / timescale
265
+ sin = torch.sin(sinusoid)
266
+ cos = torch.cos(sinusoid)
267
+
268
+ first_half, second_half = torch.chunk(inputs, 2, dim=-1)
269
+ rotated_first = first_half * cos - second_half * sin
270
+ rotated_second = second_half * cos + first_half * sin
271
+ return torch.cat([rotated_first, rotated_second], dim=-1)
272
+
273
+
274
+ class PerDimScale(nn.Module):
275
+ """Learned per-dimension scaling used prior to attention."""
276
+
277
+ def __init__(self, num_dims: int):
278
+ super().__init__()
279
+ self.num_dims = num_dims
280
+ self.per_dim_scale = nn.Parameter(torch.zeros(num_dims))
281
+
282
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
283
+ scale_factor = 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale)
284
+ return x * scale_factor
285
+
286
+
287
+ class MultiHeadAttention(nn.Module):
288
+ """Multi-head attention supporting fused QKV projections and caching."""
289
+
290
+ def __init__(
291
+ self,
292
+ num_heads: int,
293
+ in_features: int,
294
+ *,
295
+ use_per_dim_scale: bool = True,
296
+ use_rotary_position_embeddings: bool = True,
297
+ use_bias: bool = False,
298
+ attention_fn=F.scaled_dot_product_attention,
299
+ qk_norm: str = "rms",
300
+ fuse_qkv: bool = False,
301
+ ):
302
+ super().__init__()
303
+ self.num_heads = num_heads
304
+ self.in_features = in_features
305
+ self.head_dim = in_features // num_heads
306
+ self.use_bias = use_bias
307
+ self.attention_fn = attention_fn
308
+ self.qk_norm = qk_norm
309
+ self.fuse_qkv = fuse_qkv
310
+
311
+ if in_features % num_heads != 0:
312
+ raise ValueError(f"Model dimension {in_features} must be divisible by {num_heads} heads.")
313
+
314
+ if fuse_qkv:
315
+ self.qkv_proj = nn.Linear(in_features, 3 * in_features, bias=use_bias)
316
+ else:
317
+ self.query = nn.Linear(in_features, in_features, bias=use_bias)
318
+ self.key = nn.Linear(in_features, in_features, bias=use_bias)
319
+ self.value = nn.Linear(in_features, in_features, bias=use_bias)
320
+
321
+ self.out = nn.Linear(in_features, in_features, bias=use_bias)
322
+
323
+ if qk_norm == "rms":
324
+ self.query_ln = RMSNorm(self.head_dim)
325
+ self.key_ln = RMSNorm(self.head_dim)
326
+ else:
327
+ self.query_ln = nn.Identity()
328
+ self.key_ln = nn.Identity()
329
+
330
+ self.use_rotary_position_embeddings = use_rotary_position_embeddings
331
+ if use_rotary_position_embeddings:
332
+ self.rotary_position_embedding = RotaryPositionalEmbedding(self.head_dim)
333
+
334
+ self.use_per_dim_scale = use_per_dim_scale
335
+ if use_per_dim_scale:
336
+ self.per_dim_scale = PerDimScale(self.head_dim)
337
+
338
+ def forward(
339
+ self,
340
+ inputs_q: torch.Tensor,
341
+ *,
342
+ decode_cache: DecodeCache | None = None,
343
+ patch_mask: torch.Tensor | None = None,
344
+ ) -> tuple[torch.Tensor, DecodeCache | None]:
345
+ batch, num_patches, _ = inputs_q.shape
346
+ if patch_mask is None:
347
+ patch_mask = torch.zeros(batch, num_patches, dtype=torch.bool, device=inputs_q.device)
348
+
349
+ if self.fuse_qkv:
350
+ qkv = self.qkv_proj(inputs_q)
351
+ query, key, value = torch.chunk(qkv, 3, dim=-1)
352
+ query = query.view(batch, num_patches, self.num_heads, self.head_dim)
353
+ key = key.view(batch, num_patches, self.num_heads, self.head_dim)
354
+ value = value.view(batch, num_patches, self.num_heads, self.head_dim)
355
+ else:
356
+ query = self.query(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
357
+ key = self.key(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
358
+ value = self.value(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim)
359
+
360
+ if decode_cache is None:
361
+ num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1)
362
+ next_index = torch.zeros_like(num_masked, dtype=torch.int32)
363
+ else:
364
+ num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked
365
+ next_index = decode_cache.next_index.clone()
366
+
367
+ if self.use_rotary_position_embeddings:
368
+ position = (
369
+ torch.arange(num_patches, device=inputs_q.device)[None, :]
370
+ + next_index[:, None]
371
+ - num_masked[:, None]
372
+ )
373
+ query = self.rotary_position_embedding(query, position)
374
+ key = self.rotary_position_embedding(key, position)
375
+
376
+ query = self.query_ln(query)
377
+ key = self.key_ln(key)
378
+
379
+ if self.use_per_dim_scale:
380
+ query = self.per_dim_scale(query)
381
+
382
+ if decode_cache is not None:
383
+ _, cache_size, _, _ = decode_cache.value.shape
384
+ start = decode_cache.next_index[0]
385
+ end = start + num_patches
386
+
387
+ decode_cache.key[:, start:end] = key
388
+ decode_cache.value[:, start:end] = value
389
+
390
+ key = decode_cache.key
391
+ value = decode_cache.value
392
+ decode_cache.next_index += num_patches
393
+ decode_cache.num_masked = num_masked
394
+ attn_mask = make_attn_mask(
395
+ query_length=num_patches,
396
+ num_all_masked_kv=num_masked,
397
+ query_index_offset=next_index,
398
+ kv_length=cache_size,
399
+ )
400
+ else:
401
+ attn_mask = make_attn_mask(query_length=num_patches, num_all_masked_kv=num_masked)
402
+
403
+ attn_output = F.scaled_dot_product_attention(
404
+ query.permute(0, 2, 1, 3),
405
+ key.permute(0, 2, 1, 3),
406
+ value.permute(0, 2, 1, 3),
407
+ attn_mask=attn_mask,
408
+ scale=1.0,
409
+ )
410
+ attn_output = attn_output.permute(0, 2, 1, 3)
411
+ attn_output = attn_output.reshape(batch, num_patches, self.in_features)
412
+ return self.out(attn_output), decode_cache
413
+
414
+
415
+ class Transformer(nn.Module):
416
+ """Transformer block used by TimesFM."""
417
+
418
+ def __init__(self, config: TransformerConfig):
419
+ super().__init__()
420
+ if config.attention_norm != "rms" or config.feedforward_norm != "rms":
421
+ raise ValueError("Only RMS normalization is supported.")
422
+
423
+ self.pre_attn_ln = RMSNorm(config.model_dims)
424
+ self.post_attn_ln = RMSNorm(config.model_dims)
425
+ self.attn = MultiHeadAttention(
426
+ num_heads=config.num_heads,
427
+ in_features=config.model_dims,
428
+ use_per_dim_scale=True,
429
+ use_rotary_position_embeddings=config.use_rotary_position_embeddings,
430
+ qk_norm=config.qk_norm,
431
+ fuse_qkv=config.fuse_qkv,
432
+ )
433
+
434
+ self.pre_ff_ln = RMSNorm(config.model_dims)
435
+ self.post_ff_ln = RMSNorm(config.model_dims)
436
+ self.ff0 = nn.Linear(config.model_dims, config.hidden_dims, bias=config.use_bias)
437
+ self.ff1 = nn.Linear(config.hidden_dims, config.model_dims, bias=config.use_bias)
438
+ self.activation = ResidualBlock._resolve_activation(config.ff_activation)
439
+
440
+ def forward(
441
+ self,
442
+ input_embeddings: torch.Tensor,
443
+ patch_mask: torch.Tensor,
444
+ decode_cache: DecodeCache | None = None,
445
+ ) -> tuple[torch.Tensor, DecodeCache | None]:
446
+ attn_output, decode_cache = self.attn(
447
+ inputs_q=self.pre_attn_ln(input_embeddings),
448
+ decode_cache=decode_cache,
449
+ patch_mask=patch_mask,
450
+ )
451
+ attn_output = self.post_attn_ln(attn_output) + input_embeddings
452
+ feedforward = self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output))))
453
+ output_embeddings = self.post_ff_ln(feedforward) + attn_output
454
+ return output_embeddings, decode_cache
455
+
456
+
457
+ class TimesFM2Core(nn.Module):
458
+ """Core TimesFM 2.x backbone without external dependencies."""
459
+
460
+ def __init__(self, definition: TimesFM2Definition | None = None):
461
+ super().__init__()
462
+ self.config = definition or TimesFM2Definition()
463
+
464
+ self.p = self.config.input_patch_len
465
+ self.o = self.config.output_patch_len
466
+ self.os = self.config.output_quantile_len
467
+ self.m = self.o // self.p
468
+ self.x = self.config.stacked_transformers.num_layers
469
+ self.h = self.config.stacked_transformers.transformer.num_heads
470
+ self.md = self.config.stacked_transformers.transformer.model_dims
471
+ self.hd = self.md // self.h
472
+ self.q = len(self.config.quantiles) + 1
473
+ self.aridx = self.config.decode_index
474
+
475
+ self.tokenizer = ResidualBlock(self.config.tokenizer)
476
+ self.stacked_xf = nn.ModuleList(
477
+ [Transformer(self.config.stacked_transformers.transformer) for _ in range(self.x)]
478
+ )
479
+ self.output_projection_point = ResidualBlock(self.config.output_projection_point)
480
+ self.output_projection_quantiles = ResidualBlock(self.config.output_projection_quantiles)
481
+
482
+ def load_safetensors(self, path: str, strict: bool = True) -> None:
483
+ if _load_safetensors is None:
484
+ raise ImportError("Install safetensors to load TimesFM2 checkpoints.")
485
+ tensors = _load_safetensors(path)
486
+ self.load_state_dict(tensors, strict=strict)
487
+ self.eval()
488
+
489
+ def forward(
490
+ self,
491
+ inputs: torch.Tensor,
492
+ masks: torch.Tensor,
493
+ decode_caches: list[DecodeCache] | None = None,
494
+ ) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], list[DecodeCache]]:
495
+ tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1)
496
+ input_embeddings = self.tokenizer(tokenizer_inputs)
497
+
498
+ if decode_caches is None:
499
+ decode_caches = [None] * self.x # type: ignore[list-item]
500
+
501
+ output_embeddings = input_embeddings
502
+ new_decode_caches: list[DecodeCache] = []
503
+ for layer, cache in zip(self.stacked_xf, decode_caches):
504
+ output_embeddings, new_cache = layer(output_embeddings, masks[..., -1], cache)
505
+ new_decode_caches.append(new_cache)
506
+
507
+ output_ts = self.output_projection_point(output_embeddings)
508
+ output_quantile_spread = self.output_projection_quantiles(output_embeddings)
509
+ return (input_embeddings, output_embeddings, output_ts, output_quantile_spread), new_decode_caches
510
+
511
+ def decode(
512
+ self,
513
+ horizon: int,
514
+ inputs: torch.Tensor,
515
+ masks: torch.Tensor,
516
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
517
+ """Autoregressively decodes a batch of sequences."""
518
+ with torch.no_grad():
519
+ batch_size, context = inputs.shape
520
+ num_decode_steps = (horizon - 1) // self.o
521
+ num_input_patches = context // self.p
522
+ decode_cache_size = num_input_patches + num_decode_steps * self.m
523
+
524
+ patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p))
525
+ patched_masks = torch.reshape(masks, (batch_size, -1, self.p))
526
+
527
+ n = torch.zeros(batch_size, device=inputs.device)
528
+ mu = torch.zeros(batch_size, device=inputs.device)
529
+ sigma = torch.zeros(batch_size, device=inputs.device)
530
+ patch_mu: list[torch.Tensor] = []
531
+ patch_sigma: list[torch.Tensor] = []
532
+ for i in range(num_input_patches):
533
+ (n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i])
534
+ patch_mu.append(mu)
535
+ patch_sigma.append(sigma)
536
+
537
+ last_n, last_mu, last_sigma = n, mu, sigma
538
+ context_mu = torch.stack(patch_mu, dim=1)
539
+ context_sigma = torch.stack(patch_sigma, dim=1)
540
+
541
+ decode_caches = [
542
+ DecodeCache(
543
+ next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
544
+ num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device),
545
+ key=torch.zeros(batch_size, decode_cache_size, self.h, self.hd, device=inputs.device),
546
+ value=torch.zeros(batch_size, decode_cache_size, self.h, self.hd, device=inputs.device),
547
+ )
548
+ for _ in range(self.x)
549
+ ]
550
+
551
+ normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False)
552
+ normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)
553
+ (_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches)
554
+
555
+ renormed_outputs = torch.reshape(
556
+ revin(normed_outputs, context_mu, context_sigma, reverse=True),
557
+ (batch_size, -1, self.o, self.q),
558
+ )
559
+ renormed_quantile_spread = torch.reshape(
560
+ revin(normed_quantile_spread, context_mu, context_sigma, reverse=True),
561
+ (batch_size, -1, self.os, self.q),
562
+ )[:, -1, ...]
563
+
564
+ ar_outputs: list[torch.Tensor] = []
565
+ last_renormed_output = renormed_outputs[:, -1, :, self.aridx]
566
+
567
+ for _ in range(num_decode_steps):
568
+ new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p))
569
+ new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)
570
+
571
+ n, mu, sigma = last_n, last_mu, last_sigma
572
+ new_mus: list[torch.Tensor] = []
573
+ new_sigmas: list[torch.Tensor] = []
574
+ for i in range(self.m):
575
+ (n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i])
576
+ new_mus.append(mu)
577
+ new_sigmas.append(sigma)
578
+ last_n, last_mu, last_sigma = n, mu, sigma
579
+ new_mu = torch.stack(new_mus, dim=1)
580
+ new_sigma = torch.stack(new_sigmas, dim=1)
581
+
582
+ new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False)
583
+ (_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches)
584
+
585
+ new_renormed_output = torch.reshape(
586
+ revin(new_normed_output, new_mu, new_sigma, reverse=True),
587
+ (batch_size, self.m, self.o, self.q),
588
+ )
589
+ ar_outputs.append(new_renormed_output[:, -1, ...])
590
+ last_renormed_output = new_renormed_output[:, -1, :, self.aridx]
591
+
592
+ ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None
593
+
594
+ return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs
595
+
596
+
597
+ class TimesFM2(nn.Module):
598
+ """High-level TimesFM 2.x wrapper mirroring the TimesFM interface."""
599
+
600
+ def __init__(self, lookback: int = 512, lookahead: int = 96):
601
+ super().__init__()
602
+ self.lookback = lookback
603
+ self.lookahead = lookahead
604
+ self.core = TimesFM2Core()
605
+
606
+ if lookback > self.core.config.context_limit:
607
+ raise ValueError(
608
+ f"lookback ({lookback}) exceeds maximum context limit ({self.core.config.context_limit})."
609
+ )
610
+
611
+ def load_state_dict(self, state_dict, strict: bool = True):
612
+ return self.core.load_state_dict(state_dict, strict=strict)
613
+
614
+ def state_dict(self, *args, **kwargs):
615
+ return self.core.state_dict(*args, **kwargs)
616
+
617
+ def load_safetensors(self, path: str, strict: bool = True) -> None:
618
+ self.core.load_safetensors(path, strict=strict)
619
+
620
+ def _prepare_inputs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
621
+ if x.shape[1] < self.lookback:
622
+ raise ValueError(f"Expected at least {self.lookback} context steps, received {x.shape[1]}.")
623
+ context = x[:, -self.lookback:]
624
+ pad_len = (-context.shape[1]) % self.core.p
625
+ if pad_len > 0:
626
+ context = F.pad(context, (pad_len, 0))
627
+ pad_mask = torch.ones(context.shape[0], pad_len, dtype=torch.bool, device=context.device)
628
+ mask = torch.cat(
629
+ [pad_mask, torch.zeros(context.shape[0], self.lookback, dtype=torch.bool, device=context.device)],
630
+ dim=1,
631
+ )
632
+ else:
633
+ mask = torch.zeros_like(context, dtype=torch.bool)
634
+
635
+ if context.shape[1] > self.core.config.context_limit:
636
+ context = context[:, -self.core.config.context_limit :]
637
+ mask = mask[:, -self.core.config.context_limit :]
638
+
639
+ return context, mask
640
+
641
+ def forward(
642
+ self,
643
+ x: torch.Tensor,
644
+ *,
645
+ return_quantiles: bool = False,
646
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
647
+ if x.dim() != 2:
648
+ raise ValueError(f"Expected input tensor of shape (batch, time), received {tuple(x.shape)}.")
649
+
650
+ inputs, mask = self._prepare_inputs(x.to(dtype=torch.float32))
651
+ renormed_outputs, _, ar_outputs = self.core.decode(self.lookahead, inputs, mask)
652
+ batch_size = inputs.shape[0]
653
+
654
+ to_cat = [renormed_outputs[:, -1, ...]]
655
+ if ar_outputs is not None:
656
+ to_cat.append(ar_outputs.reshape(batch_size, -1, self.core.q))
657
+ full_forecast = torch.cat(to_cat, dim=1)[:, : self.lookahead, :]
658
+
659
+ point_forecast = full_forecast[..., self.core.aridx]
660
+ if return_quantiles:
661
+ return point_forecast, full_forecast
662
+ return point_forecast
663
+
664
+
665
+ __all__ = ["TimesFM2", "TimesFM2Core", "TimesFM2Definition"]