Spaces:
Runtime error
Runtime error
from transformers import pipeline | |
import streamlit as st | |
from spacy import displacy | |
from typing import List, Tuple | |
import json | |
import random | |
def ner_prediction(model, sentence): | |
""" This function takes in a ner pipeline model and a sentence, make the prediction and returns a list of entity | |
prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX") format | |
""" | |
entity_map = { | |
"B-ORG": "ORG", | |
"B-SEG": "SEGMENT", | |
"B-SEGNUM": "NUM_SEGMENT" | |
} | |
predictions = [] | |
model_output = model(sentence) | |
accumulate = "" | |
current_class = None | |
start = 0 | |
end = 0 | |
for item in model_output: | |
if item['entity'].startswith("B"): | |
if len(accumulate) > 0: | |
predictions.append((current_class, accumulate, start, end)) | |
accumulate = item['word'].lstrip("Ġ") | |
current_class = entity_map[item['entity']] | |
start = item['start'] | |
end = item['end'] | |
else: | |
if item['word'].startswith("Ġ"): | |
accumulate += " "+item['word'].lstrip("Ġ") | |
else: | |
accumulate += item['word'] | |
end = item['end'] | |
# clear last cache | |
if len(accumulate) > 0: | |
predictions.append((current_class, accumulate, start, end)) | |
return predictions | |
def generate_displacy_html(predictions: List[Tuple[str, str, int, int]], sentence) -> str: | |
''' | |
This function will take in a list of prediction in ("LABEL", "TEXT", "START_IDX", "END_IDX") | |
and generate displacy entity html to be output in streamlit | |
''' | |
colors = {"SEGMENT": "linear-gradient(90deg, #DBE575, #C3D32C)", | |
"NUM_SEGMENT": "linear-gradient(90deg, #3AD8E8, #1AA7B6)", | |
"ORG": "linear-gradient(90deg, #aa9cfc, #fc9ce7)" | |
} | |
options = {"ents": ["SEGMENT", "NUM_SEGMENT", "ORG"], "colors": colors} | |
payload = [ | |
{'text': sentence, | |
'ents': [{'start': pred[2], 'end': pred[3], 'label': pred[0]} for pred in predictions], | |
'title': "Name entity recognition" | |
} | |
] | |
displacy_html = displacy.render( | |
payload, style='ent', manual=True, options=options) | |
return displacy_html | |
# loading in the model in cache | |
def load_model_and_data(): | |
# loading in the sample text | |
with open("sample_articles.json", "r") as json_file: | |
sample_text = json.load(json_file) | |
# loading in the model | |
model_path = "wolfrage89/company_segment_ner" | |
model = pipeline('ner', model_path) | |
return sample_text, model | |
sample_texts, model = load_model_and_data() | |
# creating session state | |
if "article_text" not in st.session_state: | |
st.session_state["article_text"] = "" | |
if "displacy_html" not in st.session_state: | |
st.session_state['displacy_html'] = "" | |
# adding in the side bar | |
st.sidebar.title("Welcome To Company Segment Name Entity Recognition App") | |
random_button = st.sidebar.button("RANDOM") | |
st.sidebar.write("Randomly generates an article for testing") | |
st.sidebar.markdown("---") | |
predict_button = st.sidebar.button("PREDICT!") | |
if random_button: | |
st.session_state['article_text'] = random.choice(sample_texts) | |
st.session_state["displacy_html"] = "" | |
if predict_button: | |
if len(st.session_state['article_text']) > 0: | |
predictions = ner_prediction(model, st.session_state['article_text']) | |
st.session_state['displacy_html'] = generate_displacy_html( | |
predictions, st.session_state['article_text']) | |
else: | |
st.session_state['displacy_html'] = "" | |
st.session_state["article_text"] = st.text_area( | |
label="Insert article here", value=st.session_state["article_text"], height=200) | |
st.markdown(st.session_state['displacy_html'], unsafe_allow_html=True) | |