meg-huggingface commited on
Commit
d8ab532
1 Parent(s): 7c5b4e0

Continuing cache minimizing in new repository. Please see https://github.com/huggingface/DataMeasurements for full history

Browse files
data_measurements/dataset_statistics.py CHANGED
@@ -15,14 +15,15 @@
15
  import json
16
  import logging
17
  import statistics
 
18
  from os import mkdir
19
  from os.path import exists, isdir
20
  from os.path import join as pjoin
21
- from pathlib import Path
22
 
23
  import nltk
24
  import numpy as np
25
  import pandas as pd
 
26
  import plotly.express as px
27
  import plotly.figure_factory as ff
28
  import plotly.graph_objects as go
@@ -59,8 +60,6 @@ logs.propagate = False
59
 
60
  if not logs.handlers:
61
 
62
- Path('./log_files').mkdir(exist_ok=True)
63
-
64
  # Logging info to log file
65
  file = logging.FileHandler("./log_files/dataset_statistics.log")
66
  fileformat = logging.Formatter("%(asctime)s:%(message)s")
@@ -263,7 +262,12 @@ class DatasetStatisticsCacheClass:
263
  self.text_duplicate_counts_df_fid = pjoin(
264
  self.cache_path, "text_dup_counts_df.feather"
265
  )
 
 
 
 
266
  self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
 
267
 
268
  def get_base_dataset(self):
269
  """Gets a pointer to the truncated base dataset object."""
@@ -307,7 +311,11 @@ class DatasetStatisticsCacheClass:
307
  write_df(self.text_dup_counts_df, self.text_duplicate_counts_df_fid)
308
  write_json(self.general_stats_dict, self.general_stats_fid)
309
 
310
- def load_or_prepare_text_lengths(self, use_cache=False):
 
 
 
 
311
  if len(self.tokenized_df) == 0:
312
  self.tokenized_df = self.do_tokenization()
313
  self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[TOKENIZED_FIELD].apply(len)
@@ -320,12 +328,28 @@ class DatasetStatisticsCacheClass:
320
  statistics.stdev(self.tokenized_df[self.our_length_field]), 1
321
  )
322
  self.fig_tok_length = make_fig_lengths(self.tokenized_df, self.our_length_field)
323
-
324
- def load_or_prepare_embeddings(self, use_cache=False):
325
- self.embeddings = Embeddings(self, use_cache=use_cache)
326
- self.embeddings.make_hierarchical_clustering()
327
- self.fig_tree = self.embeddings.fig_tree
328
- self.node_list = self.embeddings.node_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  # get vocab with word counts
331
  def load_or_prepare_vocab(self, use_cache=True, save=True):
@@ -341,7 +365,7 @@ class DatasetStatisticsCacheClass:
341
  ):
342
  logs.info("Reading vocab from cache")
343
  self.load_vocab()
344
- self.vocab_counts_filtered_df = filter_words(self.vocab_counts_df)
345
  else:
346
  logs.info("Calculating vocab afresh")
347
  if len(self.tokenized_df) == 0:
@@ -352,7 +376,7 @@ class DatasetStatisticsCacheClass:
352
  word_count_df = count_vocab_frequencies(self.tokenized_df)
353
  logs.info("Making dfs with proportion.")
354
  self.vocab_counts_df = calc_p_word(word_count_df)
355
- self.vocab_counts_filtered_df = filter_words(self.vocab_counts_df)
356
  if save:
357
  logs.info("Writing out.")
358
  write_df(self.vocab_counts_df, self.vocab_counts_df_fid)
@@ -365,17 +389,31 @@ class DatasetStatisticsCacheClass:
365
  self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=use_cache)
366
  self.npmi_stats.load_or_prepare_npmi_terms()
367
 
368
- def load_or_prepare_zipf(self, use_cache=False):
369
- if use_cache and exists(self.zipf_fid):
 
 
 
 
 
 
 
 
 
370
  # TODO: Read zipf data so that the vocab is there.
371
  with open(self.zipf_fid, "r") as f:
372
  zipf_dict = json.load(f)
373
  self.z = Zipf()
374
  self.z.load(zipf_dict)
 
 
 
375
  else:
376
  self.z = Zipf(self.vocab_counts_df)
377
- write_zipf_data(self.z, self.zipf_fid)
378
- self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
 
 
379
 
380
  def prepare_general_text_stats(self):
381
  text_nan_count = int(self.tokenized_df.isnull().sum().sum())
@@ -476,6 +514,8 @@ class DatasetStatisticsCacheClass:
476
  self.label_field = label_field
477
 
478
  def load_or_prepare_labels(self, use_cache=False, save=True):
 
 
479
  """
480
  Extracts labels from the Dataset
481
  :param use_cache:
@@ -483,9 +523,17 @@ class DatasetStatisticsCacheClass:
483
  """
484
  # extracted labels
485
  if len(self.label_field) > 0:
486
- if use_cache and exists(self.label_dset_fid):
 
 
487
  # load extracted labels
488
  self.label_dset = load_from_disk(self.label_dset_fid)
 
 
 
 
 
 
489
  else:
490
  self.get_base_dataset()
491
  self.label_dset = self.dset.map(
@@ -495,14 +543,14 @@ class DatasetStatisticsCacheClass:
495
  batched=True,
496
  remove_columns=list(self.dset.features),
497
  )
 
 
 
 
498
  if save:
499
  # save extracted label instances
500
  self.label_dset.save_to_disk(self.label_dset_fid)
501
- self.label_df = self.label_dset.to_pandas()
502
-
503
- self.fig_labels = make_fig_labels(
504
- self.label_df, self.label_names, OUR_LABEL_FIELD
505
- )
506
 
507
  def load_vocab(self):
508
  with open(self.vocab_counts_df_fid, "rb") as f:
@@ -796,7 +844,7 @@ def calc_p_word(word_count_df):
796
  return vocab_counts_df
797
 
798
 
799
- def filter_words(vocab_counts_df):
800
  # TODO: Add warnings (which words are missing) to log file?
801
  filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS,
802
  errors="ignore")
@@ -808,6 +856,12 @@ def filter_words(vocab_counts_df):
808
 
809
  ## Figures ##
810
 
 
 
 
 
 
 
811
 
812
  def make_fig_lengths(tokenized_df, length_field):
813
  fig_tok_length = px.histogram(
@@ -815,7 +869,6 @@ def make_fig_lengths(tokenized_df, length_field):
815
  )
816
  return fig_tok_length
817
 
818
-
819
  def make_fig_labels(label_df, label_names, label_field):
820
  labels = label_df[label_field].unique()
821
  label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
@@ -896,6 +949,89 @@ def make_zipf_fig(vocab_counts_df, z):
896
  return fig
897
 
898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
899
  ## Input/Output ###
900
 
901
 
@@ -949,7 +1085,6 @@ def write_json(json_dict, json_fid):
949
  with open(json_fid, "w", encoding="utf-8") as f:
950
  json.dump(json_dict, f)
951
 
952
-
953
  def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
954
  """
955
  Saves the calculated nPMI statistics to their output files.
@@ -969,7 +1104,6 @@ def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
969
  with open(subgroup_cooc_fid, "w+") as f:
970
  subgroup_cooc_df.to_csv(f)
971
 
972
-
973
  def write_zipf_data(z, zipf_fid):
974
  zipf_dict = {}
975
  zipf_dict["xmin"] = int(z.xmin)
 
15
  import json
16
  import logging
17
  import statistics
18
+ import torch
19
  from os import mkdir
20
  from os.path import exists, isdir
21
  from os.path import join as pjoin
 
22
 
23
  import nltk
24
  import numpy as np
25
  import pandas as pd
26
+ import plotly
27
  import plotly.express as px
28
  import plotly.figure_factory as ff
29
  import plotly.graph_objects as go
 
60
 
61
  if not logs.handlers:
62
 
 
 
63
  # Logging info to log file
64
  file = logging.FileHandler("./log_files/dataset_statistics.log")
65
  fileformat = logging.Formatter("%(asctime)s:%(message)s")
 
262
  self.text_duplicate_counts_df_fid = pjoin(
263
  self.cache_path, "text_dup_counts_df.feather"
264
  )
265
+ self.fig_tok_length_fid = pjoin(self.cache_path, "fig_tok_length.json")
266
+ self.fig_labels_fid = pjoin(self.cache_path, "fig_labels.json")
267
+ self.node_list_fid = pjoin(self.cache_path, "node_list.th")
268
+ self.fig_tree_fid = pjoin(self.cache_path, "fig_tree.json")
269
  self.zipf_fid = pjoin(self.cache_path, "zipf_basic_stats.json")
270
+ self.zipf_fig_fid = pjoin(self.cache_path, "zipf_fig.json")
271
 
272
  def get_base_dataset(self):
273
  """Gets a pointer to the truncated base dataset object."""
 
311
  write_df(self.text_dup_counts_df, self.text_duplicate_counts_df_fid)
312
  write_json(self.general_stats_dict, self.general_stats_fid)
313
 
314
+ def load_or_prepare_text_lengths(self, use_cache=False, save=True):
315
+ # TODO: Everything here can be read from cache; it's in a transitory
316
+ # state atm where just the fig is cached. Clean up.
317
+ if use_cache and exists(self.fig_tok_length_fid):
318
+ self.fig_tok_length = read_plotly(self.fig_tok_length_fid)
319
  if len(self.tokenized_df) == 0:
320
  self.tokenized_df = self.do_tokenization()
321
  self.tokenized_df[LENGTH_FIELD] = self.tokenized_df[TOKENIZED_FIELD].apply(len)
 
328
  statistics.stdev(self.tokenized_df[self.our_length_field]), 1
329
  )
330
  self.fig_tok_length = make_fig_lengths(self.tokenized_df, self.our_length_field)
331
+ if save:
332
+ write_plotly(self.fig_tok_length, self.fig_tok_length_fid)
333
+
334
+ def load_or_prepare_embeddings(self, use_cache=False, save=True):
335
+ if use_cache and exists(self.node_list_fid) and exists(self.fig_tree_fid):
336
+ self.node_list = torch.load(self.node_list_fid)
337
+ self.fig_tree = read_plotly(self.fig_tree_fid)
338
+ elif use_cache and exists(self.node_list_fid):
339
+ self.node_list = torch.load(self.node_list_fid)
340
+ self.fig_tree = make_tree_plot(self.node_list,
341
+ self.text_dset)
342
+ if save:
343
+ write_plotly(self.fig_tree, self.fig_tree_fid)
344
+ else:
345
+ self.embeddings = Embeddings(self, use_cache=use_cache)
346
+ self.embeddings.make_hierarchical_clustering()
347
+ self.node_list = self.embeddings.node_list
348
+ self.fig_tree = make_tree_plot(self.node_list,
349
+ self.text_dset)
350
+ if save:
351
+ torch.save(self.node_list, self.node_list_fid)
352
+ write_plotly(self.fig_tree, self.fig_tree_fid)
353
 
354
  # get vocab with word counts
355
  def load_or_prepare_vocab(self, use_cache=True, save=True):
 
365
  ):
366
  logs.info("Reading vocab from cache")
367
  self.load_vocab()
368
+ self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
369
  else:
370
  logs.info("Calculating vocab afresh")
371
  if len(self.tokenized_df) == 0:
 
376
  word_count_df = count_vocab_frequencies(self.tokenized_df)
377
  logs.info("Making dfs with proportion.")
378
  self.vocab_counts_df = calc_p_word(word_count_df)
379
+ self.vocab_counts_filtered_df = filter_vocab(self.vocab_counts_df)
380
  if save:
381
  logs.info("Writing out.")
382
  write_df(self.vocab_counts_df, self.vocab_counts_df_fid)
 
389
  self.npmi_stats = nPMIStatisticsCacheClass(self, use_cache=use_cache)
390
  self.npmi_stats.load_or_prepare_npmi_terms()
391
 
392
+ def load_or_prepare_zipf(self, use_cache=False, save=True):
393
+ # TODO: Current UI only uses the fig, meaning the self.z here is irrelevant
394
+ # when only reading from cache. Either the UI should use it, or it should
395
+ # be removed when reading in cache
396
+ if use_cache and exists(self.zipf_fig_fid) and exists(self.zipf_fid):
397
+ with open(self.zipf_fid, "r") as f:
398
+ zipf_dict = json.load(f)
399
+ self.z = Zipf()
400
+ self.z.load(zipf_dict)
401
+ self.zipf_fig = read_plotly(self.zipf_fig_fid)
402
+ elif use_cache and exists(self.zipf_fid):
403
  # TODO: Read zipf data so that the vocab is there.
404
  with open(self.zipf_fid, "r") as f:
405
  zipf_dict = json.load(f)
406
  self.z = Zipf()
407
  self.z.load(zipf_dict)
408
+ self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
409
+ if save:
410
+ write_plotly(self.zipf_fig, self.zipf_fig_fid)
411
  else:
412
  self.z = Zipf(self.vocab_counts_df)
413
+ self.zipf_fig = make_zipf_fig(self.vocab_counts_df, self.z)
414
+ if save:
415
+ write_zipf_data(self.z, self.zipf_fid)
416
+ write_plotly(self.zipf_fig, self.zipf_fig_fid)
417
 
418
  def prepare_general_text_stats(self):
419
  text_nan_count = int(self.tokenized_df.isnull().sum().sum())
 
514
  self.label_field = label_field
515
 
516
  def load_or_prepare_labels(self, use_cache=False, save=True):
517
+ # TODO: This is in a transitory state for creating fig cache.
518
+ # Clean up to be caching and reading everything correctly.
519
  """
520
  Extracts labels from the Dataset
521
  :param use_cache:
 
523
  """
524
  # extracted labels
525
  if len(self.label_field) > 0:
526
+ if use_cache and exists(self.fig_labels_fid):
527
+ self.fig_labels = read_plotly(self.fig_labels_fid)
528
+ elif use_cache and exists(self.label_dset_fid):
529
  # load extracted labels
530
  self.label_dset = load_from_disk(self.label_dset_fid)
531
+ self.label_df = self.label_dset.to_pandas()
532
+ self.fig_labels = make_fig_labels(
533
+ self.label_df, self.label_names, OUR_LABEL_FIELD
534
+ )
535
+ if save:
536
+ write_plotly(self.fig_labels, self.fig_labels_fid)
537
  else:
538
  self.get_base_dataset()
539
  self.label_dset = self.dset.map(
 
543
  batched=True,
544
  remove_columns=list(self.dset.features),
545
  )
546
+ self.label_df = self.label_dset.to_pandas()
547
+ self.fig_labels = make_fig_labels(
548
+ self.label_df, self.label_names, OUR_LABEL_FIELD
549
+ )
550
  if save:
551
  # save extracted label instances
552
  self.label_dset.save_to_disk(self.label_dset_fid)
553
+ write_plotly(self.fig_labels, self.fig_labels_fid)
 
 
 
 
554
 
555
  def load_vocab(self):
556
  with open(self.vocab_counts_df_fid, "rb") as f:
 
844
  return vocab_counts_df
845
 
846
 
847
+ def filter_vocab(vocab_counts_df):
848
  # TODO: Add warnings (which words are missing) to log file?
849
  filtered_vocab_counts_df = vocab_counts_df.drop(_CLOSED_CLASS,
850
  errors="ignore")
 
856
 
857
  ## Figures ##
858
 
859
+ def write_plotly(fig, fid):
860
+ write_json(plotly.io.to_json(fig), fid)
861
+
862
+ def read_plotly(fid):
863
+ fig = plotly.io.from_json(json.load(open(fid, encoding="utf-8")))
864
+ return fig
865
 
866
  def make_fig_lengths(tokenized_df, length_field):
867
  fig_tok_length = px.histogram(
 
869
  )
870
  return fig_tok_length
871
 
 
872
  def make_fig_labels(label_df, label_names, label_field):
873
  labels = label_df[label_field].unique()
874
  label_sums = [len(label_df[label_df[label_field] == label]) for label in labels]
 
949
  return fig
950
 
951
 
952
+ def make_tree_plot(node_list, text_dset):
953
+ nid_map = dict([(node["nid"], nid) for nid, node in enumerate(node_list)])
954
+
955
+ for nid, node in enumerate(node_list):
956
+ node["label"] = node.get(
957
+ "label",
958
+ f"{nid:2d} - {node['weight']:5d} items <br>"
959
+ + "<br>".join(
960
+ [
961
+ "> " + txt[:64] + ("..." if len(txt) >= 63 else "")
962
+ for txt in list(
963
+ set(text_dset.select(node["example_ids"])[OUR_TEXT_FIELD])
964
+ )[:5]
965
+ ]
966
+ ),
967
+ )
968
+
969
+ # make plot nodes
970
+ # TODO: something more efficient than set to remove duplicates
971
+ labels = [node["label"] for node in node_list]
972
+
973
+ root = node_list[0]
974
+ root["X"] = 0
975
+ root["Y"] = 0
976
+
977
+ def rec_make_coordinates(node):
978
+ total_weight = 0
979
+ add_weight = len(node["example_ids"]) - sum(
980
+ [child["weight"] for child in node["children"]]
981
+ )
982
+ for child in node["children"]:
983
+ child["X"] = node["X"] + total_weight
984
+ child["Y"] = node["Y"] - 1
985
+ total_weight += child["weight"] + add_weight / len(node["children"])
986
+ rec_make_coordinates(child)
987
+
988
+ rec_make_coordinates(root)
989
+
990
+ E = [] # list of edges
991
+ Xn = []
992
+ Yn = []
993
+ Xe = []
994
+ Ye = []
995
+ for nid, node in enumerate(node_list):
996
+ Xn += [node["X"]]
997
+ Yn += [node["Y"]]
998
+ for child in node["children"]:
999
+ E += [(nid, nid_map[child["nid"]])]
1000
+ Xe += [node["X"], child["X"], None]
1001
+ Ye += [node["Y"], child["Y"], None]
1002
+
1003
+ # make figure
1004
+ fig = go.Figure()
1005
+ fig.add_trace(
1006
+ go.Scatter(
1007
+ x=Xe,
1008
+ y=Ye,
1009
+ mode="lines",
1010
+ line=dict(color="rgb(210,210,210)", width=1),
1011
+ hoverinfo="none",
1012
+ )
1013
+ )
1014
+ fig.add_trace(
1015
+ go.Scatter(
1016
+ x=Xn,
1017
+ y=Yn,
1018
+ mode="markers",
1019
+ name="nodes",
1020
+ marker=dict(
1021
+ symbol="circle-dot",
1022
+ size=18,
1023
+ color="#6175c1",
1024
+ line=dict(color="rgb(50,50,50)", width=1)
1025
+ # '#DB4551',
1026
+ ),
1027
+ text=labels,
1028
+ hoverinfo="text",
1029
+ opacity=0.8,
1030
+ )
1031
+ )
1032
+ return fig
1033
+
1034
+
1035
  ## Input/Output ###
1036
 
1037
 
 
1085
  with open(json_fid, "w", encoding="utf-8") as f:
1086
  json.dump(json_dict, f)
1087
 
 
1088
  def write_subgroup_npmi_data(subgroup, subgroup_dict, subgroup_files):
1089
  """
1090
  Saves the calculated nPMI statistics to their output files.
 
1104
  with open(subgroup_cooc_fid, "w+") as f:
1105
  subgroup_cooc_df.to_csv(f)
1106
 
 
1107
  def write_zipf_data(z, zipf_fid):
1108
  zipf_dict = {}
1109
  zipf_dict["xmin"] = int(z.xmin)