souljoy commited on
Commit
3c22f61
1 Parent(s): b22f057

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -72
app.py CHANGED
@@ -27,102 +27,100 @@ def doc_emb(doc): # 文档向量化
27
  value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True)
28
 
29
 
30
- def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings):
31
- now_len = len(msg)
32
- his_bg = -1
33
- for i in range(len(bot) - 1, -1, -1):
34
- if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len:
35
  break
36
- now_len += len(bot[i][0]) + len(bot[i][1])
37
- his_bg = i
38
- history = [] if his_bg == -1 else bot[his_bg:]
39
- query_embedding = embedder.encode([msg])
40
- cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
41
- score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])]
42
- score_index.sort(key=lambda x: x[0], reverse=True)
43
  print('score_index:\n', score_index)
44
- index_set, sub_doc_list = set(), []
45
- for s_i in score_index:
46
- doc = doc_text_list[s_i[1]]
47
- if now_len + len(doc) > all_max_len:
48
  break
49
- index_set.add(s_i[1])
50
- now_len += len(doc)
51
  # 可能段落截断错误,所以把上下段也加入进来
52
- if s_i[1] > 0 and s_i[1] - 1 not in index_set:
53
- doc = doc_text_list[s_i[1] - 1]
54
- if now_len + len(doc) > all_max_len:
55
  break
56
- index_set.add(s_i[1] - 1)
57
- now_len += len(doc)
58
- if s_i[1] + 1 < len(doc_text_list) and s_i[1] + 1 not in index_set:
59
- doc = doc_text_list[s_i[1] + 1]
60
- if now_len + len(doc) > all_max_len:
61
  break
62
- index_set.add(s_i[1] + 1)
63
- now_len += len(doc)
64
 
65
- index_list = list(index_set)
66
- index_list.sort()
67
- for i in index_list:
68
- sub_doc_list.append(doc_text_list[i])
69
- document = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list)
70
  messages = [{
71
  "role": "system",
72
  "content": "你是一个有用的助手,可以使用文章内容准确地回答问题。使用提供的文章来生成你的答案,但避免逐字复制文章。尽可能使用自己的话。准确、有用、简洁、清晰。"
73
- }, {"role": "system", "content": "文章内容:\n" + document}]
74
- for his in history:
75
- messages.append({"role": "user", "content": his[0]})
76
- messages.append({"role": "assistant", "content": his[1]})
77
- messages.append({"role": "user", "content": msg})
78
- req_json = {'messages': messages, 'key': open_ai_key, 'model': "gpt-3.5-turbo"}
79
- data = {"content": json.dumps(req_json)}
80
  print('data:\n', req_json)
81
  result = requests.post(url=chat_url,
82
  data=json.dumps(data),
83
  headers=headers
84
- )
85
- res = result.json()['content']
86
- bot.append([msg, res])
87
- return bot[max(0, len(bot) - 3):]
88
 
89
 
90
- def up_file(files):
91
- doc_text_list = []
92
- for idx, file in enumerate(files):
93
  print(file.name)
94
- with pdfplumber.open(file.name) as pdf:
95
- for i in range(len(pdf.pages)):
96
  # 读取PDF文档第i+1页
97
  page = pdf.pages[i]
98
- res_list = page.extract_text().split('\n')[:-1]
99
-
100
- for j in range(len(page.images)):
101
  # 获取图片的二进制流
102
  img = page.images[j]
103
- file_name = '{}-{}-{}.png'.format(str(time.time()), str(i), str(j))
104
- with open(file_name, mode='wb') as f:
105
  f.write(img['stream'].get_data())
106
  try:
107
- res = ocr.ocr(file_name)
108
  except Exception as e:
109
- res = []
110
- if len(res) > 0:
111
- res_list.append(' '.join([re['text'] for re in res]))
112
 
113
- tables = page.extract_tables()
114
- for table in tables:
115
  # 第一列当成表头:
116
  df = pd.DataFrame(table[1:], columns=table[0])
117
  try:
118
- records = json.loads(df.to_json(orient="records", force_ascii=False))
119
- for rec in records:
120
- res_list.append(json.dumps(rec, ensure_ascii=False))
121
  except Exception as e:
122
- res_list.append(str(df))
123
-
124
- doc_text_list += res_list
125
- doc_text_list = [str(text).strip() for text in doc_text_list if len(str(text).strip()) > 0]
126
  print(doc_text_list)
127
  return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(
128
  visible=True), gr.Markdown.update(
@@ -144,11 +142,11 @@ with gr.Blocks() as demo:
144
  chat_bot = gr.Chatbot(visible=False) # 聊天机器人
145
  msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False) # 消息框
146
  with gr.Row():
147
- chat_bu = gr.Button(value='发送', visible=False)
148
 
149
- file.change(up_file, [file], [txt, doc_bu, md])
150
- doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot])
151
- chat_bu.click(get_response, [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot])
152
 
153
  if __name__ == "__main__":
154
- demo.queue().launch()
 
27
  value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True)
28
 
29
 
30
+ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings): # 获取机器人回复
31
+ now_len = len(msg) # 当前输入的长度
32
+ his_bg = -1 # 历史记录的起始位置
33
+ for i in range(len(bot) - 1, -1, -1): # 从后往前遍历历史记录
34
+ if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len: # 如果超过了历史记录的最大长度,就不再加入
35
  break
36
+ now_len += len(bot[i][0]) + len(bot[i][1]) # 更新当前长度
37
+ his_bg = i # 更新历史记录的起始位置
38
+ history = [] if his_bg == -1 else bot[his_bg:] # 获取历史记录
39
+ query_embedding = embedder.encode([msg]) # 输入向量化
40
+ cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0] # 计算相似度
41
+ score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])] # 相似度和索引对应
42
+ score_index.sort(key=lambda x: x[0], reverse=True) # 按相似度排序
43
  print('score_index:\n', score_index)
44
+ index_set, sub_doc_list = set(), [] # 用于存储最终的索引和文档
45
+ for s_i in score_index: # 遍历相似度和索引对应
46
+ doc = doc_text_list[s_i[1]] # 获取文档
47
+ if now_len + len(doc) > all_max_len: # 如果超过了最大长度,就不再加入
48
  break
49
+ index_set.add(s_i[1]) # 加入索引
50
+ now_len += len(doc) # 更新当前长度
51
  # 可能段落截断错误,所以把上下段也加入进来
52
+ if s_i[1] > 0 and s_i[1] - 1 not in index_set: # 如果上一段没有加入
53
+ doc = doc_text_list[s_i[1] - 1] # 获取上一段
54
+ if now_len + len(doc) > all_max_len: # 如果超过了最大长度,就不再加入
55
  break
56
+ index_set.add(s_i[1] - 1) # 加入索引
57
+ now_len += len(doc) # 更新当前长度
58
+ if s_i[1] + 1 < len(doc_text_list) and s_i[1] + 1 not in index_set: # 如果下一段没有加入
59
+ doc = doc_text_list[s_i[1] + 1] # 获取下一段
60
+ if now_len + len(doc) > all_max_len: # 如果超过了最大长度,就不再加入
61
  break
62
+ index_set.add(s_i[1] + 1) # 加入索引
63
+ now_len += len(doc) # 更新当前长度
64
 
65
+ index_list = list(index_set) # 转换成list
66
+ index_list.sort() # 排序
67
+ for i in index_list: # 遍历索引
68
+ sub_doc_list.append(doc_text_list[i]) # 加入文档
69
+ document = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list) # 拼接文档
70
  messages = [{
71
  "role": "system",
72
  "content": "你是一个有用的助手,可以使用文章内容准确地回答问题。使用提供的文章来生成你的答案,但避免逐字复制文章。尽可能使用自己的话。准确、有用、简洁、清晰。"
73
+ }, {"role": "system", "content": "文章内容:\n" + document}] # 角色人物定义
74
+ for his in history: # ���历历史记录
75
+ messages.append({"role": "user", "content": his[0]}) # 加入用户的历史记录
76
+ messages.append({"role": "assistant", "content": his[1]}) # 加入机器人的历史记录
77
+ messages.append({"role": "user", "content": msg}) # 加入用户的当前输入
78
+ req_json = {'messages': messages, 'key': open_ai_key, 'model': "gpt-3.5-turbo"} # 请求json
79
+ data = {"content": json.dumps(req_json)} # 请求data
80
  print('data:\n', req_json)
81
  result = requests.post(url=chat_url,
82
  data=json.dumps(data),
83
  headers=headers
84
+ ) # 请求
85
+ res = result.json()['content'] # 获取回复
86
+ bot.append([msg, res]) # 加入历史记录
87
+ return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
88
 
89
 
90
+ def up_file(files): # 上传文件
91
+ doc_text_list = [] # 用于存储文档
92
+ for idx, file in enumerate(files): # 遍历文件
93
  print(file.name)
94
+ with pdfplumber.open(file.name) as pdf: # 打开pdf
95
+ for i in range(len(pdf.pages)): # 遍历pdf的每一页
96
  # 读取PDF文档第i+1页
97
  page = pdf.pages[i]
98
+ res_list = page.extract_text().split('\n')[:-1] # 提取文本
99
+ for j in range(len(page.images)): # 遍历图片
 
100
  # 获取图片的二进制流
101
  img = page.images[j]
102
+ file_name = '{}-{}-{}.png'.format(str(time.time()), str(i), str(j)) # 生成文件名
103
+ with open(file_name, mode='wb') as f: # 保存图片
104
  f.write(img['stream'].get_data())
105
  try:
106
+ res = ocr.ocr(file_name) # 识别图片
107
  except Exception as e:
108
+ res = [] # 识别失败
109
+ if len(res) > 0: # 如果识别成功
110
+ res_list.append(' '.join([re['text'] for re in res])) # 加入识别结果
111
 
112
+ tables = page.extract_tables() # 提取表格
113
+ for table in tables: # 遍历表格
114
  # 第一列当成表头:
115
  df = pd.DataFrame(table[1:], columns=table[0])
116
  try:
117
+ records = json.loads(df.to_json(orient="records", force_ascii=False)) # 转换成json
118
+ for rec in records: # 遍历json
119
+ res_list.append(json.dumps(rec, ensure_ascii=False)) # 加入json
120
  except Exception as e:
121
+ res_list.append(str(df)) # 如果转换识别,直接把表格转为str
122
+ doc_text_list += res_list # 加入文档
123
+ doc_text_list = [str(text).strip() for text in doc_text_list if len(str(text).strip()) > 0] # 去除空格
 
124
  print(doc_text_list)
125
  return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(
126
  visible=True), gr.Markdown.update(
 
142
  chat_bot = gr.Chatbot(visible=False) # 聊天机器人
143
  msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False) # 消息框
144
  with gr.Row():
145
+ chat_bu = gr.Button(value='发送', visible=False) # 发送按钮
146
 
147
+ file.change(up_file, [file], [txt, doc_bu, md]) # 上传文件
148
+ doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot]) # 提交解析结果
149
+ chat_bu.click(get_response, [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot]) # 发送消息
150
 
151
  if __name__ == "__main__":
152
+ demo.queue().launch()