hf-rag-multi / app.py
siyu618's picture
Update app.py
2cfc531 verified
import os
import warnings
from dotenv import load_dotenv
import numpy as np
from sklearn.preprocessing import normalize
# 避免 tokenizers 并行警告
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", category=UserWarning, module="tokenizers")
# 文档加载
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Embeddings & 向量库
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# Prompt & Chains
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
# LLM
from langchain_community.chat_models import ChatOpenAI
# Gradio
import gradio as gr
# -----------------------------
# 配置
# -----------------------------
PDF_PATH = "data/pdfs/Stream-Processing-with-Apache-Flink.pdf"
CHUNK_SIZE = 512
CHUNK_OVERLAP = 50
TOP_K = 3
# -----------------------------
# 1️⃣ 加载环境变量
# -----------------------------
load_dotenv()
print("✅ Environment ready")
# -----------------------------
# 2️⃣ 加载 PDF 文档
# -----------------------------
loader = PyPDFLoader(PDF_PATH)
documents = loader.load()
print(f"✅ Loaded {len(documents)} pages")
# -----------------------------
# 3️⃣ 分割文本
# -----------------------------
text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
texts = text_splitter.split_documents(documents)
print(f"✅ Split into {len(texts)} chunks")
# -----------------------------
# 4️⃣ 生成向量 & 向量库
# -----------------------------
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
# 先计算 embeddings
vectors = embedding_model.embed_documents([doc.page_content for doc in texts])
# 归一化
vectors = normalize(np.array(vectors))
# 创建 FAISS 向量库
vector_store = FAISS.from_texts(
[doc.page_content for doc in texts],
embedding_model,
metadatas=[doc.metadata for doc in texts]
)
# 替换为归一化向量
vector_store.index.reset()
vector_store.index.add(vectors.astype(np.float32))
print("✅ Embeddings created, normalized and FAISS index ready")
# -----------------------------
# 5️⃣ 检索器
# -----------------------------
retriever = vector_store.as_retriever(search_kwargs={"k": TOP_K})
print("✅ Retriever ready")
# -----------------------------
# 6️⃣ LLM
# -----------------------------
llm = ChatOpenAI(
model_name="deepseek-chat", # 或 "gpt-3.5-turbo"
temperature=0.7,
max_tokens=512
)
print("✅ LLM ready")
# -----------------------------
# 7️⃣ Prompt 模板
# -----------------------------
template = """
Use the following context to answer the question. If unsure, say "I don't know."
Context:
{context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
# -----------------------------
# 8️⃣ 构建 RetrievalQA Chain
# -----------------------------
rag_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type_kwargs={"prompt": prompt},
return_source_documents=True
)
# -----------------------------
# 9️⃣ 构建对话记忆
# -----------------------------
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer"
)
# -----------------------------
# 10️⃣ 持续对话 RAG 链
# -----------------------------
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
memory=memory,
verbose=False
)
# -----------------------------
# 11️⃣ Gradio 问答函数
# -----------------------------
def answer_question(query, threshold=0.4):
# FAISS 里没有直接阈值过滤,所以我们可以先检索 TOP_K 后手动过滤
result = rag_chain({"query": query})
answer = result["result"]
sources = result.get("source_documents", [])
# 计算 cosine 相似度,并应用阈值
filtered_sources = []
for doc in sources:
emb = embedding_model.embed_documents([doc.page_content])[0]
emb = emb / np.linalg.norm(emb)
# query embedding
query_emb = embedding_model.embed_documents([query])[0]
query_emb = query_emb / np.linalg.norm(query_emb)
score = float(np.dot(emb, query_emb))
if score >= threshold:
filtered_sources.append((doc.page_content, score))
# 展示来源文档
context = "\n\n".join([f"Score: {score:.4f}\n{doc[:400]}..." for doc, score in filtered_sources])
return answer, context
# -----------------------------
# 12️⃣ Gradio 界面
# -----------------------------
demo = gr.Interface(
fn=answer_question,
inputs=[
gr.Textbox(label="🔎 输入你的问题"),
gr.Slider(0.0, 1.0, value=0.4, step=0.05, label="相似度阈值")
],
outputs=[
gr.Textbox(label="💬 模型回答"),
gr.Textbox(label="📄 检索到的文档")
],
title="📘 Multi-PDF RAG System"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)