zjowowen commited on
Commit
6aae17d
1 Parent(s): e370e35

update LightZero RAG

Browse files
Files changed (10) hide show
  1. .gitignore +0 -2
  2. README.md +3 -4
  3. README_zh.md +2 -3
  4. app.py +90 -58
  5. app_mqa.py +132 -0
  6. app_qa.py +106 -0
  7. assets/avatar.png +0 -0
  8. rag_demo.py +199 -37
  9. rag_demo_v0.py +0 -136
  10. requirements.txt +1 -0
.gitignore DELETED
@@ -1,2 +0,0 @@
1
- .env
2
- *bkp.py
 
 
 
README.md CHANGED
@@ -56,7 +56,6 @@ QUESTION_LANG='cn' # The language of the question, currently available option is
56
 
57
  ```python
58
 
59
- # The difference between rag_demo.py and rag_demo_v0.py is that it can output the retrieved document chunks.
60
  if __name__ == "__main__":
61
  # Assuming documents are already present locally
62
  file_path = './documents/LightZero_README.zh.md'
@@ -91,9 +90,9 @@ if __name__ == "__main__":
91
  ```
92
  RAG/
93
 
94
- ├── rag_demo_v0.py # RAG demonstration script without support for outputting retrieved document chunks.
95
  ├── rag_demo.py # RAG demonstration script with support for outputting retrieved document chunks.
96
- ├── app.py # Web-based interactive application built with Gradio and rag_demo.py.
 
97
  ├── .env # Environment variable configuration file
98
  └── documents/ # Documents folder
99
  └── your_document.txt # Context document
@@ -114,4 +113,4 @@ If you encounter any issues or require assistance, please submit a problem throu
114
 
115
  ## License
116
 
117
- All code in this repository is compliant with [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
 
56
 
57
  ```python
58
 
 
59
  if __name__ == "__main__":
60
  # Assuming documents are already present locally
61
  file_path = './documents/LightZero_README.zh.md'
 
90
  ```
91
  RAG/
92
 
 
93
  ├── rag_demo.py # RAG demonstration script with support for outputting retrieved document chunks.
94
+ ├── app_qa.py # Web-based interactive application built with Gradio and rag_demo.py.
95
+ ├── app_mqa.py # Web-based interactive application built with Gradio and rag_demo.py. Supports maintaining conversation history.
96
  ├── .env # Environment variable configuration file
97
  └── documents/ # Documents folder
98
  └── your_document.txt # Context document
 
113
 
114
  ## License
115
 
116
+ All code in this repository is compliant with [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
README_zh.md CHANGED
@@ -43,7 +43,6 @@ QUESTION_LANG='cn' # 问题语言,目前可选值为 'cn'
43
 
44
  ```python
45
 
46
- # rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。
47
  if __name__ == "__main__":
48
  # 假设文档已存在于本地
49
  file_path = './documents/LightZero_README.zh.md'
@@ -78,9 +77,9 @@ if __name__ == "__main__":
78
  ```
79
  RAG/
80
 
81
- ├── rag_demo_v0.py # RAG 演示脚本,不支持输出检索到的文档块。
82
  ├── rag_demo.py # RAG 演示脚本,支持输出检索到的文档块。
83
- ├── app.py # 基于 Gradio 和 rag_demo.py 构建的网页交互式应用。
 
84
  ├── .env # 环境变量配置文件
85
  └── documents/ # 文档文件夹
86
  └── your_document.txt # 上下文文档
 
43
 
44
  ```python
45
 
 
46
  if __name__ == "__main__":
47
  # 假设文档已存在于本地
48
  file_path = './documents/LightZero_README.zh.md'
 
77
  ```
78
  RAG/
79
 
 
80
  ├── rag_demo.py # RAG 演示脚本,支持输出检索到的文档块。
81
+ ├── app_qa.py # 基于 Gradio 和 rag_demo.py 构建的网页交互式应用。
82
+ ├── app_mqa.py # 基于 Gradio 和 rag_demo.py 构建的网页交互式应用。支持保持对话历史。
83
  ├── .env # 环境变量配置文件
84
  └── documents/ # 文档文件夹
85
  └── your_document.txt # 上下文文档
app.py CHANGED
@@ -1,16 +1,3 @@
1
- """
2
- 这段代码的整体功能是创建一个Gradio应用,用户可以在其中输入问题,应用会使用Retrieval-Augmented Generation (RAG)模型来寻找答案并将结果显示在界面上。
3
- 其中,检索到的上下文会在Markdown文档中高亮显示,帮助用户理解答案的来源。应用界面分为两部分:顶部是问答区,底部展示了RAG模型参考的上下文。
4
-
5
- 结构概述:
6
- - 导入必要的库和函数。
7
- - 设置环境变量和全局变量。
8
- - 加载和处理Markdown文档。
9
- - 定义处理用户问题并返回答案和高亮显示上下文的函数。
10
- - 使用Gradio构建用户界面,包括Markdown、输入框、按钮和输出框。
11
- - 启动Gradio应用并设置为可以分享。
12
- """
13
-
14
  import os
15
 
16
  import gradio as gr
@@ -22,7 +9,6 @@ from rag_demo import load_and_split_document, create_vector_store, setup_rag_cha
22
  # 环境设置
23
  load_dotenv() # 加载环境变量
24
  QUESTION_LANG = os.getenv("QUESTION_LANG") # 从环境变量获取 QUESTION_LANG
25
-
26
  assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG
27
 
28
  if QUESTION_LANG == "cn":
@@ -31,8 +17,8 @@ if QUESTION_LANG == "cn":
31
  <div align="center">
32
  <img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
33
  </div>
34
- <h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG"> 🎭LightZero RAG Demo</a></h2>
35
- <h4 align="center"> 📢说明:请您在下面的"问题"框中输入任何关于 LightZero 的问题,然后点击"提交"按钮。右侧"回答"框中会显示 RAG 模型给出的回答。在QA栏的下方会给出参考文档(检索得到的 context 用黄色高亮显示)。</h4>
36
  <h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。 </h4>
37
  <strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
38
  """
@@ -47,55 +33,101 @@ if QUESTION_LANG == "cn":
47
 
48
  # 路径变量,方便之后的文件使用
49
  file_path = './documents/LightZero_README.zh.md'
50
- chunks = load_and_split_document(file_path)
51
- retriever = create_vector_store(chunks)
52
- # rag_chain = setup_rag_chain(model_name="gpt-4")
53
- rag_chain = setup_rag_chain(model_name="gpt-3.5-turbo")
54
 
55
  # 加载原始Markdown文档
56
  loader = TextLoader(file_path)
57
  orig_documents = loader.load()
58
 
 
 
 
59
 
60
- def rag_answer(question):
61
- retrieved_documents, answer = execute_query(retriever, rag_chain, question)
62
- # Highlight the context in the document
63
- context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
64
- highlighted_document = orig_documents[0].page_content
65
- for i in range(len(context)):
66
- highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return answer, highlighted_document
68
 
69
- """
70
- 在下面的代码中,gr.Blocks构建了Gradio的界面布局,gr.Textbox用于创建文本输入框,gr.Button创建了一个按钮,gr.Markdown则用于显示Markdown格式的内容。
71
- gr_submit.click是一个事件处理器,当用户点击提交按钮时,它会调用rag_answer函数,并将输入和输出的组件关联起来。
72
- 代码中的rag_answer函数负责接收用户的问题,使用RAG模型检索和生成答案,并将检索到的文本段落在Markdown原文中高亮显示。
73
- 该函数返回模型生成的答案和高亮显示上下文的Markdown文本。
74
- """
75
- with gr.Blocks(title=title, theme='ParityError/Interstellar') as rag_demo:
76
- gr.Markdown(title_markdown)
77
-
78
- with gr.Row():
79
- with gr.Column():
80
- inputs = gr.Textbox(
81
- placeholder="请您输入任何关于 LightZero 的问题。",
82
- label="问题 (Q)") # 设置输出框,包括答案和高亮显示参考文档
83
- gr_submit = gr.Button('提交')
84
-
85
- outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
86
- label="回答 (A)")
87
- with gr.Row():
88
- # placeholder="当你点击提交按钮后,这里会显示参考的文档,其中检索得到的与问题最相关的 context 用高亮显示。"
89
- outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")
90
-
91
- gr.Markdown(tos_markdown)
92
-
93
- gr_submit.click(
94
- rag_answer,
95
- inputs=inputs,
96
- outputs=[outputs_answer, outputs_context],
97
- )
98
 
99
  if __name__ == "__main__":
100
- # 启动界面,设置为可以分享。如果分享公网链接失败,可以在本地执行 ngrok http 7860 将本地端口映射到公网
101
- rag_demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
 
3
  import gradio as gr
 
9
  # 环境设置
10
  load_dotenv() # 加载环境变量
11
  QUESTION_LANG = os.getenv("QUESTION_LANG") # 从环境变量获取 QUESTION_LANG
 
12
  assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG
13
 
14
  if QUESTION_LANG == "cn":
 
17
  <div align="center">
18
  <img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
19
  </div>
20
+ <h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG"> LightZero RAG Demo</a></h2>
21
+ <h4 align="center"> 📢说明:请您在下面的"问题(Q)"框中输入任何关于 LightZero 的问题,然后点击"提交"按钮。右侧"回答(A)"框中会显示 RAG 模型给出的回答。在 QA 栏的下方会给出参考文档(其中检索得到的相关文段会用黄色高亮显示)。</h4>
22
  <h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。 </h4>
23
  <strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
24
  """
 
33
 
34
  # 路径变量,方便之后的文件使用
35
  file_path = './documents/LightZero_README.zh.md'
 
 
 
 
36
 
37
  # 加载原始Markdown文档
38
  loader = TextLoader(file_path)
39
  orig_documents = loader.load()
40
 
41
+ # 存储对话历史
42
+ conversation_history = []
43
+
44
 
45
+ def rag_answer(question, model_name, temperature, embedding_model, k):
46
+ """
47
+ 处理用户问题并返回答案和高亮显示的上下文
48
+
49
+ :param question: 用户输入的问题
50
+ :param model_name: 使用的语言模型名称
51
+ :param temperature: 生成答案时使用的温度参数
52
+ :param embedding_model: 使用的嵌入模型
53
+ :param k: 检索到的文档块数量
54
+ :return: 模型生成的答案和高亮显示上下文的Markdown文本
55
+ """
56
+ try:
57
+ chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
58
+ retriever = create_vector_store(chunks, model=embedding_model, k=k)
59
+ rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
60
+
61
+ # 将问题添加到对话历史中
62
+ conversation_history.append(("User", question))
63
+
64
+ # 将对话历史转换为字符串
65
+ history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])
66
+
67
+ retrieved_documents, answer = execute_query(retriever, rag_chain, history_str, model_name=model_name,
68
+ temperature=temperature)
69
+ # 在文档中高亮显示上下文
70
+ context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
71
+ highlighted_document = orig_documents[0].page_content
72
+ for i in range(len(context)):
73
+ highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
74
+
75
+ # 将回答添加到���话历史中
76
+ conversation_history.append(("Assistant", answer))
77
+ except Exception as e:
78
+ print(f"An error occurred: {e}")
79
+ return "处理您的问题时出现错误,请稍后再试。", ""
80
  return answer, highlighted_document
81
 
82
+
83
+ def clear_context():
84
+ """
85
+ 清除对话历史
86
+ """
87
+ global conversation_history
88
+ conversation_history = []
89
+ return "", ""
90
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  if __name__ == "__main__":
93
+ with gr.Blocks(title=title, theme='ParityError/Interstellar') as rag_demo:
94
+ gr.Markdown(title_markdown)
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ inputs = gr.Textbox(
99
+ placeholder="请您输入任何关于 LightZero 的问题。",
100
+ label="问题 (Q)")
101
+ model_name = gr.Dropdown(
102
+ choices=['kimi', 'abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'],
103
+ # value='azure_gpt-4',
104
+ value='kimi',
105
+ label="选择语言模型")
106
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
107
+ embedding_model = gr.Dropdown(
108
+ choices=['HuggingFace', 'TensorflowHub', 'OpenAI'],
109
+ value='OpenAI',
110
+ label="选择嵌入模型")
111
+ k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="检索到的文档块数量")
112
+ with gr.Row():
113
+ gr_submit = gr.Button('提交')
114
+ gr_clear = gr.Button('清除上下文')
115
+
116
+ outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
117
+ label="回答 (A)")
118
+ with gr.Row():
119
+ outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")
120
+
121
+ gr.Markdown(tos_markdown)
122
+
123
+ gr_submit.click(
124
+ rag_answer,
125
+ inputs=[inputs, model_name, temperature, embedding_model, k],
126
+ outputs=[outputs_answer, outputs_context],
127
+ )
128
+ gr_clear.click(clear_context, outputs=[outputs_answer, outputs_context])
129
+
130
+ concurrency = int(os.environ.get('CONCURRENCY', os.cpu_count()))
131
+ favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
132
+ rag_demo.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)
133
+
app_mqa.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ from dotenv import load_dotenv
5
+ from langchain.document_loaders import TextLoader
6
+
7
+ from rag_demo import load_and_split_document, create_vector_store, setup_rag_chain, execute_query
8
+
9
+ # 环境设置
10
+ load_dotenv() # 加载环境变量
11
+ QUESTION_LANG = os.getenv("QUESTION_LANG") # 从环境变量获取 QUESTION_LANG
12
+ assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG
13
+
14
+ if QUESTION_LANG == "cn":
15
+ title = "LightZero RAG Demo"
16
+ title_markdown = """
17
+ <div align="center">
18
+ <img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
19
+ </div>
20
+ <h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG"> LightZero RAG Demo</a></h2>
21
+ <h4 align="center"> 📢说明:请您在下面的"问题(Q)"框中输入任何关于 LightZero 的问题,然后点击"提交"按钮。右侧"回答(A)"框中会显示 RAG 模型给出的回答。在 QA 栏的下方会给出参考文档(其中检索得到的相关文段会用黄色高亮显示)。</h4>
22
+ <h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。 </h4>
23
+ <strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
24
+ """
25
+ tos_markdown = """
26
+ ### 使用条款
27
+ 玩家使用本服务须同意以下条款:
28
+ 该服务是一项探索性研究预览版,仅供非商业用途。它仅提供有限的安全措施,并可能生成令人反感的内容。不得将其用于任何非法、有害、暴力、种族主义等目的。
29
+ 如果您的游玩体验有不佳之处,请发送邮件至 opendilab@pjlab.org.cn ! 我们将删除相关信息,并不断改进这个项目。
30
+ 为了获得最佳体验,请使用台式电脑,因为移动设备可能会影响可视化效果。
31
+ **版权所有 2024 OpenDILab。**
32
+ """
33
+
34
+ # 路径变量,方便之后的文件使用
35
+ file_path = './documents/LightZero_README.zh.md'
36
+
37
+ # 加载原始Markdown文档
38
+ loader = TextLoader(file_path)
39
+ orig_documents = loader.load()
40
+
41
+ # 存储对话历史
42
+ conversation_history = []
43
+
44
+
45
+ def rag_answer(question, model_name, temperature, embedding_model, k):
46
+ """
47
+ 处理用户问题并返回答案和高亮显示的上下文
48
+
49
+ :param question: 用户输入的问题
50
+ :param model_name: 使用的语言模型名称
51
+ :param temperature: 生成答案时使用的温度参数
52
+ :param embedding_model: 使用的嵌入模型
53
+ :param k: 检索到的文档块数量
54
+ :return: 模型生成的答案和高亮显示上下文的Markdown文本
55
+ """
56
+ try:
57
+ chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
58
+ retriever = create_vector_store(chunks, model=embedding_model, k=k)
59
+ rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
60
+
61
+ # 将问题添加到对话历史中
62
+ conversation_history.append(("User", question))
63
+
64
+ # 将对话历史转换为字符串
65
+ history_str = "\n".join([f"{role}: {text}" for role, text in conversation_history])
66
+
67
+ retrieved_documents, answer = execute_query(retriever, rag_chain, history_str, model_name=model_name,
68
+ temperature=temperature)
69
+ # 在文档中高亮显示上下文
70
+ context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
71
+ highlighted_document = orig_documents[0].page_content
72
+ for i in range(len(context)):
73
+ highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
74
+
75
+ # 将回答添加到对话历史中
76
+ conversation_history.append(("Assistant", answer))
77
+ except Exception as e:
78
+ print(f"An error occurred: {e}")
79
+ return "处理您的问题时出现错误,请稍后再试。", ""
80
+ return answer, highlighted_document
81
+
82
+
83
+ def clear_context():
84
+ """
85
+ 清除对话历史
86
+ """
87
+ global conversation_history
88
+ conversation_history = []
89
+ return "", ""
90
+
91
+
92
+ if __name__ == "__main__":
93
+ with gr.Blocks(title=title, theme='ParityError/Interstellar') as rag_demo:
94
+ gr.Markdown(title_markdown)
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ inputs = gr.Textbox(
99
+ placeholder="请您输入任何关于 LightZero 的问题。",
100
+ label="问题 (Q)")
101
+ model_name = gr.Dropdown(
102
+ choices=['kimi', 'abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'],
103
+ # value='azure_gpt-4',
104
+ value='kimi',
105
+ label="选择语言模型")
106
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
107
+ embedding_model = gr.Dropdown(
108
+ choices=['HuggingFace', 'TensorflowHub', 'OpenAI'],
109
+ value='OpenAI',
110
+ label="选择嵌入模型")
111
+ k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="检索到的文档块数量")
112
+ with gr.Row():
113
+ gr_submit = gr.Button('提交')
114
+ gr_clear = gr.Button('清除上下文')
115
+
116
+ outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
117
+ label="回答 (A)")
118
+ with gr.Row():
119
+ outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")
120
+
121
+ gr.Markdown(tos_markdown)
122
+
123
+ gr_submit.click(
124
+ rag_answer,
125
+ inputs=[inputs, model_name, temperature, embedding_model, k],
126
+ outputs=[outputs_answer, outputs_context],
127
+ )
128
+ gr_clear.click(clear_context, outputs=[outputs_answer, outputs_context])
129
+
130
+ concurrency = int(os.environ.get('CONCURRENCY', os.cpu_count()))
131
+ favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
132
+ rag_demo.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)
app_qa.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ from dotenv import load_dotenv
5
+ from langchain.document_loaders import TextLoader
6
+
7
+ from rag_demo import load_and_split_document, create_vector_store, setup_rag_chain, execute_query
8
+
9
+ # 环境设置
10
+ load_dotenv() # 加载环境变量
11
+ QUESTION_LANG = os.getenv("QUESTION_LANG") # 从环境变量获取 QUESTION_LANG
12
+ assert QUESTION_LANG in ['cn', 'en'], QUESTION_LANG
13
+
14
+ if QUESTION_LANG == "cn":
15
+ title = "LightZero RAG Demo"
16
+ title_markdown = """
17
+ <div align="center">
18
+ <img src="https://raw.githubusercontent.com/puyuan1996/RAG/main/assets/banner.svg" width="80%" height="20%" alt="Banner Image">
19
+ </div>
20
+ <h2 style="text-align: center; color: black;"><a href="https://github.com/puyuan1996/RAG"> LightZero RAG Demo</a></h2>
21
+ <h4 align="center"> 📢说明:请您在下面的"问题(Q)"框中输入任何关于 LightZero 的问题,然后点击"提交"按钮。右侧"回答(A)"框中会显示 RAG 模型给出的回答。在 QA 栏的下方会给出参考文档(其中检索得到的相关文段会用黄色高亮显示)。</h4>
22
+ <h4 align="center"> 如果你喜欢这个项目,请给我们在 GitHub 点个 star ✨ 。我们将会持续保持更新。 </h4>
23
+ <strong><h5 align="center">注意:算法模型的输出可能包含一定的随机性。相关结果不代表任何开发者和相关 AI 服务的态度和意见。本项目开发者不对生成结果作任何保证,仅供参考。<h5></strong>
24
+ """
25
+ tos_markdown = """
26
+ ### 使用条款
27
+ 玩家使用本服务须同意以下条款:
28
+ 该服务是一项探索性研究预览版,仅供非商业用途。它仅提供有限的安全措施,并可能生成令人反感的内容。不得将其用于任何非法、有害、暴力、种族主义等目的。
29
+ 如果您的游玩体验有不佳之处,请发送邮件至 opendilab@pjlab.org.cn ! 我们将删除相关信息,并不断改进这个项目。
30
+ 为了获得最佳体验,请使用台式电脑,因为移动设备可能会影响可视化效果。
31
+ **版权所有 2024 OpenDILab。**
32
+ """
33
+
34
+ # 路径变量,方便之后的文件使用
35
+ file_path = './documents/LightZero_README.zh.md'
36
+
37
+ # 加载原始Markdown文档
38
+ loader = TextLoader(file_path)
39
+ orig_documents = loader.load()
40
+
41
+ def rag_answer(question, model_name, temperature, embedding_model, k):
42
+ """
43
+ 处理用户问题并返回答案和高亮显示的上下文
44
+
45
+ :param question: 用户输入的问题
46
+ :param model_name: 使用的语言模型名称
47
+ :param temperature: 生成答案时使用的温度参数
48
+ :param embedding_model: 使用的嵌入模型
49
+ :param k: 检索到的文档块数量
50
+ :return: 模型生成的答案和高亮显示上下文的Markdown文本
51
+ """
52
+ try:
53
+ chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
54
+ retriever = create_vector_store(chunks, model=embedding_model, k=k)
55
+ rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
56
+
57
+ retrieved_documents, answer = execute_query(retriever, rag_chain, question, model_name=model_name, temperature=temperature)
58
+ # 在文档中高亮显示上下文
59
+ context = [retrieved_documents[i].page_content for i in range(len(retrieved_documents))]
60
+ highlighted_document = orig_documents[0].page_content
61
+ for i in range(len(context)):
62
+ highlighted_document = highlighted_document.replace(context[i], f"<mark>{context[i]}</mark>")
63
+ except Exception as e:
64
+ print(f"An error occurred: {e}")
65
+ return "处理您的问题时出现错误,请稍后再试。", ""
66
+ return answer, highlighted_document
67
+
68
+
69
+ if __name__ == "__main__":
70
+ with gr.Blocks(title=title, theme='ParityError/Interstellar') as rag_demo:
71
+ gr.Markdown(title_markdown)
72
+
73
+ with gr.Row():
74
+ with gr.Column():
75
+ inputs = gr.Textbox(
76
+ placeholder="请您输入任何关于 LightZero 的问题。",
77
+ label="问题 (Q)")
78
+ model_name = gr.Dropdown(
79
+ choices=['kimi', 'abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'],
80
+ # value='azure_gpt-4',
81
+ value='kimi',
82
+ label="选择语言模型")
83
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.01, step=0.01, label="温度参数")
84
+ embedding_model = gr.Dropdown(
85
+ choices=['HuggingFace', 'TensorflowHub', 'OpenAI'],
86
+ value='OpenAI',
87
+ label="选择嵌入模型")
88
+ k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="检索到的文档块数量")
89
+ gr_submit = gr.Button('提交')
90
+
91
+ outputs_answer = gr.Textbox(placeholder="当你点击提交按钮后,这里会显示 RAG 模型给出的回答。",
92
+ label="回答 (A)")
93
+ with gr.Row():
94
+ outputs_context = gr.Markdown(label="参考的文档,检索得到的 context 用高亮显示 (C)")
95
+
96
+ gr.Markdown(tos_markdown)
97
+
98
+ gr_submit.click(
99
+ rag_answer,
100
+ inputs=[inputs, model_name, temperature, embedding_model, k],
101
+ outputs=[outputs_answer, outputs_context],
102
+ )
103
+
104
+ concurrency = int(os.environ.get('CONCURRENCY', os.cpu_count()))
105
+ favicon_path = os.path.join(os.path.dirname(__file__), 'assets', 'avatar.png')
106
+ rag_demo.queue().launch(max_threads=concurrency, favicon_path=favicon_path, share=True)
assets/avatar.png ADDED
rag_demo.py CHANGED
@@ -2,24 +2,34 @@
2
  参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
3
  """
4
  # 导入必要的库与模块
 
5
  import os
6
  import textwrap
7
 
 
8
  from dotenv import load_dotenv
9
  from langchain.chat_models import ChatOpenAI
10
  from langchain.document_loaders import TextLoader
11
- from langchain.embeddings import OpenAIEmbeddings
12
  from langchain.prompts import ChatPromptTemplate
13
  from langchain.schema.output_parser import StrOutputParser
14
- from langchain.schema.runnable import RunnablePassthrough
15
  from langchain.text_splitter import CharacterTextSplitter
16
  from langchain.vectorstores import Weaviate
17
  from weaviate import Client
18
  from weaviate.embedded import EmbeddedOptions
 
 
19
 
20
  # 环境设置与文档下载
21
  load_dotenv() # 加载环境变量
22
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
 
 
 
 
 
 
 
23
 
24
  # 确保 OPENAI_API_KEY 被正确设置
25
  if not OPENAI_API_KEY:
@@ -37,79 +47,231 @@ def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50):
37
 
38
 
39
  # 向量存储建立
40
- def create_vector_store(chunks, model="OpenAI"):
41
  """将文档块转换为向量并存储到 Weaviate 中"""
42
  client = Client(embedded_options=EmbeddedOptions())
43
- embedding_model = OpenAIEmbeddings() if model == "OpenAI" else None # 可以根据需要替换为其他嵌入模型
 
 
 
 
 
 
 
 
 
44
  vectorstore = Weaviate.from_documents(
45
  client=client,
46
  documents=chunks,
47
  embedding=embedding_model,
48
  by_text=False
49
  )
50
- return vectorstore.as_retriever()
51
 
52
 
53
- # 定义检索增强生成流程
54
  def setup_rag_chain(model_name="gpt-4", temperature=0):
55
  """设置检索增强生成流程"""
56
- prompt_template = """You are an assistant for question-answering tasks.
57
- Use your knowledge to answer the question if the provided context is not relevant.
58
- Otherwise, use the context to inform your answer.
59
- Question: {question}
60
- Context: {context}
61
- Answer:
62
- """
63
- prompt = ChatPromptTemplate.from_template(prompt_template)
64
- llm = ChatOpenAI(model_name=model_name, temperature=temperature)
65
- # 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
66
- rag_chain = (
67
- prompt
68
- | llm
69
- | StrOutputParser()
70
- )
 
 
 
 
 
 
71
  return rag_chain
72
 
73
 
74
  # 执行查询并打印结果
75
- def execute_query(retriever, rag_chain, query):
76
- """执行查询并返回结果及检索到的文档块"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  retrieved_documents = retriever.invoke(query)
78
- rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": query})
79
- return retrieved_documents, rag_chain_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
- # 执行无 RAG 链的查询
83
  def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
84
  """执行无 RAG 链的查询"""
85
- llm = ChatOpenAI(model_name=model_name, temperature=temperature)
86
- response = llm.invoke(query)
87
- return response.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
- # rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。
91
  if __name__ == "__main__":
92
  # 假设文档已存在于本地
93
  file_path = './documents/LightZero_README.zh.md'
 
 
 
 
 
94
 
95
  # 加载和分割文档
96
- chunks = load_and_split_document(file_path)
97
 
98
  # 创建向量存储
99
- retriever = create_vector_store(chunks)
100
 
101
  # 设置 RAG 流程
102
- rag_chain = setup_rag_chain()
103
 
104
  # 提出问题并获取答案
105
- query = "请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 环境上运行吗?请详细解释原因"
106
- # query = "请详细解释 MCTS 算法的原理,并给出带有详细中文注释的 Python 代码示例"
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  # 使用 RAG 链获取参考的文档与答案
109
- retrieved_documents, result_with_rag = execute_query(retriever, rag_chain, query)
 
110
 
111
  # 不使用 RAG 链获取答案
112
- result_without_rag = execute_query_no_rag(query=query)
113
 
114
  # 打印并对比两种方法的结果
115
  # 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
 
2
  参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
3
  """
4
  # 导入必要的库与模块
5
+ import json
6
  import os
7
  import textwrap
8
 
9
+ import requests
10
  from dotenv import load_dotenv
11
  from langchain.chat_models import ChatOpenAI
12
  from langchain.document_loaders import TextLoader
13
+ from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, TensorflowHubEmbeddings
14
  from langchain.prompts import ChatPromptTemplate
15
  from langchain.schema.output_parser import StrOutputParser
 
16
  from langchain.text_splitter import CharacterTextSplitter
17
  from langchain.vectorstores import Weaviate
18
  from weaviate import Client
19
  from weaviate.embedded import EmbeddedOptions
20
+ from zhipuai import ZhipuAI
21
+ from openai import AzureOpenAI
22
 
23
  # 环境设置与文档下载
24
  load_dotenv() # 加载环境变量
25
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
26
+ MIMIMAX_API_KEY = os.getenv("MIMIMAX_API_KEY")
27
+ MIMIMAX_GROUP_ID = os.getenv("MIMIMAX_GROUP_ID")
28
+ ZHIPUAI_API_KEY = os.getenv("ZHIPUAI_API_KEY")
29
+ KIMI_OPENAI_API_KEY = os.getenv("KIMI_OPENAI_API_KEY")
30
+
31
+ AZURE_OPENAI_KEY = os.getenv("AZURE_OPENAI_KEY")
32
+ AZURE_ENDPOINT = os.getenv("AZURE_ENDPOINT")
33
 
34
  # 确保 OPENAI_API_KEY 被正确设置
35
  if not OPENAI_API_KEY:
 
47
 
48
 
49
  # 向量存储建立
50
+ def create_vector_store(chunks, model="OpenAI", k=4):
51
  """将文档块转换为向量并存储到 Weaviate 中"""
52
  client = Client(embedded_options=EmbeddedOptions())
53
+
54
+ if model == "OpenAI":
55
+ embedding_model = OpenAIEmbeddings()
56
+ elif model == "HuggingFace":
57
+ embedding_model = HuggingFaceEmbeddings()
58
+ elif model == "TensorflowHub":
59
+ embedding_model = TensorflowHubEmbeddings()
60
+ else:
61
+ raise ValueError(f"Unsupported embedding model: {model}")
62
+
63
  vectorstore = Weaviate.from_documents(
64
  client=client,
65
  documents=chunks,
66
  embedding=embedding_model,
67
  by_text=False
68
  )
69
+ return vectorstore.as_retriever(search_kwargs={'k': k})
70
 
71
 
 
72
  def setup_rag_chain(model_name="gpt-4", temperature=0):
73
  """设置检索增强生成流程"""
74
+ if model_name.startswith("gpt"):
75
+ # 如果是以gpt开头的模型,使用原来的逻辑
76
+ prompt_template = """您是一个用于问答任务的专业助手。
77
+ 在处理问答任务时,请根据所提供的[上下文信息]给出回答。
78
+ 如果[上下文信息]与[问题]不相关,那么请运用您的知识库为提问者提供准确的答复。
79
+ 请确保回答内容的质量, 包括相关性、准确性和可读性。
80
+ [问题]: {question}
81
+ [上下文信息]: {context}
82
+ [回答]:
83
+ """
84
+ prompt = ChatPromptTemplate.from_template(prompt_template)
85
+ llm = ChatOpenAI(model_name=model_name, temperature=temperature)
86
+ # 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
87
+ rag_chain = (
88
+ prompt
89
+ | llm
90
+ | StrOutputParser()
91
+ )
92
+ else:
93
+ # 如果不是以gpt开头的模型,返回None
94
+ rag_chain = None
95
  return rag_chain
96
 
97
 
98
  # 执行查询并打印结果
99
+ def execute_query(retriever, rag_chain, query, model_name="gpt-4", temperature=0):
100
+ """
101
+ 执行查询并返回结果及检索到的文档块
102
+
103
+ 参数:
104
+ retriever: 文档检索器对象
105
+ rag_chain: 检索增强生成链对象,如果为None则不使用RAG链
106
+ query: 查询问题
107
+ model_name: 使用的语言模型名称,默认为"gpt-4"
108
+ temperature: 生成温度,默认为0
109
+
110
+ 返回:
111
+ retrieved_documents: 检索到的文档块列表
112
+ response_text: 生成的回答文本
113
+ """
114
+ # 使用检索器检索相关文档块
115
  retrieved_documents = retriever.invoke(query)
116
+
117
+ if rag_chain is not None:
118
+ # 如果有RAG链,则使用RAG链生成回答
119
+ rag_chain_response = rag_chain.invoke({"context": retrieved_documents, "question": query})
120
+ response_text = rag_chain_response
121
+ else:
122
+ # 如果没有RAG链,则将检索到的文档块和查询问题按照指定格式输入给语言模型
123
+ if model_name == "kimi":
124
+ # 对于有检索能力的模型,使用不同的模板
125
+ prompt_template = """您是一个用于问答任务的专业助手。
126
+ 在处理问答任务时,请根据所提供的【上下文信息】和【你的知识库和检索到的相关文档】给出回答。
127
+ 请确保回答内容的质量,包括相关性、准确性和可读性。
128
+ 【问题】: {question}
129
+ 【上下文信息】: {context}
130
+ 【回答】:
131
+ """
132
+ else:
133
+ prompt_template = """您是一个用于问答任务的专业助手。
134
+ 在处理问答任务时,请根据所提供的【上下文信息】给出回答。
135
+ 如果【上下文信息】与【问题】不相关,那么请运用您的知识库为提问者提供准确的答复。
136
+ 请确保回答内容的质量,包括相关性、准确性和可读性。
137
+ 【问题】: {question}
138
+ 【上下文信息】: {context}
139
+ 【回答】:
140
+ """
141
+
142
+ context = '\n'.join(
143
+ [f'**Document {i}**: ' + retrieved_documents[i].page_content for i in range(len(retrieved_documents))])
144
+ prompt = prompt_template.format(question=query, context=context)
145
+ response_text = execute_query_no_rag(model_name=model_name, temperature=temperature, query=prompt)
146
+ return retrieved_documents, response_text
147
 
148
 
 
149
  def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
150
  """执行无 RAG 链的查询"""
151
+ if model_name.startswith("gpt"):
152
+ # 如果是以gpt开头的模型,使用原来的逻辑
153
+ llm = ChatOpenAI(model_name=model_name, temperature=temperature)
154
+ response = llm.invoke(query)
155
+ return response.content
156
+ elif model_name.startswith("azure_gpt"):
157
+ client = AzureOpenAI(
158
+ azure_endpoint=AZURE_ENDPOINT,
159
+ api_key=AZURE_OPENAI_KEY,
160
+ api_version="2024-02-15-preview"
161
+ )
162
+ message_text = [{"role": "user", "content": query}, ]
163
+ completion = client.chat.completions.create(
164
+ model=model_name[6:], # model_name = 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo'
165
+ messages=message_text,
166
+ temperature=temperature,
167
+ top_p=0.95,
168
+ frequency_penalty=0,
169
+ presence_penalty=0,
170
+ stop=None
171
+ )
172
+ return completion.choices[0].message.content
173
+ elif model_name == 'abab6-chat':
174
+ # 如果是'abab6-chat'模型,使用专门的API调用方式
175
+ url = "https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId=" + MIMIMAX_GROUP_ID
176
+ headers = {"Content-Type": "application/json", "Authorization": "Bearer " + MIMIMAX_API_KEY}
177
+ payload = {
178
+ "bot_setting": [
179
+ {
180
+ "bot_name": "MM智能助理",
181
+ "content": "MM智能助理是一款由MiniMax自研的,没有调用其他产品的接口的大型语言模型。MiniMax是一家中国科技公司,一直致力于进行大模型相关的研究。",
182
+ }
183
+ ],
184
+ "messages": [{"sender_type": "USER", "sender_name": "小明", "text": query}],
185
+ "reply_constraints": {"sender_type": "BOT", "sender_name": "MM智能助理"},
186
+ "model": model_name,
187
+ "tokens_to_generate": 1034,
188
+ "temperature": temperature,
189
+ "top_p": 0.9,
190
+ }
191
+
192
+ response = requests.request("POST", url, headers=headers, json=payload)
193
+ # 将 JSON 字符串解析为字典
194
+ response_dict = json.loads(response.text)
195
+ # 提取 'reply' 键对应的值
196
+ return response_dict['reply']
197
+
198
+ elif model_name == 'glm-4':
199
+ # 如果是'glm-4'模型,使用专门的API调用方式
200
+ client = ZhipuAI(api_key=ZHIPUAI_API_KEY) # 填写您自己的APIKey
201
+ response = client.chat.completions.create(
202
+ model=model_name, # 填写需要调用的模型名称
203
+ messages=[{"role": "user", "content": query}]
204
+ )
205
+ return response.choices[0].message.content
206
+ elif model_name == 'kimi':
207
+ # 如果是'kimi'模型,使用专门的API调用方式
208
+ from openai import OpenAI
209
+ client = OpenAI(
210
+ api_key=KIMI_OPENAI_API_KEY,
211
+ base_url="https://api.moonshot.cn/v1",
212
+ )
213
+ messages = [
214
+ {
215
+ "role": "system",
216
+ "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。",
217
+ },
218
+ {"role": "user",
219
+ "content": query},
220
+ ]
221
+ completion = client.chat.completions.create(
222
+ model="moonshot-v1-128k",
223
+ messages=messages,
224
+ temperature=0.01,
225
+ top_p=1.0,
226
+ n=1, # 为每条输入消息生成多少个结果
227
+ stream=False # 流式输出
228
+ )
229
+ return completion.choices[0].message.content
230
+ else:
231
+ # 如果模型不支持,抛出异常
232
+ raise ValueError(f"Unsupported model: {model_name}")
233
 
234
 
 
235
  if __name__ == "__main__":
236
  # 假设文档已存在于本地
237
  file_path = './documents/LightZero_README.zh.md'
238
+ # model_name = "glm-4" # model_name=['abab6-chat', 'glm-4', 'gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo', 'azure_gpt-4', 'azure_gpt-35-turbo-16k', 'azure_gpt-35-turbo']
239
+ model_name = 'azure_gpt-4'
240
+ temperature = 0.01
241
+ # embedding_model = 'HuggingFace' # embedding_model=['HuggingFace', 'TensorflowHub', 'OpenAI']
242
+ embedding_model = 'OpenAI' # embedding_model=['HuggingFace', 'TensorflowHub', 'OpenAI']
243
 
244
  # 加载和分割文档
245
+ chunks = load_and_split_document(file_path, chunk_size=5000, chunk_overlap=500)
246
 
247
  # 创建向量存储
248
+ retriever = create_vector_store(chunks, model=embedding_model, k=5)
249
 
250
  # 设置 RAG 流程
251
+ rag_chain = setup_rag_chain(model_name=model_name, temperature=temperature)
252
 
253
  # 提出问题并获取答案
254
+ query = ("GitHub - opendilab/LightZero: [NeurIPS 2023 Spotlight] LightZero: A Unified Benchmark for Monte Carl 请根据这个仓库回答下面的问题:(1)请简要介绍一下 LightZero (2)请详细介绍 LightZero 的框架结构。 (3)请给出安装 LightZero,运行他们的示例代码的详细步骤 (4)- 请问 LightZero 具体支持什么任务(tasks/environments)? (5)请问 LightZero 具体支持什么算法?(6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行? (7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?(8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?(9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?(10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。(11)请问对这个仓库提出详细的改进建议")
255
+ """
256
+ (1)请简要介绍一下 LightZero
257
+ (2)请详细介绍 LightZero 的框架结构。
258
+ (3)请给出安装 LightZero,运行他们的示例代码的详细步骤
259
+ (4)请问 LightZero 具体支持什么任务(tasks/environments)?
260
+ (5)请问 LightZero 具体支持什么算法?
261
+ (6)请问 LightZero 具体支持什么算法,各自支持在哪些任务上运行?
262
+ (7)请问 LightZero 里面实现的 MuZero 算法支持在 Atari 任务上运行吗?
263
+ (8)请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 任务上运行吗?
264
+ (9)LightZero 支持哪些算法? 各自的优缺点是什么? 我应该如何根据任务特点进行选择呢?
265
+ (10)请结合 LightZero 中的代码介绍他们是如何实现 MCTS 的。
266
+ (11)请问对这个仓库提出详细的改进建议。
267
+ """
268
 
269
  # 使用 RAG 链获取参考的文档与答案
270
+ retrieved_documents, result_with_rag = execute_query(retriever, rag_chain, query, model_name=model_name,
271
+ temperature=temperature)
272
 
273
  # 不使用 RAG 链获取答案
274
+ result_without_rag = execute_query_no_rag(model_name=model_name, query=query, temperature=temperature)
275
 
276
  # 打印并对比两种方法的结果
277
  # 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
rag_demo_v0.py DELETED
@@ -1,136 +0,0 @@
1
- """
2
- 参考博客:https://mp.weixin.qq.com/s/RUdZjQMSlVOfHfhErSNXnA
3
- """
4
- # 导入必要的库与模块
5
- import os
6
- import textwrap
7
-
8
- from dotenv import load_dotenv
9
- from langchain.chat_models import ChatOpenAI
10
- from langchain.document_loaders import TextLoader
11
- from langchain.embeddings import OpenAIEmbeddings
12
- from langchain.prompts import ChatPromptTemplate
13
- from langchain.schema.output_parser import StrOutputParser
14
- from langchain.schema.runnable import RunnablePassthrough
15
- from langchain.text_splitter import CharacterTextSplitter
16
- from langchain.vectorstores import Weaviate
17
- from weaviate import Client
18
- from weaviate.embedded import EmbeddedOptions
19
-
20
- # 环境设置与文档下载
21
- load_dotenv() # 加载环境变量
22
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # 从环境变量获取 OpenAI API 密钥
23
-
24
- # 确保 OPENAI_API_KEY 被正确设置
25
- if not OPENAI_API_KEY:
26
- raise ValueError("OpenAI API Key not found in the environment variables.")
27
-
28
-
29
- # 文档加载与分割
30
- def load_and_split_document(file_path, chunk_size=500, chunk_overlap=50):
31
- """加载文档并分割成小块"""
32
- loader = TextLoader(file_path)
33
- documents = loader.load()
34
- text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
35
- chunks = text_splitter.split_documents(documents)
36
- return chunks
37
-
38
-
39
- # 向量存储建立
40
- def create_vector_store(chunks, model="OpenAI"):
41
- """将文档块转换为向量并存储到 Weaviate 中"""
42
- client = Client(embedded_options=EmbeddedOptions())
43
- embedding_model = OpenAIEmbeddings() if model == "OpenAI" else None # 可以根据需要替换为其他嵌入模型
44
- vectorstore = Weaviate.from_documents(
45
- client=client,
46
- documents=chunks,
47
- embedding=embedding_model,
48
- by_text=False
49
- )
50
- return vectorstore.as_retriever()
51
-
52
-
53
- # 定义检索增强生成流程
54
- def setup_rag_chain_v0(retriever, model_name="gpt-4", temperature=0):
55
- """设置检索增强生成流程"""
56
- prompt_template = """You are an assistant for question-answering tasks.
57
- Use your knowledge to answer the question if the provided context is not relevant.
58
- Otherwise, use the context to inform your answer.
59
- Question: {question}
60
- Context: {context}
61
- Answer:
62
- """
63
- prompt = ChatPromptTemplate.from_template(prompt_template)
64
- llm = ChatOpenAI(model_name=model_name, temperature=temperature)
65
- # 创建 RAG 链,参考 https://python.langchain.com/docs/expression_language/
66
- rag_chain = (
67
- {"context": retriever, "question": RunnablePassthrough()}
68
- | prompt
69
- | llm
70
- | StrOutputParser()
71
- )
72
- return rag_chain
73
-
74
-
75
- # 执行查询并打印结果
76
- def execute_query_v0(rag_chain, query):
77
- """执行查询并返回结果"""
78
- return rag_chain.invoke(query)
79
-
80
-
81
- # 执行无 RAG 链的查询
82
- def execute_query_no_rag(model_name="gpt-4", temperature=0, query=""):
83
- """执行无 RAG 链的查询"""
84
- llm = ChatOpenAI(model_name=model_name, temperature=temperature)
85
- response = llm.invoke(query)
86
- return response.content
87
-
88
-
89
- # rag_demo.py 相对 rag_demo_v0.py 的不同之处在于可以输出检索到的文档块。
90
- if __name__ == "__main__":
91
- # 下载并保存文档到本地(这里被注释掉了,因为已经假设文档存在于本地)
92
- # url = "https://raw.githubusercontent.com/langchain-ai/langchain/master/docs/docs/modules/state_of_the_union.txt"
93
- # res = requests.get(url)
94
- # with open("state_of_the_union.txt", "w") as f:
95
- # f.write(res.text)
96
-
97
- # 假设文档已存在于本地
98
- # file_path = './documents/state_of_the_union.txt'
99
- file_path = './documents/LightZero_README.zh.md'
100
-
101
- # 加载和分割文档
102
- chunks = load_and_split_document(file_path)
103
-
104
- # 创建向量存储
105
- retriever = create_vector_store(chunks)
106
-
107
- # 设置 RAG 流程
108
- rag_chain = setup_rag_chain_v0(retriever)
109
-
110
- # 提出问题并获取答案
111
- # query = "请你分别用中英文简介 LightZero"
112
- # query = "请你用英文简介 LightZero"
113
- query = "请你用中文简介 LightZero"
114
- # query = "请问 LightZero 支持哪些环境和算法,应该如何快速上手使用?"
115
- # query = "请问 LightZero 里面实现的 MuZero 算法支持在 Atari 环境上运行吗?"
116
- # query = "请问 LightZero 里面实现的 AlphaZero 算法支持在 Atari 环境上运行吗?请详细解释原因"
117
- # query = "请详细解释 MCTS 算法的原理,并给出带有详细中文注释的 Python 代码示例"
118
-
119
- # 使用 RAG 链获取答案
120
- result_with_rag = execute_query_v0(rag_chain, query)
121
-
122
- # 不使用 RAG 链获取答案
123
- result_without_rag = execute_query_no_rag(query=query)
124
-
125
- # 打印并对比两种方法的结果
126
- # 使用textwrap.fill来自动分段文本,width参数可以根据你的屏幕宽度进行调整
127
- wrapped_result_with_rag = textwrap.fill(result_with_rag, width=80)
128
- wrapped_result_without_rag = textwrap.fill(result_without_rag, width=80)
129
-
130
- # 打印自动分段后的文本
131
- print("="*40)
132
- print(f"我的问题是:\n{query}")
133
- print("="*40)
134
- print(f"Result with RAG:\n{wrapped_result_with_rag}")
135
- print("="*40)
136
- print(f"Result without RAG:\n{wrapped_result_without_rag}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -5,3 +5,4 @@ weaviate-client
5
  requests
6
  python-dotenv
7
  tiktoken
 
 
5
  requests
6
  python-dotenv
7
  tiktoken
8
+ sentence-transformers