Atharva
initial commit
f74445c
raw
history blame
4.18 kB
import pandas as pd
import streamlit as st
import streamlit.components.v1 as components
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
from src import GBRT, wikipedia_search, google_search
TYPE = {
'LOC': ' location',
'PER': ' person',
'ORG': ' organization',
'MISC': ''
}
COLOR = {
'LOC': '#40E0D0',
'PER': '#6495ED',
'ORG': '#CCCCFF',
'MISC': '#FF7F50'
}
# ---------------------------------------------------------------------------
# Loading models
# ---------------------------------------------------------------------------
@st.cache(allow_output_mutation=True, show_spinner=True)
def load_models():
# NER
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer,
device=0, aggregation_strategy="average")
# NED
model = GBRT()
return model, tagger
# ---------------------------------------------------------------------------
# Page setup
# ---------------------------------------------------------------------------
st.set_page_config(layout="wide", page_title='Named Entity Disambiguation')
st.write("## Named Entity Disambiguation")
col1, col2 = st.columns(2)
# ---------------------------------------------------------------------------
# Candidate Generation
# ---------------------------------------------------------------------------
def get_candidates(mentions_tags):
candidates = []
cache = {}
for mention, tag in mentions_tags:
if (mention, tag) in cache.keys():
candidates.append((mention, cache[(mention, tag)]))
else:
res1 = google_search(mention + TYPE[tag])
res2 = wikipedia_search(mention, limit=10)
cands = list(set(res1 + res2))
cache[(mention, tag)] = cands
candidates.append((mention, cands))
return candidates
# ---------------------------------------------------------------------------
# Rendering Setup
# ---------------------------------------------------------------------------
def display_tag(text, typ, label):
if label != 'NIL':
label = "https://en.wikipedia.org/wiki/" + label
return f"""
<a style="margin: 0 5px; padding: 2px 4px; border-radius: 4px; text-decoration:none;
background-color:{COLOR[typ]}; color: white; cursor:pointer" href={label} target="_blank">
<span style="margin-right:3px">{text}</span>
<span style="border-style:1px white solid; padding: 2px;">{typ}</span>
</a>"""
# ---------------------------------------------------------------------------
# Full Pipeline
# ---------------------------------------------------------------------------
def main(text):
ner_results = tagger(text)
tagged, last_pos = '', 0
with st.spinner('Generating Candidates'):
mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results])
with st.spinner('Disambiguating Mentions'):
preditions = model.link(mentions_cands, text)
with st.spinner('Rendering Results'):
for i, res in enumerate(ner_results):
tag = display_tag(res['word'], res['entity_group'], preditions[i][1])
tagged += text[last_pos:res['start']] + tag
last_pos = res['end']
tagged += text[last_pos:]
with col2:
st.write("### Disambiguated Text")
components.html(f'<p style="line-height: 1.8; margin-top:30px; font-family: sans-serif">{tagged}</p>',
scrolling=True, height=500)
df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence'])
st.write("**Additional Information**")
st.dataframe(df)
if __name__ == '__main__':
model, tagger = load_models()
with col1:
st.write("### Input Text")
user_input = st.text_area('Press Ctrl + Enter to update results',
'George Washington went to Washington.', height=350)
if user_input:
main(user_input)