gemiline / rag_manager.py
motaer0206's picture
Update rag_manager.py
0f34f56 verified
import os
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_google_community import GoogleSearchAPIWrapper
from langchain_huggingface import HuggingFaceEmbeddings
from datetime import datetime
import json
class RAGManager:
"""
負責管理網路檢索增強生成 (RAG) 功能和對話記憶的類別
"""
def __init__(self, google_search_api_key=None, google_search_cse_id=None,
cache_dir="/tmp/rag_cache", memory_dir="/tmp/rag_memory"):
"""
初始化 RAG Manager 實例。
Args:
google_search_api_key (str): Google Search API 金鑰。
google_search_cse_id (str): Google Custom Search Engine ID。
cache_dir (str): 檢索結果的快取目錄。
memory_dir (str): 對話記憶的儲存目錄。
"""
self.cache_dir = cache_dir
self.memory_dir = memory_dir
os.makedirs(self.cache_dir, exist_ok=True)
os.makedirs(self.memory_dir, exist_ok=True)
# 設定 Google Search API
self.has_google_search = False
if google_search_api_key and google_search_cse_id:
try:
self.search = GoogleSearchAPIWrapper(
google_api_key=google_search_api_key,
google_cse_id=google_search_cse_id
)
self.has_google_search = True
except Exception as e:
print(f"Google Search API 初始化失敗: {e}")
# 初始化 HuggingFace Embeddings
try:
self.embeddings = HuggingFaceEmbeddings(
model_name="shibing624/text2vec-base-chinese", # 使用中文向量模型
cache_folder="/tmp/hf_models"
)
print("已成功載入 HuggingFace Embeddings 模型")
except Exception as e:
print(f"載入 HuggingFace Embeddings 失敗: {e}")
self.embeddings = None
# 初始化文本分割器
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
separators=["\n\n", "\n", "。", "!", "?", ",", " ", ""]
)
# 儲存已處理的向量庫
self.vector_stores = {}
# 添加 RAG 控制指令和狀態
self.rag_status = {} # 儲存每個聊天室的 RAG 狀態
self.rag_enable_command = "開啟查詢模式"
self.rag_disable_command = "關閉查詢模式"
# 初始化對話記憶字典
self.conversation_memories = {} # 格式: {user_id: [message1, message2, ...]}
self._load_all_memories() # 載入已存在的記憶
def _load_all_memories(self):
"""載入所有已儲存的對話記憶"""
try:
for filename in os.listdir(self.memory_dir):
if filename.endswith('.json'):
user_id = filename.split('.')[0]
memory_path = os.path.join(self.memory_dir, filename)
with open(memory_path, 'r', encoding='utf-8') as f:
self.conversation_memories[user_id] = json.load(f)
print(f"已載入使用者 {user_id} 的對話記憶")
except Exception as e:
print(f"載入對話記憶時出錯: {e}")
def _save_memory(self, user_id):
"""儲存特定使用者的對話記憶"""
try:
if user_id in self.conversation_memories:
memory_path = os.path.join(self.memory_dir, f"{user_id}.json")
with open(memory_path, 'w', encoding='utf-8') as f:
json.dump(self.conversation_memories[user_id], f, ensure_ascii=False, indent=2)
print(f"已儲存使用者 {user_id} 的對話記憶")
except Exception as e:
print(f"儲存對話記憶時出錯: {e}")
def add_message(self, user_id, role, content, max_memory=100):
"""
添加一條對話消息到記憶中
Args:
user_id (str): 使用者ID
role (str): 消息角色 ('user' 或 'assistant')
content (str): 消息內容
max_memory (int): 保留的最大消息數量
"""
if user_id not in self.conversation_memories:
self.conversation_memories[user_id] = []
# 添加新消息,包含時間戳
message = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat()
}
self.conversation_memories[user_id].append(message)
# 如果超過最大記憶數量,移除最舊的消息
if len(self.conversation_memories[user_id]) > max_memory:
self.conversation_memories[user_id] = self.conversation_memories[user_id][-max_memory:]
# 儲存更新後的記憶
self._save_memory(user_id)
def get_conversation_history(self, user_id, limit=10, include_timestamps=False):
"""
獲取對話歷史
Args:
user_id (str): 使用者ID
limit (int): 要返回的最近消息數量
include_timestamps (bool): 是否包含時間戳
Returns:
list: 對話歷史列表
"""
if user_id not in self.conversation_memories:
return []
# 獲取最近的對話
recent_messages = self.conversation_memories[user_id][-limit:]
if include_timestamps:
return recent_messages
else:
# 移除時間戳
return [{k: v for k, v in msg.items() if k != 'timestamp'} for msg in recent_messages]
def clear_memory(self, user_id):
"""
清除特定使用者的對話記憶
Args:
user_id (str): 使用者ID
"""
if user_id in self.conversation_memories:
self.conversation_memories[user_id] = []
self._save_memory(user_id)
print(f"已清除使用者 {user_id} 的對話記憶")
def get_memory_summary(self, user_id, max_length=500):
"""
生成對話記憶摘要,用於提供給 LLM
Args:
user_id (str): 使用者ID
max_length (int): 摘要的最大長度
Returns:
str: 對話記憶摘要
"""
history = self.get_conversation_history(user_id)
if not history:
return ""
# 簡單格式化對話歷史
formatted_history = []
for msg in history:
role = "使用者" if msg["role"] == "user" else "助手"
formatted_history.append(f"{role}: {msg['content']}")
summary = "\n".join(formatted_history)
# 如果摘要太長,截斷並添加省略號
if len(summary) > max_length:
return summary[:max_length] + "..."
return summary
def search_and_process_urls(self, query, num_results=3):
"""
使用 Google Search API 搜尋相關網頁,並處理結果。
Args:
query (str): 搜尋關鍵字。
num_results (int): 要處理的搜尋結果數量。
Returns:
list: 處理後的文檔。
"""
if not self.has_google_search:
return []
try:
# 搜尋網頁
search_results = self.search.results(query, num_results)
all_docs = []
# 處理每個搜尋結果
for result in search_results:
url = result.get("link")
if not url:
continue
try:
# 使用 WebBaseLoader 載入網頁內容
loader = WebBaseLoader(url)
documents = loader.load()
# 為文檔添加來源 URL 資訊
for doc in documents:
doc.metadata["source"] = url
# 分割文檔為較小的塊
split_docs = self.text_splitter.split_documents(documents)
all_docs.extend(split_docs)
except Exception as e:
print(f"處理 URL {url} 時出錯: {e}")
return all_docs
except Exception as e:
print(f"搜尋處理時出錯: {e}")
return []
def create_vector_store(self, documents, query):
"""
從文檔創建向量存儲。
Args:
documents (list): 文檔列表。
query (str): 原始搜尋查詢,用於識別和快取向量庫。
Returns:
FAISS: FAISS 向量存儲。
"""
if not documents or not self.embeddings:
return None
try:
vector_store = FAISS.from_documents(documents, self.embeddings)
# 快取向量庫,以便將來重複使用
self.vector_stores[query] = vector_store
return vector_store
except Exception as e:
print(f"創建向量存儲時出錯: {e}")
return None
def retrieve_relevant_documents(self, query, top_k=5):
"""
檢索與查詢相關的文檔。
Args:
query (str): 搜尋查詢。
top_k (int): 要檢索的文檔數量。
Returns:
list: 相關文檔列表。
"""
# 檢查快取中是否有相關查詢
vector_store = None
for cached_query, stored_vs in self.vector_stores.items():
# 簡單檢查是否有相似查詢
if all(keyword in cached_query for keyword in query.split()[:3]):
vector_store = stored_vs
print(f"使用快取向量庫: {cached_query}")
break
# 如果沒有快取,則創建新的向量庫
if not vector_store:
documents = self.search_and_process_urls(query)
if not documents:
return []
vector_store = self.create_vector_store(documents, query)
if not vector_store:
return []
# 使用向量庫檢索相關文檔
try:
relevant_docs = vector_store.similarity_search(query, k=top_k)
return relevant_docs
except Exception as e:
print(f"檢索相關文檔時出錯: {e}")
return []
def get_web_context_for_query(self, query, user_id=None):
"""
獲取與查詢相關的網頁內容上下文,並考慮對話歷史。
Args:
query (str): 使用者查詢。
user_id (str, optional): 使用者ID,用於獲取對話歷史。
Returns:
str: 相關網頁的上下文,格式化為提示增強。
"""
# 如果有對話歷史,可以結合歷史來改進查詢
improved_query = query
if user_id and user_id in self.conversation_memories:
# 獲取最後幾條對話,用於上下文理解
recent_history = self.get_conversation_history(user_id, limit=3)
if recent_history:
# 簡單結合最近的用戶查詢來改進搜索
user_queries = [msg["content"] for msg in recent_history if msg["role"] == "user"]
if user_queries:
# 將最近的查詢和當前查詢結合
improved_query = f"{' '.join(user_queries[-1:])} {query}"
docs = self.retrieve_relevant_documents(improved_query)
if not docs:
return None
# 構建上下文字符串
context_parts = []
for i, doc in enumerate(docs):
source = doc.metadata.get("source", "未知來源")
content = doc.page_content.strip()
if content:
context_parts.append(f"資料 {i+1} (來源: {source}):\n{content}")
if not context_parts:
return None
web_context = "\n\n".join(context_parts)
return web_context
def generate_enhanced_prompt(self, user_query, user_id, include_memory=True):
"""
生成包含記憶和網頁上下文的增強提示
Args:
user_query (str): 使用者查詢
user_id (str): 使用者ID
include_memory (bool): 是否包含對話記憶
Returns:
str: 增強的提示
"""
prompt_parts = []
# 1. 添加對話記憶摘要
if include_memory and user_id in self.conversation_memories:
memory_summary = self.get_memory_summary(user_id)
if memory_summary:
prompt_parts.append(f"### 對話歷史:\n{memory_summary}\n")
# 2. 添加使用者當前查詢
prompt_parts.append(f"### 當前查詢:\n{user_query}\n")
# 3. 添加網頁上下文
web_context = self.get_web_context_for_query(user_query, user_id)
if web_context:
prompt_parts.append(f"### 參考資料:\n{web_context}\n")
# 4. 添加指示
prompt_parts.append("請根據上述資訊提供準確、相關的回答。如果參考資料不包含足夠資訊,請清楚說明。")
# 合併所有部分
full_prompt = "\n".join(prompt_parts)
return full_prompt