File size: 9,617 Bytes
fa37060
 
 
 
 
 
 
 
 
25dc5c7
fa37060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb806bb
 
5baad4c
cb806bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa37060
 
 
 
 
6feed8a
cb806bb
fa37060
6feed8a
 
 
fa37060
 
 
25dc5c7
fa37060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57b5bc
fa37060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6feed8a
fa37060
6feed8a
 
 
adeff61
6feed8a
fa37060
 
25dc5c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import gradio as gr
import regex as re
import torch
import nltk
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from nltk.tokenize import sent_tokenize
import plotly.express as px
import time
import tqdm
nltk.download('punkt')

# Define the model and tokenizer
checkpoint = "sadickam/sdg-classification-bert"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

# Define the function for preprocessing text
def prep_text(text):
    clean_sents = []
    sent_tokens = sent_tokenize(str(text))
    for sent_token in sent_tokens:
        word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()]
        clean_sents.append(' '.join((word_tokens)))
    joined = ' '.join(clean_sents).strip(' ')
    joined = re.sub(r'`', "", joined)
    joined = re.sub(r'"', "", joined)
    return joined

# APP INFO
def app_info():
    check = """
    Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text. 
    """
    
    return check

# Create Gradio interface for single text
iface1 = gr.Interface(
    fn=app_info, inputs=None, outputs=['text'], title="General-Infomation",
    description= '''
    This app powered by the sgdBERT model (sadickam/sdg-classification-bert) is for automatic classification of text with respect to 
    the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability 
    assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the 
    OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5].
    
    This app has two analysis modules summarised below:
    - Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction.
    - Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text.
    
    This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files. 
    If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section.
    
    <h3>Contact</h3>
    <p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model 
    powering this app, we are happy to hear from you.</p>
    
    <p>App contact: s.sadick@deakin.edu.au</p>
    ''')

# SINGLE TEXT
# Define the prediction function
def predict_sdg(text):
    # Preprocess the input text
    cleaned_text = prep_text(text)
    if cleaned_text == "":
        raise gr.Error('This model needs some text input to return a prediction')
    elif cleaned_text != "":
        # Tokenize the preprocessed text
        tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
        # Predict
        text_logits = model(**tokenized_text).logits
        predictions = torch.softmax(text_logits, dim=1).tolist()[0]
        # SDG labels
        label_list = [
            'GOAL 1: No Poverty',
            'GOAL 2: Zero Hunger',
            'GOAL 3: Good Health and Well-being',
            'GOAL 4: Quality Education',
            'GOAL 5: Gender Equality',
            'GOAL 6: Clean Water and Sanitation',
            'GOAL 7: Affordable and Clean Energy',
            'GOAL 8: Decent Work and Economic Growth',
            'GOAL 9: Industry, Innovation and Infrastructure',
            'GOAL 10: Reduced Inequality',
            'GOAL 11: Sustainable Cities and Communities',
            'GOAL 12: Responsible Consumption and Production',
            'GOAL 13: Climate Action',
            'GOAL 14: Life Below Water',
            'GOAL 15: Life on Land',
            'GOAL 16: Peace, Justice and Strong Institutions'
        ]
        # dictionary with label as key and percentage as value
        pred_dict = dict(zip(label_list, predictions))
    
        # sort 'pred_dict' by value and index the highest at [0]
        sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)
    
        # Make dataframe for plotly bar chart
        u, v = zip(*sorted_preds)
        m = list(u)
        n = list(v)
        df2 = pd.DataFrame()
        df2['SDG'] = m
        df2['Likelihood'] = n
    
        # plot graph of predictions
        fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h")
    
        fig.update_layout(
            # barmode='stack', 
            template='seaborn', font=dict(family="Arial", size=12, color="black"),
            autosize=True,
            #width=800,
            #height=500,
            xaxis_title="Likelihood of SDG",
            yaxis_title="Sustainable development goals (SDG)",
            # legend_title="Topics"
        )
    
        fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
        fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
        fig.update_annotations(font_size=12)  # this changes y_axis, x_axis and subplot title font sizes
    
        # Make dataframe for plotly bar chart
        #df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood'])
    
        # Return the top prediction
        top_prediction = sorted_preds[0]

    # Return result
    return {top_prediction[0]: round(top_prediction[1], 3)}, fig

# Create Gradio interface for single text
iface2 = gr.Interface(fn=predict_sdg,
                      inputs=gr.Textbox(lines=7, label="Paste or type text here"), 
                      outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)], 
                      title="Single Text Prediction",
                      article="**Note:** The quality of model predictions may depend on the quality of information provided."
                     )

# UPLOAD CSV
# Define the prediction function
def predict_sdg_from_csv(file, progress=gr.Progress()):
    # Read the CSV file
    df_docs = pd.read_csv(file)
    text_list = df_docs["text_inputs"].tolist()

    # SDG labels list
    label_list = [
        'GOAL 1: No Poverty',
        'GOAL 2: Zero Hunger',
        'GOAL 3: Good Health and Well-being',
        'GOAL 4: Quality Education',
        'GOAL 5: Gender Equality',
        'GOAL 6: Clean Water and Sanitation',
        'GOAL 7: Affordable and Clean Energy',
        'GOAL 8: Decent Work and Economic Growth',
        'GOAL 9: Industry, Innovation and Infrastructure',
        'GOAL 10: Reduced Inequality',
        'GOAL 11: Sustainable Cities and Communities',
        'GOAL 12: Responsible Consumption and Production',
        'GOAL 13: Climate Action',
        'GOAL 14: Life Below Water',
        'GOAL 15: Life on Land',
        'GOAL 16: Peace, Justice and Strong Institutions'
    ]

    # Lists for appending predictions
    predicted_labels = []
    prediction_score = []

    # Preprocess text and make predictions
    for text_input in progress.tqdm(text_list, desc="Analysing data"):
        time.sleep(0.02)  # Sleep to avoid rate limiting
        cleaned_text = prep_text(text_input)
        tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
        text_logits = model(**tokenized_text).logits
        predictions = torch.softmax(text_logits, dim=1).tolist()[0]
        pred_dict = dict(zip(label_list, predictions))
        sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True)
        predicted_labels.append(sorted_preds[0][0])
        prediction_score.append(sorted_preds[0][1])

    # Append predictions to the DataFrame
    df_docs['SDG_predicted'] = predicted_labels
    df_docs['prediction_score'] = prediction_score

    df_docs.to_csv('sdg_predictions.csv')
    output_csv = gr.File(value='sdg_predictions.csv', visible=True)

    # Create the histogram
    fig = px.histogram(df_docs, y="SDG_predicted")
    fig.update_layout(
        template='seaborn',
        font=dict(family="Arial", size=12, color="black"),
        autosize=True,
        #width=800,
        #height=500,
        xaxis_title="SDG counts",
        yaxis_title="Sustainable development goals (SDG)",
    )
    fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
    fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
    fig.update_annotations(font_size=12)

    return fig, output_csv

# Define the input component
file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"])

# Create the Gradio interface
iface3 = gr.Interface(fn=predict_sdg_from_csv, 
                      inputs= file_input, 
                      outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)], 
                      title="Multi-text Prediction (CVS)",
                      description='**NOTE:** The column to be analysed must be titled ***text_inputs***')

demo = gr.TabbedInterface(interface_list = [iface1, iface2, iface3], 
                          tab_names = ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"],
                          title = "Sustainble Development Goals (SDG) Text Classifier App",
                          theme = 'soft'
                         )

# Run the interface
demo.queue().launch()