hongaik commited on
Commit
deb200f
β€’
1 Parent(s): a5995da
w2v_ovr_svc.sav β†’ models/w2v_ovr_svc.sav RENAMED
File without changes
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit==1.4.0
2
+ re==2.2.1
3
+ gensim==4.1.2
4
+ transformers==4.16.1
5
+ pickle
text_class_app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import utils
3
+
4
+ ########## Title for the Web App ##########
5
+ st.title("Text Classification for Service Feedback")
6
+
7
+ ########## Create Input field ##########
8
+ feedback = st.text_input('Type your text here', 'The staff were extremely polite and helpful!')
9
+
10
+ if st.button('Click for predictions!'):
11
+ with st.spinner('Generating predictions...'):
12
+
13
+ result = get_single_prediction(feedback)
14
+
15
+ st.success(f'Your text has been predicted to fall under the following labels: {result[:-1]}. This text is {result[-1]}.')
16
+
17
+ st.text('Or... Upload a csv file if you have many texts')
18
+
19
+ uploaded_file = st.file_uploader("Please upload a csv file with only 1 column of texts.")
20
+
21
+ if uploaded_file is not None:
22
+
23
+ with st.spinner('Generating predictions...'):
24
+ results = get_multiple_predictions(uploaded_file)
25
+
26
+ st.download_button(
27
+ label="Download results as CSV",
28
+ data=results,
29
+ file_name='results.csv',
30
+ mime='text/csv',
31
+ )
32
+
33
+
utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from gensim.models.keyedvectors import KeyedVectors
3
+ from transformers import pipeline
4
+ import pickle
5
+
6
+ w2v = KeyedVectors.load('models/word2vec')
7
+ w2v_vocab = set(sorted(w2v.index_to_key))
8
+ model = pickle.load(open('models/w2v_ovr_svc.sav', 'rb'))
9
+ classifier = pipeline("zero-shot-classification",
10
+ model="facebook/bart-large-mnli", device=0, framework='pt'
11
+ )
12
+
13
+ labels = [
14
+ 'communication', 'waiting time',
15
+ 'information', 'user interface',
16
+ 'facilities', 'location', 'price'
17
+ ]
18
+
19
+ def get_sentiment_label_facebook(list_of_sent_dicts):
20
+ if list_of_sent_dicts['labels'][0] == 'negative':
21
+ return 'negative'
22
+ else:
23
+ return 'positive'
24
+
25
+ def get_single_prediction(text):
26
+
27
+ # manipulate data into a format that we pass to our model
28
+ text = text.lower() #lower case
29
+ text = re.sub('[^0-9a-zA-Z\s]', '', text) #remove special char, punctuation
30
+
31
+ # Remove OOV words
32
+ text = ' '.join([i for i in text.split() if i in w2v_vocab])
33
+
34
+ # Vectorise text and store in new dataframe. Sentence vector = average of word vectors
35
+ text_vectors = np.mean([w2v[i] for i in text.split()], axis=0)
36
+
37
+ # Make predictions
38
+ results = model.predict(text_vectors)
39
+
40
+ # Get sentiment
41
+ sentiment = get_sentiment_label_facebook(classifier(text,
42
+ candidate_labels=['positive', 'negative'],
43
+ hypothesis_template='The sentiment of this is {}'))
44
+
45
+ # Consolidate results
46
+ pred_labels = [labels[idx] for idx, tag in enumerate(results) if tag == 1]
47
+ pred_labels.append(sentiment)
48
+
49
+ return pred_labels
50
+
51
+ def get_multiple_predictions(csv):
52
+
53
+ df = pd.read_csv(csv)
54
+ df.columns = ['sequence']
55
+
56
+ df['sequence'] = df['sequence'].str.lower() #lower case
57
+ df['sequence'] = df['sequence'].str.replace('[^0-9a-zA-Z\s]','') #remove special char, punctuation
58
+
59
+ # Remove OOV words
60
+ df['sequence'] = df['sequence'].apply(lambda x: ' '.join([i for i in x.split() if i in w2v_vocab]))
61
+
62
+ # Remove rows with blank string
63
+ invalid = df[(pd.isna(df['sequence'])) | (df['sequence'] == '')]
64
+
65
+ df.dropna(inplace=True)
66
+ df = df[df['sequence'] != ''].reset_index(drop=True)
67
+
68
+ # Vectorise text and store in new dataframe. Sentence vector = average of word vectors
69
+ series_text_vectors = pd.DataFrame(df['sequence'].apply(lambda x: np.mean([w2v[i] for i in x.split()], axis=0)).values.tolist())
70
+
71
+ # Get predictions
72
+ pred_results = pd.DataFrame(model.predict(series_text_vectors), columns = labels)
73
+
74
+ # Join back to original sequence
75
+ final_results = df.join(series_text_vectors)
76
+
77
+ # Get sentiment labels
78
+ final_results['sentiment'] = final_results['sequence'].apply(lambda x: get_sentiment_label_facebook(classifier(x,
79
+ candidate_labels=['positive', 'negative'],
80
+ hypothesis_template='The sentiment of this is {}'))
81
+ )
82
+
83
+ # Append invalid rows
84
+ if len(invalid) == 0:
85
+ return final_results
86
+ else:
87
+ return pd.concat([final_results, invalid]).reset_index(drop=True)