File size: 4,170 Bytes
f74445c
 
 
 
5230da2
f74445c
 
 
 
 
 
 
 
 
d347304
 
 
 
f74445c
 
 
 
 
 
 
 
 
 
fe0139a
 
f74445c
f28ff95
f74445c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f441aa5
5230da2
f74445c
 
 
 
 
 
 
 
 
 
 
8a5068f
d347304
cebe0ce
f74445c
 
 
 
 
 
d347304
f74445c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7711cb9
cebe0ce
f74445c
 
 
d347304
 
f74445c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import pandas as pd
import streamlit as st
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline

from src import GBRT, google_search, wikidata_search

TYPE = {
    'LOC': ' location',
    'PER': ' person',
    'ORG': ' organization',
    'MISC': ''
}

COLOR = {
    'LOC': '#10B981',
    'PER': '#0EA5E9',
    'ORG': '#A855F7',
    'MISC': '#F97316'
}

# ---------------------------------------------------------------------------
# Loading models
# ---------------------------------------------------------------------------


@st.cache(allow_output_mutation=True, show_spinner=True)
def load_models():
    # NER
    tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
    bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
    tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer,
                      device=-1, aggregation_strategy="average")
    # NED
    model = GBRT()
    return model, tagger


# ---------------------------------------------------------------------------
# Page setup
# ---------------------------------------------------------------------------
st.set_page_config(layout="wide", page_title='Named Entity Disambiguation')
col1, col2 = st.columns(2)


# ---------------------------------------------------------------------------
# Candidate Generation
# ---------------------------------------------------------------------------
def get_candidates(mentions_tags):
    candidates = []
    cache = {}
    for mention, tag in mentions_tags:
        if (mention, tag) in cache.keys():
            candidates.append((mention, cache[(mention, tag)]))
        else:
            res1 = google_search(mention + TYPE[tag], limit=3)
            res2 = wikidata_search(mention, limit=3)
            cands = list(set(res1 + res2))
            cache[(mention, tag)] = cands
            candidates.append((mention, cands))
    return candidates


# ---------------------------------------------------------------------------
# Rendering Setup
# ---------------------------------------------------------------------------
def display_tag(text, typ, label):
    if label != 'NIL':
        return f'[<span style="color: {COLOR[typ]}" class="tag">**{text}** <small>{typ}</small></span>](https://en.wikipedia.org/wiki/{label})'
    else:
        return f'<span style="color: #EF4444" class="tag">**{text}** <small>{typ}</small></span>'


# ---------------------------------------------------------------------------
# Full Pipeline
# ---------------------------------------------------------------------------
def main(text):
    # text = text.replace('\n', ' ')
    ner_results = tagger(text)
    tagged, last_pos = '', 0

    with st.spinner('Generating Candidates'):
        mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results])

    with st.spinner('Disambiguating Mentions'):
        preditions = model.link(mentions_cands, text)

    with st.spinner('Rendering Results'):
        for i, res in enumerate(ner_results):
            tag = display_tag(res['word'], res['entity_group'], preditions[i][1])
            tagged += text[last_pos:res['start']] + tag
            last_pos = res['end']
        tagged += text[last_pos:]

    with col2:
        st.write("### Disambiguated Text")
        st.markdown("<style>a {text-decoration: none;} small {color: inherit !important} .tag {border: 1px solid; border-radius: 5px; padding: 0px 3px; white-space: nowrap}</style>", unsafe_allow_html=True)
        st.write(tagged, unsafe_allow_html=True)

    df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence'])
    st.write("**Additional Information**")
    st.markdown("<style>tbody th {display:none} .blank {display:none}</style>", unsafe_allow_html=True)
    st.table(df)


if __name__ == '__main__': 
    model, tagger = load_models()
    with col1:
        st.write("### Input Text")
        user_input = st.text_area('Press Ctrl + Enter to update results',
                                  'George Washington went to Washington.', height=350)
        if user_input:
            main(user_input)