File size: 4,376 Bytes
3ec6971
 
41386c7
3ec6971
65e4324
 
 
 
3ec6971
 
29de8f2
 
3ec6971
 
65e4324
3ec6971
 
29de8f2
3ec6971
e4016f5
3ec6971
99fb2bf
65e4324
3ec6971
 
 
 
 
 
 
 
 
 
 
e4016f5
9fd37ff
 
 
3ec6971
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer, DebertaV2Tokenizer, DebertaV2Model
import sentencepiece
import streamlit as st
import pandas as pd
import spacy 
from spacy import displacy
import plotly.express as px
import numpy as np

example_list = [
     """Hong Kong’s two-week flight ban has dashed the hopes of those planning family reunions as well as disrupted plans for incoming domestic helpers, with the Philippines, Britain and the United States among eight countries hit with tightened rules aimed at containing a Covid-19 surge.""",
     """From Friday (Jan 7), all bars and entertainment venues will close for two weeks, and restaurants have to stop dine-in after 6pm, Chief Executive Carrie Lam Cheng Yuet-ngor announced on Wednesday.  """
]

st.set_page_config(layout="wide", page_title="Vocabulary Categorizer")

st.title("Vocabulary Categorizer")
st.write("This application identifies, highlights and categorizes nouns.")

model_list = ['xlm-roberta-large-finetuned-conll03-english', 'xlm-roberta-large']

st.sidebar.header("Vocabulary categorizer")
model_checkpoint = st.sidebar.radio("", model_list)

st.sidebar.write("Which model highlights the most vocabulary words? Which model highlights the most accurately?")
st.sidebar.write("")

xlm_agg_strategy_info = "'aggregation_strategy' can be selected as 'simple' or 'none' for 'xlm-roberta'."

st.sidebar.header("Select Aggregation Strategy Type")
if model_checkpoint == "xlm-roberta-large-finetuned-conll03-english":
    aggregation = st.sidebar.radio("", ('simple', 'none'))
    st.sidebar.write(xlm_agg_strategy_info)
    st.sidebar.write("")
elif model_checkpoint == "xlm-roberta-large": 
    aggregation = st.sidebar.radio("", ('simple', 'none'))
    st.sidebar.write(xlm_agg_strategy_info)
    st.sidebar.write("")

st.subheader("Select Text Input Method")
input_method = st.radio("", ('Select from Examples', 'Write or Paste New Text'))
if input_method == 'Select from Examples':
    selected_text = st.selectbox('Select Text from List', example_list, index=0, key=1)
    st.subheader("Text to Run")
    input_text = st.text_area("Selected Text", selected_text, height=128, max_chars=None, key=2)
elif input_method == "Write or Paste New Text":
    st.subheader("Text to Run")
    input_text = st.text_area('Write or Paste Text Below', value="", height=128, max_chars=None, key=2)

@st.cache(allow_output_mutation=True)
def setModel(model_checkpoint, aggregation):
    model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    return pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy=aggregation)

@st.cache(allow_output_mutation=True)
def get_html(html: str):
    WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
    html = html.replace("\n", " ")
    return WRAPPER.format(html)
    
Run_Button = st.button("Run", key=None)
if Run_Button == True:

    ner_pipeline = setModel(model_checkpoint, aggregation)
    output = ner_pipeline(input_text)
    
    df = pd.DataFrame.from_dict(output)
    if aggregation != "none":
        cols_to_keep = ['word','entity_group','score','start','end']
    else:
        cols_to_keep = ['word','entity','score','start','end']
    df_final = df[cols_to_keep]
    
    st.subheader("Recognized Entities")
    st.dataframe(df_final)
    
    st.subheader("Spacy Style Display")
    spacy_display = {}
    spacy_display["ents"] = []
    spacy_display["text"] = input_text
    spacy_display["title"] = None

    for entity in output:
        if aggregation != "none":
            spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity["entity_group"]})
        else:
            spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": entity["entity"]})

    entity_list = ["PER", "LOC", "ORG", "MISC"]
    colors = {'PER': '#85DCDF', 'LOC': '#DF85DC', 'ORG': '#DCDF85', 'MISC': '#85ABDF',}
    html = spacy.displacy.render(spacy_display, style="ent", minify=True, manual=True, options={"ents": entity_list, "colors": colors})
    style = "<style>mark.entity { display: inline-block }</style>"
    st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)