Vrk commited on
Commit
f0c2a78
1 Parent(s): 72322f6
Files changed (1) hide show
  1. LabelEncoder.py +46 -0
LabelEncoder.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+
4
+ class LabelEncoder(object):
5
+ """Label encoder for tag labels."""
6
+ def __init__(self, class_to_index={}):
7
+ self.class_to_index = class_to_index
8
+ self.index_to_class = {v: k for k, v in self.class_to_index.items()}
9
+ self.classes = list(self.class_to_index.keys())
10
+
11
+ def __len__(self):
12
+ return len(self.class_to_index)
13
+
14
+ def __str__(self):
15
+ return f"<LabelEncoder(num_classes={len(self)})>"
16
+
17
+ def fit(self, y):
18
+ classes = np.unique(y)
19
+ for i, class_ in enumerate(classes):
20
+ self.class_to_index[class_] = i
21
+ self.index_to_class = {v: k for k, v in self.class_to_index.items()}
22
+ self.classes = list(self.class_to_index.keys())
23
+ return self
24
+
25
+ def encode(self, y):
26
+ encoded = np.zeros((len(y)), dtype=int)
27
+ for i, item in enumerate(y):
28
+ encoded[i] = self.class_to_index[item]
29
+ return encoded
30
+
31
+ def decode(self, y):
32
+ classes = []
33
+ for i, item in enumerate(y):
34
+ classes.append(self.index_to_class[item])
35
+ return classes
36
+
37
+ def save(self, fp):
38
+ with open(fp, "w") as fp:
39
+ contents = {'class_to_index': self.class_to_index}
40
+ json.dump(contents, fp, indent=4, sort_keys=False)
41
+
42
+ @classmethod
43
+ def load(cls, fp):
44
+ with open(fp, "r") as fp:
45
+ kwargs = json.load(fp=fp)
46
+ return cls(**kwargs)