""" 参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA """ # 导入必要的库与模块 import os import textwrap from dotenv import load_dotenv from langchain.chat_models import ChatOpenAI from langchain.document_loaders import TextLoader from langchain.embeddings import OpenAIEmbeddings from langchain.prompts import ChatPromptTemplate from langchain.schema.output_parser import StrOutputParser from langchain.schema.runnable import RunnablePassthrough from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Weaviate from weaviate import Client from weaviate.embedded import EmbeddedOptions # 环境设置与文档下载 load_dotenv() # 加载环境变量 OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥 # 确保 OPENAI_API_KEY 被正确设置 if not OPENAI_API_KEY: raise ValueError("OpenAI API Key not found in the environment variables.") # 文档加载与分割 def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50): """加载文档并分割成小块""" loader = TextLoader(file_path) documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) chunks = text_splitter.split_documents(documents) return chunks # 向量存储建立 def create_vector_store(chunks, model="OpenAI"): """将文档块转换为向量并存储到 Weaviate 中""" client = Client(embedded_options=EmbeddedOptions()) embedding_model = OpenAIEmbeddings() if model == "OpenAI" else None # 可以根据需要替换为其他嵌入模型 vectorstore = Weaviate.from_documents( client=client, documents=chunks, embedding=embedding_model, by_text=False ) return vectorstore.as_retriever() # 定义检索增强生成流程 def setup_rag_chain_v0(retriever, model_name="gpt-4", temperature=0): """设置检索增强生成流程""" prompt_template = """You are an assistant for question-answering tasks. Use your knowledge to answer the question if the provided context is not relevant. Otherwise, use the context to inform your answer. Question: {question} Context: {context} Answer: """ prompt = ChatPromptTemplate.from_template(prompt_template) llm = ChatOpenAI(model_name=model_name, temperature=temperature) # 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/ rag_chain = ( {"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) return rag_chain # 执行查询并打印结果 def execute_query_v0(rag_chain, query): """执行查询并返回结果""" return rag_chain.invoke(query) # 执行无 RAG 链的查询 def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""): """执行无 RAG 链的查询""" llm = ChatOpenAI(model_name=model_name, temperature=temperature) response = llm.invoke(query) return response.content # rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。 if __name__ == "__main__": # 下载并保存文档到本地(这里被注释掉了,因为已经假设文档存在于本地) # url = "https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs/modules/state_of_the_union.txt" # res = requests.get(url) # with open("state_of_the_union.txt", "w") as f: # f.write(res.text) # 假设文档已存在于本地 # file_path = './documents/state_of_the_union.txt' file_path = './documents/LightZero_README.zh.md' # 加载和分割文档 chunks = load_and_split_document(file_path) # 创建向量存储 retriever = create_vector_store(chunks) # 设置 RAG 流程 rag_chain = setup_rag_chain_v0(retriever) # 提出问题并获取答案 # query = "请你分别用中英文简介 LightZero" # query = "请你用英文简介 LightZero" query = "请你用中文简介 LightZero" # query = "请问 LightZero 支持哪些环境和算法,应该如何快速上手使用?" # query = "请问 LightZero 里面实现的 MuZero 算法支持在 Atari 环境上运行吗?" # query = "请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 环境上运行吗?请详细解释原因" # query = "请详细解释 MCTS 算法的原理,并给出带有详细中文注释的 Python 代码示例" # 使用 RAG 链获取答案 result_with_rag = execute_query_v0(rag_chain, query) # 不使用 RAG 链获取答案 result_without_rag = execute_query_no_rag(query=query) # 打印并对比两种方法的结果 # 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整 wrapped_result_with_rag = textwrap.fill(result_with_rag, width=80) wrapped_result_without_rag = textwrap.fill(result_without_rag, width=80) # 打印自动分段后的文本 print("="*40) print(f"我的问题是:\n{query}") print("="*40) print(f"Result with RAG:\n{wrapped_result_with_rag}") print("="*40) print(f"Result without RAG:\n{wrapped_result_without_rag}")