import gradio as gr import regex as re import torch import nltk import pandas as pd from transformers import AutoTokenizer, AutoModelForSequenceClassification from nltk.tokenize import sent_tokenize import plotly.express as px import time import tqdm nltk.download('punkt') # Define the model and tokenizer checkpoint = "sadickam/sdg-classification-bert" tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForSequenceClassification.from_pretrained(checkpoint) # Define the function for preprocessing text def prep_text(text): clean_sents = [] sent_tokens = sent_tokenize(str(text)) for sent_token in sent_tokens: word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()] clean_sents.append(' '.join((word_tokens))) joined = ' '.join(clean_sents).strip(' ') joined = re.sub(r'`', "", joined) joined = re.sub(r'"', "", joined) return joined # APP INFO def app_info(): check = """ Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text. """ return check # Create Gradio interface for single text iface1 = gr.Interface( fn=app_info, inputs=None, outputs=['text'], title="General-Infomation", description= ''' This app powered by the sgdBERT model (sadickam/sdg-classification-bert) is for automatic classification of text with respect to the UN Sustainable Development Goals (SDG). Note that 16 out of the 17 SDGs labels are covered. This app is for sustainability assessment and benchmarking and is not limited to a specific industry. The model powering this app was developed using the OSDG Community Dataset (OSDG-CD) [Link - https://zenodo.org/record/5550238#.Y8Sd5f5ByF5]. This app has two analysis modules summarised below: - Single-Text-Prediction - Analyses text pasted in a text box and return SDG prediction. - Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with SDG prediction for each row of text. This app runs on a free server and may therefore not be suitable for analysing large CSV and PDF files. If you need assistance with analysing large CSV or PDF files, do get in touch using the contact information in the Contact section.

Contact

We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model powering this app, we are happy to hear from you.

App contact: s.sadick@deakin.edu.au

''') # SINGLE TEXT # Define the prediction function def predict_sdg(text): # Preprocess the input text cleaned_text = prep_text(text) if cleaned_text == "": raise gr.Error('This model needs some text input to return a prediction') elif cleaned_text != "": # Tokenize the preprocessed text tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True) # Predict text_logits = model(**tokenized_text).logits predictions = torch.softmax(text_logits, dim=1).tolist()[0] # SDG labels label_list = [ 'GOAL 1: No Poverty', 'GOAL 2: Zero Hunger', 'GOAL 3: Good Health and Well-being', 'GOAL 4: Quality Education', 'GOAL 5: Gender Equality', 'GOAL 6: Clean Water and Sanitation', 'GOAL 7: Affordable and Clean Energy', 'GOAL 8: Decent Work and Economic Growth', 'GOAL 9: Industry, Innovation and Infrastructure', 'GOAL 10: Reduced Inequality', 'GOAL 11: Sustainable Cities and Communities', 'GOAL 12: Responsible Consumption and Production', 'GOAL 13: Climate Action', 'GOAL 14: Life Below Water', 'GOAL 15: Life on Land', 'GOAL 16: Peace, Justice and Strong Institutions' ] # dictionary with label as key and percentage as value pred_dict = dict(zip(label_list, predictions)) # sort 'pred_dict' by value and index the highest at [0] sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True) # Make dataframe for plotly bar chart u, v = zip(*sorted_preds) m = list(u) n = list(v) df2 = pd.DataFrame() df2['SDG'] = m df2['Likelihood'] = n # plot graph of predictions fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h") fig.update_layout( # barmode='stack', template='seaborn', font=dict(family="Arial", size=12, color="black"), autosize=True, #width=800, #height=500, xaxis_title="Likelihood of SDG", yaxis_title="Sustainable development goals (SDG)", # legend_title="Topics" ) fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) fig.update_annotations(font_size=12) # this changes y_axis, x_axis and subplot title font sizes # Make dataframe for plotly bar chart #df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood']) # Return the top prediction top_prediction = sorted_preds[0] # Return result return {top_prediction[0]: round(top_prediction[1], 3)}, fig # Create Gradio interface for single text iface2 = gr.Interface(fn=predict_sdg, inputs=gr.Textbox(lines=7, label="Paste or type text here"), outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)], title="Single Text Prediction", article="**Note:** The quality of model predictions may depend on the quality of information provided." ) # UPLOAD CSV # Define the prediction function def predict_sdg_from_csv(file, progress=gr.Progress()): # Read the CSV file df_docs = pd.read_csv(file) text_list = df_docs["text_inputs"].tolist() # SDG labels list label_list = [ 'GOAL 1: No Poverty', 'GOAL 2: Zero Hunger', 'GOAL 3: Good Health and Well-being', 'GOAL 4: Quality Education', 'GOAL 5: Gender Equality', 'GOAL 6: Clean Water and Sanitation', 'GOAL 7: Affordable and Clean Energy', 'GOAL 8: Decent Work and Economic Growth', 'GOAL 9: Industry, Innovation and Infrastructure', 'GOAL 10: Reduced Inequality', 'GOAL 11: Sustainable Cities and Communities', 'GOAL 12: Responsible Consumption and Production', 'GOAL 13: Climate Action', 'GOAL 14: Life Below Water', 'GOAL 15: Life on Land', 'GOAL 16: Peace, Justice and Strong Institutions' ] # Lists for appending predictions predicted_labels = [] prediction_score = [] # Preprocess text and make predictions for text_input in progress.tqdm(text_list, desc="Analysing data"): time.sleep(0.02) # Sleep to avoid rate limiting cleaned_text = prep_text(text_input) tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True) text_logits = model(**tokenized_text).logits predictions = torch.softmax(text_logits, dim=1).tolist()[0] pred_dict = dict(zip(label_list, predictions)) sorted_preds = sorted(pred_dict.items(), key=lambda g: g[1], reverse=True) predicted_labels.append(sorted_preds[0][0]) prediction_score.append(sorted_preds[0][1]) # Append predictions to the DataFrame df_docs['SDG_predicted'] = predicted_labels df_docs['prediction_score'] = prediction_score df_docs.to_csv('sdg_predictions.csv') output_csv = gr.File(value='sdg_predictions.csv', visible=True) # Create the histogram fig = px.histogram(df_docs, y="SDG_predicted") fig.update_layout( template='seaborn', font=dict(family="Arial", size=12, color="black"), autosize=True, #width=800, #height=500, xaxis_title="SDG counts", yaxis_title="Sustainable development goals (SDG)", ) fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) fig.update_annotations(font_size=12) return fig, output_csv # Define the input component file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"]) # Create the Gradio interface iface3 = gr.Interface(fn=predict_sdg_from_csv, inputs= file_input, outputs=[gr.Plot(label='Frequency of SDGs', show_label=True), gr.File(label='Download output CSV', show_label=True)], title="Multi-text Prediction (CVS)", description='**NOTE:** The column to be analysed must be titled ***text_inputs***') demo = gr.TabbedInterface(interface_list = [iface1, iface2, iface3], tab_names = ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"], title = "Sustainble Development Goals (SDG) Text Classifier App", theme = 'soft' ) # Run the interface demo.queue().launch()