ZeroPal / rag_demo.py
zjowowen's picture
update LightZero RAG
6aae17d
raw
history blame
13.8 kB
"""
参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
"""
# 导入必要的库与模块
import json
import os
import textwrap
import requests
from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, TensorflowHubEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Weaviate
from weaviate import Client
from weaviate.embedded import EmbeddedOptions
from zhipuai import ZhipuAI
from openai import AzureOpenAI
# 环境设置与文档下载
load_dotenv() # 加载环境变量
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
MIMIMAX_API_KEY = os.getenv("MIMIMAX_API_KEY")
MIMIMAX_GROUP_ID = os.getenv("MIMIMAX_GROUP_ID")
ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
KIMI_OPENAI_API_KEY = os.getenv("KIMI_OPENAI_API_KEY")
AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY")
AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT")
# 确保 OPENAI_API_KEY 被正确设置
if not OPENAI_API_KEY:
raise ValueError("OpenAI API Key not found in the environment variables.")
# 文档加载与分割
def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50):
"""加载文档并分割成小块"""
loader = TextLoader(file_path)
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = text_splitter.split_documents(documents)
return chunks
# 向量存储建立
def create_vector_store(chunks, model="OpenAI", k=4):
"""将文档块转换为向量并存储到 Weaviate 中"""
client = Client(embedded_options=EmbeddedOptions())
if model == "OpenAI":
embedding_model = OpenAIEmbeddings()
elif model == "HuggingFace":
embedding_model = HuggingFaceEmbeddings()
elif model == "TensorflowHub":
embedding_model = TensorflowHubEmbeddings()
else:
raise ValueError(f"Unsupported embedding model: {model}")
vectorstore = Weaviate.from_documents(
client=client,
documents=chunks,
embedding=embedding_model,
by_text=False
)
return vectorstore.as_retriever(search_kwargs={'k': k})
def setup_rag_chain(model_name="gpt-4", temperature=0):
"""设置检索增强生成流程"""
if model_name.startswith("gpt"):
# 如果是以gpt开头的模型,使用原来的逻辑
prompt_template = """您是一个用于问答任务的专业助手。
在处理问答任务时,请根据所提供的[上下文信息]给出回答。
如果[上下文信息]与[问题]不相关,那么请运用您的知识库为提问者提供准确的答复。
请确保回答内容的质量, 包括相关性、准确性和可读性。
[问题]: {question}
[上下文信息]: {context}
[回答]:
"""
prompt = ChatPromptTemplate.from_template(prompt_template)
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
# 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
rag_chain = (
prompt
| llm
| StrOutputParser()
)
else:
# 如果不是以gpt开头的模型,返回None
rag_chain = None
return rag_chain
# 执行查询并打印结果
def execute_query(retriever, rag_chain, query, model_name="gpt-4", temperature=0):
"""
执行查询并返回结果及检索到的文档块
参数:
retriever: 文档检索器对象
rag_chain: 检索增强生成链对象,如果为None则不使用RAG链
query: 查询问题
model_name: 使用的语言模型名称,默认为"gpt-4"
temperature: 生成温度,默认为0
返回:
retrieved_documents: 检索到的文档块列表
response_text: 生成的回答文本
"""
# 使用检索器检索相关文档块
retrieved_documents = retriever.invoke(query)
if rag_chain is not None:
# 如果有RAG链,则使用RAG链生成回答
rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": query})
response_text = rag_chain_response
else:
# 如果没有RAG链,则将检索到的文档块和查询问题按照指定格式输入给语言模型
if model_name == "kimi":
# 对于有检索能力的模型,使用不同的模板
prompt_template = """您是一个用于问答任务的专业助手。
在处理问答任务时,请根据所提供的【上下文信息】和【你的知识库和检索到的相关文档】给出回答。
请确保回答内容的质量,包括相关性、准确性和可读性。
【问题】: {question}
【上下文信息】: {context}
【回答】:
"""
else:
prompt_template = """您是一个用于问答任务的专业助手。
在处理问答任务时,请根据所提供的【上下文信息】给出回答。
如果【上下文信息】与【问题】不相关,那么请运用您的知识库为提问者提供准确的答复。
请确保回答内容的质量,包括相关性、准确性和可读性。
【问题】: {question}
【上下文信息】: {context}
【回答】:
"""
context = '\n'.join(
[f'**Document {i}**: ' + retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
prompt = prompt_template.format(question=query, context=context)
response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt)
return retrieved_documents, response_text
def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
"""执行无 RAG 链的查询"""
if model_name.startswith("gpt"):
# 如果是以gpt开头的模型,使用原来的逻辑
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
response = llm.invoke(query)
return response.content
elif model_name.startswith("azure_gpt"):
client = AzureOpenAI(
azure_endpoint=AZURE_ENDPOINT,
api_key=AZURE_OPENAI_KEY,
api_version="2024-02-15-preview"
)
message_text = [{"role": "user", "content": query}, ]
completion = client.chat.completions.create(
model=model_name[6:], # model_name = 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'
messages=message_text,
temperature=temperature,
top_p=0.95,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
return completion.choices[0].message.content
elif model_name == 'abab6-chat':
# 如果是'abab6-chat'模型,使用专门的API调用方式
url = "https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId=" + MIMIMAX_GROUP_ID
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + MIMIMAX_API_KEY}
payload = {
"bot_setting": [
{
"bot_name": "MM智能助理",
"content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。",
}
],
"messages": [{"sender_type": "USER", "sender_name": "小明", "text": query}],
"reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
"model": model_name,
"tokens_to_generate": 1034,
"temperature": temperature,
"top_p": 0.9,
}
response = requests.request("POST", url, headers=headers, json=payload)
# 将 JSON 字符串解析为字典
response_dict = json.loads(response.text)
# 提取 'reply' 键对应的值
return response_dict['reply']
elif model_name == 'glm-4':
# 如果是'glm-4'模型,使用专门的API调用方式
client = ZhipuAI(api_key=ZHIPUAI_API_KEY) # 填写您自己的APIKey
response = client.chat.completions.create(
model=model_name, # 填写需要调用的模型名称
messages=[{"role": "user", "content": query}]
)
return response.choices[0].message.content
elif model_name == 'kimi':
# 如果是'kimi'模型,使用专门的API调用方式
from openai import OpenAI
client = OpenAI(
api_key=KIMI_OPENAI_API_KEY,
base_url="https://api.moonshot.cn/v1",
)
messages = [
{
"role": "system",
"content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。",
},
{"role": "user",
"content": query},
]
completion = client.chat.completions.create(
model="moonshot-v1-128k",
messages=messages,
temperature=0.01,
top_p=1.0,
n=1, # 为每条输入消息生成多少个结果
stream=False # 流式输出
)
return completion.choices[0].message.content
else:
# 如果模型不支持,抛出异常
raise ValueError(f"Unsupported model: {model_name}")
if __name__ == "__main__":
# 假设文档已存在于本地
file_path = './documents/LightZero_README.zh.md'
# model_name = "glm-4" # model_name=['abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo']
model_name = 'azure_gpt-4'
temperature = 0.01
# embedding_model = 'HuggingFace' # embedding_model=['HuggingFace', 'TensorflowHub', 'OpenAI']
embedding_model = 'OpenAI' # embedding_model=['HuggingFace', 'TensorflowHub', 'OpenAI']
# 加载和分割文档
chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
# 创建向量存储
retriever = create_vector_store(chunks, model=embedding_model, k=5)
# 设置 RAG 流程
rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
# 提出问题并获取答案
query = ("GitHub - opendilab/LightZero: [NeurIPS 2023 Spotlight] LightZero: A Unified Benchmark for Monte Carl 请根据这个仓库回答下面的问题:(1)请简要介绍一下 LightZero (2)请详细介绍 LightZero 的框架结构。 (3)请给出安装 LightZero,运行他们的示例代码的详细步骤 (4)- 请问 LightZero 具体支持什么任务(tasks/environments)? (5)请问 LightZero 具体支持什么算法?(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行? (7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。(11)请问对这个仓库提出详细的改进建议")
"""
(1)请简要介绍一下 LightZero
(2)请详细介绍 LightZero 的框架结构。
(3)请给出安装 LightZero,运行他们的示例代码的详细步骤
(4)请问 LightZero 具体支持什么任务(tasks/environments)?
(5)请问 LightZero 具体支持什么算法?
(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行?
(7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?
(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?
(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?
(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。
(11)请问对这个仓库提出详细的改进建议。
"""
# 使用 RAG 链获取参考的文档与答案
retrieved_documents, result_with_rag = execute_query(retriever, rag_chain, query, model_name=model_name,
temperature=temperature)
# 不使用 RAG 链获取答案
result_without_rag = execute_query_no_rag(model_name=model_name, query=query, temperature=temperature)
# 打印并对比两种方法的结果
# 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
wrapped_result_with_rag = textwrap.fill(result_with_rag, width=80)
wrapped_result_without_rag = textwrap.fill(result_without_rag, width=80)
context = '\n'.join(
[f'**Document {i}**: ' + retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
# 打印自动分段后的文本
print("=" * 40)
print(f"我的问题是:\n{query}")
print("=" * 40)
print(f"Result with RAG:\n{wrapped_result_with_rag}\n检索得到的context是: \n{context}")
print("=" * 40)
print(f"Result without RAG:\n{wrapped_result_without_rag}")
print("=" * 40)