from os import WEXITED import streamlit as st from datasets import load_dataset from sentence_transformers import SentenceTransformer import torch from spectral_metric.estimator import CumulativeGradientEstimator import numpy as np import seaborn as sns import matplotlib.pyplot as plt from spectral_metric.visualize import make_graph from scipy.stats import entropy import pandas as pd from utils import show_most_confused AVAILABLE_DATASETS = [ ("clinc_oos", "small"), ("clinc_oos", "imbalanced"), ("banking77",), ("tweet_eval", "emoji"), ("tweet_eval", "stance_climate") ] label_column_mapping = { "clinc_oos": "intent", "banking77": "label", "tweet_eval": "label", } st.title("Perform a data-driven analysis using `spectral-metric`") st.markdown( """Today, I would like to analyze this dataset and perform a data-driven analysis by `sentence-transformers` to extract features and `spectral_metric` to perform a spectral analysis of the dataset. For support, please submit an issue on [our repo](https://github.com/Dref360/spectral-metric) or [contact me directly](https://github.com/Dref360) """ ) st.markdown( """ Let's load your dataset, we will run our analysis on the train set. """ ) dataset_name = st.selectbox("Select your dataset", AVAILABLE_DATASETS) if st.button("Start the analysis"): label_column = label_column_mapping[dataset_name[0]] # We perform the analysis on the train set. ds = load_dataset(*dataset_name)["train"] class_names = ds.features[label_column].names ds # I use all-MiniLM-L12-v2 as it is a good compromise between speed and performance. embedder = SentenceTransformer("all-MiniLM-L12-v2") # We will get **normalized** features for the dataset using our embedder. with st.spinner(text="Computing embeddings..."): features = embedder.encode( ds["text"], device=0 if torch.cuda.is_available() else "cpu", normalize_embeddings=True, ) st.markdown( """ ### Running the spectral analysis Now that we have our embeddings extracted by our sentence embedder, we can make an in-depth analysis of these features. To do so, we will use CSG (Branchaud-Charron et al, 2019), a technique that combines Probability Product Kernels (Jebara et al, 2004) and spectral clustering to analyze a dataset without training a model. In this notebook, we won't use the actual CSG metrics, but we will use the $W$ matrix. This matrix is computed as: * Run a Probabilistic K-NN on the dataset (optionally done via Monte-Carlo) * Compute the average prediction per class (results in the $S$ matrix) * Symetrize this matrix using Bray-Curtis distance metric, a metric that was made to compare samplings from a distribution. These steps are all done by `spectral_metric.estimator.CumulativeGradientEstimator`. """ ) X, y = features, np.array(ds[label_column]) # Your dataset with shape [N, ?], [N] estimator = CumulativeGradientEstimator(M_sample=250, k_nearest=9, distance="cosine") estimator.fit(data=X, target=y) fig, ax = plt.subplots(figsize=(10, 5)) sns.heatmap(estimator.W, ax=ax, cmap="rocket_r") ax.set_title(f"Similarity between classes in {dataset_name[0]}") st.pyplot(fig) st.markdown( """ This figure will be hard to read on most datasets, so we need to go deeper. Let's do the following analysis: 1. Find the class with the highest entropy ie. the class that is the most confused with others. 2. Find the 5 pairs of classes that are the most confused. 3. Find the items in these pairs that contribute to the confusion. """ ) entropy_per_class = entropy(estimator.W / estimator.W.sum(-1)[:, None], axis=-1) st.markdown( f"Most confused class (highest entropy): {class_names[np.argmax(entropy_per_class)]}", ) st.markdown( f"Least confused class (lowest entropy): {class_names[np.argmin(entropy_per_class)]}", ) pairs = list(zip(*np.unravel_index(np.argsort(estimator.W, axis=None), estimator.W.shape)))[::-1] pairs = [(i,j) for i,j in pairs if i != j] lst = [] for idx, (i,j) in enumerate(pairs[::2][:10]): lst.append({"Intent A" : class_names[i], "Intent B": class_names[j], "Similarity": estimator.W[i,j]}) st.title("Most similar pairs") st.dataframe(pd.DataFrame(lst).sort_values("Similarity", ascending=False)) st.markdown(""" ## Analysis By looking at the top-10 most similar pairs, we get some good insights on the dataset. While this does not 100% indicates that the classifier trained downstream will have issues with these pairs, we know that these intents are similar. In consequence, the classifier might not be able to separate them easily. Let's now look at which utterance is contributing the most to the confusion. """) first_pair = pairs[0] second_pair = pairs[2] st.dataframe(pd.DataFrame({**show_most_confused(ds,first_pair[0], first_pair[1], estimator, class_names), **show_most_confused(ds, first_pair[1], first_pair[0], estimator, class_names)}), width=1000) st.markdown("### We can do the same for the second pair") st.dataframe(pd.DataFrame({**show_most_confused(ds, second_pair[0], second_pair[1], estimator, class_names), **show_most_confused(ds, second_pair[1], second_pair[0], estimator, class_names)}), width=1000) st.markdown(f""" From the top-5 most confused examples per pair, we can see that the sentences are quite similar. While a human could easily separate the two intents, we see that the sentences are made of the same words which might confuse the classifier. Some sentences could be seen as mislabelled. Of course, these features come from a model that was not trained to separate these classes, they come from a general-purpose language model. The goal of this analysis is to give insights to the data scientist before they train an expensive model. If we were to train a model on this dataset, the model could probably handle the confusion between `{class_names[first_pair[0]]}` and `{class_names[first_pair[1]]}`, but maybe not easily. ## Conclusion In this tutorial, we covered how to conduct a data-driven analysis for on a text classification dataset. By using sentence embedding and the `spectral_metric` library, we found the intents that would be the most likely to be confused and which utterances caused this confusion. Following our analysis, we could take the following actions: 1. Upweight the classes that are confused during training for the model to better learn to separate them. 2. Merge similar classes together. 3. Analyse sentences that are confusing to find mislabelled sentences. If you have any questions, suggestions or ideas for this library please reach out: 1. frederic.branchaud.charron@gmail.com 2. [@Dref360 on Github](https://github.com/Dref360) If you have a dataset that you think would be a good fit for this analysis let me know too! """)