gabriel-p commited on
Commit
69558da
1 Parent(s): ee6a09d

Add Spanish true caser

Browse files
Files changed (4) hide show
  1. .gitattributes +2 -0
  2. TrueCaser.py +130 -0
  3. english.dist +3 -0
  4. spanish.dist +3 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ spanish.dist filter=lfs diff=lfs merge=lfs -text
36
+ english.dist filter=lfs diff=lfs merge=lfs -text
TrueCaser.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import pickle
3
+ import string
4
+
5
+ from nltk.tokenize import word_tokenize
6
+ from nltk.tokenize.treebank import TreebankWordDetokenizer
7
+
8
+
9
+ class TrueCaser(object):
10
+ def __init__(self, dist_file_path):
11
+ with open(dist_file_path, "rb") as distributions_file:
12
+ pickle_dict = pickle.load(distributions_file)
13
+ self.uni_dist = pickle_dict["uni_dist"]
14
+ self.backward_bi_dist = pickle_dict["backward_bi_dist"]
15
+ self.forward_bi_dist = pickle_dict["forward_bi_dist"]
16
+ self.trigram_dist = pickle_dict["trigram_dist"]
17
+ self.word_casing_lookup = pickle_dict["word_casing_lookup"]
18
+ self.detknzr = TreebankWordDetokenizer()
19
+
20
+ def get_score(self, prev_token, possible_token, next_token):
21
+ pseudo_count = 5.0
22
+
23
+ # Get Unigram Score
24
+ numerator = self.uni_dist[possible_token] + pseudo_count
25
+ denominator = 0
26
+ for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
27
+ denominator += self.uni_dist[alternativeToken] + pseudo_count
28
+
29
+ unigram_score = numerator / denominator
30
+
31
+ # Get Backward Score
32
+ bigram_backward_score = 1
33
+ if prev_token is not None:
34
+ key = prev_token + "_" + possible_token
35
+ numerator = self.backward_bi_dist[key] + pseudo_count
36
+ denominator = 0
37
+ for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
38
+ key = prev_token + "_" + alternativeToken
39
+ denominator += self.backward_bi_dist[key] + pseudo_count
40
+
41
+ bigram_backward_score = numerator / denominator
42
+
43
+ # Get Forward Score
44
+ bigram_forward_score = 1
45
+ if next_token is not None:
46
+ next_token = next_token.lower() # Ensure it is lower case
47
+ key = possible_token + "_" + next_token
48
+ numerator = self.forward_bi_dist[key] + pseudo_count
49
+ denominator = 0
50
+ for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
51
+ key = alternativeToken + "_" + next_token
52
+ denominator += self.forward_bi_dist[key] + pseudo_count
53
+
54
+ bigram_forward_score = numerator / denominator
55
+
56
+ # Get Trigram Score
57
+ trigram_score = 1
58
+ if prev_token is not None and next_token is not None:
59
+ next_token = next_token.lower() # Ensure it is lower case
60
+ trigram_key = prev_token + "_" + possible_token + "_" + next_token
61
+ numerator = self.trigram_dist[trigram_key] + pseudo_count
62
+ denominator = 0
63
+ for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
64
+ trigram_key = prev_token + "_" + alternativeToken + "_" + next_token
65
+ denominator += self.trigram_dist[trigram_key] + pseudo_count
66
+
67
+ trigram_score = numerator / denominator
68
+
69
+ result = (
70
+ math.log(unigram_score)
71
+ + math.log(bigram_backward_score)
72
+ + math.log(bigram_forward_score)
73
+ + math.log(trigram_score)
74
+ )
75
+
76
+ return result
77
+
78
+ @staticmethod
79
+ def first_token_case(raw):
80
+ return raw.capitalize()
81
+
82
+ def get_true_case(self, sentence, out_of_vocabulary_token_option="title"):
83
+ tokens = word_tokenize(sentence)
84
+ tokens_true_case = self.get_true_case_from_tokens(tokens, out_of_vocabulary_token_option)
85
+ return self.detknzr.detokenize(tokens_true_case)
86
+
87
+ def get_true_case_from_tokens(self, tokens, out_of_vocabulary_token_option="title"):
88
+ tokens_true_case = []
89
+
90
+ if not len(tokens):
91
+ return tokens_true_case
92
+
93
+ for token_idx, token in enumerate(tokens):
94
+ if token in string.punctuation or token.isdigit():
95
+ tokens_true_case.append(token)
96
+ continue
97
+
98
+ token = token.lower()
99
+ if token not in self.word_casing_lookup: # Token out of vocabulary
100
+ if out_of_vocabulary_token_option == "title":
101
+ tokens_true_case.append(token.title())
102
+ elif out_of_vocabulary_token_option == "capitalize":
103
+ tokens_true_case.append(token.capitalize())
104
+ elif out_of_vocabulary_token_option == "lower":
105
+ tokens_true_case.append(token.lower())
106
+ else:
107
+ tokens_true_case.append(token)
108
+ continue
109
+
110
+ if len(self.word_casing_lookup[token]) == 1:
111
+ tokens_true_case.append(list(self.word_casing_lookup[token])[0])
112
+ continue
113
+
114
+ prev_token = tokens_true_case[token_idx - 1] if token_idx > 0 else None
115
+ next_token = tokens[token_idx + 1] if token_idx < len(tokens) - 1 else None
116
+
117
+ best_token = None
118
+ highest_score = float("-inf")
119
+
120
+ for possible_token in self.word_casing_lookup[token]:
121
+ score = self.get_score(prev_token, possible_token, next_token)
122
+
123
+ if score > highest_score:
124
+ best_token = possible_token
125
+ highest_score = score
126
+
127
+ tokens_true_case.append(best_token)
128
+
129
+ tokens_true_case[0] = self.first_token_case(tokens_true_case[0])
130
+ return tokens_true_case
english.dist ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8a93297fa1e415e3c7dcee0d50068d71cf273a24b14f31a28e7d1415c84462a
3
+ size 57318894
spanish.dist ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46f93c82fcfeac191cd2f8e21ea4e582d8ab2d3e6d54ad822607d671bcdc3657
3
+ size 215488323