philschmid's picture
philschmid HF staff
add custom handler
3ed65c3
raw
history blame contribute delete
No virus
919 Bytes
from typing import Dict, List, Any
from transformers import pipeline
import holidays
class EndpointHandler:
def __init__(self, path=""):
self.pipeline = pipeline("text-classification", model=path)
self.holidays = holidays.US()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
# get additional date field
date = data.pop("date", None)
# check if date exists and if it is a holiday
if date is not None and date in self.holidays:
return [{"label": "happy", "score": 1}]
# run normal prediction
prediction = self.pipeline(inputs)
return prediction