Pietro Lesci commited on
Commit
e2db848
·
1 Parent(s): 6114f21

redorder preprocessing steps and add new

Browse files
Files changed (1) hide show
  1. src/utils.py +36 -16
src/utils.py CHANGED
@@ -18,7 +18,7 @@ from stqdm import stqdm
18
  from textacy.preprocessing import make_pipeline, normalize, remove, replace
19
 
20
  from .configs import Languages, ModelConfigs, SupportedFiles
21
-
22
  stqdm.pandas()
23
 
24
 
@@ -27,13 +27,17 @@ def get_logo(path):
27
  return Image.open(path)
28
 
29
 
30
- # @st.cache(suppress_st_warning=True)
 
31
  def read_file(uploaded_file) -> pd.DataFrame:
32
 
33
  file_type = uploaded_file.name.split(".")[-1]
34
  if file_type in set(i.name for i in SupportedFiles):
35
  read_f = SupportedFiles[file_type].value[0]
36
- return read_f(uploaded_file, dtype=str)
 
 
 
37
 
38
  else:
39
  st.error("File type not supported")
@@ -155,16 +159,20 @@ def wordifier(X, y, X_names: List[str], y_names: List[str], configs=ModelConfigs
155
 
156
  # more [here](https://github.com/fastai/fastai/blob/master/fastai/text/core.py#L42)
157
  # and [here](https://textacy.readthedocs.io/en/latest/api_reference/preprocessing.html)
158
- _re_space = re.compile(" {2,}")
 
 
159
 
 
 
 
160
 
 
161
  def normalize_useless_spaces(t):
162
  return _re_space.sub(" ", t)
163
 
164
 
165
  _re_rep = re.compile(r"(\S)(\1{2,})")
166
-
167
-
168
  def normalize_repeating_chars(t):
169
  def _replace_rep(m):
170
  c, cc = m.groups()
@@ -174,8 +182,6 @@ def normalize_repeating_chars(t):
174
 
175
 
176
  _re_wrep = re.compile(r"(?:\s|^)(\w+)\s+((?:\1\s+)+)\1(\s|\W|$)")
177
-
178
-
179
  def normalize_repeating_words(t):
180
  def _replace_wrep(m):
181
  c, cc, e = m.groups()
@@ -248,18 +254,20 @@ class TextPreprocessor:
248
  ("normalize_hyphenated_words", normalize.hyphenated_words),
249
  ("normalize_quotation_marks", normalize.quotation_marks),
250
  ("normalize_whitespace", normalize.whitespace),
251
- ("remove_accents", remove.accents),
252
- ("remove_brackets", remove.brackets),
253
- ("remove_html_tags", remove.html_tags),
254
- ("remove_punctuation", remove.punctuation),
255
  ("replace_currency_symbols", replace.currency_symbols),
256
  ("replace_emails", replace.emails),
257
  ("replace_emojis", replace.emojis),
258
  ("replace_hashtags", replace.hashtags),
259
  ("replace_numbers", replace.numbers),
260
  ("replace_phone_numbers", replace.phone_numbers),
261
- ("replace_urls", replace.urls),
262
  ("replace_user_handles", replace.user_handles),
 
 
 
 
 
 
263
  ("normalize_useless_spaces", normalize_useless_spaces),
264
  ("normalize_repeating_chars", normalize_repeating_chars),
265
  ("normalize_repeating_words", normalize_repeating_words),
@@ -286,15 +294,27 @@ class TextPreprocessor:
286
 
287
  def plot_labels_prop(data: pd.DataFrame, label_column: str):
288
 
289
- source = data["label"].value_counts().reset_index().rename(columns={"index": "Labels", label_column: "Counts"})
 
 
 
 
 
 
 
 
 
 
290
 
291
- source["Proportions"] = ((source["Counts"] / source["Counts"].sum()).round(3) * 100).map("{:,.2f}".format) + "%"
 
 
292
 
293
  bars = (
294
  alt.Chart(source)
295
  .mark_bar()
296
  .encode(
297
- x="Labels:O",
298
  y="Counts:Q",
299
  )
300
  )
 
18
  from textacy.preprocessing import make_pipeline, normalize, remove, replace
19
 
20
  from .configs import Languages, ModelConfigs, SupportedFiles
21
+ import string
22
  stqdm.pandas()
23
 
24
 
 
27
  return Image.open(path)
28
 
29
 
30
+ # @st.cache(suppress_st_warning=True)
31
+ @st.cache(allow_output_mutation=True)
32
  def read_file(uploaded_file) -> pd.DataFrame:
33
 
34
  file_type = uploaded_file.name.split(".")[-1]
35
  if file_type in set(i.name for i in SupportedFiles):
36
  read_f = SupportedFiles[file_type].value[0]
37
+ df = read_f(uploaded_file)
38
+ # remove any NA
39
+ df = df.dropna()
40
+ return df
41
 
42
  else:
43
  st.error("File type not supported")
 
159
 
160
  # more [here](https://github.com/fastai/fastai/blob/master/fastai/text/core.py#L42)
161
  # and [here](https://textacy.readthedocs.io/en/latest/api_reference/preprocessing.html)
162
+ _re_normalize_acronyms = re.compile("(?:[a-zA-Z]\.){2,}")
163
+ def normalize_acronyms(t):
164
+ return _re_normalize_acronyms.sub(t.translate(str.maketrans("", "", string.punctuation)).upper(), t)
165
 
166
+ _re_non_word = re.compile("\W")
167
+ def remove_non_word(t):
168
+ return _re_non_word.sub(" ", t)
169
 
170
+ _re_space = re.compile(" {2,}")
171
  def normalize_useless_spaces(t):
172
  return _re_space.sub(" ", t)
173
 
174
 
175
  _re_rep = re.compile(r"(\S)(\1{2,})")
 
 
176
  def normalize_repeating_chars(t):
177
  def _replace_rep(m):
178
  c, cc = m.groups()
 
182
 
183
 
184
  _re_wrep = re.compile(r"(?:\s|^)(\w+)\s+((?:\1\s+)+)\1(\s|\W|$)")
 
 
185
  def normalize_repeating_words(t):
186
  def _replace_wrep(m):
187
  c, cc, e = m.groups()
 
254
  ("normalize_hyphenated_words", normalize.hyphenated_words),
255
  ("normalize_quotation_marks", normalize.quotation_marks),
256
  ("normalize_whitespace", normalize.whitespace),
257
+ ("replace_urls", replace.urls),
 
 
 
258
  ("replace_currency_symbols", replace.currency_symbols),
259
  ("replace_emails", replace.emails),
260
  ("replace_emojis", replace.emojis),
261
  ("replace_hashtags", replace.hashtags),
262
  ("replace_numbers", replace.numbers),
263
  ("replace_phone_numbers", replace.phone_numbers),
 
264
  ("replace_user_handles", replace.user_handles),
265
+ ("normalize_acronyms", normalize_acronyms),
266
+ ("remove_accents", remove.accents),
267
+ ("remove_brackets", remove.brackets),
268
+ ("remove_html_tags", remove.html_tags),
269
+ ("remove_punctuation", remove.punctuation),
270
+ ("remove_non_words", remove_non_word),
271
  ("normalize_useless_spaces", normalize_useless_spaces),
272
  ("normalize_repeating_chars", normalize_repeating_chars),
273
  ("normalize_repeating_words", normalize_repeating_words),
 
294
 
295
  def plot_labels_prop(data: pd.DataFrame, label_column: str):
296
 
297
+ unique_value_limit = 100
298
+
299
+ if data[label_column].nunique() > unique_value_limit:
300
+
301
+ st.warning(f"""
302
+ The column you selected has more than {unique_value_limit}.
303
+ Are you sure it's the right column? If it is, please note that
304
+ this will impact __Wordify__ performance.
305
+ """)
306
+
307
+ return
308
 
309
+ source = data[label_column].value_counts().reset_index().rename(columns={"index": "Labels", label_column: "Counts"})
310
+ source["Props"] = source["Counts"] / source["Counts"].sum()
311
+ source["Proportions"] = (source["Props"].round(3) * 100).map("{:,.2f}".format) + "%"
312
 
313
  bars = (
314
  alt.Chart(source)
315
  .mark_bar()
316
  .encode(
317
+ x=alt.X("Labels:O", sort="-y"),
318
  y="Counts:Q",
319
  )
320
  )