meg HF staff commited on
Commit
c24f881
1 Parent(s): e8ac901

merging dataset statistics file

Browse files
Files changed (1) hide show
  1. data_measurements/dataset_statistics.py +112 -217
data_measurements/dataset_statistics.py CHANGED
@@ -15,11 +15,12 @@
15
  import json
16
  import logging
17
  import statistics
18
- import torch
19
  from os import mkdir
20
  from os.path import exists, isdir
21
  from os.path import join as pjoin
22
 
 
 
23
  import nltk
24
  import numpy as np
25
  import pandas as pd
@@ -28,31 +29,17 @@ import plotly.express as px
28
  import plotly.figure_factory as ff
29
  import plotly.graph_objects as go
30
  import pyarrow.feather as feather
31
- import matplotlib.pyplot as plt
32
- import matplotlib.image as mpimg
33
  import seaborn as sns
 
34
  from datasets import load_from_disk
35
  from nltk.corpus import stopwords
36
  from sklearn.feature_extraction.text import CountVectorizer
37
 
38
- from .dataset_utils import (
39
- TOT_WORDS,
40
- TOT_OPEN_WORDS,
41
- CNT,
42
- DEDUP_TOT,
43
- EMBEDDING_FIELD,
44
- LENGTH_FIELD,
45
- OUR_LABEL_FIELD,
46
- OUR_TEXT_FIELD,
47
- PROP,
48
- TEXT_NAN_CNT,
49
- TOKENIZED_FIELD,
50
- TXT_LEN,
51
- VOCAB,
52
- WORD,
53
- extract_field,
54
- load_truncated_dataset,
55
- )
56
  from .embeddings import Embeddings
57
  from .npmi import nPMI
58
  from .zipf import Zipf
@@ -151,6 +138,7 @@ _NUM_VOCAB_BATCHES = 2000
151
  _TOP_N = 100
152
  _CVEC = CountVectorizer(token_pattern="(?u)\\b\\w+\\b", lowercase=True)
153
 
 
154
  class DatasetStatisticsCacheClass:
155
  def __init__(
156
  self,
@@ -249,13 +237,13 @@ class DatasetStatisticsCacheClass:
249
  # path to the directory used for caching
250
  if not isinstance(text_field, str):
251
  text_field = "-".join(text_field)
252
- #if isinstance(label_field, str):
253
  # label_field = label_field
254
- #else:
255
  # label_field = "-".join(label_field)
256
  self.cache_path = pjoin(
257
  self.cache_dir,
258
- f"{dset_name}_{dset_config}_{split_name}_{text_field}", #{label_field},
259
  )
260
  if not isdir(self.cache_path):
261
  logs.warning("Creating cache directory %s." % self.cache_path)
@@ -284,14 +272,15 @@ class DatasetStatisticsCacheClass:
284
  # Needed for UI
285
  self.dup_counts_df_fid = pjoin(self.cache_path, "dup_counts_df.feather")
286
  # Needed for UI
287
- self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.json")
288
 
289
  ## General text stats
290
  # Needed for UI
291
  self.general_stats_json_fid = pjoin(self.cache_path, "general_stats_dict.json")
292
  # Needed for UI
293
- self.sorted_top_vocab_df_fid = pjoin(self.cache_path,
294
- "sorted_top_vocab.feather")
 
295
  ## Zipf cache files
296
  # Needed for UI
297
  self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
@@ -303,7 +292,6 @@ class DatasetStatisticsCacheClass:
303
  self.node_list_fid = pjoin(self.cache_path, "node_list.th")
304
  # Needed for UI
305
  self.fig_tree_json_fid = pjoin(self.cache_path, "fig_tree.json")
306
- self.zipf_counts = None
307
 
308
  self.live = False
309
 
@@ -343,18 +331,17 @@ class DatasetStatisticsCacheClass:
343
  and exists(self.dup_counts_df_fid)
344
  and exists(self.sorted_top_vocab_df_fid)
345
  ):
346
- logs.info('Loading cached general stats')
347
  self.load_general_stats()
348
  else:
349
  if not self.live:
350
- logs.info('Preparing general stats')
351
  self.prepare_general_stats()
352
  if save:
353
  write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
354
  write_df(self.dup_counts_df, self.dup_counts_df_fid)
355
  write_json(self.general_stats_dict, self.general_stats_json_fid)
356
 
357
-
358
  def load_or_prepare_text_lengths(self, save=True):
359
  """
360
  The text length widget relies on this function, which provides
@@ -366,15 +353,13 @@ class DatasetStatisticsCacheClass:
366
 
367
  """
368
  # Text length figure
369
- if (self.use_cache and exists(self.fig_tok_length_fid)):
370
  self.fig_tok_length_png = mpimg.imread(self.fig_tok_length_fid)
371
- self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
372
  else:
373
  if not self.live:
374
  self.prepare_fig_text_lengths()
375
  if save:
376
- write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
377
-
378
  # Text length dataframe
379
  if self.use_cache and exists(self.length_df_fid):
380
  self.length_df = feather.read_feather(self.length_df_fid)
@@ -401,51 +386,48 @@ class DatasetStatisticsCacheClass:
401
  if not self.live:
402
  if self.tokenized_df is None:
403
  self.tokenized_df = self.do_tokenization()
404
- self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[
405
- TOKENIZED_FIELD].apply(len)
406
- self.length_df = self.tokenized_df[
407
- [LENGTH_FIELD, OUR_TEXT_FIELD]].sort_values(
408
- by=[LENGTH_FIELD], ascending=True
409
  )
 
 
 
410
 
411
  def prepare_text_length_stats(self):
412
  if not self.live:
413
- if self.tokenized_df is None or LENGTH_FIELD not in self.tokenized_df.columns or self.length_df is None:
 
 
 
 
414
  self.prepare_length_df()
415
- avg_length = sum(self.tokenized_df[LENGTH_FIELD])/len(self.tokenized_df[LENGTH_FIELD])
 
 
416
  self.avg_length = round(avg_length, 1)
417
  std_length = statistics.stdev(self.tokenized_df[LENGTH_FIELD])
418
  self.std_length = round(std_length, 1)
419
  self.num_uniq_lengths = len(self.length_df["length"].unique())
420
- self.length_stats_dict = {"avg length": self.avg_length,
421
- "std length": self.std_length,
422
- "num lengths": self.num_uniq_lengths}
 
 
423
 
424
  def prepare_fig_text_lengths(self):
425
  if not self.live:
426
- if self.tokenized_df is None or LENGTH_FIELD not in self.tokenized_df.columns:
 
 
 
427
  self.prepare_length_df()
428
  self.fig_tok_length = make_fig_lengths(self.tokenized_df, LENGTH_FIELD)
429
 
430
- def load_or_prepare_embeddings(self, save=True):
431
- if self.use_cache and exists(self.node_list_fid) and exists(self.fig_tree_json_fid):
432
- self.node_list = torch.load(self.node_list_fid)
433
- self.fig_tree = read_plotly(self.fig_tree_json_fid)
434
- elif self.use_cache and exists(self.node_list_fid):
435
- self.node_list = torch.load(self.node_list_fid)
436
- self.fig_tree = make_tree_plot(self.node_list,
437
- self.text_dset)
438
- if save:
439
- write_plotly(self.fig_tree, self.fig_tree_json_fid)
440
- else:
441
- self.embeddings = Embeddings(self, use_cache=self.use_cache)
442
- self.embeddings.make_hierarchical_clustering()
443
- self.node_list = self.embeddings.node_list
444
- self.fig_tree = make_tree_plot(self.node_list,
445
- self.text_dset)
446
- if save:
447
- torch.save(self.node_list, self.node_list_fid)
448
- write_plotly(self.fig_tree, self.fig_tree_json_fid)
449
 
450
  # get vocab with word counts
451
  def load_or_prepare_vocab(self, save=True):
@@ -455,10 +437,7 @@ class DatasetStatisticsCacheClass:
455
  :param
456
  :return:
457
  """
458
- if (
459
- self.use_cache
460
- and exists(self.vocab_counts_df_fid)
461
- ):
462
  logs.info("Reading vocab from cache")
463
  self.load_vocab()
464
  self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
@@ -505,7 +484,9 @@ class DatasetStatisticsCacheClass:
505
  write_df(self.dup_counts_df, self.dup_counts_df_fid)
506
 
507
  def load_general_stats(self):
508
- self.general_stats_dict = json.load(open(self.general_stats_json_fid, encoding="utf-8"))
 
 
509
  with open(self.sorted_top_vocab_df_fid, "rb") as f:
510
  self.sorted_top_vocab_df = feather.read_feather(f)
511
  self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
@@ -540,8 +521,7 @@ class DatasetStatisticsCacheClass:
540
  if not self.live:
541
  if self.tokenized_df is None:
542
  self.load_or_prepare_tokenized_df()
543
- dup_df = self.tokenized_df[
544
- self.tokenized_df.duplicated([OUR_TEXT_FIELD])]
545
  self.dup_counts_df = pd.DataFrame(
546
  dup_df.pivot_table(
547
  columns=[OUR_TEXT_FIELD], aggfunc="size"
@@ -581,7 +561,7 @@ class DatasetStatisticsCacheClass:
581
  write_json({"dset peek": self.dset_peek}, self.dset_peek_json_fid)
582
 
583
  def load_or_prepare_tokenized_df(self, save=True):
584
- if (self.use_cache and exists(self.tokenized_df_fid)):
585
  self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
586
  else:
587
  if not self.live:
@@ -593,7 +573,7 @@ class DatasetStatisticsCacheClass:
593
  write_df(self.tokenized_df, self.tokenized_df_fid)
594
 
595
  def load_or_prepare_text_dset(self, save=True):
596
- if (self.use_cache and exists(self.text_dset_fid)):
597
  # load extracted text
598
  self.text_dset = load_from_disk(self.text_dset_fid)
599
  logs.warning("Loaded dataset from disk")
@@ -711,8 +691,6 @@ class DatasetStatisticsCacheClass:
711
  zipf_dict = json.load(f)
712
  self.z = Zipf()
713
  self.z.load(zipf_dict)
714
- # TODO: Should this be cached?
715
- self.zipf_counts = self.z.calc_zipf_counts(self.vocab_counts_df)
716
  self.zipf_fig = read_plotly(self.zipf_fig_fid)
717
  elif self.use_cache and exists(self.zipf_fid):
718
  # TODO: Read zipf data so that the vocab is there.
@@ -775,30 +753,26 @@ class nPMIStatisticsCacheClass:
775
  and exists(self.npmi_terms_fid)
776
  and json.load(open(self.npmi_terms_fid))["available terms"] != []
777
  ):
778
- self.available_terms = json.load(open(self.npmi_terms_fid))["available terms"]
779
  else:
780
- if not self.live:
781
- if self.dstats.vocab_counts_df is None:
782
- self.dstats.load_or_prepare_vocab()
783
-
784
- true_false = [
785
- term in self.dstats.vocab_counts_df.index for term in self.termlist
786
- ]
787
- word_list_tmp = [x for x, y in zip(self.termlist, true_false) if y]
788
- true_false_counts = [
789
- self.dstats.vocab_counts_df.loc[word, CNT] >= self.min_vocab_count
790
- for word in word_list_tmp
791
- ]
792
- available_terms = [
793
- word for word, y in zip(word_list_tmp, true_false_counts) if y
794
- ]
795
- logs.info(available_terms)
796
- with open(self.npmi_terms_fid, "w+") as f:
797
- json.dump({"available terms": available_terms}, f)
798
- self.available_terms = available_terms
799
- return self.available_terms
800
-
801
- def load_or_prepare_joint_npmi(self, subgroup_pair, save=True):
802
  """
803
  Run on-the fly, while the app is already open,
804
  as it depends on the subgroup terms that the user chooses
@@ -823,7 +797,13 @@ class nPMIStatisticsCacheClass:
823
  # When everything is already computed for the selected subgroups.
824
  logs.info("Loading cached joint npmi")
825
  joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
826
- npmi_display_cols = ['npmi-bias', subgroup1 + '-npmi', subgroup2 + '-npmi', subgroup1 + '-count', subgroup2 + '-count']
 
 
 
 
 
 
827
  joint_npmi_df = joint_npmi_df[npmi_display_cols]
828
  # When maybe some things have been computed for the selected subgroups.
829
  else:
@@ -832,14 +812,12 @@ class nPMIStatisticsCacheClass:
832
  joint_npmi_df, subgroup_dict = self.prepare_joint_npmi_df(
833
  subgroup_pair, subgroup_files
834
  )
835
- if save:
836
- if joint_npmi_df is not None:
837
- # Cache new results
838
- logs.info("Writing out.")
839
- for subgroup in subgroup_pair:
840
- write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files)
841
- with open(joint_npmi_fid, "w+") as f:
842
- joint_npmi_df.to_csv(f)
843
  else:
844
  joint_npmi_df = pd.DataFrame()
845
  logs.info("The joint npmi df is")
@@ -881,7 +859,7 @@ class nPMIStatisticsCacheClass:
881
  subgroup_dict[subgroup] = cached_results
882
  logs.info("Calculating for subgroup list")
883
  joint_npmi_df, subgroup_dict = self.do_npmi(subgroup_pair, subgroup_dict)
884
- return joint_npmi_df, subgroup_dict
885
 
886
  # TODO: Update pairwise assumption
887
  def do_npmi(self, subgroup_pair, subgroup_dict):
@@ -892,7 +870,6 @@ class nPMIStatisticsCacheClass:
892
  :return: Selected identity term's co-occurrence counts with
893
  other words, pmi per word, and nPMI per word.
894
  """
895
- no_results = False
896
  logs.info("Initializing npmi class")
897
  npmi_obj = self.set_npmi_obj()
898
  # Canonical ordering used
@@ -900,26 +877,18 @@ class nPMIStatisticsCacheClass:
900
  # Calculating nPMI statistics
901
  for subgroup in subgroup_pair:
902
  # If the subgroup data is already computed, grab it.
903
- # TODO: Should we set idx and column names similarly to
904
- # how we set them for cached files?
905
  if subgroup not in subgroup_dict:
906
  logs.info("Calculating statistics for %s" % subgroup)
907
  vocab_cooc_df, pmi_df, npmi_df = npmi_obj.calc_metrics(subgroup)
908
- if vocab_cooc_df is None:
909
- no_results = True
910
- else:
911
- # Store the nPMI information for the current subgroups
912
- subgroup_dict[subgroup] = (vocab_cooc_df, pmi_df, npmi_df)
913
- if no_results:
914
- logs.warning("Couldn't grap the npmi files -- Under construction")
915
- return None, None
916
- else:
917
- # Pair the subgroups together, indexed by all words that
918
- # co-occur between them.
919
- logs.info("Computing pairwise npmi bias")
920
- paired_results = npmi_obj.calc_paired_metrics(subgroup_pair, subgroup_dict)
921
- UI_results = make_npmi_fig(paired_results, subgroup_pair)
922
- return UI_results.dropna(), subgroup_dict
923
 
924
  def set_npmi_obj(self):
925
  """
@@ -993,9 +962,11 @@ class nPMIStatisticsCacheClass:
993
  def get_available_terms(self):
994
  return self.load_or_prepare_npmi_terms()
995
 
 
996
  def dummy(doc):
997
  return doc
998
 
 
999
  def count_vocab_frequencies(tokenized_df):
1000
  """
1001
  Based on an input pandas DataFrame with a 'text' column,
@@ -1010,7 +981,9 @@ def count_vocab_frequencies(tokenized_df):
1010
  )
1011
  # We do this to calculate per-word statistics
1012
  # Fast calculation of single word counts
1013
- logs.info("Fitting dummy tokenization to make matrix using the previous tokenization")
 
 
1014
  cvec.fit(tokenized_df[TOKENIZED_FIELD])
1015
  document_matrix = cvec.transform(tokenized_df[TOKENIZED_FIELD])
1016
  batches = np.linspace(0, tokenized_df.shape[0], _NUM_VOCAB_BATCHES).astype(int)
@@ -1031,6 +1004,7 @@ def count_vocab_frequencies(tokenized_df):
1031
  word_count_df.index.name = WORD
1032
  return word_count_df
1033
 
 
1034
  def calc_p_word(word_count_df):
1035
  # p(word)
1036
  word_count_df[PROP] = word_count_df[CNT] / float(sum(word_count_df[CNT]))
@@ -1041,8 +1015,7 @@ def calc_p_word(word_count_df):
1041
 
1042
  def filter_vocab(vocab_counts_df):
1043
  # TODO: Add warnings (which words are missing) to log file?
1044
- filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS,
1045
- errors="ignore")
1046
  filtered_count = filtered_vocab_counts_df[CNT]
1047
  filtered_count_denom = float(sum(filtered_vocab_counts_df[CNT]))
1048
  filtered_vocab_counts_df[PROP] = filtered_count / filtered_count_denom
@@ -1051,19 +1024,23 @@ def filter_vocab(vocab_counts_df):
1051
 
1052
  ## Figures ##
1053
 
 
1054
  def write_plotly(fig, fid):
1055
  write_json(plotly.io.to_json(fig), fid)
1056
 
 
1057
  def read_plotly(fid):
1058
  fig = plotly.io.from_json(json.load(open(fid, encoding="utf-8")))
1059
  return fig
1060
 
 
1061
  def make_fig_lengths(tokenized_df, length_field):
1062
- fig_tok_length = px.histogram(
1063
- tokenized_df, x=length_field, marginal="rug", hover_data=[length_field]
1064
- )
1065
  return fig_tok_length
1066
 
 
1067
  def make_fig_labels(label_df, label_names, label_field):
1068
  labels = label_df[label_field].unique()
1069
  label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
@@ -1144,89 +1121,6 @@ def make_zipf_fig(vocab_counts_df, z):
1144
  return fig
1145
 
1146
 
1147
- def make_tree_plot(node_list, text_dset):
1148
- nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
1149
-
1150
- for nid, node in enumerate(node_list):
1151
- node["label"] = node.get(
1152
- "label",
1153
- f"{nid:2d} - {node['weight']:5d} items <br>"
1154
- + "<br>".join(
1155
- [
1156
- "> " + txt[:64] + ("..." if len(txt) >= 63 else "")
1157
- for txt in list(
1158
- set(text_dset.select(node["example_ids"])[OUR_TEXT_FIELD])
1159
- )[:5]
1160
- ]
1161
- ),
1162
- )
1163
-
1164
- # make plot nodes
1165
- # TODO: something more efficient than set to remove duplicates
1166
- labels = [node["label"] for node in node_list]
1167
-
1168
- root = node_list[0]
1169
- root["X"] = 0
1170
- root["Y"] = 0
1171
-
1172
- def rec_make_coordinates(node):
1173
- total_weight = 0
1174
- add_weight = len(node["example_ids"]) - sum(
1175
- [child["weight"] for child in node["children"]]
1176
- )
1177
- for child in node["children"]:
1178
- child["X"] = node["X"] + total_weight
1179
- child["Y"] = node["Y"] - 1
1180
- total_weight += child["weight"] + add_weight / len(node["children"])
1181
- rec_make_coordinates(child)
1182
-
1183
- rec_make_coordinates(root)
1184
-
1185
- E = [] # list of edges
1186
- Xn = []
1187
- Yn = []
1188
- Xe = []
1189
- Ye = []
1190
- for nid, node in enumerate(node_list):
1191
- Xn += [node["X"]]
1192
- Yn += [node["Y"]]
1193
- for child in node["children"]:
1194
- E += [(nid, nid_map[child["nid"]])]
1195
- Xe += [node["X"], child["X"], None]
1196
- Ye += [node["Y"], child["Y"], None]
1197
-
1198
- # make figure
1199
- fig = go.Figure()
1200
- fig.add_trace(
1201
- go.Scatter(
1202
- x=Xe,
1203
- y=Ye,
1204
- mode="lines",
1205
- line=dict(color="rgb(210,210,210)", width=1),
1206
- hoverinfo="none",
1207
- )
1208
- )
1209
- fig.add_trace(
1210
- go.Scatter(
1211
- x=Xn,
1212
- y=Yn,
1213
- mode="markers",
1214
- name="nodes",
1215
- marker=dict(
1216
- symbol="circle-dot",
1217
- size=18,
1218
- color="#6175c1",
1219
- line=dict(color="rgb(50,50,50)", width=1)
1220
- # '#DB4551',
1221
- ),
1222
- text=labels,
1223
- hoverinfo="text",
1224
- opacity=0.8,
1225
- )
1226
- )
1227
- return fig
1228
-
1229
-
1230
  ## Input/Output ###
1231
 
1232
 
@@ -1280,6 +1174,7 @@ def write_json(json_dict, json_fid):
1280
  with open(json_fid, "w", encoding="utf-8") as f:
1281
  json.dump(json_dict, f)
1282
 
 
1283
  def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
1284
  """
1285
  Saves the calculated nPMI statistics to their output files.
@@ -1299,6 +1194,7 @@ def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
1299
  with open(subgroup_cooc_fid, "w+") as f:
1300
  subgroup_cooc_df.to_csv(f)
1301
 
 
1302
  def write_zipf_data(z, zipf_fid):
1303
  zipf_dict = {}
1304
  zipf_dict["xmin"] = int(z.xmin)
@@ -1310,4 +1206,3 @@ def write_zipf_data(z, zipf_fid):
1310
  zipf_dict["uniq_ranks"] = [int(rank) for rank in z.uniq_ranks]
1311
  with open(zipf_fid, "w+", encoding="utf-8") as f:
1312
  json.dump(zipf_dict, f)
1313
-
 
15
  import json
16
  import logging
17
  import statistics
 
18
  from os import mkdir
19
  from os.path import exists, isdir
20
  from os.path import join as pjoin
21
 
22
+ import matplotlib.pyplot as plt
23
+ import matplotlib.image as mpimg
24
  import nltk
25
  import numpy as np
26
  import pandas as pd
 
29
  import plotly.figure_factory as ff
30
  import plotly.graph_objects as go
31
  import pyarrow.feather as feather
 
 
32
  import seaborn as sns
33
+ import torch
34
  from datasets import load_from_disk
35
  from nltk.corpus import stopwords
36
  from sklearn.feature_extraction.text import CountVectorizer
37
 
38
+ from .dataset_utils import (CNT, DEDUP_TOT, EMBEDDING_FIELD, LENGTH_FIELD,
39
+ OUR_LABEL_FIELD, OUR_TEXT_FIELD, PROP,
40
+ TEXT_NAN_CNT, TOKENIZED_FIELD, TOT_OPEN_WORDS,
41
+ TOT_WORDS, TXT_LEN, VOCAB, WORD, extract_field,
42
+ load_truncated_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  from .embeddings import Embeddings
44
  from .npmi import nPMI
45
  from .zipf import Zipf
 
138
  _TOP_N = 100
139
  _CVEC = CountVectorizer(token_pattern="(?u)\\b\\w+\\b", lowercase=True)
140
 
141
+
142
  class DatasetStatisticsCacheClass:
143
  def __init__(
144
  self,
 
237
  # path to the directory used for caching
238
  if not isinstance(text_field, str):
239
  text_field = "-".join(text_field)
240
+ # if isinstance(label_field, str):
241
  # label_field = label_field
242
+ # else:
243
  # label_field = "-".join(label_field)
244
  self.cache_path = pjoin(
245
  self.cache_dir,
246
+ f"{dset_name}_{dset_config}_{split_name}_{text_field}", # {label_field},
247
  )
248
  if not isdir(self.cache_path):
249
  logs.warning("Creating cache directory %s." % self.cache_path)
 
272
  # Needed for UI
273
  self.dup_counts_df_fid = pjoin(self.cache_path, "dup_counts_df.feather")
274
  # Needed for UI
275
+ self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.png")
276
 
277
  ## General text stats
278
  # Needed for UI
279
  self.general_stats_json_fid = pjoin(self.cache_path, "general_stats_dict.json")
280
  # Needed for UI
281
+ self.sorted_top_vocab_df_fid = pjoin(
282
+ self.cache_path, "sorted_top_vocab.feather"
283
+ )
284
  ## Zipf cache files
285
  # Needed for UI
286
  self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
 
292
  self.node_list_fid = pjoin(self.cache_path, "node_list.th")
293
  # Needed for UI
294
  self.fig_tree_json_fid = pjoin(self.cache_path, "fig_tree.json")
 
295
 
296
  self.live = False
297
 
 
331
  and exists(self.dup_counts_df_fid)
332
  and exists(self.sorted_top_vocab_df_fid)
333
  ):
334
+ logs.info("Loading cached general stats")
335
  self.load_general_stats()
336
  else:
337
  if not self.live:
338
+ logs.info("Preparing general stats")
339
  self.prepare_general_stats()
340
  if save:
341
  write_df(self.sorted_top_vocab_df, self.sorted_top_vocab_df_fid)
342
  write_df(self.dup_counts_df, self.dup_counts_df_fid)
343
  write_json(self.general_stats_dict, self.general_stats_json_fid)
344
 
 
345
  def load_or_prepare_text_lengths(self, save=True):
346
  """
347
  The text length widget relies on this function, which provides
 
353
 
354
  """
355
  # Text length figure
356
+ if self.use_cache and exists(self.fig_tok_length_fid):
357
  self.fig_tok_length_png = mpimg.imread(self.fig_tok_length_fid)
 
358
  else:
359
  if not self.live:
360
  self.prepare_fig_text_lengths()
361
  if save:
362
+ self.fig_tok_length.savefig(self.fig_tok_length_fid)
 
363
  # Text length dataframe
364
  if self.use_cache and exists(self.length_df_fid):
365
  self.length_df = feather.read_feather(self.length_df_fid)
 
386
  if not self.live:
387
  if self.tokenized_df is None:
388
  self.tokenized_df = self.do_tokenization()
389
+ self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[TOKENIZED_FIELD].apply(
390
+ len
 
 
 
391
  )
392
+ self.length_df = self.tokenized_df[
393
+ [LENGTH_FIELD, OUR_TEXT_FIELD]
394
+ ].sort_values(by=[LENGTH_FIELD], ascending=True)
395
 
396
  def prepare_text_length_stats(self):
397
  if not self.live:
398
+ if (
399
+ self.tokenized_df is None
400
+ or LENGTH_FIELD not in self.tokenized_df.columns
401
+ or self.length_df is None
402
+ ):
403
  self.prepare_length_df()
404
+ avg_length = sum(self.tokenized_df[LENGTH_FIELD]) / len(
405
+ self.tokenized_df[LENGTH_FIELD]
406
+ )
407
  self.avg_length = round(avg_length, 1)
408
  std_length = statistics.stdev(self.tokenized_df[LENGTH_FIELD])
409
  self.std_length = round(std_length, 1)
410
  self.num_uniq_lengths = len(self.length_df["length"].unique())
411
+ self.length_stats_dict = {
412
+ "avg length": self.avg_length,
413
+ "std length": self.std_length,
414
+ "num lengths": self.num_uniq_lengths,
415
+ }
416
 
417
  def prepare_fig_text_lengths(self):
418
  if not self.live:
419
+ if (
420
+ self.tokenized_df is None
421
+ or LENGTH_FIELD not in self.tokenized_df.columns
422
+ ):
423
  self.prepare_length_df()
424
  self.fig_tok_length = make_fig_lengths(self.tokenized_df, LENGTH_FIELD)
425
 
426
+ def load_or_prepare_embeddings(self):
427
+ self.embeddings = Embeddings(self, use_cache=self.use_cache)
428
+ self.embeddings.make_hierarchical_clustering()
429
+ self.node_list = self.embeddings.node_list
430
+ self.fig_tree = self.embeddings.fig_tree
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
  # get vocab with word counts
433
  def load_or_prepare_vocab(self, save=True):
 
437
  :param
438
  :return:
439
  """
440
+ if self.use_cache and exists(self.vocab_counts_df_fid):
 
 
 
441
  logs.info("Reading vocab from cache")
442
  self.load_vocab()
443
  self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
 
484
  write_df(self.dup_counts_df, self.dup_counts_df_fid)
485
 
486
  def load_general_stats(self):
487
+ self.general_stats_dict = json.load(
488
+ open(self.general_stats_json_fid, encoding="utf-8")
489
+ )
490
  with open(self.sorted_top_vocab_df_fid, "rb") as f:
491
  self.sorted_top_vocab_df = feather.read_feather(f)
492
  self.text_nan_count = self.general_stats_dict[TEXT_NAN_CNT]
 
521
  if not self.live:
522
  if self.tokenized_df is None:
523
  self.load_or_prepare_tokenized_df()
524
+ dup_df = self.tokenized_df[self.tokenized_df.duplicated([OUR_TEXT_FIELD])]
 
525
  self.dup_counts_df = pd.DataFrame(
526
  dup_df.pivot_table(
527
  columns=[OUR_TEXT_FIELD], aggfunc="size"
 
561
  write_json({"dset peek": self.dset_peek}, self.dset_peek_json_fid)
562
 
563
  def load_or_prepare_tokenized_df(self, save=True):
564
+ if self.use_cache and exists(self.tokenized_df_fid):
565
  self.tokenized_df = feather.read_feather(self.tokenized_df_fid)
566
  else:
567
  if not self.live:
 
573
  write_df(self.tokenized_df, self.tokenized_df_fid)
574
 
575
  def load_or_prepare_text_dset(self, save=True):
576
+ if self.use_cache and exists(self.text_dset_fid):
577
  # load extracted text
578
  self.text_dset = load_from_disk(self.text_dset_fid)
579
  logs.warning("Loaded dataset from disk")
 
691
  zipf_dict = json.load(f)
692
  self.z = Zipf()
693
  self.z.load(zipf_dict)
 
 
694
  self.zipf_fig = read_plotly(self.zipf_fig_fid)
695
  elif self.use_cache and exists(self.zipf_fid):
696
  # TODO: Read zipf data so that the vocab is there.
 
753
  and exists(self.npmi_terms_fid)
754
  and json.load(open(self.npmi_terms_fid))["available terms"] != []
755
  ):
756
+ available_terms = json.load(open(self.npmi_terms_fid))["available terms"]
757
  else:
758
+ true_false = [
759
+ term in self.dstats.vocab_counts_df.index for term in self.termlist
760
+ ]
761
+ word_list_tmp = [x for x, y in zip(self.termlist, true_false) if y]
762
+ true_false_counts = [
763
+ self.dstats.vocab_counts_df.loc[word, CNT] >= self.min_vocab_count
764
+ for word in word_list_tmp
765
+ ]
766
+ available_terms = [
767
+ word for word, y in zip(word_list_tmp, true_false_counts) if y
768
+ ]
769
+ logs.info(available_terms)
770
+ with open(self.npmi_terms_fid, "w+") as f:
771
+ json.dump({"available terms": available_terms}, f)
772
+ self.available_terms = available_terms
773
+ return available_terms
774
+
775
+ def load_or_prepare_joint_npmi(self, subgroup_pair):
 
 
 
 
776
  """
777
  Run on-the fly, while the app is already open,
778
  as it depends on the subgroup terms that the user chooses
 
797
  # When everything is already computed for the selected subgroups.
798
  logs.info("Loading cached joint npmi")
799
  joint_npmi_df = self.load_joint_npmi_df(joint_npmi_fid)
800
+ npmi_display_cols = [
801
+ "npmi-bias",
802
+ subgroup1 + "-npmi",
803
+ subgroup2 + "-npmi",
804
+ subgroup1 + "-count",
805
+ subgroup2 + "-count",
806
+ ]
807
  joint_npmi_df = joint_npmi_df[npmi_display_cols]
808
  # When maybe some things have been computed for the selected subgroups.
809
  else:
 
812
  joint_npmi_df, subgroup_dict = self.prepare_joint_npmi_df(
813
  subgroup_pair, subgroup_files
814
  )
815
+ # Cache new results
816
+ logs.info("Writing out.")
817
+ for subgroup in subgroup_pair:
818
+ write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files)
819
+ with open(joint_npmi_fid, "w+") as f:
820
+ joint_npmi_df.to_csv(f)
 
 
821
  else:
822
  joint_npmi_df = pd.DataFrame()
823
  logs.info("The joint npmi df is")
 
859
  subgroup_dict[subgroup] = cached_results
860
  logs.info("Calculating for subgroup list")
861
  joint_npmi_df, subgroup_dict = self.do_npmi(subgroup_pair, subgroup_dict)
862
+ return joint_npmi_df.dropna(), subgroup_dict
863
 
864
  # TODO: Update pairwise assumption
865
  def do_npmi(self, subgroup_pair, subgroup_dict):
 
870
  :return: Selected identity term's co-occurrence counts with
871
  other words, pmi per word, and nPMI per word.
872
  """
 
873
  logs.info("Initializing npmi class")
874
  npmi_obj = self.set_npmi_obj()
875
  # Canonical ordering used
 
877
  # Calculating nPMI statistics
878
  for subgroup in subgroup_pair:
879
  # If the subgroup data is already computed, grab it.
880
+ # TODO: Should we set idx and column names similarly to how we set them for cached files?
 
881
  if subgroup not in subgroup_dict:
882
  logs.info("Calculating statistics for %s" % subgroup)
883
  vocab_cooc_df, pmi_df, npmi_df = npmi_obj.calc_metrics(subgroup)
884
+ # Store the nPMI information for the current subgroups
885
+ subgroup_dict[subgroup] = (vocab_cooc_df, pmi_df, npmi_df)
886
+ # Pair the subgroups together, indexed by all words that
887
+ # co-occur between them.
888
+ logs.info("Computing pairwise npmi bias")
889
+ paired_results = npmi_obj.calc_paired_metrics(subgroup_pair, subgroup_dict)
890
+ UI_results = make_npmi_fig(paired_results, subgroup_pair)
891
+ return UI_results, subgroup_dict
 
 
 
 
 
 
 
892
 
893
  def set_npmi_obj(self):
894
  """
 
962
  def get_available_terms(self):
963
  return self.load_or_prepare_npmi_terms()
964
 
965
+
966
  def dummy(doc):
967
  return doc
968
 
969
+
970
  def count_vocab_frequencies(tokenized_df):
971
  """
972
  Based on an input pandas DataFrame with a 'text' column,
 
981
  )
982
  # We do this to calculate per-word statistics
983
  # Fast calculation of single word counts
984
+ logs.info(
985
+ "Fitting dummy tokenization to make matrix using the previous tokenization"
986
+ )
987
  cvec.fit(tokenized_df[TOKENIZED_FIELD])
988
  document_matrix = cvec.transform(tokenized_df[TOKENIZED_FIELD])
989
  batches = np.linspace(0, tokenized_df.shape[0], _NUM_VOCAB_BATCHES).astype(int)
 
1004
  word_count_df.index.name = WORD
1005
  return word_count_df
1006
 
1007
+
1008
  def calc_p_word(word_count_df):
1009
  # p(word)
1010
  word_count_df[PROP] = word_count_df[CNT] / float(sum(word_count_df[CNT]))
 
1015
 
1016
  def filter_vocab(vocab_counts_df):
1017
  # TODO: Add warnings (which words are missing) to log file?
1018
+ filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS, errors="ignore")
 
1019
  filtered_count = filtered_vocab_counts_df[CNT]
1020
  filtered_count_denom = float(sum(filtered_vocab_counts_df[CNT]))
1021
  filtered_vocab_counts_df[PROP] = filtered_count / filtered_count_denom
 
1024
 
1025
  ## Figures ##
1026
 
1027
+
1028
  def write_plotly(fig, fid):
1029
  write_json(plotly.io.to_json(fig), fid)
1030
 
1031
+
1032
  def read_plotly(fid):
1033
  fig = plotly.io.from_json(json.load(open(fid, encoding="utf-8")))
1034
  return fig
1035
 
1036
+
1037
  def make_fig_lengths(tokenized_df, length_field):
1038
+ fig_tok_length, axs = plt.subplots(figsize=(15, 6), dpi=150)
1039
+ sns.histplot(data=tokenized_df[length_field], kde=True, bins=100, ax=axs)
1040
+ sns.rugplot(data=tokenized_df[length_field], ax=axs)
1041
  return fig_tok_length
1042
 
1043
+
1044
  def make_fig_labels(label_df, label_names, label_field):
1045
  labels = label_df[label_field].unique()
1046
  label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
 
1121
  return fig
1122
 
1123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1124
  ## Input/Output ###
1125
 
1126
 
 
1174
  with open(json_fid, "w", encoding="utf-8") as f:
1175
  json.dump(json_dict, f)
1176
 
1177
+
1178
  def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
1179
  """
1180
  Saves the calculated nPMI statistics to their output files.
 
1194
  with open(subgroup_cooc_fid, "w+") as f:
1195
  subgroup_cooc_df.to_csv(f)
1196
 
1197
+
1198
  def write_zipf_data(z, zipf_fid):
1199
  zipf_dict = {}
1200
  zipf_dict["xmin"] = int(z.xmin)
 
1206
  zipf_dict["uniq_ranks"] = [int(rank) for rank in z.uniq_ranks]
1207
  with open(zipf_fid, "w+", encoding="utf-8") as f:
1208
  json.dump(zipf_dict, f)