MarkusWesterwald's picture
Upload handler.py
619d92b
raw
history blame
No virus
2.55 kB
from typing import Dict, List, Any
from setfit import SetFitModel
class EndpointHandler:
def __init__(self, path=""):
# load model
self.model = SetFitModel.from_pretrained(path)
# ag_news id to label mapping
self.id2label = {
0: "Art",
1: "Artificial Intelligence",
2: "Beauty",
3: "Blockchain",
4: "Business",
5: "Cities",
6: "Cultural Studies",
7: "Data Science",
8: "Design",
9: "Dev Ops",
10: "Drugs",
11: "Economics",
12: "Education",
13: "Equality",
14: "Family",
15: "Fashion",
16: "Finance",
17: "Food",
18: "Gadgets",
19: "Gaming",
20: "Health",
21: "Home",
22: "Humor",
23: "Language",
24: "Law",
25: "Leadership",
26: "Makers",
27: "Marketing",
28: "Mathematics",
29: "Mental Health",
30: "Mindfulness",
31: "Movies",
32: "Music",
33: "Nature",
34: "News",
35: "Operating Systems",
36: "Pets",
37: "Philosophy",
38: "Photography",
39: "Podcasts",
40: "Politics",
41: "Product Management",
42: "Productivity",
43: "Programming",
44: "Programming Languages",
45: "Race",
46: "Relationships",
47: "Religion",
48: "Remote Work",
49: "Science",
50: "Security",
51: "Sexuality",
52: "Spirituality",
53: "Sports",
54: "Tech Companies",
55: "Television",
56: "Transportation",
57: "Travel",
58: "Writing",
}
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
if isinstance(inputs, str):
inputs = [inputs]
# run normal prediction
scores = self.model.predict_proba(inputs)[0]
return [
{"label": self.id2label[i], "score": score.item()}
for i, score in enumerate(scores)
]