DeDeckerThomas
Second version
e4f39c4
raw
history blame
No virus
4.84 kB
import streamlit as st
import pandas as pd
from pipelines.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline
from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline
import orjson
from annotated_text.util import get_annotated_html
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
import re
import numpy as np
if "config" not in st.session_state:
with open("config.json", "r") as f:
content = f.read()
st.session_state.config = orjson.loads(content)
st.session_state.data_frame = pd.DataFrame(columns=["model"])
st.session_state.keyphrases = []
st.set_page_config(
page_icon="πŸ”‘",
page_title="Keyphrase extraction/generation with Transformers",
layout="wide",
)
if "select_rows" not in st.session_state:
st.session_state.selected_rows = []
st.header("πŸ”‘ Keyphrase extraction/generation with Transformers")
col1, col2 = st.empty().columns(2)
@st.cache(allow_output_mutation=True)
def load_pipeline(chosen_model):
if "keyphrase-extraction" in chosen_model:
return KeyphraseExtractionPipeline(chosen_model)
elif "keyphrase-generation" in chosen_model:
return KeyphraseGenerationPipeline(chosen_model)
def extract_keyphrases():
st.session_state.keyphrases = pipe(st.session_state.input_text)
st.session_state.data_frame = pd.concat(
[
st.session_state.data_frame,
pd.DataFrame(
data=[
np.concatenate(
(
[
st.session_state.chosen_model,
st.session_state.input_text,
],
st.session_state.keyphrases,
)
)
],
columns=["model", "text"]
+ [str(i) for i in range(len(st.session_state.keyphrases))],
),
],
ignore_index=True,
axis=0,
).fillna("")
def get_annotated_text(text, keyphrases):
for keyphrase in keyphrases:
text = re.sub(
f"({keyphrase})",
keyphrase.replace(" ", "$K"),
text,
flags=re.I,
)
result = []
for i, word in enumerate(text.split(" ")):
if re.sub(r"[^\w\s]", "", word) in keyphrases:
result.append((word, "KEY", "#21c354"))
elif "$K" in word:
result.append((" ".join(word.split("$K")), "KEY", "#21c354"))
else:
if i == len(st.session_state.input_text.split(" ")) - 1:
result.append(f" {word}")
elif i == 0:
result.append(f"{word} ")
else:
result.append(f" {word} ")
return result
def rerender_output(layout):
layout.subheader("🐧 Output")
if (
len(st.session_state.keyphrases) > 0
and len(st.session_state.selected_rows) == 0
):
text, keyphrases = st.session_state.input_text, st.session_state.keyphrases
else:
text, keyphrases = (
st.session_state.selected_rows["text"].values[0],
[
keyphrase
for keyphrase in st.session_state.selected_rows.loc[
:,
st.session_state.selected_rows.columns.difference(
["model", "text"]
),
]
.astype(str)
.values.tolist()[0]
if keyphrase != ""
],
)
result = get_annotated_text(text, keyphrases)
layout.markdown(
get_annotated_html(*result),
unsafe_allow_html=True,
)
chosen_model = col1.selectbox(
"Choose your model:",
st.session_state.config.get("models"),
)
st.session_state.chosen_model = chosen_model
pipe = load_pipeline(
f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}"
)
st.session_state.input_text = col1.text_area(
"Input", st.session_state.config.get("example_text"), height=300
)
pressed = col1.button("Extract", on_click=extract_keyphrases)
if len(st.session_state.data_frame.columns) > 0:
st.subheader("πŸ“œ History")
builder = GridOptionsBuilder.from_dataframe(
st.session_state.data_frame, sortable=False
)
builder.configure_selection(selection_mode="single", use_checkbox=True)
builder.configure_column("text", hide=True)
go = builder.build()
data = AgGrid(
st.session_state.data_frame,
gridOptions=go,
update_mode=GridUpdateMode.SELECTION_CHANGED,
)
st.session_state.selected_rows = pd.DataFrame(data["selected_rows"])
if len(st.session_state.selected_rows) > 0 or len(st.session_state.keyphrases) > 0:
rerender_output(col2)