HHansi's picture
Update app.py
b107783 verified
raw
history blame contribute delete
No virus
4.42 kB
# Created by Hansi at 30/08/2023
import os
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
import streamlit as st
from PIL import Image
from accord_nlp.information_extraction.convertor import entity_pairing, graph_building
from accord_nlp.information_extraction.ie_pipeline import InformationExtractor
from trubrics.integrations.streamlit import FeedbackCollector
ner_args = {
"labels_list": ["O", "B-quality", "B-property", "I-property", "I-quality", "B-object", "I-object", "B-value", "I-value"],
"use_multiprocessing": False,
"process_count": 1
}
re_args = {
"labels_list": ["selection", "necessity", "none", "greater", "part-of", "equal", "greater-equal", "less-equal", "not-part-of", "less"],
"special_tags": ["<e1>", "<e2>"], # Should be either begin_tag or end_tag
"use_multiprocessing": False,
"process_count": 1
}
@st.cache_resource
def init():
return InformationExtractor(
ner_model_info=('roberta', 'ACCORD-NLP/ner-roberta-large', ner_args),
re_model_info=('roberta', 'ACCORD-NLP/re-roberta-large', re_args))
st.set_page_config(
page_title='ACCORD NLP Demo',
initial_sidebar_state='expanded',
layout='wide',
)
with st.spinner(text="Initialising..."):
ie = init()
collector = FeedbackCollector(
# component_name="default",
email=st.secrets["TRUBRICS_EMAIL"],
password=st.secrets["TRUBRICS_PASSWORD"],
project="accord-nlp-ie"
)
def main():
image = Image.open(os.path.join(os.path.dirname(__file__), 'accord_logo.png'))
st.sidebar.image(image)
# st.sidebar.markdown("[![image](upload://accordproject.eu/wp-content/uploads/2022/08/accord_logo-e1662800862179.png)](https://accordproject.eu/)")
# st.sidebar.markdown(
# "[![image](os.path.join(os.path.dirname(__file__), 'accord_logo.png'))](https://accordproject.eu/)")
# st.sidebar.title("ACCORD-NLP")
st.sidebar.header("Information Extractor")
st.sidebar.markdown("Extract entities and their relations from textual data")
st.sidebar.markdown(
"[codebase](https://github.com/Accord-Project/NLP-Framework)"
)
st.sidebar.markdown(
"[models](https://huggingface.co/ACCORD-NLP)"
)
if 'text' not in st.session_state:
st.session_state['text'] = ''
if 'graph' not in st.session_state:
st.session_state['graph'] = None
st.header("Input a sentence")
txt = st.text_area('Sentence')
if txt:
if txt == st.session_state['text']:
st.header('Entity-Relation Representation')
st.graphviz_chart(st.session_state['graph'], use_container_width=True)
st.session_state['text'] = txt
else:
st.session_state['text'] = txt
st.session_state['graph'] = None
# preprocess
sentence = ie.preprocess(txt)
# NER
with st.spinner(text="Recognising entities..."):
ner_predictions, ner_raw_outputs = ie.ner_model.predict([sentence])
with st.spinner(text="Extracting relations..."):
# pair entities to predict their relations
entity_pair_df = entity_pairing(sentence, ner_predictions[0])
# relation extraction
re_predictions, re_raw_outputs = ie.re_model.predict(entity_pair_df['output'].tolist())
entity_pair_df['prediction'] = re_predictions
with st.spinner(text="Building graph..."):
# build graph
graph = graph_building(entity_pair_df, view=False)
st.header('Entity-Relation Representation')
# st.graphviz_chart(graph)
st.graphviz_chart(graph, use_container_width=True)
st.session_state['graph'] = graph
# if st.session_state['graph'] is not None:
# st.divider()
# st.write("Does this prediction look correct?")
# collector.st_feedback(
# component="default",
# feedback_type="thumbs",
# model="v1-test",
# align="flex-start",
# metadata={
# "sentence": txt
# },
# open_feedback_label="[Optional] Provide additional feedback",
# # single_submit=False
# )
# st.session_state['text'] = ''
# st.session_state['graph'] = None
if __name__ == '__main__':
main()