from typing import Dict, List, Any import transformers import torch MAX_TOKENS=1024 class EndpointHandler(object): def __init__(self, path=''): self.pipeline: transformers.Pipeline = transformers.pipeline( "text-generation", model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint", model_kwargs={"torch_dtype": torch.bfloat16 }, device_map="auto", ) def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]: """ :param data: inputs: message format parameters: parameters for the pipeline :return: """ inputs = data.pop("inputs") if parameters:= data.pop("parameters") is not None: outputs = self.pipeline( inputs, **parameters ) else: outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS) return outputs