from typing import Dict, List, Any | |
import transformers | |
import torch | |
MAX_TOKENS=1024 | |
class EndpointHandler(object): | |
def __init__(self, path=''): | |
self.pipeline: transformers.Pipeline = transformers.pipeline( | |
"text-generation", | |
model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint", | |
model_kwargs={"torch_dtype": torch.bfloat16 }, | |
device_map="auto", | |
) | |
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: | |
""" | |
:param data: | |
inputs: message format | |
parameters: parameters for the pipeline | |
:return: | |
""" | |
inputs = data.pop("inputs") | |
if parameters:= data.pop("parameters") is not None: | |
outputs = self.pipeline( | |
inputs, | |
**parameters | |
) | |
else: | |
outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS) | |
return outputs | |