import pandas as pd import streamlit as st import streamlit.components.v1 as components from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline from src import GBRT, wikipedia_search, google_search TYPE = { 'LOC': ' location', 'PER': ' person', 'ORG': ' organization', 'MISC': '' } COLOR = { 'LOC': '#40E0D0', 'PER': '#6495ED', 'ORG': '#CCCCFF', 'MISC': '#FF7F50' } # --------------------------------------------------------------------------- # Loading models # --------------------------------------------------------------------------- @st.cache(allow_output_mutation=True, show_spinner=True) def load_models(): # NER tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer, device=0, aggregation_strategy="average") # NED model = GBRT() return model, tagger # --------------------------------------------------------------------------- # Page setup # --------------------------------------------------------------------------- st.set_page_config(layout="wide", page_title='Named Entity Disambiguation') st.write("## Named Entity Disambiguation") col1, col2 = st.columns(2) # --------------------------------------------------------------------------- # Candidate Generation # --------------------------------------------------------------------------- def get_candidates(mentions_tags): candidates = [] cache = {} for mention, tag in mentions_tags: if (mention, tag) in cache.keys(): candidates.append((mention, cache[(mention, tag)])) else: res1 = google_search(mention + TYPE[tag]) res2 = wikipedia_search(mention, limit=10) cands = list(set(res1 + res2)) cache[(mention, tag)] = cands candidates.append((mention, cands)) return candidates # --------------------------------------------------------------------------- # Rendering Setup # --------------------------------------------------------------------------- def display_tag(text, typ, label): if label != 'NIL': label = "https://en.wikipedia.org/wiki/" + label return f""" {text} {typ} """ # --------------------------------------------------------------------------- # Full Pipeline # --------------------------------------------------------------------------- def main(text): ner_results = tagger(text) tagged, last_pos = '', 0 with st.spinner('Generating Candidates'): mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results]) with st.spinner('Disambiguating Mentions'): preditions = model.link(mentions_cands, text) with st.spinner('Rendering Results'): for i, res in enumerate(ner_results): tag = display_tag(res['word'], res['entity_group'], preditions[i][1]) tagged += text[last_pos:res['start']] + tag last_pos = res['end'] tagged += text[last_pos:] with col2: st.write("### Disambiguated Text") components.html(f'
{tagged}
', scrolling=True, height=500) df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence']) st.write("**Additional Information**") st.dataframe(df) if __name__ == '__main__': model, tagger = load_models() with col1: st.write("### Input Text") user_input = st.text_area('Press Ctrl + Enter to update results', 'George Washington went to Washington.', height=350) if user_input: main(user_input)