import pandas as pd import streamlit as st from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline from src import GBRT, wikipedia_search, wikidata_search TYPE = { 'LOC': ' location', 'PER': ' person', 'ORG': ' organization', 'MISC': '' } COLOR = { 'LOC': '#10B981', 'PER': '#0EA5E9', 'ORG': '#A855F7', 'MISC': '#F97316' } # --------------------------------------------------------------------------- # Loading models # --------------------------------------------------------------------------- @st.cache(allow_output_mutation=True, show_spinner=True) def load_models(): # NER tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") bert_ner = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") tagger = pipeline("token-classification", model=bert_ner, tokenizer=tokenizer, device=-1, aggregation_strategy="average") # NED model = GBRT() return model, tagger # --------------------------------------------------------------------------- # Page setup # --------------------------------------------------------------------------- st.set_page_config(layout="wide", page_title='Named Entity Disambiguation') col1, col2 = st.columns(2) # --------------------------------------------------------------------------- # Candidate Generation # --------------------------------------------------------------------------- def get_candidates(mentions_tags): candidates = [] cache = {} for mention, tag in mentions_tags: if (mention, tag) in cache.keys(): candidates.append((mention, cache[(mention, tag)])) else: cands = wikidata_search(mention, limit=3) if cands == []: cands = wikipedia_search(mention, limit=3) cache[(mention, tag)] = cands candidates.append((mention, cands)) print(mention, cands) return candidates # --------------------------------------------------------------------------- # Rendering Setup # --------------------------------------------------------------------------- def display_tag(text, typ, label): if label != 'NIL': return f'[**{text}** {typ}](https://en.wikipedia.org/wiki/{label})' else: return f'**{text}** {typ}' # --------------------------------------------------------------------------- # Full Pipeline # --------------------------------------------------------------------------- def main(text): # text = text.replace('\n', ' ') ner_results = tagger(text) tagged, last_pos = '', 0 with st.spinner('Generating Candidates'): mentions_cands = get_candidates([(res['word'], res['entity_group']) for res in ner_results]) with st.spinner('Disambiguating Mentions'): preditions = model.link(mentions_cands, text) with st.spinner('Rendering Results'): for i, res in enumerate(ner_results): tag = display_tag(res['word'], res['entity_group'], preditions[i][1]) tagged += text[last_pos:res['start']] + tag last_pos = res['end'] tagged += text[last_pos:] with col2: st.write("### Disambiguated Text") st.markdown("", unsafe_allow_html=True) st.write(tagged, unsafe_allow_html=True) df = pd.DataFrame(data=preditions, columns=['Mention', 'Prediction', 'Confidence']) st.write("**Additional Information**") st.markdown("", unsafe_allow_html=True) st.table(df) if __name__ == '__main__': model, tagger = load_models() with col1: st.write("### Input Text") user_input = st.text_area('Press Ctrl + Enter to update results', 'George Washington went to Washington.', height=350) if user_input: main(user_input)