Update handler.py
Browse files- handler.py +13 -9
handler.py
CHANGED
@@ -4,8 +4,11 @@ from typing import Dict, List, Any
|
|
4 |
import torch
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
9 |
|
10 |
class EndpointHandler():
|
11 |
def __init__(self, data):
|
@@ -17,19 +20,20 @@ class EndpointHandler():
|
|
17 |
|
18 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
19 |
inputs = data.pop("inputs", "Q: What is the chemical composition of common concrete in 2024?\nA: ")
|
20 |
-
|
21 |
try:
|
22 |
-
|
23 |
except Exception as e:
|
24 |
return json.dumps({
|
25 |
"status": "error",
|
26 |
"reason": "max_length was passed as something that was absolutely not a plain old int"
|
27 |
})
|
28 |
|
29 |
-
res =
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
34 |
|
35 |
return res
|
|
|
4 |
import torch
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
|
7 |
+
PROMPT_FORMAT= """
|
8 |
+
<|user|>
|
9 |
+
{inputs} <|end|>
|
10 |
+
<|assistant|>
|
11 |
+
"""
|
12 |
|
13 |
class EndpointHandler():
|
14 |
def __init__(self, data):
|
|
|
20 |
|
21 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
22 |
inputs = data.pop("inputs", "Q: What is the chemical composition of common concrete in 2024?\nA: ")
|
23 |
+
max_new_tokens = data.pop("max_length", 1024)
|
24 |
try:
|
25 |
+
max_new_tokens = int(max_new_tokens)
|
26 |
except Exception as e:
|
27 |
return json.dumps({
|
28 |
"status": "error",
|
29 |
"reason": "max_length was passed as something that was absolutely not a plain old int"
|
30 |
})
|
31 |
|
32 |
+
res = PROMPT_FORMAT.format(do_sample=False)
|
33 |
+
retrurn model(
|
34 |
+
res,
|
35 |
+
max_new_tokens=max_new_tokens,
|
36 |
+
do_sample=False
|
37 |
+
)
|
38 |
|
39 |
return res
|