souljoy commited on
Commit
0296cb0
1 Parent(s): 9bb4c69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -54
app.py CHANGED
@@ -1,64 +1,41 @@
1
  import requests
2
  import json
3
  import gradio as gr
4
- # from concurrent.futures import ThreadPoolExecutor
5
  import pdfplumber
6
  import pandas as pd
7
  import time
8
  from cnocr import CnOcr
9
  from sentence_transformers import SentenceTransformer, models, util
10
- word_embedding_model = models.Transformer('uer/sbert-base-chinese-nli', do_lower_case=True)
11
- pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='cls')
12
- embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
13
- ocr = CnOcr()
14
- # chat_url = 'https://souljoy-my-api.hf.space/sale'
15
- chat_url = 'https://souljoy-my-api.hf.space/chatpdf'
16
  headers = {
17
  'Content-Type': 'application/json',
18
- }
19
- # thread_pool_executor = ThreadPoolExecutor(max_workers=4)
20
- history_max_len = 500
21
- all_max_len = 3000
22
-
23
-
24
- def get_emb(text):
25
- emb_url = 'https://souljoy-my-api.hf.space/embeddings'
26
- data = {"content": text}
27
- try:
28
- result = requests.post(url=emb_url,
29
- data=json.dumps(data),
30
- headers=headers
31
- )
32
- return result.json()['data'][0]['embedding']
33
- except Exception as e:
34
- print('data', data, 'result json', result.json())
35
 
36
 
37
- def doc_emb(doc: str):
38
- texts = doc.split('\n')
39
- # futures = []
40
- emb_list = embedder.encode(texts)
41
- # for text in texts:
42
- # futures.append(thread_pool_executor.submit(get_emb, text))
43
- # for f in futures:
44
- # emb_list.append(f.result())
45
  print('\n'.join(texts))
46
  return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
47
  value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True)
48
 
49
 
50
- def get_response(msg, bot, doc_text_list, doc_embeddings):
51
- # future = thread_pool_executor.submit(get_emb, msg)
52
  now_len = len(msg)
53
- req_json = {'question': msg}
54
  his_bg = -1
55
  for i in range(len(bot) - 1, -1, -1):
56
  if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len:
57
  break
58
  now_len += len(bot[i][0]) + len(bot[i][1])
59
  his_bg = i
60
- req_json['history'] = [] if his_bg == -1 else bot[his_bg:]
61
- # query_embedding = future.result()
62
  query_embedding = embedder.encode([msg])
63
  cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
64
  score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])]
@@ -72,24 +49,33 @@ def get_response(msg, bot, doc_text_list, doc_embeddings):
72
  index_set.add(s_i[1])
73
  now_len += len(doc)
74
  # 可能段落截断错误,所以把上下段也加入进来
75
- if s_i[1] > 0 and s_i[1] -1 not in index_set:
76
- doc = doc_text_list[s_i[1]-1]
77
  if now_len + len(doc) > all_max_len:
78
  break
79
- index_set.add(s_i[1]-1)
80
  now_len += len(doc)
81
  if s_i[1] + 1 < len(doc_text_list) and s_i[1] + 1 not in index_set:
82
- doc = doc_text_list[s_i[1]+1]
83
  if now_len + len(doc) > all_max_len:
84
  break
85
- index_set.add(s_i[1]+1)
86
  now_len += len(doc)
87
 
88
  index_list = list(index_set)
89
  index_list.sort()
90
  for i in index_list:
91
  sub_doc_list.append(doc_text_list[i])
92
- req_json['doc'] = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list)
 
 
 
 
 
 
 
 
 
93
  data = {"content": json.dumps(req_json)}
94
  print('data:\n', req_json)
95
  result = requests.post(url=chat_url,
@@ -146,21 +132,23 @@ def up_file(files):
146
  with gr.Blocks() as demo:
147
  with gr.Row():
148
  with gr.Column():
149
- file = gr.File(file_types=['.pdf'], label='点击上传PDF,进行解析(支持多文档、表格、OCR)', file_count='multiple')
150
- doc_bu = gr.Button(value='提交解析结果', visible=False)
151
- txt = gr.Textbox(label='PDF解析结果', visible=False)
152
- doc_text_state = gr.State([])
153
- doc_emb_state = gr.State([])
 
 
154
  with gr.Column():
155
- md = gr.Markdown("""操作说明 step 1:点击左侧区域,上传PDF,进行解析""")
156
- chat_bot = gr.Chatbot(visible=False)
157
- msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False)
158
- chat_bu = gr.Button(value='发送', visible=False)
 
159
 
160
  file.change(up_file, [file], [txt, doc_bu, md])
161
  doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot])
162
- chat_bu.click(get_response, [msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot])
163
 
164
  if __name__ == "__main__":
165
  demo.queue().launch()
166
- # demo.queue().launch(share=False, server_name='172.22.2.54', server_port=9191)
 
1
  import requests
2
  import json
3
  import gradio as gr
 
4
  import pdfplumber
5
  import pandas as pd
6
  import time
7
  from cnocr import CnOcr
8
  from sentence_transformers import SentenceTransformer, models, util
9
+
10
+ word_embedding_model = models.Transformer('uer/sbert-base-chinese-nli', do_lower_case=True) # BERT模型
11
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='cls') # 取cls向量作为句向量
12
+ embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model]) # 定义模型
13
+ ocr = CnOcr() # 初始化ocr模型
14
+ chat_url = 'https://souljoy-my-api.hf.space/chatgpt' # 你的url
15
  headers = {
16
  'Content-Type': 'application/json',
17
+ } # 你的headers
18
+ history_max_len = 500 # 机器人记忆的最大长度
19
+ all_max_len = 3000 # 输入的最大长度
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
+ def doc_emb(doc): # 文档向量化
23
+ texts = doc.split('\n') # 按行切分
24
+ emb_list = embedder.encode(texts) # 句向量化
 
 
 
 
 
25
  print('\n'.join(texts))
26
  return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
27
  value="""操作说明 step 3:PDF解析提交成功! 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True)
28
 
29
 
30
+ def get_response(open_ai_key, msg, bot, doc_text_list, doc_embeddings):
 
31
  now_len = len(msg)
 
32
  his_bg = -1
33
  for i in range(len(bot) - 1, -1, -1):
34
  if now_len + len(bot[i][0]) + len(bot[i][1]) > history_max_len:
35
  break
36
  now_len += len(bot[i][0]) + len(bot[i][1])
37
  his_bg = i
38
+ history = [] if his_bg == -1 else bot[his_bg:]
 
39
  query_embedding = embedder.encode([msg])
40
  cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
41
  score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])]
 
49
  index_set.add(s_i[1])
50
  now_len += len(doc)
51
  # 可能段落截断错误,所以把上下段也加入进来
52
+ if s_i[1] > 0 and s_i[1] - 1 not in index_set:
53
+ doc = doc_text_list[s_i[1] - 1]
54
  if now_len + len(doc) > all_max_len:
55
  break
56
+ index_set.add(s_i[1] - 1)
57
  now_len += len(doc)
58
  if s_i[1] + 1 < len(doc_text_list) and s_i[1] + 1 not in index_set:
59
+ doc = doc_text_list[s_i[1] + 1]
60
  if now_len + len(doc) > all_max_len:
61
  break
62
+ index_set.add(s_i[1] + 1)
63
  now_len += len(doc)
64
 
65
  index_list = list(index_set)
66
  index_list.sort()
67
  for i in index_list:
68
  sub_doc_list.append(doc_text_list[i])
69
+ document = '' if len(sub_doc_list) == 0 else '\n'.join(sub_doc_list)
70
+ messages = [{
71
+ "role": "system",
72
+ "content": "你是一个有用的助手,可以使用文章内容准确地回答问题。使用提供的文章来生成你的答案,但避免逐字复制文章。尽可能使用自己的话。准确、有用、简洁、清晰。"
73
+ }, {"role": "system", "content": "文章内容:\n" + document}]
74
+ for his in history:
75
+ messages.append({"role": "user", "content": his[0]})
76
+ messages.append({"role": "assistant", "content": his[1]})
77
+ messages.append({"role": "user", "content": msg})
78
+ req_json = {'messages': messages, 'key': open_ai_key, 'model': "gpt-3.5-turbo"}
79
  data = {"content": json.dumps(req_json)}
80
  print('data:\n', req_json)
81
  result = requests.post(url=chat_url,
 
132
  with gr.Blocks() as demo:
133
  with gr.Row():
134
  with gr.Column():
135
+ open_ai_key = gr.Textbox(label='OpenAI API Key', placeholder='输入你的OpenAI API Key') # 你的OpenAI API Key
136
+ file = gr.File(file_types=['.pdf'], label='点击上传PDF,进行解析(支持多文档、表格、OCR)',
137
+ file_count='multiple') # 支持多文档、表格、OCR
138
+ doc_bu = gr.Button(value='开始PDF解析', visible=False) # 开始PDF解析
139
+ txt = gr.Textbox(label='PDF解析结果', visible=False) # PDF解析结果
140
+ doc_text_state = gr.State([]) # 存储PDF解析结果
141
+ doc_emb_state = gr.State([]) # 存储PDF解析结果的embedding
142
  with gr.Column():
143
+ md = gr.Markdown("""操作说明 step 1:点击左侧区域,上传PDF,进行解析""") # 操作说明
144
+ chat_bot = gr.Chatbot(visible=False) # 聊天机器人
145
+ msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False) # 消息框
146
+ with gr.Row():
147
+ chat_bu = gr.Button(value='发送', visible=False)
148
 
149
  file.change(up_file, [file], [txt, doc_bu, md])
150
  doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot])
151
+ chat_bu.click(get_response, [open_ai_key, msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot])
152
 
153
  if __name__ == "__main__":
154
  demo.queue().launch()