File size: 5,931 Bytes
1f30dbc
0def03f
 
 
18e3201
1f30dbc
 
7b62017
3c30fa3
 
 
7b62017
 
 
 
 
 
 
 
 
 
 
 
 
 
1f30dbc
 
 
0def03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b62017
 
 
3c30fa3
 
0def03f
3c30fa3
 
 
 
 
 
 
 
 
 
9ec7b19
3c30fa3
 
9ec7b19
3c30fa3
 
 
7b62017
 
 
3c30fa3
 
 
 
1f30dbc
 
3c30fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f30dbc
7b62017
 
 
0def03f
 
3c30fa3
 
 
 
 
 
 
7b62017
 
1f30dbc
 
 
 
 
 
7b62017
1f30dbc
0def03f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b62017
0def03f
 
 
 
 
 
 
 
3c30fa3
18e3201
 
 
 
 
 
 
 
 
 
3c30fa3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import re
import unicodedata
from typing import Dict
from requests.exceptions import HTTPError

import kenlm
import sentencepiece
from huggingface_hub import cached_download, hf_hub_url

KENLM_MODEL_REPO = "edugp/kenlm"


class SentencePiece:
    def __init__(
        self,
        model: str,
    ):
        super().__init__()
        self.sp = sentencepiece.SentencePieceProcessor()
        self.sp.load(str(model))

    def do(self, text: dict) -> dict:
        tokenized = self.sp.encode_as_pieces(text)
        return " ".join(tokenized)


class KenlmModel:
    digit_re: re.Pattern = re.compile(r"\d")
    unicode_punct: Dict[str, str] = {
        ",": ",",
        "。": ".",
        "、": ",",
        "„": '"',
        "”": '"',
        "“": '"',
        "«": '"',
        "»": '"',
        "1": '"',
        "」": '"',
        "「": '"',
        "《": '"',
        "》": '"',
        "´": "'",
        "∶": ":",
        ":": ":",
        "?": "?",
        "!": "!",
        "(": "(",
        ")": ")",
        ";": ";",
        "–": "-",
        "—": " - ",
        ".": ". ",
        "~": "~",
        "’": "'",
        "…": "...",
        "━": "-",
        "〈": "<",
        "〉": ">",
        "【": "[",
        "】": "]",
        "%": "%",
        "►": "-",
    }
    unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
    non_printing_chars_re = re.compile(
        f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
    )
    kenlm_model_dir = None
    sentence_piece_model_dir = None

    def __init__(
        self,
        model_dataset: str,
        language: str,
        lower_case: bool = False,
        remove_accents: bool = False,
        normalize_numbers: bool = True,
        punctuation: int = 1,
    ):
        self.download_kenlm_model(model_dataset, language)
        try:
            self.model = kenlm.Model(self.kenlm_model_dir)
            self.tokenizer = SentencePiece(self.sentence_piece_model_dir)
        except OSError:
            os.remove(self.kenlm_model_dir)
            if os.path.exists(self.sentence_piece_model_dir):
                os.remove(self.sentence_piece_model_dir)
            raise OSError(
                "File was corrupt and should have been removed. Please, retry."
            )
        self.accent = remove_accents
        self.case = lower_case
        self.numbers = normalize_numbers
        self.punct = punctuation

    @classmethod
    def from_pretrained(
        cls,
        model_dataset: str,
        language: str,
        lower_case: bool,
        remove_accents: bool,
        normalize_numbers: bool,
        punctuation: int,
    ):
        return cls(
            model_dataset,
            language,
            lower_case,
            remove_accents,
            normalize_numbers,
            punctuation,
        )

    def pp(self, log_score, length):
        return 10.0 ** (-log_score / length)

    def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
        if normalize_cc_net:
            doc = self.normalize(
                doc,
                accent=self.accent,
                case=self.case,
                numbers=self.numbers,
                punct=self.punct,
            )
        # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
        doc = self.tokenizer.do(doc)
        doc_log_score, doc_length = 0, 0
        for line in doc.split("\n"):
            log_score = self.model.score(line)
            length = len(line.split()) + 1
            doc_log_score += log_score
            doc_length += length
        return round(self.pp(doc_log_score, doc_length), 1)

    def normalize(
        self,
        line: str,
        accent: bool = True,
        case: bool = True,
        numbers: bool = True,
        punct: int = 1,
    ) -> str:
        line = line.strip()
        if not line:
            return line
        if case:
            line = line.lower()
        if accent:
            line = self.strip_accents(line)
        if numbers:
            line = self.digit_re.sub("0", line)
        if punct == 1:
            line = self.replace_unicode_punct(line)
        elif punct == 2:
            line = self.remove_unicode_punct(line)
        line = self.remove_non_printing_char(line)
        return line

    def strip_accents(self, line: str) -> str:
        """Strips accents from a piece of text."""
        nfd = unicodedata.normalize("NFD", line)
        output = [c for c in nfd if unicodedata.category(c) != "Mn"]
        if len(output) == line:
            return line
        return "".join(output)

    def replace_unicode_punct(self, text: str) -> str:
        return "".join(self.unicode_punct.get(c, c) for c in text)

    def remove_unicode_punct(self, text: str) -> str:
        """More aggressive version of replace_unicode_punct but also faster."""
        return self.unicode_punct_re.sub("", text)

    def remove_non_printing_char(self, text: str) -> str:
        return self.non_printing_chars_re.sub("", text)

    def download_kenlm_model(self, model_dataset: str, language: str):
        try:
            kenlm_model_url = hf_hub_url(
                KENLM_MODEL_REPO, filename=f"{model_dataset}/{language}.arpa.trie.bin"
            )
            self.kenlm_model_dir = cached_download(kenlm_model_url)
        except HTTPError:
            kenlm_model_url = hf_hub_url(
                KENLM_MODEL_REPO, filename=f"{model_dataset}/{language}.arpa.bin"
            )
            self.kenlm_model_dir = cached_download(kenlm_model_url)
        sentence_piece_model_url = hf_hub_url(
            KENLM_MODEL_REPO, filename=f"{model_dataset}/{language}.sp.model"
        )
        self.sentence_piece_model_dir = cached_download(sentence_piece_model_url)