|
import gradio as gr |
|
import regex as re |
|
import torch |
|
import nltk |
|
import pandas as pd |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from nltk.tokenize import sent_tokenize |
|
import plotly.express as px |
|
import time |
|
import tqdm |
|
nltk.download('punkt') |
|
|
|
|
|
checkpoint = "sadickam/sdg-classification-bert" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSequenceClassification.from_pretrained(checkpoint) |
|
|
|
|
|
def prep_text(text): |
|
clean_sents = [] |
|
sent_tokens = sent_tokenize(str(text)) |
|
for sent_token in sent_tokens: |
|
word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()] |
|
clean_sents.append(' '.join((word_tokens))) |
|
joined = ' '.join(clean_sents).strip(' ') |
|
joined = re.sub(r'`', "", joined) |
|
joined = re.sub(r'"', "", joined) |
|
return joined |
|
|
|
|
|
def app_info(): |
|
check = """ |
|
Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text. |
|
""" |
|
|
|
return check |
|
|
|
|
|
iface1 = gr.Interface( |
|
fn=app_info, inputs=None, outputs=['text'], title="General-Infomation", |
|
description= ''' |
|
This app powered by the sgdBERT model (sadickam/sdg-classification-bert) is for automatic classification of text with respect to |
|
the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability |
|
assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the |
|
OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5]. |
|
|
|
This app has two analysis modules summarised below: |
|
- Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction. |
|
- Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text. |
|
|
|
This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files. |
|
If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section. |
|
|
|
<h3>Contact</h3> |
|
<p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model |
|
powering this app, we are happy to hear from you.</p> |
|
|
|
<p>App contact: s.sadick@deakin.edu.au</p> |
|
''') |
|
|
|
|
|
|
|
def predict_sdg(text): |
|
|
|
cleaned_text = prep_text(text) |
|
if cleaned_text == "": |
|
raise gr.Error('This model needs some text input to return a prediction') |
|
elif cleaned_text != "": |
|
|
|
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True) |
|
|
|
text_logits = model(**tokenized_text).logits |
|
predictions = torch.softmax(text_logits, dim=1).tolist()[0] |
|
|
|
label_list = [ |
|
'GOAL 1: No Poverty', |
|
'GOAL 2: Zero Hunger', |
|
'GOAL 3: Good Health and Well-being', |
|
'GOAL 4: Quality Education', |
|
'GOAL 5: Gender Equality', |
|
'GOAL 6: Clean Water and Sanitation', |
|
'GOAL 7: Affordable and Clean Energy', |
|
'GOAL 8: Decent Work and Economic Growth', |
|
'GOAL 9: Industry, Innovation and Infrastructure', |
|
'GOAL 10: Reduced Inequality', |
|
'GOAL 11: Sustainable Cities and Communities', |
|
'GOAL 12: Responsible Consumption and Production', |
|
'GOAL 13: Climate Action', |
|
'GOAL 14: Life Below Water', |
|
'GOAL 15: Life on Land', |
|
'GOAL 16: Peace, Justice and Strong Institutions' |
|
] |
|
|
|
pred_dict = dict(zip(label_list, predictions)) |
|
|
|
|
|
sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True) |
|
|
|
|
|
u, v = zip(*sorted_preds) |
|
m = list(u) |
|
n = list(v) |
|
df2 = pd.DataFrame() |
|
df2['SDG'] = m |
|
df2['Likelihood'] = n |
|
|
|
|
|
fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h") |
|
|
|
fig.update_layout( |
|
|
|
template='seaborn', font=dict(family="Arial", size=12, color="black"), |
|
autosize=True, |
|
|
|
|
|
xaxis_title="Likelihood of SDG", |
|
yaxis_title="Sustainable development goals (SDG)", |
|
|
|
) |
|
|
|
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) |
|
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) |
|
fig.update_annotations(font_size=12) |
|
|
|
|
|
|
|
|
|
|
|
top_prediction = sorted_preds[0] |
|
|
|
|
|
return {top_prediction[0]: round(top_prediction[1], 3)}, fig |
|
|
|
|
|
iface2 = gr.Interface(fn=predict_sdg, |
|
inputs=gr.Textbox(lines=7, label="Paste or type text here"), |
|
outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)], |
|
title="Single Text Prediction", |
|
article="**Note:** The quality of model predictions may depend on the quality of information provided." |
|
) |
|
|
|
|
|
|
|
def predict_sdg_from_csv(file, progress=gr.Progress()): |
|
|
|
df_docs = pd.read_csv(file) |
|
text_list = df_docs["text_inputs"].tolist() |
|
|
|
|
|
label_list = [ |
|
'GOAL 1: No Poverty', |
|
'GOAL 2: Zero Hunger', |
|
'GOAL 3: Good Health and Well-being', |
|
'GOAL 4: Quality Education', |
|
'GOAL 5: Gender Equality', |
|
'GOAL 6: Clean Water and Sanitation', |
|
'GOAL 7: Affordable and Clean Energy', |
|
'GOAL 8: Decent Work and Economic Growth', |
|
'GOAL 9: Industry, Innovation and Infrastructure', |
|
'GOAL 10: Reduced Inequality', |
|
'GOAL 11: Sustainable Cities and Communities', |
|
'GOAL 12: Responsible Consumption and Production', |
|
'GOAL 13: Climate Action', |
|
'GOAL 14: Life Below Water', |
|
'GOAL 15: Life on Land', |
|
'GOAL 16: Peace, Justice and Strong Institutions' |
|
] |
|
|
|
|
|
predicted_labels = [] |
|
prediction_score = [] |
|
|
|
|
|
for text_input in progress.tqdm(text_list, desc="Analysing data": |
|
time.sleep(0.02) |
|
cleaned_text = prep_text(text_input) |
|
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True) |
|
text_logits = model(**tokenized_text).logits |
|
predictions = torch.softmax(text_logits, dim=1).tolist()[0] |
|
pred_dict = dict(zip(label_list, predictions)) |
|
sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True) |
|
predicted_labels.append(sorted_preds[0][0]) |
|
prediction_score.append(sorted_preds[0][1]) |
|
|
|
|
|
df_docs['SDG_predicted'] = predicted_labels |
|
df_docs['prediction_score'] = prediction_score |
|
|
|
df_docs.to_csv('sdg_predictions.csv') |
|
output_csv = gr.File(value='sdg_predictions.csv', visible=True) |
|
|
|
|
|
fig = px.histogram(df_docs, y="SDG_predicted") |
|
fig.update_layout( |
|
template='seaborn', |
|
font=dict(family="Arial", size=12, color="black"), |
|
autosize=True, |
|
|
|
|
|
xaxis_title="SDG counts", |
|
yaxis_title="Sustainable development goals (SDG)", |
|
) |
|
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) |
|
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) |
|
fig.update_annotations(font_size=12) |
|
|
|
return fig, output_csv |
|
|
|
|
|
file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"]) |
|
|
|
|
|
iface3 = gr.Interface(fn=predict_sdg_from_csv, |
|
inputs= file_input, |
|
outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)], |
|
title="Multi-text Prediction (CVS)", |
|
description='**NOTE:** The column to be analysed must be titled ***text_inputs***') |
|
|
|
demo = gr.TabbedInterface(interface_list = [iface1, iface2, iface3], |
|
tab_names = ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"], |
|
title = "Sustainble Development Goals (SDG) Text Classifier App", |
|
theme = 'soft' |
|
) |
|
|
|
|
|
demo.queue().launch() |