Johnathan commited on
Commit
c7186c4
1 Parent(s): 55c6df8

add app.py file

Browse files
Files changed (1) hide show
  1. app.py +1374 -0
app.py ADDED
@@ -0,0 +1,1374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datetime import datetime
3
+ from urllib.parse import ParseResultBytes
4
+ from file_setting import leaf_idf_dict, leaf_IDF_dict, leafConv_dict
5
+ from PIL import Image
6
+ from segmentation import segmentation
7
+ from seg_file import userdict
8
+ from tqdm import tqdm
9
+
10
+ import copy
11
+ import faiss
12
+ import json
13
+ import math
14
+ import numpy as np
15
+ import os
16
+ import pandas as pd
17
+ import re
18
+ from sentence_transformers import SentenceTransformer, util
19
+ import streamlit as st
20
+ from streamlit_option_menu import option_menu
21
+ from streamlit_chat import message
22
+ import sys
23
+ import time
24
+ import unicodedata as uni
25
+
26
+ module_dir = os.path.dirname(__file__)
27
+ data_dir = os.path.join(module_dir, "data")
28
+
29
+
30
+ im = Image.open(os.path.join(data_dir, "MetaEdge.png"))
31
+ st.set_page_config(
32
+ page_title="ChatBot Prototype testing",
33
+ page_icon=im,
34
+ layout="wide",
35
+ )
36
+
37
+ @st.experimental_singleton
38
+ @st.cache
39
+ def prepare_model1():
40
+ return SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
41
+ # model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
42
+ @st.experimental_singleton
43
+ @st.cache(suppress_st_warning = True, allow_output_mutation=True)
44
+ def prepare_model2():
45
+ return SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
46
+ global model1
47
+ global model2
48
+ model1 = prepare_model1()
49
+ model2 = prepare_model2()
50
+ # model2 = model1
51
+
52
+ seg = segmentation()
53
+ noidf = []
54
+ no_userdict = []
55
+ class PreProcess():
56
+ """
57
+ 去除括號、斷詞、leaf_conversion、leaf轉換、計算句向量
58
+
59
+ 測試的時候的確彙整個excel資料夾丟進來匹配
60
+ 但中心概念是API打過來一個外規,找到跟我們的目標內規最有關連的內規
61
+ """
62
+ def __init__(self):
63
+ pass
64
+
65
+ def remove_parenthesis(self, text):
66
+
67
+ texts = uni.normalize('NFKC', str(text))
68
+ texts = re.split('[─│ ][(){}\[\]]', texts)
69
+ texts = ''.join(texts)
70
+
71
+ return texts
72
+
73
+ def word_to_leaf(self, text):
74
+
75
+ text_lst = text.split(" ")
76
+
77
+ leaf_result = []
78
+ for word in text_lst:
79
+ try:
80
+ tmpLeaf = userdict[word]
81
+ if tmpLeaf in leaf_idf_dict:
82
+ leaf_result.append(tmpLeaf)
83
+ else:
84
+ noidf.append(tmpLeaf)
85
+ except:
86
+ # 這裡要記錄有哪些不在userdict裡的word
87
+ no_userdict.append(word)
88
+
89
+
90
+ return leaf_result
91
+
92
+ def leaf_conversion(self, segLeaves_lst):
93
+
94
+ def ngram(lst, n):
95
+ """
96
+ Input:
97
+ lst:清洗後的電文詞-> [[word,word,...],[word,word,...],[word,word,...]......]
98
+ n: 要多少gram(目前使用30)
99
+ """
100
+ if len(lst) < n:
101
+ n = len(lst)
102
+ nLst = []
103
+ for i in range(0, len(lst)):
104
+ ntmp = []
105
+ try:
106
+ for j in range(n):
107
+ ntmp.append(lst[i+j])
108
+ except:
109
+ pass
110
+ if len(ntmp) == n:
111
+ nLst.append(ntmp)
112
+ try:
113
+ nLst = [ele.remove(" ") for ele in nLst]
114
+ except:
115
+ pass
116
+
117
+
118
+ return nLst
119
+
120
+ segLeaves_str = " ".join(segLeaves_lst)
121
+
122
+ ct = 0
123
+
124
+ for i in range(4, 0, -1):
125
+
126
+ segLeaves_igram = ngram(segLeaves_lst, i)
127
+
128
+ for before_leaf in segLeaves_igram:
129
+
130
+ before_leaf = " ".join(before_leaf)
131
+
132
+ if before_leaf in leafConv_dict:
133
+ if (" "+before_leaf+" ") in segLeaves_str:
134
+
135
+ before = (" "+before_leaf+" ")
136
+ after = (" "+str(leafConv_dict[before_leaf])+" ")
137
+ segLeaves_str = segLeaves_str.replace(before, after)
138
+
139
+ elif (before_leaf+" ") in segLeaves_str:
140
+ if segLeaves_str.index(before_leaf) == 0:
141
+
142
+ before = (before_leaf+" ")
143
+ after = (str(leafConv_dict[before_leaf])+" ")
144
+ segLeaves_str = segLeaves_str.replace(before, after)
145
+
146
+ elif (" "+before_leaf) in segLeaves_str:
147
+ if segLeaves_str.index(before_leaf) == (len(segLeaves_str) - len(before_leaf)):
148
+
149
+ before = (" "+before_leaf)
150
+ after = (" "+str(leafConv_dict[before_leaf]))
151
+ segLeaves_str = segLeaves_str.replace(before, after)
152
+
153
+ segLeaves_str = segLeaves_str.replace("nan", "")
154
+ while " " in segLeaves_str:
155
+ segLeaves_str = segLeaves_str.replace(" ", " ")
156
+
157
+ segLeaves_lst = []
158
+
159
+ for leaf in segLeaves_str.split(" "):
160
+ if leaf == '':
161
+ continue
162
+ try:
163
+ if leaf_idf_dict[leaf] < 10:
164
+ continue
165
+ elif leaf_idf_dict[leaf] == 19:
166
+ continue
167
+ elif leaf_idf_dict[leaf] == 20:
168
+ continue
169
+ else:
170
+ segLeaves_lst.append(leaf)
171
+ except Exception as e:
172
+ print(e)
173
+
174
+
175
+ return segLeaves_lst
176
+
177
+ def all_preprocess(self, text):
178
+
179
+ text = self.remove_parenthesis(text)
180
+
181
+ text = seg.seg_one(text)
182
+
183
+ text = self.word_to_leaf(text)
184
+
185
+ # text = self.leaf_conversion(text)
186
+
187
+ return text
188
+
189
+ class PairingRule():
190
+
191
+ def __init__(self, leaf_IDF_dict, lower, middle, upper):
192
+ self.leaf_IDF_dict = leaf_IDF_dict
193
+ self.leaf_idf_dict = leaf_idf_dict
194
+
195
+ self.lower_thres = lower
196
+ self.middle_thres = middle
197
+ self.upper_thres = upper
198
+
199
+ def top(self, law_leafIDF, rank):
200
+
201
+ if len(law_leafIDF) < rank:
202
+ rank = len(law_leafIDF)
203
+ result = set([leaf[0] for leaf in law_leafIDF[:rank]])
204
+
205
+ return result
206
+
207
+ def sortingIDF_leaf(self, law):
208
+
209
+ for i in range(len(law)):
210
+ if law[i][0] in ('(', '{'):
211
+ if law[i][-1] in (')', '}'):
212
+ law[i] = law[i][1:-1]
213
+
214
+ lawIDF = [self.leaf_IDF_dict[leaf] for leaf in law if (leaf in self.leaf_IDF_dict) and (leaf in self.leaf_idf_dict)]
215
+
216
+ law_leafIDF = dict(zip(law, lawIDF))
217
+
218
+ law_leafIDF = sorted(law_leafIDF.items(), key=lambda item: item[1], reverse = True)
219
+
220
+ return law_leafIDF
221
+
222
+ def IDFCA(self, outlaw, inlaw):
223
+
224
+ PN = ""
225
+ if not outlaw: return "N"
226
+ if not inlaw: return "N"
227
+
228
+ # 先準備好 按照IDF大小的內外規leaf list
229
+ outlaw_leafIDF = self.sortingIDF_leaf(outlaw)
230
+ inlaw_leafIDF = self.sortingIDF_leaf(inlaw)
231
+
232
+ if not outlaw_leafIDF: return "N"
233
+ if not inlaw_leafIDF: return "N"
234
+
235
+ # 先準備好 前幾名IDF的leaf list
236
+ outlaw_top2IDF = self.top(outlaw_leafIDF, 2)
237
+ inlaw_top2IDF = self.top(inlaw_leafIDF, 2)
238
+
239
+ outlaw_top3IDF = self.top(outlaw_leafIDF, 3)
240
+ inlaw_top3IDF = self.top(inlaw_leafIDF, 3)
241
+
242
+ outlaw_top4IDF = self.top(outlaw_leafIDF, 4)
243
+ inlaw_top4IDF = self.top(inlaw_leafIDF, 4)
244
+
245
+ if len(outlaw) == 1:
246
+ if len(inlaw) in {1, 2}:
247
+ if outlaw_leafIDF[0][0] == inlaw_leafIDF[0][0]:
248
+ PN = "P1"
249
+ elif len(inlaw) in {3, 4, 5, 6}:
250
+ if outlaw_leafIDF[0][0] in inlaw_top2IDF:
251
+ PN = "P1"
252
+
253
+ elif len(outlaw) == 2:
254
+ if len(inlaw) == 1:
255
+ if outlaw_leafIDF[0][0] == inlaw_leafIDF[0][0]:
256
+ PN = "P2"
257
+ elif len(inlaw) == 2:
258
+ if outlaw_leafIDF[0][0] in inlaw_top2IDF and inlaw_leafIDF[0][0] in outlaw_top2IDF:
259
+ PN = "P2"
260
+ elif len(inlaw) >= 3:
261
+ if (len(set(outlaw_top2IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top3IDF).intersection(outlaw)) >= 2):
262
+ PN = "P2"
263
+
264
+ elif len(outlaw) == 3:
265
+ if len(inlaw) == 1:
266
+ if inlaw_leafIDF[0][0] in outlaw_top2IDF:
267
+ PN = "P3"
268
+ elif len(inlaw) == 2:
269
+ if (len(set(outlaw_top3IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top2IDF).intersection(outlaw)) >= 2):
270
+ PN = "P3"
271
+ elif 3 <= len(inlaw) <= 5:
272
+ if (len(set(outlaw_top3IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top3IDF).intersection(outlaw)) >= 2):
273
+ PN = "P3"
274
+ elif len(inlaw) >= 6:
275
+ if (len(set(outlaw_top3IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top4IDF).intersection(outlaw)) >= 2):
276
+ PN = "P3"
277
+
278
+ elif len(outlaw) == 4:
279
+ if len(inlaw) == 1:
280
+ if inlaw_leafIDF[0][0] in outlaw_top2IDF:
281
+ PN = "P4"
282
+ elif len(inlaw) == 2:
283
+ if (len(set(outlaw_top3IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top2IDF).intersection(outlaw)) == 2):
284
+ PN = "P4"
285
+ elif 3 <= len(inlaw) <= 5:
286
+ if (len(set(outlaw_top3IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top3IDF).intersection(outlaw)) >= 2):
287
+ PN = "P4"
288
+ elif len(inlaw) >= 6:
289
+ if (len(set(outlaw_top4IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top4IDF).intersection(outlaw)) >= 2):
290
+ PN = "P4"
291
+
292
+ elif len(outlaw) == 5:
293
+ if len(inlaw) == 1:
294
+ if inlaw_leafIDF[0][0] in outlaw_top2IDF:
295
+ PN = "P5"
296
+ elif len(inlaw) == 2:
297
+ if (len(set(outlaw_top3IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top2IDF).intersection(outlaw)) == 2):
298
+ PN = "P5"
299
+ elif 3 <= len(inlaw) <= 5:
300
+ if (len(set(outlaw_top3IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top3IDF).intersection(outlaw)) >= 2):
301
+ PN = "P5"
302
+ elif len(inlaw) >= 6:
303
+ if (len(set(outlaw_top4IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top4IDF).intersection(outlaw)) >= 2):
304
+ PN = "P5"
305
+
306
+ elif len(outlaw) >= 6:
307
+ if len(inlaw) == 1:
308
+ if inlaw_leafIDF[0][0] in outlaw_top2IDF:
309
+ PN = "P6+"
310
+ elif len(inlaw) == 2:
311
+ if (len(set(outlaw_top3IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top2IDF).intersection(outlaw)) == 2):
312
+ PN = "P6+"
313
+ elif len(inlaw) == 3:
314
+ if (len(set(outlaw_top4IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top3IDF).intersection(outlaw)) >= 2):
315
+ PN = "P6+"
316
+ elif len(inlaw) >= 4:
317
+ if (len(set(outlaw_top4IDF).intersection(inlaw)) >= 2) and (len(set(inlaw_top4IDF).intersection(outlaw)) >= 2):
318
+ PN = "P6+"
319
+
320
+
321
+ if PN != "":
322
+ return PN
323
+ else:
324
+ return "N"
325
+
326
+ def COLA(self, outlaw, inlaw):
327
+
328
+ PN = ""
329
+
330
+ def ngram(lst, n):
331
+ """
332
+ Input:
333
+ lst:清洗後的電文詞-> [[word,word,...],[word,word,...],[word,word,...]......]
334
+ n: 要多少gram(目前使用30)
335
+ """
336
+ if len(lst) < n:
337
+ n = len(lst)
338
+ nLst = []
339
+ for i in range(0, len(lst)):
340
+ ntmp = []
341
+ try:
342
+ for j in range(n):
343
+ ntmp.append(lst[i+j])
344
+ except:
345
+ pass
346
+ if len(ntmp) == n:
347
+ nLst.append(ntmp)
348
+ try:
349
+ nLst = [ele.remove(" ") for ele in nLst]
350
+ except:
351
+ pass
352
+
353
+
354
+ return nLst
355
+
356
+ outlaw_leafIDF = self.sortingIDF_leaf(outlaw)
357
+ inlaw_leafIDF = self.sortingIDF_leaf(inlaw)
358
+
359
+ outlaw_top3IDF = self.top(outlaw_leafIDF, 3)
360
+ inlaw_top3IDF = self.top(inlaw_leafIDF, 3)
361
+
362
+ outlaw_3gram = ngram(outlaw, 3)
363
+ inlaw_3gram = ngram(inlaw, 3)
364
+
365
+ outlaw_4gram = ngram(outlaw, 4)
366
+ inlaw_4gram = ngram(inlaw, 4)
367
+
368
+ # (P4) 如果 外規 leaf 有連續 4 個 leaf 出現於內規當中而且也連續, 順序也相同
369
+ for out_grams in outlaw_4gram:
370
+ if out_grams in inlaw_4gram:
371
+ PN = "P4"
372
+
373
+ # (P3) 內規 leaf 為外規 leaf 子集合(所謂集合即不包含重複 leaf), 且內規至少一個 IDF 名列外規前3名 && 該 idf > 10
374
+ if set((inlaw)).issubset(outlaw):
375
+ check = set(inlaw).intersection(outlaw_top3IDF)
376
+ if len(check) > 0:
377
+ for leaf in check:
378
+ if self.leaf_idf_dict[leaf] > 10:
379
+ PN = "P3"
380
+ break
381
+
382
+ # (P2) 如果 外規 leaf 全部出現在內規中
383
+ if set((outlaw)).issubset(inlaw):
384
+ PN = "P2"
385
+
386
+ # (P1) 如果 外規 leaf 有連續 3 個 leaf 出現於內規當中而且也連續, 順序可以不同,且其中至少一個 IDF 名列前3名(內規、外規皆須) && 該 idf > 10
387
+ for out_grams in outlaw_3gram:
388
+ for in_grams in inlaw_3gram:
389
+ if set((out_grams)).issubset(in_grams):
390
+ if (len(set(out_grams).intersection(outlaw_top3IDF)) >= 1) and (len(set(in_grams).intersection(inlaw_top3IDF)) >= 1):
391
+ PN = "P1"
392
+ break
393
+ if PN != "":
394
+ break
395
+
396
+ if PN != "":
397
+ return PN
398
+ else:
399
+ return "N1"
400
+
401
+ def COSR(self, outlaw, inlaw, cos):
402
+
403
+ PN = ""
404
+
405
+ outlaw_leafIDF = self.sortingIDF_leaf(outlaw)
406
+ inlaw_leafIDF = self.sortingIDF_leaf(inlaw)
407
+
408
+ # X = 內外規交集內 idf>=10 leaf 數量
409
+ law_intersection = set(outlaw).intersection(inlaw)
410
+ X = len([leaf for leaf in law_intersection if self.leaf_idf_dict[leaf] >= 10])
411
+
412
+ # Y = max( 內規交集外 idf>10 leaf 數量, 外規交集外 idf>10 leaf 數量)
413
+ out_difference = set(outlaw).difference(inlaw)
414
+ in_difference = set(inlaw).difference(outlaw)
415
+
416
+ out_idf10 = len([leaf for leaf in out_difference if self.leaf_idf_dict[leaf] > 10])
417
+ in_idf10 = len([leaf for leaf in in_difference if self.leaf_idf_dict[leaf] > 10])
418
+
419
+ Y = max(out_idf10, in_idf10)
420
+
421
+ if Y == 1:
422
+ if X >= 5:
423
+ PN = "P1"
424
+ elif X == 1:
425
+ PN = round(cos, 6)
426
+ elif 2 <= X <= 4:
427
+ PN = round(math.sqrt(cos), 6)
428
+
429
+ elif Y == 2:
430
+ if X >= 5:
431
+ PN = "P2"
432
+ elif X == 1:
433
+ PN = round(cos ** 2, 6)
434
+ elif X == 2:
435
+ PN = round(cos, 6)
436
+ elif 3 <= X <= 4:
437
+ PN = round(math.sqrt(cos), 6)
438
+
439
+ elif Y == 3:
440
+ if X >= 5:
441
+ PN = "P3"
442
+ elif 1 <= X <= 2:
443
+ PN = round(cos ** 2, 6)
444
+ elif X == 3:
445
+ PN = round(cos, 6)
446
+ elif X == 4:
447
+ PN == round(math.sqrt(cos), 6)
448
+
449
+ elif Y == 4:
450
+ if X >= 5:
451
+ PN = "P4"
452
+ elif 1 <= X <= 3:
453
+ PN = round(cos ** 2, 6)
454
+ elif X == 4:
455
+ PN = round(cos, 6)
456
+
457
+ elif Y >= 5:
458
+ if X >= 5:
459
+ if outlaw_leafIDF[0][0] in law_intersection:
460
+ PN = "P5"
461
+ elif inlaw_leafIDF[0][0] in law_intersection:
462
+ PN = "P5"
463
+ else:
464
+ PN = "Z"
465
+
466
+ elif X == 1:
467
+ PN = "N1"
468
+ elif X == 2:
469
+ PN = "N2"
470
+ elif X == 3:
471
+ PN = "N3"
472
+ elif X == 4:
473
+ PN = "N4"
474
+
475
+
476
+ if PN != "":
477
+ return str(PN)
478
+ else:
479
+ return "N"
480
+
481
+ def TOPN(self, outlaw, inlaw):
482
+
483
+ PN = "N"
484
+ outlaw = [ele for ele in outlaw if leaf_idf_dict[ele] >= 10]
485
+ inlaw = [ele for ele in inlaw if leaf_idf_dict[ele] >= 10]
486
+
487
+ if not outlaw: return "N"
488
+ if not inlaw: return "N"
489
+
490
+ outlaw_leafIDF = self.sortingIDF_leaf(outlaw)
491
+ inlaw_leafIDF = self.sortingIDF_leaf(inlaw)
492
+
493
+ if not outlaw_leafIDF: return "N"
494
+ if not inlaw_leafIDF: return "N"
495
+
496
+ outlaw_top3IDF = self.top(outlaw_leafIDF, 3)
497
+ inlaw_top3IDF = self.top(inlaw_leafIDF, 3)
498
+
499
+ if (len(outlaw) >= 3) and (len(inlaw) >= 3):
500
+ if (set(outlaw_top3IDF).issubset(inlaw)) and (set(inlaw_top3IDF).issubset(outlaw)):
501
+ PN = "P1"
502
+ if (set(outlaw[:3]).issubset(inlaw)) and (set(inlaw[:3]).issubset(outlaw)):
503
+ PN = "P2"
504
+
505
+ # if len(outlaw) >= 3:
506
+ # if set(outlaw_top3IDF).issubset(inlaw):
507
+ # PN = "P1"
508
+ # elif set(outlaw[:3]).issubset(inlaw):
509
+ # PN = "P2"
510
+ # if len(inlaw) >= 3:
511
+ # if set(inlaw_top3IDF).issubset(outlaw):
512
+ # PN = "P1"
513
+ # elif set(inlaw[:3]).issubset(outlaw):
514
+ # PN = "P2"
515
+
516
+ check = 0
517
+ if (len(outlaw) != 0) and (len(inlaw) != 0):
518
+ if (len(outlaw) < 3) and (len(inlaw) > 3):
519
+ if outlaw_leafIDF[0][0] == inlaw_leafIDF[0][0]:
520
+ PN = "P3"
521
+
522
+
523
+ if len(inlaw_leafIDF) > 1:
524
+ if outlaw_leafIDF[0][0] == inlaw_leafIDF[1][0]:
525
+ if self.leaf_IDF_dict[inlaw_leafIDF[1][0]] > 2.5:
526
+ check += 1
527
+ # elif len(inlaw) < 3:
528
+ # if inlaw_leafIDF[0][0] == outlaw_leafIDF[0][0]:
529
+ # PN = "P3"
530
+ if len(outlaw_leafIDF) > 1:
531
+ if inlaw_leafIDF[0][0] == outlaw_leafIDF[1][0]:
532
+ if self.leaf_IDF_dict[outlaw_leafIDF[1][0]] > 2.5:
533
+ check += 1
534
+
535
+ if check == 2:
536
+ PN = "P3"
537
+
538
+
539
+ if PN[0] == "P":
540
+ return PN
541
+ else:
542
+ return "N"
543
+
544
+ def scoring(self, cos, IDFCA, COLA, COSR, TOPN):
545
+
546
+ PN = ""
547
+ score = 0
548
+ # upper_threshold = 0.85
549
+
550
+ if cos < self.lower_thres:
551
+ PN = "N"
552
+ score = "N/A"
553
+
554
+ elif self.lower_thres <= cos < self.upper_thres:
555
+
556
+ if IDFCA[0] == "P":
557
+ score += 1
558
+ else:
559
+ score -= 1
560
+
561
+ if COLA[0] == "P":
562
+ score += 2
563
+ else:
564
+ score += -1
565
+
566
+ if COSR[0] == "P":
567
+ score += 1
568
+ elif (COSR[0] not in {"P", "N", "Z", "E"}):
569
+ if (float(COSR) > cos):
570
+ score += 1
571
+ else:
572
+ score -= 1
573
+
574
+ if TOPN[0] == "P":
575
+ score += 2
576
+ else:
577
+ score += 0
578
+
579
+
580
+ if self.lower_thres <= cos < self.middle_thres:
581
+ if score >= 4:
582
+ PN = "P"
583
+ elif self.middle_thres <= cos < self.upper_thres:
584
+ if score >= 0:
585
+ PN = "P"
586
+
587
+ elif cos >= self.upper_thres:
588
+
589
+ check = 0
590
+ PN = "P"
591
+ if IDFCA[0] == "N":
592
+ check += 1
593
+ if COLA[0] == "N":
594
+ check += 1
595
+ if COSR[0] not in {"P", "N", "Z", "E"}:
596
+ if (float(COSR) < 0.9):
597
+ check += 1
598
+ PN = "N"
599
+ elif COSR[0] in {"N"}:
600
+ PN = "N"
601
+
602
+ score = "N/A"
603
+
604
+ if PN == "":
605
+ PN = "N"
606
+ score = "N/A"
607
+
608
+ return PN, score
609
+
610
+ def level(self, outlaw, inlaw, cos, PN, score):
611
+
612
+ level = ""
613
+
614
+ if (cos > 0.98):
615
+ if (len(outlaw) >= 2) or ((len(outlaw) == 1) and (self.leaf_idf_dict[outlaw[0]] > 10)):
616
+ level = "L1"
617
+ elif (len(outlaw) == 1) and (self.leaf_idf_dict[outlaw[0]] > 10):
618
+ level = "L5"
619
+
620
+ elif PN[0] == "P":
621
+ if (len(outlaw) > 1) and (len(inlaw) > 1):
622
+ level = "L2"
623
+ elif (len(outlaw) == 1) or (len(inlaw) == 1):
624
+ level = "L3"
625
+
626
+ elif PN[0] == "N":
627
+ if cos >= 0.8:
628
+ level = "L6"
629
+ elif 0.7 <= cos < 0.8:
630
+ level = "L7"
631
+ elif cos < 0.7:
632
+ level = "L8"
633
+
634
+ if level != "":
635
+ return level
636
+ else:
637
+ return "no_level"
638
+
639
+ def display_leaves(self, law_leaf):
640
+
641
+ # leaves = list(law_leaf)
642
+ leaves = copy.deepcopy(law_leaf)
643
+ for i in range(len(leaves)):
644
+ if leaves[i][0] in {'(', '{'}:
645
+ if leaves[i][-1] in {')', '}'}:
646
+ continue
647
+ if self.leaf_idf_dict[leaves[i]] in {19.0, 20.0}:
648
+ continue
649
+ else:
650
+ if self.leaf_idf_dict[leaves[i]] == 10.0:
651
+ leaves[i] = f"({leaves[i]})"
652
+ elif self.leaf_idf_dict[leaves[i]] > 10.0:
653
+ leaves[i] = "["+str(leaves[i])+"]"
654
+
655
+ leaves = " ".join(leaves)
656
+
657
+ return leaves
658
+
659
+ class FaissProcess():
660
+ """
661
+ 預設是外規配對內規,因此把內規做成faiss index
662
+ """
663
+ def __init__(self):
664
+
665
+ self.ct = 0
666
+ pass
667
+
668
+ def get_faissIndex(self, searched_vecs):
669
+
670
+ searched_sentenceVec = np.array(searched_vecs)
671
+ faiss.normalize_L2((searched_sentenceVec))
672
+
673
+ dimension = 768
674
+ faissIndex = faiss.IndexFlatIP(dimension)
675
+ faissIndex.add(searched_sentenceVec)
676
+ self.ct += 1
677
+ faiss_dir = r"C:\Users\楊尚霖\暫存區\MetaEdge\ChatBot_prototype_testing_streamlit"
678
+ # faiss.write_index(faissIndex, os.path.join(faiss_dir, "chatbot_faissIndex1.index"))
679
+ # faiss.write_index(diffQ_faissIndex_model2, r"C:\Users\楊尚霖\暫存區\MetaEdge\ChatBot_prototype_testing_streamlit\chatbot_faissIndex2.index")
680
+
681
+ return faissIndex
682
+
683
+ def get_faissResult(self, searching_vec, faissIndex, diffQ):
684
+
685
+ k = 10
686
+
687
+ searching_vec = np.array(searching_vec)
688
+ faiss.normalize_L2(searching_vec)
689
+
690
+ D, I = faissIndex.search(searching_vec, k)
691
+
692
+ # simQ_result = []
693
+ # cosine_result = []
694
+ # for i, index in enumerate(I[0]):
695
+
696
+ # simQ_result.append(diffQ[I[0][i]])
697
+ # cosine_result.append(round(D[0][i], 4))
698
+ # print(diffQ[I[0][i]])
699
+ # print(f"{round(D[0][i], 4)}\t{diffQ[I[0][i]]}")
700
+
701
+ return D, I
702
+
703
+
704
+ @st.cache(suppress_st_warning=True)
705
+ def prepare_diffQ_content():
706
+
707
+ PQA_path = os.path.join(data_dir, "PQA_bank_20220801.xlsx")
708
+ PQA_df = pd.read_excel(PQA_path)
709
+
710
+
711
+ diffQ_lst = [uni.normalize("NFKC", str(ele)) for ele in PQA_df["變化問題"]]
712
+ briefQ_lst = [uni.normalize("NFKC", str(ele)) for ele in PQA_df["問題簡述"]]
713
+ answer_lst = [uni.normalize("NFKC", str(ele)) for ele in PQA_df["回答"]]
714
+
715
+ for i, ans in enumerate(answer_lst):
716
+ if "\n" in ans:
717
+ answer_lst[i] = answer_lst[i].replace("\n", "")
718
+
719
+
720
+ diff_brief_dict = dict(zip(diffQ_lst, briefQ_lst))
721
+ qa_dict = dict(zip(briefQ_lst, answer_lst))
722
+
723
+ order_diff_path = os.path.join(data_dir, "order_diffQ.txt")
724
+ diffQ = []
725
+ with open(order_diff_path, mode = "r", encoding = "utf-8") as r:
726
+ for line in r:
727
+ tmp = uni.normalize("NFKC", line.strip())
728
+ diffQ.append(tmp)
729
+
730
+ return diff_brief_dict, diffQ, qa_dict
731
+
732
+ @st.cache(suppress_st_warning=True)
733
+ def prepare_diffQ_leaf(diffQ):
734
+
735
+ diffQ_leaf = []
736
+ diffQ_vec = []
737
+ for i, question in tqdm(enumerate(diffQ), total = len(diffQ)):
738
+ leaves = PreProc.all_preprocess(question)
739
+
740
+ # if mode == 1:
741
+ # embedding = model1.encode(question, convert_to_numpy = True)
742
+ # else:
743
+ # embedding = model2.encode(question, convert_to_numpy = True)
744
+
745
+ # vector = np.array(embedding).astype("float32")
746
+
747
+ diffQ_leaf.append(leaves)
748
+ # diffQ_vec.append(vector)
749
+
750
+ # diffQ_faissIndex = FaissProc.get_faissIndex(diffQ_vec)
751
+
752
+ return diffQ_leaf
753
+
754
+ @st.cache(suppress_st_warning=True)
755
+ def prepare_diffQ_faissIndex(mode):
756
+ if mode == 1:
757
+ return faiss.read_index(os.path.join(data_dir, "faissIndex1_20220805.index"))
758
+ else:
759
+ return faiss.read_index(os.path.join(data_dir, "faissIndex2_20220805.index"))
760
+
761
+
762
+
763
+ @st.cache(suppress_st_warning=True)
764
+ def prepare_loan_diffQ_content():
765
+
766
+ PQA_path = os.path.join(data_dir, "信貸FAQ總表.xlsx")
767
+ PQA_df = pd.read_excel(PQA_path)
768
+
769
+
770
+ loan_diffQ_lst = [uni.normalize("NFKC", str(ele)) for ele in PQA_df["標準問題"]]
771
+ loan_briefQ_lst = [uni.normalize("NFKC", str(ele)) for ele in PQA_df["標準問題"]]
772
+ loan_answer_lst = [uni.normalize("NFKC", str(ele)) for ele in PQA_df["回答"]]
773
+
774
+ for i, ans in enumerate(loan_answer_lst):
775
+ if "\n" in ans:
776
+ loan_answer_lst[i] = loan_answer_lst[i].replace("\n", "")
777
+
778
+
779
+ loan_diff_brief_dict = dict(zip(loan_diffQ_lst, loan_briefQ_lst))
780
+ loan_qa_dict = dict(zip(loan_briefQ_lst, loan_answer_lst))
781
+
782
+ # order_diff_path = os.path.join(data_dir, "order_diffQ.txt")
783
+ # diffQ = []
784
+ # with open(order_diff_path, mode = "r", encoding = "utf-8") as r:
785
+ # for line in r:
786
+ # tmp = uni.normalize("NFKC", line.strip())
787
+ # diffQ.append(tmp)
788
+
789
+ return loan_diff_brief_dict, loan_diffQ_lst, loan_qa_dict
790
+
791
+ @st.cache(suppress_st_warning=True)
792
+ def prepare_loan_diffQ_leaf(loan_diffQ):
793
+
794
+ loan_diffQ_leaf = []
795
+ for i, question in tqdm(enumerate(loan_diffQ), total = len(loan_diffQ)):
796
+ leaves = PreProc.all_preprocess(question)
797
+
798
+ loan_diffQ_leaf.append(leaves)
799
+
800
+
801
+ return loan_diffQ_leaf
802
+
803
+ @st.cache(suppress_st_warning=True)
804
+ def prepare_loan_diffQ_faissIndex(mode):
805
+ if mode == 1:
806
+ return faiss.read_index(os.path.join(data_dir, "chatbot_loan_faissIndex1.index"))
807
+ else:
808
+ return faiss.read_index(os.path.join(data_dir, "chatbot_loan_faissIndex2.index"))
809
+
810
+
811
+
812
+ def pairing_search(userQ):
813
+
814
+ userQ_leaf = PreProc.all_preprocess(userQ)
815
+
816
+ if model_choose == "paraphrase-multilingual-mpnet-base-v2":
817
+ embedding = model1.encode(userQ, convert_to_numpy = True)
818
+ elif model_choose == "all-mpnet-base-v2":
819
+ embedding = model2.encode(userQ, convert_to_numpy = True)
820
+
821
+ userQ_vec = np.array(embedding).astype("float32")
822
+
823
+ if model_choose == "paraphrase-multilingual-mpnet-base-v2":
824
+ if data_choose == "銀行QA":
825
+ simDistance, simIndex = FaissProc.get_faissResult([userQ_vec], diffQ_faissIndex_model1, diffQ)
826
+ elif data_choose == "信貸QA":
827
+ simDistance, simIndex = FaissProc.get_faissResult([userQ_vec], loan_diffQ_faissIndex_model1, loan_diffQ)
828
+ elif model_choose == "all-mpnet-base-v2":
829
+ if data_choose == "銀行QA":
830
+ simDistance, simIndex = FaissProc.get_faissResult([userQ_vec], diffQ_faissIndex_model2, diffQ)
831
+ elif data_choose == "信貸QA":
832
+ simDistance, simIndex = FaissProc.get_faissResult([userQ_vec], loan_diffQ_faissIndex_model2, loan_diffQ)
833
+
834
+ whole_result = []
835
+ bot_answer = []
836
+ for i, cos in enumerate(simDistance[0]):
837
+
838
+ if i == 5: break
839
+
840
+ cos = round(float(cos), 4)
841
+
842
+
843
+ if data_choose == "銀行QA":
844
+ tmp_diffQ = diffQ[simIndex[0][i]]
845
+ tmp_diffQ_leaf = diffQ_leaf[simIndex[0][i]]
846
+ elif data_choose == "信貸QA":
847
+ tmp_diffQ = loan_diffQ[simIndex[0][i]]
848
+ tmp_diffQ_leaf = loan_diffQ_leaf[simIndex[0][i]]
849
+
850
+ idfca = rule.IDFCA(tmp_diffQ_leaf, userQ_leaf)
851
+ cola = rule.COLA(tmp_diffQ_leaf, userQ_leaf)
852
+ cosr = rule.COSR(tmp_diffQ_leaf, userQ_leaf, cos)
853
+ topn = rule.TOPN(tmp_diffQ_leaf, userQ_leaf)
854
+
855
+ PN, score = rule.scoring(cos, idfca, cola, cosr, topn)
856
+ if PN == "P":
857
+ if data_choose == "銀行QA":
858
+ bot_answer.append(qa_dict[diff_brief_dict[tmp_diffQ]])
859
+ elif data_choose == "信貸QA":
860
+ bot_answer.append(loan_qa_dict[loan_diff_brief_dict[tmp_diffQ]])
861
+
862
+
863
+ # print(f"{PN}\t{cos}\t{diffQ[simIndex[0][i]]}")
864
+ tmp_result = []
865
+ if data_choose == "銀行QA":
866
+ tmp_result.extend([userQ, rule.display_leaves(userQ_leaf), diffQ[simIndex[0][i]], rule.display_leaves(tmp_diffQ_leaf)])
867
+ tmp_result.append(diff_brief_dict[diffQ[simIndex[0][i]]])
868
+ elif data_choose == "信貸QA":
869
+ tmp_result.extend([userQ, rule.display_leaves(userQ_leaf), loan_diffQ[simIndex[0][i]], rule.display_leaves(tmp_diffQ_leaf)])
870
+ tmp_result.append(loan_diff_brief_dict[loan_diffQ[simIndex[0][i]]])
871
+ tmp_result.extend([cos, (idfca), (cola), (cosr), (topn), (score), (PN)])
872
+
873
+ whole_result.append(tmp_result)
874
+
875
+ result_columns = ["user問題", "user問題(leaf)", "變化問題", "變化問題(leaf)", "問題簡述", "cosine", "IDFCA", "COLA", "COSR", "TOPN", "score", "P/N"]
876
+ result_df = pd.DataFrame(columns = result_columns, data = whole_result)
877
+
878
+ try:
879
+ result_df = result_df.astype(str)
880
+ # st.subheader("回答")
881
+ # if bot_answer:
882
+ # for a, ans in enumerate(bot_answer):
883
+ # st.text(f"{a+1}.\t{ans}")
884
+ # else:
885
+ # st.text("**沒有適合的回答**")
886
+ # st.subheader("搜尋結果")
887
+ # st.dataframe(result_df)
888
+ except Exception as e:
889
+ print(f"e: {e}")
890
+ print(f"type of result_df: {type(result_df)}")
891
+
892
+ return bot_answer, result_df
893
+
894
+ def pairing_two_sentence(q1_input, q2_input):
895
+
896
+ q1_leaf = PreProc.all_preprocess(q1_input)
897
+ q2_leaf = PreProc.all_preprocess(q2_input)
898
+
899
+ if model_choose == "paraphrase-multilingual-mpnet-base-v2":
900
+ q1_vec = model1.encode(q1_input, convert_to_tensor = True)
901
+ q2_vec = model1.encode(q2_input, convert_to_tensor = True)
902
+ elif model_choose == "all-mpnet-base-v2":
903
+ q1_vec = model2.encode(q1_input, convert_to_tensor = True)
904
+ q2_vec = model2.encode(q2_input, convert_to_tensor = True)
905
+
906
+ cosine = util.cos_sim(q1_vec, q2_vec)
907
+ cosine = round(float(cosine[0][0]), 4)
908
+
909
+ q1_display = rule.display_leaves(q1_leaf)
910
+ q2_display = rule.display_leaves(q2_leaf)
911
+
912
+ pairing_df = pd.DataFrame([[q1_display], [q2_display]], columns = ["Leaves"], index = ["配對語句1-leaf","配對語句2-leaf"])
913
+ # pairing_df = pairing_df.astype(str)
914
+
915
+ # st.subheader("文字 轉 Leaf")
916
+ # st.dataframe(pairing_df)
917
+
918
+ idfca = str(rule.IDFCA(q1_leaf, q2_leaf))
919
+ cola = str(rule.COLA(q1_leaf, q2_leaf))
920
+ cosr = str(rule.COSR(q1_leaf, q2_leaf, cosine))
921
+ topn = str(rule.TOPN(q1_leaf, q2_leaf))
922
+
923
+ PN, score = rule.scoring(cosine, idfca, cola, cosr, topn)
924
+
925
+ PN = str(PN)
926
+ score = str(score)
927
+
928
+ data_df = pd.DataFrame(
929
+ columns = ["Cosine", "IDFCA", "COLA", "COSR", "TOPN", "score", "PN"],
930
+ data = [[cosine, idfca, cola, cosr, topn, score, PN]],
931
+ index = ["數據結果"]
932
+ )
933
+ data_df = data_df.astype(str)
934
+ # st.subheader("配對結果")
935
+ # st.dataframe(data_df)
936
+
937
+ return pairing_df, data_df
938
+
939
+ def clear_input():
940
+
941
+ st.session_state["text"] = ""
942
+
943
+ def isChinese(word):
944
+ for ch in word:
945
+ if '\u4e00' <= ch <= '\u9fff':
946
+ return True
947
+ return False
948
+
949
+ def format_chat(mode, sentence):
950
+
951
+ space_line = 0
952
+ full_ct = 0
953
+ half_ct = 0
954
+ tmp_i = 0
955
+ for i, char in enumerate(sentence):
956
+ if isChinese(char): full_ct += 1
957
+ else: half_ct += 1
958
+
959
+ if full_ct + half_ct/2 >= 22:
960
+ if mode == "user":
961
+ st.session_state.user_message.append(sentence[tmp_i:i+1])
962
+ else:
963
+ st.session_state.bot_message.append(sentence[tmp_i:i+1])
964
+
965
+ space_line += 1
966
+ tmp_i = i+1
967
+ full_ct = 0
968
+ half_ct = 0
969
+ if sentence[tmp_i:] != "":
970
+ if mode == "user":
971
+ st.session_state.user_message.append(sentence[tmp_i:])
972
+
973
+ else:
974
+ st.session_state.bot_message.append(sentence[tmp_i:])
975
+
976
+ space_line += 1
977
+
978
+ # space_line = check_full_half(sentence)
979
+ # st.markdown(space_line)
980
+
981
+ for i in range(space_line):
982
+ if mode == "user":
983
+ st.session_state.bot_message.append(f"|")
984
+ else:
985
+ st.session_state.user_message.append(f"|")
986
+
987
+ def multi_chat(template, pre_answer):
988
+
989
+ if pre_answer == "退出":
990
+ st.session_state.multi_question_num = -1
991
+ st.session_state.multi_answer = dict()
992
+ st.session_state.multi_mode = ""
993
+ format_chat("bot", template["closing"])
994
+
995
+ return
996
+
997
+ q_num = st.session_state.multi_question_num
998
+
999
+ if st.session_state.multi_mode == "請假":
1000
+ if q_num == 0:
1001
+ format_chat("bot", template["opening"])
1002
+ format_chat("bot", template[f"Q{q_num+1}"]["question"])
1003
+ st.session_state.multi_question_num += 1
1004
+
1005
+ elif q_num == 1:
1006
+ if pre_answer in template[f"Q{q_num}"]["ans_choice"]:
1007
+ st.session_state.multi_answer[f"Q{q_num}"] = pre_answer
1008
+ st.session_state.multi_question_num += 1
1009
+ q_num += 1
1010
+ format_chat("bot", template[f"Q{q_num}"]["question"])
1011
+ else:
1012
+ format_chat("bot", template[f"Q{q_num}"]["ans_warning"])
1013
+ format_chat("bot", template[f"Q{q_num}"]["question"])
1014
+
1015
+ elif (q_num == 2) or (q_num == 3):
1016
+ time_format = template[f"Q{q_num}"]["ans_format"]
1017
+ try:
1018
+ pre_answer_date = datetime.strptime(pre_answer, time_format)
1019
+ st.session_state.multi_answer[f"Q{q_num}"] = pre_answer_date
1020
+ st.session_state.multi_question_num += 1
1021
+ q_num += 1
1022
+ format_chat("bot", template[f"Q{q_num}"]["question"])
1023
+ except Exception as e:
1024
+ # st.write(e)
1025
+ format_chat("bot", template[f"Q{q_num}"]["ans_warning"])
1026
+ format_chat("bot", template[f"Q{q_num}"]["question"])
1027
+
1028
+ elif q_num == 4:
1029
+ st.session_state.multi_answer[f"Q{q_num}"] = pre_answer
1030
+ st.session_state.multi_question_num += 1
1031
+ q_num += 1
1032
+ format_chat("bot", template[f"Q{q_num}"]["question"])
1033
+
1034
+
1035
+ elif q_num == 5:
1036
+ st.session_state.multi_answer[f"Q{q_num}"] = pre_answer
1037
+
1038
+ format_chat("bot", template["done"])
1039
+ format_chat("bot", "請確認請假資訊")
1040
+
1041
+ format_chat("bot", f"""請假類別:\t{st.session_state.multi_answer[f"Q1"]}""")
1042
+ format_chat("bot", f"""請假開始時間:\t{st.session_state.multi_answer[f"Q2"]}""")
1043
+ format_chat("bot", f"""請假結束時間:\t{st.session_state.multi_answer[f"Q3"]}""")
1044
+ format_chat("bot", f"""職務代理人:\t{st.session_state.multi_answer[f"Q4"]}""")
1045
+ format_chat("bot", f"""請假事由:\t{st.session_state.multi_answer[f"Q5"]}""")
1046
+
1047
+ st.session_state.multi_question_num = -1
1048
+ st.session_state.multi_answer = dict()
1049
+ st.session_state.multi_mode = ""
1050
+
1051
+ elif st.session_state.multi_mode == "客戶資金及資產來源":
1052
+
1053
+ if q_num == 0:
1054
+ format_chat("bot", template["opening"])
1055
+
1056
+ for ele in template[f"Q{q_num+1}"]["question"].split("<br>"):
1057
+ format_chat("bot", ele)
1058
+ # format_chat("bot", template[f"Q{q_num+1}"]["question"])
1059
+ st.session_state.multi_question_num += 1
1060
+ elif q_num == 1:
1061
+ if pre_answer in template[f"Q{q_num}"]["ans_choice"]:
1062
+ st.session_state.multi_answer[f"Q{q_num}"] = f"""{template[f"Q{q_num}"]["ans_choice"][pre_answer]}"""
1063
+ st.session_state.multi_question_num += 1
1064
+ q_num += 1
1065
+ format_chat("bot", template[f"Q{q_num}"]["question"])
1066
+ else:
1067
+ format_chat("bot", template[f"Q{q_num}"]["ans_warning"])
1068
+ for ele in template[f"Q{q_num}"]["question"].split("<br>"):
1069
+ format_chat("bot", ele)
1070
+ # format_chat("bot", template[f"Q{q_num}"]["question"])
1071
+ elif q_num == 2:
1072
+ if pre_answer in template[f"Q{q_num}"]["ans_choice"]:
1073
+ if pre_answer == "有":
1074
+ st.session_state.multi_answer[f"Q{q_num}"] = pre_answer
1075
+ st.session_state.multi_question_num += 1
1076
+ q_num += 1
1077
+ for ele in template[f"Q{q_num}"]["question"].split("<br>"):
1078
+ format_chat("bot", ele)
1079
+ # format_chat("bot", template[f"Q{q_num}"]["question"])
1080
+ else:
1081
+ st.session_state.multi_answer[f"Q{q_num}"] = pre_answer
1082
+
1083
+ format_chat("bot", template["done"])
1084
+ format_chat("bot", "請確認表單資訊")
1085
+ format_chat("bot", f"""年收入區間:\t{st.session_state.multi_answer[f"Q1"]}""")
1086
+ format_chat("bot", f"""無其他資金來源""")
1087
+
1088
+ st.session_state.multi_question_num = -1
1089
+ st.session_state.multi_answer = dict()
1090
+ st.session_state.multi_mode = ""
1091
+
1092
+
1093
+ else:
1094
+ format_chat("bot", template[f"Q{q_num}"]["ans_warning"])
1095
+ for ele in template[f"Q{q_num}"]["question"].split("<br>"):
1096
+ format_chat("bot", ele)
1097
+
1098
+ elif q_num == 3:
1099
+ pre_answer_lst = pre_answer.split(",")
1100
+ option_ok = True
1101
+ for option in pre_answer_lst:
1102
+ if option not in template[f"Q{q_num}"]["ans_choice"]:
1103
+ option_ok = False
1104
+
1105
+ if option_ok:
1106
+ answer = []
1107
+ for option in pre_answer_lst:
1108
+ tmp_ans = template[f"Q{q_num}"]["ans_choice"][option]
1109
+ answer.append(tmp_ans)
1110
+ st.session_state.multi_answer[f"Q{q_num}"] = "、".join(answer)
1111
+
1112
+ if "G" in pre_answer_lst:
1113
+ st.session_state.multi_question_num += 1
1114
+ q_num += 1
1115
+ format_chat("bot", template[f"Q{q_num}"]["question"])
1116
+ else:
1117
+ st.session_state.multi_question_num += 2
1118
+ q_num += 2
1119
+ for ele in template[f"Q{q_num}"]["question"].split("<br>"):
1120
+ format_chat("bot", ele)
1121
+
1122
+
1123
+ else:
1124
+ format_chat("bot", template[f"Q{q_num}"]["ans_warning"])
1125
+ for ele in template[f"Q{q_num}"]["question"].split("<br>"):
1126
+ format_chat("bot", ele)
1127
+
1128
+ elif q_num == 4:
1129
+
1130
+ st.session_state.multi_answer[f"Q{q_num}"] = pre_answer
1131
+ st.session_state.multi_question_num += 1
1132
+ q_num += 1
1133
+ # format_chat("bot", template[f"Q{q_num}"]["question"])
1134
+ for ele in template[f"Q{q_num}"]["question"].split("<br>"):
1135
+ format_chat("bot", ele)
1136
+ # format_chat("bot", template[f"Q{q_num}"]["question"])
1137
+
1138
+ elif q_num == 5:
1139
+
1140
+ if pre_answer in template[f"Q{q_num}"]["ans_choice"]:
1141
+ st.session_state.multi_answer[f"Q{q_num}"] = f"""{template[f"Q{q_num}"]["ans_choice"][pre_answer]}"""
1142
+ st.session_state.multi_question_num += 1
1143
+ q_num += 1
1144
+ format_chat("bot", template[f"Q{q_num}"]["question"])
1145
+ else:
1146
+ format_chat("bot", template[f"Q{q_num}"]["ans_warning"])
1147
+ for ele in template[f"Q{q_num}"]["question"].split("<br>"):
1148
+ format_chat("bot", ele)
1149
+ # format_chat("bot", template[f"Q{q_num}"]["question"])
1150
+
1151
+ elif q_num == 6:
1152
+
1153
+ st.session_state.multi_answer[f"Q{q_num}"] = pre_answer
1154
+
1155
+ format_chat("bot", template["done"])
1156
+ format_chat("bot", "請確認表單資訊")
1157
+ format_chat("bot", f"""年收入區間:\t{st.session_state.multi_answer[f"Q1"]}""")
1158
+
1159
+ if "Q4" in st.session_state.multi_answer:
1160
+ format_chat("bot", f"""其他資金來源:\t{st.session_state.multi_answer[f"Q3"]}({st.session_state.multi_answer[f"Q4"]})""")
1161
+ else:
1162
+ format_chat("bot", f"""其他資金來源:\t{st.session_state.multi_answer[f"Q3"]}""")
1163
+ format_chat("bot", f"""其他資金來源總金額:\t{st.session_state.multi_answer[f"Q5"]}""")
1164
+ format_chat("bot", f"""其他資金來源相關資訊:\t{st.session_state.multi_answer[f"Q6"]}""")
1165
+
1166
+
1167
+ st.session_state.multi_question_num = -1
1168
+ st.session_state.multi_answer = dict()
1169
+ st.session_state.multi_mode = ""
1170
+
1171
+
1172
+
1173
+ if __name__ == "__main__":
1174
+
1175
+ global PreProc
1176
+ global FaissProc
1177
+ global rule
1178
+ global diff_brief_dict, diffQ
1179
+ global diffQ_leaf
1180
+ global diffQ_faissIndex_model1, diffQ_faissIndex_model2
1181
+ global qa_dict
1182
+ global multi_chat_template
1183
+
1184
+ PreProc = PreProcess()
1185
+ FaissProc = FaissProcess()
1186
+
1187
+ diff_brief_dict, diffQ, qa_dict = prepare_diffQ_content()
1188
+ diffQ_leaf = prepare_diffQ_leaf(diffQ)
1189
+ diffQ_faissIndex_model1 = prepare_diffQ_faissIndex(1)
1190
+ diffQ_faissIndex_model2 = prepare_diffQ_faissIndex(2)
1191
+
1192
+ loan_diff_brief_dict, loan_diffQ, loan_qa_dict = prepare_loan_diffQ_content()
1193
+ loan_diffQ_leaf = prepare_loan_diffQ_leaf(loan_diffQ)
1194
+ loan_diffQ_faissIndex_model1 = prepare_loan_diffQ_faissIndex(1)
1195
+ loan_diffQ_faissIndex_model2 = prepare_loan_diffQ_faissIndex(2)
1196
+
1197
+ json_path = os.path.join(data_dir, "chatpot_multi-round.json")
1198
+ with open(json_path, mode = "r", encoding = "utf-8") as r:
1199
+ multi_chat_template = json.load(r)
1200
+
1201
+ with st.sidebar:
1202
+ function_choose = option_menu("功能選擇", ["搜尋測試", "配對測試", "聊天測試"],
1203
+ icons=['question-circle', 'search', 'chat-dots'],
1204
+ menu_icon = "list", default_index=0,
1205
+ styles={
1206
+ "container": {"padding": "5!important", "background-color": "#fafafa"},
1207
+ "icon": {"color": "black", "font-size": "25px"},
1208
+ "nav-link": {"font-size": "20px", "text-align": "left", "margin":"0px", "--hover-color": "#eee"},
1209
+ "nav-link-selected": {"background-color": "#272ba8"},
1210
+ }
1211
+ )
1212
+
1213
+ st.subheader("標準問答 資料")
1214
+ data_choose = st.selectbox(
1215
+ "QA 選單",
1216
+ (
1217
+ "銀行QA",
1218
+ "信貸QA"
1219
+ )
1220
+ )
1221
+
1222
+ st.subheader("SBERT 模型")
1223
+ model_choose = st.selectbox(
1224
+ "模型選單",
1225
+ (
1226
+ "paraphrase-multilingual-mpnet-base-v2",
1227
+ "all-mpnet-base-v2"
1228
+ )
1229
+ )
1230
+ st.subheader("Cosine 調整")
1231
+ lower_thres = st.slider('LOWER_threshold', 0.0, 1.0, 0.7)
1232
+
1233
+ default_mid = 0.75
1234
+ if lower_thres > 0.75: default_mid = lower_thres
1235
+ middle_thres = st.slider('MIDDLE_threshold', lower_thres, 1.0, default_mid)
1236
+
1237
+ default_up = 0.85
1238
+ if middle_thres > 0.85: default_up = middle_thres
1239
+ upper_thres = st.slider('UPPER_threshold', middle_thres, 1.0, default_up)
1240
+
1241
+
1242
+ rule = PairingRule(leaf_IDF_dict, lower_thres, middle_thres, upper_thres)
1243
+
1244
+
1245
+ if function_choose == "搜尋測試":
1246
+
1247
+ st.header("搜尋測試")
1248
+ form = st.form(key = 'Question pairing')
1249
+
1250
+ userQ_input = form.text_input(label = '輸入的問題將會與PQA的"變化問題"做匹配', placeholder = "請輸入要搜尋的問題")
1251
+ submit_button = form.form_submit_button(label = 'Submit')
1252
+
1253
+ if submit_button:
1254
+
1255
+ bot_answer, result_df = pairing_search(str(userQ_input))
1256
+
1257
+ st.subheader("回答")
1258
+ if bot_answer:
1259
+ for a, ans in enumerate(bot_answer):
1260
+ st.text(f"{a+1}.\t{ans}")
1261
+ else:
1262
+ st.text("**沒有適合的回答**")
1263
+ st.subheader("搜尋結果")
1264
+ st.dataframe(result_df)
1265
+
1266
+ elif function_choose == "配對測試":
1267
+
1268
+ st.header("配對測試")
1269
+ pairing_form = st.form(key = 'input_pairing')
1270
+
1271
+ q1_input = pairing_form.text_input(label = "��對語句1")
1272
+ q2_input = pairing_form.text_input(label = "配對語句2")
1273
+
1274
+ pair_button = pairing_form.form_submit_button(label = "Submit")
1275
+ if pair_button:
1276
+
1277
+ pairing_df, data_df = pairing_two_sentence(q1_input, q2_input)
1278
+
1279
+ st.subheader("文字 轉 Leaf")
1280
+ st.dataframe(pairing_df)
1281
+
1282
+ st.subheader("配對結果")
1283
+ st.dataframe(data_df)
1284
+
1285
+ elif function_choose == "聊天測試":
1286
+
1287
+ st.header("聊天測試")
1288
+
1289
+ if "user_message" not in st.session_state:
1290
+ st.session_state.user_message = []
1291
+ if "bot_message" not in st.session_state:
1292
+ st.session_state.bot_message = []
1293
+
1294
+ if "multi_mode" not in st.session_state:
1295
+ st.session_state.multi_mode = ""
1296
+ if "multi_question_num" not in st.session_state:
1297
+ st.session_state.multi_question_num = -1
1298
+ if "multi_answer" not in st.session_state:
1299
+ st.session_state.multi_answer = dict()
1300
+
1301
+ reset_button = st.button(label = "清空聊天室")
1302
+
1303
+ col1, col2 = st.columns(2)
1304
+
1305
+ form = st.form(key = 'Chatting', clear_on_submit = True)
1306
+
1307
+ input_ = form.text_input("USER:")
1308
+ submit_button = form.form_submit_button(label = 'Submit')
1309
+
1310
+ if "請假" in input_:
1311
+
1312
+ if st.session_state.multi_question_num == -1:
1313
+ st.session_state.multi_question_num = 0
1314
+ if st.session_state.multi_mode == "":
1315
+ st.session_state.multi_mode = "請假"
1316
+ elif "客戶資金及資產來源" in input_:
1317
+
1318
+ if st.session_state.multi_question_num == -1:
1319
+ st.session_state.multi_question_num = 0
1320
+ if st.session_state.multi_mode == "":
1321
+ st.session_state.multi_mode = "客戶資金及資產來源"
1322
+
1323
+
1324
+ if submit_button:
1325
+
1326
+ if len(input_) == 0:
1327
+ st.session_state.bot_message.append("請輸入文字後再送出")
1328
+ st.session_state.user_message.append("|")
1329
+ else:
1330
+ format_chat("user", input_)
1331
+ if st.session_state.multi_question_num == -1:
1332
+
1333
+
1334
+ bot_answer, result_df = pairing_search(str(input_))
1335
+
1336
+ bot_reply = ""
1337
+ if bot_answer:
1338
+ bot_reply = bot_answer[0]
1339
+ else:
1340
+ bot_reply = "抱歉,我不清楚你的問題(沒有匹配的變化問題)"
1341
+
1342
+ format_chat("bot", bot_reply)
1343
+
1344
+ else:
1345
+
1346
+ multi_chat(multi_chat_template[st.session_state.multi_mode], input_)
1347
+
1348
+
1349
+ if reset_button:
1350
+
1351
+ st.session_state.user_message = []
1352
+ st.session_state.bot_message = []
1353
+
1354
+
1355
+ if len(st.session_state.user_message) > 0:
1356
+
1357
+ user_bot_max = max(len(st.session_state.user_message), len(st.session_state.bot_message))
1358
+
1359
+ for i in range(user_bot_max):
1360
+
1361
+ try:
1362
+ col2.write(f"{st.session_state.user_message[i]}")
1363
+ except: pass
1364
+
1365
+ try:
1366
+ col1.write(f"{st.session_state.bot_message[i]}")
1367
+ except: pass
1368
+
1369
+
1370
+
1371
+
1372
+
1373
+
1374
+