from typing import Dict, List, Any import transformers import torch MAX_TOKENS=8192 class EndpointHandler(object): def __init__(self): self.pipeline: transformers.Pipeline = transformers.pipeline( "text-generation", model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint", model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True, }, device_map="auto", ) def __call__(self, text_inputs: Any) -> List[List[Dict[str, float]]]: outputs = self.pipeline( text_inputs, max_new_tokens=MAX_TOKENS, ) print(outputs[0]["generated_text"][-1]) return outputs