souljoy commited on
Commit
9d36857
1 Parent(s): cfab94c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -43
app.py CHANGED
@@ -6,6 +6,8 @@ import pandas as pd
6
  import time
7
  from cnocr import CnOcr
8
  import numpy as np
 
 
9
 
10
  ocr = CnOcr() # 初始化ocr模型
11
  history_max_len = 500 # 机器人记忆的最大长度
@@ -13,22 +15,12 @@ all_max_len = 2000 # 输入的最大长度
13
 
14
 
15
  def get_text_emb(open_ai_key, text):
16
- url = 'https://api.openai.com/v1/embeddings'
17
- headers = {
18
- 'Content-Type': 'application/json',
19
- 'Authorization': 'Bearer ' + open_ai_key
20
- }
21
- data = {
22
- "model": "text-embedding-ada-002",
23
- "input": text
24
- }
25
- result = requests.post(url=url,
26
- data=json.dumps(data),
27
- headers=headers
28
- )
29
- if result.status_code != 200:
30
- raise Exception('API请求出错,状态码为:' + str(result.status_code) + ',错误信息为:' + result.json())
31
- return result.json()['data'][0]['embedding']
32
 
33
 
34
  def doc_index_self(open_ai_key, doc): # 文档向量化
@@ -37,10 +29,10 @@ def doc_index_self(open_ai_key, doc): # 文档向量化
37
  for text in texts:
38
  emb_list.append(get_text_emb(open_ai_key, text))
39
  return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
40
- value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True)
41
 
42
 
43
- def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings): # 获取机器人回复
44
  now_len = len(msg) # 当前输入的长度
45
  his_bg = -1 # 历史记录的起始位置
46
  for i in range(len(bot) - 1, -1, -1): # 从后往前遍历历史记录
@@ -96,29 +88,43 @@ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings): # 获
96
  messages.append({"role": "user", "content": his[0]}) # 加入用户的历史记录
97
  messages.append({"role": "assistant", "content": his[1]}) # 加入机器人的历史记录
98
  messages.append({"role": "user", "content": msg}) # 加入用户的当前输入
 
 
 
 
 
 
99
 
100
- url = 'https://api.openai.com/v1/chat/completions'
101
-
102
- data = {
103
- "model": "gpt-3.5-turbo",
104
- "messages": messages
105
- }
106
- print("data = \n", data)
107
-
108
- headers = {
109
- 'Content-Type': 'application/json',
110
- 'Authorization': 'Bearer ' + open_ai_key
111
- }
112
- result = requests.post(url=url,
113
- data=json.dumps(data),
114
- headers=headers
115
- )
116
- print("result = \n", result.json())
117
- res = str(result.json()['choices'][0]['message']['content']).strip()
 
118
  bot.append([msg, res]) # 加入历史记录
119
  return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
120
 
121
 
 
 
 
 
 
 
 
122
  def up_file(files): # 上传文件
123
  doc_text_list = [] # 用于存储文档
124
  for idx, file in enumerate(files): # 遍历文件
@@ -155,10 +161,26 @@ def up_file(files): # 上传文件
155
  doc_text_list = [str(text).strip() for text in doc_text_list if len(str(text).strip()) > 0] # 去除空格
156
  print(doc_text_list)
157
  return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(
 
158
  visible=True), gr.Markdown.update(
159
  value="操作说明 step 2:确认PDF解析结果(可修正),点击“建立索引”,随后进行对话")
160
 
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  with gr.Blocks() as demo:
163
  with gr.Row():
164
  with gr.Column():
@@ -166,21 +188,27 @@ with gr.Blocks() as demo:
166
  file = gr.File(file_types=['.pdf'], label='点击上传PDF,进行解析(支持多文档、表格、OCR)',
167
  file_count='multiple') # 支持多文档、表格、OCR
168
  txt = gr.Textbox(label='PDF解析结果', visible=False) # PDF解析结果
169
- index_self_bu = gr.Button(value='建立索引(by self)', visible=False) #
170
- index_llama_bu = gr.Button(value='建立索引(by llama_index)', visible=False) #
 
171
  doc_text_state = gr.State([]) # 存储PDF解析结果
172
  doc_emb_state = gr.State([]) # 存储PDF解析结果的embedding
 
 
173
  with gr.Column():
174
  md = gr.Markdown("""操作说明 step 1:点击左侧区域,上传PDF,进行解析""") # 操作说明
175
  chat_bot = gr.Chatbot(visible=False) # 聊天机器人
176
  msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False) # 消息框
177
- with gr.Row():
178
- chat_bu = gr.Button(value='发送', visible=False) # 发送按钮
179
 
180
- file.change(up_file, [file], [txt, index_self_bu, md]) # 上传文件
181
  index_self_bu.click(doc_index_self, [open_ai_key, txt],
182
- [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot]) # 提交解析结果
183
- chat_bu.click(get_response, [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot]) # 发送消息
 
 
 
 
184
 
185
  if __name__ == "__main__":
186
  demo.queue().launch()
 
6
  import time
7
  from cnocr import CnOcr
8
  import numpy as np
9
+ import openai
10
+ from llama_index import GPTVectorStoreIndex, SimpleDirectoryReader, Prompt
11
 
12
  ocr = CnOcr() # 初始化ocr模型
13
  history_max_len = 500 # 机器人记忆的最大长度
 
15
 
16
 
17
  def get_text_emb(open_ai_key, text):
18
+ openai.api_key = open_ai_key
19
+ response = openai.Embedding.create(
20
+ input=text,
21
+ model="text-embedding-ada-002"
22
+ )
23
+ return response['data'][0]['embedding']
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def doc_index_self(open_ai_key, doc): # 文档向量化
 
29
  for text in texts:
30
  emb_list.append(get_text_emb(open_ai_key, text))
31
  return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
32
+ value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True), 1
33
 
34
 
35
+ def get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings): # 获取机器人回复
36
  now_len = len(msg) # 当前输入的长度
37
  his_bg = -1 # 历史记录的起始位置
38
  for i in range(len(bot) - 1, -1, -1): # 从后往前遍历历史记录
 
88
  messages.append({"role": "user", "content": his[0]}) # 加入用户的历史记录
89
  messages.append({"role": "assistant", "content": his[1]}) # 加入机器人的历史记录
90
  messages.append({"role": "user", "content": msg}) # 加入用户的当前输入
91
+ openai.api_key = open_ai_key
92
+ chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages) # 获取机器人的回复
93
+ res = chat_completion.choices[0].message.content # 获取机器人的回复
94
+ bot.append([msg, res]) # 加入历史记录
95
+ return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
96
+
97
 
98
+ def get_response_by_llama_index(open_ai_key, msg, bot, query_engine): # 获取机器人回复
99
+ openai.api_key = open_ai_key
100
+ template = (
101
+ "你是一个有用的助手,可以使用文章内容准确地回答问题。使用提供的文章来生成你的答案,但避免逐字复制文章。尽可能使用自己的话。准确、有用、简洁、清晰。文章内容如下: \n"
102
+ "---------------------\n"
103
+ "{context_str}"
104
+ "\n---------------------\n"
105
+ "{query_str}\n"
106
+ "请基于文章内容回答用户的问题。\n"
107
+ ) # 定义模板
108
+ query_str = "历史对话如下:\n"
109
+ for his in bot: # 遍历历史记录
110
+ query_str += "用户:" + his[0] + "\n" # 加入用户的历史记录
111
+ query_str += "机器人:" + his[1] + "\n" # 加入机器人的历史记录
112
+ query_str += "用户:" + msg + "\n" # 加入用户的当前输入
113
+ qa_template = Prompt(template) # 将模板转换成Prompt对象
114
+ query_engine = query_engine.as_query_engine(text_qa_template=qa_template) # 建立查询引擎
115
+ res = query_engine.query(msg) # 获取回答
116
+ print(res) # 显示回答
117
  bot.append([msg, res]) # 加入历史记录
118
  return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
119
 
120
 
121
+ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings, query_engine, index_type): # 获取机器人回复
122
+ if index_type == 1:
123
+ return get_response_by_self(open_ai_key, msg, bot, doc_text_list, doc_embeddings)
124
+ else:
125
+ return get_response_by_llama_index(open_ai_key, msg, bot, query_engine)
126
+
127
+
128
  def up_file(files): # 上传文件
129
  doc_text_list = [] # 用于存储文档
130
  for idx, file in enumerate(files): # 遍历文件
 
161
  doc_text_list = [str(text).strip() for text in doc_text_list if len(str(text).strip()) > 0] # 去除空格
162
  print(doc_text_list)
163
  return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(
164
+ visible=True), gr.Button.update(
165
  visible=True), gr.Markdown.update(
166
  value="操作说明 step 2:确认PDF解析结果(可修正),点击“建立索引”,随后进行对话")
167
 
168
 
169
+ def doc_index_llama(open_ai_key, txt): # 建立索引
170
+ # 根据时间戳新建目录,保存txt文件
171
+ path = str(time.time())
172
+ import os
173
+ os.mkdir(path)
174
+ with open(path + '/doc.txt', mode='w', encoding='utf-8') as f:
175
+ f.write(txt)
176
+ openai.api_key = open_ai_key # 设置OpenAI API Key
177
+ documents = SimpleDirectoryReader(path).load_data() # 读取文档
178
+ index = GPTVectorStoreIndex.from_documents(documents) # 建立索引
179
+ query_engine = index.as_query_engine() # 建立查询引擎
180
+ return query_engine, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
181
+ value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True), 0
182
+
183
+
184
  with gr.Blocks() as demo:
185
  with gr.Row():
186
  with gr.Column():
 
188
  file = gr.File(file_types=['.pdf'], label='点击上传PDF,进行解析(支持多文档、表格、OCR)',
189
  file_count='multiple') # 支持多文档、表格、OCR
190
  txt = gr.Textbox(label='PDF解析结果', visible=False) # PDF解析结果
191
+ with gr.Row():
192
+ index_llama_bu = gr.Button(value='建立索引(by llama_index)', visible=False) # 建立索引(by llama_index)
193
+ index_self_bu = gr.Button(value='建立索引(by self)', visible=False) # 建立索引(by self)
194
  doc_text_state = gr.State([]) # 存储PDF解析结果
195
  doc_emb_state = gr.State([]) # 存储PDF解析结果的embedding
196
+ query_engine = gr.State([]) # 存储查询引擎
197
+ index_type = gr.State([]) # 存储索引类型
198
  with gr.Column():
199
  md = gr.Markdown("""操作说明 step 1:点击左侧区域,上传PDF,进行解析""") # 操作说明
200
  chat_bot = gr.Chatbot(visible=False) # 聊天机器人
201
  msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False) # 消息框
202
+ chat_bu = gr.Button(value='发送', visible=False) # 发送按钮
 
203
 
204
+ file.change(up_file, [file], [txt, index_self_bu, index_llama_bu, md]) # 上传文件
205
  index_self_bu.click(doc_index_self, [open_ai_key, txt],
206
+ [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot, index_type]) # 提交解析结果
207
+ index_llama_bu.click(doc_index_llama, [open_ai_key, txt],
208
+ [query_engine, msg_txt, chat_bu, md, chat_bot, index_type]) # 提交解析结果
209
+ chat_bu.click(get_response,
210
+ [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state, query_engine, index_type],
211
+ [chat_bot]) # 发送消息
212
 
213
  if __name__ == "__main__":
214
  demo.queue().launch()