datasets-explorer / clarin_datasets /punctuation_restoration_dataset.py
Mariusz Kossakowski
Add tSNE projection
77405f7
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset
from sklearn.manifold import TSNE
import streamlit as st
from clarin_datasets.dataset_to_show import DatasetToShow
from clarin_datasets.utils import embed_sentence, PLOT_COLOR_PALETTE
class PunctuationRestorationDataset(DatasetToShow):
def __init__(self):
DatasetToShow.__init__(self)
self.data_dict_named = None
self.dataset_name = "clarin-pl/2021-punctuation-restoration"
self.description = [
f"""
Dataset link: https://huggingface.co/datasets/{self.dataset_name}
Speech transcripts generated by Automatic Speech Recognition (ASR) systems typically do
not contain any punctuation or capitalization. In longer stretches of automatically recognized speech,
the lack of punctuation affects the general clarity of the output text [1]. The primary purpose of
punctuation (PR) and capitalization restoration (CR) as a distinct natural language processing (NLP) task is
to improve the legibility of ASR-generated text, and possibly other types of texts without punctuation. Aside
from their intrinsic value, PR and CR may improve the performance of other NLP aspects such as Named Entity
Recognition (NER), part-of-speech (POS) and semantic parsing or spoken dialog segmentation [2, 3]. As useful
as it seems, It is hard to systematically evaluate PR on transcripts of conversational language; mainly
because punctuation rules can be ambiguous even for originally written texts, and the very nature of
naturally-occurring spoken language makes it difficult to identify clear phrase and sentence boundaries [4,
5]. Given these requirements and limitations, a PR task based on a redistributable corpus of read speech was
suggested. 1200 texts included in this collection (totaling over 240,000 words) were selected from two
distinct sources: WikiNews and WikiTalks. Punctuation found in these sources should be approached with some
reservation when used for evaluation: these are original texts and may contain some user-induced errors and
bias. The texts were read out by over a hundred different speakers. Original texts with punctuation were
forced-aligned with recordings and used as the ideal ASR output. The goal of the task is to provide a
solution for restoring punctuation in the test set collated for this task. The test set consists of
time-aligned ASR transcriptions of read texts from the two sources. Participants are encouraged to use both
text-based and speech-derived features to identify punctuation symbols (e.g. multimodal framework [6]). In
addition, the train set is accompanied by reference text corpora of WikiNews and WikiTalks data that can be
used in training and fine-tuning punctuation models.
""",
"Task description",
"The purpose of this task is to restore punctuation in the ASR recognition of texts read out loud.",
"clarin_datasets/punctuation_restoration_task.png",
]
def load_data(self):
raw_dataset = load_dataset(self.dataset_name)
self.data_dict = {
subset: raw_dataset[subset].to_pandas() for subset in self.subsets
}
self.data_dict_named = {}
for subset in self.subsets:
references = raw_dataset[subset]["tags"]
references_named = [
[
raw_dataset[subset].features["tags"].feature.names[label]
for label in labels
]
for labels in references
]
self.data_dict_named[subset] = pd.DataFrame(
{
"tokens": self.data_dict[subset]["tokens"],
"tags": references_named,
}
)
def show_dataset(self):
header = st.container()
description = st.container()
dataframe_head = st.container()
class_distribution = st.container()
tsne_projection = st.container()
with header:
st.title(self.dataset_name)
with description:
st.header("Dataset description")
st.write(self.description[0])
st.subheader(self.description[1])
st.write(self.description[2])
st.image(self.description[3])
full_dataframe = pd.concat(self.data_dict.values(), axis="rows")
with dataframe_head:
st.header("First 10 observations of the chosen subset")
subset_to_show = st.selectbox(
label="Select subset to see", options=self.subsets
)
df_to_show = self.data_dict[subset_to_show].head(10)
st.dataframe(df_to_show)
st.text_area(label="LaTeX code", value=df_to_show.style.to_latex())
class_distribution_dict = {}
for subset in self.subsets:
all_labels_from_subset = self.data_dict_named[subset]["tags"].tolist()
all_labels_from_subset = [
x for subarray in all_labels_from_subset for x in subarray if x != "O"
]
all_labels_from_subset = pd.Series(all_labels_from_subset)
class_distribution_dict[subset] = (
all_labels_from_subset.value_counts(normalize=True)
.sort_index()
.reset_index()
.rename({"index": "class", 0: subset}, axis="columns")
)
class_distribution_df = pd.merge(
class_distribution_dict["train"],
class_distribution_dict["test"],
on="class",
)
with class_distribution:
st.header("Class distribution in each subset (without 'O')")
st.dataframe(class_distribution_df)
st.text_area(
label="LaTeX code", value=class_distribution_df.style.to_latex()
)
with tsne_projection:
st.header("t-SNE projection of the dataset")
subset_to_project = st.selectbox(
label="Select subset to project", options=self.subsets
)
tokens_unzipped = self.data_dict_named[subset_to_project]["tokens"].tolist()
tokens_unzipped = np.array([x for subarray in tokens_unzipped for x in subarray])
labels_unzipped = self.data_dict_named[subset_to_project]["tags"].tolist()
labels_unzipped = np.array([x for subarray in labels_unzipped for x in subarray])
df_unzipped = pd.DataFrame(
{
"tokens": tokens_unzipped,
"tags": labels_unzipped,
}
)
df_unzipped = df_unzipped.loc[df_unzipped["tags"] != "O"]
tokens_unzipped = df_unzipped["tokens"].values
labels_unzipped = df_unzipped["tags"].values
mapping_dict = {name: number for number, name in enumerate(set(labels_unzipped))}
labels_as_ints = [mapping_dict[label] for label in labels_unzipped]
embedded_tokens = np.array(
[embed_sentence(x) for x in tokens_unzipped]
)
reducer = TSNE(
n_components=2
)
transformed_embeddings = reducer.fit_transform(embedded_tokens)
fig, ax = plt.subplots()
ax.scatter(
x=transformed_embeddings[:, 0],
y=transformed_embeddings[:, 1],
c=[
PLOT_COLOR_PALETTE[i] for i in labels_as_ints
]
)
st.pyplot(fig)