sadickam commited on
Commit
fa37060
1 Parent(s): 03172bc

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +216 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import regex as re
3
+ import torch
4
+ import nltk
5
+ import pandas as pd
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from nltk.tokenize import sent_tokenize
8
+ import plotly.express as px
9
+ import time
10
+ nltk.download('punkt')
11
+
12
+ # Define the model and tokenizer
13
+ checkpoint = "sadickam/sdg-classification-bert"
14
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
15
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
16
+
17
+ # Define the function for preprocessing text
18
+ def prep_text(text):
19
+ clean_sents = []
20
+ sent_tokens = sent_tokenize(str(text))
21
+ for sent_token in sent_tokens:
22
+ word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()]
23
+ clean_sents.append(' '.join((word_tokens)))
24
+ joined = ' '.join(clean_sents).strip(' ')
25
+ joined = re.sub(r'`', "", joined)
26
+ joined = re.sub(r'"', "", joined)
27
+ return joined
28
+
29
+ # APP INFO
30
+ def app_info():
31
+ check = """
32
+ Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text.
33
+ """
34
+
35
+ return check
36
+
37
+ # Create Gradio interface for single text
38
+ iface1 = gr.Interface(
39
+ fn=app_info, inputs=None, outputs=['text'], title="General-Infomation",
40
+ description= '''
41
+ This app powered by the sgdBERT model (sadickam/sdg-classification-bert) is for automatic classification of text with respect to
42
+ the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability
43
+ assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the
44
+ OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5].
45
+
46
+ This app has two analysis modules summarised below:
47
+ - Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction.
48
+ - Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text.
49
+
50
+ This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files.
51
+ If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section.
52
+
53
+ <h3>Contact</h3>
54
+ <p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model
55
+ powering this app, we are happy to hear from you.</p>
56
+
57
+ <p>App contact: s.sadick@deakin.edu.au</p>
58
+ ''')
59
+
60
+ # SINGLE TEXT
61
+ # Define the prediction function
62
+ def predict_sdg(text):
63
+ # Preprocess the input text
64
+ cleaned_text = prep_text(text)
65
+ # Tokenize the preprocessed text
66
+ tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
67
+ # Predict
68
+ text_logits = model(**tokenized_text).logits
69
+ predictions = torch.softmax(text_logits, dim=1).tolist()[0]
70
+ # SDG labels
71
+ label_list = [
72
+ 'GOAL 1: No Poverty',
73
+ 'GOAL 2: Zero Hunger',
74
+ 'GOAL 3: Good Health and Well-being',
75
+ 'GOAL 4: Quality Education',
76
+ 'GOAL 5: Gender Equality',
77
+ 'GOAL 6: Clean Water and Sanitation',
78
+ 'GOAL 7: Affordable and Clean Energy',
79
+ 'GOAL 8: Decent Work and Economic Growth',
80
+ 'GOAL 9: Industry, Innovation and Infrastructure',
81
+ 'GOAL 10: Reduced Inequality',
82
+ 'GOAL 11: Sustainable Cities and Communities',
83
+ 'GOAL 12: Responsible Consumption and Production',
84
+ 'GOAL 13: Climate Action',
85
+ 'GOAL 14: Life Below Water',
86
+ 'GOAL 15: Life on Land',
87
+ 'GOAL 16: Peace, Justice and Strong Institutions'
88
+ ]
89
+ # dictionary with label as key and percentage as value
90
+ pred_dict = dict(zip(label_list, predictions))
91
+
92
+ # sort 'pred_dict' by value and index the highest at [0]
93
+ sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)
94
+
95
+ # Make dataframe for plotly bar chart
96
+ u, v = zip(*sorted_preds)
97
+ m = list(u)
98
+ n = list(v)
99
+ df2 = pd.DataFrame()
100
+ df2['SDG'] = m
101
+ df2['Likelihood'] = n
102
+
103
+ # plot graph of predictions
104
+ fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h")
105
+
106
+ fig.update_layout(
107
+ # barmode='stack',
108
+ template='seaborn', font=dict(family="Arial", size=12, color="black"),
109
+ autosize=True,
110
+ #width=800,
111
+ #height=500,
112
+ xaxis_title="Likelihood of SDG",
113
+ yaxis_title="Sustainable development goals (SDG)",
114
+ # legend_title="Topics"
115
+ )
116
+
117
+ fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
118
+ fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
119
+ fig.update_annotations(font_size=12) # this changes y_axis, x_axis and subplot title font sizes
120
+
121
+ # Make dataframe for plotly bar chart
122
+ #df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood'])
123
+
124
+ # Return the top prediction
125
+ top_prediction = sorted_preds[0]
126
+
127
+ # Return result
128
+ return {top_prediction[0]: round(top_prediction[1], 3)}, fig
129
+
130
+ # Create Gradio interface for single text
131
+ iface2 = gr.Interface(fn=predict_sdg,
132
+ inputs=gr.Textbox(lines=7, label="Paste or type text here"),
133
+ outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)],
134
+ title="Single Text Prediction")
135
+
136
+ # UPLOAD CSV
137
+ # Define the prediction function
138
+ def predict_sdg_from_csv(file):
139
+ # Read the CSV file
140
+ df_docs = pd.read_csv(file)
141
+ text_list = df_docs["text_inputs"].tolist()
142
+
143
+ # SDG labels list
144
+ label_list = [
145
+ 'GOAL 1: No Poverty',
146
+ 'GOAL 2: Zero Hunger',
147
+ 'GOAL 3: Good Health and Well-being',
148
+ 'GOAL 4: Quality Education',
149
+ 'GOAL 5: Gender Equality',
150
+ 'GOAL 6: Clean Water and Sanitation',
151
+ 'GOAL 7: Affordable and Clean Energy',
152
+ 'GOAL 8: Decent Work and Economic Growth',
153
+ 'GOAL 9: Industry, Innovation and Infrastructure',
154
+ 'GOAL 10: Reduced Inequality',
155
+ 'GOAL 11: Sustainable Cities and Communities',
156
+ 'GOAL 12: Responsible Consumption and Production',
157
+ 'GOAL 13: Climate Action',
158
+ 'GOAL 14: Life Below Water',
159
+ 'GOAL 15: Life on Land',
160
+ 'GOAL 16: Peace, Justice and Strong Institutions'
161
+ ]
162
+
163
+ # Lists for appending predictions
164
+ predicted_labels = []
165
+ prediction_score = []
166
+
167
+ # Preprocess text and make predictions
168
+ for text_input in text_list:
169
+ time.sleep(0.02) # Sleep to avoid rate limiting
170
+ cleaned_text = prep_text(text_input)
171
+ tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
172
+ text_logits = model(**tokenized_text).logits
173
+ predictions = torch.softmax(text_logits, dim=1).tolist()[0]
174
+ pred_dict = dict(zip(label_list, predictions))
175
+ sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True)
176
+ predicted_labels.append(sorted_preds[0][0])
177
+ prediction_score.append(sorted_preds[0][1])
178
+
179
+ # Append predictions to the DataFrame
180
+ df_docs['SDG_predicted'] = predicted_labels
181
+ df_docs['prediction_score'] = prediction_score
182
+
183
+ df_docs.to_csv('sdg_predictions.csv')
184
+ output_csv = gr.File(value='sdg_predictions.csv', visible=True)
185
+
186
+ # Create the histogram
187
+ fig = px.histogram(df_docs, y="SDG_predicted")
188
+ fig.update_layout(
189
+ template='seaborn',
190
+ font=dict(family="Arial", size=12, color="black"),
191
+ autosize=True,
192
+ #width=800,
193
+ #height=500,
194
+ xaxis_title="SDG counts",
195
+ yaxis_title="Sustainable development goals (SDG)",
196
+ )
197
+ fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
198
+ fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
199
+ fig.update_annotations(font_size=12)
200
+
201
+ return fig, output_csv
202
+
203
+ # Define the input component
204
+ file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"])
205
+
206
+ # Create the Gradio interface
207
+ iface3 = gr.Interface(fn=predict_sdg_from_csv,
208
+ inputs= file_input,
209
+ outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)],
210
+ title="Multi-text Prediction (CVS)",
211
+ description='NOTE: Column to be analysed must be titled ***text_inputs***')
212
+
213
+ demo = gr.TabbedInterface([iface1, iface2, iface3], ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"])
214
+
215
+ # Run the interface
216
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ plotly
4
+ pandas
5
+ numpy
6
+ nltk
7
+ regex
8
+ gradio
9
+ pypdf