saytext / handler.py
michaelmuellersmao's picture
Upload folder using huggingface_hub
474b76e verified
"""HuggingFace Inference API handler for text normalization.
This enables the model to work with the HuggingFace Inference API
and the `text2text-generation` pipeline.
"""
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
class EndpointHandler:
def __init__(self, path: str = ""):
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-base")
self.model.eval()
def __call__(self, data):
"""Handle inference request.
Expected input format:
{"inputs": "<de> Das kostet 12,50 €."}
or:
{"inputs": "Das kostet 12,50 €.", "parameters": {"language": "de"}}
"""
inputs = data.get("inputs", "")
params = data.get("parameters", {})
# If language is passed separately, add the prefix
if not inputs.startswith("<") and "language" in params:
inputs = f"<{params['language']}> {inputs}"
tokenized = self.tokenizer(
inputs, return_tensors="pt", max_length=512, truncation=True
)
import torch
with torch.no_grad():
output = self.model.generate(
**tokenized,
max_new_tokens=params.get("max_new_tokens", 512),
num_beams=params.get("num_beams", 1),
)
result = self.tokenizer.decode(output[0], skip_special_tokens=True)
return [{"generated_text": result}]