Spaces:
Sleeping
Sleeping
import gradio as gr | |
import regex as re | |
import torch | |
import nltk | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from nltk.tokenize import sent_tokenize | |
import plotly.express as px | |
import time | |
import tqdm | |
nltk.download('punkt') | |
# Define the model and tokenizer | |
checkpoint = "sadickam/sdg-classification-bert" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForSequenceClassification.from_pretrained(checkpoint) | |
# Define the function for preprocessing text | |
def prep_text(text): | |
clean_sents = [] | |
sent_tokens = sent_tokenize(str(text)) | |
for sent_token in sent_tokens: | |
word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()] | |
clean_sents.append(' '.join((word_tokens))) | |
joined = ' '.join(clean_sents).strip(' ') | |
joined = re.sub(r'`', "", joined) | |
joined = re.sub(r'"', "", joined) | |
return joined | |
# APP INFO | |
def app_info(): | |
check = """ | |
Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text. | |
""" | |
return check | |
# Create Gradio interface for single text | |
iface1 = gr.Interface( | |
fn=app_info, inputs=None, outputs=['text'], title="General-Infomation", | |
description= ''' | |
This app powered by the sgdBERT model (sadickam/sdg-classification-bert) is for automatic classification of text with respect to | |
the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability | |
assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the | |
OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5]. | |
This app has two analysis modules summarised below: | |
- Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction. | |
- Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text. | |
This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files. | |
If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section. | |
<h3>Contact</h3> | |
<p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model | |
powering this app, we are happy to hear from you.</p> | |
<p>App contact: s.sadick@deakin.edu.au</p> | |
''') | |
# SINGLE TEXT | |
# Define the prediction function | |
def predict_sdg(text): | |
# Preprocess the input text | |
cleaned_text = prep_text(text) | |
if cleaned_text == "": | |
raise gr.Error('This model needs some text input to return a prediction') | |
elif cleaned_text != "": | |
# Tokenize the preprocessed text | |
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True) | |
# Predict | |
text_logits = model(**tokenized_text).logits | |
predictions = torch.softmax(text_logits, dim=1).tolist()[0] | |
# SDG labels | |
label_list = [ | |
'GOAL 1: No Poverty', | |
'GOAL 2: Zero Hunger', | |
'GOAL 3: Good Health and Well-being', | |
'GOAL 4: Quality Education', | |
'GOAL 5: Gender Equality', | |
'GOAL 6: Clean Water and Sanitation', | |
'GOAL 7: Affordable and Clean Energy', | |
'GOAL 8: Decent Work and Economic Growth', | |
'GOAL 9: Industry, Innovation and Infrastructure', | |
'GOAL 10: Reduced Inequality', | |
'GOAL 11: Sustainable Cities and Communities', | |
'GOAL 12: Responsible Consumption and Production', | |
'GOAL 13: Climate Action', | |
'GOAL 14: Life Below Water', | |
'GOAL 15: Life on Land', | |
'GOAL 16: Peace, Justice and Strong Institutions' | |
] | |
# dictionary with label as key and percentage as value | |
pred_dict = dict(zip(label_list, predictions)) | |
# sort 'pred_dict' by value and index the highest at [0] | |
sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True) | |
# Make dataframe for plotly bar chart | |
u, v = zip(*sorted_preds) | |
m = list(u) | |
n = list(v) | |
df2 = pd.DataFrame() | |
df2['SDG'] = m | |
df2['Likelihood'] = n | |
# plot graph of predictions | |
fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h") | |
fig.update_layout( | |
# barmode='stack', | |
template='seaborn', font=dict(family="Arial", size=12, color="black"), | |
autosize=True, | |
#width=800, | |
#height=500, | |
xaxis_title="Likelihood of SDG", | |
yaxis_title="Sustainable development goals (SDG)", | |
# legend_title="Topics" | |
) | |
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
fig.update_annotations(font_size=12) # this changes y_axis, x_axis and subplot title font sizes | |
# Make dataframe for plotly bar chart | |
#df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood']) | |
# Return the top prediction | |
top_prediction = sorted_preds[0] | |
# Return result | |
return {top_prediction[0]: round(top_prediction[1], 3)}, fig | |
# Create Gradio interface for single text | |
iface2 = gr.Interface(fn=predict_sdg, | |
inputs=gr.Textbox(lines=7, label="Paste or type text here"), | |
outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)], | |
title="Single Text Prediction", | |
article="**Note:** The quality of model predictions may depend on the quality of information provided." | |
) | |
# UPLOAD CSV | |
# Define the prediction function | |
def predict_sdg_from_csv(file, progress=gr.Progress()): | |
# Read the CSV file | |
df_docs = pd.read_csv(file) | |
text_list = df_docs["text_inputs"].tolist() | |
# SDG labels list | |
label_list = [ | |
'GOAL 1: No Poverty', | |
'GOAL 2: Zero Hunger', | |
'GOAL 3: Good Health and Well-being', | |
'GOAL 4: Quality Education', | |
'GOAL 5: Gender Equality', | |
'GOAL 6: Clean Water and Sanitation', | |
'GOAL 7: Affordable and Clean Energy', | |
'GOAL 8: Decent Work and Economic Growth', | |
'GOAL 9: Industry, Innovation and Infrastructure', | |
'GOAL 10: Reduced Inequality', | |
'GOAL 11: Sustainable Cities and Communities', | |
'GOAL 12: Responsible Consumption and Production', | |
'GOAL 13: Climate Action', | |
'GOAL 14: Life Below Water', | |
'GOAL 15: Life on Land', | |
'GOAL 16: Peace, Justice and Strong Institutions' | |
] | |
# Lists for appending predictions | |
predicted_labels = [] | |
prediction_score = [] | |
# Preprocess text and make predictions | |
for text_input in progress.tqdm(text_list, desc="Analysing data"): | |
time.sleep(0.02) # Sleep to avoid rate limiting | |
cleaned_text = prep_text(text_input) | |
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True) | |
text_logits = model(**tokenized_text).logits | |
predictions = torch.softmax(text_logits, dim=1).tolist()[0] | |
pred_dict = dict(zip(label_list, predictions)) | |
sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True) | |
predicted_labels.append(sorted_preds[0][0]) | |
prediction_score.append(sorted_preds[0][1]) | |
# Append predictions to the DataFrame | |
df_docs['SDG_predicted'] = predicted_labels | |
df_docs['prediction_score'] = prediction_score | |
df_docs.to_csv('sdg_predictions.csv') | |
output_csv = gr.File(value='sdg_predictions.csv', visible=True) | |
# Create the histogram | |
fig = px.histogram(df_docs, y="SDG_predicted") | |
fig.update_layout( | |
template='seaborn', | |
font=dict(family="Arial", size=12, color="black"), | |
autosize=True, | |
#width=800, | |
#height=500, | |
xaxis_title="SDG counts", | |
yaxis_title="Sustainable development goals (SDG)", | |
) | |
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
fig.update_annotations(font_size=12) | |
return fig, output_csv | |
# Define the input component | |
file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"]) | |
# Create the Gradio interface | |
iface3 = gr.Interface(fn=predict_sdg_from_csv, | |
inputs= file_input, | |
outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)], | |
title="Multi-text Prediction (CVS)", | |
description='**NOTE:** The column to be analysed must be titled ***text_inputs***') | |
demo = gr.TabbedInterface(interface_list = [iface1, iface2, iface3], | |
tab_names = ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"], | |
title = "Sustainble Development Goals (SDG) Text Classifier App", | |
theme = 'soft' | |
) | |
# Run the interface | |
demo.queue().launch() |