MarkusWesterwald commited on
Commit
4bd497b
1 Parent(s): d3a8134

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -62
handler.py CHANGED
@@ -1,73 +1,16 @@
1
  from typing import Dict, List, Any
2
  from setfit import SetFitModel
 
3
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  # load model
8
  self.model = SetFitModel.from_pretrained(path)
9
- # ag_news id to label mapping
10
- self.id2label = {
11
- 0: "Art",
12
- 1: "Artificial Intelligence",
13
- 2: "Beauty",
14
- 3: "Blockchain",
15
- 4: "Business",
16
- 5: "Cities",
17
- 6: "Cultural Studies",
18
- 7: "Data Science",
19
- 8: "Design",
20
- 9: "Dev Ops",
21
- 10: "Drugs",
22
- 11: "Economics",
23
- 12: "Education",
24
- 13: "Equality",
25
- 14: "Family",
26
- 15: "Fashion",
27
- 16: "Finance",
28
- 17: "Food",
29
- 18: "Gadgets",
30
- 19: "Gaming",
31
- 20: "Health",
32
- 21: "Home",
33
- 22: "Humor",
34
- 23: "Language",
35
- 24: "Law",
36
- 25: "Leadership",
37
- 26: "Makers",
38
- 27: "Marketing",
39
- 28: "Mathematics",
40
- 29: "Mental Health",
41
- 30: "Mindfulness",
42
- 31: "Movies",
43
- 32: "Music",
44
- 33: "Nature",
45
- 34: "News",
46
- 35: "Operating Systems",
47
- 36: "Pets",
48
- 37: "Philosophy",
49
- 38: "Photography",
50
- 39: "Podcasts",
51
- 40: "Politics",
52
- 41: "Product Management",
53
- 42: "Productivity",
54
- 43: "Programming",
55
- 44: "Programming Languages",
56
- 45: "Race",
57
- 46: "Relationships",
58
- 47: "Religion",
59
- 48: "Remote Work",
60
- 49: "Science",
61
- 50: "Security",
62
- 51: "Sexuality",
63
- 52: "Spirituality",
64
- 53: "Sports",
65
- 54: "Tech Companies",
66
- 55: "Television",
67
- 56: "Transportation",
68
- 57: "Travel",
69
- 58: "Writing",
70
- }
71
 
72
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
73
  """
 
1
  from typing import Dict, List, Any
2
  from setfit import SetFitModel
3
+ import json
4
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
  # load model
9
  self.model = SetFitModel.from_pretrained(path)
10
+
11
+ with open('/repository/label_config.json', 'r') as file:
12
+ raw = json.load(file)
13
+ self.id2label = {int(k): v for k, v in raw.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
  """