File size: 4,838 Bytes
0f23c4b
 
e4f39c4
 
0f23c4b
 
e4f39c4
 
 
 
 
0f23c4b
 
 
 
 
e4f39c4
 
0f23c4b
 
 
 
 
 
 
e4f39c4
 
 
 
 
 
0f23c4b
 
 
 
 
 
 
 
 
 
 
e4f39c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f23c4b
 
e4f39c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f23c4b
 
 
 
 
 
 
e4f39c4
 
 
0f23c4b
e4f39c4
 
0f23c4b
e4f39c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f23c4b
e4f39c4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import streamlit as st
import pandas as pd
from pipelines.keyphrase_extraction_pipeline import KeyphraseExtractionPipeline
from pipelines.keyphrase_generation_pipeline import KeyphraseGenerationPipeline
import orjson

from annotated_text.util import get_annotated_html
from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
import re
import numpy as np


if "config" not in st.session_state:
    with open("config.json", "r") as f:
        content = f.read()
    st.session_state.config = orjson.loads(content)
    st.session_state.data_frame = pd.DataFrame(columns=["model"])
    st.session_state.keyphrases = []

st.set_page_config(
    page_icon="πŸ”‘",
    page_title="Keyphrase extraction/generation with Transformers",
    layout="wide",
)

if "select_rows" not in st.session_state:
    st.session_state.selected_rows = []

st.header("πŸ”‘ Keyphrase extraction/generation with Transformers")
col1, col2 = st.empty().columns(2)


@st.cache(allow_output_mutation=True)
def load_pipeline(chosen_model):
    if "keyphrase-extraction" in chosen_model:
        return KeyphraseExtractionPipeline(chosen_model)
    elif "keyphrase-generation" in chosen_model:
        return KeyphraseGenerationPipeline(chosen_model)


def extract_keyphrases():
    st.session_state.keyphrases = pipe(st.session_state.input_text)
    st.session_state.data_frame = pd.concat(
        [
            st.session_state.data_frame,
            pd.DataFrame(
                data=[
                    np.concatenate(
                        (
                            [
                                st.session_state.chosen_model,
                                st.session_state.input_text,
                            ],
                            st.session_state.keyphrases,
                        )
                    )
                ],
                columns=["model", "text"]
                + [str(i) for i in range(len(st.session_state.keyphrases))],
            ),
        ],
        ignore_index=True,
        axis=0,
    ).fillna("")


def get_annotated_text(text, keyphrases):
    for keyphrase in keyphrases:
        text = re.sub(
            f"({keyphrase})",
            keyphrase.replace(" ", "$K"),
            text,
            flags=re.I,
        )

    result = []
    for i, word in enumerate(text.split(" ")):
        if re.sub(r"[^\w\s]", "", word) in keyphrases:
            result.append((word, "KEY", "#21c354"))
        elif "$K" in word:
            result.append((" ".join(word.split("$K")), "KEY", "#21c354"))
        else:
            if i == len(st.session_state.input_text.split(" ")) - 1:
                result.append(f" {word}")
            elif i == 0:
                result.append(f"{word} ")
            else:
                result.append(f" {word} ")
    return result


def rerender_output(layout):
    layout.subheader("🐧 Output")
    if (
        len(st.session_state.keyphrases) > 0
        and len(st.session_state.selected_rows) == 0
    ):
        text, keyphrases = st.session_state.input_text, st.session_state.keyphrases
    else:
        text, keyphrases = (
            st.session_state.selected_rows["text"].values[0],
            [
                keyphrase
                for keyphrase in st.session_state.selected_rows.loc[
                    :,
                    st.session_state.selected_rows.columns.difference(
                        ["model", "text"]
                    ),
                ]
                .astype(str)
                .values.tolist()[0]
                if keyphrase != ""
            ],
        )

    result = get_annotated_text(text, keyphrases)

    layout.markdown(
        get_annotated_html(*result),
        unsafe_allow_html=True,
    )


chosen_model = col1.selectbox(
    "Choose your model:",
    st.session_state.config.get("models"),
)
st.session_state.chosen_model = chosen_model

pipe = load_pipeline(
    f"{st.session_state.config.get('model_author')}/{st.session_state.chosen_model}"
)

st.session_state.input_text = col1.text_area(
    "Input", st.session_state.config.get("example_text"), height=300
)
pressed = col1.button("Extract", on_click=extract_keyphrases)


if len(st.session_state.data_frame.columns) > 0:
    st.subheader("πŸ“œ History")
    builder = GridOptionsBuilder.from_dataframe(
        st.session_state.data_frame, sortable=False
    )
    builder.configure_selection(selection_mode="single", use_checkbox=True)
    builder.configure_column("text", hide=True)
    go = builder.build()
    data = AgGrid(
        st.session_state.data_frame,
        gridOptions=go,
        update_mode=GridUpdateMode.SELECTION_CHANGED,
    )
    st.session_state.selected_rows = pd.DataFrame(data["selected_rows"])

if len(st.session_state.selected_rows) > 0 or len(st.session_state.keyphrases) > 0:
    rerender_output(col2)