thefish1 commited on
Commit
00fc4b2
·
1 Parent(s): 9bff0c8
Files changed (1) hide show
  1. app.py +551 -309
app.py CHANGED
@@ -190,9 +190,6 @@
190
 
191
 
192
 
193
-
194
-
195
-
196
  import gradio as gr
197
  from huggingface_hub import InferenceClient
198
  import json
@@ -202,61 +199,47 @@ from load_data import load_data
202
  from openai import OpenAI
203
  from transformers import AutoTokenizer, AutoModel
204
  import weaviate
205
- import os
206
- import subprocess
207
  import torch
208
  from tqdm import tqdm
209
  import numpy as np
210
  import time
211
 
212
- # 设置 Matplotlib 的缓存目录
213
- os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
214
- # 设置 Hugging Face Transformers 的缓存目录
215
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
216
- # 确保这些目录存在
217
- os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
218
- os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
219
-
220
- auth_config = weaviate.AuthApiKey(api_key="Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH")
221
-
222
- URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
223
-
224
- # Connect to a WCS instance
225
- db_client = weaviate.Client(
226
- url=URL,
227
- auth_client_secret=auth_config
228
- )
229
-
230
-
231
- class_name="ad_DB02"
232
-
233
- device = torch.device(device='cuda' if torch.cuda.is_available() else 'cpu')
234
- tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
235
- model = AutoModel.from_pretrained("bert-base-chinese")
236
-
237
-
238
- global_api_key = None
239
- client = None
240
-
241
- def initialize_clients(api_key):
242
- global client
243
- client = OpenAI(api_key=api_key)
244
-
245
- def get_keywords(message):
246
- system_message = """
247
- # 角色
248
- 你是一个关键词提取机器人
249
- # 指令
250
- 你的目标是从用户的输入中提取关键词,这些关键词应该尽可能是购买意图相关的。关键词中应该尽可能注意那些名词和形容词
251
- # 输出格式
252
- 你应该直接输出关键词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙
253
- # 注意:如果输入文本过短可以重复输出关键词,例如对输入“你好”可以输出:你好 你好 你好 你好 你好
254
  """
 
 
 
 
255
 
256
- messages = [{"role": "system", "content": system_message}]
257
- messages.append({"role": "user", "content": f"从下面的文本中给我提取五个关键词,只输出这五个关键词,以空格分隔{message}"})
258
-
259
- response = client.chat.completions.create(
260
  model="gpt-3.5-turbo",
261
  messages=messages,
262
  max_tokens=100,
@@ -267,321 +250,580 @@ def get_keywords(message):
267
  keywords = response.choices[0].message.content.split(' ')
268
  return ','.join(keywords)
269
 
 
 
 
270
 
271
- #字符串匹配模块
272
- def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,distance_threshold):
273
- distance = 0
274
- most_matching_list = None
275
- index = 0
 
 
 
 
 
276
 
277
- # query_keywords = query_keywords.split(',')
278
- # query_keywords = [keyword for keyword in query_keywords if keyword]
279
-
280
- #匹配模块
281
- query_keywords= list(query_keywords_dict.keys())
282
-
283
- for i, lst in enumerate(ad_keywords_lists):
284
- lst = lst.split(',')
285
- matches = sum(
286
- any(
287
- ad_keyword in keyword and
288
- (
289
- keyword not in triggered_keywords or
290
- triggered_keywords.get(keyword) is None or
291
- current_turn - triggered_keywords.get(keyword, 0) > window_size
292
- ) * query_keywords_dict.get(keyword, 1) #计数乘以权重
293
- for keyword in query_keywords
294
- )
295
- for ad_keyword in lst
296
- )
297
- if matches > distance:
298
- distance = matches
299
- most_matching_list = lst
300
- index = i
301
-
302
- #更新对distance 有贡献的关键词
303
- if distance >= distance_threshold:
304
- for keyword in query_keywords:
305
- if any(
306
- ad_keyword in keyword for ad_keyword in most_matching_list
307
- ):
308
- triggered_keywords[keyword] = current_turn
309
-
310
- return distance, index
311
 
 
312
 
313
- def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
314
- if torch.cuda.is_available():
315
- print('Using GPU')
316
- print(device)
317
- else:
318
- print('Using CPU')
319
- print(device)
320
-
321
- avg_embeddings = []
322
- for keywords in tqdm(keywords_list_list):
323
- keywords_lst=[]
324
- # keywords.split(',')
325
- for keyword in keywords:
326
- inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
327
- inputs.to(device)
328
- with torch.no_grad():
329
- outputs = model(**inputs)
330
- embeddings = outputs.last_hidden_state.mean(dim=1)
331
- keywords_lst.append(embeddings)
332
- avg_embedding = sum(keywords_lst) / len(keywords_lst)
333
- avg_embeddings.append(avg_embedding)
334
-
335
- return avg_embeddings
336
-
337
-
338
- def encode_to_avg(keywords_dict, model, tokenizer, device):
339
- if torch.cuda.is_available():
340
- print('Using GPU')
341
- print(device)
342
- else:
343
- print('Using CPU')
344
- print(device)
345
-
346
-
347
- keyword_embeddings=[]
348
- for keyword, weight in keywords_dict.items():
349
  inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
350
  inputs.to(device)
351
  with torch.no_grad():
352
  outputs = model(**inputs)
353
- embedding = outputs.last_hidden_state.mean(dim=1)
354
-
355
- keyword_embedding=embedding * weight
356
-
357
- keyword_embeddings.append(keyword_embedding * weight)
358
-
359
- avg_embedding = sum(keyword_embeddings) / sum(keywords_dict.values())
360
 
361
- return avg_embedding.tolist()
362
-
363
-
364
- def fetch_response_from_db(query_keywords_dict,class_name):
365
-
366
- start_time = time.time()
367
- avg_vec=np.array(encode_to_avg(query_keywords_dict, model, tokenizer, device))
368
- end_time = time.time()
369
- print(f"Time taken to encode to avg: {end_time - start_time}")
370
 
371
 
372
- nearVector = {
373
- 'vector': avg_vec
374
- }
375
- start_time = time.time()
376
  response = (
377
- db_client.query
378
  .get(class_name, ['keywords', 'summary'])
379
- .with_near_vector(nearVector)
380
  .with_limit(1)
381
  .with_additional(['distance'])
382
  .do()
383
  )
384
- end_time = time.time()
385
- print(f"Time taken to search in DB: {end_time - start_time}")
386
-
387
 
388
- print(response)
389
- class_name=class_name[0].upper()+class_name[1:]
390
-
391
- if class_name in response['data']['Get']:
392
- results = response['data']['Get'][class_name]
393
- print(results[0]['keywords'])
394
- return results[0]['_additional']['distance'],results[0]['summary'], results[0]['keywords']
395
-
396
  else:
397
- print(f"Class name {class_name} not found in response")
398
- return None
399
-
400
-
401
-
402
- def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key):
403
- initialize_clients(api_key)
404
- return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered)
405
-
406
-
407
- #触发词及触发回合字典
408
- triggered_keywords = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
- def respond(
411
- message,
412
- history: list[tuple[str, str]],
413
- max_tokens,
414
- temperature,
415
- top_p,
416
- window_size,
417
- distance_threshold,
418
- weight_keywords_users,
419
- weight_keywords_triggered,
420
 
421
- ):
422
-
423
- system_message_with_ad = """
424
- # 角色
425
- 你是一个热情的聊天机器人
426
- # 指令
427
- 你的目标是以一种微妙的方式将给定的广告内容嵌入对话中。插入广告时应尽量自然,避免突兀。优先正常回答用户问题,然后逐步以推荐方式引入广告内容。
428
- 注意要在商品的描述前面加上是来自哪个品牌的广告。
429
- 注意在推荐中不要脑补用户的身份,只是进行简单推荐。
430
- 注意要热情但是语气只要适度热情
431
- # 输入格式
432
- 用户查询后跟随广告品牌,用<sep>分隔,广告品牌后跟随广告描述,再用<sep>分隔。
433
- 例如:我想买一条阔腿裤 <sep> 腾讯 <sep> 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。
434
- 注意: 当没有<sep>时,正常回复用户,不插入广告。
435
- # 输出格式
436
- 始终使用中文,只输出聊天内容,不输出任何自我分析的信息
437
- """
438
 
439
- system_message_without_ad = """
440
- 你是一个热情的聊天机器人
441
- """
442
- print(f"triggered_keywords{triggered_keywords}")
443
- # 更新当前轮次
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  current_turn = len(history) + 1
446
- print(f"\ncurrent_turn: {current_turn}\n")
447
 
448
- # 检查历史记录的长度
449
- if len(history) >= window_size:
450
- combined_message_user = " ".join([h[0] for h in history[-window_size:] if h[0]] + [message])
451
- combined_message_assistant=" ".join(h[1] for h in history[-window_size:] if h[1])
452
- else:
453
- combined_message_user = message
454
- combined_message_assistant = ""
455
-
456
- start_time = time.time()
457
- key_words_users=get_keywords(combined_message_user).split(',')
458
- key_words_assistant=get_keywords(combined_message_assistant).split(',')
459
- end_time = time.time()
460
- print(f"Time taken to get keywords: {end_time - start_time}")
461
-
462
- print(f"Initial keywords_users: {key_words_users}")
463
- print(f"Initial keywords_assistant: {key_words_assistant}")
464
-
465
- keywords_dict = {}
466
- added_keywords = set()
467
-
468
- for keywords in key_words_users:
469
- if keywords not in added_keywords:
470
- if keywords in keywords_dict:
471
- keywords_dict[keywords] += weight_keywords_users
472
- else:
473
- keywords_dict[keywords] = weight_keywords_users
474
- added_keywords.add(keywords)
475
-
476
- for keywords in key_words_assistant:
477
- if keywords not in added_keywords:
478
- if keywords in keywords_dict:
479
- keywords_dict[keywords] += 1
480
- else:
481
- keywords_dict[keywords] = 1
482
- added_keywords.add(keywords)
483
-
484
- #窗口内触发过的关键词权重下调为0.5
485
- for keyword in list(keywords_dict.keys()):
486
- if keyword in triggered_keywords:
487
- if current_turn - triggered_keywords[keyword] < window_size:
488
- keywords_dict[keyword] = weight_keywords_triggered
489
-
490
- query_keywords = list(keywords_dict.keys())
491
- print(keywords_dict)
492
-
493
- start_time = time.time()
494
- distance,top_keywords_list,top_summary = fetch_response_from_db(keywords_dict,class_name)
495
- end_time = time.time()
496
- print(f"Time taken to fetch response from db: {end_time - start_time}")
497
 
 
 
 
498
 
499
- print(f"distance: {distance}")
 
 
 
500
 
501
- if distance<distance_threshold:
502
- ad =top_summary
 
503
 
504
- messages = [{"role": "system", "content": system_message_with_ad}]
 
 
 
 
 
 
 
 
 
 
505
 
506
- for val in history:
507
- if val[0]:
508
- messages.append({"role": "user", "content": val[0]})
509
- if val[1]:
510
- messages.append({"role": "assistant", "content": val[1]})
511
 
512
- brands = ['腾讯', '百度', '京东', '华为', '小米', '苹果', '微软', '谷歌', '亚马逊']
513
- brand = random.choice(brands)
514
- messages.append({"role": "user", "content": f"{message} <sep>{brand}的 <sep> {ad}"})
 
 
 
515
 
516
- #更新触发词
517
- for keyword in query_keywords:
518
- if any(
519
- ad_keyword in keyword for ad_keyword in top_keywords_list
520
- ):
521
  triggered_keywords[keyword] = current_turn
522
-
523
  else:
524
- messages = [{"role": "system", "content": system_message_without_ad}]
525
-
526
- for val in history:
527
- if val[0]:
528
- messages.append({"role": "user", "content": val[0]})
529
- if val[1]:
530
- messages.append({"role": "assistant", "content": val[1]})
531
-
532
  messages.append({"role": "user", "content": message})
533
 
534
- start_time = time.time()
535
- response = client.chat.completions.create(
536
  model="gpt-3.5-turbo",
537
  messages=messages,
538
  max_tokens=max_tokens,
539
  temperature=temperature,
540
  top_p=top_p,
541
  )
542
- end_time = time.time()
543
- print(f"Time taken to get response from GPT: {end_time - start_time}")
544
-
545
 
546
- return response.choices[0].message.content
547
-
548
-
549
- # def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
550
- # global triggered_keywords
551
- # response, triggered_keywords = respond(
552
- # message,
553
- # history,
554
- # max_tokens,
555
- # temperature,
556
- # top_p,
557
- # window_size,
558
- # distance_threshold,
559
- # triggered_keywords
560
- # )
561
- # return response, history + [(message, response)]
562
 
 
563
  demo = gr.ChatInterface(
564
- wrapper,
565
  additional_inputs=[
566
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
567
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
568
- gr.Slider(
569
- minimum=0.1,
570
- maximum=1.0,
571
- value=0.95,
572
- step=0.05,
573
- label="Top-p (nucleus sampling)",
574
- ),
575
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
576
  gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
577
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
578
  gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
579
- gr.Textbox(label="api_key"),
580
  ],
581
  )
582
 
583
  if __name__ == "__main__":
584
  demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
  # import gradio as gr
587
  # from huggingface_hub import InferenceClient
 
190
 
191
 
192
 
 
 
 
193
  import gradio as gr
194
  from huggingface_hub import InferenceClient
195
  import json
 
199
  from openai import OpenAI
200
  from transformers import AutoTokenizer, AutoModel
201
  import weaviate
202
+ import os
 
203
  import torch
204
  from tqdm import tqdm
205
  import numpy as np
206
  import time
207
 
208
+ # 设置缓存目录
209
+ os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
210
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
211
+ os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
212
+ os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
213
+
214
+ # Weaviate 连接配置
215
+ WEAVIATE_API_KEY = "Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH"
216
+ WEAVIATE_URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
217
+ weaviate_auth_config = weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY)
218
+ weaviate_client = weaviate.Client(url=WEAVIATE_URL, auth_client_secret=weaviate_auth_config)
219
+
220
+ # 预训练模型配置
221
+ MODEL_NAME = "bert-base-chinese"
222
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
223
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
224
+ model = AutoModel.from_pretrained(MODEL_NAME)
225
+
226
+ # OpenAI 客户端
227
+ openai_client = None
228
+
229
+ def initialize_openai_client(api_key):
230
+ global openai_client
231
+ openai_client = OpenAI(api_key=api_key)
232
+
233
+ def extract_keywords(text):
234
+ prompt = """
235
+ 你是一个关键词提取机器人。提取用户输入中的关键词,特别是名词和形容词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  """
237
+ messages = [
238
+ {"role": "system", "content": prompt},
239
+ {"role": "user", "content": f"从下面的文本中提取五个关键词,以空格分隔:{text}"}
240
+ ]
241
 
242
+ response = openai_client.chat.completions.create(
 
 
 
243
  model="gpt-3.5-turbo",
244
  messages=messages,
245
  max_tokens=100,
 
250
  keywords = response.choices[0].message.content.split(' ')
251
  return ','.join(keywords)
252
 
253
+ # def match_keywords(query_keywords, ad_keywords_list, triggered_keywords, current_turn, window_size, threshold):
254
+ # best_match_distance = 0
255
+ # best_match_index = -1
256
 
257
+ # for i, ad_keywords in enumerate(ad_keywords_list):
258
+ # match_count = sum(
259
+ # any(
260
+ # ad_keyword in keyword and
261
+ # (keyword not in triggered_keywords or current_turn - triggered_keywords[keyword] > window_size)
262
+ # ) for keyword in query_keywords
263
+ # )
264
+ # if match_count > best_match_distance:
265
+ # best_match_distance = match_count
266
+ # best_match_index = i
267
 
268
+ # if best_match_distance >= threshold:
269
+ # for keyword in query_keywords:
270
+ # if any(ad_keyword in keyword for ad_keyword in ad_keywords_list[best_match_index]):
271
+ # triggered_keywords[keyword] = current_turn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ # return best_match_distance, best_match_index
274
 
275
+ def encode_keywords_to_avg(keywords, model, tokenizer, device):
276
+ embeddings = []
277
+ for keyword in tqdm(keywords):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
279
  inputs.to(device)
280
  with torch.no_grad():
281
  outputs = model(**inputs)
282
+ embeddings.append(outputs.last_hidden_state.mean(dim=1))
283
+ avg_embedding = sum(embeddings) / len(embeddings)
284
+ return avg_embedding
 
 
 
 
285
 
286
+ def encode_keywords_to_list(keywords, model, tokenizer, device):
287
+ embeddings = []
288
+ for keyword in tqdm(keywords):
289
+ inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
290
+ inputs.to(device)
291
+ with torch.no_grad():
292
+ outputs = model(**inputs)
293
+ embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().tolist())
294
+ return embeddings
295
 
296
 
297
+ def get_response_from_db(keywords_dict, class_name):
298
+ avg_vec = encode_keywords_to_avg(keywords_dict.keys(), model, tokenizer, device).numpy()
 
 
299
  response = (
300
+ weaviate_client.query
301
  .get(class_name, ['keywords', 'summary'])
302
+ .with_near_vector({'vector': avg_vec})
303
  .with_limit(1)
304
  .with_additional(['distance'])
305
  .do()
306
  )
 
 
 
307
 
308
+ if class_name.capitalize() in response['data']['Get']:
309
+ result = response['data']['Get'][class_name.capitalize()][0]
310
+ return result['_additional']['distance'], result['summary'], result['keywords']
 
 
 
 
 
311
  else:
312
+ return None, None, None
313
+
314
+ def get_candidates_from_db(keywords_dict, class_name,limit=3):
315
+ embeddings= encode_keywords_to_list(keywords_dict.keys(), model, tokenizer, device)
316
+ candidate_list=[]
317
+ for embedding in embeddings:
318
+ response = (
319
+ weaviate_client.query
320
+ .get(class_name, ['keywords', 'summary'])
321
+ .with_near_vector({'vector': embedding})
322
+ .with_limit(limit)
323
+ .with_additional(['distance'])
324
+ .do()
325
+ )
326
+
327
+ if class_name.capitalize() in response['data']['Get']:
328
+ results = response['data']['Get'][class_name.capitalize()]
329
+ for result in results:
330
+ candidate_list.append({
331
+ 'distance': result['_additional']['distance'],
332
+ 'summary': result['summary'],
333
+ 'keywords': result['keywords']
334
+ })
335
+
336
+ return candidate_list
337
 
 
 
 
 
 
 
 
 
 
 
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
+ triggered_keywords = {}
 
 
 
 
341
 
342
+ def keyword_match(keywords_dict,candidates):
343
+ for candidate in candidates:
344
+ keywords=candidate['keywords'].split('*')
345
+ processed_keywords=[keyword.split('#')[1] for keyword in keywords]
346
+ candidate_keywords_list=','.join(processed_keywords)
347
+ for keyword in keywords_dict.keys():
348
+ if any(candidate_keyword in keyword for candidate_keyword in candidate_keywords_list):
349
+ # triggered_keywords[keyword]=True
350
+ return candidate['distance'],candidate['summary'],candidate['keywords']
351
+
352
+ def chatbot_response(message, history, max_tokens, temperature, top_p, window_size, threshold, user_weight, triggered_weight, api_key):
353
+ #初始化openai client
354
+ initialize_openai_client(api_key)
355
+
356
+ #更新轮次,获取窗口历史
357
  current_turn = len(history) + 1
 
358
 
359
+ combined_user_message = " ".join([h[0] for h in history[-window_size:]] + [message])
360
+ combined_assistant_message = " ".join([h[1] for h in history[-window_size:]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
+ #提取关键词
363
+ user_keywords = extract_keywords(combined_user_message).split(',')
364
+ assistant_keywords = extract_keywords(combined_assistant_message).split(',')
365
 
366
+ #获取关键词字典
367
+ keywords_dict = {keyword: user_weight for keyword in user_keywords}
368
+ for keyword in assistant_keywords:
369
+ keywords_dict[keyword] = keywords_dict.get(keyword, 0) + 1
370
 
371
+ for keyword in list(keywords_dict.keys()):
372
+ if keyword in triggered_keywords and current_turn - triggered_keywords[keyword] < window_size:
373
+ keywords_dict[keyword] = triggered_weight
374
 
375
+ #数据库检索,双方平均方式
376
+ # distance, ad_summary, ad_keywords = get_response_from_db(keywords_dict, class_name="ad_DB02")
377
+ #数据库索引,数据库关键词平均方式
378
+ candidates=get_candidates_from_db(keywords_dict, class_name="ad_DB02",limit=3)
379
+
380
+ #先对候选集的distance进行筛选,保留小于threshold的
381
+ candidates.sort(key=lambda x:x['distance'])
382
+ candidates=[candidate for candidate in candidates if candidate['distance']<threshold]
383
+
384
+ if(candidates):
385
+ distance, ad_summary, ad_keywords=keyword_match(keywords_dict,candidates)
386
 
 
 
 
 
 
387
 
388
+ #判断相似度
389
+ if distance and distance < threshold:
390
+ ad_message = f"{message} <sep>品牌<sep>{ad_summary}"
391
+ messages = [{"role": "system", "content": "你是一个热情的聊天机器人,应微妙地嵌入广告内容。"}]
392
+ messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
393
+ messages.append({"role": "user", "content": ad_message})
394
 
395
+ for keyword in keywords_dict.keys():
396
+ if any(ad_keyword in keyword for ad_keyword in ad_keywords.split(',')):
 
 
 
397
  triggered_keywords[keyword] = current_turn
 
398
  else:
399
+ messages = [{"role": "system", "content": "你是一个热情的聊天机器人。"}]
400
+ messages.extend([{"role": "user", "content": msg[0]}, {"role": "assistant", "content": msg[1]}] for msg in history)
 
 
 
 
 
 
401
  messages.append({"role": "user", "content": message})
402
 
403
+ #获取回复
404
+ response = openai_client.chat.completions.create(
405
  model="gpt-3.5-turbo",
406
  messages=messages,
407
  max_tokens=max_tokens,
408
  temperature=temperature,
409
  top_p=top_p,
410
  )
 
 
 
411
 
412
+ print(f"triggered_keywords: {triggered_keywords}")
413
+ return response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ # Gradio UI
416
  demo = gr.ChatInterface(
417
+ chatbot_response,
418
  additional_inputs=[
419
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
420
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
421
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
422
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
423
  gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
424
  gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
425
  gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
426
+ gr.Textbox(label="API Key"),
427
  ],
428
  )
429
 
430
  if __name__ == "__main__":
431
  demo.launch(share=True)
432
+ print("happyhappyhappy")
433
+
434
+
435
+
436
+
437
+
438
+ # import gradio as gr
439
+ # from huggingface_hub import InferenceClient
440
+ # import json
441
+ # import random
442
+ # import re
443
+ # from load_data import load_data
444
+ # from openai import OpenAI
445
+ # from transformers import AutoTokenizer, AutoModel
446
+ # import weaviate
447
+ # import os
448
+ # import subprocess
449
+ # import torch
450
+ # from tqdm import tqdm
451
+ # import numpy as np
452
+ # import time
453
+
454
+ # # 设置 Matplotlib 的缓存目录
455
+ # os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
456
+ # # 设置 Hugging Face Transformers 的缓存目录
457
+ # os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
458
+ # # 确保这些目录存在
459
+ # os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True)
460
+ # os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
461
+
462
+ # auth_config = weaviate.AuthApiKey(api_key="Y7c8DRmcxZ4nP5IJLwkznIsK84l6EdwfXwcH")
463
+
464
+ # URL = "https://39nlafviqvard82k6y8btq.c0.asia-southeast1.gcp.weaviate.cloud"
465
+
466
+ # # Connect to a WCS instance
467
+ # db_client = weaviate.Client(
468
+ # url=URL,
469
+ # auth_client_secret=auth_config
470
+ # )
471
+
472
+
473
+ # class_name="ad_DB02"
474
+
475
+ # device = torch.device(device='cuda' if torch.cuda.is_available() else 'cpu')
476
+ # tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
477
+ # model = AutoModel.from_pretrained("bert-base-chinese")
478
+
479
+
480
+ # global_api_key = None
481
+ # client = None
482
+
483
+ # def initialize_clients(api_key):
484
+ # global client
485
+ # client = OpenAI(api_key=api_key)
486
+
487
+ # def get_keywords(message):
488
+ # system_message = """
489
+ # # 角色
490
+ # 你是一个关键词提取机器人
491
+ # # 指令
492
+ # 你的目标是从用户的输入中提取关键词,这些关键词应该尽可能是购买意图相关的。关键词中应该尽可能注意那些名词和形容词
493
+ # # 输出格式
494
+ # 你应该直接输出关键词,关键词之间用空格分隔。例如:苹果 电脑 裤子 蓝色 裙
495
+ # # 注意:如果输入文本过短可以重复输出关键词,例如对输入“你好”可以输出:你好 你好 你好 你好 你好
496
+ # """
497
+
498
+ # messages = [{"role": "system", "content": system_message}]
499
+ # messages.append({"role": "user", "content": f"从下面的文本中给我提取五个关键词,只输出这五个关键词,以空格分隔{message}"})
500
+
501
+ # response = client.chat.completions.create(
502
+ # model="gpt-3.5-turbo",
503
+ # messages=messages,
504
+ # max_tokens=100,
505
+ # temperature=0.7,
506
+ # top_p=0.9,
507
+ # )
508
+
509
+ # keywords = response.choices[0].message.content.split(' ')
510
+ # return ','.join(keywords)
511
+
512
+
513
+ # #字符串匹配模块
514
+ # def keyword_match(query_keywords_dict, ad_keywords_lists, triggered_keywords, current_turn, window_size,distance_threshold):
515
+ # distance = 0
516
+ # most_matching_list = None
517
+ # index = 0
518
+
519
+ # # query_keywords = query_keywords.split(',')
520
+ # # query_keywords = [keyword for keyword in query_keywords if keyword]
521
+
522
+ # #匹配模块
523
+ # query_keywords= list(query_keywords_dict.keys())
524
+
525
+ # for i, lst in enumerate(ad_keywords_lists):
526
+ # lst = lst.split(',')
527
+ # matches = sum(
528
+ # any(
529
+ # ad_keyword in keyword and
530
+ # (
531
+ # keyword not in triggered_keywords or
532
+ # triggered_keywords.get(keyword) is None or
533
+ # current_turn - triggered_keywords.get(keyword, 0) > window_size
534
+ # ) * query_keywords_dict.get(keyword, 1) #计数乘以权重
535
+ # for keyword in query_keywords
536
+ # )
537
+ # for ad_keyword in lst
538
+ # )
539
+ # if matches > distance:
540
+ # distance = matches
541
+ # most_matching_list = lst
542
+ # index = i
543
+
544
+ # #更新对distance 有贡献的关键词
545
+ # if distance >= distance_threshold:
546
+ # for keyword in query_keywords:
547
+ # if any(
548
+ # ad_keyword in keyword for ad_keyword in most_matching_list
549
+ # ):
550
+ # triggered_keywords[keyword] = current_turn
551
+
552
+ # return distance, index
553
+
554
+
555
+ # def encode_list_to_avg(keywords_list_list, model, tokenizer, device):
556
+ # if torch.cuda.is_available():
557
+ # print('Using GPU')
558
+ # print(device)
559
+ # else:
560
+ # print('Using CPU')
561
+ # print(device)
562
+
563
+ # avg_embeddings = []
564
+ # for keywords in tqdm(keywords_list_list):
565
+ # keywords_lst=[]
566
+ # # keywords.split(',')
567
+ # for keyword in keywords:
568
+ # inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
569
+ # inputs.to(device)
570
+ # with torch.no_grad():
571
+ # outputs = model(**inputs)
572
+ # embeddings = outputs.last_hidden_state.mean(dim=1)
573
+ # keywords_lst.append(embeddings)
574
+ # avg_embedding = sum(keywords_lst) / len(keywords_lst)
575
+ # avg_embeddings.append(avg_embedding)
576
+
577
+ # return avg_embeddings
578
+
579
+
580
+ # def encode_to_avg(keywords_dict, model, tokenizer, device):
581
+ # if torch.cuda.is_available():
582
+ # print('Using GPU')
583
+ # print(device)
584
+ # else:
585
+ # print('Using CPU')
586
+ # print(device)
587
+
588
+
589
+ # keyword_embeddings=[]
590
+ # for keyword, weight in keywords_dict.items():
591
+ # inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
592
+ # inputs.to(device)
593
+ # with torch.no_grad():
594
+ # outputs = model(**inputs)
595
+ # embedding = outputs.last_hidden_state.mean(dim=1)
596
+
597
+ # keyword_embedding=embedding * weight
598
+
599
+ # keyword_embeddings.append(keyword_embedding * weight)
600
+
601
+ # avg_embedding = sum(keyword_embeddings) / sum(keywords_dict.values())
602
+
603
+ # return avg_embedding.tolist()
604
+
605
+
606
+ # def fetch_response_from_db(query_keywords_dict,class_name):
607
+
608
+ # start_time = time.time()
609
+ # avg_vec=np.array(encode_to_avg(query_keywords_dict, model, tokenizer, device))
610
+ # end_time = time.time()
611
+ # print(f"Time taken to encode to avg: {end_time - start_time}")
612
+
613
+
614
+ # nearVector = {
615
+ # 'vector': avg_vec
616
+ # }
617
+ # start_time = time.time()
618
+ # response = (
619
+ # db_client.query
620
+ # .get(class_name, ['keywords', 'summary'])
621
+ # .with_near_vector(nearVector)
622
+ # .with_limit(1)
623
+ # .with_additional(['distance'])
624
+ # .do()
625
+ # )
626
+ # end_time = time.time()
627
+ # print(f"Time taken to search in DB: {end_time - start_time}")
628
+
629
+
630
+ # print(response)
631
+ # class_name=class_name[0].upper()+class_name[1:]
632
+
633
+ # if class_name in response['data']['Get']:
634
+ # results = response['data']['Get'][class_name]
635
+ # print(results[0]['keywords'])
636
+ # return results[0]['_additional']['distance'],results[0]['summary'], results[0]['keywords']
637
+
638
+ # else:
639
+ # print(f"Class name {class_name} not found in response")
640
+ # return None
641
+
642
+
643
+
644
+ # def wrapper(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered, api_key):
645
+ # initialize_clients(api_key)
646
+ # return respond(message, history, max_tokens, temperature, top_p, window_size, distance_threshold, weight_keywords_users, weight_keywords_triggered)
647
+
648
+
649
+ # #触发词及触发回合字典
650
+ # triggered_keywords = {}
651
+
652
+ # def respond(
653
+ # message,
654
+ # history: list[tuple[str, str]],
655
+ # max_tokens,
656
+ # temperature,
657
+ # top_p,
658
+ # window_size,
659
+ # distance_threshold,
660
+ # weight_keywords_users,
661
+ # weight_keywords_triggered,
662
+
663
+ # ):
664
+
665
+ # system_message_with_ad = """
666
+ # # 角色
667
+ # 你是一个热情的聊天机器人
668
+ # # 指令
669
+ # 你的目标是以一种微妙的方式将给定的广告内容嵌入对话中。插入广告时应尽量自然,避免突兀。优先正常回答用户问题,然后逐步以推荐方式引入广告内容。
670
+ # 注意要在商品的描述前面加上是来自哪个品牌的广告。
671
+ # 注意在推荐中不要脑补用户的身份,只是进行简单推荐。
672
+ # 注意要热情但是语气只要适度热情
673
+ # # 输入格式
674
+ # 用户查询后跟随广告品牌,用<sep>分隔,广告品牌后跟随广告描述,再用<sep>分隔。
675
+ # 例如:我想买一条阔腿裤 <sep> 腾讯 <sep> 宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。
676
+ # 注意: 当没有<sep>时,正常回复用户,不插入广告。
677
+ # # 输出格式
678
+ # 始终使用中文,只输出聊天内容,不输出任何自我分析的信息
679
+ # """
680
+
681
+ # system_message_without_ad = """
682
+ # 你是一个热情的聊天机器人
683
+ # """
684
+ # print(f"triggered_keywords{triggered_keywords}")
685
+ # # 更新当前轮次
686
+
687
+ # current_turn = len(history) + 1
688
+ # print(f"\ncurrent_turn: {current_turn}\n")
689
+
690
+ # # 检查历史记录的长度
691
+ # if len(history) >= window_size:
692
+ # combined_message_user = " ".join([h[0] for h in history[-window_size:] if h[0]] + [message])
693
+ # combined_message_assistant=" ".join(h[1] for h in history[-window_size:] if h[1])
694
+ # else:
695
+ # combined_message_user = message
696
+ # combined_message_assistant = ""
697
+
698
+ # start_time = time.time()
699
+ # key_words_users=get_keywords(combined_message_user).split(',')
700
+ # key_words_assistant=get_keywords(combined_message_assistant).split(',')
701
+ # end_time = time.time()
702
+ # print(f"Time taken to get keywords: {end_time - start_time}")
703
+
704
+ # print(f"Initial keywords_users: {key_words_users}")
705
+ # print(f"Initial keywords_assistant: {key_words_assistant}")
706
+
707
+ # keywords_dict = {}
708
+ # added_keywords = set()
709
+
710
+ # for keywords in key_words_users:
711
+ # if keywords not in added_keywords:
712
+ # if keywords in keywords_dict:
713
+ # keywords_dict[keywords] += weight_keywords_users
714
+ # else:
715
+ # keywords_dict[keywords] = weight_keywords_users
716
+ # added_keywords.add(keywords)
717
+
718
+ # for keywords in key_words_assistant:
719
+ # if keywords not in added_keywords:
720
+ # if keywords in keywords_dict:
721
+ # keywords_dict[keywords] += 1
722
+ # else:
723
+ # keywords_dict[keywords] = 1
724
+ # added_keywords.add(keywords)
725
+
726
+ # #窗口内触发过的关键词权重下调为0.5
727
+ # for keyword in list(keywords_dict.keys()):
728
+ # if keyword in triggered_keywords:
729
+ # if current_turn - triggered_keywords[keyword] < window_size:
730
+ # keywords_dict[keyword] = weight_keywords_triggered
731
+
732
+ # query_keywords = list(keywords_dict.keys())
733
+ # print(keywords_dict)
734
+
735
+ # start_time = time.time()
736
+ # distance,top_keywords_list,top_summary = fetch_response_from_db(keywords_dict,class_name)
737
+ # end_time = time.time()
738
+ # print(f"Time taken to fetch response from db: {end_time - start_time}")
739
+
740
+
741
+ # print(f"distance: {distance}")
742
+
743
+ # if distance<distance_threshold:
744
+ # ad =top_summary
745
+
746
+ # messages = [{"role": "system", "content": system_message_with_ad}]
747
+
748
+ # for val in history:
749
+ # if val[0]:
750
+ # messages.append({"role": "user", "content": val[0]})
751
+ # if val[1]:
752
+ # messages.append({"role": "assistant", "content": val[1]})
753
+
754
+ # brands = ['腾讯', '百度', '京东', '华为', '小米', '苹果', '微软', '谷歌', '亚马逊']
755
+ # brand = random.choice(brands)
756
+ # messages.append({"role": "user", "content": f"{message} <sep>{brand}的 <sep> {ad}"})
757
+
758
+ # #更新触发词
759
+ # for keyword in query_keywords:
760
+ # if any(
761
+ # ad_keyword in keyword for ad_keyword in top_keywords_list
762
+ # ):
763
+ # triggered_keywords[keyword] = current_turn
764
+
765
+ # else:
766
+ # messages = [{"role": "system", "content": system_message_without_ad}]
767
+
768
+ # for val in history:
769
+ # if val[0]:
770
+ # messages.append({"role": "user", "content": val[0]})
771
+ # if val[1]:
772
+ # messages.append({"role": "assistant", "content": val[1]})
773
+
774
+ # messages.append({"role": "user", "content": message})
775
+
776
+ # start_time = time.time()
777
+ # response = client.chat.completions.create(
778
+ # model="gpt-3.5-turbo",
779
+ # messages=messages,
780
+ # max_tokens=max_tokens,
781
+ # temperature=temperature,
782
+ # top_p=top_p,
783
+ # )
784
+ # end_time = time.time()
785
+ # print(f"Time taken to get response from GPT: {end_time - start_time}")
786
+
787
+
788
+ # return response.choices[0].message.content
789
+
790
+
791
+ # # def chat_interface(message, history, max_tokens, temperature, top_p, window_size, distance_threshold):
792
+ # # global triggered_keywords
793
+ # # response, triggered_keywords = respond(
794
+ # # message,
795
+ # # history,
796
+ # # max_tokens,
797
+ # # temperature,
798
+ # # top_p,
799
+ # # window_size,
800
+ # # distance_threshold,
801
+ # # triggered_keywords
802
+ # # )
803
+ # # return response, history + [(message, response)]
804
+
805
+ # demo = gr.ChatInterface(
806
+ # wrapper,
807
+ # additional_inputs=[
808
+ # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
809
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
810
+ # gr.Slider(
811
+ # minimum=0.1,
812
+ # maximum=1.0,
813
+ # value=0.95,
814
+ # step=0.05,
815
+ # label="Top-p (nucleus sampling)",
816
+ # ),
817
+ # gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Window size"),
818
+ # gr.Slider(minimum=0.01, maximum=0.20, value=0.08, step=0.01, label="Distance threshold"),
819
+ # gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Weight of keywords from users"),
820
+ # gr.Slider(minimum=0, maximum=2, value=0.5, step=0.5, label="Weight of triggered keywords"),
821
+ # gr.Textbox(label="api_key"),
822
+ # ],
823
+ # )
824
+
825
+ # if __name__ == "__main__":
826
+ # demo.launch(share=True)
827
 
828
  # import gradio as gr
829
  # from huggingface_hub import InferenceClient