|
from typing import Dict, List, Any |
|
from setfit import SetFitModel |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model = SetFitModel.from_pretrained(path) |
|
|
|
self.id2label = { |
|
0: "Art", |
|
1: "Artificial Intelligence", |
|
2: "Beauty", |
|
3: "Blockchain", |
|
4: "Business", |
|
5: "Cities", |
|
6: "Cultural Studies", |
|
7: "Data Science", |
|
8: "Design", |
|
9: "Dev Ops", |
|
10: "Drugs", |
|
11: "Economics", |
|
12: "Education", |
|
13: "Equality", |
|
14: "Family", |
|
15: "Fashion", |
|
16: "Finance", |
|
17: "Food", |
|
18: "Gadgets", |
|
19: "Gaming", |
|
20: "Health", |
|
21: "Home", |
|
22: "Humor", |
|
23: "Language", |
|
24: "Law", |
|
25: "Leadership", |
|
26: "Makers", |
|
27: "Marketing", |
|
28: "Mathematics", |
|
29: "Mental Health", |
|
30: "Mindfulness", |
|
31: "Movies", |
|
32: "Music", |
|
33: "Nature", |
|
34: "News", |
|
35: "Operating Systems", |
|
36: "Pets", |
|
37: "Philosophy", |
|
38: "Photography", |
|
39: "Podcasts", |
|
40: "Politics", |
|
41: "Product Management", |
|
42: "Productivity", |
|
43: "Programming", |
|
44: "Programming Languages", |
|
45: "Race", |
|
46: "Relationships", |
|
47: "Religion", |
|
48: "Remote Work", |
|
49: "Science", |
|
50: "Security", |
|
51: "Sexuality", |
|
52: "Spirituality", |
|
53: "Sports", |
|
54: "Tech Companies", |
|
55: "Television", |
|
56: "Transportation", |
|
57: "Travel", |
|
58: "Writing", |
|
} |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
|
|
|
|
scores = self.model.predict_proba(inputs)[0] |
|
|
|
return [ |
|
{"label": self.id2label[i], "score": score.item()} |
|
for i, score in enumerate(scores) |
|
] |
|
|