Spaces:
Runtime error
Runtime error
import pandas as pd | |
import streamlit as st | |
import streamlit.components.v1 as components | |
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline | |
from src import GBRT, wikipedia_search, google_search | |
TYPE = { | |
'LOC': ' location', | |
'PER': ' person', | |
'ORG': ' organization', | |
'MISC': '' | |
} | |
COLOR = { | |
'LOC': '#40E0D0', | |
'PER': '#6495ED', | |
'ORG': '#CCCCFF', | |
'MISC': '#FF7F50' | |
} | |
# --------------------------------------------------------------------------- | |
# Loading models | |
# --------------------------------------------------------------------------- | |
def load_models(): | |
# NER | |
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") | |
bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") | |
tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer, | |
device=0, aggregation_strategy="average") | |
# NED | |
model = GBRT() | |
return model, tagger | |
# --------------------------------------------------------------------------- | |
# Page setup | |
# --------------------------------------------------------------------------- | |
st.set_page_config(layout="wide", page_title='Named Entity Disambiguation') | |
st.write("## Named Entity Disambiguation") | |
col1, col2 = st.columns(2) | |
# --------------------------------------------------------------------------- | |
# Candidate Generation | |
# --------------------------------------------------------------------------- | |
def get_candidates(mentions_tags): | |
candidates = [] | |
cache = {} | |
for mention, tag in mentions_tags: | |
if (mention, tag) in cache.keys(): | |
candidates.append((mention, cache[(mention, tag)])) | |
else: | |
res1 = google_search(mention + TYPE[tag]) | |
res2 = wikipedia_search(mention, limit=10) | |
cands = list(set(res1 + res2)) | |
cache[(mention, tag)] = cands | |
candidates.append((mention, cands)) | |
return candidates | |
# --------------------------------------------------------------------------- | |
# Rendering Setup | |
# --------------------------------------------------------------------------- | |
def display_tag(text, typ, label): | |
if label != 'NIL': | |
label = "https://en.wikipedia.org/wiki/" + label | |
return f""" | |
<a style="margin: 0 5px; padding: 2px 4px; border-radius: 4px; text-decoration:none; | |
background-color:{COLOR[typ]}; color: white; cursor:pointer" href={label} target="_blank"> | |
<span style="margin-right:3px">{text}</span> | |
<span style="border-style:1px white solid; padding: 2px;">{typ}</span> | |
</a>""" | |
# --------------------------------------------------------------------------- | |
# Full Pipeline | |
# --------------------------------------------------------------------------- | |
def main(text): | |
ner_results = tagger(text) | |
tagged, last_pos = '', 0 | |
with st.spinner('Generating Candidates'): | |
mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results]) | |
with st.spinner('Disambiguating Mentions'): | |
preditions = model.link(mentions_cands, text) | |
with st.spinner('Rendering Results'): | |
for i, res in enumerate(ner_results): | |
tag = display_tag(res['word'], res['entity_group'], preditions[i][1]) | |
tagged += text[last_pos:res['start']] + tag | |
last_pos = res['end'] | |
tagged += text[last_pos:] | |
with col2: | |
st.write("### Disambiguated Text") | |
components.html(f'<p style="line-height: 1.8; margin-top:30px; font-family: sans-serif">{tagged}</p>', | |
scrolling=True, height=500) | |
df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence']) | |
st.write("**Additional Information**") | |
st.dataframe(df) | |
if __name__ == '__main__': | |
model, tagger = load_models() | |
with col1: | |
st.write("### Input Text") | |
user_input = st.text_area('Press Ctrl + Enter to update results', | |
'George Washington went to Washington.', height=350) | |
if user_input: | |
main(user_input) | |