sadickam commited on
Commit
cb806bb
1 Parent(s): dbde8c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -69
app.py CHANGED
@@ -62,80 +62,77 @@ iface1 = gr.Interface(
62
  def predict_sdg(text):
63
  # Preprocess the input text
64
  cleaned_text = prep_text(text)
65
- # Tokenize the preprocessed text
66
- tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
67
- # Predict
68
- text_logits = model(**tokenized_text).logits
69
- predictions = torch.softmax(text_logits, dim=1).tolist()[0]
70
- # SDG labels
71
- label_list = [
72
- 'GOAL 1: No Poverty',
73
- 'GOAL 2: Zero Hunger',
74
- 'GOAL 3: Good Health and Well-being',
75
- 'GOAL 4: Quality Education',
76
- 'GOAL 5: Gender Equality',
77
- 'GOAL 6: Clean Water and Sanitation',
78
- 'GOAL 7: Affordable and Clean Energy',
79
- 'GOAL 8: Decent Work and Economic Growth',
80
- 'GOAL 9: Industry, Innovation and Infrastructure',
81
- 'GOAL 10: Reduced Inequality',
82
- 'GOAL 11: Sustainable Cities and Communities',
83
- 'GOAL 12: Responsible Consumption and Production',
84
- 'GOAL 13: Climate Action',
85
- 'GOAL 14: Life Below Water',
86
- 'GOAL 15: Life on Land',
87
- 'GOAL 16: Peace, Justice and Strong Institutions'
88
- ]
89
- # dictionary with label as key and percentage as value
90
- pred_dict = dict(zip(label_list, predictions))
91
-
92
- # sort 'pred_dict' by value and index the highest at [0]
93
- sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)
94
-
95
- # Make dataframe for plotly bar chart
96
- u, v = zip(*sorted_preds)
97
- m = list(u)
98
- n = list(v)
99
- df2 = pd.DataFrame()
100
- df2['SDG'] = m
101
- df2['Likelihood'] = n
102
-
103
- # plot graph of predictions
104
- fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h")
105
-
106
- fig.update_layout(
107
- # barmode='stack',
108
- template='seaborn', font=dict(family="Arial", size=12, color="black"),
109
- autosize=True,
110
- #width=800,
111
- #height=500,
112
- xaxis_title="Likelihood of SDG",
113
- yaxis_title="Sustainable development goals (SDG)",
114
- # legend_title="Topics"
115
- )
116
-
117
- fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
118
- fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
119
- fig.update_annotations(font_size=12) # this changes y_axis, x_axis and subplot title font sizes
120
-
121
- # Make dataframe for plotly bar chart
122
- #df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood'])
123
-
124
- # Return the top prediction
125
- top_prediction = sorted_preds[0]
 
 
 
126
 
127
  # Return result
128
  return {top_prediction[0]: round(top_prediction[1], 3)}, fig
129
 
130
- # Define input and warning
131
- if gr.Textbox(lines=7, label="Paste or type text here") != '':
132
- single_text = gr.Textbox(lines=7, label="Paste or type text here")
133
- elif r.Textbox(lines=7, label="Paste or type text here") == '':
134
- single_text = gr.Warning('This model need some text to return a prediction')
135
-
136
  # Create Gradio interface for single text
137
  iface2 = gr.Interface(fn=predict_sdg,
138
- inputs=single_text,
139
  outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)],
140
  title="Single Text Prediction")
141
 
@@ -219,4 +216,4 @@ iface3 = gr.Interface(fn=predict_sdg_from_csv,
219
  demo = gr.TabbedInterface([iface1, iface2, iface3], ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"])
220
 
221
  # Run the interface
222
- demo.queue().launch()
 
62
  def predict_sdg(text):
63
  # Preprocess the input text
64
  cleaned_text = prep_text(text)
65
+ if cleaned_text == "":
66
+ raise gr.Error('This model needs some text input to return a prediction')
67
+ elif cleaned_text != ""
68
+ # Tokenize the preprocessed text
69
+ tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True)
70
+ # Predict
71
+ text_logits = model(**tokenized_text).logits
72
+ predictions = torch.softmax(text_logits, dim=1).tolist()[0]
73
+ # SDG labels
74
+ label_list = [
75
+ 'GOAL 1: No Poverty',
76
+ 'GOAL 2: Zero Hunger',
77
+ 'GOAL 3: Good Health and Well-being',
78
+ 'GOAL 4: Quality Education',
79
+ 'GOAL 5: Gender Equality',
80
+ 'GOAL 6: Clean Water and Sanitation',
81
+ 'GOAL 7: Affordable and Clean Energy',
82
+ 'GOAL 8: Decent Work and Economic Growth',
83
+ 'GOAL 9: Industry, Innovation and Infrastructure',
84
+ 'GOAL 10: Reduced Inequality',
85
+ 'GOAL 11: Sustainable Cities and Communities',
86
+ 'GOAL 12: Responsible Consumption and Production',
87
+ 'GOAL 13: Climate Action',
88
+ 'GOAL 14: Life Below Water',
89
+ 'GOAL 15: Life on Land',
90
+ 'GOAL 16: Peace, Justice and Strong Institutions'
91
+ ]
92
+ # dictionary with label as key and percentage as value
93
+ pred_dict = dict(zip(label_list, predictions))
94
+
95
+ # sort 'pred_dict' by value and index the highest at [0]
96
+ sorted_preds = sorted(pred_dict.items(), key=lambda x: x[1], reverse=True)
97
+
98
+ # Make dataframe for plotly bar chart
99
+ u, v = zip(*sorted_preds)
100
+ m = list(u)
101
+ n = list(v)
102
+ df2 = pd.DataFrame()
103
+ df2['SDG'] = m
104
+ df2['Likelihood'] = n
105
+
106
+ # plot graph of predictions
107
+ fig = px.bar(df2, x="Likelihood", y="SDG", orientation="h")
108
+
109
+ fig.update_layout(
110
+ # barmode='stack',
111
+ template='seaborn', font=dict(family="Arial", size=12, color="black"),
112
+ autosize=True,
113
+ #width=800,
114
+ #height=500,
115
+ xaxis_title="Likelihood of SDG",
116
+ yaxis_title="Sustainable development goals (SDG)",
117
+ # legend_title="Topics"
118
+ )
119
+
120
+ fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
121
+ fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
122
+ fig.update_annotations(font_size=12) # this changes y_axis, x_axis and subplot title font sizes
123
+
124
+ # Make dataframe for plotly bar chart
125
+ #df2 = pd.DataFrame(sorted_preds, columns=['SDG', 'Likelihood'])
126
+
127
+ # Return the top prediction
128
+ top_prediction = sorted_preds[0]
129
 
130
  # Return result
131
  return {top_prediction[0]: round(top_prediction[1], 3)}, fig
132
 
 
 
 
 
 
 
133
  # Create Gradio interface for single text
134
  iface2 = gr.Interface(fn=predict_sdg,
135
+ inputs=gr.Textbox(lines=7, label="Paste or type text here"),
136
  outputs=[gr.Label(label="Top SDG Predicted", show_label=True), gr.Plot(label="Likelihood of all SDG", show_label=True)],
137
  title="Single Text Prediction")
138
 
 
216
  demo = gr.TabbedInterface([iface1, iface2, iface3], ["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"])
217
 
218
  # Run the interface
219
+ demo.launch()