wolfrage89's picture
first commit
f09ca43
raw
history blame
No virus
3.8 kB
from transformers import pipeline
import streamlit as st
from spacy import displacy
from typing import List, Tuple
import json
import random
def ner_prediction(model, sentence):
""" This function takes in a ner pipeline model and a sentence, make the prediction and returns a list of entity
prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX") format
"""
entity_map = {
"B-ORG": "ORG",
"B-SEG": "SEGMENT",
"B-SEGNUM": "NUM_SEGMENT"
}
predictions = []
model_output = model(sentence)
accumulate = ""
current_class = None
start = 0
end = 0
for item in model_output:
if item['entity'].startswith("B"):
if len(accumulate) > 0:
predictions.append((current_class, accumulate, start, end))
accumulate = item['word'].lstrip("Ġ")
current_class = entity_map[item['entity']]
start = item['start']
end = item['end']
else:
if item['word'].startswith("Ġ"):
accumulate += " "+item['word'].lstrip("Ġ")
else:
accumulate += item['word']
end = item['end']
# clear last cache
if len(accumulate) > 0:
predictions.append((current_class, accumulate, start, end))
return predictions
def generate_displacy_html(predictions: List[Tuple[str, str, int, int]], sentence) -> str:
'''
This function will take in a list of prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX")
and generate displacy entity html to be output in streamlit
'''
colors = {"SEGMENT": "linear-gradient(90deg, #DBE575, #C3D32C)",
"NUM_SEGMENT": "linear-gradient(90deg, #3AD8E8, #1AA7B6)",
"ORG": "linear-gradient(90deg, #aa9cfc, #fc9ce7)"
}
options = {"ents": ["SEGMENT", "NUM_SEGMENT", "ORG"], "colors": colors}
payload = [
{'text': sentence,
'ents': [{'start': pred[2], 'end': pred[3], 'label': pred[0]} for pred in predictions],
'title': "Name entity recognition"
}
]
displacy_html = displacy.render(
payload, style='ent', manual=True, options=options)
return displacy_html
# loading in the model in cache
@st.cache(allow_output_mutation=True)
def load_model_and_data():
# loading in the sample text
with open("sample_articles.json", "r") as json_file:
sample_text = json.load(json_file)
# loading in the model
model_path = "wolfrage89/company_segment_ner"
model = pipeline('ner', model_path)
return sample_text, model
sample_texts, model = load_model_and_data()
# creating session state
if "article_text" not in st.session_state:
st.session_state["article_text"] = ""
if "displacy_html" not in st.session_state:
st.session_state['displacy_html'] = ""
# adding in the side bar
st.sidebar.title("Welcome To Company Segment Name Entity Recognition App")
random_button = st.sidebar.button("RANDOM")
st.sidebar.write("Randomly generates an article for testing")
st.sidebar.markdown("---")
predict_button = st.sidebar.button("PREDICT!")
if random_button:
st.session_state['article_text'] = random.choice(sample_texts)
st.session_state["displacy_html"] = ""
if predict_button:
if len(st.session_state['article_text']) > 0:
predictions = ner_prediction(model, st.session_state['article_text'])
st.session_state['displacy_html'] = generate_displacy_html(
predictions, st.session_state['article_text'])
else:
st.session_state['displacy_html'] = ""
st.session_state["article_text"] = st.text_area(
label="Insert article here", value=st.session_state["article_text"], height=200)
st.markdown(st.session_state['displacy_html'], unsafe_allow_html=True)