Spaces:
Runtime error
Runtime error
| import os | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_google_community import GoogleSearchAPIWrapper | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from datetime import datetime | |
| import json | |
| class RAGManager: | |
| """ | |
| 負責管理網路檢索增強生成 (RAG) 功能和對話記憶的類別 | |
| """ | |
| def __init__(self, google_search_api_key=None, google_search_cse_id=None, | |
| cache_dir="/tmp/rag_cache", memory_dir="/tmp/rag_memory"): | |
| """ | |
| 初始化 RAG Manager 實例。 | |
| Args: | |
| google_search_api_key (str): Google Search API 金鑰。 | |
| google_search_cse_id (str): Google Custom Search Engine ID。 | |
| cache_dir (str): 檢索結果的快取目錄。 | |
| memory_dir (str): 對話記憶的儲存目錄。 | |
| """ | |
| self.cache_dir = cache_dir | |
| self.memory_dir = memory_dir | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| os.makedirs(self.memory_dir, exist_ok=True) | |
| # 設定 Google Search API | |
| self.has_google_search = False | |
| if google_search_api_key and google_search_cse_id: | |
| try: | |
| self.search = GoogleSearchAPIWrapper( | |
| google_api_key=google_search_api_key, | |
| google_cse_id=google_search_cse_id | |
| ) | |
| self.has_google_search = True | |
| except Exception as e: | |
| print(f"Google Search API 初始化失敗: {e}") | |
| # 初始化 HuggingFace Embeddings | |
| try: | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name="shibing624/text2vec-base-chinese", # 使用中文向量模型 | |
| cache_folder="/tmp/hf_models" | |
| ) | |
| print("已成功載入 HuggingFace Embeddings 模型") | |
| except Exception as e: | |
| print(f"載入 HuggingFace Embeddings 失敗: {e}") | |
| self.embeddings = None | |
| # 初始化文本分割器 | |
| self.text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50, | |
| separators=["\n\n", "\n", "。", "!", "?", ",", " ", ""] | |
| ) | |
| # 儲存已處理的向量庫 | |
| self.vector_stores = {} | |
| # 添加 RAG 控制指令和狀態 | |
| self.rag_status = {} # 儲存每個聊天室的 RAG 狀態 | |
| self.rag_enable_command = "開啟查詢模式" | |
| self.rag_disable_command = "關閉查詢模式" | |
| # 初始化對話記憶字典 | |
| self.conversation_memories = {} # 格式: {user_id: [message1, message2, ...]} | |
| self._load_all_memories() # 載入已存在的記憶 | |
| def _load_all_memories(self): | |
| """載入所有已儲存的對話記憶""" | |
| try: | |
| for filename in os.listdir(self.memory_dir): | |
| if filename.endswith('.json'): | |
| user_id = filename.split('.')[0] | |
| memory_path = os.path.join(self.memory_dir, filename) | |
| with open(memory_path, 'r', encoding='utf-8') as f: | |
| self.conversation_memories[user_id] = json.load(f) | |
| print(f"已載入使用者 {user_id} 的對話記憶") | |
| except Exception as e: | |
| print(f"載入對話記憶時出錯: {e}") | |
| def _save_memory(self, user_id): | |
| """儲存特定使用者的對話記憶""" | |
| try: | |
| if user_id in self.conversation_memories: | |
| memory_path = os.path.join(self.memory_dir, f"{user_id}.json") | |
| with open(memory_path, 'w', encoding='utf-8') as f: | |
| json.dump(self.conversation_memories[user_id], f, ensure_ascii=False, indent=2) | |
| print(f"已儲存使用者 {user_id} 的對話記憶") | |
| except Exception as e: | |
| print(f"儲存對話記憶時出錯: {e}") | |
| def add_message(self, user_id, role, content, max_memory=100): | |
| """ | |
| 添加一條對話消息到記憶中 | |
| Args: | |
| user_id (str): 使用者ID | |
| role (str): 消息角色 ('user' 或 'assistant') | |
| content (str): 消息內容 | |
| max_memory (int): 保留的最大消息數量 | |
| """ | |
| if user_id not in self.conversation_memories: | |
| self.conversation_memories[user_id] = [] | |
| # 添加新消息,包含時間戳 | |
| message = { | |
| "role": role, | |
| "content": content, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| self.conversation_memories[user_id].append(message) | |
| # 如果超過最大記憶數量,移除最舊的消息 | |
| if len(self.conversation_memories[user_id]) > max_memory: | |
| self.conversation_memories[user_id] = self.conversation_memories[user_id][-max_memory:] | |
| # 儲存更新後的記憶 | |
| self._save_memory(user_id) | |
| def get_conversation_history(self, user_id, limit=10, include_timestamps=False): | |
| """ | |
| 獲取對話歷史 | |
| Args: | |
| user_id (str): 使用者ID | |
| limit (int): 要返回的最近消息數量 | |
| include_timestamps (bool): 是否包含時間戳 | |
| Returns: | |
| list: 對話歷史列表 | |
| """ | |
| if user_id not in self.conversation_memories: | |
| return [] | |
| # 獲取最近的對話 | |
| recent_messages = self.conversation_memories[user_id][-limit:] | |
| if include_timestamps: | |
| return recent_messages | |
| else: | |
| # 移除時間戳 | |
| return [{k: v for k, v in msg.items() if k != 'timestamp'} for msg in recent_messages] | |
| def clear_memory(self, user_id): | |
| """ | |
| 清除特定使用者的對話記憶 | |
| Args: | |
| user_id (str): 使用者ID | |
| """ | |
| if user_id in self.conversation_memories: | |
| self.conversation_memories[user_id] = [] | |
| self._save_memory(user_id) | |
| print(f"已清除使用者 {user_id} 的對話記憶") | |
| def get_memory_summary(self, user_id, max_length=500): | |
| """ | |
| 生成對話記憶摘要,用於提供給 LLM | |
| Args: | |
| user_id (str): 使用者ID | |
| max_length (int): 摘要的最大長度 | |
| Returns: | |
| str: 對話記憶摘要 | |
| """ | |
| history = self.get_conversation_history(user_id) | |
| if not history: | |
| return "" | |
| # 簡單格式化對話歷史 | |
| formatted_history = [] | |
| for msg in history: | |
| role = "使用者" if msg["role"] == "user" else "助手" | |
| formatted_history.append(f"{role}: {msg['content']}") | |
| summary = "\n".join(formatted_history) | |
| # 如果摘要太長,截斷並添加省略號 | |
| if len(summary) > max_length: | |
| return summary[:max_length] + "..." | |
| return summary | |
| def search_and_process_urls(self, query, num_results=3): | |
| """ | |
| 使用 Google Search API 搜尋相關網頁,並處理結果。 | |
| Args: | |
| query (str): 搜尋關鍵字。 | |
| num_results (int): 要處理的搜尋結果數量。 | |
| Returns: | |
| list: 處理後的文檔。 | |
| """ | |
| if not self.has_google_search: | |
| return [] | |
| try: | |
| # 搜尋網頁 | |
| search_results = self.search.results(query, num_results) | |
| all_docs = [] | |
| # 處理每個搜尋結果 | |
| for result in search_results: | |
| url = result.get("link") | |
| if not url: | |
| continue | |
| try: | |
| # 使用 WebBaseLoader 載入網頁內容 | |
| loader = WebBaseLoader(url) | |
| documents = loader.load() | |
| # 為文檔添加來源 URL 資訊 | |
| for doc in documents: | |
| doc.metadata["source"] = url | |
| # 分割文檔為較小的塊 | |
| split_docs = self.text_splitter.split_documents(documents) | |
| all_docs.extend(split_docs) | |
| except Exception as e: | |
| print(f"處理 URL {url} 時出錯: {e}") | |
| return all_docs | |
| except Exception as e: | |
| print(f"搜尋處理時出錯: {e}") | |
| return [] | |
| def create_vector_store(self, documents, query): | |
| """ | |
| 從文檔創建向量存儲。 | |
| Args: | |
| documents (list): 文檔列表。 | |
| query (str): 原始搜尋查詢,用於識別和快取向量庫。 | |
| Returns: | |
| FAISS: FAISS 向量存儲。 | |
| """ | |
| if not documents or not self.embeddings: | |
| return None | |
| try: | |
| vector_store = FAISS.from_documents(documents, self.embeddings) | |
| # 快取向量庫,以便將來重複使用 | |
| self.vector_stores[query] = vector_store | |
| return vector_store | |
| except Exception as e: | |
| print(f"創建向量存儲時出錯: {e}") | |
| return None | |
| def retrieve_relevant_documents(self, query, top_k=5): | |
| """ | |
| 檢索與查詢相關的文檔。 | |
| Args: | |
| query (str): 搜尋查詢。 | |
| top_k (int): 要檢索的文檔數量。 | |
| Returns: | |
| list: 相關文檔列表。 | |
| """ | |
| # 檢查快取中是否有相關查詢 | |
| vector_store = None | |
| for cached_query, stored_vs in self.vector_stores.items(): | |
| # 簡單檢查是否有相似查詢 | |
| if all(keyword in cached_query for keyword in query.split()[:3]): | |
| vector_store = stored_vs | |
| print(f"使用快取向量庫: {cached_query}") | |
| break | |
| # 如果沒有快取,則創建新的向量庫 | |
| if not vector_store: | |
| documents = self.search_and_process_urls(query) | |
| if not documents: | |
| return [] | |
| vector_store = self.create_vector_store(documents, query) | |
| if not vector_store: | |
| return [] | |
| # 使用向量庫檢索相關文檔 | |
| try: | |
| relevant_docs = vector_store.similarity_search(query, k=top_k) | |
| return relevant_docs | |
| except Exception as e: | |
| print(f"檢索相關文檔時出錯: {e}") | |
| return [] | |
| def get_web_context_for_query(self, query, user_id=None): | |
| """ | |
| 獲取與查詢相關的網頁內容上下文,並考慮對話歷史。 | |
| Args: | |
| query (str): 使用者查詢。 | |
| user_id (str, optional): 使用者ID,用於獲取對話歷史。 | |
| Returns: | |
| str: 相關網頁的上下文,格式化為提示增強。 | |
| """ | |
| # 如果有對話歷史,可以結合歷史來改進查詢 | |
| improved_query = query | |
| if user_id and user_id in self.conversation_memories: | |
| # 獲取最後幾條對話,用於上下文理解 | |
| recent_history = self.get_conversation_history(user_id, limit=3) | |
| if recent_history: | |
| # 簡單結合最近的用戶查詢來改進搜索 | |
| user_queries = [msg["content"] for msg in recent_history if msg["role"] == "user"] | |
| if user_queries: | |
| # 將最近的查詢和當前查詢結合 | |
| improved_query = f"{' '.join(user_queries[-1:])} {query}" | |
| docs = self.retrieve_relevant_documents(improved_query) | |
| if not docs: | |
| return None | |
| # 構建上下文字符串 | |
| context_parts = [] | |
| for i, doc in enumerate(docs): | |
| source = doc.metadata.get("source", "未知來源") | |
| content = doc.page_content.strip() | |
| if content: | |
| context_parts.append(f"資料 {i+1} (來源: {source}):\n{content}") | |
| if not context_parts: | |
| return None | |
| web_context = "\n\n".join(context_parts) | |
| return web_context | |
| def generate_enhanced_prompt(self, user_query, user_id, include_memory=True): | |
| """ | |
| 生成包含記憶和網頁上下文的增強提示 | |
| Args: | |
| user_query (str): 使用者查詢 | |
| user_id (str): 使用者ID | |
| include_memory (bool): 是否包含對話記憶 | |
| Returns: | |
| str: 增強的提示 | |
| """ | |
| prompt_parts = [] | |
| # 1. 添加對話記憶摘要 | |
| if include_memory and user_id in self.conversation_memories: | |
| memory_summary = self.get_memory_summary(user_id) | |
| if memory_summary: | |
| prompt_parts.append(f"### 對話歷史:\n{memory_summary}\n") | |
| # 2. 添加使用者當前查詢 | |
| prompt_parts.append(f"### 當前查詢:\n{user_query}\n") | |
| # 3. 添加網頁上下文 | |
| web_context = self.get_web_context_for_query(user_query, user_id) | |
| if web_context: | |
| prompt_parts.append(f"### 參考資料:\n{web_context}\n") | |
| # 4. 添加指示 | |
| prompt_parts.append("請根據上述資訊提供準確、相關的回答。如果參考資料不包含足夠資訊,請清楚說明。") | |
| # 合併所有部分 | |
| full_prompt = "\n".join(prompt_parts) | |
| return full_prompt |