from typing import Dict, List, Any import transformers import torch MAX_TOKENS=1024 class EndpointHandler(object): def __init__(self, path=''): self.pipeline: transformers.Pipeline = transformers.pipeline( "text-generation", model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint", model_kwargs={"torch_dtype": torch.bfloat16 }, device_map="auto", ) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: inputs = data.pop("inputs") if parameters:= data.pop("parameters", None): outputs = self.pipeline( inputs, **parameters ) else: outputs = self.pipeline(inputs, max_new_tokens=MAX_TOKENS) return outputs