ZeroPal / rag_demo_v0.py
zjowowen's picture
init space
643bd7e
raw
history blame
No virus
5.25 kB
"""
参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
"""
# 导入必要的库与模块
import os
import textwrap
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Weaviate
from weaviate import Client
from weaviate.embedded import EmbeddedOptions
# 环境设置与文档下载
load_dotenv() # 加载环境变量
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
# 确保 OPENAI_API_KEY 被正确设置
if not OPENAI_API_KEY:
raise ValueError("OpenAI API Key not found in the environment variables.")
# 文档加载与分割
def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50):
"""加载文档并分割成小块"""
loader = TextLoader(file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = text_splitter.split_documents(documents)
return chunks
# 向量存储建立
def create_vector_store(chunks, model="OpenAI"):
"""将文档块转换为向量并存储到 Weaviate 中"""
client = Client(embedded_options=EmbeddedOptions())
embedding_model = OpenAIEmbeddings() if model == "OpenAI" else None # 可以根据需要替换为其他嵌入模型
vectorstore = Weaviate.from_documents(
client=client,
documents=chunks,
embedding=embedding_model,
by_text=False
)
return vectorstore.as_retriever()
# 定义检索增强生成流程
def setup_rag_chain_v0(retriever, model_name="gpt-4", temperature=0):
"""设置检索增强生成流程"""
prompt_template = """You are an assistant for question-answering tasks.
Use your knowledge to answer the question if the provided context is not relevant.
Otherwise, use the context to inform your answer.
Question: {question}
Context: {context}
Answer:
"""
prompt = ChatPromptTemplate.from_template(prompt_template)
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
# 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
# 执行查询并打印结果
def execute_query_v0(rag_chain, query):
"""执行查询并返回结果"""
return rag_chain.invoke(query)
# 执行无 RAG 链的查询
def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
"""执行无 RAG 链的查询"""
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
response = llm.invoke(query)
return response.content
# rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。
if __name__ == "__main__":
# 下载并保存文档到本地(这里被注释掉了,因为已经假设文档存在于本地)
# url = "https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs/modules/state_of_the_union.txt"
# res = requests.get(url)
# with open("state_of_the_union.txt", "w") as f:
# f.write(res.text)
# 假设文档已存在于本地
# file_path = './documents/state_of_the_union.txt'
file_path = './documents/LightZero_README.zh.md'
# 加载和分割文档
chunks = load_and_split_document(file_path)
# 创建向量存储
retriever = create_vector_store(chunks)
# 设置 RAG 流程
rag_chain = setup_rag_chain_v0(retriever)
# 提出问题并获取答案
# query = "请你分别用中英文简介 LightZero"
# query = "请你用英文简介 LightZero"
query = "请你用中文简介 LightZero"
# query = "请问 LightZero 支持哪些环境和算法,应该如何快速上手使用?"
# query = "请问 LightZero 里面实现的 MuZero 算法支持在 Atari 环境上运行吗?"
# query = "请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 环境上运行吗?请详细解释原因"
# query = "请详细解释 MCTS 算法的原理,并给出带有详细中文注释的 Python 代码示例"
# 使用 RAG 链获取答案
result_with_rag = execute_query_v0(rag_chain, query)
# 不使用 RAG 链获取答案
result_without_rag = execute_query_no_rag(query=query)
# 打印并对比两种方法的结果
# 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
wrapped_result_with_rag = textwrap.fill(result_with_rag, width=80)
wrapped_result_without_rag = textwrap.fill(result_without_rag, width=80)
# 打印自动分段后的文本
print("="*40)
print(f"我的问题是:\n{query}")
print("="*40)
print(f"Result with RAG:\n{wrapped_result_with_rag}")
print("="*40)
print(f"Result without RAG:\n{wrapped_result_without_rag}")