Spaces:
Runtime error
Runtime error
import os | |
import urllib.request | |
import kenlm | |
class KenlmModel: | |
def __init__(self, language): | |
download_kenlm_model(language) | |
try: | |
self.model = kenlm.Model(f"{language}.arpa.bin") | |
except OSError: | |
os.remove(f"{language}.arpa.bin") | |
if os.path.exists(f"{language}.sp.model"): | |
os.remove(f"{language}.sp.model") | |
raise OSError("File was corrupt and should have been removed. Please, retry.") | |
def from_pretrained(cls, language: str): | |
return cls(language) | |
def get_perplexity(self, doc: str): | |
doc_log_score, doc_length = 0, 0 | |
for line in doc.split("\n"): | |
log_score = self.model.score(line) | |
length = len(line.split()) + 1 | |
doc_log_score += log_score | |
doc_length += length | |
return 10.0 ** (-doc_log_score / doc_length) | |
def download_kenlm_model(language: str): | |
root_url = "http://dl.fbaipublicfiles.com/cc_net/lm" | |
bin_name = f"{language}.arpa.bin" | |
model_name = f"{language}.sp.model" | |
bin_url = f"{root_url}/{bin_name}" | |
model_url = f"{root_url}/{model_name}" | |
if not os.path.isfile(bin_name): | |
urllib.request.urlretrieve(bin_url, bin_name) | |
if not os.path.isfile(model_name): | |
urllib.request.urlretrieve(model_url, model_name) | |