acl-spectrum / app.py
ehsk's picture
Update app.py
4077dc3
raw
history blame
4.87 kB
import os
import re
import pandas as pd
import plotly.express as px
import streamlit as st
st.set_page_config(layout="wide")
DATA_FILE = "data/aclanthology2016-23_specter2_base.json"
THEMES = {"cluster": "fall", "year": "mint", "source": "phase"}
st.markdown(
"""
<link href="https://cdn.jsdelivr.net/npm/bootstrap@4.6.1/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha256-DF7Zhf293AJxJNTmh5zhoYYIMs2oXitRfBjY+9L//AY=" crossorigin="anonymous">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Permanent+Marker&display=swap" rel="stylesheet">
<style>
.title {
font-family: 'Permanent Marker', cursive;
font-size: 2.0rem;
}
</style>""",
unsafe_allow_html=True,
)
st.sidebar.write(
"""<center><p class="title">
acl-spectrum
</p></center>""",
unsafe_allow_html=True,
)
st.sidebar.write(
"""<p class="text-justify">
An interactive t-SNE visualization of <a href="https://huggingface.co/allenai/specter2_base">spectre2</a> embeddings
featuring over 12K papers (titles and abstracts) from the <a href="https://aclanthology.org/">ACL Anthology</a>
spanning 2016 to 2023.
For more details, check out our <a href="https://huggingface.co/spaces/gwf-uwaterloo/acl-spectrum/blob/main/README.md">README</a>
and our step-by-step guide <a href="https://huggingface.co/spaces/gwf-uwaterloo/acl-spectrum/blob/main/scipapers_scatter.ipynb">here</a>.
</p>""",
unsafe_allow_html=True,
)
st.sidebar.markdown(
"Happy exploring! :rocket::rocket:"
)
def to_string_authors(list_of_authors):
if len(list_of_authors) > 5:
return ", ".join(list_of_authors[:5]) + ", et al."
elif len(list_of_authors) > 2:
return ", ".join(list_of_authors[:-1]) + ", and " + list_of_authors[-1]
else:
return " and ".join(list_of_authors)
def load_df(data_file: os.PathLike):
df = pd.read_json(data_file, orient="records")
df["x"] = df["point2d"].apply(lambda x: x[0])
df["y"] = df["point2d"].apply(lambda x: x[1])
df["authors_trimmed"] = df.authors.apply(
lambda row: to_string_authors(
[(x[x.index(",") + 1 :].strip() + " " + x.split(",")[0].strip()) if "," in x else x for x in row]
)
)
if "publication_type" in df.columns:
df["type"] = df["publication_type"]
df = df.drop(columns=["point2d", "publication_type"])
else:
df = df.drop(columns=["point2d"])
return df
@st.cache_data
def load_dataframe():
return load_df(DATA_FILE)
DF = load_dataframe()
DF["opacity"] = 0.04
min_year, max_year = DF["year"].min(), DF["year"].max()
with st.sidebar:
venues = st.multiselect(
"Venues",
["ACL", "EMNLP", "NAACL", "TACL"],
["ACL", "EMNLP", "NAACL", "TACL"],
)
start_year, end_year = st.select_slider(
"Publication year",
options=[str(y) for y in range(min_year, max_year + 1)],
value=(str(min_year), str(max_year)),
)
author_names = st.text_input("Author names (separated by comma)")
title = st.text_input("Title")
start_year = int(start_year)
end_year = int(end_year)
df_mask = (DF["year"] >= start_year) & (DF["year"] <= end_year)
if 0 < len(venues) < 4:
selected_venues = [v.lower() for v in venues]
df_mask = df_mask & DF["source"].isin(selected_venues)
elif not venues:
st.write(":red[Please select a venue]")
if author_names:
authors = [a.strip() for a in author_names.split(",")]
author_mask = DF.authors.apply(
lambda row: all(any(re.match(rf".*{a}.*", x, re.IGNORECASE) for x in row) for a in authors)
)
df_mask = df_mask & author_mask
if title:
df_mask = df_mask & DF.title.apply(lambda x: title.lower() in x.lower())
DF.loc[df_mask, "opacity"] = 1.0
st.write(f"Number of points: {DF[df_mask].shape[0]}")
color = st.selectbox("Color", ("cluster", "year", "source"))
fig = px.scatter(
DF,
x="x",
y="y",
opacity=DF["opacity"],
color=color,
width=1000,
height=800,
custom_data=("title", "authors_trimmed", "year", "source", "type"),
color_continuous_scale=THEMES[color],
)
fig.update_traces(
hovertemplate="<b>%{customdata[0]}</b><br>%{customdata[1]}<br>%{customdata[2]}<br><i>%{customdata[3]}</i>"
)
fig.update_layout(
# margin=dict(l=10, r=10, t=10, b=10),
showlegend=False,
font=dict(
family="Times New Roman",
size=30,
),
hoverlabel=dict(
align="left",
font_size=14,
font_family="Rockwell",
namelength=-1,
),
)
fig.update_xaxes(title="")
fig.update_yaxes(title="")
st.plotly_chart(fig, use_container_width=True)