Thomas De Decker
First version
0f23c4b
raw
history blame
1.62 kB
import streamlit as st
import pandas as pd
from extraction.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline
from extraction.keyphrase_generation_pipeline import KeyphraseGenerationPipeline
import orjson
if "config" not in st.session_state:
with open("config.json", "r") as f:
content = f.read()
st.session_state.config = orjson.loads(content)
st.set_page_config(
page_icon="πŸ”‘",
page_title="Keyphrase extraction/generation with Transformers",
layout="wide",
initial_sidebar_state="auto",
)
@st.cache(allow_output_mutation=True)
def load_pipeline(chosen_model):
if "keyphrase-extraction" in chosen_model:
return KeyphraseExtractionPipeline(chosen_model)
elif "keyphrase-generation" in chosen_model:
return KeyphraseGenerationPipeline(chosen_model)
def extract_keyphrases():
st.session_state.keyphrases = pipe(st.session_state.input_text)
st.header("πŸ”‘ Keyphrase extraction/generation with Transformers")
col1, col2 = st.columns([1, 3])
col1.subheader("Select model")
chosen_model = col1.selectbox(
"Choose your model:",
st.session_state.config.get("models"),
)
st.session_state.chosen_model = chosen_model
pipe = load_pipeline(st.session_state.chosen_model)
col2.subheader("Input your text")
st.session_state.input_text = col2.text_area(
"Input", st.session_state.config.get("example_text"), height=150
)
pressed = col2.button("Extract", on_click=extract_keyphrases)
if pressed:
col2.subheader("🐧 Output")
df = pd.DataFrame(data=st.session_state.keyphrases, columns=["Keyphrases"])
col2.table(df)