|
from typing import Dict, List, Any, Optional |
|
import transformers |
|
import torch |
|
|
|
MAX_TOKENS=1024 |
|
|
|
class EndpointHandler(object): |
|
def __init__(self, path=''): |
|
self.pipeline: transformers.Pipeline = transformers.pipeline( |
|
"text-generation", |
|
model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint", |
|
model_kwargs={"torch_dtype": torch.bfloat16 }, |
|
device_map="auto", |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: |
|
""" |
|
:param data: |
|
inputs: message format |
|
parameters: parameters for the pipeline |
|
:return: |
|
""" |
|
print(f"data: {data}") |
|
inputs = data.pop("inputs") |
|
parameters: Optional[Dict] = data.pop("parameters", None) |
|
|
|
if parameters is not None: |
|
outputs = self.pipeline( |
|
inputs, |
|
**parameters |
|
) |
|
else: |
|
outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS) |
|
|
|
return outputs |
|
|