File size: 1,355 Bytes
1f30dbc
 
 
 
 
 
 
 
 
9ec7b19
 
 
 
ab7449f
 
9ec7b19
1f30dbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import urllib.request

import kenlm


class KenlmModel:
    def __init__(self, language):
        download_kenlm_model(language)
        try:
            self.model = kenlm.Model(f"{language}.arpa.bin")
        except OSError:
            os.remove(f"{language}.arpa.bin")
            if os.path.exists(f"{language}.sp.model"):
                os.remove(f"{language}.sp.model")
            raise OSError("File was corrupt and should have been removed. Please, retry.")

    @classmethod
    def from_pretrained(cls, language: str):
        return cls(language)

    def get_perplexity(self, doc: str):
        doc_log_score, doc_length = 0, 0
        for line in doc.split("\n"):
            log_score = self.model.score(line)
            length = len(line.split()) + 1
            doc_log_score += log_score
            doc_length += length
        return 10.0 ** (-doc_log_score / doc_length)


def download_kenlm_model(language: str):
    root_url = "http://dl.fbaipublicfiles.com/cc_net/lm"
    bin_name = f"{language}.arpa.bin"
    model_name = f"{language}.sp.model"
    bin_url = f"{root_url}/{bin_name}"
    model_url = f"{root_url}/{model_name}"

    if not os.path.isfile(bin_name):
        urllib.request.urlretrieve(bin_url, bin_name)

    if not os.path.isfile(model_name):
        urllib.request.urlretrieve(model_url, model_name)