Commit
•
fb27ada
1
Parent(s):
c179430
Delete model.py
Browse files
model.py
DELETED
@@ -1,163 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import unicodedata
|
4 |
-
from typing import Dict
|
5 |
-
|
6 |
-
import kenlm
|
7 |
-
import sentencepiece
|
8 |
-
from huggingface_hub import cached_download, hf_hub_url
|
9 |
-
|
10 |
-
|
11 |
-
class SentencePiece:
|
12 |
-
def __init__(
|
13 |
-
self,
|
14 |
-
model: str,
|
15 |
-
):
|
16 |
-
super().__init__()
|
17 |
-
self.sp = sentencepiece.SentencePieceProcessor()
|
18 |
-
self.sp.load(str(model))
|
19 |
-
|
20 |
-
def do(self, text: dict) -> dict:
|
21 |
-
tokenized = self.sp.encode_as_pieces(text)
|
22 |
-
return " ".join(tokenized)
|
23 |
-
|
24 |
-
|
25 |
-
class KenlmModel:
|
26 |
-
digit_re: re.Pattern = re.compile(r"\d")
|
27 |
-
unicode_punct: Dict[str, str] = {
|
28 |
-
",": ",",
|
29 |
-
"。": ".",
|
30 |
-
"、": ",",
|
31 |
-
"„": '"',
|
32 |
-
"”": '"',
|
33 |
-
"“": '"',
|
34 |
-
"«": '"',
|
35 |
-
"»": '"',
|
36 |
-
"1": '"',
|
37 |
-
"」": '"',
|
38 |
-
"「": '"',
|
39 |
-
"《": '"',
|
40 |
-
"》": '"',
|
41 |
-
"´": "'",
|
42 |
-
"∶": ":",
|
43 |
-
":": ":",
|
44 |
-
"?": "?",
|
45 |
-
"!": "!",
|
46 |
-
"(": "(",
|
47 |
-
")": ")",
|
48 |
-
";": ";",
|
49 |
-
"–": "-",
|
50 |
-
"—": " - ",
|
51 |
-
".": ". ",
|
52 |
-
"~": "~",
|
53 |
-
"’": "'",
|
54 |
-
"…": "...",
|
55 |
-
"━": "-",
|
56 |
-
"〈": "<",
|
57 |
-
"〉": ">",
|
58 |
-
"【": "[",
|
59 |
-
"】": "]",
|
60 |
-
"%": "%",
|
61 |
-
"►": "-",
|
62 |
-
}
|
63 |
-
unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
|
64 |
-
non_printing_chars_re = re.compile(
|
65 |
-
f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
|
66 |
-
)
|
67 |
-
kenlm_model_dir = None
|
68 |
-
sentence_piece_model_dir = None
|
69 |
-
|
70 |
-
def __init__(
|
71 |
-
self,
|
72 |
-
model_dataset: str,
|
73 |
-
language: str,
|
74 |
-
lower_case: bool = False,
|
75 |
-
remove_accents: bool = False,
|
76 |
-
normalize_numbers: bool = True,
|
77 |
-
punctuation: int = 1,
|
78 |
-
):
|
79 |
-
self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin"))
|
80 |
-
self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model"))
|
81 |
-
self.accent = remove_accents
|
82 |
-
self.case = lower_case
|
83 |
-
self.numbers = normalize_numbers
|
84 |
-
self.punct = punctuation
|
85 |
-
|
86 |
-
@classmethod
|
87 |
-
def from_pretrained(
|
88 |
-
cls,
|
89 |
-
model_dataset: str,
|
90 |
-
language: str,
|
91 |
-
):
|
92 |
-
return cls(
|
93 |
-
model_dataset,
|
94 |
-
language,
|
95 |
-
False,
|
96 |
-
False,
|
97 |
-
True,
|
98 |
-
1,
|
99 |
-
)
|
100 |
-
|
101 |
-
def pp(self, log_score, length):
|
102 |
-
return 10.0 ** (-log_score / length)
|
103 |
-
|
104 |
-
def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
|
105 |
-
if normalize_cc_net:
|
106 |
-
doc = self.normalize(
|
107 |
-
doc,
|
108 |
-
accent=self.accent,
|
109 |
-
case=self.case,
|
110 |
-
numbers=self.numbers,
|
111 |
-
punct=self.punct,
|
112 |
-
)
|
113 |
-
# Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
|
114 |
-
doc = self.tokenizer.do(doc)
|
115 |
-
doc_log_score, doc_length = 0, 0
|
116 |
-
for line in doc.split("\n"):
|
117 |
-
log_score = self.model.score(line)
|
118 |
-
length = len(line.split()) + 1
|
119 |
-
doc_log_score += log_score
|
120 |
-
doc_length += length
|
121 |
-
return round(self.pp(doc_log_score, doc_length), 1)
|
122 |
-
|
123 |
-
def normalize(
|
124 |
-
self,
|
125 |
-
line: str,
|
126 |
-
accent: bool = True,
|
127 |
-
case: bool = True,
|
128 |
-
numbers: bool = True,
|
129 |
-
punct: int = 1,
|
130 |
-
) -> str:
|
131 |
-
line = line.strip()
|
132 |
-
if not line:
|
133 |
-
return line
|
134 |
-
if case:
|
135 |
-
line = line.lower()
|
136 |
-
if accent:
|
137 |
-
line = self.strip_accents(line)
|
138 |
-
if numbers:
|
139 |
-
line = self.digit_re.sub("0", line)
|
140 |
-
if punct == 1:
|
141 |
-
line = self.replace_unicode_punct(line)
|
142 |
-
elif punct == 2:
|
143 |
-
line = self.remove_unicode_punct(line)
|
144 |
-
line = self.remove_non_printing_char(line)
|
145 |
-
return line
|
146 |
-
|
147 |
-
def strip_accents(self, line: str) -> str:
|
148 |
-
"""Strips accents from a piece of text."""
|
149 |
-
nfd = unicodedata.normalize("NFD", line)
|
150 |
-
output = [c for c in nfd if unicodedata.category(c) != "Mn"]
|
151 |
-
if len(output) == line:
|
152 |
-
return line
|
153 |
-
return "".join(output)
|
154 |
-
|
155 |
-
def replace_unicode_punct(self, text: str) -> str:
|
156 |
-
return "".join(self.unicode_punct.get(c, c) for c in text)
|
157 |
-
|
158 |
-
def remove_unicode_punct(self, text: str) -> str:
|
159 |
-
"""More aggressive version of replace_unicode_punct but also faster."""
|
160 |
-
return self.unicode_punct_re.sub("", text)
|
161 |
-
|
162 |
-
def remove_non_printing_char(self, text: str) -> str:
|
163 |
-
return self.non_printing_chars_re.sub("", text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|