update 0722
Browse files
app.py
CHANGED
@@ -39,13 +39,6 @@ tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
|
39 |
model = AutoModel.from_pretrained("bert-base-chinese")
|
40 |
|
41 |
|
42 |
-
|
43 |
-
# 本地加载数据
|
44 |
-
dataset = load_data(file_path='train_300.json', num_samples=300)
|
45 |
-
keyword_lists = [item['content'] for item in dataset if 'content' in item]
|
46 |
-
summary_lists = [item['summary'] for item in dataset if 'summary' in item]
|
47 |
-
|
48 |
-
|
49 |
global_api_key = None
|
50 |
client = None
|
51 |
|
@@ -53,10 +46,6 @@ def initialize_clients(api_key):
|
|
53 |
global client
|
54 |
client = OpenAI(api_key=api_key)
|
55 |
|
56 |
-
|
57 |
-
for item in keyword_lists:
|
58 |
-
item = item.split(',')
|
59 |
-
|
60 |
def get_keywords(message):
|
61 |
system_message = """
|
62 |
# 角色
|
@@ -83,8 +72,7 @@ def get_keywords(message):
|
|
83 |
return ','.join(keywords)
|
84 |
|
85 |
|
86 |
-
|
87 |
-
|
88 |
def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,distance_threshold):
|
89 |
distance = 0
|
90 |
most_matching_list = None
|
@@ -93,7 +81,6 @@ def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, cu
|
|
93 |
# query_keywords = query_keywords.split(',')
|
94 |
# query_keywords = [keyword for keyword in query_keywords if keyword]
|
95 |
|
96 |
-
|
97 |
#匹配模块
|
98 |
query_keywords= list(query_keywords_dict.keys())
|
99 |
|
@@ -126,6 +113,7 @@ def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, cu
|
|
126 |
|
127 |
return distance, index
|
128 |
|
|
|
129 |
def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
|
130 |
if torch.cuda.is_available():
|
131 |
print('Using GPU')
|
@@ -150,6 +138,7 @@ def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
|
|
150 |
|
151 |
return avg_embeddings
|
152 |
|
|
|
153 |
def encode_to_avg(keywords_dict, model, tokenizer, device):
|
154 |
if torch.cuda.is_available():
|
155 |
print('Using GPU')
|
@@ -176,33 +165,6 @@ def encode_to_avg(keywords_dict, model, tokenizer, device):
|
|
176 |
return avg_embedding.tolist()
|
177 |
|
178 |
|
179 |
-
# def fetch_response_from_db(query_keywords,class_name):
|
180 |
-
|
181 |
-
# avg_vec=np.array(encode_list_to_avg([query_keywords], model, tokenizer, device)[0])
|
182 |
-
# nearVector = {
|
183 |
-
# 'vector': avg_vec
|
184 |
-
# }
|
185 |
-
|
186 |
-
# response = (
|
187 |
-
# db_client.query
|
188 |
-
# .get(class_name, ['keywords', 'summary'])
|
189 |
-
# .with_near_vector(nearVector)
|
190 |
-
# .with_limit(1)
|
191 |
-
# .with_additional(['distance'])
|
192 |
-
# .do()
|
193 |
-
# )
|
194 |
-
|
195 |
-
# print(response)
|
196 |
-
# class_name=class_name[0].upper()+class_name[1:]
|
197 |
-
|
198 |
-
# if class_name in response['data']['Get']:
|
199 |
-
# results = response['data']['Get'][class_name]
|
200 |
-
# return results[0]['_additional']['distance'],results[0]['summary'], results[0]['keywords']
|
201 |
-
# else:
|
202 |
-
# print(f"Class name {class_name} not found in response")
|
203 |
-
# return None
|
204 |
-
|
205 |
-
|
206 |
def fetch_response_from_db(query_keywords_dict,class_name):
|
207 |
|
208 |
|
@@ -230,13 +192,15 @@ def fetch_response_from_db(query_keywords_dict,class_name):
|
|
230 |
print(f"Class name {class_name} not found in response")
|
231 |
return None
|
232 |
|
233 |
-
|
|
|
|
|
234 |
initialize_clients(api_key)
|
235 |
-
return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered)
|
236 |
|
237 |
|
238 |
#触发词及触发回合字典
|
239 |
-
|
240 |
|
241 |
def respond(
|
242 |
message,
|
@@ -247,10 +211,11 @@ def respond(
|
|
247 |
window_size,
|
248 |
distance_threshold,
|
249 |
weight_keywords_users,
|
250 |
-
weight_keywords_triggered
|
|
|
251 |
):
|
252 |
|
253 |
-
|
254 |
system_message_with_ad = """
|
255 |
# 角色
|
256 |
你是一个热情的聊天机器人
|
@@ -291,9 +256,9 @@ def respond(
|
|
291 |
keywords_dict={}
|
292 |
for keywords in key_words_users:
|
293 |
if keywords in keywords_dict:
|
294 |
-
keywords_dict[keywords]+=
|
295 |
else:
|
296 |
-
keywords_dict[keywords]=
|
297 |
for keywords in key_words_assistant:
|
298 |
if keywords in keywords_dict:
|
299 |
keywords_dict[keywords]+=1
|
@@ -355,7 +320,7 @@ def respond(
|
|
355 |
top_p=top_p,
|
356 |
)
|
357 |
|
358 |
-
return response.choices[0].message.content
|
359 |
|
360 |
|
361 |
# def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
|
@@ -388,7 +353,8 @@ demo = gr.ChatInterface(
|
|
388 |
gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
389 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
390 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
391 |
-
gr.Textbox(label="api_key")
|
|
|
392 |
],
|
393 |
)
|
394 |
|
|
|
39 |
model = AutoModel.from_pretrained("bert-base-chinese")
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
global_api_key = None
|
43 |
client = None
|
44 |
|
|
|
46 |
global client
|
47 |
client = OpenAI(api_key=api_key)
|
48 |
|
|
|
|
|
|
|
|
|
49 |
def get_keywords(message):
|
50 |
system_message = """
|
51 |
# 角色
|
|
|
72 |
return ','.join(keywords)
|
73 |
|
74 |
|
75 |
+
#字符串匹配模块
|
|
|
76 |
def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,distance_threshold):
|
77 |
distance = 0
|
78 |
most_matching_list = None
|
|
|
81 |
# query_keywords = query_keywords.split(',')
|
82 |
# query_keywords = [keyword for keyword in query_keywords if keyword]
|
83 |
|
|
|
84 |
#匹配模块
|
85 |
query_keywords= list(query_keywords_dict.keys())
|
86 |
|
|
|
113 |
|
114 |
return distance, index
|
115 |
|
116 |
+
|
117 |
def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
|
118 |
if torch.cuda.is_available():
|
119 |
print('Using GPU')
|
|
|
138 |
|
139 |
return avg_embeddings
|
140 |
|
141 |
+
|
142 |
def encode_to_avg(keywords_dict, model, tokenizer, device):
|
143 |
if torch.cuda.is_available():
|
144 |
print('Using GPU')
|
|
|
165 |
return avg_embedding.tolist()
|
166 |
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
def fetch_response_from_db(query_keywords_dict,class_name):
|
169 |
|
170 |
|
|
|
192 |
print(f"Class name {class_name} not found in response")
|
193 |
return None
|
194 |
|
195 |
+
|
196 |
+
|
197 |
+
def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key, state):
|
198 |
initialize_clients(api_key)
|
199 |
+
return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, state)
|
200 |
|
201 |
|
202 |
#触发词及触发回合字典
|
203 |
+
|
204 |
|
205 |
def respond(
|
206 |
message,
|
|
|
211 |
window_size,
|
212 |
distance_threshold,
|
213 |
weight_keywords_users,
|
214 |
+
weight_keywords_triggered,
|
215 |
+
triggered_keywords
|
216 |
):
|
217 |
|
218 |
+
triggered_keywords=triggered_keywords or {}
|
219 |
system_message_with_ad = """
|
220 |
# 角色
|
221 |
你是一个热情的聊天机器人
|
|
|
256 |
keywords_dict={}
|
257 |
for keywords in key_words_users:
|
258 |
if keywords in keywords_dict:
|
259 |
+
keywords_dict[keywords]+=weight_keywords_users
|
260 |
else:
|
261 |
+
keywords_dict[keywords]=weight_keywords_users
|
262 |
for keywords in key_words_assistant:
|
263 |
if keywords in keywords_dict:
|
264 |
keywords_dict[keywords]+=1
|
|
|
320 |
top_p=top_p,
|
321 |
)
|
322 |
|
323 |
+
return response.choices[0].message.content , triggered_keywords
|
324 |
|
325 |
|
326 |
# def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
|
|
|
353 |
gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
354 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
355 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
356 |
+
gr.Textbox(label="api_key"),
|
357 |
+
'state'
|
358 |
],
|
359 |
)
|
360 |
|