spanish-truecasing / TrueCaser.py
gabriel-p's picture
Update TrueCaser
9ca4613
import math
import pickle
import re
import string
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer
class TrueCaser(object):
def __init__(self, dist_file_path):
with open(dist_file_path, "rb") as distributions_file:
pickle_dict = pickle.load(distributions_file)
self.uni_dist = pickle_dict["uni_dist"]
self.backward_bi_dist = pickle_dict["backward_bi_dist"]
self.forward_bi_dist = pickle_dict["forward_bi_dist"]
self.trigram_dist = pickle_dict["trigram_dist"]
self.word_casing_lookup = pickle_dict["word_casing_lookup"]
self.detknzr = TreebankWordDetokenizer()
def get_score(self, prev_token, possible_token, next_token):
pseudo_count = 5.0
# Get Unigram Score
numerator = self.uni_dist[possible_token] + pseudo_count
denominator = 0
for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
denominator += self.uni_dist[alternativeToken] + pseudo_count
unigram_score = numerator / denominator
# Get Backward Score
bigram_backward_score = 1
if prev_token is not None:
key = prev_token + "_" + possible_token
numerator = self.backward_bi_dist[key] + pseudo_count
denominator = 0
for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
key = prev_token + "_" + alternativeToken
denominator += self.backward_bi_dist[key] + pseudo_count
bigram_backward_score = numerator / denominator
# Get Forward Score
bigram_forward_score = 1
if next_token is not None:
next_token = next_token.lower() # Ensure it is lower case
key = possible_token + "_" + next_token
numerator = self.forward_bi_dist[key] + pseudo_count
denominator = 0
for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
key = alternativeToken + "_" + next_token
denominator += self.forward_bi_dist[key] + pseudo_count
bigram_forward_score = numerator / denominator
# Get Trigram Score
trigram_score = 1
if prev_token is not None and next_token is not None:
next_token = next_token.lower() # Ensure it is lower case
trigram_key = prev_token + "_" + possible_token + "_" + next_token
numerator = self.trigram_dist[trigram_key] + pseudo_count
denominator = 0
for alternativeToken in self.word_casing_lookup[possible_token.lower()]:
trigram_key = prev_token + "_" + alternativeToken + "_" + next_token
denominator += self.trigram_dist[trigram_key] + pseudo_count
trigram_score = numerator / denominator
result = (
math.log(unigram_score)
+ math.log(bigram_backward_score)
+ math.log(bigram_forward_score)
+ math.log(trigram_score)
)
return result
@staticmethod
def first_token_case(raw):
return raw.capitalize()
@staticmethod
def upper_replacement(match):
return '. ' + match.group(0)[-1].upper()
def get_true_case(self, sentence, out_of_vocabulary_token_option="title"):
tokens = word_tokenize(sentence)
tokens_true_case = self.get_true_case_from_tokens(tokens, out_of_vocabulary_token_option)
text = self.detknzr.detokenize(tokens_true_case)
text = re.sub(r' \. .', self.upper_replacement, text)
return text
def get_true_case_from_tokens(self, tokens, out_of_vocabulary_token_option="title"):
tokens_true_case = []
if not len(tokens):
return tokens_true_case
for token_idx, token in enumerate(tokens):
if token in string.punctuation or token.isdigit():
tokens_true_case.append(token)
continue
token = token.lower()
if token not in self.word_casing_lookup: # Token out of vocabulary
if out_of_vocabulary_token_option == "title":
tokens_true_case.append(token.title())
elif out_of_vocabulary_token_option == "capitalize":
tokens_true_case.append(token.capitalize())
elif out_of_vocabulary_token_option == "lower":
tokens_true_case.append(token.lower())
else:
tokens_true_case.append(token)
continue
if len(self.word_casing_lookup[token]) == 1:
tokens_true_case.append(list(self.word_casing_lookup[token])[0])
continue
prev_token = tokens_true_case[token_idx - 1] if token_idx > 0 else None
next_token = tokens[token_idx + 1] if token_idx < len(tokens) - 1 else None
best_token = None
highest_score = float("-inf")
for possible_token in self.word_casing_lookup[token]:
score = self.get_score(prev_token, possible_token, next_token)
if score > highest_score:
best_token = possible_token
highest_score = score
tokens_true_case.append(best_token)
tokens_true_case[0] = self.first_token_case(tokens_true_case[0])
return tokens_true_case