souljoy commited on
Commit
122b2e0
1 Parent(s): 69afcbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -30
app.py CHANGED
@@ -1,44 +1,51 @@
1
  import requests
2
  import json
3
  import gradio as gr
4
- from concurrent.futures import ThreadPoolExecutor
5
- from sentence_transformers import util
 
 
 
 
 
6
 
7
  url = 'https://souljoy-my-api.hf.space/qa_maker'
8
  headers = {
9
  'Content-Type': 'application/json',
10
  }
11
- thread_pool_executor = ThreadPoolExecutor(max_workers=16)
12
  history_max_len = 500
13
- all_max_len = 2000
14
 
15
 
16
  def get_emb(text):
17
  emb_url = 'https://souljoy-my-api.hf.space/embeddings'
18
  data = {"content": text}
19
- result = requests.post(url=emb_url,
20
- data=json.dumps(data),
21
- headers=headers
22
- )
23
-
24
- return result.json()['data'][0]['embedding']
 
 
25
 
26
 
27
  def doc_emb(doc: str):
28
  texts = doc.split('\n')
29
- futures = []
30
- for text in texts:
31
- futures.append(thread_pool_executor.submit(get_emb, text))
32
- emb_list = []
33
- for f in futures:
34
- emb_list.append(f.result())
35
  print('\n'.join(texts))
36
  return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
37
- visible=True)
38
 
39
 
40
  def get_response(msg, bot, doc_text_list, doc_embeddings):
41
- future = thread_pool_executor.submit(get_emb, msg)
42
  now_len = len(msg)
43
  req_json = {'question': msg}
44
  his_bg = -1
@@ -48,18 +55,34 @@ def get_response(msg, bot, doc_text_list, doc_embeddings):
48
  now_len += len(bot[i][0]) + len(bot[i][1])
49
  his_bg = i
50
  req_json['history'] = [] if his_bg == -1 else bot[his_bg:]
51
- query_embedding = future.result()
 
52
  cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
53
  score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])]
54
  score_index.sort(key=lambda x: x[0], reverse=True)
55
  print('score_index:\n', score_index)
56
- index_list, sub_doc_list = [], []
57
  for s_i in score_index:
58
  doc = doc_text_list[s_i[1]]
59
  if now_len + len(doc) > all_max_len:
60
  break
61
- index_list.append(s_i[1])
62
  now_len += len(doc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  index_list.sort()
64
  for i in index_list:
65
  sub_doc_list.append(doc_text_list[i])
@@ -72,32 +95,56 @@ def get_response(msg, bot, doc_text_list, doc_embeddings):
72
  )
73
  res = result.json()['content']
74
  bot.append([msg, res])
75
- return bot[max(0, len(bot) - 3):], gr.Markdown.update(visible=False)
76
 
77
 
78
  def up_file(files):
 
79
  for idx, file in enumerate(files):
80
  print(file.name)
81
- return gr.Button.update(visible=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  with gr.Blocks() as demo:
85
  with gr.Row():
86
  with gr.Column():
87
- file = gr.File(file_types=['.pdf'], label='上传PDF')
 
88
  txt = gr.Textbox(label='PDF解析结果', visible=False)
89
- doc_bu = gr.Button(value='提交', visible=False)
90
- md = gr.Markdown("""#### 文档提交成功 🙋 """, visible=False)
91
  doc_text_state = gr.State([])
92
  doc_emb_state = gr.State([])
93
  with gr.Column():
94
- chat_bot = gr.Chatbot()
 
95
  msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False)
96
  chat_bu = gr.Button(value='发送', visible=False)
97
 
98
- doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md])
99
- chat_bu.click(get_response, [msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot, md])
100
- file.change(up_file, [file], [doc_bu])
 
101
  if __name__ == "__main__":
102
  demo.queue().launch()
103
  # demo.queue().launch(share=False, server_name='172.22.2.54', server_port=9191)
 
1
  import requests
2
  import json
3
  import gradio as gr
4
+ # from concurrent.futures import ThreadPoolExecutor
5
+ import pdfplumber
6
+ import pandas as pd
7
+ from sentence_transformers import SentenceTransformer, models, util
8
+ word_embedding_model = models.Transformer('uer/sbert-base-chinese-nli', do_lower_case=True)
9
+ pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='cls')
10
+ embedder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
11
 
12
  url = 'https://souljoy-my-api.hf.space/qa_maker'
13
  headers = {
14
  'Content-Type': 'application/json',
15
  }
16
+ # thread_pool_executor = ThreadPoolExecutor(max_workers=4)
17
  history_max_len = 500
18
+ all_max_len = 3000
19
 
20
 
21
  def get_emb(text):
22
  emb_url = 'https://souljoy-my-api.hf.space/embeddings'
23
  data = {"content": text}
24
+ try:
25
+ result = requests.post(url=emb_url,
26
+ data=json.dumps(data),
27
+ headers=headers
28
+ )
29
+ return result.json()['data'][0]['embedding']
30
+ except Exception as e:
31
+ print('data', data, 'result json', result.json())
32
 
33
 
34
  def doc_emb(doc: str):
35
  texts = doc.split('\n')
36
+ # futures = []
37
+ emb_list = embedder.encode(texts)
38
+ # for text in texts:
39
+ # futures.append(thread_pool_executor.submit(get_emb, text))
40
+ # for f in futures:
41
+ # emb_list.append(f.result())
42
  print('\n'.join(texts))
43
  return texts, emb_list, gr.Textbox.update(visible=True), gr.Button.update(visible=True), gr.Markdown.update(
44
+ value="""操作说明 step 3:PDF解析提交成功 🙋 可以开始对话啦~"""), gr.Chatbot.update(visible=True)
45
 
46
 
47
  def get_response(msg, bot, doc_text_list, doc_embeddings):
48
+ # future = thread_pool_executor.submit(get_emb, msg)
49
  now_len = len(msg)
50
  req_json = {'question': msg}
51
  his_bg = -1
 
55
  now_len += len(bot[i][0]) + len(bot[i][1])
56
  his_bg = i
57
  req_json['history'] = [] if his_bg == -1 else bot[his_bg:]
58
+ # query_embedding = future.result()
59
+ query_embedding = embedder.encode([msg])
60
  cos_scores = util.cos_sim(query_embedding, doc_embeddings)[0]
61
  score_index = [[score, index] for score, index in zip(cos_scores, [i for i in range(len(cos_scores))])]
62
  score_index.sort(key=lambda x: x[0], reverse=True)
63
  print('score_index:\n', score_index)
64
+ index_set, sub_doc_list = set(), []
65
  for s_i in score_index:
66
  doc = doc_text_list[s_i[1]]
67
  if now_len + len(doc) > all_max_len:
68
  break
69
+ index_set.add(s_i[1])
70
  now_len += len(doc)
71
+ # 可能段落截断错误,所以把上下段也加入进来
72
+ if s_i[1] > 0 and s_i[1] -1 not in index_set:
73
+ doc = doc_text_list[s_i[1]-1]
74
+ if now_len + len(doc) > all_max_len:
75
+ break
76
+ index_set.add(s_i[1]-1)
77
+ now_len += len(doc)
78
+ if s_i[1] + 1 < len(doc_text_list) and s_i[1] + 1 not in index_set:
79
+ doc = doc_text_list[s_i[1]+1]
80
+ if now_len + len(doc) > all_max_len:
81
+ break
82
+ index_set.add(s_i[1]+1)
83
+ now_len += len(doc)
84
+
85
+ index_list = list(index_set)
86
  index_list.sort()
87
  for i in index_list:
88
  sub_doc_list.append(doc_text_list[i])
 
95
  )
96
  res = result.json()['content']
97
  bot.append([msg, res])
98
+ return bot[max(0, len(bot) - 3):]
99
 
100
 
101
  def up_file(files):
102
+ doc_text_list = []
103
  for idx, file in enumerate(files):
104
  print(file.name)
105
+ with pdfplumber.open(file.name) as pdf:
106
+ for i in range(len(pdf.pages)):
107
+ # 读取PDF文档第i+1页
108
+ page = pdf.pages[i]
109
+ res_list = page.extract_text().split('\n')[:-1]
110
+ tables = page.extract_tables()
111
+ for table in tables:
112
+ # 第一列当成表头:
113
+ df = pd.DataFrame(table[1:], columns=table[0])
114
+ try:
115
+ records = json.loads(df.to_json(orient="records", force_ascii=False))
116
+ for rec in records:
117
+ res_list.append(json.dumps(rec, ensure_ascii=False))
118
+ except Exception as e:
119
+ res_list.append(str(df))
120
+
121
+ doc_text_list += res_list
122
+
123
+ for i in doc_text_list:
124
+ print(i)
125
+ return gr.Textbox.update(value='\n'.join(doc_text_list), visible=True), gr.Button.update(
126
+ visible=True), gr.Markdown.update(
127
+ value="操作说明 step 2:确认PDF解析结果(可修正),点击“提交结果”,进行对话")
128
 
129
 
130
  with gr.Blocks() as demo:
131
  with gr.Row():
132
  with gr.Column():
133
+ file = gr.File(file_types=['.pdf'], label='点击上传PDF,进行解析', file_count='multiple')
134
+ doc_bu = gr.Button(value='提交解析结果', visible=False)
135
  txt = gr.Textbox(label='PDF解析结果', visible=False)
 
 
136
  doc_text_state = gr.State([])
137
  doc_emb_state = gr.State([])
138
  with gr.Column():
139
+ md = gr.Markdown("""操作说明 step 1:点击左侧区域,上传PDF,进行解析""")
140
+ chat_bot = gr.Chatbot(visible=False)
141
  msg_txt = gr.Textbox(label='消息框', placeholder='输入消息,点击发送', visible=False)
142
  chat_bu = gr.Button(value='发送', visible=False)
143
 
144
+ file.change(up_file, [file], [txt, doc_bu, md])
145
+ doc_bu.click(doc_emb, [txt], [doc_text_state, doc_emb_state, msg_txt, chat_bu, md, chat_bot])
146
+ chat_bu.click(get_response, [msg_txt, chat_bot, doc_text_state, doc_emb_state], [chat_bot])
147
+
148
  if __name__ == "__main__":
149
  demo.queue().launch()
150
  # demo.queue().launch(share=False, server_name='172.22.2.54', server_port=9191)