Vrk commited on
Commit
cbb407f
1 Parent(s): cecc69d
Files changed (1) hide show
  1. Tokenizer.py +71 -0
Tokenizer.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+
4
+ class Tokenizer(object):
5
+ def __init__(self, char_level, num_tokens=None,
6
+ pad_token="<PAD>", oov_token="<UNK>",
7
+ token_to_index=None):
8
+ self.char_level = char_level
9
+ self.separator = "" if self.char_level else " "
10
+ if num_tokens: num_tokens -= 2 # pad + unk tokens
11
+ self.num_tokens = num_tokens
12
+ self.pad_token = pad_token
13
+ self.oov_token = oov_token
14
+ if not token_to_index:
15
+ token_to_index = {pad_token: 0, oov_token: 1}
16
+ self.token_to_index = token_to_index
17
+ self.index_to_token = {v: k for k, v in self.token_to_index.items()}
18
+
19
+ def __len__(self):
20
+ return len(self.token_to_index)
21
+
22
+ def __str__(self):
23
+ return f"<Tokenizer(num_tokens={len(self)})>"
24
+
25
+ def fit_on_texts(self, texts):
26
+ if not self.char_level:
27
+ texts = [text.split(" ") for text in texts]
28
+ all_tokens = [token for text in texts for token in text]
29
+ counts = Counter(all_tokens).most_common(self.num_tokens)
30
+ self.min_token_freq = counts[-1][1]
31
+ for token, count in counts:
32
+ index = len(self)
33
+ self.token_to_index[token] = index
34
+ self.index_to_token[index] = token
35
+ return self
36
+
37
+ def texts_to_sequences(self, texts):
38
+ sequences = []
39
+ for text in texts:
40
+ if not self.char_level:
41
+ text = text.split(" ")
42
+ sequence = []
43
+ for token in text:
44
+ sequence.append(self.token_to_index.get(
45
+ token, self.token_to_index[self.oov_token]))
46
+ sequences.append(np.asarray(sequence))
47
+ return sequences
48
+
49
+ def sequences_to_texts(self, sequences):
50
+ texts = []
51
+ for sequence in sequences:
52
+ text = []
53
+ for index in sequence:
54
+ text.append(self.index_to_token.get(index, self.oov_token))
55
+ texts.append(self.separator.join([token for token in text]))
56
+ return texts
57
+
58
+ def save(self, fp):
59
+ with open(fp, "w") as fp:
60
+ contents = {
61
+ "char_level": self.char_level,
62
+ "oov_token": self.oov_token,
63
+ "token_to_index": self.token_to_index
64
+ }
65
+ json.dump(contents, fp, indent=4, sort_keys=False)
66
+
67
+ @classmethod
68
+ def load(cls, fp):
69
+ with open(fp, "r") as fp:
70
+ kwargs = json.load(fp=fp)
71
+ return cls(**kwargs)