|
from typing import Dict, List, Any |
|
import transformers |
|
import torch |
|
|
|
MAX_TOKENS=8192 |
|
|
|
class EndpointHandler(object): |
|
def __init__(self): |
|
self.pipeline: transformers.Pipeline = transformers.pipeline( |
|
"text-generation", |
|
model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint", |
|
model_kwargs={"torch_dtype": torch.bfloat16, "low_cpu_mem_usage": True, }, |
|
device_map="auto", |
|
) |
|
|
|
def __call__(self, text_inputs: Any) -> List[List[Dict[str, float]]]: |
|
outputs = self.pipeline( |
|
text_inputs, |
|
max_new_tokens=MAX_TOKENS, |
|
) |
|
print(outputs[0]["generated_text"][-1]) |
|
return outputs |
|
|