Spaces:
Paused
Paused
File size: 2,418 Bytes
df321c6 6af31c0 df321c6 6af31c0 e18eaf4 366eb95 df321c6 e4c8a16 e18eaf4 df321c6 6af31c0 e18eaf4 df321c6 e18eaf4 6af31c0 e18eaf4 366eb95 e18eaf4 6af31c0 e18eaf4 6af31c0 e18eaf4 6af31c0 e18eaf4 df321c6 366eb95 e18eaf4 df321c6 366eb95 df321c6 e18eaf4 df321c6 366eb95 6af31c0 df321c6 366eb95 df321c6 2f6de14 df321c6 366eb95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import re
import pandas as pd
import plotly.express as px
import streamlit as st
st.set_page_config(layout="wide")
DATA_FILE = "data/anthology-2020-23_specter2_base.json"
THEMES = {"cluster": "fall", "year": "mint", "source": "phase"}
def load_df(data_file: os.PathLike):
df = pd.read_json(data_file, orient="records")
df["x"] = df["point2d"].apply(lambda x: x[0])
df["y"] = df["point2d"].apply(lambda x: x[1])
if "publication_type" in df.columns:
df["type"] = df["publication_type"]
df = df.drop(columns=["point2d", "publication_type"])
else:
df = df.drop(columns=["point2d"])
return df
@st.cache_data
def load_dataframe():
return load_df(DATA_FILE)
DF = load_dataframe()
with st.sidebar:
venues = st.multiselect(
"Venues",
["ACL", "EMNLP", "NAACL", "TACL"],
["ACL", "EMNLP", "NAACL", "TACL"],
)
start_year, end_year = st.select_slider(
"Publication year", options=("2020", "2021", "2022", "2023"), value=("2020", "2023")
)
author_names = st.text_input("Author names (separated by comma)")
title = st.text_input("Title")
start_year = int(start_year)
end_year = int(end_year)
df = DF[(DF["year"] >= start_year) & (DF["year"] <= end_year)]
if 0 < len(venues) < 4:
selected_venues = [v.lower() for v in venues]
df = df[df["source"].isin(selected_venues)]
elif not venues:
st.write(":red[Please select a venue]")
if author_names:
authors = [a.strip() for a in author_names.split(",")]
author_mask = df.authors.apply(
lambda row: all(any(re.match(rf".*{a}.*", x, re.IGNORECASE) for x in row) for a in authors)
)
df = df[author_mask]
if title:
df = df[df.title.apply(lambda x: title.lower() in x.lower())]
st.write(f"Number of points: {df.shape[0]}")
color = st.selectbox("Color", ("cluster", "year", "source"))
fig = px.scatter(
df,
x="x",
y="y",
color=color,
width=1000,
height=800,
hover_data=["title", "authors", "year", "source", "type"],
color_continuous_scale=THEMES[color],
)
fig.update_layout(
# margin=dict(l=10, r=10, t=10, b=10),
showlegend=False,
font=dict(
family="Times New Roman",
size=30,
),
)
fig.update_xaxes(title="")
fig.update_yaxes(title="")
st.plotly_chart(fig, use_container_width=True)
|