Atharva
pipeline update
3befa47
import pandas as pd
import streamlit as st
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
from src import GBRT, wikidata_search, google_search
TYPE = {
'LOC': ' location',
'PER': ' person',
'ORG': ' organization',
'MISC': ''
}
COLOR = {
'LOC': '#10B981',
'PER': '#0EA5E9',
'ORG': '#A855F7',
'MISC': '#F97316'
}
# ---------------------------------------------------------------------------
# Loading models
# ---------------------------------------------------------------------------
@st.cache(allow_output_mutation=True, show_spinner=True)
def load_models():
# NER
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer,
device=-1, aggregation_strategy="average")
# NED
model = GBRT()
return model, tagger
# ---------------------------------------------------------------------------
# Page setup
# ---------------------------------------------------------------------------
st.set_page_config(layout="wide", page_title='Named Entity Disambiguation')
col1, col2 = st.columns(2)
# ---------------------------------------------------------------------------
# Candidate Generation
# ---------------------------------------------------------------------------
def get_candidates(mentions_tags):
candidates = []
cache = {}
for mention, tag in mentions_tags:
if (mention, tag) in cache.keys():
candidates.append((mention, cache[(mention, tag)]))
else:
cands = wikidata_search(mention, limit=3)
if cands == [] and tag != 'PER':
cands = google_search(mention, limit=3)
cache[(mention, tag)] = cands
candidates.append((mention, cands))
print(mention, cands)
return candidates
# ---------------------------------------------------------------------------
# Rendering Setup
# ---------------------------------------------------------------------------
def display_tag(text, typ, label):
if label != 'NIL':
return f'[<span style="color: {COLOR[typ]}" class="tag">**{text}** <small>{typ}</small></span>](https://en.wikipedia.org/wiki/{label})'
else:
return f'<span style="color: #EF4444" class="tag">**{text}** <small>{typ}</small></span>'
# ---------------------------------------------------------------------------
# Full Pipeline
# ---------------------------------------------------------------------------
def main(text):
# text = text.replace('\n', ' ')
ner_results = tagger(text)
tagged, last_pos = '', 0
with st.spinner('Generating Candidates'):
mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results])
with st.spinner('Disambiguating Mentions'):
preditions = model.link(mentions_cands, text)
with st.spinner('Rendering Results'):
for i, res in enumerate(ner_results):
tag = display_tag(res['word'], res['entity_group'], preditions[i][1])
tagged += text[last_pos:res['start']] + tag
last_pos = res['end']
tagged += text[last_pos:]
with col2:
st.write("### Disambiguated Text")
st.markdown("<style>a {text-decoration: none;} small {color: inherit !important} .tag {border: 1px solid; border-radius: 5px; padding: 0px 3px; white-space: nowrap}</style>", unsafe_allow_html=True)
st.write(tagged, unsafe_allow_html=True)
df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence'])
st.write("**Additional Information**")
st.markdown("<style>tbody th {display:none} .blank {display:none}</style>", unsafe_allow_html=True)
st.table(df)
if __name__ == '__main__':
model, tagger = load_models()
with col1:
st.write("### Input Text")
user_input = st.text_area('Press Ctrl + Enter to update results',
'George Washington went to Washington.', height=350)
if user_input:
main(user_input)