decula commited on
Commit ·
4c45781
1
Parent(s): 4f26428
Added agent head
Browse files
7b_rag.py
CHANGED
|
@@ -5,6 +5,7 @@ from huggingface_hub import hf_hub_download
|
|
| 5 |
from pynvml import *
|
| 6 |
from duckduckgo_search import DDGS
|
| 7 |
import re
|
|
|
|
| 8 |
|
| 9 |
# Flag to check if GPU is present
|
| 10 |
HAS_GPU = False
|
|
@@ -43,10 +44,37 @@ model = RWKV(model=model_path, strategy=MODEL_STRAT)
|
|
| 43 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 44 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
| 45 |
|
| 46 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
def web_search(query, max_results=3):
|
| 48 |
try:
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
results = list(ddgs.text(query, max_results=max_results))
|
| 51 |
if not results:
|
| 52 |
return "No search results found."
|
|
|
|
| 5 |
from pynvml import *
|
| 6 |
from duckduckgo_search import DDGS
|
| 7 |
import re
|
| 8 |
+
import asyncio
|
| 9 |
|
| 10 |
# Flag to check if GPU is present
|
| 11 |
HAS_GPU = False
|
|
|
|
| 44 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 45 |
pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
|
| 46 |
|
| 47 |
+
# 理解问题并提取关键词的函数
|
| 48 |
+
async def understanding_question(question: str):
|
| 49 |
+
# 简单处理:移除常见的问题词,保留关键内容
|
| 50 |
+
question = question.lower()
|
| 51 |
+
question = re.sub(r'^(can you|could you|please|tell me about|what is|who is|how to|why is|when did)\s+', '', question)
|
| 52 |
+
# 返回处理后的问题作为关键词
|
| 53 |
+
return question
|
| 54 |
+
|
| 55 |
+
# Web search function for RAG with browser agent HTTP headers
|
| 56 |
+
async def run_duckduckgo_search_tool(question: str):
|
| 57 |
+
text = await understanding_question(question)
|
| 58 |
+
|
| 59 |
+
keywords = text.split(",")
|
| 60 |
+
headers = {
|
| 61 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0"
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
results = DDGS(headers=headers).text(keywords[0], max_results=5)
|
| 65 |
+
print(results)
|
| 66 |
+
|
| 67 |
+
return text
|
| 68 |
+
|
| 69 |
+
# 修改后的web_search函数,使用run_duckduckgo_search_tool
|
| 70 |
def web_search(query, max_results=3):
|
| 71 |
try:
|
| 72 |
+
# 设置浏览器代理HTTP头部
|
| 73 |
+
headers = {
|
| 74 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:124.0) Gecko/20100101 Firefox/124.0"
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
with DDGS(headers=headers) as ddgs:
|
| 78 |
results = list(ddgs.text(query, max_results=max_results))
|
| 79 |
if not results:
|
| 80 |
return "No search results found."
|