rollback
Browse files
app.py
CHANGED
@@ -190,9 +190,6 @@
|
|
190 |
|
191 |
|
192 |
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
import gradio as gr
|
197 |
from huggingface_hub import InferenceClient
|
198 |
import json
|
@@ -202,61 +199,47 @@ from load_data import load_data
|
|
202 |
from openai import OpenAI
|
203 |
from transformers import AutoTokenizer, AutoModel
|
204 |
import weaviate
|
205 |
-
import os
|
206 |
-
import subprocess
|
207 |
import torch
|
208 |
from tqdm import tqdm
|
209 |
import numpy as np
|
210 |
import time
|
211 |
|
212 |
-
|
213 |
-
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
214 |
-
|
215 |
-
os.environ['
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
#
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
)
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
def initialize_clients(api_key):
|
242 |
-
global client
|
243 |
-
client = OpenAI(api_key=api_key)
|
244 |
-
|
245 |
-
def get_keywords(message):
|
246 |
-
system_message = """
|
247 |
-
# 角色
|
248 |
-
你是一个关键词提取机器人
|
249 |
-
# 指令
|
250 |
-
你的目标是从用户的输入中提取关键词,这些关键词应该尽可能是购买意图相关的。关键词中应该尽可能注意那些名词和形容词
|
251 |
-
# 输出格式
|
252 |
-
你应该直接输出关键词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙
|
253 |
-
# 注意:如果输入文本过短可以重复输出关键词,例如对输入“你好”可以输出:你好 你好 你好 你好 你好
|
254 |
"""
|
|
|
|
|
|
|
|
|
255 |
|
256 |
-
|
257 |
-
messages.append({"role": "user", "content": f"从下面的文本中给我提取五个关键词,只输出这五个关键词,以空格分隔{message}"})
|
258 |
-
|
259 |
-
response = client.chat.completions.create(
|
260 |
model="gpt-3.5-turbo",
|
261 |
messages=messages,
|
262 |
max_tokens=100,
|
@@ -267,321 +250,580 @@ def get_keywords(message):
|
|
267 |
keywords = response.choices[0].message.content.split(' ')
|
268 |
return ','.join(keywords)
|
269 |
|
|
|
|
|
|
|
270 |
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
query_keywords= list(query_keywords_dict.keys())
|
282 |
-
|
283 |
-
for i, lst in enumerate(ad_keywords_lists):
|
284 |
-
lst = lst.split(',')
|
285 |
-
matches = sum(
|
286 |
-
any(
|
287 |
-
ad_keyword in keyword and
|
288 |
-
(
|
289 |
-
keyword not in triggered_keywords or
|
290 |
-
triggered_keywords.get(keyword) is None or
|
291 |
-
current_turn - triggered_keywords.get(keyword, 0) > window_size
|
292 |
-
) * query_keywords_dict.get(keyword, 1) #计数乘以权重
|
293 |
-
for keyword in query_keywords
|
294 |
-
)
|
295 |
-
for ad_keyword in lst
|
296 |
-
)
|
297 |
-
if matches > distance:
|
298 |
-
distance = matches
|
299 |
-
most_matching_list = lst
|
300 |
-
index = i
|
301 |
-
|
302 |
-
#更新对distance 有贡献的关键词
|
303 |
-
if distance >= distance_threshold:
|
304 |
-
for keyword in query_keywords:
|
305 |
-
if any(
|
306 |
-
ad_keyword in keyword for ad_keyword in most_matching_list
|
307 |
-
):
|
308 |
-
triggered_keywords[keyword] = current_turn
|
309 |
-
|
310 |
-
return distance, index
|
311 |
|
|
|
312 |
|
313 |
-
def
|
314 |
-
|
315 |
-
|
316 |
-
print(device)
|
317 |
-
else:
|
318 |
-
print('Using CPU')
|
319 |
-
print(device)
|
320 |
-
|
321 |
-
avg_embeddings = []
|
322 |
-
for keywords in tqdm(keywords_list_list):
|
323 |
-
keywords_lst=[]
|
324 |
-
# keywords.split(',')
|
325 |
-
for keyword in keywords:
|
326 |
-
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
327 |
-
inputs.to(device)
|
328 |
-
with torch.no_grad():
|
329 |
-
outputs = model(**inputs)
|
330 |
-
embeddings = outputs.last_hidden_state.mean(dim=1)
|
331 |
-
keywords_lst.append(embeddings)
|
332 |
-
avg_embedding = sum(keywords_lst) / len(keywords_lst)
|
333 |
-
avg_embeddings.append(avg_embedding)
|
334 |
-
|
335 |
-
return avg_embeddings
|
336 |
-
|
337 |
-
|
338 |
-
def encode_to_avg(keywords_dict, model, tokenizer, device):
|
339 |
-
if torch.cuda.is_available():
|
340 |
-
print('Using GPU')
|
341 |
-
print(device)
|
342 |
-
else:
|
343 |
-
print('Using CPU')
|
344 |
-
print(device)
|
345 |
-
|
346 |
-
|
347 |
-
keyword_embeddings=[]
|
348 |
-
for keyword, weight in keywords_dict.items():
|
349 |
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
350 |
inputs.to(device)
|
351 |
with torch.no_grad():
|
352 |
outputs = model(**inputs)
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
keyword_embeddings.append(keyword_embedding * weight)
|
358 |
-
|
359 |
-
avg_embedding = sum(keyword_embeddings) / sum(keywords_dict.values())
|
360 |
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
|
371 |
|
372 |
-
|
373 |
-
|
374 |
-
}
|
375 |
-
start_time = time.time()
|
376 |
response = (
|
377 |
-
|
378 |
.get(class_name, ['keywords', 'summary'])
|
379 |
-
.with_near_vector(
|
380 |
.with_limit(1)
|
381 |
.with_additional(['distance'])
|
382 |
.do()
|
383 |
)
|
384 |
-
end_time = time.time()
|
385 |
-
print(f"Time taken to search in DB: {end_time - start_time}")
|
386 |
-
|
387 |
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
if class_name in response['data']['Get']:
|
392 |
-
results = response['data']['Get'][class_name]
|
393 |
-
print(results[0]['keywords'])
|
394 |
-
return results[0]['_additional']['distance'],results[0]['summary'], results[0]['keywords']
|
395 |
-
|
396 |
else:
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
-
def respond(
|
411 |
-
message,
|
412 |
-
history: list[tuple[str, str]],
|
413 |
-
max_tokens,
|
414 |
-
temperature,
|
415 |
-
top_p,
|
416 |
-
window_size,
|
417 |
-
distance_threshold,
|
418 |
-
weight_keywords_users,
|
419 |
-
weight_keywords_triggered,
|
420 |
|
421 |
-
):
|
422 |
-
|
423 |
-
system_message_with_ad = """
|
424 |
-
# 角色
|
425 |
-
你是一个热情的聊天机器人
|
426 |
-
# 指令
|
427 |
-
你的目标是以一种微妙的方式将给定的广告内容嵌入对话中。插入广告时应尽量自然,避免突兀。优先正常回答用户问题,然后逐步以推荐方式引入广告内容。
|
428 |
-
注意要在商品的描述前面加上是来自哪个品牌的广告。
|
429 |
-
注意在推荐中不要脑补用户的身份,只是进行简单推荐。
|
430 |
-
注意要热情但是语气只要适度热情
|
431 |
-
# 输入格式
|
432 |
-
用户查询后跟随广告品牌,用<sep>分隔,广告品牌后跟随广告描述,再用<sep>分隔。
|
433 |
-
例如:我想买一条阔腿裤 <sep> 腾讯 <sep> 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。
|
434 |
-
注意: 当没有<sep>时,正常回复用户,不插入广告。
|
435 |
-
# 输出格式
|
436 |
-
始终使用中文,只输出聊天内容,不输出任何自我分析的信息
|
437 |
-
"""
|
438 |
|
439 |
-
|
440 |
-
你是一个热情的聊天机器人
|
441 |
-
"""
|
442 |
-
print(f"triggered_keywords{triggered_keywords}")
|
443 |
-
# 更新当前轮次
|
444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
current_turn = len(history) + 1
|
446 |
-
print(f"\ncurrent_turn: {current_turn}\n")
|
447 |
|
448 |
-
|
449 |
-
|
450 |
-
combined_message_user = " ".join([h[0] for h in history[-window_size:] if h[0]] + [message])
|
451 |
-
combined_message_assistant=" ".join(h[1] for h in history[-window_size:] if h[1])
|
452 |
-
else:
|
453 |
-
combined_message_user = message
|
454 |
-
combined_message_assistant = ""
|
455 |
-
|
456 |
-
start_time = time.time()
|
457 |
-
key_words_users=get_keywords(combined_message_user).split(',')
|
458 |
-
key_words_assistant=get_keywords(combined_message_assistant).split(',')
|
459 |
-
end_time = time.time()
|
460 |
-
print(f"Time taken to get keywords: {end_time - start_time}")
|
461 |
-
|
462 |
-
print(f"Initial keywords_users: {key_words_users}")
|
463 |
-
print(f"Initial keywords_assistant: {key_words_assistant}")
|
464 |
-
|
465 |
-
keywords_dict = {}
|
466 |
-
added_keywords = set()
|
467 |
-
|
468 |
-
for keywords in key_words_users:
|
469 |
-
if keywords not in added_keywords:
|
470 |
-
if keywords in keywords_dict:
|
471 |
-
keywords_dict[keywords] += weight_keywords_users
|
472 |
-
else:
|
473 |
-
keywords_dict[keywords] = weight_keywords_users
|
474 |
-
added_keywords.add(keywords)
|
475 |
-
|
476 |
-
for keywords in key_words_assistant:
|
477 |
-
if keywords not in added_keywords:
|
478 |
-
if keywords in keywords_dict:
|
479 |
-
keywords_dict[keywords] += 1
|
480 |
-
else:
|
481 |
-
keywords_dict[keywords] = 1
|
482 |
-
added_keywords.add(keywords)
|
483 |
-
|
484 |
-
#窗口内触发过的关键词权重下调为0.5
|
485 |
-
for keyword in list(keywords_dict.keys()):
|
486 |
-
if keyword in triggered_keywords:
|
487 |
-
if current_turn - triggered_keywords[keyword] < window_size:
|
488 |
-
keywords_dict[keyword] = weight_keywords_triggered
|
489 |
-
|
490 |
-
query_keywords = list(keywords_dict.keys())
|
491 |
-
print(keywords_dict)
|
492 |
-
|
493 |
-
start_time = time.time()
|
494 |
-
distance,top_keywords_list,top_summary = fetch_response_from_db(keywords_dict,class_name)
|
495 |
-
end_time = time.time()
|
496 |
-
print(f"Time taken to fetch response from db: {end_time - start_time}")
|
497 |
|
|
|
|
|
|
|
498 |
|
499 |
-
|
|
|
|
|
|
|
500 |
|
501 |
-
|
502 |
-
|
|
|
503 |
|
504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
|
506 |
-
for val in history:
|
507 |
-
if val[0]:
|
508 |
-
messages.append({"role": "user", "content": val[0]})
|
509 |
-
if val[1]:
|
510 |
-
messages.append({"role": "assistant", "content": val[1]})
|
511 |
|
512 |
-
|
513 |
-
|
514 |
-
|
|
|
|
|
|
|
515 |
|
516 |
-
|
517 |
-
|
518 |
-
if any(
|
519 |
-
ad_keyword in keyword for ad_keyword in top_keywords_list
|
520 |
-
):
|
521 |
triggered_keywords[keyword] = current_turn
|
522 |
-
|
523 |
else:
|
524 |
-
messages = [{"role": "system", "content":
|
525 |
-
|
526 |
-
for val in history:
|
527 |
-
if val[0]:
|
528 |
-
messages.append({"role": "user", "content": val[0]})
|
529 |
-
if val[1]:
|
530 |
-
messages.append({"role": "assistant", "content": val[1]})
|
531 |
-
|
532 |
messages.append({"role": "user", "content": message})
|
533 |
|
534 |
-
|
535 |
-
response =
|
536 |
model="gpt-3.5-turbo",
|
537 |
messages=messages,
|
538 |
max_tokens=max_tokens,
|
539 |
temperature=temperature,
|
540 |
top_p=top_p,
|
541 |
)
|
542 |
-
end_time = time.time()
|
543 |
-
print(f"Time taken to get response from GPT: {end_time - start_time}")
|
544 |
-
|
545 |
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
# def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
|
550 |
-
# global triggered_keywords
|
551 |
-
# response, triggered_keywords = respond(
|
552 |
-
# message,
|
553 |
-
# history,
|
554 |
-
# max_tokens,
|
555 |
-
# temperature,
|
556 |
-
# top_p,
|
557 |
-
# window_size,
|
558 |
-
# distance_threshold,
|
559 |
-
# triggered_keywords
|
560 |
-
# )
|
561 |
-
# return response, history + [(message, response)]
|
562 |
|
|
|
563 |
demo = gr.ChatInterface(
|
564 |
-
|
565 |
additional_inputs=[
|
566 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
567 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
568 |
-
gr.Slider(
|
569 |
-
minimum=0.1,
|
570 |
-
maximum=1.0,
|
571 |
-
value=0.95,
|
572 |
-
step=0.05,
|
573 |
-
label="Top-p (nucleus sampling)",
|
574 |
-
),
|
575 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
576 |
gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
577 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
578 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
579 |
-
gr.Textbox(label="
|
580 |
],
|
581 |
)
|
582 |
|
583 |
if __name__ == "__main__":
|
584 |
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
# import gradio as gr
|
587 |
# from huggingface_hub import InferenceClient
|
|
|
190 |
|
191 |
|
192 |
|
|
|
|
|
|
|
193 |
import gradio as gr
|
194 |
from huggingface_hub import InferenceClient
|
195 |
import json
|
|
|
199 |
from openai import OpenAI
|
200 |
from transformers import AutoTokenizer, AutoModel
|
201 |
import weaviate
|
202 |
+
import os
|
|
|
203 |
import torch
|
204 |
from tqdm import tqdm
|
205 |
import numpy as np
|
206 |
import time
|
207 |
|
208 |
+
# 设置缓存目录
|
209 |
+
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
210 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
211 |
+
os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
|
212 |
+
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
213 |
+
|
214 |
+
# Weaviate 连接配置
|
215 |
+
WEAVIATE_API_KEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
|
216 |
+
WEAVIATE_URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
|
217 |
+
weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
|
218 |
+
weaviate_client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate_auth_config)
|
219 |
+
|
220 |
+
# 预训练模型配置
|
221 |
+
MODEL_NAME = "bert-base-chinese"
|
222 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
223 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
224 |
+
model = AutoModel.from_pretrained(MODEL_NAME)
|
225 |
+
|
226 |
+
# OpenAI 客户端
|
227 |
+
openai_client = None
|
228 |
+
|
229 |
+
def initialize_openai_client(api_key):
|
230 |
+
global openai_client
|
231 |
+
openai_client = OpenAI(api_key=api_key)
|
232 |
+
|
233 |
+
def extract_keywords(text):
|
234 |
+
prompt = """
|
235 |
+
你是一个关键词提取机器人。提取用户输入中的关键词,特别是名词和形容词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
"""
|
237 |
+
messages = [
|
238 |
+
{"role": "system", "content": prompt},
|
239 |
+
{"role": "user", "content": f"从下面的文本中提取五个关键词,以空格分隔:{text}"}
|
240 |
+
]
|
241 |
|
242 |
+
response = openai_client.chat.completions.create(
|
|
|
|
|
|
|
243 |
model="gpt-3.5-turbo",
|
244 |
messages=messages,
|
245 |
max_tokens=100,
|
|
|
250 |
keywords = response.choices[0].message.content.split(' ')
|
251 |
return ','.join(keywords)
|
252 |
|
253 |
+
# def match_keywords(query_keywords, ad_keywords_list, triggered_keywords, current_turn, window_size, threshold):
|
254 |
+
# best_match_distance = 0
|
255 |
+
# best_match_index = -1
|
256 |
|
257 |
+
# for i, ad_keywords in enumerate(ad_keywords_list):
|
258 |
+
# match_count = sum(
|
259 |
+
# any(
|
260 |
+
# ad_keyword in keyword and
|
261 |
+
# (keyword not in triggered_keywords or current_turn - triggered_keywords[keyword] > window_size)
|
262 |
+
# ) for keyword in query_keywords
|
263 |
+
# )
|
264 |
+
# if match_count > best_match_distance:
|
265 |
+
# best_match_distance = match_count
|
266 |
+
# best_match_index = i
|
267 |
|
268 |
+
# if best_match_distance >= threshold:
|
269 |
+
# for keyword in query_keywords:
|
270 |
+
# if any(ad_keyword in keyword for ad_keyword in ad_keywords_list[best_match_index]):
|
271 |
+
# triggered_keywords[keyword] = current_turn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
+
# return best_match_distance, best_match_index
|
274 |
|
275 |
+
def encode_keywords_to_avg(keywords, model, tokenizer, device):
|
276 |
+
embeddings = []
|
277 |
+
for keyword in tqdm(keywords):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
279 |
inputs.to(device)
|
280 |
with torch.no_grad():
|
281 |
outputs = model(**inputs)
|
282 |
+
embeddings.append(outputs.last_hidden_state.mean(dim=1))
|
283 |
+
avg_embedding = sum(embeddings) / len(embeddings)
|
284 |
+
return avg_embedding
|
|
|
|
|
|
|
|
|
285 |
|
286 |
+
def encode_keywords_to_list(keywords, model, tokenizer, device):
|
287 |
+
embeddings = []
|
288 |
+
for keyword in tqdm(keywords):
|
289 |
+
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
290 |
+
inputs.to(device)
|
291 |
+
with torch.no_grad():
|
292 |
+
outputs = model(**inputs)
|
293 |
+
embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().tolist())
|
294 |
+
return embeddings
|
295 |
|
296 |
|
297 |
+
def get_response_from_db(keywords_dict, class_name):
|
298 |
+
avg_vec = encode_keywords_to_avg(keywords_dict.keys(), model, tokenizer, device).numpy()
|
|
|
|
|
299 |
response = (
|
300 |
+
weaviate_client.query
|
301 |
.get(class_name, ['keywords', 'summary'])
|
302 |
+
.with_near_vector({'vector': avg_vec})
|
303 |
.with_limit(1)
|
304 |
.with_additional(['distance'])
|
305 |
.do()
|
306 |
)
|
|
|
|
|
|
|
307 |
|
308 |
+
if class_name.capitalize() in response['data']['Get']:
|
309 |
+
result = response['data']['Get'][class_name.capitalize()][0]
|
310 |
+
return result['_additional']['distance'], result['summary'], result['keywords']
|
|
|
|
|
|
|
|
|
|
|
311 |
else:
|
312 |
+
return None, None, None
|
313 |
+
|
314 |
+
def get_candidates_from_db(keywords_dict, class_name,limit=3):
|
315 |
+
embeddings= encode_keywords_to_list(keywords_dict.keys(), model, tokenizer, device)
|
316 |
+
candidate_list=[]
|
317 |
+
for embedding in embeddings:
|
318 |
+
response = (
|
319 |
+
weaviate_client.query
|
320 |
+
.get(class_name, ['keywords', 'summary'])
|
321 |
+
.with_near_vector({'vector': embedding})
|
322 |
+
.with_limit(limit)
|
323 |
+
.with_additional(['distance'])
|
324 |
+
.do()
|
325 |
+
)
|
326 |
+
|
327 |
+
if class_name.capitalize() in response['data']['Get']:
|
328 |
+
results = response['data']['Get'][class_name.capitalize()]
|
329 |
+
for result in results:
|
330 |
+
candidate_list.append({
|
331 |
+
'distance': result['_additional']['distance'],
|
332 |
+
'summary': result['summary'],
|
333 |
+
'keywords': result['keywords']
|
334 |
+
})
|
335 |
+
|
336 |
+
return candidate_list
|
337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
|
340 |
+
triggered_keywords = {}
|
|
|
|
|
|
|
|
|
341 |
|
342 |
+
def keyword_match(keywords_dict,candidates):
|
343 |
+
for candidate in candidates:
|
344 |
+
keywords=candidate['keywords'].split('*')
|
345 |
+
processed_keywords=[keyword.split('#')[1] for keyword in keywords]
|
346 |
+
candidate_keywords_list=','.join(processed_keywords)
|
347 |
+
for keyword in keywords_dict.keys():
|
348 |
+
if any(candidate_keyword in keyword for candidate_keyword in candidate_keywords_list):
|
349 |
+
# triggered_keywords[keyword]=True
|
350 |
+
return candidate['distance'],candidate['summary'],candidate['keywords']
|
351 |
+
|
352 |
+
def chatbot_response(message, history, max_tokens, temperature, top_p, window_size, threshold, user_weight, triggered_weight, api_key):
|
353 |
+
#初始化openai client
|
354 |
+
initialize_openai_client(api_key)
|
355 |
+
|
356 |
+
#更新轮次,获取窗口历史
|
357 |
current_turn = len(history) + 1
|
|
|
358 |
|
359 |
+
combined_user_message = " ".join([h[0] for h in history[-window_size:]] + [message])
|
360 |
+
combined_assistant_message = " ".join([h[1] for h in history[-window_size:]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
+
#提取关键词
|
363 |
+
user_keywords = extract_keywords(combined_user_message).split(',')
|
364 |
+
assistant_keywords = extract_keywords(combined_assistant_message).split(',')
|
365 |
|
366 |
+
#获取关键词字典
|
367 |
+
keywords_dict = {keyword: user_weight for keyword in user_keywords}
|
368 |
+
for keyword in assistant_keywords:
|
369 |
+
keywords_dict[keyword] = keywords_dict.get(keyword, 0) + 1
|
370 |
|
371 |
+
for keyword in list(keywords_dict.keys()):
|
372 |
+
if keyword in triggered_keywords and current_turn - triggered_keywords[keyword] < window_size:
|
373 |
+
keywords_dict[keyword] = triggered_weight
|
374 |
|
375 |
+
#数据库检索,双方平均方式
|
376 |
+
# distance, ad_summary, ad_keywords = get_response_from_db(keywords_dict, class_name="ad_DB02")
|
377 |
+
#数据库索引,数据库关键词平均方式
|
378 |
+
candidates=get_candidates_from_db(keywords_dict, class_name="ad_DB02",limit=3)
|
379 |
+
|
380 |
+
#先对候选集的distance进行筛选,保留小于threshold的
|
381 |
+
candidates.sort(key=lambda x:x['distance'])
|
382 |
+
candidates=[candidate for candidate in candidates if candidate['distance']<threshold]
|
383 |
+
|
384 |
+
if(candidates):
|
385 |
+
distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
|
386 |
|
|
|
|
|
|
|
|
|
|
|
387 |
|
388 |
+
#判断相似度
|
389 |
+
if distance and distance < threshold:
|
390 |
+
ad_message = f"{message} <sep>品牌<sep>{ad_summary}"
|
391 |
+
messages = [{"role": "system", "content": "你是一个热情的聊天机器人,应微妙地嵌入广告内容。"}]
|
392 |
+
messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
|
393 |
+
messages.append({"role": "user", "content": ad_message})
|
394 |
|
395 |
+
for keyword in keywords_dict.keys():
|
396 |
+
if any(ad_keyword in keyword for ad_keyword in ad_keywords.split(',')):
|
|
|
|
|
|
|
397 |
triggered_keywords[keyword] = current_turn
|
|
|
398 |
else:
|
399 |
+
messages = [{"role": "system", "content": "你是一个热情的聊天机器人。"}]
|
400 |
+
messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
messages.append({"role": "user", "content": message})
|
402 |
|
403 |
+
#获取回复
|
404 |
+
response = openai_client.chat.completions.create(
|
405 |
model="gpt-3.5-turbo",
|
406 |
messages=messages,
|
407 |
max_tokens=max_tokens,
|
408 |
temperature=temperature,
|
409 |
top_p=top_p,
|
410 |
)
|
|
|
|
|
|
|
411 |
|
412 |
+
print(f"triggered_keywords: {triggered_keywords}")
|
413 |
+
return response.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
|
415 |
+
# Gradio UI
|
416 |
demo = gr.ChatInterface(
|
417 |
+
chatbot_response,
|
418 |
additional_inputs=[
|
419 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
420 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
421 |
+
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
423 |
gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
424 |
gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
425 |
gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
426 |
+
gr.Textbox(label="API Key"),
|
427 |
],
|
428 |
)
|
429 |
|
430 |
if __name__ == "__main__":
|
431 |
demo.launch(share=True)
|
432 |
+
print("happyhappyhappy")
|
433 |
+
|
434 |
+
|
435 |
+
|
436 |
+
|
437 |
+
|
438 |
+
# import gradio as gr
|
439 |
+
# from huggingface_hub import InferenceClient
|
440 |
+
# import json
|
441 |
+
# import random
|
442 |
+
# import re
|
443 |
+
# from load_data import load_data
|
444 |
+
# from openai import OpenAI
|
445 |
+
# from transformers import AutoTokenizer, AutoModel
|
446 |
+
# import weaviate
|
447 |
+
# import os
|
448 |
+
# import subprocess
|
449 |
+
# import torch
|
450 |
+
# from tqdm import tqdm
|
451 |
+
# import numpy as np
|
452 |
+
# import time
|
453 |
+
|
454 |
+
# # 设置 Matplotlib 的缓存目录
|
455 |
+
# os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
456 |
+
# # 设置 Hugging Face Transformers 的缓存目录
|
457 |
+
# os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
458 |
+
# # 确保这些目录存在
|
459 |
+
# os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
|
460 |
+
# os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
461 |
+
|
462 |
+
# auth_config = weaviate.AuthApiKey(api_key="Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH")
|
463 |
+
|
464 |
+
# URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
|
465 |
+
|
466 |
+
# # Connect to a WCS instance
|
467 |
+
# db_client = weaviate.Client(
|
468 |
+
# url=URL,
|
469 |
+
# auth_client_secret=auth_config
|
470 |
+
# )
|
471 |
+
|
472 |
+
|
473 |
+
# class_name="ad_DB02"
|
474 |
+
|
475 |
+
# device = torch.device(device='cuda' if torch.cuda.is_available() else 'cpu')
|
476 |
+
# tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
477 |
+
# model = AutoModel.from_pretrained("bert-base-chinese")
|
478 |
+
|
479 |
+
|
480 |
+
# global_api_key = None
|
481 |
+
# client = None
|
482 |
+
|
483 |
+
# def initialize_clients(api_key):
|
484 |
+
# global client
|
485 |
+
# client = OpenAI(api_key=api_key)
|
486 |
+
|
487 |
+
# def get_keywords(message):
|
488 |
+
# system_message = """
|
489 |
+
# # 角色
|
490 |
+
# 你是一个关键词提取机器人
|
491 |
+
# # 指令
|
492 |
+
# 你的目标是从用户的输入中提取关键词,这些关键词应该尽可能是购买意图相关的。关键词中应该尽可能注意那些名词和形容词
|
493 |
+
# # 输出格式
|
494 |
+
# 你应该直接输出关键词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙
|
495 |
+
# # 注意:如果输入文本过短可以重复输出关键词,例如对输入“你好”可以输出:你好 你好 你好 你好 你好
|
496 |
+
# """
|
497 |
+
|
498 |
+
# messages = [{"role": "system", "content": system_message}]
|
499 |
+
# messages.append({"role": "user", "content": f"从下面的文本中给我提取五个关键词,只输出这五个关键词,以空格分隔{message}"})
|
500 |
+
|
501 |
+
# response = client.chat.completions.create(
|
502 |
+
# model="gpt-3.5-turbo",
|
503 |
+
# messages=messages,
|
504 |
+
# max_tokens=100,
|
505 |
+
# temperature=0.7,
|
506 |
+
# top_p=0.9,
|
507 |
+
# )
|
508 |
+
|
509 |
+
# keywords = response.choices[0].message.content.split(' ')
|
510 |
+
# return ','.join(keywords)
|
511 |
+
|
512 |
+
|
513 |
+
# #字符串匹配模块
|
514 |
+
# def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,distance_threshold):
|
515 |
+
# distance = 0
|
516 |
+
# most_matching_list = None
|
517 |
+
# index = 0
|
518 |
+
|
519 |
+
# # query_keywords = query_keywords.split(',')
|
520 |
+
# # query_keywords = [keyword for keyword in query_keywords if keyword]
|
521 |
+
|
522 |
+
# #匹配模块
|
523 |
+
# query_keywords= list(query_keywords_dict.keys())
|
524 |
+
|
525 |
+
# for i, lst in enumerate(ad_keywords_lists):
|
526 |
+
# lst = lst.split(',')
|
527 |
+
# matches = sum(
|
528 |
+
# any(
|
529 |
+
# ad_keyword in keyword and
|
530 |
+
# (
|
531 |
+
# keyword not in triggered_keywords or
|
532 |
+
# triggered_keywords.get(keyword) is None or
|
533 |
+
# current_turn - triggered_keywords.get(keyword, 0) > window_size
|
534 |
+
# ) * query_keywords_dict.get(keyword, 1) #计数乘以权重
|
535 |
+
# for keyword in query_keywords
|
536 |
+
# )
|
537 |
+
# for ad_keyword in lst
|
538 |
+
# )
|
539 |
+
# if matches > distance:
|
540 |
+
# distance = matches
|
541 |
+
# most_matching_list = lst
|
542 |
+
# index = i
|
543 |
+
|
544 |
+
# #更新对distance 有贡献的关键词
|
545 |
+
# if distance >= distance_threshold:
|
546 |
+
# for keyword in query_keywords:
|
547 |
+
# if any(
|
548 |
+
# ad_keyword in keyword for ad_keyword in most_matching_list
|
549 |
+
# ):
|
550 |
+
# triggered_keywords[keyword] = current_turn
|
551 |
+
|
552 |
+
# return distance, index
|
553 |
+
|
554 |
+
|
555 |
+
# def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
|
556 |
+
# if torch.cuda.is_available():
|
557 |
+
# print('Using GPU')
|
558 |
+
# print(device)
|
559 |
+
# else:
|
560 |
+
# print('Using CPU')
|
561 |
+
# print(device)
|
562 |
+
|
563 |
+
# avg_embeddings = []
|
564 |
+
# for keywords in tqdm(keywords_list_list):
|
565 |
+
# keywords_lst=[]
|
566 |
+
# # keywords.split(',')
|
567 |
+
# for keyword in keywords:
|
568 |
+
# inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
569 |
+
# inputs.to(device)
|
570 |
+
# with torch.no_grad():
|
571 |
+
# outputs = model(**inputs)
|
572 |
+
# embeddings = outputs.last_hidden_state.mean(dim=1)
|
573 |
+
# keywords_lst.append(embeddings)
|
574 |
+
# avg_embedding = sum(keywords_lst) / len(keywords_lst)
|
575 |
+
# avg_embeddings.append(avg_embedding)
|
576 |
+
|
577 |
+
# return avg_embeddings
|
578 |
+
|
579 |
+
|
580 |
+
# def encode_to_avg(keywords_dict, model, tokenizer, device):
|
581 |
+
# if torch.cuda.is_available():
|
582 |
+
# print('Using GPU')
|
583 |
+
# print(device)
|
584 |
+
# else:
|
585 |
+
# print('Using CPU')
|
586 |
+
# print(device)
|
587 |
+
|
588 |
+
|
589 |
+
# keyword_embeddings=[]
|
590 |
+
# for keyword, weight in keywords_dict.items():
|
591 |
+
# inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
592 |
+
# inputs.to(device)
|
593 |
+
# with torch.no_grad():
|
594 |
+
# outputs = model(**inputs)
|
595 |
+
# embedding = outputs.last_hidden_state.mean(dim=1)
|
596 |
+
|
597 |
+
# keyword_embedding=embedding * weight
|
598 |
+
|
599 |
+
# keyword_embeddings.append(keyword_embedding * weight)
|
600 |
+
|
601 |
+
# avg_embedding = sum(keyword_embeddings) / sum(keywords_dict.values())
|
602 |
+
|
603 |
+
# return avg_embedding.tolist()
|
604 |
+
|
605 |
+
|
606 |
+
# def fetch_response_from_db(query_keywords_dict,class_name):
|
607 |
+
|
608 |
+
# start_time = time.time()
|
609 |
+
# avg_vec=np.array(encode_to_avg(query_keywords_dict, model, tokenizer, device))
|
610 |
+
# end_time = time.time()
|
611 |
+
# print(f"Time taken to encode to avg: {end_time - start_time}")
|
612 |
+
|
613 |
+
|
614 |
+
# nearVector = {
|
615 |
+
# 'vector': avg_vec
|
616 |
+
# }
|
617 |
+
# start_time = time.time()
|
618 |
+
# response = (
|
619 |
+
# db_client.query
|
620 |
+
# .get(class_name, ['keywords', 'summary'])
|
621 |
+
# .with_near_vector(nearVector)
|
622 |
+
# .with_limit(1)
|
623 |
+
# .with_additional(['distance'])
|
624 |
+
# .do()
|
625 |
+
# )
|
626 |
+
# end_time = time.time()
|
627 |
+
# print(f"Time taken to search in DB: {end_time - start_time}")
|
628 |
+
|
629 |
+
|
630 |
+
# print(response)
|
631 |
+
# class_name=class_name[0].upper()+class_name[1:]
|
632 |
+
|
633 |
+
# if class_name in response['data']['Get']:
|
634 |
+
# results = response['data']['Get'][class_name]
|
635 |
+
# print(results[0]['keywords'])
|
636 |
+
# return results[0]['_additional']['distance'],results[0]['summary'], results[0]['keywords']
|
637 |
+
|
638 |
+
# else:
|
639 |
+
# print(f"Class name {class_name} not found in response")
|
640 |
+
# return None
|
641 |
+
|
642 |
+
|
643 |
+
|
644 |
+
# def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key):
|
645 |
+
# initialize_clients(api_key)
|
646 |
+
# return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered)
|
647 |
+
|
648 |
+
|
649 |
+
# #触发词及触发回合字典
|
650 |
+
# triggered_keywords = {}
|
651 |
+
|
652 |
+
# def respond(
|
653 |
+
# message,
|
654 |
+
# history: list[tuple[str, str]],
|
655 |
+
# max_tokens,
|
656 |
+
# temperature,
|
657 |
+
# top_p,
|
658 |
+
# window_size,
|
659 |
+
# distance_threshold,
|
660 |
+
# weight_keywords_users,
|
661 |
+
# weight_keywords_triggered,
|
662 |
+
|
663 |
+
# ):
|
664 |
+
|
665 |
+
# system_message_with_ad = """
|
666 |
+
# # 角色
|
667 |
+
# 你是一个热情的聊天机器人
|
668 |
+
# # 指令
|
669 |
+
# 你的目标是以一种微妙的方式将给定的广告内容嵌入对话中。插入广告时应尽量自然,避免突兀。优先正常回答用户问题,然后逐步以推荐方式引入广告内容。
|
670 |
+
# 注意要在商品的描述前面加上是来自哪个品牌的广告。
|
671 |
+
# 注意在推荐中不要脑补用户的身份,只是进行简单推荐。
|
672 |
+
# 注意要热情但是语气只要适度热情
|
673 |
+
# # 输入格式
|
674 |
+
# 用户查询后跟随广告品牌,用<sep>分隔,广告品牌后跟随广告描述,再用<sep>分隔。
|
675 |
+
# 例如:我想买一条阔腿裤 <sep> 腾讯 <sep> 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。
|
676 |
+
# 注意: 当没有<sep>时,正常回复用户,不插入广告。
|
677 |
+
# # 输出格式
|
678 |
+
# 始终使用中文,只输出聊天内容,不输出任何自我分析的信息
|
679 |
+
# """
|
680 |
+
|
681 |
+
# system_message_without_ad = """
|
682 |
+
# 你是一个热情的聊天机器人
|
683 |
+
# """
|
684 |
+
# print(f"triggered_keywords{triggered_keywords}")
|
685 |
+
# # 更新当前轮次
|
686 |
+
|
687 |
+
# current_turn = len(history) + 1
|
688 |
+
# print(f"\ncurrent_turn: {current_turn}\n")
|
689 |
+
|
690 |
+
# # 检查历史记录的长度
|
691 |
+
# if len(history) >= window_size:
|
692 |
+
# combined_message_user = " ".join([h[0] for h in history[-window_size:] if h[0]] + [message])
|
693 |
+
# combined_message_assistant=" ".join(h[1] for h in history[-window_size:] if h[1])
|
694 |
+
# else:
|
695 |
+
# combined_message_user = message
|
696 |
+
# combined_message_assistant = ""
|
697 |
+
|
698 |
+
# start_time = time.time()
|
699 |
+
# key_words_users=get_keywords(combined_message_user).split(',')
|
700 |
+
# key_words_assistant=get_keywords(combined_message_assistant).split(',')
|
701 |
+
# end_time = time.time()
|
702 |
+
# print(f"Time taken to get keywords: {end_time - start_time}")
|
703 |
+
|
704 |
+
# print(f"Initial keywords_users: {key_words_users}")
|
705 |
+
# print(f"Initial keywords_assistant: {key_words_assistant}")
|
706 |
+
|
707 |
+
# keywords_dict = {}
|
708 |
+
# added_keywords = set()
|
709 |
+
|
710 |
+
# for keywords in key_words_users:
|
711 |
+
# if keywords not in added_keywords:
|
712 |
+
# if keywords in keywords_dict:
|
713 |
+
# keywords_dict[keywords] += weight_keywords_users
|
714 |
+
# else:
|
715 |
+
# keywords_dict[keywords] = weight_keywords_users
|
716 |
+
# added_keywords.add(keywords)
|
717 |
+
|
718 |
+
# for keywords in key_words_assistant:
|
719 |
+
# if keywords not in added_keywords:
|
720 |
+
# if keywords in keywords_dict:
|
721 |
+
# keywords_dict[keywords] += 1
|
722 |
+
# else:
|
723 |
+
# keywords_dict[keywords] = 1
|
724 |
+
# added_keywords.add(keywords)
|
725 |
+
|
726 |
+
# #窗口内触发过的关键词权重下调为0.5
|
727 |
+
# for keyword in list(keywords_dict.keys()):
|
728 |
+
# if keyword in triggered_keywords:
|
729 |
+
# if current_turn - triggered_keywords[keyword] < window_size:
|
730 |
+
# keywords_dict[keyword] = weight_keywords_triggered
|
731 |
+
|
732 |
+
# query_keywords = list(keywords_dict.keys())
|
733 |
+
# print(keywords_dict)
|
734 |
+
|
735 |
+
# start_time = time.time()
|
736 |
+
# distance,top_keywords_list,top_summary = fetch_response_from_db(keywords_dict,class_name)
|
737 |
+
# end_time = time.time()
|
738 |
+
# print(f"Time taken to fetch response from db: {end_time - start_time}")
|
739 |
+
|
740 |
+
|
741 |
+
# print(f"distance: {distance}")
|
742 |
+
|
743 |
+
# if distance<distance_threshold:
|
744 |
+
# ad =top_summary
|
745 |
+
|
746 |
+
# messages = [{"role": "system", "content": system_message_with_ad}]
|
747 |
+
|
748 |
+
# for val in history:
|
749 |
+
# if val[0]:
|
750 |
+
# messages.append({"role": "user", "content": val[0]})
|
751 |
+
# if val[1]:
|
752 |
+
# messages.append({"role": "assistant", "content": val[1]})
|
753 |
+
|
754 |
+
# brands = ['腾讯', '百度', '京东', '华为', '小米', '苹果', '微软', '谷歌', '亚马逊']
|
755 |
+
# brand = random.choice(brands)
|
756 |
+
# messages.append({"role": "user", "content": f"{message} <sep>{brand}的 <sep> {ad}"})
|
757 |
+
|
758 |
+
# #更新触发词
|
759 |
+
# for keyword in query_keywords:
|
760 |
+
# if any(
|
761 |
+
# ad_keyword in keyword for ad_keyword in top_keywords_list
|
762 |
+
# ):
|
763 |
+
# triggered_keywords[keyword] = current_turn
|
764 |
+
|
765 |
+
# else:
|
766 |
+
# messages = [{"role": "system", "content": system_message_without_ad}]
|
767 |
+
|
768 |
+
# for val in history:
|
769 |
+
# if val[0]:
|
770 |
+
# messages.append({"role": "user", "content": val[0]})
|
771 |
+
# if val[1]:
|
772 |
+
# messages.append({"role": "assistant", "content": val[1]})
|
773 |
+
|
774 |
+
# messages.append({"role": "user", "content": message})
|
775 |
+
|
776 |
+
# start_time = time.time()
|
777 |
+
# response = client.chat.completions.create(
|
778 |
+
# model="gpt-3.5-turbo",
|
779 |
+
# messages=messages,
|
780 |
+
# max_tokens=max_tokens,
|
781 |
+
# temperature=temperature,
|
782 |
+
# top_p=top_p,
|
783 |
+
# )
|
784 |
+
# end_time = time.time()
|
785 |
+
# print(f"Time taken to get response from GPT: {end_time - start_time}")
|
786 |
+
|
787 |
+
|
788 |
+
# return response.choices[0].message.content
|
789 |
+
|
790 |
+
|
791 |
+
# # def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
|
792 |
+
# # global triggered_keywords
|
793 |
+
# # response, triggered_keywords = respond(
|
794 |
+
# # message,
|
795 |
+
# # history,
|
796 |
+
# # max_tokens,
|
797 |
+
# # temperature,
|
798 |
+
# # top_p,
|
799 |
+
# # window_size,
|
800 |
+
# # distance_threshold,
|
801 |
+
# # triggered_keywords
|
802 |
+
# # )
|
803 |
+
# # return response, history + [(message, response)]
|
804 |
+
|
805 |
+
# demo = gr.ChatInterface(
|
806 |
+
# wrapper,
|
807 |
+
# additional_inputs=[
|
808 |
+
# gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
809 |
+
# gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
810 |
+
# gr.Slider(
|
811 |
+
# minimum=0.1,
|
812 |
+
# maximum=1.0,
|
813 |
+
# value=0.95,
|
814 |
+
# step=0.05,
|
815 |
+
# label="Top-p (nucleus sampling)",
|
816 |
+
# ),
|
817 |
+
# gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
|
818 |
+
# gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
|
819 |
+
# gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
|
820 |
+
# gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
|
821 |
+
# gr.Textbox(label="api_key"),
|
822 |
+
# ],
|
823 |
+
# )
|
824 |
+
|
825 |
+
# if __name__ == "__main__":
|
826 |
+
# demo.launch(share=True)
|
827 |
|
828 |
# import gradio as gr
|
829 |
# from huggingface_hub import InferenceClient
|