import evaluate import logging import os import pandas as pd import plotly.express as px import utils import utils.dataset_utils as ds_utils from collections import Counter from os.path import exists, isdir from os.path import join as pjoin LABEL_FIELD = "labels" LABEL_NAMES = "label_names" LABEL_LIST = "label_list" LABEL_MEASUREMENT = "label_measurement" # Specific to the evaluate library EVAL_LABEL_MEASURE = "label_distribution" EVAL_LABEL_ID = "labels" EVAL_LABEL_FRAC = "fractions" # TODO: This should ideally be in what's returned from the evaluate library EVAL_LABEL_SUM = "sums" logs = utils.prepare_logging(__file__) def map_labels(label_field, ds_name_to_dict, ds_name, config_name): try: label_field, label_names = ( ds_name_to_dict[ds_name][config_name]["features"][label_field][0] if len( ds_name_to_dict[ds_name][config_name]["features"][label_field]) > 0 else ((), []) ) except KeyError as e: logs.exception(e) logs.warning("Not returning a label-name mapping") return [] return label_names def make_label_results_dict(label_measurement, label_names): label_dict = {LABEL_MEASUREMENT: label_measurement, LABEL_NAMES: label_names} return label_dict def make_label_fig(label_results, chart_type="pie"): try: label_names = label_results[LABEL_NAMES] label_measurement = label_results[LABEL_MEASUREMENT] label_sums = label_measurement[EVAL_LABEL_SUM] if chart_type == "bar": fig_labels = plt.bar( label_measurement[EVAL_LABEL_MEASURE][EVAL_LABEL_ID], label_measurement[EVAL_LABEL_MEASURE][EVAL_LABEL_FRAC]) else: if chart_type != "pie": logs.info("Oops! Don't have that chart-type implemented.") logs.info("Making the default pie chart") # IMDB - unsupervised has a labels column where all values are -1, # which breaks the assumption that # the number of label_names == the number of label_sums. # This handles that case, assuming it will happen in other datasets. if len(label_names) != len(label_sums): logs.warning("Can't make a figure with the given label names: " "We don't have the right amount of label types " "to apply them to!") return False fig_labels = px.pie(names=label_names, values=label_sums) except KeyError: logs.info("Input label data missing required key(s).") logs.info("We require %s, %s" % (LABEL_NAMES, LABEL_MEASUREMENT)) logs.info("We found: %s" % ",".join(label_results.keys())) return False return fig_labels def extract_label_names(label_field, ds_name, config_name): ds_name_to_dict = ds_utils.get_dataset_info_dicts(ds_name) label_names = map_labels(label_field, ds_name_to_dict, ds_name, config_name) return label_names class DMTHelper: """Helper class for the Data Measurements Tool. This allows us to keep all variables and functions related to labels in one file. """ def __init__(self, dstats, load_only, save): logs.info("Initializing labels.") # -- Data Measurements Tool variables self.label_results = dstats.label_results self.fig_labels = dstats.fig_labels self.use_cache = dstats.use_cache self.cache_dir = dstats.dataset_cache_dir self.load_only = load_only self.save = save # -- Hugging Face Dataset variables self.label_field = dstats.label_field # Input HuggingFace dataset self.dset = dstats.dset self.dset_name = dstats.dset_name self.dset_config = dstats.dset_config self.label_names = dstats.label_names # -- Filenames self.label_dir = "labels" label_json = "labels.json" label_fig_json = "labels_fig.json" label_fig_html = "labels_fig.html" self.labels_json_fid = pjoin(self.cache_dir, self.label_dir, label_json) self.labels_fig_json_fid = pjoin(self.cache_dir, self.label_dir, label_fig_json) self.labels_fig_html_fid = pjoin(self.cache_dir, self.label_dir, label_fig_html) def run_DMT_processing(self): """ Loads or prepares the Labels measurements and figure as specified by the DMT options. """ # First look to see what we can load from cache. if self.use_cache: logs.info("Trying to load labels.") self.fig_labels, self.label_results = self._load_label_cache() if self.fig_labels: logs.info("Loaded cached label figure.") if self.label_results: logs.info("Loaded cached label results.") # If we can prepare the results afresh... if not self.load_only: # If we didn't load them already, compute label statistics. if not self.label_results: logs.info("Preparing labels.") self.label_results = self._prepare_labels() # If we didn't load it already, create figure. if not self.fig_labels: logs.info("Creating label figure.") self.fig_labels = \ make_label_fig(self.label_results) # Finish if self.save: self._write_label_cache() def _load_label_cache(self): fig_labels = {} label_results = {} # Measurements exist. Load them. if exists(self.labels_json_fid): # Loads the label list, names, and results label_results = ds_utils.read_json(self.labels_json_fid) # Image exists. Load it. if exists(self.labels_fig_json_fid): fig_labels = ds_utils.read_plotly(self.labels_fig_json_fid) return fig_labels, label_results def _prepare_labels(self): """Loads a Labels object and computes label statistics""" # Label object for the dataset label_obj = Labels(dataset=self.dset, dataset_name=self.dset_name, config_name=self.dset_config) # TODO: Handle the case where there are multiple label columns. # The logic throughout the code assumes only one. if type(self.label_field) == tuple: label_field = self.label_field[0] elif type(self.label_field) == str: label_field = self.label_field else: logs.warning("Unexpected format %s for label column name(s). " "Not computing label statistics." % type(self.label_field)) return {} label_results = label_obj.prepare_labels(label_field, self.label_names) return label_results def _write_label_cache(self): ds_utils.make_path(pjoin(self.cache_dir, self.label_dir)) if self.label_results: ds_utils.write_json(self.label_results, self.labels_json_fid) if self.fig_labels: ds_utils.write_plotly(self.fig_labels, self.labels_fig_json_fid) self.fig_labels.write_html(self.labels_fig_html_fid) def get_label_filenames(self): label_fid_dict = {"statistics": self.labels_json_fid, "figure json": self.labels_fig_json_fid, "figure html": self.labels_fig_html_fid} return label_fid_dict class Labels: """Generic class for label processing. Uses the Dataset to extract the label column and compute label measurements. """ def __init__(self, dataset, dataset_name=None, config_name=None): # Input HuggingFace Dataset. self.dset = dataset # These are used to extract label names, when the label names # are stored in the Dataset object but not in the "label" column # we are working with, which may instead just be ints corresponding to # the names self.ds_name = dataset_name self.config_name = config_name # For measurement data and additional metadata. self.label_results_dict = {} def prepare_labels(self, label_field, label_names=[]): """ Uses the evaluate library to return the label distribution. """ logs.info("Inside main label calculation function.") logs.debug("Looking for label field called '%s'" % label_field) # The input Dataset object # When the label field is not found, an error will be thrown. if label_field in self.dset.features: label_list = self.dset[label_field] else: logs.warning("No label column found -- nothing to do. Returning.") logs.debug(self.dset.features) return {} # Get the evaluate library's measurement for label distro. label_distribution = evaluate.load(EVAL_LABEL_MEASURE) # Measure the label distro. label_measurement = label_distribution.compute(data=label_list) # TODO: Incorporate this summation into what the evaluate library returns. label_sum_dict = Counter(label_list) label_sums = [label_sum_dict[key] for key in sorted(label_sum_dict)] label_measurement["sums"] = label_sums if not label_names: # Have to extract the label names from the Dataset object when the # actual dataset columns are just ints representing the label names. label_names = extract_label_names(label_field, self.ds_name, self.config_name) label_results = make_label_results_dict(label_measurement, label_names) return label_results