| """HuggingFace Inference API handler for text normalization. |
| |
| This enables the model to work with the HuggingFace Inference API |
| and the `text2text-generation` pipeline. |
| """ |
|
|
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
| self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-base") |
| self.model.eval() |
|
|
| def __call__(self, data): |
| """Handle inference request. |
| |
| Expected input format: |
| {"inputs": "<de> Das kostet 12,50 €."} |
| or: |
| {"inputs": "Das kostet 12,50 €.", "parameters": {"language": "de"}} |
| """ |
| inputs = data.get("inputs", "") |
| params = data.get("parameters", {}) |
|
|
| |
| if not inputs.startswith("<") and "language" in params: |
| inputs = f"<{params['language']}> {inputs}" |
|
|
| tokenized = self.tokenizer( |
| inputs, return_tensors="pt", max_length=512, truncation=True |
| ) |
|
|
| import torch |
|
|
| with torch.no_grad(): |
| output = self.model.generate( |
| **tokenized, |
| max_new_tokens=params.get("max_new_tokens", 512), |
| num_beams=params.get("num_beams", 1), |
| ) |
|
|
| result = self.tokenizer.decode(output[0], skip_special_tokens=True) |
| return [{"generated_text": result}] |
|
|