File size: 4,183 Bytes
f74445c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import pandas as pd
import streamlit as st
import streamlit.components.v1 as components
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline

from src import GBRT, wikipedia_search, google_search


TYPE = {
    'LOC': ' location',
    'PER': ' person',
    'ORG': ' organization',
    'MISC': ''
}

COLOR = {
    'LOC': '#40E0D0',
    'PER': '#6495ED',
    'ORG': '#CCCCFF',
    'MISC': '#FF7F50'
}

# ---------------------------------------------------------------------------
# Loading models
# ---------------------------------------------------------------------------


@st.cache(allow_output_mutation=True, show_spinner=True)
def load_models():
    # NER
    tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
    bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
    tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer,
                      device=0, aggregation_strategy="average")
    # NED
    model = GBRT()
    return model, tagger


# ---------------------------------------------------------------------------
# Page setup
# ---------------------------------------------------------------------------
st.set_page_config(layout="wide", page_title='Named Entity Disambiguation')
st.write("## Named Entity Disambiguation")
col1, col2 = st.columns(2)


# ---------------------------------------------------------------------------
# Candidate Generation
# ---------------------------------------------------------------------------
def get_candidates(mentions_tags):
    candidates = []
    cache = {}
    for mention, tag in mentions_tags:
        if (mention, tag) in cache.keys():
            candidates.append((mention, cache[(mention, tag)]))
        else:
            res1 = google_search(mention + TYPE[tag])
            res2 = wikipedia_search(mention, limit=10)
            cands = list(set(res1 + res2))
            cache[(mention, tag)] = cands
            candidates.append((mention, cands))
    return candidates


# ---------------------------------------------------------------------------
# Rendering Setup
# ---------------------------------------------------------------------------
def display_tag(text, typ, label):
    if label != 'NIL':
        label = "https://en.wikipedia.org/wiki/" + label
    return f"""
    <a style="margin: 0 5px; padding: 2px 4px; border-radius: 4px; text-decoration:none;
              background-color:{COLOR[typ]}; color: white; cursor:pointer" href={label} target="_blank">
        <span style="margin-right:3px">{text}</span>
        <span style="border-style:1px white solid; padding: 2px;">{typ}</span>
    </a>"""


# ---------------------------------------------------------------------------
# Full Pipeline
# ---------------------------------------------------------------------------
def main(text):
    ner_results = tagger(text)
    tagged, last_pos = '', 0

    with st.spinner('Generating Candidates'):
        mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results])

    with st.spinner('Disambiguating Mentions'):
        preditions = model.link(mentions_cands, text)

    with st.spinner('Rendering Results'):
        for i, res in enumerate(ner_results):
            tag = display_tag(res['word'], res['entity_group'], preditions[i][1])
            tagged += text[last_pos:res['start']] + tag
            last_pos = res['end']
        tagged += text[last_pos:]

    with col2:
        st.write("### Disambiguated Text")
        components.html(f'<p style="line-height: 1.8; margin-top:30px; font-family: sans-serif">{tagged}</p>',
                        scrolling=True, height=500)

    df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence'])
    st.write("**Additional Information**")
    st.dataframe(df)


if __name__ == '__main__': 
    model, tagger = load_models()
    with col1:
        st.write("### Input Text")
        user_input = st.text_area('Press Ctrl + Enter to update results',
                                  'George Washington went to Washington.', height=350)
        if user_input:
            main(user_input)