File size: 3,704 Bytes
fcb2250
 
 
 
 
 
 
 
 
 
 
d7e89e5
 
 
fcb2250
 
 
3fc071b
fcb2250
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c2743
fcb2250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3c2743
fcb2250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import spacy.displacy
import streamlit as st
from flair.models import SequenceTagger
from flair.splitter import SegtokSentenceSplitter
from colorhash import ColorHash

# st.title("Flair NER Demo")
st.set_page_config(layout="centered")

# models to choose from
model_map = {
    "find Entities (default)": "ner-large",
    "find Entities (18-class)": "ner-ontonotes-large",
    "find Parts-of-Speech": "pos-multi",
}

# Block 1: Users can select a model
st.subheader("Select a model")
selected_model_id = st.selectbox("This is a check box",
                                 model_map.keys(),
                                 label_visibility="collapsed",
                                 )

# Block 2: Users can input text
st.subheader("Input your text here")
input_text = st.text_area('Write or Paste Text Below',
                          value="George was born in Washington.",
                          height=128,
                          max_chars=None,
                          label_visibility="collapsed")


@st.cache(allow_output_mutation=True)
def get_model(model_name):
    return SequenceTagger.load(model_map[model_name])


def get_html(html: str):
    WRAPPER = """<div style="overflow-x: auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem; margin-bottom: 2.5rem">{}</div>"""
    html = html.replace("\n", " ")
    return WRAPPER.format(html)


def color_variant(hex_color, brightness_offset=1):
    """ takes a color like #87c95f and produces a lighter or darker variant
    taken from: https://chase-seibert.github.io/blog/2011/07/29/python-calculate-lighterdarker-rgb-colors.html
    """
    if len(hex_color) != 7:
        raise Exception("Passed %s into color_variant(), needs to be in #87c95f format." % hex_color)
    rgb_hex = [hex_color[x:x + 2] for x in [1, 3, 5]]
    new_rgb_int = [int(hex_value, 16) + brightness_offset for hex_value in rgb_hex]
    new_rgb_int = [min([255, max([0, i])]) for i in new_rgb_int]  # make sure new values are between 0 and 255
    # hex() produces "0x88", we want just "88"
    return "#" + "".join([hex(i)[2:] for i in new_rgb_int])


# Block 3: Output is displayed
button_clicked = st.button("**Click here** to tag the input text", key=None)

if button_clicked:

    # get a sentence splitter and split text into sentences
    splitter = SegtokSentenceSplitter()
    sentences = splitter.split(input_text)

    # get the model and predict
    model = get_model(selected_model_id)
    model.predict(sentences)

    spacy_display = {"ents": [], "text": input_text, "title": None}

    predicted_labels = set()
    for sentence in sentences:
        for prediction in sentence.get_labels():
            spacy_display["ents"].append(
                {"start": prediction.data_point.start_position + sentence.start_position,
                 "end": prediction.data_point.end_position + sentence.start_position,
                 "label": prediction.value})
            predicted_labels.add(prediction.value)

    # create colors for each label
    colors = {}
    for label in predicted_labels:
        colors[label] = color_variant(ColorHash(label).hex, brightness_offset=85)

    # use displacy to render
    html = spacy.displacy.render(spacy_display,
                                 style="ent",
                                 minify=True,
                                 manual=True,
                                 options={
                                     "colors": colors,
                                 },
                                 )
    style = "<style>mark.entity { display: inline-block }</style>"
    st.subheader("Found entities")
    st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)