philschmid HF staff commited on
Commit
fb27ada
1 Parent(s): c179430

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -163
model.py DELETED
@@ -1,163 +0,0 @@
1
- import os
2
- import re
3
- import unicodedata
4
- from typing import Dict
5
-
6
- import kenlm
7
- import sentencepiece
8
- from huggingface_hub import cached_download, hf_hub_url
9
-
10
-
11
- class SentencePiece:
12
- def __init__(
13
- self,
14
- model: str,
15
- ):
16
- super().__init__()
17
- self.sp = sentencepiece.SentencePieceProcessor()
18
- self.sp.load(str(model))
19
-
20
- def do(self, text: dict) -> dict:
21
- tokenized = self.sp.encode_as_pieces(text)
22
- return " ".join(tokenized)
23
-
24
-
25
- class KenlmModel:
26
- digit_re: re.Pattern = re.compile(r"\d")
27
- unicode_punct: Dict[str, str] = {
28
- ",": ",",
29
- "。": ".",
30
- "、": ",",
31
- "„": '"',
32
- "”": '"',
33
- "“": '"',
34
- "«": '"',
35
- "»": '"',
36
- "1": '"',
37
- "」": '"',
38
- "「": '"',
39
- "《": '"',
40
- "》": '"',
41
- "´": "'",
42
- "∶": ":",
43
- ":": ":",
44
- "?": "?",
45
- "!": "!",
46
- "(": "(",
47
- ")": ")",
48
- ";": ";",
49
- "–": "-",
50
- "—": " - ",
51
- ".": ". ",
52
- "~": "~",
53
- "’": "'",
54
- "…": "...",
55
- "━": "-",
56
- "〈": "<",
57
- "〉": ">",
58
- "【": "[",
59
- "】": "]",
60
- "%": "%",
61
- "►": "-",
62
- }
63
- unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
64
- non_printing_chars_re = re.compile(
65
- f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
66
- )
67
- kenlm_model_dir = None
68
- sentence_piece_model_dir = None
69
-
70
- def __init__(
71
- self,
72
- model_dataset: str,
73
- language: str,
74
- lower_case: bool = False,
75
- remove_accents: bool = False,
76
- normalize_numbers: bool = True,
77
- punctuation: int = 1,
78
- ):
79
- self.model = kenlm.Model(os.path.join(model_dataset, f"{language}.arpa.bin"))
80
- self.tokenizer = SentencePiece(os.path.join(model_dataset, f"{language}.sp.model"))
81
- self.accent = remove_accents
82
- self.case = lower_case
83
- self.numbers = normalize_numbers
84
- self.punct = punctuation
85
-
86
- @classmethod
87
- def from_pretrained(
88
- cls,
89
- model_dataset: str,
90
- language: str,
91
- ):
92
- return cls(
93
- model_dataset,
94
- language,
95
- False,
96
- False,
97
- True,
98
- 1,
99
- )
100
-
101
- def pp(self, log_score, length):
102
- return 10.0 ** (-log_score / length)
103
-
104
- def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
105
- if normalize_cc_net:
106
- doc = self.normalize(
107
- doc,
108
- accent=self.accent,
109
- case=self.case,
110
- numbers=self.numbers,
111
- punct=self.punct,
112
- )
113
- # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline
114
- doc = self.tokenizer.do(doc)
115
- doc_log_score, doc_length = 0, 0
116
- for line in doc.split("\n"):
117
- log_score = self.model.score(line)
118
- length = len(line.split()) + 1
119
- doc_log_score += log_score
120
- doc_length += length
121
- return round(self.pp(doc_log_score, doc_length), 1)
122
-
123
- def normalize(
124
- self,
125
- line: str,
126
- accent: bool = True,
127
- case: bool = True,
128
- numbers: bool = True,
129
- punct: int = 1,
130
- ) -> str:
131
- line = line.strip()
132
- if not line:
133
- return line
134
- if case:
135
- line = line.lower()
136
- if accent:
137
- line = self.strip_accents(line)
138
- if numbers:
139
- line = self.digit_re.sub("0", line)
140
- if punct == 1:
141
- line = self.replace_unicode_punct(line)
142
- elif punct == 2:
143
- line = self.remove_unicode_punct(line)
144
- line = self.remove_non_printing_char(line)
145
- return line
146
-
147
- def strip_accents(self, line: str) -> str:
148
- """Strips accents from a piece of text."""
149
- nfd = unicodedata.normalize("NFD", line)
150
- output = [c for c in nfd if unicodedata.category(c) != "Mn"]
151
- if len(output) == line:
152
- return line
153
- return "".join(output)
154
-
155
- def replace_unicode_punct(self, text: str) -> str:
156
- return "".join(self.unicode_punct.get(c, c) for c in text)
157
-
158
- def remove_unicode_punct(self, text: str) -> str:
159
- """More aggressive version of replace_unicode_punct but also faster."""
160
- return self.unicode_punct_re.sub("", text)
161
-
162
- def remove_non_printing_char(self, text: str) -> str:
163
- return self.non_printing_chars_re.sub("", text)