File size: 4,422 Bytes
dea7dd8
e2cb08c
 
dea7dd8
 
 
 
 
e2cb08c
dea7dd8
 
 
547ad6f
 
e76df46
 
 
0c8e9d4
e76df46
 
 
 
 
 
e2cb08c
e76df46
dea7dd8
 
 
e76df46
 
 
dea7dd8
 
 
 
 
 
 
 
 
 
 
 
547ad6f
 
 
 
6be4a27
547ad6f
 
 
dea7dd8
b06e7d2
 
6be4a27
 
b06e7d2
 
6be4a27
e2cb08c
 
 
 
dea7dd8
547ad6f
 
 
 
dea7dd8
 
6be4a27
547ad6f
 
 
 
 
dea7dd8
 
 
 
 
547ad6f
 
 
 
ac23fd7
 
547ad6f
 
ac23fd7
547ad6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b107783
 
 
 
 
 
 
 
 
 
 
 
 
 
dea7dd8
ac23fd7
 
0388aa3
dea7dd8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Created by Hansi at 30/08/2023
import os

import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

import streamlit as st
from PIL import Image
from accord_nlp.information_extraction.convertor import entity_pairing, graph_building
from accord_nlp.information_extraction.ie_pipeline import InformationExtractor

from trubrics.integrations.streamlit import FeedbackCollector

ner_args = {
    "labels_list": ["O", "B-quality", "B-property", "I-property", "I-quality", "B-object", "I-object", "B-value", "I-value"],
    "use_multiprocessing": False,
    "process_count": 1
}

re_args = {
    "labels_list": ["selection", "necessity", "none", "greater", "part-of", "equal", "greater-equal", "less-equal", "not-part-of", "less"],
    "special_tags": ["<e1>", "<e2>"],  # Should be either begin_tag or end_tag
    "use_multiprocessing": False,
    "process_count": 1
}

@st.cache_resource
def init():
    return InformationExtractor(
        ner_model_info=('roberta', 'ACCORD-NLP/ner-roberta-large', ner_args),
        re_model_info=('roberta', 'ACCORD-NLP/re-roberta-large', re_args))


st.set_page_config(
    page_title='ACCORD NLP Demo',
    initial_sidebar_state='expanded',
    layout='wide',
)

with st.spinner(text="Initialising..."):
    ie = init()


collector = FeedbackCollector(
    # component_name="default",
    email=st.secrets["TRUBRICS_EMAIL"],
    password=st.secrets["TRUBRICS_PASSWORD"],
    project="accord-nlp-ie"
)


def main():
    image = Image.open(os.path.join(os.path.dirname(__file__), 'accord_logo.png'))
    st.sidebar.image(image)

    # st.sidebar.markdown("[![image](upload://accordproject.eu/wp-content/uploads/2022/08/accord_logo-e1662800862179.png)](https://accordproject.eu/)")
    # st.sidebar.markdown(
    #     "[![image](os.path.join(os.path.dirname(__file__), 'accord_logo.png'))](https://accordproject.eu/)")


    # st.sidebar.title("ACCORD-NLP")
    st.sidebar.header("Information Extractor")
    st.sidebar.markdown("Extract entities and their relations from textual data")
    st.sidebar.markdown(
        "[codebase](https://github.com/Accord-Project/NLP-Framework)"
    )
    st.sidebar.markdown(
        "[models](https://huggingface.co/ACCORD-NLP)"
    )


    if 'text' not in st.session_state:
        st.session_state['text'] = ''
    if 'graph' not in st.session_state:
        st.session_state['graph'] = None

    st.header("Input a sentence")

    txt = st.text_area('Sentence')

    if txt:
        if txt == st.session_state['text']:
            st.header('Entity-Relation Representation')
            st.graphviz_chart(st.session_state['graph'], use_container_width=True)

            st.session_state['text'] = txt

        else:
            st.session_state['text'] = txt
            st.session_state['graph'] = None

            # preprocess
            sentence = ie.preprocess(txt)

            # NER
            with st.spinner(text="Recognising entities..."):
                ner_predictions, ner_raw_outputs = ie.ner_model.predict([sentence])

            with st.spinner(text="Extracting relations..."):
                # pair entities to predict their relations
                entity_pair_df = entity_pairing(sentence, ner_predictions[0])

                # relation extraction
                re_predictions, re_raw_outputs = ie.re_model.predict(entity_pair_df['output'].tolist())
                entity_pair_df['prediction'] = re_predictions

            with st.spinner(text="Building graph..."):
                # build graph
                graph = graph_building(entity_pair_df, view=False)

            st.header('Entity-Relation Representation')
            # st.graphviz_chart(graph)
            st.graphviz_chart(graph, use_container_width=True)

            st.session_state['graph'] = graph

    # if st.session_state['graph'] is not None:
        # st.divider()
        # st.write("Does this prediction look correct?")
        # collector.st_feedback(
        #     component="default",
        #     feedback_type="thumbs",
        #     model="v1-test",
        #     align="flex-start",
        #     metadata={
        #         "sentence": txt
        #     },
        #     open_feedback_label="[Optional] Provide additional feedback",
        #     # single_submit=False
        # )

        # st.session_state['text'] = ''
        # st.session_state['graph'] = None


if __name__ == '__main__':
    main()