strongpear commited on
Commit
c2b0a49
1 Parent(s): 175db61

update infer

Browse files
Files changed (1) hide show
  1. app.py +269 -1
app.py CHANGED
@@ -1,3 +1,271 @@
1
  import streamlit as st
2
 
3
- st.title('Hello World!')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
+ import numpy as np
4
+
5
+ import torch
6
+ from torch.autograd import Variable
7
+
8
+ import argparse
9
+ import os
10
+ import re
11
+
12
+ from data_preprocessing import remove_xem_them, remove_emojis, remove_stopwords, format_punctuation, remove_punctuation, clean_text, normalize_format, word_segment, format_price, format_price_v2
13
+
14
+ class inferSSCL():
15
+ def __init__(self, args='None'):
16
+ self.args = args
17
+ self.base_models = {}
18
+ self.batch_data = {}
19
+ self.test_data = []
20
+
21
+ def load_vocab_pretrain(self, file_pretrain_vocab, file_pretrain_vec, pad_tokens=True):
22
+ vocab2id = {'<pad>': 0}
23
+ id2vocab = {0: '<pad>'}
24
+
25
+ cnt = len(id2vocab)
26
+ with open(file_pretrain_vocab, 'r', encoding='utf-8') as fp:
27
+ for line in fp:
28
+ arr = re.split(' ', line[:-1])
29
+ vocab2id[arr[1]] = cnt
30
+ id2vocab[cnt] = arr[1]
31
+ cnt += 1
32
+ # word embedding
33
+ pretrain_vec = np.load(file_pretrain_vec)
34
+ pad_vec = np.zeros([1, pretrain_vec.shape[1]])
35
+ pretrain_vec = np.vstack((pad_vec, pretrain_vec))
36
+ return vocab2id, id2vocab, pretrain_vec
37
+
38
+ def load_vocabulary(self):
39
+ cluster_dir = './'
40
+ file_wordvec = 'vectors.npy'
41
+ file_vocab = 'vocab.txt'
42
+ file_kmeans_centroid = 'aspect_centroid.txt'
43
+ file_aspect_mapping = 'aspect_mapping.txt'
44
+
45
+ vocab2id, id2vocab, pretrain_vec = self.load_vocab_pretrain(os.path.join(cluster_dir, file_vocab), os.path.join(cluster_dir, file_wordvec))
46
+ vocab_size = len(vocab2id)
47
+
48
+ self.batch_data['vocab2id'] = vocab2id
49
+ self.batch_data['id2vocab'] = id2vocab
50
+ self.batch_data['pretrain_emb'] = pretrain_vec
51
+ self.batch_data['vocab_size'] = vocab_size
52
+
53
+ aspect_vec = np.loadtxt(os.path.join(cluster_dir, file_kmeans_centroid), dtype=float)
54
+
55
+ tmp = []
56
+ fp = open(os.path.join(cluster_dir, file_aspect_mapping), 'r')
57
+ for line in fp:
58
+ line = re.sub(r'[0-9]+', '', line)
59
+ line = line.replace(' ', '').replace('\n', '')
60
+ if line == "none":
61
+ tmp.append([0.] * 256)
62
+ else :
63
+ tmp.append([1.] * 256)
64
+ fp.close()
65
+
66
+ aspect_vec = aspect_vec * tmp
67
+ aspect_vec = torch.FloatTensor(aspect_vec).to(device)
68
+ self.batch_data['aspect_centroid'] = aspect_vec
69
+ self.batch_data['n_aspects'] = aspect_vec.shape[0]
70
+
71
+ def load_models(self):
72
+ self.base_models['embedding'] = torch.nn.Embedding(self.batch_data['vocab_size'], emb_size).to(device)
73
+ emb_para = torch.FloatTensor(self.batch_data['pretrain_emb']).to(device)
74
+ self.base_models['embedding'].weight = torch.nn.Parameter(emb_para)
75
+
76
+ self.base_models['asp_weight'] = torch.nn.Linear(emb_size, self.batch_data['n_aspects']).to(device)
77
+ self.base_models['asp_weight'].load_state_dict(torch.load('./asp_weight.model'))
78
+
79
+ self.base_models['attn_kernel'] = torch.nn.Linear(emb_size, emb_size).to(device)
80
+ self.base_models['attn_kernel'].load_state_dict(torch.load('./attn_kernel.model'), strict=False)
81
+
82
+
83
+ def build_pipe(self):
84
+
85
+ attn_pos, lbl_pos = self.encoder(
86
+ self.batch_data['pos_sen_var'],
87
+ self.batch_data['pos_pad_mask']
88
+ )
89
+
90
+ outw = np.around(attn_pos.data.cpu().numpy().tolist(), 4)
91
+ outw = outw.tolist()
92
+ outw = outw[:len(self.batch_data['comment'].split())]
93
+
94
+ asp_weight = self.base_models['asp_weight'](lbl_pos)
95
+ # Attention weight
96
+ asp_weight = torch.softmax(asp_weight, dim=1)
97
+
98
+ return asp_weight
99
+
100
+ def encoder(self, input_, mask_):
101
+
102
+ with torch.no_grad():
103
+ emb_ = self.base_models['embedding'](input_)
104
+
105
+ print(emb_.shape)
106
+
107
+ emb_ = emb_ * mask_.unsqueeze(2)
108
+
109
+ emb_avg = torch.sum(emb_, dim=1)
110
+ norm = torch.sum(mask_, dim=1, keepdim=True) + 1e-20
111
+
112
+ # query vector
113
+ enc_ = emb_avg.div(norm.expand_as(emb_avg))
114
+
115
+ #We Ex + be
116
+ emb_trn = self.base_models['attn_kernel'](emb_)
117
+
118
+ #query vetor * (We Ex + be)
119
+ attn_ = enc_.unsqueeze(1) @ emb_trn.transpose(1, 2)
120
+ attn_ = attn_.squeeze(1)
121
+
122
+ #alignment score
123
+ attn_ = self.args.smooth_factor * torch.tanh(attn_)
124
+
125
+ attn_ = attn_.masked_fill(mask_ == 0, -1e20)
126
+
127
+ # attention weight
128
+ attn_ = torch.softmax(attn_, dim=1)
129
+
130
+ #sxE
131
+ lbl_ = attn_.unsqueeze(1) @ emb_
132
+ lbl_ = lbl_.squeeze(1)
133
+
134
+ return attn_, lbl_
135
+
136
+ def build_batch(self, review):
137
+ vocab2id = self.batch_data['vocab2id']
138
+
139
+ sen_text = []
140
+ cmt = []
141
+ # sen_text_len = 0
142
+ sen_text_len = emb_size
143
+
144
+ senid = [vocab2id[wd] for wd in review.split() if wd in vocab2id]
145
+ sen_text.append(senid)
146
+
147
+ cmt.append(review)
148
+
149
+ # if len(senid) > sen_text_len:
150
+ # sen_text_len = len(senid)
151
+ sen_text_len = min(len(senid), sen_text_len)
152
+ sen_text = [itm[:sen_text_len] + [vocab2id['<pad>'] for _ in range(sen_text_len - len(itm))] for itm in sen_text]
153
+
154
+ sen_text_var = Variable(torch.LongTensor(sen_text)).to(device)
155
+ sen_pad_mask = Variable(torch.LongTensor(sen_text)).to(device)
156
+ sen_pad_mask[sen_pad_mask != vocab2id['<pad>']] = -1
157
+ sen_pad_mask[sen_pad_mask == vocab2id['<pad>']] = 0
158
+ sen_pad_mask = -sen_pad_mask
159
+
160
+ self.batch_data['comment'] = cmt
161
+
162
+ self.batch_data['pos_sen_var'] = sen_text_var
163
+ self.batch_data['pos_pad_mask'] = sen_pad_mask
164
+
165
+ def calculate_atten_weight(self):
166
+
167
+ attn_pos, lbl_pos = self.encoder(
168
+ self.batch_data['pos_sen_var'],
169
+ self.batch_data['pos_pad_mask']
170
+ )
171
+
172
+
173
+ asp_weight = self.base_models['asp_weight'](lbl_pos)
174
+ #print('asp_weight:', asp_weight)
175
+ asp_weight = torch.softmax(asp_weight, dim=1)
176
+ #print('soft_max:', asp_weight)
177
+
178
+ return asp_weight
179
+
180
+ def get_test_data(self):
181
+ asp_weight = self.calculate_atten_weight()
182
+ asp_weight = asp_weight.data.cpu().numpy().tolist()
183
+
184
+ output = {}
185
+ output['comment'] = self.batch_data['comment']
186
+ output['aspect_weight'] = asp_weight[0]
187
+ self.test_data.append(output)
188
+
189
+ def select_top(self, data):
190
+ #print(data)
191
+ d = np.abs(data - np.median(data))
192
+ mdev = np.median(d)
193
+ s = d/mdev if mdev else 0
194
+
195
+ return s
196
+
197
+ def get_predict(self, top_pred, aspect_label, threshold=1):
198
+ pred = {'none':0, 'do_an': 0, 'gia_ca':0, 'khong_gian': 0, 'phuc_vu': 0}
199
+ try:
200
+ for i in range(len(top_pred)):
201
+ if top_pred[i] > threshold:
202
+ pred[aspect_label[i]] = 1
203
+ except:
204
+ print('Error')
205
+ return pred
206
+
207
+ def get_evaluate_result(self, input_):
208
+
209
+ aspect_label = []
210
+ fp = open('./aspect_mapping.txt', 'r', encoding='utf8')
211
+ for line in fp:
212
+ aspect_label.append(line.split()[1])
213
+ fp.close()
214
+
215
+ top_score = self.select_top(input_['aspect_weight'])
216
+ print(top_score)
217
+ curr_pred = self.get_predict(top_score, aspect_label)
218
+
219
+ aspect_key = []
220
+ for key, value in curr_pred.items():
221
+ if int(value) == 1:
222
+ aspect_key.append(key)
223
+
224
+ return self.get_aspect(aspect_key)
225
+
226
+ def get_aspect(self, pred, ignore='none'):
227
+ if len(pred) > 1:
228
+ return(pred[1:])
229
+ else:
230
+ return(['None'])
231
+
232
+ def infer(self, text=''):
233
+ self.args.task = 'sscl-infer'
234
+
235
+ text = remove_xem_them(text)
236
+ text = remove_emojis(text)
237
+ text = format_punctuation(text)
238
+ text = remove_punctuation(text)
239
+ text = clean_text(text)
240
+ text = normalize_format(text)
241
+ text = word_segment(text)
242
+ text = remove_stopwords(text)
243
+ text = format_price(text)
244
+ input_ = format_price_v2(text)
245
+ print(input_)
246
+
247
+ self.load_vocabulary()
248
+ self.load_models()
249
+
250
+ self.build_batch(input_)
251
+ self.get_test_data()
252
+
253
+ val_result = self.test_data
254
+
255
+ self.get_evaluate_result(val_result[0])
256
+
257
+
258
+ parser = argparse.ArgumentParser()
259
+ parser.add_argument('--task', default='infer')
260
+ parser.add_argument('--smooth_factor', type=float, default=0.9)
261
+ device = 'cuda:0'
262
+ emb_size = 256
263
+
264
+ args = parser.parse_args(args=[])
265
+ model = inferSSCL(args)
266
+
267
+ cmt = st.text_area('Enter some text: ')
268
+ output = model.infer(cmt)
269
+
270
+ if output:
271
+ st.title(output)