File size: 3,796 Bytes
f09ca43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from transformers import pipeline
import streamlit as st
from spacy import displacy
from typing import List, Tuple
import json
import random


def ner_prediction(model, sentence):
    """ This function takes in a ner pipeline model and a sentence, make the prediction and returns a list of entity
        prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX") format
    """
    entity_map = {
        "B-ORG": "ORG",
        "B-SEG": "SEGMENT",
        "B-SEGNUM": "NUM_SEGMENT"
    }
    predictions = []
    model_output = model(sentence)

    accumulate = ""
    current_class = None
    start = 0
    end = 0
    for item in model_output:
        if item['entity'].startswith("B"):
            if len(accumulate) > 0:
                predictions.append((current_class, accumulate, start, end))
            accumulate = item['word'].lstrip("Ġ")
            current_class = entity_map[item['entity']]
            start = item['start']
            end = item['end']

        else:
            if item['word'].startswith("Ġ"):
                accumulate += " "+item['word'].lstrip("Ġ")

            else:
                accumulate += item['word']
            end = item['end']

    # clear last cache
    if len(accumulate) > 0:
        predictions.append((current_class, accumulate, start, end))

    return predictions


def generate_displacy_html(predictions: List[Tuple[str, str, int, int]], sentence) -> str:
    '''
    This function will take in a list of prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX")
    and generate displacy entity html to be output in streamlit
    '''
    colors = {"SEGMENT": "linear-gradient(90deg, #DBE575, #C3D32C)",
              "NUM_SEGMENT": "linear-gradient(90deg, #3AD8E8, #1AA7B6)",
              "ORG": "linear-gradient(90deg, #aa9cfc, #fc9ce7)"
              }
    options = {"ents": ["SEGMENT", "NUM_SEGMENT", "ORG"], "colors": colors}

    payload = [
        {'text': sentence,
         'ents': [{'start': pred[2], 'end': pred[3], 'label': pred[0]} for pred in predictions],
         'title': "Name entity recognition"
         }
    ]
    displacy_html = displacy.render(
        payload, style='ent', manual=True, options=options)

    return displacy_html


# loading in the model in cache
@st.cache(allow_output_mutation=True)
def load_model_and_data():
    # loading in the sample text
    with open("sample_articles.json", "r") as json_file:
        sample_text = json.load(json_file)

    # loading in the model
    model_path = "wolfrage89/company_segment_ner"
    model = pipeline('ner', model_path)

    return sample_text, model


sample_texts, model = load_model_and_data()

# creating session state
if "article_text" not in st.session_state:
    st.session_state["article_text"] = ""
if "displacy_html" not in st.session_state:
    st.session_state['displacy_html'] = ""


# adding in the side bar
st.sidebar.title("Welcome To Company Segment Name Entity Recognition App")

random_button = st.sidebar.button("RANDOM")
st.sidebar.write("Randomly generates an article for testing")
st.sidebar.markdown("---")
predict_button = st.sidebar.button("PREDICT!")

if random_button:
    st.session_state['article_text'] = random.choice(sample_texts)
    st.session_state["displacy_html"] = ""

if predict_button:
    if len(st.session_state['article_text']) > 0:
        predictions = ner_prediction(model, st.session_state['article_text'])
        st.session_state['displacy_html'] = generate_displacy_html(
            predictions, st.session_state['article_text'])
    else:
        st.session_state['displacy_html'] = ""

st.session_state["article_text"] = st.text_area(
    label="Insert article here", value=st.session_state["article_text"], height=200)

st.markdown(st.session_state['displacy_html'], unsafe_allow_html=True)