souljoy commited on
Commit
296b38a
1 Parent(s): 3c22f61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -24
app.py CHANGED
@@ -5,24 +5,35 @@ import pdfplumber
5
  import pandas as pd
6
  import time
7
  from cnocr import CnOcr
8
- from sentence_transformers import SentenceTransformer, models, util
9
 
10
- word_embedding_model = models.Transformer('uer/sbert-base-chinese-nli', do_lower_case=True) # BERT模型
11
- pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='cls') # 取cls向量作为句向量
12
- embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model]) # 定义模型
13
  ocr = CnOcr() # 初始化ocr模型
14
- chat_url = 'https://souljoy-my-api.hf.space/chatgpt' # 你的url
15
- headers = {
16
- 'Content-Type': 'application/json',
17
- } # 你的headers
18
  history_max_len = 500 # 机器人记忆的最大长度
19
  all_max_len = 3000 # 输入的最大长度
20
 
21
 
22
- def doc_emb(doc): # 文档向量化
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  texts = doc.split('\n') # 按行切分
24
- emb_list = embedder.encode(texts) # 句向量化
25
- print('\n'.join(texts))
 
26
  return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
27
  value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True)
28
 
@@ -36,9 +47,17 @@ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings): # 获
36
  now_len += len(bot[i][0]) + len(bot[i][1]) # 更新当前长度
37
  his_bg = i # 更新历史记录的起始位置
38
  history = [] if his_bg == -1 else bot[his_bg:] # 获取历史记录
39
- query_embedding = embedder.encode([msg]) # 输入向量化
40
- cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0] # 计算相似度
41
- score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])] # 相似度和索引对应
 
 
 
 
 
 
 
 
42
  score_index.sort(key=lambda x: x[0], reverse=True) # 按相似度排序
43
  print('score_index:\n', score_index)
44
  index_set, sub_doc_list = set(), [] # 用于存储最终的索引和文档
@@ -75,14 +94,24 @@ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings): # 获
75
  messages.append({"role": "user", "content": his[0]}) # 加入用户的历史记录
76
  messages.append({"role": "assistant", "content": his[1]}) # 加入机器人的历史记录
77
  messages.append({"role": "user", "content": msg}) # 加入用户的当前输入
78
- req_json = {'messages': messages, 'key': open_ai_key, 'model': "gpt-3.5-turbo"} # 请求json
79
- data = {"content": json.dumps(req_json)} # 请求data
80
- print('data:\n', req_json)
81
- result = requests.post(url=chat_url,
 
 
 
 
 
 
 
 
 
 
82
  data=json.dumps(data),
83
  headers=headers
84
- ) # 请求
85
- res = result.json()['content'] # 获取回复
86
  bot.append([msg, res]) # 加入历史记录
87
  return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
88
 
@@ -124,7 +153,7 @@ def up_file(files): # 上传文件
124
  print(doc_text_list)
125
  return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(
126
  visible=True), gr.Markdown.update(
127
- value="操作说明 step 2:确认PDF解析结果(可修正),点击“提交解析结果”,随后进行对话")
128
 
129
 
130
  with gr.Blocks() as demo:
@@ -134,7 +163,8 @@ with gr.Blocks() as demo:
134
  file = gr.File(file_types=['.pdf'], label='点击上传PDF,进行解析(支持多文档、表格、OCR)',
135
  file_count='multiple') # 支持多文档、表格、OCR
136
  txt = gr.Textbox(label='PDF解析结果', visible=False) # PDF解析结果
137
- doc_bu = gr.Button(value='提交解析结果', visible=False) # 提交解析结果
 
138
  doc_text_state = gr.State([]) # 存储PDF解析结果
139
  doc_emb_state = gr.State([]) # 存储PDF解析结果的embedding
140
  with gr.Column():
@@ -144,8 +174,9 @@ with gr.Blocks() as demo:
144
  with gr.Row():
145
  chat_bu = gr.Button(value='发送', visible=False) # 发送按钮
146
 
147
- file.change(up_file, [file], [txt, doc_bu, md]) # 上传文件
148
- doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot]) # 提交解析结果
 
149
  chat_bu.click(get_response, [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot]) # 发送消息
150
 
151
  if __name__ == "__main__":
 
5
  import pandas as pd
6
  import time
7
  from cnocr import CnOcr
8
+ import numpy as np
9
 
 
 
 
10
  ocr = CnOcr() # 初始化ocr模型
 
 
 
 
11
  history_max_len = 500 # 机器人记忆的最大长度
12
  all_max_len = 3000 # 输入的最大长度
13
 
14
 
15
+ def get_text_emb(open_ai_key, text):
16
+ url = 'https://api.openai.com/v1/embeddings'
17
+ headers = {
18
+ 'Content-Type': 'application/json',
19
+ 'Authorization': 'Bearer ' + open_ai_key
20
+ }
21
+ data = {
22
+ "model": "text-embedding-ada-002",
23
+ "input": text
24
+ }
25
+ result = requests.post(url=url,
26
+ data=json.dumps(data),
27
+ headers=headers
28
+ )
29
+ return result.json()['data'][0]['embedding']
30
+
31
+
32
+ def doc_index_self(open_ai_key, doc): # 文档向量化
33
  texts = doc.split('\n') # 按行切分
34
+ emb_list = []
35
+ for text in texts:
36
+ emb_list.append(get_text_emb(open_ai_key, text))
37
  return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
38
  value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True)
39
 
 
47
  now_len += len(bot[i][0]) + len(bot[i][1]) # 更新当前长度
48
  his_bg = i # 更新历史记录的起始位置
49
  history = [] if his_bg == -1 else bot[his_bg:] # 获取历史记录
50
+ query_embedding = get_text_emb(open_ai_key, msg) # 获取输入的向量
51
+ cos_scores = [] # 用于存储相似度
52
+
53
+ def cos_sim(a, b):
54
+ return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
55
+
56
+ for doc_embedding in doc_embeddings: # 遍历文档向量
57
+ cos_scores.append(cos_sim(query_embedding, doc_embedding)) # 计算相似度
58
+ score_index = [] # 用于存储相似度和索引对应
59
+ for i in range(len(cos_scores)): # 遍历相似度
60
+ score_index.append((cos_scores[i], i)) # 加入相似度和索引对应
61
  score_index.sort(key=lambda x: x[0], reverse=True) # 按相似度排序
62
  print('score_index:\n', score_index)
63
  index_set, sub_doc_list = set(), [] # 用于存储最终的索引和文档
 
94
  messages.append({"role": "user", "content": his[0]}) # 加入用户的历史记录
95
  messages.append({"role": "assistant", "content": his[1]}) # 加入机器人的历史记录
96
  messages.append({"role": "user", "content": msg}) # 加入用户的当前输入
97
+
98
+ url = 'https://api.openai.com/v1/chat/completions'
99
+
100
+ data = {
101
+ "model": "gpt-3.5-turbo",
102
+ "messages": messages
103
+ }
104
+ print("data = \n", data)
105
+
106
+ headers = {
107
+ 'Content-Type': 'application/json',
108
+ 'Authorization': 'Bearer ' + open_ai_key
109
+ }
110
+ result = requests.post(url=url,
111
  data=json.dumps(data),
112
  headers=headers
113
+ )
114
+ res = str(result.json()['choices'][0]['message']['content']).strip()
115
  bot.append([msg, res]) # 加入历史记录
116
  return bot[max(0, len(bot) - 3):] # 返回最近3轮的历史记录
117
 
 
153
  print(doc_text_list)
154
  return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(
155
  visible=True), gr.Markdown.update(
156
+ value="操作说明 step 2:确认PDF解析结果(可修正),点击“建立索引”,随后进行对话")
157
 
158
 
159
  with gr.Blocks() as demo:
 
163
  file = gr.File(file_types=['.pdf'], label='点击上传PDF,进行解析(支持多文档、表格、OCR)',
164
  file_count='multiple') # 支持多文档、表格、OCR
165
  txt = gr.Textbox(label='PDF解析结果', visible=False) # PDF解析结果
166
+ index_self_bu = gr.Button(value='建立索引(by self)', visible=False) #
167
+ index_llama_bu = gr.Button(value='建立索引(by llama_index)', visible=False) #
168
  doc_text_state = gr.State([]) # 存储PDF解析结果
169
  doc_emb_state = gr.State([]) # 存储PDF解析结果的embedding
170
  with gr.Column():
 
174
  with gr.Row():
175
  chat_bu = gr.Button(value='发送', visible=False) # 发送按钮
176
 
177
+ file.change(up_file, [file], [txt, index_self_bu, md]) # 上传文件
178
+ index_self_bu.click(doc_index_self, [txt],
179
+ [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot]) # 提交解析结果
180
  chat_bu.click(get_response, [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot]) # 发送消息
181
 
182
  if __name__ == "__main__":