RAG_Gen / app.py
smartTranscend's picture
Update app.py
76b5019 verified
import os
import tempfile
from pathlib import Path
import streamlit as st
import re
from typing import List, Dict, Any, Optional
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_community.document_loaders import PyPDFLoader, UnstructuredPDFLoader
from langchain.schema import Document
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.llm import LLMChain
# ==============================================
# 基本設定
# ==============================================
print("✅ LangChain 0.2.x environment OK!")
MODEL_NAME = os.getenv("LLM_MODEL", "gpt-4o") # 使用更強的模型
##### multi-user
def main():
# 初始化用戶 session
init_user_session()
st.set_page_config(
page_title="專業PDF閱讀助手",
page_icon="📚",
layout="wide"
)
# 確保 messages 已初始化
if "messages" not in st.session_state:
st.session_state.messages = []
# ==============================================
# 用戶 API Key 驗證
# ==============================================
def validate_api_key(api_key: str) -> bool:
"""驗證 OpenAI API Key 是否有效"""
if not api_key or not api_key.startswith("sk-"):
return False
try:
# 簡單測試 API key
from openai import OpenAI
client = OpenAI(api_key=api_key)
client.models.list()
return True
except Exception as e:
st.error(f"API Key 驗證失敗: {str(e)}")
return False
# ==============================================
# 初始化用戶 Session
# ==============================================
def init_user_session():
"""初始化用戶專屬的 session state"""
if "user_initialized" not in st.session_state:
st.session_state.user_initialized = True
st.session_state.api_key = ""
st.session_state.api_key_valid = False
st.session_state.custom_settings = {
"topic_name": "通用知識庫",
"expert_role": "知識專家",
"domain": "通用",
"custom_system_prompt": ""
}
st.session_state.messages = []
st.session_state.retriever = None
st.session_state.chain = None
st.session_state.current_sources = []
# ==============================================
# 通用查詢類型識別
# ==============================================
def identify_query_type(query: str) -> str:
"""識別查詢的類型,以優化回答策略"""
# 檢測是否為繼續或完整內容請求
if re.search(
r'後面是什麼|後面要接什麼|下一句是什麼|接著是什麼|如何接龍|怎麼接下去|全部內容|完整的|繼續|接續|更多|全文',
query):
return "completion"
# 檢測是否為解釋意思的請求
if re.search(r'是什麼意思|什麼意思|的意思|啥意思|解釋|說明|代表什麼|定義|含義|詮釋', query):
return "explanation"
# 檢測是否為摘要請求
if re.search(r'摘要|總結|概述|簡述|大意|重點|主旨', query):
return "summary"
# 檢測是否為比較請求
if re.search(r'比較|差別|區別|不同|相似|共同點|差異', query):
return "comparison"
# 默認為一般問題
return "general"
def extract_keyword_from_query(query: str) -> str:
"""從查詢中提取關鍵詞"""
# 處理完整內容類查詢
completion_pattern = r'(.+?)(?:後面是什麼|下一句是什麼|接著是什麼|如何接龍|怎麼接下去|繼續|接續|更多|全文)'
match = re.search(completion_pattern, query)
if match:
return match.group(1).strip()
# 處理解釋類查詢
explanation_pattern = r'(.+?)(?:是什麼意思|什麼意思|的意思|啥意思|解釋|說明|代表什麼|定義|含義|詮釋)'
match = re.search(explanation_pattern, query)
if match:
return match.group(1).strip()
# 處理摘要類查詢
summary_pattern = r'(.+?)(?:摘要|總結|概述|簡述|大意|重點|主旨)'
match = re.search(summary_pattern, query)
if match:
return match.group(1).strip()
# 處理比較類查詢
comparison_pattern = r'(.+?)(?:和|與|跟|同)(.+?)(?:比較|差別|區別|不同|相似|共同點|差異)'
match = re.search(comparison_pattern, query)
if match:
return f"{match.group(1).strip()} {match.group(2).strip()}"
# 如果都沒匹配到,返回整個查詢作為關鍵詞
return query
def preprocess_query(query: str, topic_name: str = "") -> str:
"""預處理查詢,優化檢索效果"""
query_type = identify_query_type(query)
keyword = extract_keyword_from_query(query)
# 根據查詢類型調整查詢字串,提升檢索效果
if query_type == "completion":
return f"完整內容 {keyword}"
elif query_type == "explanation":
return f"解釋 {keyword}"
elif query_type == "summary":
return f"摘要 {keyword}"
elif query_type == "comparison":
return f"比較 {keyword}"
else:
return query
# ==============================================
# 通用文本處理
# ==============================================
def process_text(text: str) -> str:
"""處理文本,標準化格式"""
# 移除多餘空格
text = re.sub(r'\s+', ' ', text).strip()
# 將常見的章節標記轉換為標準格式,便於分割和檢索
text = re.sub(r'第[一二三四五六七八九十\d]+章', '## ', text)
text = re.sub(r'第[一二三四五六七八九十\d]+節', '### ', text)
return text
# ==============================================
# 優化的PDF加載函數
# ==============================================
def load_pdf_with_fallback(file_path: str) -> List[Document]:
"""嘗試使用多種PDF加載器來確保中文內容被正確提取"""
try:
# 先嘗試 PyPDFLoader
docs = PyPDFLoader(file_path).load()
# 檢查提取的文本是否有效
if docs and any(len(doc.page_content.strip()) > 50 for doc in docs):
return docs
except Exception as e:
st.error(f"PyPDFLoader 載入失敗: {str(e)}")
try:
# 嘗試使用 UnstructuredPDFLoader 作為後備
docs = UnstructuredPDFLoader(file_path).load()
return docs
except Exception as e:
st.error(f"UnstructuredPDFLoader 載入失敗: {str(e)}")
raise ValueError(f"無法載入PDF: {str(e)}")
# ==============================================
# multi-user 建立向量索引(移除 cache,改為 session 級別)
# ==============================================
def build_retriever_from_files(uploaded_files, api_key: str):
"""從上傳的PDF文件建立優化的檢索器"""
tmpdir = tempfile.TemporaryDirectory()
tmp_paths = []
for f in uploaded_files:
p = Path(tmpdir.name) / f.name
with open(p, "wb") as out:
out.write(f.read())
tmp_paths.append(str(p))
all_docs = []
for p in tmp_paths:
docs = load_pdf_with_fallback(p)
for doc in docs:
doc.page_content = process_text(doc.page_content)
all_docs.extend(docs)
splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=150,
separators=["##", "###", "\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
)
chunks = splitter.split_documents(all_docs)
# 使用用戶的 API Key
embeddings = OpenAIEmbeddings(
model="text-embedding-3-small",
openai_api_key=api_key
)
vs = FAISS.from_documents(chunks, embeddings)
return vs.as_retriever(search_kwargs={"k": 4})
# ==============================================
# 預設的領域提示模板
# ==============================================
# 通用提示模板
DEFAULT_SYSTEM_PROMPT = """你是一位專業的知識助手,專精於{topic}相關知識。請根據以下文檔內容回答問題。
請注意以下特殊情況:
1. 如果問題是關於完整內容或繼續閱讀,請提供完整信息。
2. 如果問題是關於解釋意思,請詳細解釋含義、用法和背景。
3. 如果問題是關於摘要,請提煉出最重要的概念和要點。
4. 如果問題是關於比較,請清晰列出相似點和不同點。
5. 即使問題可能不完整或模糊,也請盡量根據提供的信息回答。
請用繁體中文回答,語氣親切專業。如果找不到確切答案,請說明無法找到相關資訊,而不是簡單回答「我不知道」。"""
# 醫學領域的系統提示
MEDICAL_SYSTEM_PROMPT = """你是一位醫學專家,專精於{topic}相關醫學知識。請根據以下文檔內容回答醫學相關問題。
在回答時請遵循以下原則:
1. 保持專業準確性,使用醫學術語但同時確保病人或普通人可以理解。
2. 區分已確立的醫學共識與仍有爭議的領域。
3. 對於病症解釋,提供清晰的定義、可能的原因、典型症狀和標準治療方法。
4. 對於藥物相關問題,說明作用機制、常見副作用和用藥注意事項。
5. 你是基於專業個人化建議,但病人病況嚴重時得適時提醒本建議不能替代專業醫生的診斷和建議。
請用繁體中文回答,語氣專業但平易近人。如果找不到確切答案,請明確說明而不是猜測。"""
# 法律領域的系統提示
LEGAL_SYSTEM_PROMPT = """你是一位法律專家,專精於{topic}相關法律知識。請根據以下文檔內容回答法律相關問題。
在回答時請遵循以下原則:
1. 清晰界定法律概念和術語。
2. 引述相關法條或判例時,提供準確的出處。
3. 區分法律事實與法律見解。
4. 說明不同情況下可能適用的不同法律解釋。
5. 提醒使用者,你的回答僅供參考,不構成法律建議,複雜法律問題應諮詢執業律師。
請用繁體中文回答,語氣專業嚴謹。如果找不到確切答案,請明確說明而不是提供不確定的法律意見。"""
# 技術領域的系統提示
TECH_SYSTEM_PROMPT = """你是一位技術專家,專精於{topic}相關技術知識。請根據以下文檔內容回答技術相關問題。
在回答時請遵循以下原則:
1. 提供簡潔明確的技術解釋,適當包含代碼示例。
2. 區分核心概念和實現細節。
3. 說明技術選擇的優缺點和適用場景。
4. 關注最佳實踐和常見的錯誤模式。
5. 適當引入類比或視覺化解釋,幫助理解複雜概念。
請用繁體中文回答,語氣專業但平易近人。如果找不到確切答案,請明確說明而不是提供猜測性的技術建議。"""
# 教育領域的系統提示
EDUCATION_SYSTEM_PROMPT = """你是一位教育專家,專精於{topic}相關教育知識。請根據以下文檔內容回答教育相關問題。
在回答時請遵循以下原則:
1. 適應不同學習階段學生的需求,提供適齡的解釋。
2. 結合具體例子和情境來解釋抽象概念。
3. 關注學習目標、評估方法和教學策略。
4. 鼓勵批判性思考和自主學習。
5. 考慮不同學習風格和能力的差異化教學方法。
請用繁體中文回答,語氣親切鼓勵。如果找不到確切答案,請建議可能的替代學習資源或方向。"""
# 金融領域的系統提示
FINANCE_SYSTEM_PROMPT = """你是一位金融專家,專精於{topic}相關金融知識。請根據以下文檔內容回答金融相關問題。
在回答時請遵循以下原則:
1. 清晰解釋金融概念、產品和服務的特點。
2. 討論金融決策時,考慮風險、收益和時間因素。
3. 使用適當的數據和指標來支持分析。
4. 關注市場趨勢和經濟環境的影響。
5. 提醒使用者,金融決策應基於個人情況,並可能需要專業顧問的建議。
請用繁體中文回答,語氣專業客觀。如果找不到確切答案,請明確說明而不是提供可能誤導的金融建議。"""
# 歷史領域的系統提示
HISTORY_SYSTEM_PROMPT = """你是一位歷史專家,專精於{topic}相關歷史知識。請根據以下文檔內容回答歷史相關問題。
在回答時請遵循以下原則:
1. 提供準確的年代、人物和事件資訊。
2. 區分歷史事實與歷史解釋或觀點。
3. 考慮歷史事件的背景、原因和影響。
4. 呈現多元的歷史視角,尤其是不同文化或群體的觀點。
5. 適當引用歷史資料來源或學術研究。
請用繁體中文回答,語氣學術但生動。如果找不到確切答案,請明確說明歷史記錄的局限性。"""
# ==============================================
# 獲取領域提示
# ==============================================
def get_domain_prompt(domain: str, topic: str) -> str:
"""根據領域返回適當的系統提示模板"""
domain_prompts = {
"通用": DEFAULT_SYSTEM_PROMPT,
"醫學": MEDICAL_SYSTEM_PROMPT,
"法律": LEGAL_SYSTEM_PROMPT,
"技術": TECH_SYSTEM_PROMPT,
"教育": EDUCATION_SYSTEM_PROMPT,
"金融": FINANCE_SYSTEM_PROMPT,
"歷史": HISTORY_SYSTEM_PROMPT
}
prompt_template = domain_prompts.get(domain, DEFAULT_SYSTEM_PROMPT)
return prompt_template.format(topic=topic)
# ==============================================
# 優化的提示模板
# ==============================================
def create_qa_prompt(system_prompt: str):
"""根據系統提示創建QA提示模板"""
template = f"""{system_prompt}
文檔內容:
{{context}}
問題: {{question}}
回答:"""
return PromptTemplate(
template=template,
input_variables=["context", "question"]
)
# ==============================================
# multi-user 建立對話鏈(移除 cache,使用用戶的 API Key)
# ==============================================
def make_chain(_retriever, system_prompt: str, topic: str, api_key: str):
"""建立針對特定主題和系統提示優化的對話鏈"""
llm = ChatOpenAI(
model=MODEL_NAME,
temperature=0,
openai_api_key=api_key
)
qa_prompt = create_qa_prompt(system_prompt)
condense_prompt = PromptTemplate.from_template(f"""
根據以下對話歷史和最新的問題,生成一個獨立的問題,該問題應包含所有必要的上下文信息,以便於找到與{topic}相關的答案。
對話歷史:
{{chat_history}}
最新問題: {{question}}
獨立問題:
""")
# 記憶模組設定
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True
)
# 創建問題生成器
question_generator = LLMChain(
llm=llm,
prompt=condense_prompt
)
# 使用自定義問答鏈
qa_chain = load_qa_chain(
llm=llm,
chain_type="stuff",
prompt=qa_prompt,
verbose=True
)
# 建立對話檢索鏈
chain = ConversationalRetrievalChain(
retriever=_retriever,
combine_docs_chain=qa_chain,
question_generator=question_generator,
memory=memory,
return_source_documents=True,
verbose=True,
output_key="answer"
)
return chain
# ==============================================
# 處理查詢的主函數
# ==============================================
def process_query(query: str, chain, topic_name: str) -> Dict[str, Any]:
"""處理用戶查詢,返回結果和來源"""
# 預處理查詢以優化檢索效果
processed_query = preprocess_query(query, topic_name)
# 使用對話鏈處理查詢
result = chain({"question": processed_query})
return result
# ==============================================
# Streamlit App UI
# ==============================================
def main():
########multi-user start############
# 初始化用戶 session
init_user_session()
st.set_page_config(
page_title="專業PDF閱讀助手",
page_icon="📚",
layout="wide"
)
st.title("📚 專業PDF閱讀助手")
st.caption("🔑 輸入你的 API Key → 上傳PDF → 建立索引 → 與文件對話")
with st.sidebar:
st.markdown("### 🔐 API Key 設定")
# API Key 輸入
api_key_input = st.text_input(
"OpenAI API Key",
type="password",
value=st.session_state.api_key,
help="請輸入你的 OpenAI API Key (sk-...)"
)
# 驗證 API Key
if api_key_input and api_key_input != st.session_state.api_key:
with st.spinner("驗證 API Key..."):
if validate_api_key(api_key_input):
st.session_state.api_key = api_key_input
st.session_state.api_key_valid = True
st.success("✅ API Key 驗證成功!")
# 清除舊的 chain 和 retriever
st.session_state.chain = None
st.session_state.retriever = None
else:
st.session_state.api_key_valid = False
if not st.session_state.api_key_valid:
st.warning("⚠️ 請先輸入有效的 OpenAI API Key")
st.info("💡 如何取得 API Key:\n1. 前往 platform.openai.com\n2. 登入你的帳號\n3. 進入 API Keys 頁面\n4. 創建新的 API Key")
return
st.markdown("---")
st.markdown("### ⚙️ 系統設定")
st.write(f"使用的模型: {MODEL_NAME}")
########multi-user end############
# 自定義主題名稱
topic_name = st.text_input(
"知識庫主題名稱",
value=st.session_state.custom_settings["topic_name"]
)
# 專家角色
expert_role = st.text_input(
"AI 助手角色",
value=st.session_state.custom_settings["expert_role"]
)
# 領域選擇
domain_options = ["通用", "醫學", "法律", "技術", "教育", "金融", "歷史", "自訂"]
selected_domain = st.selectbox(
"專業領域",
options=domain_options,
index=domain_options.index(st.session_state.custom_settings.get("domain", "通用"))
)
# 自訂系統提示
custom_system_prompt = st.session_state.custom_settings.get("custom_system_prompt", "")
if selected_domain == "自訂":
st.markdown("### 自訂系統提示 (System Prompt)")
custom_system_prompt = st.text_area(
"請輸入自訂的系統提示",
value=custom_system_prompt,
height=300,
help="這裡的提示將指導AI如何回答問題,可以定義角色、回答風格和專業領域特性"
)
else:
# 顯示預設提示供參考
with st.expander(f"查看{selected_domain}領域的系統提示"):
domain_prompt = get_domain_prompt(selected_domain, topic_name)
st.code(domain_prompt)
# 保存設定
settings_changed = (
topic_name != st.session_state.custom_settings["topic_name"] or
expert_role != st.session_state.custom_settings["expert_role"] or
selected_domain != st.session_state.custom_settings.get("domain", "通用") or
(selected_domain == "自訂" and custom_system_prompt != st.session_state.custom_settings.get(
"custom_system_prompt", ""))
)
if settings_changed:
st.session_state.custom_settings["topic_name"] = topic_name
st.session_state.custom_settings["expert_role"] = expert_role
st.session_state.custom_settings["domain"] = selected_domain
if selected_domain == "自訂":
st.session_state.custom_settings["custom_system_prompt"] = custom_system_prompt
####multi user start################
# 清除對話鏈和歷史
if st.session_state.chain:
st.session_state.chain = None
st.session_state.messages = []
st.session_state.retriever = None # 加上這行
st.info("設定已更新,請重新上傳文件建立索引")
# 檔案上傳區
st.markdown("---")
uploaded = st.file_uploader(
"上傳PDF",
type=["pdf"],
accept_multiple_files=True,
help="上傳PDF文件"
)
# 清空對話按鈕
if st.session_state.messages:
if st.button("🧹 清空對話"):
st.session_state.messages = []
st.rerun()
####multi user end###################
# 確定要使用的系統提示
if st.session_state.custom_settings["domain"] == "自訂":
system_prompt = st.session_state.custom_settings.get("custom_system_prompt", DEFAULT_SYSTEM_PROMPT)
else:
system_prompt = get_domain_prompt(
st.session_state.custom_settings["domain"],
st.session_state.custom_settings["topic_name"]
)
# multi user 當有文件上傳且尚未建立對話鏈時
if uploaded and not st.session_state.chain and st.session_state.api_key_valid:
with st.spinner(f"正在為「{topic_name}」建立文檔索引,這可能需要一點時間..."):
try:
# 建立檢索器(使用用戶的 API Key)
retriever = build_retriever_from_files(uploaded, st.session_state.api_key)
st.session_state.retriever = retriever
# 建立對話鏈(使用用戶的 API Key)
st.session_state.chain = make_chain(
retriever,
system_prompt,
st.session_state.custom_settings["topic_name"],
st.session_state.api_key
)
# 清空舊對話歷史 (加上這行)
st.session_state.messages = []
st.success(f"「{topic_name}」知識庫索引建立成功!您現在可以開始提問了。")
except Exception as e:
st.error(f"索引建立失敗: {str(e)}")
####multi-user end
# multi-user 提供範例問題,根據主題和領域動態生成
if st.session_state.chain and not st.session_state.messages: # 改用 st.session_state.chain
domain = st.session_state.custom_settings.get("domain", "通用")
examples = {
"通用": [
f"{topic_name}中的重要概念是什麼?",
f"請解釋{topic_name}中的關鍵術語是什麼意思?",
f"請提供{topic_name}的摘要。",
f"{topic_name}的第一章內容是什麼?"
],
"醫學": [
f"{topic_name}的病因是什麼?",
f"{topic_name}的診斷標準是什麼?",
f"{topic_name}和相似疾病的區別是什麼?",
f"請解釋{topic_name}的治療方法和注意事項。"
],
"法律": [
f"{topic_name}的法律定義是什麼?",
f"{topic_name}相關法條的適用範圍?",
f"{topic_name}在不同情況下如何適用?",
f"{topic_name}涉及的法律責任有哪些?"
],
"技術": [
f"{topic_name}的核心原理是什麼?",
f"{topic_name}在實際應用中的最佳實踐?",
f"{topic_name}相比其他技術有什麼優缺點?",
f"如何解決{topic_name}中常見的技術問題?"
]
}
# 如果沒有特定領域的範例,使用通用範例
domain_examples = examples.get(domain, examples["通用"])
with st.expander("💡 範例問題", expanded=True):
examples_md = f"嘗試提問以下關於「{topic_name}」的問題:\n\n"
for example in domain_examples:
examples_md += f"- {example}\n"
st.markdown(examples_md)
# 建立雙欄佈局
col1, col2 = st.columns([2, 1])
with col1:
# 顯示對話歷史
if "messages" in st.session_state:
for role, content in st.session_state.messages:
with st.chat_message(role):
st.markdown(content)
# multi user 使用者輸入
placeholder_text = f"請問關於{topic_name}的問題..." if st.session_state.chain else "請先輸入 API Key 並上傳PDF文件..."
prompt = st.chat_input(placeholder_text, disabled=not st.session_state.chain)
# multi user 處理使用者提問
if prompt and st.session_state.chain:
# 添加使用者訊息到對話歷史
st.session_state.messages.append(("user", prompt))
# 顯示使用者訊息
with st.chat_message("user"):
st.markdown(prompt)
# 顯示助手回應
with st.chat_message("assistant"):
with st.spinner(f"思考{topic_name}相關問題中..."):
try:
# 處理查詢
result = process_query(
prompt,
st.session_state.chain,
st.session_state.custom_settings["topic_name"]
)
answer = result["answer"]
# 顯示答案
st.markdown(answer)
# 保存源文檔以供右側面板顯示
st.session_state.current_sources = result.get("source_documents", [])
except Exception as e:
error_msg = f"處理問題時出錯: {str(e)}"
st.error(error_msg)
answer = "抱歉,處理您的問題時出現了錯誤。請再試一次。"
st.session_state.current_sources = []
# 添加助手回應到對話歷史
st.session_state.messages.append(("assistant", answer))
# 右側面板顯示源文檔和設置
with col2:
# 顯示當前系統提示概要
st.markdown(f"### ⚙️ 當前配置")
st.info(f"""
**主題**: {topic_name}
**角色**: {expert_role}
**領域**: {st.session_state.custom_settings.get('domain', '通用')}
""")
st.markdown("### 💡 參考來源")
if st.session_state.chain: # 改這裡
if st.session_state.current_sources:
for i, doc in enumerate(st.session_state.current_sources, 1):
with st.expander(f"來源 {i}"):
src = Path(doc.metadata.get("source", "?")).name
page = doc.metadata.get("page", None)
st.write(f"**文件**: {src}" + (f" · 第{page + 1}頁" if page is not None else ""))
st.markdown(f"**內容片段**:")
st.markdown(f"```\n{doc.page_content[:300]}...\n```")
else:
st.info("提問後將顯示參考來源")
else:
st.info("請先輸入 API Key 並上傳PDF文件")
# 如果尚未上傳任何文件
if not uploaded:
st.info("👆 請先在側邊欄輸入你的 OpenAI API Key,然後上傳PDF文件")
if __name__ == "__main__":
main()