Time Series Forecasting
TiRex
Nikita commited on
Commit
155c2a4
·
1 Parent(s): e015236

inference and reqs

Browse files
Files changed (2) hide show
  1. inference.py +45 -0
  2. requirements.txt +1 -0
inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from tirex import load_model, ForecastModel
4
+
5
+ # Disable CUDA for Hugging Face endpoints unless explicitly enabled
6
+ os.environ['TIREX_NO_CUDA'] = '1'
7
+
8
+ class EndpointModel:
9
+ def __init__(self):
10
+ """
11
+ This class is used by Hugging Face Inference Endpoints
12
+ to initialize the model once at startup.
13
+ """
14
+ # Load the TiRex model from Hugging Face hub
15
+ # This will resolve to your repo (NX-AI/TiRex)
16
+ self.model: ForecastModel = load_model("NX-AI/TiRex")
17
+
18
+ def __call__(self, inputs: dict) -> dict:
19
+ """
20
+ This method is called for every inference request.
21
+ Inputs must be JSON-serializable.
22
+ Example request:
23
+ {
24
+ "data": [[0.1, 0.2, 0.3, ...], [0.5, 0.6, ...]], # 2D array: batch_size x context_length
25
+ "prediction_length": 64
26
+ }
27
+ """
28
+ # Convert input data to a torch tensor
29
+ data = torch.tensor(inputs["data"], dtype=torch.float32)
30
+
31
+ # Default prediction length if not provided
32
+ prediction_length = inputs.get("prediction_length", 64)
33
+
34
+ # Run forecast
35
+ quantiles, mean = self.model.forecast(
36
+ context=data,
37
+ prediction_length=prediction_length
38
+ )
39
+
40
+ # Return both quantiles and mean as Python lists (JSON-safe)
41
+ return {
42
+ "quantiles": {k: v.tolist() for k, v in quantiles.items()},
43
+ "mean": mean.tolist()
44
+ }
45
+
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git+https://github.com/NX-AI/tirex.git