from transformers import pipeline import streamlit as st from spacy import displacy from typing import List, Tuple import json import random def ner_prediction(model, sentence): """ This function takes in a ner pipeline model and a sentence, make the prediction and returns a list of entity prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX") format """ entity_map = { "B-ORG": "ORG", "B-SEG": "SEGMENT", "B-SEGNUM": "NUM_SEGMENT" } predictions = [] model_output = model(sentence) accumulate = "" current_class = None start = 0 end = 0 for item in model_output: if item['entity'].startswith("B"): if len(accumulate) > 0: predictions.append((current_class, accumulate, start, end)) accumulate = item['word'].lstrip("Ġ") current_class = entity_map[item['entity']] start = item['start'] end = item['end'] else: if item['word'].startswith("Ġ"): accumulate += " "+item['word'].lstrip("Ġ") else: accumulate += item['word'] end = item['end'] # clear last cache if len(accumulate) > 0: predictions.append((current_class, accumulate, start, end)) return predictions def generate_displacy_html(predictions: List[Tuple[str, str, int, int]], sentence) -> str: ''' This function will take in a list of prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX") and generate displacy entity html to be output in streamlit ''' colors = {"SEGMENT": "linear-gradient(90deg, #DBE575, #C3D32C)", "NUM_SEGMENT": "linear-gradient(90deg, #3AD8E8, #1AA7B6)", "ORG": "linear-gradient(90deg, #aa9cfc, #fc9ce7)" } options = {"ents": ["SEGMENT", "NUM_SEGMENT", "ORG"], "colors": colors} payload = [ {'text': sentence, 'ents': [{'start': pred[2], 'end': pred[3], 'label': pred[0]} for pred in predictions], 'title': "Name entity recognition" } ] displacy_html = displacy.render( payload, style='ent', manual=True, options=options) return displacy_html # loading in the model in cache @st.cache(allow_output_mutation=True) def load_model_and_data(): # loading in the sample text with open("sample_articles.json", "r") as json_file: sample_text = json.load(json_file) # loading in the model model_path = "wolfrage89/company_segment_ner" model = pipeline('ner', model_path) return sample_text, model sample_texts, model = load_model_and_data() # creating session state if "article_text" not in st.session_state: st.session_state["article_text"] = "" if "displacy_html" not in st.session_state: st.session_state['displacy_html'] = "" # adding in the side bar st.sidebar.title("Welcome To Company Segment Name Entity Recognition App") random_button = st.sidebar.button("RANDOM") st.sidebar.write("Randomly generates an article for testing") st.sidebar.markdown("---") predict_button = st.sidebar.button("PREDICT!") if random_button: st.session_state['article_text'] = random.choice(sample_texts) st.session_state["displacy_html"] = "" if predict_button: if len(st.session_state['article_text']) > 0: predictions = ner_prediction(model, st.session_state['article_text']) st.session_state['displacy_html'] = generate_displacy_html( predictions, st.session_state['article_text']) else: st.session_state['displacy_html'] = "" st.session_state["article_text"] = st.text_area( label="Insert article here", value=st.session_state["article_text"], height=200) st.markdown(st.session_state['displacy_html'], unsafe_allow_html=True)