File size: 5,252 Bytes
643bd7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
"""
# 导入必要的库与模块
import os
import textwrap

from dotenv import load_dotenv
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Weaviate
from weaviate import Client
from weaviate.embedded import EmbeddedOptions

# 环境设置与文档下载
load_dotenv()  # 加载环境变量
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")  # 从环境变量获取 OpenAI API 密钥

# 确保 OPENAI_API_KEY 被正确设置
if not OPENAI_API_KEY:
    raise ValueError("OpenAI API Key not found in the environment variables.")


# 文档加载与分割
def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50):
    """加载文档并分割成小块"""
    loader = TextLoader(file_path)
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    chunks = text_splitter.split_documents(documents)
    return chunks


# 向量存储建立
def create_vector_store(chunks, model="OpenAI"):
    """将文档块转换为向量并存储到 Weaviate 中"""
    client = Client(embedded_options=EmbeddedOptions())
    embedding_model = OpenAIEmbeddings() if model == "OpenAI" else None  # 可以根据需要替换为其他嵌入模型
    vectorstore = Weaviate.from_documents(
        client=client,
        documents=chunks,
        embedding=embedding_model,
        by_text=False
    )
    return vectorstore.as_retriever()


# 定义检索增强生成流程
def setup_rag_chain_v0(retriever, model_name="gpt-4", temperature=0):
    """设置检索增强生成流程"""
    prompt_template = """You are an assistant for question-answering tasks. 
    Use your knowledge to answer the question if the provided context is not relevant. 
    Otherwise, use the context to inform your answer.
    Question: {question} 
    Context: {context} 
    Answer:
    """
    prompt = ChatPromptTemplate.from_template(prompt_template)
    llm = ChatOpenAI(model_name=model_name, temperature=temperature)
    # 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
    rag_chain = (
            {"context": retriever, "question": RunnablePassthrough()}
            | prompt
            | llm
            | StrOutputParser()
    )
    return rag_chain


# 执行查询并打印结果
def execute_query_v0(rag_chain, query):
    """执行查询并返回结果"""
    return rag_chain.invoke(query)


# 执行无 RAG 链的查询
def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
    """执行无 RAG 链的查询"""
    llm = ChatOpenAI(model_name=model_name, temperature=temperature)
    response = llm.invoke(query)
    return response.content


# rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。
if __name__ == "__main__":
    # 下载并保存文档到本地(这里被注释掉了,因为已经假设文档存在于本地)
    # url = "https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs/modules/state_of_the_union.txt"
    # res = requests.get(url)
    # with open("state_of_the_union.txt", "w") as f:
    #     f.write(res.text)

    # 假设文档已存在于本地
    # file_path = './documents/state_of_the_union.txt'
    file_path = './documents/LightZero_README.zh.md'

    # 加载和分割文档
    chunks = load_and_split_document(file_path)

    # 创建向量存储
    retriever = create_vector_store(chunks)

    # 设置 RAG 流程
    rag_chain = setup_rag_chain_v0(retriever)

    # 提出问题并获取答案
    # query = "请你分别用中英文简介 LightZero"
    # query = "请你用英文简介 LightZero"
    query = "请你用中文简介 LightZero"
    # query = "请问 LightZero 支持哪些环境和算法,应该如何快速上手使用?"
    # query = "请问 LightZero 里面实现的 MuZero 算法支持在 Atari 环境上运行吗?"
    # query = "请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 环境上运行吗?请详细解释原因"
    # query = "请详细解释 MCTS 算法的原理,并给出带有详细中文注释的 Python 代码示例"

    # 使用 RAG 链获取答案
    result_with_rag = execute_query_v0(rag_chain, query)

    # 不使用 RAG 链获取答案
    result_without_rag = execute_query_no_rag(query=query)

    # 打印并对比两种方法的结果
    # 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
    wrapped_result_with_rag = textwrap.fill(result_with_rag, width=80)
    wrapped_result_without_rag = textwrap.fill(result_without_rag, width=80)

    # 打印自动分段后的文本
    print("="*40)
    print(f"我的问题是:\n{query}")
    print("="*40)
    print(f"Result with RAG:\n{wrapped_result_with_rag}")
    print("="*40)
    print(f"Result without RAG:\n{wrapped_result_without_rag}")