spectral-metric / app.py
Dref360's picture
Add application
6e367e1
raw
history blame contribute delete
No virus
7.29 kB
from os import WEXITED
import streamlit as st
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import torch
from spectral_metric.estimator import CumulativeGradientEstimator
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from spectral_metric.visualize import make_graph
from scipy.stats import entropy
import pandas as pd
from utils import show_most_confused
AVAILABLE_DATASETS = [
("clinc_oos", "small"),
("clinc_oos", "imbalanced"),
("banking77",),
("tweet_eval", "emoji"),
("tweet_eval", "stance_climate")
]
label_column_mapping = {
"clinc_oos": "intent",
"banking77": "label",
"tweet_eval": "label",
}
st.title("Perform a data-driven analysis using `spectral-metric`")
st.markdown(
"""Today, I would like to analyze this dataset and perform a
data-driven analysis by `sentence-transformers` to extract features
and `spectral_metric` to perform a spectral analysis of the dataset.
For support, please submit an issue on [our repo](https://github.com/Dref360/spectral-metric) or [contact me directly](https://github.com/Dref360)
"""
)
st.markdown(
"""
Let's load your dataset, we will run our analysis on the train set.
"""
)
dataset_name = st.selectbox("Select your dataset", AVAILABLE_DATASETS)
if st.button("Start the analysis"):
label_column = label_column_mapping[dataset_name[0]]
# We perform the analysis on the train set.
ds = load_dataset(*dataset_name)["train"]
class_names = ds.features[label_column].names
ds
# I use all-MiniLM-L12-v2 as it is a good compromise between speed and performance.
embedder = SentenceTransformer("all-MiniLM-L12-v2")
# We will get **normalized** features for the dataset using our embedder.
with st.spinner(text="Computing embeddings..."):
features = embedder.encode(
ds["text"],
device=0 if torch.cuda.is_available() else "cpu",
normalize_embeddings=True,
)
st.markdown(
"""
### Running the spectral analysis
Now that we have our embeddings extracted by our sentence embedder, we can make an in-depth analysis of these features.
To do so, we will use CSG (Branchaud-Charron et al, 2019), a technique that combines Probability Product Kernels (Jebara et al, 2004) and spectral clustering to analyze a dataset without training a model.
In this notebook, we won't use the actual CSG metrics, but we will use the $W$ matrix.
This matrix is computed as:
* Run a Probabilistic K-NN on the dataset (optionally done via Monte-Carlo)
* Compute the average prediction per class (results in the $S$ matrix)
* Symetrize this matrix using Bray-Curtis distance metric, a metric that was made to compare samplings from a distribution.
These steps are all done by `spectral_metric.estimator.CumulativeGradientEstimator`.
"""
)
X, y = features, np.array(ds[label_column]) # Your dataset with shape [N, ?], [N]
estimator = CumulativeGradientEstimator(M_sample=250, k_nearest=9, distance="cosine")
estimator.fit(data=X, target=y)
fig, ax = plt.subplots(figsize=(10, 5))
sns.heatmap(estimator.W, ax=ax, cmap="rocket_r")
ax.set_title(f"Similarity between classes in {dataset_name[0]}")
st.pyplot(fig)
st.markdown(
"""
This figure will be hard to read on most datasets, so we need to go deeper.
Let's do the following analysis:
1. Find the class with the highest entropy ie. the class that is the most confused with others.
2. Find the 5 pairs of classes that are the most confused.
3. Find the items in these pairs that contribute to the confusion.
"""
)
entropy_per_class = entropy(estimator.W / estimator.W.sum(-1)[:, None], axis=-1)
st.markdown(
f"Most confused class (highest entropy): {class_names[np.argmax(entropy_per_class)]}",
)
st.markdown(
f"Least confused class (lowest entropy): {class_names[np.argmin(entropy_per_class)]}",
)
pairs = list(zip(*np.unravel_index(np.argsort(estimator.W, axis=None), estimator.W.shape)))[::-1]
pairs = [(i,j) for i,j in pairs if i != j]
lst = []
for idx, (i,j) in enumerate(pairs[::2][:10]):
lst.append({"Intent A" : class_names[i], "Intent B": class_names[j], "Similarity": estimator.W[i,j]})
st.title("Most similar pairs")
st.dataframe(pd.DataFrame(lst).sort_values("Similarity", ascending=False))
st.markdown("""
## Analysis
By looking at the top-10 most similar pairs, we get some good insights on the dataset.
While this does not 100% indicates that the classifier trained downstream will have issues with these pairs,
we know that these intents are similar.
In consequence, the classifier might not be able to separate them easily.
Let's now look at which utterance is contributing the most to the confusion.
""")
first_pair = pairs[0]
second_pair = pairs[2]
st.dataframe(pd.DataFrame({**show_most_confused(ds,first_pair[0], first_pair[1], estimator, class_names),
**show_most_confused(ds, first_pair[1], first_pair[0], estimator, class_names)}),
width=1000)
st.markdown("### We can do the same for the second pair")
st.dataframe(pd.DataFrame({**show_most_confused(ds, second_pair[0], second_pair[1], estimator, class_names),
**show_most_confused(ds, second_pair[1], second_pair[0], estimator, class_names)}),
width=1000)
st.markdown(f"""
From the top-5 most confused examples per pair, we can see that the sentences are quite similar.
While a human could easily separate the two intents, we see that the sentences are made of the same words which might confuse the classifier.
Some sentences could be seen as mislabelled.
Of course, these features come from a model that was not trained to separate these classes,
they come from a general-purpose language model.
The goal of this analysis is to give insights to the data scientist before they train an expensive model.
If we were to train a model on this dataset, the model could probably handle the confusion between `{class_names[first_pair[0]]}`
and `{class_names[first_pair[1]]}`,
but maybe not easily.
## Conclusion
In this tutorial, we covered how to conduct a data-driven analysis for on a text classification dataset.
By using sentence embedding and the `spectral_metric` library, we found the intents that would be the most likely to be confused and which utterances caused this confusion.
Following our analysis, we could take the following actions:
1. Upweight the classes that are confused during training for the model to better learn to separate them.
2. Merge similar classes together.
3. Analyse sentences that are confusing to find mislabelled sentences.
If you have any questions, suggestions or ideas for this library please reach out:
1. frederic.branchaud.charron@gmail.com
2. [@Dref360 on Github](https://github.com/Dref360)
If you have a dataset that you think would be a good fit for this analysis let me know too!
""")