sadickam commited on
Commit
bd3703e
β€’
1 Parent(s): 9e957a5

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +249 -0
  2. requirements.txt +6 -0
main.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ # import inflect
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+ import string
6
+ import plotly.express as px
7
+ import pandas as pd
8
+ import nltk
9
+ from nltk.tokenize import sent_tokenize
10
+ nltk.download('punkt')
11
+
12
+ # Note - USE "VBA_venv" environment in the local github folder
13
+
14
+ punctuations = string.punctuation
15
+
16
+ def prep_text(text):
17
+ # function for preprocessing text
18
+
19
+ # remove trailing characters (\s\n) and convert to lowercase
20
+ clean_sents = [] # append clean con sentences
21
+ sent_tokens = sent_tokenize(str(text))
22
+ for sent_token in sent_tokens:
23
+ word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()]
24
+ word_tokens = [word_token for word_token in word_tokens if word_token not in punctuations]
25
+ clean_sents.append(' '.join((word_tokens)))
26
+ joined = ' '.join(clean_sents).strip(' ')
27
+ return joined
28
+
29
+
30
+ # model name or path to model
31
+ checkpoint_1 = "Highway/SubCat"
32
+
33
+ checkpoint_2 = "Highway/ExtraOver"
34
+
35
+
36
+ @st.cache(allow_output_mutation=True)
37
+ def load_model_1():
38
+ return AutoModelForSequenceClassification.from_pretrained(checkpoint_1)
39
+
40
+
41
+ @st.cache(allow_output_mutation=True)
42
+ def load_tokenizer_1():
43
+ return AutoTokenizer.from_pretrained(checkpoint_1)
44
+
45
+
46
+ @st.cache(allow_output_mutation=True)
47
+ def load_model_2():
48
+ return AutoModelForSequenceClassification.from_pretrained(checkpoint_2)
49
+
50
+
51
+ @st.cache(allow_output_mutation=True)
52
+ def load_tokenizer_2():
53
+ return AutoTokenizer.from_pretrained(checkpoint_2)
54
+
55
+
56
+ st.set_page_config(
57
+ page_title="Cost Data Classifier", layout= "wide", initial_sidebar_state="auto", page_icon="πŸ’·"
58
+ )
59
+
60
+ st.title("🚦 AI Infrastructure Cost Data Classifier")
61
+ # st.header("")
62
+
63
+ with st.expander("About this app", expanded=False):
64
+ st.write(
65
+ """
66
+ - Artificial Intelligence (AI) and Machine learning (ML) tool for automatic classification of infrastructure cost data for benchmarking
67
+ - Classifies cost descriptions from documents such as Bills of Quantities (BOQs) and Schedule of Rates
68
+ - Can be trained to classify granular and itemised cost descriptions into any predefined categories for benchmarking
69
+ - Contact research team to discuss your data structures and suitability for the app
70
+ - It is best to use this app on a laptop or desktop computer
71
+ """
72
+ )
73
+
74
+
75
+ st.markdown("##### Description")
76
+ with st.form(key="my_form"):
77
+ Text_entry = st.text_area(
78
+ "Paste or type infrastructure cost description in the text box below (i.e., input)"
79
+ )
80
+ submitted = st.form_submit_button(label="πŸ‘‰ Get SubCat and ExtraOver!")
81
+
82
+ if submitted:
83
+
84
+ # First prediction
85
+
86
+ label_list_1 = [
87
+ 'Arrow, Triangle, Circle, Letter, Numeral, Symbol and Sundries',
88
+ 'Binder',
89
+ 'Cable',
90
+ 'Catman Other Adjustment',
91
+ 'Cold Milling',
92
+ 'Disposal of Acceptable/Unacceptable Material',
93
+ 'Drain/Service Duct In Trench',
94
+ 'Erection & Dismantling of Temporary Accommodation/Facilities (All Types)',
95
+ 'Excavate And Replace Filter Material/Recycle Filter Material',
96
+ 'Excavation',
97
+ 'General TM Item',
98
+ 'Information boards',
99
+ 'Joint/Termination',
100
+ 'Line, Ancillary Line, Solid Area',
101
+ 'Loop Detector Installation',
102
+ 'Minimum Lining Visit Charge',
103
+ 'Node Marker',
104
+ 'PCC Kerb',
105
+ 'Provision of Mobile Welfare Facilities',
106
+ 'Removal of Deformable Safety Fence',
107
+ 'Removal of Line, Ancillary Line, Solid Area',
108
+ 'Removal of Traffic Sign and post(s)',
109
+ 'Road Stud',
110
+ 'Safety Barrier Or Bifurcation (Non-Concrete)',
111
+ 'Servicing of Temporary Accommodation/Facilities (All Types) (day)',
112
+ 'Tack Coat',
113
+ 'Temporary Road Markings',
114
+ 'Thin Surface Course',
115
+ 'Traffic Sign - Unknown specification',
116
+ 'Vegetation Clearance/Weed Control (m2)',
117
+ 'Others'
118
+ ]
119
+
120
+ joined_clean_sents = prep_text(Text_entry)
121
+
122
+ # tokenize
123
+ tokenizer_1 = load_tokenizer_1()
124
+ tokenized_text_1 = tokenizer_1(joined_clean_sents, return_tensors="pt")
125
+
126
+ # predict
127
+ model_1 = load_model_1()
128
+ text_logits_1 = model_1(**tokenized_text_1).logits
129
+ predictions_1 = torch.softmax(text_logits_1, dim=1).tolist()[0]
130
+ predictions_1 = [round(a, 3) for a in predictions_1]
131
+
132
+ # dictionary with label as key and percentage as value
133
+ pred_dict_1 = (dict(zip(label_list_1, predictions_1)))
134
+
135
+ # sort 'pred_dict' by value and index the highest at [0]
136
+ sorted_preds_1 = sorted(pred_dict_1.items(), key=lambda x: x[1], reverse=True)
137
+
138
+ # Make dataframe for plotly bar chart
139
+ u_1, v_1 = zip(*sorted_preds_1)
140
+ x_1 = list(u_1)
141
+ y_1 = list(v_1)
142
+ df2 = pd.DataFrame()
143
+ df2['SubCatName'] = x_1
144
+ df2['Likelihood'] = y_1
145
+
146
+ c1, c2, c3 = st.columns([1.5, 0.5, 1])
147
+
148
+ with c1:
149
+ st.header("SubCatName")
150
+ # plot graph of predictions
151
+ fig = px.bar(df2, x="Likelihood", y="SubCatName", orientation="h")
152
+
153
+ fig.update_layout(
154
+ # barmode='stack',
155
+ template='ggplot2',
156
+ font=dict(
157
+ family="Arial",
158
+ size=14,
159
+ color="black"
160
+ ),
161
+ autosize=False,
162
+ width=800,
163
+ height=500,
164
+ xaxis_title="Likelihood of SubCatName",
165
+ yaxis_title="SubCatNames",
166
+ # legend_title="Topics"
167
+ )
168
+
169
+ fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=14))
170
+ fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=14))
171
+ fig.update_annotations(font_size=14) # this changes y_axis, x_axis and subplot title font sizes
172
+
173
+ # Plot
174
+ st.plotly_chart(fig, use_container_width=False)
175
+
176
+ with c3:
177
+ st.header("")
178
+ predicted_1 = st.metric("Predicted SubCatName", sorted_preds_1[0][0])
179
+ Prediction_confidence_1 = st.metric("Prediction confidence", (str(round(sorted_preds_1[0][1]*100, 1))+"%"))
180
+
181
+ st.success("Great! SubCatName successfully predicted. ", icon="βœ…")
182
+
183
+
184
+ # Second prediction
185
+
186
+ label_list_2 = ["False", "True"]
187
+
188
+ joined_clean_sents = prep_text(Text_entry)
189
+
190
+ # tokenize
191
+ tokenizer_2 = load_tokenizer_2()
192
+ tokenized_text_2 = tokenizer_2(joined_clean_sents, return_tensors="pt")
193
+
194
+ # predict
195
+ model_2 = load_model_2()
196
+ text_logits_2 = model_2(**tokenized_text_2).logits
197
+ predictions_2 = torch.softmax(text_logits_2, dim=1).tolist()[0]
198
+ predictions_2 = [round(a_, 3) for a_ in predictions_2]
199
+
200
+ # dictionary with label as key and percentage as value
201
+ pred_dict_2 = (dict(zip(label_list_2, predictions_2)))
202
+
203
+ # sort 'pred_dict' by value and index the highest at [0]
204
+ sorted_preds_2 = sorted(pred_dict_2.items(), key=lambda x: x[1], reverse=True)
205
+
206
+ # Make dataframe for plotly bar chart
207
+ u_2, v_2 = zip(*sorted_preds_2)
208
+ x_2 = list(u_2)
209
+ y_2 = list(v_2)
210
+ df3 = pd.DataFrame()
211
+ df3['ExtraOver'] = x_2
212
+ df3['Likelihood'] = y_2
213
+
214
+ d1, d2, d3 = st.columns([1.5, 0.5, 1])
215
+
216
+ with d1:
217
+ st.header("ExtraOver")
218
+ # plot graph of predictions
219
+ fig = px.bar(df3, x="Likelihood", y="ExtraOver", orientation="h")
220
+
221
+ fig.update_layout(
222
+ # barmode='stack',
223
+ template='ggplot2',
224
+ font=dict(
225
+ family="Arial",
226
+ size=14,
227
+ color="black"
228
+ ),
229
+ autosize=False,
230
+ width=800,
231
+ height=200,
232
+ xaxis_title="Likelihood of ExtraOver",
233
+ yaxis_title="ExtraOver",
234
+ # legend_title="Topics"
235
+ )
236
+
237
+ fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=14))
238
+ fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=14))
239
+ fig.update_annotations(font_size=14) # this changes y_axis, x_axis and subplot title font sizes
240
+
241
+ # Plot
242
+ st.plotly_chart(fig, use_container_width=False)
243
+
244
+ with d3:
245
+ st.header("")
246
+ predicted_2 = st.metric("Predicted ExtraOver", sorted_preds_2[0][0])
247
+ Prediction_confidence_2 = st.metric("Prediction confidence", (str(round(sorted_preds_2[0][1]*100, 1))+"%"))
248
+
249
+ st.success("Great! ExtraOver successfully predicted. ", icon="βœ…")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ plotly
4
+ pandas
5
+ nltk
6
+ streamlit