boris commited on
Commit
1212a74
1 Parent(s): 1c83da9

feat: add text utilities

Browse files
Files changed (1) hide show
  1. dalle_mini/text.py +268 -0
dalle_mini/text.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for processing text.
3
+ """
4
+
5
+ import requests
6
+ from pathlib import Path
7
+ from unidecode import unidecode
8
+ import re, math, random, html
9
+
10
+
11
+ WIKI_STATS_URL = "https://github.com/borisdayma/wikipedia-word-frequency/raw/feat-update/results/enwiki-20210820-words-frequency.txt"
12
+ WIKI_STATS_LOCAL = Path(WIKI_STATS_URL).parts[-1]
13
+
14
+ # based on wiki word occurence
15
+ person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
16
+ temp_token = "xtokx" # avoid repeating chars
17
+
18
+
19
+ def get_wiki_file():
20
+ if not Path(WIKI_STATS_LOCAL).exists():
21
+ r = requests.get(WIKI_STATS_URL, stream=True)
22
+ with open(WIKI_STATS_LOCAL, "wb") as fd:
23
+ for chunk in r.iter_content(chunk_size=128):
24
+ fd.write(chunk)
25
+ return WIKI_STATS_LOCAL
26
+
27
+
28
+ class HashtagProcessor:
29
+ # Adapted from wordninja library
30
+ # We use our wikipedia word count + a good heuristic to make it work
31
+ def __init__(self):
32
+ self._word_cost = (
33
+ l.split()[0] for l in Path(get_wiki_file()).read_text().splitlines()
34
+ )
35
+ self._word_cost = {
36
+ str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
37
+ }
38
+ self._max_word = max(len(x) for x in self._word_cost.keys())
39
+ self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
40
+
41
+ def __call__(self, s):
42
+ """Uses dynamic programming to infer the location of spaces in a string without spaces."""
43
+ l = [self._split(x) for x in self._SPLIT_RE.split(s)]
44
+ return " ".join([item for sublist in l for item in sublist])
45
+
46
+ def _split(self, s):
47
+ # Find the best match for the i first characters, assuming cost has
48
+ # been built for the i-1 first characters.
49
+ # Returns a pair (match_cost, match_length).
50
+ def best_match(i):
51
+ candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
52
+ return min(
53
+ (c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1)
54
+ for k, c in candidates
55
+ )
56
+
57
+ # Build the cost array
58
+ cost = [0]
59
+ for i in range(1, len(s) + 1):
60
+ c, k = best_match(i)
61
+ cost.append(c)
62
+
63
+ # Backtrack to recover the minimal-cost string.
64
+ out = []
65
+ i = len(s)
66
+ while i > 0:
67
+ c, k = best_match(i)
68
+ assert c == cost[i]
69
+ newToken = True
70
+ if not s[i - k : i] == "'": # ignore a lone apostrophe
71
+ if len(out) > 0:
72
+ # re-attach split 's and split digits
73
+ if out[-1] == "'s" or (
74
+ s[i - 1].isdigit() and out[-1][0].isdigit()
75
+ ): # digit followed by digit
76
+ out[-1] = (
77
+ s[i - k : i] + out[-1]
78
+ ) # combine current token with previous token
79
+ newToken = False
80
+
81
+ if newToken:
82
+ out.append(s[i - k : i])
83
+
84
+ i -= k
85
+
86
+ return reversed(out)
87
+
88
+
89
+ def replace_person_token(t):
90
+ "Used for CC12M"
91
+ t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
92
+ while "<person>" in t:
93
+ t = t.replace(
94
+ "<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1
95
+ )
96
+ return t
97
+
98
+
99
+ def fix_html(t):
100
+ "Adapted from fastai"
101
+ t = (
102
+ t.replace("#39;", "'")
103
+ .replace("&amp;", "&")
104
+ .replace("amp;", "&")
105
+ .replace("#146;", "'")
106
+ .replace("nbsp;", " ")
107
+ .replace("#36;", "$")
108
+ .replace("\\n", "\n")
109
+ .replace("quot;", "'")
110
+ .replace("<br />", "\n")
111
+ .replace('\\"', '"')
112
+ .replace("<unk>", " ")
113
+ .replace(" @.@ ", ".")
114
+ .replace(" @-@ ", "-")
115
+ )
116
+ return html.unescape(t)
117
+
118
+
119
+ def replace_punctuation_with_commas(t):
120
+ return re.sub("""([()[\].,|:;?!=+~\-])""", ",", t)
121
+
122
+
123
+ def simplify_quotes(t):
124
+ return re.sub("""['"`]""", ' " ', t)
125
+
126
+
127
+ def merge_quotes(t):
128
+ return re.sub('(\s*"+\s*)+', ' " ', t)
129
+
130
+
131
+ def remove_comma_numbers(t):
132
+ def _f(t):
133
+ return re.sub("(\d),(\d{3})", r"\1\2", t)
134
+
135
+ return _f(_f(t))
136
+
137
+
138
+ def pre_process_dot_numbers(t):
139
+ return re.sub("(\d)\.(\d)", fr"\1{temp_token}dot{temp_token}\2", t)
140
+
141
+
142
+ def post_process_dot_numbers(t):
143
+ return re.sub(f"{temp_token}dot{temp_token}", ".", t)
144
+
145
+
146
+ def pre_process_quotes(t):
147
+ # allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
148
+ return re.sub(
149
+ r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", fr"{temp_token}quote{temp_token}", t
150
+ )
151
+
152
+
153
+ def post_process_quotes(t):
154
+ return re.sub(f"{temp_token}quote{temp_token}", "'", t)
155
+
156
+
157
+ def merge_commas(t):
158
+ return re.sub("(\s*,+\s*)+", ", ", t)
159
+
160
+
161
+ def add_space_after_commas(t):
162
+ return re.sub(",", ", ", t)
163
+
164
+
165
+ def handle_special_chars(t):
166
+ "Handle special characters"
167
+ # replace "-" with a space when between words without space
168
+ t = re.sub("([a-zA-Z])-([a-zA-Z])", r"\1 \2", t)
169
+ # always add space around &
170
+ return re.sub("&", " & ", t)
171
+
172
+
173
+ def expand_hashtags(t, hashtag_processor):
174
+ "Remove # and try to split words"
175
+ return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
176
+
177
+
178
+ _re_ignore_chars = """[_#\/\\%]"""
179
+
180
+
181
+ def ignore_chars(t):
182
+ "Ignore useless characters"
183
+ return re.sub(_re_ignore_chars, " ", t)
184
+
185
+
186
+ def remove_extra_spaces(t):
187
+ "Remove extra spaces (including \t and \n)"
188
+ return re.sub("\s+", " ", t)
189
+
190
+
191
+ def remove_repeating_chars(t):
192
+ "If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
193
+ return re.sub(r"(\D)(\1{3,})", r"\1", t)
194
+
195
+
196
+ def remove_urls(t):
197
+ return re.sub(r"http\S+", "", t)
198
+
199
+
200
+ def remove_html_tags(t):
201
+ return re.sub("<[^<]+?>", "", t)
202
+
203
+
204
+ def remove_first_last_commas(t):
205
+ t = t.strip()
206
+ t = t[:-1] if t and t[-1] == "," else t
207
+ t = t[1:] if t and t[0] == "," else t
208
+ return t.strip()
209
+
210
+
211
+ def remove_wiki_ref(t):
212
+ t = re.sub(r"\A\s*\[\d+\]", "", t)
213
+ return re.sub(r"\[\d+\]\s*\Z", "", t)
214
+
215
+
216
+ class TextNormalizer:
217
+ "Normalize text"
218
+
219
+ def __init__(self):
220
+ self._hashtag_processor = HashtagProcessor()
221
+
222
+ def __call__(self, t, clip=False):
223
+ # fix html
224
+ t = fix_html(t)
225
+ if not clip:
226
+ # decode and simplify text: see unidecode library
227
+ t = unidecode(t)
228
+ # lower case
229
+ t = t.lower()
230
+ # replace <PERSON> (for CC12M)
231
+ t = replace_person_token(t)
232
+ # remove wiki reference (for WIT)
233
+ t = remove_wiki_ref(t)
234
+ # remove html tags
235
+ t = remove_html_tags(t)
236
+ # remove urls
237
+ t = remove_urls(t)
238
+ # remove commas in numbers
239
+ t = remove_comma_numbers(t)
240
+ if not clip:
241
+ # handle dots in numbers and quotes - Part 1
242
+ t = pre_process_dot_numbers(t)
243
+ t = pre_process_quotes(t)
244
+ # handle special characters
245
+ t = handle_special_chars(t)
246
+ # handle hashtags
247
+ t = expand_hashtags(t, self._hashtag_processor)
248
+ # ignore useless characters
249
+ t = ignore_chars(t)
250
+ # simplify quotes
251
+ t = simplify_quotes(t)
252
+ # all punctuation becomes commas
253
+ t = replace_punctuation_with_commas(t)
254
+ # handle dots in numbers and quotes - Part 2
255
+ t = post_process_dot_numbers(t)
256
+ t = post_process_quotes(t)
257
+ # handle repeating characters
258
+ t = remove_repeating_chars(t)
259
+ # merge commas
260
+ t = merge_commas(t)
261
+ # merge quotes
262
+ t = merge_quotes(t)
263
+ # remove multiple spaces
264
+ t = remove_extra_spaces(t)
265
+ # remove first and last comma
266
+ t = remove_first_last_commas(t)
267
+ # always start with a space
268
+ return f" {t}" if not clip else t