wolfrage89's picture
first commit
f09ca43
from transformers import pipeline
import streamlit as st
from spacy import displacy
from typing import List, Tuple
import json
import random
def ner_prediction(model, sentence):
""" This function takes in a ner pipeline model and a sentence, make the prediction and returns a list of entity
prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX") format
"""
entity_map = {
"B-ORG": "ORG",
"B-SEG": "SEGMENT",
"B-SEGNUM": "NUM_SEGMENT"
}
predictions = []
model_output = model(sentence)
accumulate = ""
current_class = None
start = 0
end = 0
for item in model_output:
if item['entity'].startswith("B"):
if len(accumulate) > 0:
predictions.append((current_class, accumulate, start, end))
accumulate = item['word'].lstrip("Ġ")
current_class = entity_map[item['entity']]
start = item['start']
end = item['end']
else:
if item['word'].startswith("Ġ"):
accumulate += " "+item['word'].lstrip("Ġ")
else:
accumulate += item['word']
end = item['end']
# clear last cache
if len(accumulate) > 0:
predictions.append((current_class, accumulate, start, end))
return predictions
def generate_displacy_html(predictions: List[Tuple[str, str, int, int]], sentence) -> str:
'''
This function will take in a list of prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX")
and generate displacy entity html to be output in streamlit
'''
colors = {"SEGMENT": "linear-gradient(90deg, #DBE575, #C3D32C)",
"NUM_SEGMENT": "linear-gradient(90deg, #3AD8E8, #1AA7B6)",
"ORG": "linear-gradient(90deg, #aa9cfc, #fc9ce7)"
}
options = {"ents": ["SEGMENT", "NUM_SEGMENT", "ORG"], "colors": colors}
payload = [
{'text': sentence,
'ents': [{'start': pred[2], 'end': pred[3], 'label': pred[0]} for pred in predictions],
'title': "Name entity recognition"
}
]
displacy_html = displacy.render(
payload, style='ent', manual=True, options=options)
return displacy_html
# loading in the model in cache
@st.cache(allow_output_mutation=True)
def load_model_and_data():
# loading in the sample text
with open("sample_articles.json", "r") as json_file:
sample_text = json.load(json_file)
# loading in the model
model_path = "wolfrage89/company_segment_ner"
model = pipeline('ner', model_path)
return sample_text, model
sample_texts, model = load_model_and_data()
# creating session state
if "article_text" not in st.session_state:
st.session_state["article_text"] = ""
if "displacy_html" not in st.session_state:
st.session_state['displacy_html'] = ""
# adding in the side bar
st.sidebar.title("Welcome To Company Segment Name Entity Recognition App")
random_button = st.sidebar.button("RANDOM")
st.sidebar.write("Randomly generates an article for testing")
st.sidebar.markdown("---")
predict_button = st.sidebar.button("PREDICT!")
if random_button:
st.session_state['article_text'] = random.choice(sample_texts)
st.session_state["displacy_html"] = ""
if predict_button:
if len(st.session_state['article_text']) > 0:
predictions = ner_prediction(model, st.session_state['article_text'])
st.session_state['displacy_html'] = generate_displacy_html(
predictions, st.session_state['article_text'])
else:
st.session_state['displacy_html'] = ""
st.session_state["article_text"] = st.text_area(
label="Insert article here", value=st.session_state["article_text"], height=200)
st.markdown(st.session_state['displacy_html'], unsafe_allow_html=True)