Spaces:
Running
Running
File size: 13,735 Bytes
f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 0f34f56 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc f771b5a 2f06ccc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 |
import os
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_google_community import GoogleSearchAPIWrapper
from langchain_huggingface import HuggingFaceEmbeddings
from datetime import datetime
import json
class RAGManager:
"""
負責管理網路檢索增強生成 (RAG) 功能和對話記憶的類別
"""
def __init__(self, google_search_api_key=None, google_search_cse_id=None,
cache_dir="/tmp/rag_cache", memory_dir="/tmp/rag_memory"):
"""
初始化 RAG Manager 實例。
Args:
google_search_api_key (str): Google Search API 金鑰。
google_search_cse_id (str): Google Custom Search Engine ID。
cache_dir (str): 檢索結果的快取目錄。
memory_dir (str): 對話記憶的儲存目錄。
"""
self.cache_dir = cache_dir
self.memory_dir = memory_dir
os.makedirs(self.cache_dir, exist_ok=True)
os.makedirs(self.memory_dir, exist_ok=True)
# 設定 Google Search API
self.has_google_search = False
if google_search_api_key and google_search_cse_id:
try:
self.search = GoogleSearchAPIWrapper(
google_api_key=google_search_api_key,
google_cse_id=google_search_cse_id
)
self.has_google_search = True
except Exception as e:
print(f"Google Search API 初始化失敗: {e}")
# 初始化 HuggingFace Embeddings
try:
self.embeddings = HuggingFaceEmbeddings(
model_name="shibing624/text2vec-base-chinese", # 使用中文向量模型
cache_folder="/tmp/hf_models"
)
print("已成功載入 HuggingFace Embeddings 模型")
except Exception as e:
print(f"載入 HuggingFace Embeddings 失敗: {e}")
self.embeddings = None
# 初始化文本分割器
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", "。", "!", "?", ",", " ", ""]
)
# 儲存已處理的向量庫
self.vector_stores = {}
# 添加 RAG 控制指令和狀態
self.rag_status = {} # 儲存每個聊天室的 RAG 狀態
self.rag_enable_command = "開啟查詢模式"
self.rag_disable_command = "關閉查詢模式"
# 初始化對話記憶字典
self.conversation_memories = {} # 格式: {user_id: [message1, message2, ...]}
self._load_all_memories() # 載入已存在的記憶
def _load_all_memories(self):
"""載入所有已儲存的對話記憶"""
try:
for filename in os.listdir(self.memory_dir):
if filename.endswith('.json'):
user_id = filename.split('.')[0]
memory_path = os.path.join(self.memory_dir, filename)
with open(memory_path, 'r', encoding='utf-8') as f:
self.conversation_memories[user_id] = json.load(f)
print(f"已載入使用者 {user_id} 的對話記憶")
except Exception as e:
print(f"載入對話記憶時出錯: {e}")
def _save_memory(self, user_id):
"""儲存特定使用者的對話記憶"""
try:
if user_id in self.conversation_memories:
memory_path = os.path.join(self.memory_dir, f"{user_id}.json")
with open(memory_path, 'w', encoding='utf-8') as f:
json.dump(self.conversation_memories[user_id], f, ensure_ascii=False, indent=2)
print(f"已儲存使用者 {user_id} 的對話記憶")
except Exception as e:
print(f"儲存對話記憶時出錯: {e}")
def add_message(self, user_id, role, content, max_memory=100):
"""
添加一條對話消息到記憶中
Args:
user_id (str): 使用者ID
role (str): 消息角色 ('user' 或 'assistant')
content (str): 消息內容
max_memory (int): 保留的最大消息數量
"""
if user_id not in self.conversation_memories:
self.conversation_memories[user_id] = []
# 添加新消息,包含時間戳
message = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat()
}
self.conversation_memories[user_id].append(message)
# 如果超過最大記憶數量,移除最舊的消息
if len(self.conversation_memories[user_id]) > max_memory:
self.conversation_memories[user_id] = self.conversation_memories[user_id][-max_memory:]
# 儲存更新後的記憶
self._save_memory(user_id)
def get_conversation_history(self, user_id, limit=10, include_timestamps=False):
"""
獲取對話歷史
Args:
user_id (str): 使用者ID
limit (int): 要返回的最近消息數量
include_timestamps (bool): 是否包含時間戳
Returns:
list: 對話歷史列表
"""
if user_id not in self.conversation_memories:
return []
# 獲取最近的對話
recent_messages = self.conversation_memories[user_id][-limit:]
if include_timestamps:
return recent_messages
else:
# 移除時間戳
return [{k: v for k, v in msg.items() if k != 'timestamp'} for msg in recent_messages]
def clear_memory(self, user_id):
"""
清除特定使用者的對話記憶
Args:
user_id (str): 使用者ID
"""
if user_id in self.conversation_memories:
self.conversation_memories[user_id] = []
self._save_memory(user_id)
print(f"已清除使用者 {user_id} 的對話記憶")
def get_memory_summary(self, user_id, max_length=500):
"""
生成對話記憶摘要,用於提供給 LLM
Args:
user_id (str): 使用者ID
max_length (int): 摘要的最大長度
Returns:
str: 對話記憶摘要
"""
history = self.get_conversation_history(user_id)
if not history:
return ""
# 簡單格式化對話歷史
formatted_history = []
for msg in history:
role = "使用者" if msg["role"] == "user" else "助手"
formatted_history.append(f"{role}: {msg['content']}")
summary = "\n".join(formatted_history)
# 如果摘要太長,截斷並添加省略號
if len(summary) > max_length:
return summary[:max_length] + "..."
return summary
def search_and_process_urls(self, query, num_results=3):
"""
使用 Google Search API 搜尋相關網頁,並處理結果。
Args:
query (str): 搜尋關鍵字。
num_results (int): 要處理的搜尋結果數量。
Returns:
list: 處理後的文檔。
"""
if not self.has_google_search:
return []
try:
# 搜尋網頁
search_results = self.search.results(query, num_results)
all_docs = []
# 處理每個搜尋結果
for result in search_results:
url = result.get("link")
if not url:
continue
try:
# 使用 WebBaseLoader 載入網頁內容
loader = WebBaseLoader(url)
documents = loader.load()
# 為文檔添加來源 URL 資訊
for doc in documents:
doc.metadata["source"] = url
# 分割文檔為較小的塊
split_docs = self.text_splitter.split_documents(documents)
all_docs.extend(split_docs)
except Exception as e:
print(f"處理 URL {url} 時出錯: {e}")
return all_docs
except Exception as e:
print(f"搜尋處理時出錯: {e}")
return []
def create_vector_store(self, documents, query):
"""
從文檔創建向量存儲。
Args:
documents (list): 文檔列表。
query (str): 原始搜尋查詢,用於識別和快取向量庫。
Returns:
FAISS: FAISS 向量存儲。
"""
if not documents or not self.embeddings:
return None
try:
vector_store = FAISS.from_documents(documents, self.embeddings)
# 快取向量庫,以便將來重複使用
self.vector_stores[query] = vector_store
return vector_store
except Exception as e:
print(f"創建向量存儲時出錯: {e}")
return None
def retrieve_relevant_documents(self, query, top_k=5):
"""
檢索與查詢相關的文檔。
Args:
query (str): 搜尋查詢。
top_k (int): 要檢索的文檔數量。
Returns:
list: 相關文檔列表。
"""
# 檢查快取中是否有相關查詢
vector_store = None
for cached_query, stored_vs in self.vector_stores.items():
# 簡單檢查是否有相似查詢
if all(keyword in cached_query for keyword in query.split()[:3]):
vector_store = stored_vs
print(f"使用快取向量庫: {cached_query}")
break
# 如果沒有快取,則創建新的向量庫
if not vector_store:
documents = self.search_and_process_urls(query)
if not documents:
return []
vector_store = self.create_vector_store(documents, query)
if not vector_store:
return []
# 使用向量庫檢索相關文檔
try:
relevant_docs = vector_store.similarity_search(query, k=top_k)
return relevant_docs
except Exception as e:
print(f"檢索相關文檔時出錯: {e}")
return []
def get_web_context_for_query(self, query, user_id=None):
"""
獲取與查詢相關的網頁內容上下文,並考慮對話歷史。
Args:
query (str): 使用者查詢。
user_id (str, optional): 使用者ID,用於獲取對話歷史。
Returns:
str: 相關網頁的上下文,格式化為提示增強。
"""
# 如果有對話歷史,可以結合歷史來改進查詢
improved_query = query
if user_id and user_id in self.conversation_memories:
# 獲取最後幾條對話,用於上下文理解
recent_history = self.get_conversation_history(user_id, limit=3)
if recent_history:
# 簡單結合最近的用戶查詢來改進搜索
user_queries = [msg["content"] for msg in recent_history if msg["role"] == "user"]
if user_queries:
# 將最近的查詢和當前查詢結合
improved_query = f"{' '.join(user_queries[-1:])} {query}"
docs = self.retrieve_relevant_documents(improved_query)
if not docs:
return None
# 構建上下文字符串
context_parts = []
for i, doc in enumerate(docs):
source = doc.metadata.get("source", "未知來源")
content = doc.page_content.strip()
if content:
context_parts.append(f"資料 {i+1} (來源: {source}):\n{content}")
if not context_parts:
return None
web_context = "\n\n".join(context_parts)
return web_context
def generate_enhanced_prompt(self, user_query, user_id, include_memory=True):
"""
生成包含記憶和網頁上下文的增強提示
Args:
user_query (str): 使用者查詢
user_id (str): 使用者ID
include_memory (bool): 是否包含對話記憶
Returns:
str: 增強的提示
"""
prompt_parts = []
# 1. 添加對話記憶摘要
if include_memory and user_id in self.conversation_memories:
memory_summary = self.get_memory_summary(user_id)
if memory_summary:
prompt_parts.append(f"### 對話歷史:\n{memory_summary}\n")
# 2. 添加使用者當前查詢
prompt_parts.append(f"### 當前查詢:\n{user_query}\n")
# 3. 添加網頁上下文
web_context = self.get_web_context_for_query(user_query, user_id)
if web_context:
prompt_parts.append(f"### 參考資料:\n{web_context}\n")
# 4. 添加指示
prompt_parts.append("請根據上述資訊提供準確、相關的回答。如果參考資料不包含足夠資訊,請清楚說明。")
# 合併所有部分
full_prompt = "\n".join(prompt_parts)
return full_prompt |