Spaces:
Runtime error
Runtime error
import pandas as pd | |
import streamlit as st | |
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline | |
from src import GBRT, wikidata_search, google_search | |
TYPE = { | |
'LOC': ' location', | |
'PER': ' person', | |
'ORG': ' organization', | |
'MISC': '' | |
} | |
COLOR = { | |
'LOC': '#10B981', | |
'PER': '#0EA5E9', | |
'ORG': '#A855F7', | |
'MISC': '#F97316' | |
} | |
# --------------------------------------------------------------------------- | |
# Loading models | |
# --------------------------------------------------------------------------- | |
def load_models(): | |
# NER | |
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") | |
bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") | |
tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer, | |
device=-1, aggregation_strategy="average") | |
# NED | |
model = GBRT() | |
return model, tagger | |
# --------------------------------------------------------------------------- | |
# Page setup | |
# --------------------------------------------------------------------------- | |
st.set_page_config(layout="wide", page_title='Named Entity Disambiguation') | |
col1, col2 = st.columns(2) | |
# --------------------------------------------------------------------------- | |
# Candidate Generation | |
# --------------------------------------------------------------------------- | |
def get_candidates(mentions_tags): | |
candidates = [] | |
cache = {} | |
for mention, tag in mentions_tags: | |
if (mention, tag) in cache.keys(): | |
candidates.append((mention, cache[(mention, tag)])) | |
else: | |
cands = wikidata_search(mention, limit=3) | |
if cands == [] and tag != 'PER': | |
cands = google_search(mention, limit=3) | |
cache[(mention, tag)] = cands | |
candidates.append((mention, cands)) | |
print(mention, cands) | |
return candidates | |
# --------------------------------------------------------------------------- | |
# Rendering Setup | |
# --------------------------------------------------------------------------- | |
def display_tag(text, typ, label): | |
if label != 'NIL': | |
return f'[<span style="color: {COLOR[typ]}" class="tag">**{text}** <small>{typ}</small></span>](https://en.wikipedia.org/wiki/{label})' | |
else: | |
return f'<span style="color: #EF4444" class="tag">**{text}** <small>{typ}</small></span>' | |
# --------------------------------------------------------------------------- | |
# Full Pipeline | |
# --------------------------------------------------------------------------- | |
def main(text): | |
# text = text.replace('\n', ' ') | |
ner_results = tagger(text) | |
tagged, last_pos = '', 0 | |
with st.spinner('Generating Candidates'): | |
mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results]) | |
with st.spinner('Disambiguating Mentions'): | |
preditions = model.link(mentions_cands, text) | |
with st.spinner('Rendering Results'): | |
for i, res in enumerate(ner_results): | |
tag = display_tag(res['word'], res['entity_group'], preditions[i][1]) | |
tagged += text[last_pos:res['start']] + tag | |
last_pos = res['end'] | |
tagged += text[last_pos:] | |
with col2: | |
st.write("### Disambiguated Text") | |
st.markdown("<style>a {text-decoration: none;} small {color: inherit !important} .tag {border: 1px solid; border-radius: 5px; padding: 0px 3px; white-space: nowrap}</style>", unsafe_allow_html=True) | |
st.write(tagged, unsafe_allow_html=True) | |
df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence']) | |
st.write("**Additional Information**") | |
st.markdown("<style>tbody th {display:none} .blank {display:none}</style>", unsafe_allow_html=True) | |
st.table(df) | |
if __name__ == '__main__': | |
model, tagger = load_models() | |
with col1: | |
st.write("### Input Text") | |
user_input = st.text_area('Press Ctrl + Enter to update results', | |
'George Washington went to Washington.', height=350) | |
if user_input: | |
main(user_input) | |