Spaces:
Running
Running
import os | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_google_community import GoogleSearchAPIWrapper | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from datetime import datetime | |
import json | |
class RAGManager: | |
""" | |
負責管理網路檢索增強生成 (RAG) 功能和對話記憶的類別 | |
""" | |
def __init__(self, google_search_api_key=None, google_search_cse_id=None, | |
cache_dir="/tmp/rag_cache", memory_dir="/tmp/rag_memory"): | |
""" | |
初始化 RAG Manager 實例。 | |
Args: | |
google_search_api_key (str): Google Search API 金鑰。 | |
google_search_cse_id (str): Google Custom Search Engine ID。 | |
cache_dir (str): 檢索結果的快取目錄。 | |
memory_dir (str): 對話記憶的儲存目錄。 | |
""" | |
self.cache_dir = cache_dir | |
self.memory_dir = memory_dir | |
os.makedirs(self.cache_dir, exist_ok=True) | |
os.makedirs(self.memory_dir, exist_ok=True) | |
# 設定 Google Search API | |
self.has_google_search = False | |
if google_search_api_key and google_search_cse_id: | |
try: | |
self.search = GoogleSearchAPIWrapper( | |
google_api_key=google_search_api_key, | |
google_cse_id=google_search_cse_id | |
) | |
self.has_google_search = True | |
except Exception as e: | |
print(f"Google Search API 初始化失敗: {e}") | |
# 初始化 HuggingFace Embeddings | |
try: | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="shibing624/text2vec-base-chinese", # 使用中文向量模型 | |
cache_folder="/tmp/hf_models" | |
) | |
print("已成功載入 HuggingFace Embeddings 模型") | |
except Exception as e: | |
print(f"載入 HuggingFace Embeddings 失敗: {e}") | |
self.embeddings = None | |
# 初始化文本分割器 | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
separators=["\n\n", "\n", "。", "!", "?", ",", " ", ""] | |
) | |
# 儲存已處理的向量庫 | |
self.vector_stores = {} | |
# 添加 RAG 控制指令和狀態 | |
self.rag_status = {} # 儲存每個聊天室的 RAG 狀態 | |
self.rag_enable_command = "開啟查詢模式" | |
self.rag_disable_command = "關閉查詢模式" | |
# 初始化對話記憶字典 | |
self.conversation_memories = {} # 格式: {user_id: [message1, message2, ...]} | |
self._load_all_memories() # 載入已存在的記憶 | |
def _load_all_memories(self): | |
"""載入所有已儲存的對話記憶""" | |
try: | |
for filename in os.listdir(self.memory_dir): | |
if filename.endswith('.json'): | |
user_id = filename.split('.')[0] | |
memory_path = os.path.join(self.memory_dir, filename) | |
with open(memory_path, 'r', encoding='utf-8') as f: | |
self.conversation_memories[user_id] = json.load(f) | |
print(f"已載入使用者 {user_id} 的對話記憶") | |
except Exception as e: | |
print(f"載入對話記憶時出錯: {e}") | |
def _save_memory(self, user_id): | |
"""儲存特定使用者的對話記憶""" | |
try: | |
if user_id in self.conversation_memories: | |
memory_path = os.path.join(self.memory_dir, f"{user_id}.json") | |
with open(memory_path, 'w', encoding='utf-8') as f: | |
json.dump(self.conversation_memories[user_id], f, ensure_ascii=False, indent=2) | |
print(f"已儲存使用者 {user_id} 的對話記憶") | |
except Exception as e: | |
print(f"儲存對話記憶時出錯: {e}") | |
def add_message(self, user_id, role, content, max_memory=100): | |
""" | |
添加一條對話消息到記憶中 | |
Args: | |
user_id (str): 使用者ID | |
role (str): 消息角色 ('user' 或 'assistant') | |
content (str): 消息內容 | |
max_memory (int): 保留的最大消息數量 | |
""" | |
if user_id not in self.conversation_memories: | |
self.conversation_memories[user_id] = [] | |
# 添加新消息,包含時間戳 | |
message = { | |
"role": role, | |
"content": content, | |
"timestamp": datetime.now().isoformat() | |
} | |
self.conversation_memories[user_id].append(message) | |
# 如果超過最大記憶數量,移除最舊的消息 | |
if len(self.conversation_memories[user_id]) > max_memory: | |
self.conversation_memories[user_id] = self.conversation_memories[user_id][-max_memory:] | |
# 儲存更新後的記憶 | |
self._save_memory(user_id) | |
def get_conversation_history(self, user_id, limit=10, include_timestamps=False): | |
""" | |
獲取對話歷史 | |
Args: | |
user_id (str): 使用者ID | |
limit (int): 要返回的最近消息數量 | |
include_timestamps (bool): 是否包含時間戳 | |
Returns: | |
list: 對話歷史列表 | |
""" | |
if user_id not in self.conversation_memories: | |
return [] | |
# 獲取最近的對話 | |
recent_messages = self.conversation_memories[user_id][-limit:] | |
if include_timestamps: | |
return recent_messages | |
else: | |
# 移除時間戳 | |
return [{k: v for k, v in msg.items() if k != 'timestamp'} for msg in recent_messages] | |
def clear_memory(self, user_id): | |
""" | |
清除特定使用者的對話記憶 | |
Args: | |
user_id (str): 使用者ID | |
""" | |
if user_id in self.conversation_memories: | |
self.conversation_memories[user_id] = [] | |
self._save_memory(user_id) | |
print(f"已清除使用者 {user_id} 的對話記憶") | |
def get_memory_summary(self, user_id, max_length=500): | |
""" | |
生成對話記憶摘要,用於提供給 LLM | |
Args: | |
user_id (str): 使用者ID | |
max_length (int): 摘要的最大長度 | |
Returns: | |
str: 對話記憶摘要 | |
""" | |
history = self.get_conversation_history(user_id) | |
if not history: | |
return "" | |
# 簡單格式化對話歷史 | |
formatted_history = [] | |
for msg in history: | |
role = "使用者" if msg["role"] == "user" else "助手" | |
formatted_history.append(f"{role}: {msg['content']}") | |
summary = "\n".join(formatted_history) | |
# 如果摘要太長,截斷並添加省略號 | |
if len(summary) > max_length: | |
return summary[:max_length] + "..." | |
return summary | |
def search_and_process_urls(self, query, num_results=3): | |
""" | |
使用 Google Search API 搜尋相關網頁,並處理結果。 | |
Args: | |
query (str): 搜尋關鍵字。 | |
num_results (int): 要處理的搜尋結果數量。 | |
Returns: | |
list: 處理後的文檔。 | |
""" | |
if not self.has_google_search: | |
return [] | |
try: | |
# 搜尋網頁 | |
search_results = self.search.results(query, num_results) | |
all_docs = [] | |
# 處理每個搜尋結果 | |
for result in search_results: | |
url = result.get("link") | |
if not url: | |
continue | |
try: | |
# 使用 WebBaseLoader 載入網頁內容 | |
loader = WebBaseLoader(url) | |
documents = loader.load() | |
# 為文檔添加來源 URL 資訊 | |
for doc in documents: | |
doc.metadata["source"] = url | |
# 分割文檔為較小的塊 | |
split_docs = self.text_splitter.split_documents(documents) | |
all_docs.extend(split_docs) | |
except Exception as e: | |
print(f"處理 URL {url} 時出錯: {e}") | |
return all_docs | |
except Exception as e: | |
print(f"搜尋處理時出錯: {e}") | |
return [] | |
def create_vector_store(self, documents, query): | |
""" | |
從文檔創建向量存儲。 | |
Args: | |
documents (list): 文檔列表。 | |
query (str): 原始搜尋查詢,用於識別和快取向量庫。 | |
Returns: | |
FAISS: FAISS 向量存儲。 | |
""" | |
if not documents or not self.embeddings: | |
return None | |
try: | |
vector_store = FAISS.from_documents(documents, self.embeddings) | |
# 快取向量庫,以便將來重複使用 | |
self.vector_stores[query] = vector_store | |
return vector_store | |
except Exception as e: | |
print(f"創建向量存儲時出錯: {e}") | |
return None | |
def retrieve_relevant_documents(self, query, top_k=5): | |
""" | |
檢索與查詢相關的文檔。 | |
Args: | |
query (str): 搜尋查詢。 | |
top_k (int): 要檢索的文檔數量。 | |
Returns: | |
list: 相關文檔列表。 | |
""" | |
# 檢查快取中是否有相關查詢 | |
vector_store = None | |
for cached_query, stored_vs in self.vector_stores.items(): | |
# 簡單檢查是否有相似查詢 | |
if all(keyword in cached_query for keyword in query.split()[:3]): | |
vector_store = stored_vs | |
print(f"使用快取向量庫: {cached_query}") | |
break | |
# 如果沒有快取,則創建新的向量庫 | |
if not vector_store: | |
documents = self.search_and_process_urls(query) | |
if not documents: | |
return [] | |
vector_store = self.create_vector_store(documents, query) | |
if not vector_store: | |
return [] | |
# 使用向量庫檢索相關文檔 | |
try: | |
relevant_docs = vector_store.similarity_search(query, k=top_k) | |
return relevant_docs | |
except Exception as e: | |
print(f"檢索相關文檔時出錯: {e}") | |
return [] | |
def get_web_context_for_query(self, query, user_id=None): | |
""" | |
獲取與查詢相關的網頁內容上下文,並考慮對話歷史。 | |
Args: | |
query (str): 使用者查詢。 | |
user_id (str, optional): 使用者ID,用於獲取對話歷史。 | |
Returns: | |
str: 相關網頁的上下文,格式化為提示增強。 | |
""" | |
# 如果有對話歷史,可以結合歷史來改進查詢 | |
improved_query = query | |
if user_id and user_id in self.conversation_memories: | |
# 獲取最後幾條對話,用於上下文理解 | |
recent_history = self.get_conversation_history(user_id, limit=3) | |
if recent_history: | |
# 簡單結合最近的用戶查詢來改進搜索 | |
user_queries = [msg["content"] for msg in recent_history if msg["role"] == "user"] | |
if user_queries: | |
# 將最近的查詢和當前查詢結合 | |
improved_query = f"{' '.join(user_queries[-1:])} {query}" | |
docs = self.retrieve_relevant_documents(improved_query) | |
if not docs: | |
return None | |
# 構建上下文字符串 | |
context_parts = [] | |
for i, doc in enumerate(docs): | |
source = doc.metadata.get("source", "未知來源") | |
content = doc.page_content.strip() | |
if content: | |
context_parts.append(f"資料 {i+1} (來源: {source}):\n{content}") | |
if not context_parts: | |
return None | |
web_context = "\n\n".join(context_parts) | |
return web_context | |
def generate_enhanced_prompt(self, user_query, user_id, include_memory=True): | |
""" | |
生成包含記憶和網頁上下文的增強提示 | |
Args: | |
user_query (str): 使用者查詢 | |
user_id (str): 使用者ID | |
include_memory (bool): 是否包含對話記憶 | |
Returns: | |
str: 增強的提示 | |
""" | |
prompt_parts = [] | |
# 1. 添加對話記憶摘要 | |
if include_memory and user_id in self.conversation_memories: | |
memory_summary = self.get_memory_summary(user_id) | |
if memory_summary: | |
prompt_parts.append(f"### 對話歷史:\n{memory_summary}\n") | |
# 2. 添加使用者當前查詢 | |
prompt_parts.append(f"### 當前查詢:\n{user_query}\n") | |
# 3. 添加網頁上下文 | |
web_context = self.get_web_context_for_query(user_query, user_id) | |
if web_context: | |
prompt_parts.append(f"### 參考資料:\n{web_context}\n") | |
# 4. 添加指示 | |
prompt_parts.append("請根據上述資訊提供準確、相關的回答。如果參考資料不包含足夠資訊,請清楚說明。") | |
# 合併所有部分 | |
full_prompt = "\n".join(prompt_parts) | |
return full_prompt |