import logging import matplotlib.image as mpimg import matplotlib.pyplot as plt from matplotlib.figure import Figure from PIL import Image import seaborn as sns import statistics from os.path import join as pjoin import pandas as pd import utils from utils import dataset_utils as ds_utils from collections import Counter from os.path import exists, isdir from os.path import join as pjoin TEXT_FIELD = "text" TOKENIZED_FIELD = "tokenized_text" LENGTH_FIELD = "length" UNIQ = "num_instance_lengths" AVG = "average_instance_length" STD = "standard_dev_instance_length" logs = utils.prepare_logging(__file__) def make_fig_lengths(lengths_df): # How the hell is this working? plt transforms to sns ?! logs.info("Creating lengths figure.") plt.switch_backend('Agg') fig_tok_lengths, axs = plt.subplots(figsize=(15, 6), dpi=150) plt.xlabel("Number of tokens") plt.title("Binned counts of text lengths, with kernel density estimate and ticks for each instance.") sns.histplot(data=lengths_df, kde=True, ax=axs, x=LENGTH_FIELD, legend=False) sns.rugplot(data=lengths_df, ax=axs) return fig_tok_lengths class DMTHelper: def __init__(self, dstats, load_only=False, save=True): self.tokenized_df = dstats.tokenized_df # Whether to only use cache self.load_only = load_only # Whether to try using cache first. # Must be true when self.load_only = True; this function assures that. self.use_cache = dstats.use_cache self.cache_dir = dstats.dataset_cache_dir self.save = save # Lengths class object self.lengths_obj = None # Content shared in the DMT: # The figure, the table, and the sufficient statistics (measurements) self.fig_lengths = None self.lengths_df = None self.avg_length = None self.std_length = None self.uniq_counts = None # Dict for the measurements, used in caching self.length_stats_dict = {} # Filenames, used in caching self.lengths_dir = "lengths" length_meas_json = "length_measurements.json" lengths_fig_png = "lengths_fig.png" lengths_df_json = "lengths_table.json" self.length_stats_json_fid = pjoin(self.cache_dir, self.lengths_dir, length_meas_json) self.lengths_fig_png_fid = pjoin(self.cache_dir, self.lengths_dir, lengths_fig_png) self.lengths_df_json_fid = pjoin(self.cache_dir, self.lengths_dir, lengths_df_json) def run_DMT_processing(self): """ Gets data structures for the figure, table, and measurements. """ # First look to see what we can load from cache. if self.use_cache: logs.info("Trying to load from cache...") # Defines self.lengths_df, self.length_stats_dict, self.fig_lengths # This is the table, the dict of measurements, and the figure self.load_lengths_cache() # Sets the measurements as attributes of the DMT object self.set_attributes() # If we do not have measurements loaded from cache... if not self.length_stats_dict and not self.load_only: logs.info("Preparing length results") # Compute length statistics. Uses the Lengths class. self.lengths_obj = self._prepare_lengths() # Dict of measurements self.length_stats_dict = self.lengths_obj.length_stats_dict # Table of text and lengths self.lengths_df = self.lengths_obj.lengths_df # Sets the measurements in the length_stats_dict self.set_attributes() # Makes the figure self.fig_lengths = make_fig_lengths(self.lengths_df) # Finish if self.save: logs.info("Saving results.") self._write_lengths_cache() if exists(self.lengths_fig_png_fid): # As soon as we have a figure, we redefine it as an image. # This is a hack to handle a UI display error (TODO: file bug) self.fig_lengths = Image.open(self.lengths_fig_png_fid) def set_attributes(self): if self.length_stats_dict: self.avg_length = self.length_stats_dict[AVG] self.std_length = self.length_stats_dict[STD] self.uniq_counts = self.length_stats_dict[UNIQ] else: logs.info("No lengths stats found. =(") def load_lengths_cache(self): # Dataframe with exists. Load it. if exists(self.lengths_df_json_fid): self.lengths_df = ds_utils.read_df(self.lengths_df_json_fid) # Image exists. Load it. if exists(self.lengths_fig_png_fid): self.fig_lengths = Image.open(self.lengths_fig_png_fid) # mpimg.imread(self.lengths_fig_png_fid) # Measurements exist. Load them. if exists(self.length_stats_json_fid): # Loads the length measurements self.length_stats_dict = ds_utils.read_json(self.length_stats_json_fid) def _write_lengths_cache(self): # Writes the data structures using the corresponding filetypes. ds_utils.make_path(pjoin(self.cache_dir, self.lengths_dir)) if self.length_stats_dict != {}: ds_utils.write_json(self.length_stats_dict, self.length_stats_json_fid) if isinstance(self.fig_lengths, Figure): self.fig_lengths.savefig(self.lengths_fig_png_fid) if isinstance(self.lengths_df, pd.DataFrame): ds_utils.write_df(self.lengths_df, self.lengths_df_json_fid) def _prepare_lengths(self): """Loads a Lengths object and computes length statistics""" # Length object for the dataset lengths_obj = Lengths(dataset=self.tokenized_df) lengths_obj.prepare_lengths() return lengths_obj def get_filenames(self): lengths_fid_dict = {"statistics": self.length_stats_json_fid, "figure png": self.lengths_fig_png_fid, "table": self.lengths_df_json_fid} return lengths_fid_dict class Lengths: """Generic class for text length processing. Uses DataFrames for faster processing. Given a dataframe with tokenized words in a column called TOKENIZED_TEXT, and the text instances in a column called TEXT, compute statistics. """ def __init__(self, dataset): self.dset_df = dataset # Dict of measurements self.length_stats_dict = {} # Measurements self.avg_length = None self.std_length = None self.num_uniq_lengths = None # Table of lengths and sentences self.lengths_df = None def prepare_lengths(self): self.lengths_df = pd.DataFrame(self.dset_df[TEXT_FIELD]) self.lengths_df[LENGTH_FIELD] = self.dset_df[TOKENIZED_FIELD].apply(len) lengths_array = self.lengths_df[LENGTH_FIELD] self.avg_length = statistics.mean(lengths_array) self.std_length = statistics.stdev(lengths_array) self.num_uniq_lengths = len(lengths_array.unique()) self.length_stats_dict = { "average_instance_length": self.avg_length, "standard_dev_instance_length": self.std_length, "num_instance_lengths": self.num_uniq_lengths, }