import streamlit as st from annotated_text import annotated_text import transformers ENTITY_TO_COLOR = { 'PER': '#8ef', 'LOC': '#faa', 'ORG': '#afa', 'MISC': '#fea', } @st.cache(allow_output_mutation=True, show_spinner=False) def get_pipe(): model_name = "dslim/bert-base-NER" model = transformers.AutoModelForTokenClassification.from_pretrained(model_name) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) pipe = transformers.pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") return pipe def parse_text(text, prediction): start = 0 parsed_text = [] for p in prediction: parsed_text.append(text[start:p["start"]]) parsed_text.append((p["word"], p["entity_group"], ENTITY_TO_COLOR[p["entity_group"]])) start = p["end"] parsed_text.append(text[start:]) return parsed_text st.set_page_config(page_title="Named Entity Recognition") st.title("Named Entity Recognition") st.write("Type text into the text box and then press 'Predict' to get the named entities.") default_text = "My name is John Smith. I work at Microsoft. I live in Paris. My favorite painting is the Mona Lisa." text = st.text_area('Enter text here:', value=default_text) submit = st.button('Predict') with st.spinner("Loading model..."): pipe = get_pipe() if (submit and len(text.strip()) > 0) or len(text.strip()) > 0: prediction = pipe(text) parsed_text = parse_text(text, prediction) st.header("Prediction:") annotated_text(*parsed_text) st.header('Raw values:') st.json(prediction)