booksouls's picture
upload handler.py and requirements.txt
983b8e1 verified
raw
history blame
No virus
1.69 kB
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
from typing import Any
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForSeq2SeqLM.from_pretrained(path, device_map="auto")
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
inputs = data.get("inputs")
parameters = data.get("parameters")
if inputs is None:
raise ValueError(f"'inputs' is missing from the request body")
if not isinstance(inputs, str):
raise ValueError(f"Expected 'inputs' to be a str, but found {type(inputs)}")
if parameters is not None and not isinstance(parameters, dict):
raise ValueError(f"Expected 'parameters' to be a dict, but found {type(parameters)}")
# Truncate the tokens to 1024 to prevent errors with BART and long text.
tokens = self.tokenizer(
inputs,
max_length=1024,
truncation=True,
return_tensors="pt",
return_attention_mask=False,
)
# Ensure the input_ids and the model are both on the GPU to prevent errors.
input_ids = tokens.input_ids.to("cuda")
# Gradient calculation is not needed for inference.
with torch.no_grad():
if parameters is None:
output = self.model.generate(input_ids)
else:
output = self.model.generate(input_ids, **parameters)
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
return {"generated_text": generated_text}