binaryaaron
commited on
update handler for inputs and parameters
Browse files- handler.py +13 -8
handler.py
CHANGED
@@ -2,21 +2,26 @@ from typing import Dict, List, Any
|
|
2 |
import transformers
|
3 |
import torch
|
4 |
|
5 |
-
MAX_TOKENS=
|
6 |
|
7 |
class EndpointHandler(object):
|
8 |
def __init__(self, path=''):
|
9 |
self.pipeline: transformers.Pipeline = transformers.pipeline(
|
10 |
"text-generation",
|
11 |
model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
|
12 |
-
model_kwargs={"torch_dtype": torch.bfloat16
|
13 |
device_map="auto",
|
14 |
)
|
15 |
|
16 |
-
def __call__(self,
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
22 |
return outputs
|
|
|
2 |
import transformers
|
3 |
import torch
|
4 |
|
5 |
+
MAX_TOKENS=1024
|
6 |
|
7 |
class EndpointHandler(object):
|
8 |
def __init__(self, path=''):
|
9 |
self.pipeline: transformers.Pipeline = transformers.pipeline(
|
10 |
"text-generation",
|
11 |
model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint",
|
12 |
+
model_kwargs={"torch_dtype": torch.bfloat16 },
|
13 |
device_map="auto",
|
14 |
)
|
15 |
|
16 |
+
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
17 |
+
inputs = data.pop("inputs")
|
18 |
+
|
19 |
+
if parameters:= data.pop("parameters", None):
|
20 |
+
outputs = self.pipeline(
|
21 |
+
inputs,
|
22 |
+
**parameters
|
23 |
+
)
|
24 |
+
else:
|
25 |
+
outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS)
|
26 |
+
|
27 |
return outputs
|