Spaces:
Runtime error
Runtime error
# Created by Hansi at 30/08/2023 | |
import os | |
import nltk | |
nltk.download('punkt') | |
nltk.download('averaged_perceptron_tagger') | |
import streamlit as st | |
from PIL import Image | |
from accord_nlp.information_extraction.convertor import entity_pairing, graph_building | |
from accord_nlp.information_extraction.ie_pipeline import InformationExtractor | |
from trubrics.integrations.streamlit import FeedbackCollector | |
ner_args = { | |
"labels_list": ["O", "B-quality", "B-property", "I-property", "I-quality", "B-object", "I-object", "B-value", "I-value"], | |
"use_multiprocessing": False, | |
"process_count": 1 | |
} | |
re_args = { | |
"labels_list": ["selection", "necessity", "none", "greater", "part-of", "equal", "greater-equal", "less-equal", "not-part-of", "less"], | |
"special_tags": ["<e1>", "<e2>"], # Should be either begin_tag or end_tag | |
"use_multiprocessing": False, | |
"process_count": 1 | |
} | |
def init(): | |
return InformationExtractor( | |
ner_model_info=('roberta', 'ACCORD-NLP/ner-roberta-large', ner_args), | |
re_model_info=('roberta', 'ACCORD-NLP/re-roberta-large', re_args)) | |
st.set_page_config( | |
page_title='ACCORD NLP Demo', | |
initial_sidebar_state='expanded', | |
layout='wide', | |
) | |
with st.spinner(text="Initialising..."): | |
ie = init() | |
collector = FeedbackCollector( | |
# component_name="default", | |
email=st.secrets["TRUBRICS_EMAIL"], | |
password=st.secrets["TRUBRICS_PASSWORD"], | |
project="accord-nlp-ie" | |
) | |
def main(): | |
image = Image.open(os.path.join(os.path.dirname(__file__), 'accord_logo.png')) | |
st.sidebar.image(image) | |
# st.sidebar.markdown("[![image](upload://accordproject.eu/wp-content/uploads/2022/08/accord_logo-e1662800862179.png)](https://accordproject.eu/)") | |
# st.sidebar.markdown( | |
# "[![image](os.path.join(os.path.dirname(__file__), 'accord_logo.png'))](https://accordproject.eu/)") | |
# st.sidebar.title("ACCORD-NLP") | |
st.sidebar.header("Information Extractor") | |
st.sidebar.markdown("Extract entities and their relations from textual data") | |
st.sidebar.markdown( | |
"[codebase](https://github.com/Accord-Project/NLP-Framework)" | |
) | |
st.sidebar.markdown( | |
"[models](https://huggingface.co/ACCORD-NLP)" | |
) | |
if 'text' not in st.session_state: | |
st.session_state['text'] = '' | |
if 'graph' not in st.session_state: | |
st.session_state['graph'] = None | |
st.header("Input a sentence") | |
txt = st.text_area('Sentence') | |
if txt: | |
if txt == st.session_state['text']: | |
st.header('Entity-Relation Representation') | |
st.graphviz_chart(st.session_state['graph'], use_container_width=True) | |
st.session_state['text'] = txt | |
else: | |
st.session_state['text'] = txt | |
st.session_state['graph'] = None | |
# preprocess | |
sentence = ie.preprocess(txt) | |
# NER | |
with st.spinner(text="Recognising entities..."): | |
ner_predictions, ner_raw_outputs = ie.ner_model.predict([sentence]) | |
with st.spinner(text="Extracting relations..."): | |
# pair entities to predict their relations | |
entity_pair_df = entity_pairing(sentence, ner_predictions[0]) | |
# relation extraction | |
re_predictions, re_raw_outputs = ie.re_model.predict(entity_pair_df['output'].tolist()) | |
entity_pair_df['prediction'] = re_predictions | |
with st.spinner(text="Building graph..."): | |
# build graph | |
graph = graph_building(entity_pair_df, view=False) | |
st.header('Entity-Relation Representation') | |
# st.graphviz_chart(graph) | |
st.graphviz_chart(graph, use_container_width=True) | |
st.session_state['graph'] = graph | |
# if st.session_state['graph'] is not None: | |
# st.divider() | |
# st.write("Does this prediction look correct?") | |
# collector.st_feedback( | |
# component="default", | |
# feedback_type="thumbs", | |
# model="v1-test", | |
# align="flex-start", | |
# metadata={ | |
# "sentence": txt | |
# }, | |
# open_feedback_label="[Optional] Provide additional feedback", | |
# # single_submit=False | |
# ) | |
# st.session_state['text'] = '' | |
# st.session_state['graph'] = None | |
if __name__ == '__main__': | |
main() |