boris commited on
Commit
9e08dc5
2 Parent(s): e226ca6 849c5f3

Merge pull request #104 from borisdayma/feat-hf_hub

Browse files
Files changed (1) hide show
  1. dalle_mini/text.py +32 -43
dalle_mini/text.py CHANGED
@@ -2,36 +2,28 @@
2
  Utilities for processing text.
3
  """
4
 
5
- import requests
6
  from pathlib import Path
7
  from unidecode import unidecode
8
 
9
  import re, math, random, html
10
  import ftfy
11
 
12
- WIKI_STATS_URL = "https://github.com/borisdayma/wikipedia-word-frequency/raw/feat-update/results/enwiki-20210820-words-frequency.txt"
13
- WIKI_STATS_LOCAL = Path(WIKI_STATS_URL).parts[-1]
14
 
15
  # based on wiki word occurence
16
  person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
17
  temp_token = "xtokx" # avoid repeating chars
18
 
19
 
20
- def get_wiki_file():
21
- if not Path(WIKI_STATS_LOCAL).exists():
22
- r = requests.get(WIKI_STATS_URL, stream=True)
23
- with open(WIKI_STATS_LOCAL, "wb") as fd:
24
- for chunk in r.iter_content(chunk_size=128):
25
- fd.write(chunk)
26
- return WIKI_STATS_LOCAL
27
-
28
-
29
  class HashtagProcessor:
30
  # Adapted from wordninja library
31
  # We use our wikipedia word count + a good heuristic to make it work
32
  def __init__(self):
 
 
 
33
  self._word_cost = (
34
- l.split()[0] for l in Path(get_wiki_file()).read_text().splitlines()
35
  )
36
  self._word_cost = {
37
  str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
@@ -158,7 +150,7 @@ def handle_special_chars(t):
158
 
159
  def expand_hashtags(t, hashtag_processor):
160
  "Remove # and try to split words"
161
- return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
162
 
163
 
164
  _re_ignore_chars = """[_#\/\\%]"""
@@ -205,15 +197,13 @@ class TextNormalizer:
205
  def __init__(self):
206
  self._hashtag_processor = HashtagProcessor()
207
 
208
- def __call__(self, t, clip=False):
209
-
210
  # fix some characters
211
  t = ftfy.fix_text(t)
212
  # fix html
213
  t = fix_html(t)
214
- if not clip:
215
- # decode and simplify text: see unidecode library
216
- t = unidecode(t)
217
  # lower case
218
  t = t.lower()
219
  # replace <PERSON> (for CC12M)
@@ -226,32 +216,31 @@ class TextNormalizer:
226
  t = remove_urls(t)
227
  # remove commas in numbers
228
  t = remove_comma_numbers(t)
229
- if not clip:
230
- # handle dots in numbers and quotes - Part 1
231
- t = pre_process_dot_numbers(t)
232
- t = pre_process_quotes(t)
233
- # handle special characters
234
- t = handle_special_chars(t)
235
- # handle hashtags
236
- t = expand_hashtags(t, self._hashtag_processor)
237
- # ignore useless characters
238
- t = ignore_chars(t)
239
- # simplify quotes
240
- t = simplify_quotes(t)
241
- # all punctuation becomes commas
242
- t = replace_punctuation_with_commas(t)
243
- # handle dots in numbers and quotes - Part 2
244
- t = post_process_dot_numbers(t)
245
- t = post_process_quotes(t)
246
- # handle repeating characters
247
- t = remove_repeating_chars(t)
248
- # merge commas
249
- t = merge_commas(t)
250
- # merge quotes
251
- t = merge_quotes(t)
252
  # remove multiple spaces
253
  t = remove_extra_spaces(t)
254
  # remove first and last comma
255
  t = remove_first_last_commas(t)
256
  # always start with a space
257
- return f" {t}" if not clip else t
 
2
  Utilities for processing text.
3
  """
4
 
 
5
  from pathlib import Path
6
  from unidecode import unidecode
7
 
8
  import re, math, random, html
9
  import ftfy
10
 
11
+ from huggingface_hub import hf_hub_download
 
12
 
13
  # based on wiki word occurence
14
  person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
15
  temp_token = "xtokx" # avoid repeating chars
16
 
17
 
 
 
 
 
 
 
 
 
 
18
  class HashtagProcessor:
19
  # Adapted from wordninja library
20
  # We use our wikipedia word count + a good heuristic to make it work
21
  def __init__(self):
22
+ wiki_word_frequency = hf_hub_download(
23
+ "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
24
+ )
25
  self._word_cost = (
26
+ l.split()[0] for l in Path(wiki_word_frequency).read_text().splitlines()
27
  )
28
  self._word_cost = {
29
  str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)
 
150
 
151
  def expand_hashtags(t, hashtag_processor):
152
  "Remove # and try to split words"
153
+ return re.sub("#(\w+)", lambda m: " , " + hashtag_processor(m.group(1)), t)
154
 
155
 
156
  _re_ignore_chars = """[_#\/\\%]"""
 
197
  def __init__(self):
198
  self._hashtag_processor = HashtagProcessor()
199
 
200
+ def __call__(self, t):
 
201
  # fix some characters
202
  t = ftfy.fix_text(t)
203
  # fix html
204
  t = fix_html(t)
205
+ # decode and simplify text: see unidecode library
206
+ t = unidecode(t)
 
207
  # lower case
208
  t = t.lower()
209
  # replace <PERSON> (for CC12M)
 
216
  t = remove_urls(t)
217
  # remove commas in numbers
218
  t = remove_comma_numbers(t)
219
+ # handle dots in numbers and quotes - Part 1
220
+ t = pre_process_dot_numbers(t)
221
+ t = pre_process_quotes(t)
222
+ # handle special characters
223
+ t = handle_special_chars(t)
224
+ # handle hashtags
225
+ t = expand_hashtags(t, self._hashtag_processor)
226
+ # ignore useless characters
227
+ t = ignore_chars(t)
228
+ # simplify quotes
229
+ t = simplify_quotes(t)
230
+ # all punctuation becomes commas
231
+ t = replace_punctuation_with_commas(t)
232
+ # handle dots in numbers and quotes - Part 2
233
+ t = post_process_dot_numbers(t)
234
+ t = post_process_quotes(t)
235
+ # handle repeating characters
236
+ t = remove_repeating_chars(t)
237
+ # merge quotes
238
+ t = merge_quotes(t)
239
+ # merge commas
240
+ t = merge_commas(t)
 
241
  # remove multiple spaces
242
  t = remove_extra_spaces(t)
243
  # remove first and last comma
244
  t = remove_first_last_commas(t)
245
  # always start with a space
246
+ return f" {t}"