MarkusWesterwald commited on
Commit
619d92b
1 Parent(s): 68dcb64

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +90 -0
handler.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from setfit import SetFitModel
3
+
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # load model
8
+ self.model = SetFitModel.from_pretrained(path)
9
+ # ag_news id to label mapping
10
+ self.id2label = {
11
+ 0: "Art",
12
+ 1: "Artificial Intelligence",
13
+ 2: "Beauty",
14
+ 3: "Blockchain",
15
+ 4: "Business",
16
+ 5: "Cities",
17
+ 6: "Cultural Studies",
18
+ 7: "Data Science",
19
+ 8: "Design",
20
+ 9: "Dev Ops",
21
+ 10: "Drugs",
22
+ 11: "Economics",
23
+ 12: "Education",
24
+ 13: "Equality",
25
+ 14: "Family",
26
+ 15: "Fashion",
27
+ 16: "Finance",
28
+ 17: "Food",
29
+ 18: "Gadgets",
30
+ 19: "Gaming",
31
+ 20: "Health",
32
+ 21: "Home",
33
+ 22: "Humor",
34
+ 23: "Language",
35
+ 24: "Law",
36
+ 25: "Leadership",
37
+ 26: "Makers",
38
+ 27: "Marketing",
39
+ 28: "Mathematics",
40
+ 29: "Mental Health",
41
+ 30: "Mindfulness",
42
+ 31: "Movies",
43
+ 32: "Music",
44
+ 33: "Nature",
45
+ 34: "News",
46
+ 35: "Operating Systems",
47
+ 36: "Pets",
48
+ 37: "Philosophy",
49
+ 38: "Photography",
50
+ 39: "Podcasts",
51
+ 40: "Politics",
52
+ 41: "Product Management",
53
+ 42: "Productivity",
54
+ 43: "Programming",
55
+ 44: "Programming Languages",
56
+ 45: "Race",
57
+ 46: "Relationships",
58
+ 47: "Religion",
59
+ 48: "Remote Work",
60
+ 49: "Science",
61
+ 50: "Security",
62
+ 51: "Sexuality",
63
+ 52: "Spirituality",
64
+ 53: "Sports",
65
+ 54: "Tech Companies",
66
+ 55: "Television",
67
+ 56: "Transportation",
68
+ 57: "Travel",
69
+ 58: "Writing",
70
+ }
71
+
72
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
73
+ """
74
+ data args:
75
+ inputs (:obj: `str`)
76
+ Return:
77
+ A :obj:`list` | `dict`: will be serialized and returned
78
+ """
79
+ # get inputs
80
+ inputs = data.pop("inputs", data)
81
+ if isinstance(inputs, str):
82
+ inputs = [inputs]
83
+
84
+ # run normal prediction
85
+ scores = self.model.predict_proba(inputs)[0]
86
+
87
+ return [
88
+ {"label": self.id2label[i], "score": score.item()}
89
+ for i, score in enumerate(scores)
90
+ ]