decula commited on
Commit
4c45781
·
1 Parent(s): 4f26428

Added agent head

Browse files
Files changed (1) hide show
  1. 7b_rag.py +30 -2
7b_rag.py CHANGED
@@ -5,6 +5,7 @@ from huggingface_hub import hf_hub_download
5
  from pynvml import *
6
  from duckduckgo_search import DDGS
7
  import re
 
8
 
9
  # Flag to check if GPU is present
10
  HAS_GPU = False
@@ -43,10 +44,37 @@ model = RWKV(model=model_path, strategy=MODEL_STRAT)
43
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
44
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
45
 
46
- # Web search function for RAG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def web_search(query, max_results=3):
48
  try:
49
- with DDGS() as ddgs:
 
 
 
 
 
50
  results = list(ddgs.text(query, max_results=max_results))
51
  if not results:
52
  return "No search results found."
 
5
  from pynvml import *
6
  from duckduckgo_search import DDGS
7
  import re
8
+ import asyncio
9
 
10
  # Flag to check if GPU is present
11
  HAS_GPU = False
 
44
  from rwkv.utils import PIPELINE, PIPELINE_ARGS
45
  pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
46
 
47
+ # 理解问题并提取关键词的函数
48
+ async def understanding_question(question: str):
49
+ # 简单处理:移除常见的问题词,保留关键内容
50
+ question = question.lower()
51
+ question = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', question)
52
+ # 返回处理后的问题作为关键词
53
+ return question
54
+
55
+ # Web search function for RAG with browser agent HTTP headers
56
+ async def run_duckduckgo_search_tool(question: str):
57
+ text = await understanding_question(question)
58
+
59
+ keywords = text.split(",")
60
+ headers = {
61
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0"
62
+ }
63
+
64
+ results = DDGS(headers=headers).text(keywords[0], max_results=5)
65
+ print(results)
66
+
67
+ return text
68
+
69
+ # 修改后的web_search函数,使用run_duckduckgo_search_tool
70
  def web_search(query, max_results=3):
71
  try:
72
+ # 设置浏览器代理HTTP头部
73
+ headers = {
74
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0"
75
+ }
76
+
77
+ with DDGS(headers=headers) as ddgs:
78
  results = list(ddgs.text(query, max_results=max_results))
79
  if not results:
80
  return "No search results found."