|
from typing import Dict, List, Any |
|
import transformers |
|
import torch |
|
|
|
MAX_TOKENS=8192 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.pipeline = transformers.pipeline( |
|
"text-generation", |
|
model="humane-intelligence/gemma2-9b-cpt-sealionv3-instruct-endpoint", |
|
model_kwargs={"torch_dtype": torch.bfloat16}, |
|
device_map="auto", |
|
) |
|
|
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
inputs = data.get("inputs", data) |
|
|
|
outputs = self.pipeline( |
|
inputs, |
|
max_new_tokens=256, |
|
) |
|
print(outputs[0]["generated_text"][-1]) |
|
return outputs |
|
|