File size: 4,809 Bytes
fcb2250
 
 
 
 
 
 
 
 
 
 
d7e89e5
 
c80b30b
d7e89e5
fcb2250
 
 
3fc071b
fcb2250
 
 
 
 
 
 
 
5937336
c80b30b
fcb2250
 
 
 
 
6b3ace2
fcb2250
 
 
 
3f6a24b
 
 
 
 
 
 
 
c80b30b
 
fcb2250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c2743
fcb2250
 
 
3f6a24b
 
c80b30b
fcb2250
 
df6fb26
fcb2250
 
 
 
 
 
 
 
 
 
 
c80b30b
 
 
 
 
 
 
 
 
 
 
 
 
 
fcb2250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c80b30b
fcb2250
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import spacy.displacy
import streamlit as st
from flair.models import SequenceTagger
from flair.splitter import SegtokSentenceSplitter
from colorhash import ColorHash

# st.title("Flair NER Demo")
st.set_page_config(layout="centered")

# models to choose from
model_map = {
    "find Entities (default)": "ner-large",
    "find Entities (18-class)": "ner-ontonotes-large",
    "find Frames": "frame-large",
    "find Parts-of-Speech": "pos-multi",
}

# Block 1: Users can select a model
st.subheader("Select a model")
selected_model_id = st.selectbox("This is a check box",
                                 model_map.keys(),
                                 label_visibility="collapsed",
                                 )

# Block 2: Users can input text
st.subheader("Input your text here")
input_text = st.text_area('Write or Paste Text Below',
                          value='May visited the Eiffel Tower in Paris last May.\n\n'
                                'There she ran across a sign in German that read: "Dirk liebt den Eiffelturm"',
                          height=128,
                          max_chars=None,
                          label_visibility="collapsed")


@st.cache_resource
def get_model(model_name):
    return SequenceTagger.load(model_map[model_name])


# @st.cache(allow_output_mutation=True)
# def get_frame_definitions():
#     frame_definition_map = {}
#     with open('propbank_frames_3.1.txt') as infile:
#         for line in infile:
#             frame_definition_map[line.split('\t')[0]] = line.split('\t')[1]
#
#     return frame_definition_map


def get_html(html: str):
    WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
    html = html.replace("\n", " ")
    return WRAPPER.format(html)


def color_variant(hex_color, brightness_offset=1):
    """ takes a color like #87c95f and produces a lighter or darker variant
    taken from: https://chase-seibert.github.io/blog/2011/07/29/python-calculate-lighterdarker-rgb-colors.html
    """
    if len(hex_color) != 7:
        raise Exception("Passed %s into color_variant(), needs to be in #87c95f format." % hex_color)
    rgb_hex = [hex_color[x:x + 2] for x in [1, 3, 5]]
    new_rgb_int = [int(hex_value, 16) + brightness_offset for hex_value in rgb_hex]
    new_rgb_int = [min([255, max([0, i])]) for i in new_rgb_int]  # make sure new values are between 0 and 255
    # hex() produces "0x88", we want just "88"
    return "#" + "".join([hex(i)[2:] for i in new_rgb_int])


# Block 3: Output is displayed
button_clicked = st.button("**Click here** to tag the input text", key=None)

if button_clicked:

    # if 'frame' in selected_model_id.lower():
    #     frame_definition_map = get_frame_definitions()

    # get a sentence splitter and split text into sentences
    splitter = SegtokSentenceSplitter()
    # TODO: perhaps truncate input_text
    sentences = splitter.split(input_text)

    # get the model and predict
    model = get_model(selected_model_id)
    model.predict(sentences)

    spacy_display = {"ents": [], "text": input_text, "title": None}

    predicted_labels = set()
    for sentence in sentences:
        for prediction in sentence.get_labels():
            entity_fields = {
                "start": prediction.data_point.start_position + sentence.start_position,
                "end": prediction.data_point.end_position + sentence.start_position,
                "label": prediction.value,
            }

            if 'frame' in selected_model_id.lower():
                id = prediction.value.split('.')[-1]
                verb = ''.join(prediction.value.split('.')[:-1])
                kb_url = f"https://propbank.github.io/v3.4.0/frames/{verb}.html#{verb}.{id}"
                entity_fields["label"] = f'<a style="text-decoration: underline; text-decoration-style: dotted; color: inherit; font-weight: bold" href="{kb_url}">{prediction.value}</a>'

            spacy_display["ents"].append(entity_fields)
            predicted_labels.add(entity_fields["label"])

    # create colors for each label
    colors = {}
    for label in predicted_labels:
        colors[label] = color_variant(ColorHash(label).hex, brightness_offset=85)

    # use displacy to render
    html = spacy.displacy.render(spacy_display,
                                 style="ent",
                                 minify=True,
                                 manual=True,
                                 options={
                                     "colors": colors,
                                 },
                                 )
    style = "<style>mark.entity { display: inline-block }</style>"
    st.subheader("Tagged text")
    st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)