sadickam's picture
Update app.py
a57b5bc verified
import gradio as gr
import regex as re
import torch
import nltk
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from nltk.tokenize import sent_tokenize
import plotly.express as px
import time
import tqdm
nltk.download('punkt')
# Define the model and tokenizer
checkpoint = "sadickam/sdg-classification-bert"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
# Define the function for preprocessing text
def prep_text(text):
clean_sents = []
sent_tokens = sent_tokenize(str(text))
for sent_token in sent_tokens:
word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()]
clean_sents.append(' '.join((word_tokens)))
joined = ' '.join(clean_sents).strip(' ')
joined = re.sub(r'`', "", joined)
joined = re.sub(r'"', "", joined)
return joined
# APP INFO
def app_info():
check = """
Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text.
"""
return check
# Create Gradio interface for single text
iface1 = gr.Interface(
fn=app_info, inputs=None, outputs=['text'], title="General-Infomation",
description= '''
This app powered by the sgdBERT model (sadickam/sdg-classification-bert) is for automatic classification of text with respect to
the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability
assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the
OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5].
This app has two analysis modules summarised below:
- Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction.
- Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text.
This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files.
If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section.
<h3>Contact</h3>
<p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model
powering this app, we are happy to hear from you.</p>
<p>App contact: s.sadick@deakin.edu.au</p>
''')
# SINGLE TEXT
# Define the prediction function
def predict_sdg(text):
# Preprocess the input text
cleaned_text = prep_text(text)
if cleaned_text == "":
raise gr.Error('This model needs some text input to return a prediction')
elif cleaned_text != "":
# Tokenize the preprocessed text
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
# Predict
text_logits = model(**tokenized_text).logits
predictions = torch.softmax(text_logits, dim=1).tolist()[0]
# SDG labels
label_list = [
'GOAL 1: No Poverty',
'GOAL 2: Zero Hunger',
'GOAL 3: Good Health and Well-being',
'GOAL 4: Quality Education',
'GOAL 5: Gender Equality',
'GOAL 6: Clean Water and Sanitation',
'GOAL 7: Affordable and Clean Energy',
'GOAL 8: Decent Work and Economic Growth',
'GOAL 9: Industry, Innovation and Infrastructure',
'GOAL 10: Reduced Inequality',
'GOAL 11: Sustainable Cities and Communities',
'GOAL 12: Responsible Consumption and Production',
'GOAL 13: Climate Action',
'GOAL 14: Life Below Water',
'GOAL 15: Life on Land',
'GOAL 16: Peace, Justice and Strong Institutions'
]
# dictionary with label as key and percentage as value
pred_dict = dict(zip(label_list, predictions))
# sort 'pred_dict' by value and index the highest at [0]
sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)
# Make dataframe for plotly bar chart
u, v = zip(*sorted_preds)
m = list(u)
n = list(v)
df2 = pd.DataFrame()
df2['SDG'] = m
df2['Likelihood'] = n
# plot graph of predictions
fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h")
fig.update_layout(
# barmode='stack',
template='seaborn', font=dict(family="Arial", size=12, color="black"),
autosize=True,
#width=800,
#height=500,
xaxis_title="Likelihood of SDG",
yaxis_title="Sustainable development goals (SDG)",
# legend_title="Topics"
)
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
fig.update_annotations(font_size=12) # this changes y_axis, x_axis and subplot title font sizes
# Make dataframe for plotly bar chart
#df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood'])
# Return the top prediction
top_prediction = sorted_preds[0]
# Return result
return {top_prediction[0]: round(top_prediction[1], 3)}, fig
# Create Gradio interface for single text
iface2 = gr.Interface(fn=predict_sdg,
inputs=gr.Textbox(lines=7, label="Paste or type text here"),
outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)],
title="Single Text Prediction",
article="**Note:** The quality of model predictions may depend on the quality of information provided."
)
# UPLOAD CSV
# Define the prediction function
def predict_sdg_from_csv(file, progress=gr.Progress()):
# Read the CSV file
df_docs = pd.read_csv(file)
text_list = df_docs["text_inputs"].tolist()
# SDG labels list
label_list = [
'GOAL 1: No Poverty',
'GOAL 2: Zero Hunger',
'GOAL 3: Good Health and Well-being',
'GOAL 4: Quality Education',
'GOAL 5: Gender Equality',
'GOAL 6: Clean Water and Sanitation',
'GOAL 7: Affordable and Clean Energy',
'GOAL 8: Decent Work and Economic Growth',
'GOAL 9: Industry, Innovation and Infrastructure',
'GOAL 10: Reduced Inequality',
'GOAL 11: Sustainable Cities and Communities',
'GOAL 12: Responsible Consumption and Production',
'GOAL 13: Climate Action',
'GOAL 14: Life Below Water',
'GOAL 15: Life on Land',
'GOAL 16: Peace, Justice and Strong Institutions'
]
# Lists for appending predictions
predicted_labels = []
prediction_score = []
# Preprocess text and make predictions
for text_input in progress.tqdm(text_list, desc="Analysing data"):
time.sleep(0.02) # Sleep to avoid rate limiting
cleaned_text = prep_text(text_input)
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
text_logits = model(**tokenized_text).logits
predictions = torch.softmax(text_logits, dim=1).tolist()[0]
pred_dict = dict(zip(label_list, predictions))
sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True)
predicted_labels.append(sorted_preds[0][0])
prediction_score.append(sorted_preds[0][1])
# Append predictions to the DataFrame
df_docs['SDG_predicted'] = predicted_labels
df_docs['prediction_score'] = prediction_score
df_docs.to_csv('sdg_predictions.csv')
output_csv = gr.File(value='sdg_predictions.csv', visible=True)
# Create the histogram
fig = px.histogram(df_docs, y="SDG_predicted")
fig.update_layout(
template='seaborn',
font=dict(family="Arial", size=12, color="black"),
autosize=True,
#width=800,
#height=500,
xaxis_title="SDG counts",
yaxis_title="Sustainable development goals (SDG)",
)
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
fig.update_annotations(font_size=12)
return fig, output_csv
# Define the input component
file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"])
# Create the Gradio interface
iface3 = gr.Interface(fn=predict_sdg_from_csv,
inputs= file_input,
outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)],
title="Multi-text Prediction (CVS)",
description='**NOTE:** The column to be analysed must be titled ***text_inputs***')
demo = gr.TabbedInterface(interface_list = [iface1, iface2, iface3],
tab_names = ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"],
title = "Sustainble Development Goals (SDG) Text Classifier App",
theme = 'soft'
)
# Run the interface
demo.queue().launch()