Spaces:
Runtime error
Runtime error
File size: 4,183 Bytes
f74445c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import pandas as pd
import streamlit as st
import streamlit.components.v1 as components
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
from src import GBRT, wikipedia_search, google_search
TYPE = {
'LOC': ' location',
'PER': ' person',
'ORG': ' organization',
'MISC': ''
}
COLOR = {
'LOC': '#40E0D0',
'PER': '#6495ED',
'ORG': '#CCCCFF',
'MISC': '#FF7F50'
}
# ---------------------------------------------------------------------------
# Loading models
# ---------------------------------------------------------------------------
@st.cache(allow_output_mutation=True, show_spinner=True)
def load_models():
# NER
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer,
device=0, aggregation_strategy="average")
# NED
model = GBRT()
return model, tagger
# ---------------------------------------------------------------------------
# Page setup
# ---------------------------------------------------------------------------
st.set_page_config(layout="wide", page_title='Named Entity Disambiguation')
st.write("## Named Entity Disambiguation")
col1, col2 = st.columns(2)
# ---------------------------------------------------------------------------
# Candidate Generation
# ---------------------------------------------------------------------------
def get_candidates(mentions_tags):
candidates = []
cache = {}
for mention, tag in mentions_tags:
if (mention, tag) in cache.keys():
candidates.append((mention, cache[(mention, tag)]))
else:
res1 = google_search(mention + TYPE[tag])
res2 = wikipedia_search(mention, limit=10)
cands = list(set(res1 + res2))
cache[(mention, tag)] = cands
candidates.append((mention, cands))
return candidates
# ---------------------------------------------------------------------------
# Rendering Setup
# ---------------------------------------------------------------------------
def display_tag(text, typ, label):
if label != 'NIL':
label = "https://en.wikipedia.org/wiki/" + label
return f"""
<a style="margin: 0 5px; padding: 2px 4px; border-radius: 4px; text-decoration:none;
background-color:{COLOR[typ]}; color: white; cursor:pointer" href={label} target="_blank">
<span style="margin-right:3px">{text}</span>
<span style="border-style:1px white solid; padding: 2px;">{typ}</span>
</a>"""
# ---------------------------------------------------------------------------
# Full Pipeline
# ---------------------------------------------------------------------------
def main(text):
ner_results = tagger(text)
tagged, last_pos = '', 0
with st.spinner('Generating Candidates'):
mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results])
with st.spinner('Disambiguating Mentions'):
preditions = model.link(mentions_cands, text)
with st.spinner('Rendering Results'):
for i, res in enumerate(ner_results):
tag = display_tag(res['word'], res['entity_group'], preditions[i][1])
tagged += text[last_pos:res['start']] + tag
last_pos = res['end']
tagged += text[last_pos:]
with col2:
st.write("### Disambiguated Text")
components.html(f'<p style="line-height: 1.8; margin-top:30px; font-family: sans-serif">{tagged}</p>',
scrolling=True, height=500)
df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence'])
st.write("**Additional Information**")
st.dataframe(df)
if __name__ == '__main__':
model, tagger = load_models()
with col1:
st.write("### Input Text")
user_input = st.text_area('Press Ctrl + Enter to update results',
'George Washington went to Washington.', height=350)
if user_input:
main(user_input)
|