kaluaim commited on
Commit
baf1a4b
·
verified ·
1 Parent(s): c89a48e

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +82 -0
handler.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HF Inference Endpoints handler for ChatTS-14B.
3
+
4
+ Expected request JSON:
5
+ {
6
+ "inputs": {
7
+ "prompt": "Describe the trend of this series.",
8
+ "timeseries": [[0.1, 0.2, 0.3, ...]], # list of float lists, one per <ts><ts/>
9
+ "max_new_tokens": 300
10
+ }
11
+ }
12
+
13
+ The prompt MUST contain one `<ts><ts/>` placeholder per series in `timeseries`.
14
+
15
+ Response:
16
+ {"generated_text": "..."}
17
+ """
18
+ from __future__ import annotations
19
+
20
+ from typing import Any
21
+
22
+ import numpy as np
23
+ import torch
24
+ from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
25
+
26
+
27
+ class EndpointHandler:
28
+ def __init__(self, path: str = "") -> None:
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ dtype = torch.float16 if self.device == "cuda" else torch.float32
31
+
32
+ self.model = AutoModelForCausalLM.from_pretrained(
33
+ path,
34
+ trust_remote_code=True,
35
+ torch_dtype=dtype,
36
+ device_map=0 if self.device == "cuda" else None,
37
+ )
38
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
39
+ self.processor = AutoProcessor.from_pretrained(
40
+ path, trust_remote_code=True, tokenizer=self.tokenizer
41
+ )
42
+ self.model.eval()
43
+
44
+ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
45
+ inputs = data.get("inputs", {})
46
+ if isinstance(inputs, str):
47
+ return {
48
+ "error": "ChatTS requires structured inputs. "
49
+ "Use {'inputs': {'prompt': str, 'timeseries': [[...]], 'max_new_tokens': int}}"
50
+ }
51
+
52
+ prompt: str = inputs["prompt"]
53
+ ts_lists = inputs["timeseries"]
54
+ max_new_tokens: int = int(inputs.get("max_new_tokens", 300))
55
+
56
+ ts_arrays = [np.asarray(t, dtype=np.float64) for t in ts_lists]
57
+
58
+ formatted = (
59
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
60
+ f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
61
+ )
62
+
63
+ proc_inputs = self.processor(
64
+ text=[formatted],
65
+ timeseries=ts_arrays,
66
+ padding=True,
67
+ return_tensors="pt",
68
+ )
69
+ proc_inputs = {k: v.to(self.device) for k, v in proc_inputs.items()}
70
+
71
+ with torch.no_grad():
72
+ outputs = self.model.generate(
73
+ **proc_inputs,
74
+ max_new_tokens=max_new_tokens,
75
+ do_sample=False,
76
+ )
77
+
78
+ generated = self.tokenizer.batch_decode(
79
+ outputs[:, proc_inputs["input_ids"].shape[1] :],
80
+ skip_special_tokens=True,
81
+ )
82
+ return {"generated_text": generated[0]}