from typing import Dict, List, Any from setfit import SetFitModel class EndpointHandler: def __init__(self, path=""): # load model self.model = SetFitModel.from_pretrained(path) # ag_news id to label mapping self.id2label = { 0: "Art", 1: "Artificial Intelligence", 2: "Beauty", 3: "Blockchain", 4: "Business", 5: "Cities", 6: "Cultural Studies", 7: "Data Science", 8: "Design", 9: "Dev Ops", 10: "Drugs", 11: "Economics", 12: "Education", 13: "Equality", 14: "Family", 15: "Fashion", 16: "Finance", 17: "Food", 18: "Gadgets", 19: "Gaming", 20: "Health", 21: "Home", 22: "Humor", 23: "Language", 24: "Law", 25: "Leadership", 26: "Makers", 27: "Marketing", 28: "Mathematics", 29: "Mental Health", 30: "Mindfulness", 31: "Movies", 32: "Music", 33: "Nature", 34: "News", 35: "Operating Systems", 36: "Pets", 37: "Philosophy", 38: "Photography", 39: "Podcasts", 40: "Politics", 41: "Product Management", 42: "Productivity", 43: "Programming", 44: "Programming Languages", 45: "Race", 46: "Relationships", 47: "Religion", 48: "Remote Work", 49: "Science", 50: "Security", 51: "Sexuality", 52: "Spirituality", 53: "Sports", 54: "Tech Companies", 55: "Television", 56: "Transportation", 57: "Travel", 58: "Writing", } def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs", data) if isinstance(inputs, str): inputs = [inputs] # run normal prediction scores = self.model.predict_proba(inputs)[0] return [ {"label": self.id2label[i], "score": score.item()} for i, score in enumerate(scores) ]