KevinHuSh commited on
Commit
9fe9fc4
·
1 Parent(s): f1ced48

add dockerfile for cuda envirement. Refine table search strategy, (#123)

Browse files
Dockerfile CHANGED
@@ -14,6 +14,7 @@ ADD ./rag ./rag
14
  ENV PYTHONPATH=/ragflow/
15
  ENV HF_ENDPOINT=https://hf-mirror.com
16
 
 
17
  ADD docker/entrypoint.sh ./entrypoint.sh
18
  RUN chmod +x ./entrypoint.sh
19
 
 
14
  ENV PYTHONPATH=/ragflow/
15
  ENV HF_ENDPOINT=https://hf-mirror.com
16
 
17
+ /root/miniconda3/envs/py11/bin/pip install peewee==3.17.1
18
  ADD docker/entrypoint.sh ./entrypoint.sh
19
  RUN chmod +x ./entrypoint.sh
20
 
Dockerfile.cuda ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow-base:v1.0
2
+ USER root
3
+
4
+ WORKDIR /ragflow
5
+
6
+ ## for cuda > 12.0
7
+ RUN /root/miniconda3/envs/py11/bin/pip uninstall -y onnxruntime-gpu
8
+ RUN /root/miniconda3/envs/py11/bin/pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
9
+
10
+
11
+ ADD ./web ./web
12
+ RUN cd ./web && npm i && npm run build
13
+
14
+ ADD ./api ./api
15
+ ADD ./conf ./conf
16
+ ADD ./deepdoc ./deepdoc
17
+ ADD ./rag ./rag
18
+
19
+ ENV PYTHONPATH=/ragflow/
20
+ ENV HF_ENDPOINT=https://hf-mirror.com
21
+
22
+ /root/miniconda3/envs/py11/bin/pip install peewee==3.17.1
23
+ ADD docker/entrypoint.sh ./entrypoint.sh
24
+ RUN chmod +x ./entrypoint.sh
25
+
26
+ ENTRYPOINT ["./entrypoint.sh"]
api/apps/conversation_app.py CHANGED
@@ -21,7 +21,7 @@ from api.db.services.dialog_service import DialogService, ConversationService
21
  from api.db import LLMType
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMService, LLMBundle
24
- from api.settings import access_logger, stat_logger, retrievaler
25
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
  from api.utils import get_uuid
27
  from api.utils.api_utils import get_json_result
@@ -183,10 +183,10 @@ def chat(dialog, messages, **kwargs):
183
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
184
  ## try to use sql if field mapping is good to go
185
  if field_map:
186
- stat_logger.info("Use SQL to retrieval.")
187
- markdown_tbl, chunks = use_sql("\n".join(questions), field_map, dialog.tenant_id, chat_mdl)
188
  if markdown_tbl:
189
- return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
190
 
191
  prompt_config = dialog.prompt_config
192
  for p in prompt_config["parameters"]:
@@ -201,6 +201,7 @@ def chat(dialog, messages, **kwargs):
201
  dialog.similarity_threshold,
202
  dialog.vector_similarity_weight, top=1024, aggs=False)
203
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
 
204
 
205
  if not knowledges and prompt_config.get("empty_response"):
206
  return {"answer": prompt_config["empty_response"], "reference": kbinfos}
@@ -212,7 +213,7 @@ def chat(dialog, messages, **kwargs):
212
  if "max_tokens" in gen_conf:
213
  gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
214
  answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
215
- stat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
216
 
217
  if knowledges:
218
  answer, idx = retrievaler.insert_citations(answer,
@@ -237,47 +238,83 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
237
 
238
  问题如下:
239
  {}
240
- 请写出SQL,且只要SQL,不要有其他说明及文字。
241
  """.format(
242
  index_name(tenant_id),
243
  "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
244
  question
245
  )
246
- sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
247
- stat_logger.info(f"“{question}” get SQL: {sql}")
248
- sql = re.sub(r"[\r\n]+", " ", sql.lower())
249
- sql = re.sub(r".*?select ", "select ", sql.lower())
250
- sql = re.sub(r" +", " ", sql)
251
- sql = re.sub(r"([;;]|```).*", "", sql)
252
- if sql[:len("select ")] != "select ":
253
- return None, None
254
- if sql[:len("select *")] != "select *":
255
- sql = "select doc_id,docnm_kwd," + sql[6:]
256
- else:
257
- flds = []
258
- for k in field_map.keys():
259
- if k in forbidden_select_fields4resume:continue
260
- if len(flds) > 11:break
261
- flds.append(k)
262
- sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
263
-
264
- stat_logger.info(f"“{question}” get SQL(refined): {sql}")
265
- tbl = retrievaler.sql_retrieval(sql, format="json")
266
- if not tbl or len(tbl["rows"]) == 0: return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
269
  docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
270
  clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
271
 
272
  # compose markdown table
273
- clmns = "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
274
- line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
275
- rows = ["|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
276
  if not docid_idx or not docnm_idx:
277
- access_logger.error("SQL missing field: " + sql)
278
  return "\n".join([clmns, line, "\n".join(rows)]), []
279
 
280
- rows = "\n".join([r + f"##{ii}$$" for ii, r in enumerate(rows)])
281
  docid_idx = list(docid_idx)[0]
282
  docnm_idx = list(docnm_idx)[0]
283
  return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
 
21
  from api.db import LLMType
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMService, LLMBundle
24
+ from api.settings import access_logger, stat_logger, retrievaler, chat_logger
25
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
  from api.utils import get_uuid
27
  from api.utils.api_utils import get_json_result
 
183
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
184
  ## try to use sql if field mapping is good to go
185
  if field_map:
186
+ chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
187
+ markdown_tbl, chunks = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl)
188
  if markdown_tbl:
189
+ return {"answer": markdown_tbl, "reference": {"chunks": chunks, "doc_aggs": []}}
190
 
191
  prompt_config = dialog.prompt_config
192
  for p in prompt_config["parameters"]:
 
201
  dialog.similarity_threshold,
202
  dialog.vector_similarity_weight, top=1024, aggs=False)
203
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
204
+ chat_logger.info("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
205
 
206
  if not knowledges and prompt_config.get("empty_response"):
207
  return {"answer": prompt_config["empty_response"], "reference": kbinfos}
 
213
  if "max_tokens" in gen_conf:
214
  gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
215
  answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
216
+ chat_logger.info("User: {}|Assistant: {}".format(msg[-1]["content"], answer))
217
 
218
  if knowledges:
219
  answer, idx = retrievaler.insert_citations(answer,
 
238
 
239
  问题如下:
240
  {}
241
+ 请写出SQL, 且只要SQL,不要有其他说明及文字。
242
  """.format(
243
  index_name(tenant_id),
244
  "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
245
  question
246
  )
247
+ tried_times = 0
248
+ def get_table():
249
+ nonlocal sys_prompt, user_promt, question, tried_times
250
+ sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.06})
251
+ print(user_promt, sql)
252
+ chat_logger.info(f"“{question}”==>{user_promt} get SQL: {sql}")
253
+ sql = re.sub(r"[\r\n]+", " ", sql.lower())
254
+ sql = re.sub(r".*select ", "select ", sql.lower())
255
+ sql = re.sub(r" +", " ", sql)
256
+ sql = re.sub(r"([;;]|```).*", "", sql)
257
+ if sql[:len("select ")] != "select ":
258
+ return None, None
259
+ if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
260
+ if sql[:len("select *")] != "select *":
261
+ sql = "select doc_id,docnm_kwd," + sql[6:]
262
+ else:
263
+ flds = []
264
+ for k in field_map.keys():
265
+ if k in forbidden_select_fields4resume:continue
266
+ if len(flds) > 11:break
267
+ flds.append(k)
268
+ sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
269
+
270
+ print(f"“{question}” get SQL(refined): {sql}")
271
+
272
+ chat_logger.info(f"“{question}” get SQL(refined): {sql}")
273
+ tried_times += 1
274
+ return retrievaler.sql_retrieval(sql, format="json"), sql
275
+
276
+ tbl, sql = get_table()
277
+ if tbl.get("error") and tried_times <= 2:
278
+ user_promt = """
279
+ 表名:{};
280
+ 数据库表字段说明如下:
281
+ {}
282
+
283
+ 问题如下:
284
+ {}
285
+
286
+ 你上一次给出的错误SQL如下:
287
+ {}
288
+
289
+ 后台报错如下:
290
+ {}
291
+
292
+ 请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
293
+ """.format(
294
+ index_name(tenant_id),
295
+ "\n".join([f"{k}: {v}" for k, v in field_map.items()]),
296
+ question, sql, tbl["error"]
297
+ )
298
+ tbl, sql = get_table()
299
+ chat_logger.info("TRY it again: {}".format(sql))
300
+
301
+ chat_logger.info("GET table: {}".format(tbl))
302
+ print(tbl)
303
+ if tbl.get("error") or len(tbl["rows"]) == 0: return None, None
304
 
305
  docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
306
  docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
307
  clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
308
 
309
  # compose markdown table
310
+ clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
311
+ line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
312
+ rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
313
  if not docid_idx or not docnm_idx:
314
+ chat_logger.warning("SQL missing field: " + sql)
315
  return "\n".join([clmns, line, "\n".join(rows)]), []
316
 
317
+ rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
318
  docid_idx = list(docid_idx)[0]
319
  docnm_idx = list(docnm_idx)[0]
320
  return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
api/db/db_models.py CHANGED
@@ -502,7 +502,7 @@ class Document(DataBaseModel):
502
  token_num = IntegerField(default=0)
503
  chunk_num = IntegerField(default=0)
504
  progress = FloatField(default=0)
505
- progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
506
  process_begin_at = DateTimeField(null=True)
507
  process_duation = FloatField(default=0)
508
  run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
@@ -520,7 +520,7 @@ class Task(DataBaseModel):
520
  begin_at = DateTimeField(null=True)
521
  process_duation = FloatField(default=0)
522
  progress = FloatField(default=0)
523
- progress_msg = TextField(max_length=4096, null=True, help_text="process message", default="")
524
 
525
 
526
  class Dialog(DataBaseModel):
 
502
  token_num = IntegerField(default=0)
503
  chunk_num = IntegerField(default=0)
504
  progress = FloatField(default=0)
505
+ progress_msg = TextField(null=True, help_text="process message", default="")
506
  process_begin_at = DateTimeField(null=True)
507
  process_duation = FloatField(default=0)
508
  run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
 
520
  begin_at = DateTimeField(null=True)
521
  process_duation = FloatField(default=0)
522
  progress = FloatField(default=0)
523
+ progress_msg = TextField(null=True, help_text="process message", default="")
524
 
525
 
526
  class Dialog(DataBaseModel):
api/db/init_data.py CHANGED
@@ -90,6 +90,17 @@ def init_llm_factory():
90
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
91
  "status": "1",
92
  },
 
 
 
 
 
 
 
 
 
 
 
93
  # {
94
  # "name": "文心一言",
95
  # "logo": "",
@@ -155,6 +166,12 @@ def init_llm_factory():
155
  "tags": "LLM,CHAT,32K",
156
  "max_tokens": 32768,
157
  "model_type": LLMType.CHAT.value
 
 
 
 
 
 
158
  },{
159
  "fid": factory_infos[1]["name"],
160
  "llm_name": "text-embedding-v2",
@@ -201,6 +218,46 @@ def init_llm_factory():
201
  "max_tokens": 512,
202
  "model_type": LLMType.EMBEDDING.value
203
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  ]
205
  for info in factory_infos:
206
  LLMFactoriesService.save(**info)
 
90
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
91
  "status": "1",
92
  },
93
+ {
94
+ "name": "Local",
95
+ "logo": "",
96
+ "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
97
+ "status": "0",
98
+ },{
99
+ "name": "Moonshot",
100
+ "logo": "",
101
+ "tags": "LLM,TEXT EMBEDDING",
102
+ "status": "1",
103
+ }
104
  # {
105
  # "name": "文心一言",
106
  # "logo": "",
 
166
  "tags": "LLM,CHAT,32K",
167
  "max_tokens": 32768,
168
  "model_type": LLMType.CHAT.value
169
+ },{
170
+ "fid": factory_infos[1]["name"],
171
+ "llm_name": "qwen-max-1201",
172
+ "tags": "LLM,CHAT,6K",
173
+ "max_tokens": 5899,
174
+ "model_type": LLMType.CHAT.value
175
  },{
176
  "fid": factory_infos[1]["name"],
177
  "llm_name": "text-embedding-v2",
 
218
  "max_tokens": 512,
219
  "model_type": LLMType.EMBEDDING.value
220
  },
221
+ # ---------------------- 本地 ----------------------
222
+ {
223
+ "fid": factory_infos[3]["name"],
224
+ "llm_name": "qwen-14B-chat",
225
+ "tags": "LLM,CHAT,",
226
+ "max_tokens": 8191,
227
+ "model_type": LLMType.CHAT.value
228
+ }, {
229
+ "fid": factory_infos[3]["name"],
230
+ "llm_name": "flag-enbedding",
231
+ "tags": "TEXT EMBEDDING,",
232
+ "max_tokens": 128 * 1000,
233
+ "model_type": LLMType.EMBEDDING.value
234
+ },
235
+ # ------------------------ Moonshot -----------------------
236
+ {
237
+ "fid": factory_infos[4]["name"],
238
+ "llm_name": "moonshot-v1-8k",
239
+ "tags": "LLM,CHAT,",
240
+ "max_tokens": 7900,
241
+ "model_type": LLMType.CHAT.value
242
+ }, {
243
+ "fid": factory_infos[4]["name"],
244
+ "llm_name": "flag-enbedding",
245
+ "tags": "TEXT EMBEDDING,",
246
+ "max_tokens": 128 * 1000,
247
+ "model_type": LLMType.EMBEDDING.value
248
+ },{
249
+ "fid": factory_infos[4]["name"],
250
+ "llm_name": "moonshot-v1-32k",
251
+ "tags": "LLM,CHAT,",
252
+ "max_tokens": 32768,
253
+ "model_type": LLMType.CHAT.value
254
+ },{
255
+ "fid": factory_infos[4]["name"],
256
+ "llm_name": "moonshot-v1-128k",
257
+ "tags": "LLM,CHAT",
258
+ "max_tokens": 128 * 1000,
259
+ "model_type": LLMType.CHAT.value
260
+ },
261
  ]
262
  for info in factory_infos:
263
  LLMFactoriesService.save(**info)
api/settings.py CHANGED
@@ -29,6 +29,7 @@ LoggerFactory.LEVEL = 10
29
  stat_logger = getLogger("stat")
30
  access_logger = getLogger("access")
31
  database_logger = getLogger("database")
 
32
 
33
  API_VERSION = "v1"
34
  RAG_FLOW_SERVICE_NAME = "ragflow"
@@ -69,9 +70,15 @@ default_llm = {
69
  "image2text_model": "glm-4v",
70
  "asr_model": "",
71
  },
72
- "local": {
73
- "chat_model": "",
74
- "embedding_model": "",
 
 
 
 
 
 
75
  "image2text_model": "",
76
  "asr_model": "",
77
  }
@@ -86,7 +93,7 @@ EMBEDDING_MDL = default_llm[LLM_FACTORY]["embedding_model"]
86
  ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
87
  IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
88
 
89
- API_KEY = LLM.get("api_key", "infiniflow API Key")
90
  PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
91
 
92
  # distribution
 
29
  stat_logger = getLogger("stat")
30
  access_logger = getLogger("access")
31
  database_logger = getLogger("database")
32
+ chat_logger = getLogger("chat")
33
 
34
  API_VERSION = "v1"
35
  RAG_FLOW_SERVICE_NAME = "ragflow"
 
70
  "image2text_model": "glm-4v",
71
  "asr_model": "",
72
  },
73
+ "Local": {
74
+ "chat_model": "qwen-14B-chat",
75
+ "embedding_model": "flag-enbedding",
76
+ "image2text_model": "",
77
+ "asr_model": "",
78
+ },
79
+ "Moonshot": {
80
+ "chat_model": "moonshot-v1-8k",
81
+ "embedding_model": "flag-enbedding",
82
  "image2text_model": "",
83
  "asr_model": "",
84
  }
 
93
  ASR_MDL = default_llm[LLM_FACTORY]["asr_model"]
94
  IMAGE2TEXT_MDL = default_llm[LLM_FACTORY]["image2text_model"]
95
 
96
+ API_KEY = LLM.get("api_key", "")
97
  PARSERS = LLM.get("parsers", "naive:General,qa:Q&A,resume:Resume,table:Table,laws:Laws,manual:Manual,book:Book,paper:Paper,presentation:Presentation,picture:Picture")
98
 
99
  # distribution
deepdoc/parser/excel_parser.py CHANGED
@@ -34,7 +34,7 @@ class HuExcelParser:
34
  total = 0
35
  for sheetname in wb.sheetnames:
36
  ws = wb[sheetname]
37
- total += len(ws.rows)
38
  return total
39
 
40
  if fnm.split(".")[-1].lower() in ["csv", "txt"]:
 
34
  total = 0
35
  for sheetname in wb.sheetnames:
36
  ws = wb[sheetname]
37
+ total += len(list(ws.rows))
38
  return total
39
 
40
  if fnm.split(".")[-1].lower() in ["csv", "txt"]:
deepdoc/parser/pdf_parser.py CHANGED
@@ -655,14 +655,14 @@ class HuParser:
655
  #if min(tv, fv) > 2000:
656
  # i += 1
657
  # continue
658
- if tv < fv:
659
  tables[tk].insert(0, c)
660
  logging.debug(
661
  "TABLE:" +
662
  self.boxes[i]["text"] +
663
  "; Cap: " +
664
  tk)
665
- else:
666
  figures[fk].insert(0, c)
667
  logging.debug(
668
  "FIGURE:" +
 
655
  #if min(tv, fv) > 2000:
656
  # i += 1
657
  # continue
658
+ if tv < fv and tk:
659
  tables[tk].insert(0, c)
660
  logging.debug(
661
  "TABLE:" +
662
  self.boxes[i]["text"] +
663
  "; Cap: " +
664
  tk)
665
+ elif fk:
666
  figures[fk].insert(0, c)
667
  logging.debug(
668
  "FIGURE:" +
deepdoc/parser/ppt_parser.py CHANGED
@@ -31,7 +31,7 @@ class HuPptParser(object):
31
 
32
  if shape.shape_type == 6:
33
  texts = []
34
- for p in shape.shapes:
35
  t = self.__extract(p)
36
  if t: texts.append(t)
37
  return "\n".join(texts)
@@ -46,7 +46,7 @@ class HuPptParser(object):
46
  if i < from_page: continue
47
  if i >= to_page:break
48
  texts = []
49
- for shape in slide.shapes:
50
  txt = self.__extract(shape)
51
  if txt: texts.append(txt)
52
  txts.append("\n".join(texts))
 
31
 
32
  if shape.shape_type == 6:
33
  texts = []
34
+ for p in sorted(shape.shapes, key=lambda x: (x.top//10, x.left)):
35
  t = self.__extract(p)
36
  if t: texts.append(t)
37
  return "\n".join(texts)
 
46
  if i < from_page: continue
47
  if i >= to_page:break
48
  texts = []
49
+ for shape in sorted(slide.shapes, key=lambda x: (x.top//10, x.left)):
50
  txt = self.__extract(shape)
51
  if txt: texts.append(txt)
52
  txts.append("\n".join(texts))
deepdoc/vision/ocr.py CHANGED
@@ -64,10 +64,15 @@ def load_model(model_dir, nm):
64
  raise ValueError("not find model file path {}".format(
65
  model_file_path))
66
 
 
 
 
 
 
67
  if ort.get_device() == "GPU":
68
- sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
69
  else:
70
- sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
71
  return sess, sess.get_inputs()[0]
72
 
73
 
@@ -325,7 +330,13 @@ class TextRecognizer(object):
325
 
326
  input_dict = {}
327
  input_dict[self.input_tensor.name] = norm_img_batch
328
- outputs = self.predictor.run(None, input_dict)
 
 
 
 
 
 
329
  preds = outputs[0]
330
  rec_result = self.postprocess_op(preds)
331
  for rno in range(len(rec_result)):
@@ -430,7 +441,13 @@ class TextDetector(object):
430
  img = img.copy()
431
  input_dict = {}
432
  input_dict[self.input_tensor.name] = img
433
- outputs = self.predictor.run(None, input_dict)
 
 
 
 
 
 
434
 
435
  post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
436
  dt_boxes = post_result[0]['points']
 
64
  raise ValueError("not find model file path {}".format(
65
  model_file_path))
66
 
67
+ options = ort.SessionOptions()
68
+ options.enable_cpu_mem_arena = False
69
+ options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
70
+ options.intra_op_num_threads = 2
71
+ options.inter_op_num_threads = 2
72
  if ort.get_device() == "GPU":
73
+ sess = ort.InferenceSession(model_file_path, options=options, providers=['CUDAExecutionProvider'])
74
  else:
75
+ sess = ort.InferenceSession(model_file_path, options=options, providers=['CPUExecutionProvider'])
76
  return sess, sess.get_inputs()[0]
77
 
78
 
 
330
 
331
  input_dict = {}
332
  input_dict[self.input_tensor.name] = norm_img_batch
333
+ for i in range(100000):
334
+ try:
335
+ outputs = self.predictor.run(None, input_dict)
336
+ break
337
+ except Exception as e:
338
+ if i >= 3: raise e
339
+ time.sleep(5)
340
  preds = outputs[0]
341
  rec_result = self.postprocess_op(preds)
342
  for rno in range(len(rec_result)):
 
441
  img = img.copy()
442
  input_dict = {}
443
  input_dict[self.input_tensor.name] = img
444
+ for i in range(100000):
445
+ try:
446
+ outputs = self.predictor.run(None, input_dict)
447
+ break
448
+ except Exception as e:
449
+ if i >= 3: raise e
450
+ time.sleep(5)
451
 
452
  post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
453
  dt_boxes = post_result[0]['points']
deepdoc/vision/recognizer.py CHANGED
@@ -42,7 +42,9 @@ class Recognizer(object):
42
  raise ValueError("not find model file path {}".format(
43
  model_file_path))
44
  if ort.get_device() == "GPU":
45
- self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
 
 
46
  else:
47
  self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
48
  self.input_names = [node.name for node in self.ort_sess.get_inputs()]
 
42
  raise ValueError("not find model file path {}".format(
43
  model_file_path))
44
  if ort.get_device() == "GPU":
45
+ options = ort.SessionOptions()
46
+ options.enable_cpu_mem_arena = False
47
+ self.ort_sess = ort.InferenceSession(model_file_path, options=options, providers=[('CUDAExecutionProvider')])
48
  else:
49
  self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
50
  self.input_names = [node.name for node in self.ort_sess.get_inputs()]
rag/app/table.py CHANGED
@@ -67,7 +67,7 @@ class Excel(ExcelParser):
67
 
68
  def trans_datatime(s):
69
  try:
70
- return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S")
71
  except Exception as e:
72
  pass
73
 
@@ -80,6 +80,7 @@ def trans_bool(s):
80
 
81
 
82
  def column_data_type(arr):
 
83
  uni = len(set([a for a in arr if a is not None]))
84
  counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
85
  trans = {t: f for f, t in
@@ -130,7 +131,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
130
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
131
  callback(0.1, "Start to parse.")
132
  excel_parser = Excel()
133
- dfs = excel_parser(filename, binary, callback)
134
  elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
135
  callback(0.1, "Start to parse.")
136
  txt = ""
@@ -188,7 +189,7 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
188
  df[clmns[j]] = cln
189
  if ty == "text":
190
  txts.extend([str(c) for c in cln if c])
191
- clmns_map = [(py_clmns[i] + fieds_map[clmn_tys[i]], clmns[i])
192
  for i in range(len(clmns))]
193
 
194
  eng = lang.lower() == "english"#is_english(txts)
@@ -201,6 +202,8 @@ def chunk(filename, binary=None, from_page=0, to_page=10000000000, lang="Chinese
201
  for j in range(len(clmns)):
202
  if row[clmns[j]] is None:
203
  continue
 
 
204
  fld = clmns_map[j][0]
205
  d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(
206
  row[clmns[j]])
 
67
 
68
  def trans_datatime(s):
69
  try:
70
+ return datetime_parse(s.strip()).strftime("%Y-%m-%d %H:%M:%S")
71
  except Exception as e:
72
  pass
73
 
 
80
 
81
 
82
  def column_data_type(arr):
83
+ arr = list(arr)
84
  uni = len(set([a for a in arr if a is not None]))
85
  counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
86
  trans = {t: f for f, t in
 
131
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
132
  callback(0.1, "Start to parse.")
133
  excel_parser = Excel()
134
+ dfs = excel_parser(filename, binary, from_page=from_page, to_page=to_page, callback=callback)
135
  elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
136
  callback(0.1, "Start to parse.")
137
  txt = ""
 
189
  df[clmns[j]] = cln
190
  if ty == "text":
191
  txts.extend([str(c) for c in cln if c])
192
+ clmns_map = [(py_clmns[i] + fieds_map[clmn_tys[i]], clmns[i].replace("_", " "))
193
  for i in range(len(clmns))]
194
 
195
  eng = lang.lower() == "english"#is_english(txts)
 
202
  for j in range(len(clmns)):
203
  if row[clmns[j]] is None:
204
  continue
205
+ if not str(row[clmns[j]]):
206
+ continue
207
  fld = clmns_map[j][0]
208
  d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(
209
  row[clmns[j]])
rag/llm/__init__.py CHANGED
@@ -19,18 +19,20 @@ from .cv_model import *
19
 
20
 
21
  EmbeddingModel = {
22
- "local": HuEmbedding,
23
  "OpenAI": OpenAIEmbed,
24
  "通义千问": HuEmbedding, #QWenEmbed,
25
- "智谱AI": ZhipuEmbed
 
26
  }
27
 
28
 
29
  CvModel = {
30
  "OpenAI": GptV4,
31
- "local": LocalCV,
32
  "通义千问": QWenCV,
33
- "智谱AI": Zhipu4V
 
34
  }
35
 
36
 
@@ -38,6 +40,7 @@ ChatModel = {
38
  "OpenAI": GptTurbo,
39
  "智谱AI": ZhipuChat,
40
  "通义千问": QWenChat,
41
- "local": LocalLLM
 
42
  }
43
 
 
19
 
20
 
21
  EmbeddingModel = {
22
+ "Local": HuEmbedding,
23
  "OpenAI": OpenAIEmbed,
24
  "通义千问": HuEmbedding, #QWenEmbed,
25
+ "智谱AI": ZhipuEmbed,
26
+ "Moonshot": HuEmbedding
27
  }
28
 
29
 
30
  CvModel = {
31
  "OpenAI": GptV4,
32
+ "Local": LocalCV,
33
  "通义千问": QWenCV,
34
+ "智谱AI": Zhipu4V,
35
+ "Moonshot": LocalCV
36
  }
37
 
38
 
 
40
  "OpenAI": GptTurbo,
41
  "智谱AI": ZhipuChat,
42
  "通义千问": QWenChat,
43
+ "Local": LocalLLM,
44
+ "Moonshot": MoonshotChat
45
  }
46
 
rag/llm/chat_model.py CHANGED
@@ -14,11 +14,8 @@
14
  # limitations under the License.
15
  #
16
  from abc import ABC
17
- from copy import deepcopy
18
-
19
  from openai import OpenAI
20
  import openai
21
-
22
  from rag.nlp import is_english
23
  from rag.utils import num_tokens_from_string
24
 
@@ -52,6 +49,12 @@ class GptTurbo(Base):
52
  return "**ERROR**: "+str(e), 0
53
 
54
 
 
 
 
 
 
 
55
  from dashscope import Generation
56
  class QWenChat(Base):
57
  def __init__(self, key, model_name=Generation.Models.qwen_turbo):
 
14
  # limitations under the License.
15
  #
16
  from abc import ABC
 
 
17
  from openai import OpenAI
18
  import openai
 
19
  from rag.nlp import is_english
20
  from rag.utils import num_tokens_from_string
21
 
 
49
  return "**ERROR**: "+str(e), 0
50
 
51
 
52
+ class MoonshotChat(GptTurbo):
53
+ def __init__(self, key, model_name="moonshot-v1-8k"):
54
+ self.client = OpenAI(api_key=key, base_url="https://api.moonshot.cn/v1",)
55
+ self.model_name = model_name
56
+
57
+
58
  from dashscope import Generation
59
  class QWenChat(Base):
60
  def __init__(self, key, model_name=Generation.Models.qwen_turbo):
rag/llm/rpc_server.py CHANGED
@@ -4,7 +4,7 @@ import random
4
  import time
5
  from multiprocessing.connection import Listener
6
  from threading import Thread
7
- import torch
8
 
9
 
10
  class RPCHandler:
@@ -47,14 +47,27 @@ tokenizer = None
47
  def chat(messages, gen_conf):
48
  global tokenizer
49
  model = Model()
50
- roles = {"system":"System", "user": "User", "assistant": "Assistant"}
51
- line = ["{}: {}".format(roles[m["role"].lower()], m["content"]) for m in messages]
52
- line = "\n".join(line) + "\nAssistant: "
53
- tokens = tokenizer([line], return_tensors='pt')
54
- tokens = {k: tokens[k].to(model.device) if isinstance(tokens[k], torch.Tensor) else tokens[k] for k in
55
- tokens.keys()}
56
- res = [tokenizer.decode(t) for t in model.generate(**tokens, **gen_conf)][0]
57
- return res.split("Assistant: ")[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  def Model():
@@ -71,20 +84,13 @@ if __name__ == "__main__":
71
  handler = RPCHandler()
72
  handler.register_function(chat)
73
 
74
- from transformers import AutoModelForCausalLM, AutoTokenizer
75
- from transformers.generation.utils import GenerationConfig
76
-
77
  models = []
78
- for _ in range(2):
79
  m = AutoModelForCausalLM.from_pretrained(args.model_name,
80
  device_map="auto",
81
- torch_dtype='auto',
82
- trust_remote_code=True)
83
- m.generation_config = GenerationConfig.from_pretrained(args.model_name)
84
- m.generation_config.pad_token_id = m.generation_config.eos_token_id
85
  models.append(m)
86
- tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False,
87
- trust_remote_code=True)
88
 
89
  # Run the server
90
  rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
 
4
  import time
5
  from multiprocessing.connection import Listener
6
  from threading import Thread
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
 
9
 
10
  class RPCHandler:
 
47
  def chat(messages, gen_conf):
48
  global tokenizer
49
  model = Model()
50
+ try:
51
+ conf = {"max_new_tokens": int(gen_conf.get("max_tokens", 256)), "temperature": float(gen_conf.get("temperature", 0.1))}
52
+ print(messages, conf)
53
+ text = tokenizer.apply_chat_template(
54
+ messages,
55
+ tokenize=False,
56
+ add_generation_prompt=True
57
+ )
58
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
59
+
60
+ generated_ids = model.generate(
61
+ model_inputs.input_ids,
62
+ **conf
63
+ )
64
+ generated_ids = [
65
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
66
+ ]
67
+
68
+ return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
69
+ except Exception as e:
70
+ return str(e)
71
 
72
 
73
  def Model():
 
84
  handler = RPCHandler()
85
  handler.register_function(chat)
86
 
 
 
 
87
  models = []
88
+ for _ in range(1):
89
  m = AutoModelForCausalLM.from_pretrained(args.model_name,
90
  device_map="auto",
91
+ torch_dtype='auto')
 
 
 
92
  models.append(m)
93
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
 
94
 
95
  # Run the server
96
  rpc_server(handler, ('0.0.0.0', args.port), authkey=b'infiniflow-token4kevinhu')
rag/nlp/search.py CHANGED
@@ -7,6 +7,7 @@ from elasticsearch_dsl import Q, Search
7
  from typing import List, Optional, Dict, Union
8
  from dataclasses import dataclass
9
 
 
10
  from rag.settings import es_logger
11
  from rag.utils import rmSpace
12
  from rag.nlp import huqie, query
@@ -333,15 +334,16 @@ class Dealer:
333
  replaces = []
334
  for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
335
  fld, v = r.group(1), r.group(3)
336
- match = " MATCH({}, '{}', 'operator=OR;fuzziness=AUTO:1,3;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v)))
337
  replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match))
338
 
339
  for p, r in replaces: sql = sql.replace(p, r, 1)
340
- es_logger.info(f"To es: {sql}")
341
 
342
  try:
343
  tbl = self.es.sql(sql, fetch_size, format)
344
  return tbl
345
  except Exception as e:
346
- es_logger.error(f"SQL failure: {sql} =>" + str(e))
 
347
 
 
7
  from typing import List, Optional, Dict, Union
8
  from dataclasses import dataclass
9
 
10
+ from api.settings import chat_logger
11
  from rag.settings import es_logger
12
  from rag.utils import rmSpace
13
  from rag.nlp import huqie, query
 
334
  replaces = []
335
  for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
336
  fld, v = r.group(1), r.group(3)
337
+ match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(fld, huqie.qieqie(huqie.qie(v)))
338
  replaces.append(("{}{}'{}'".format(r.group(1), r.group(2), r.group(3)), match))
339
 
340
  for p, r in replaces: sql = sql.replace(p, r, 1)
341
+ chat_logger.info(f"To es: {sql}")
342
 
343
  try:
344
  tbl = self.es.sql(sql, fetch_size, format)
345
  return tbl
346
  except Exception as e:
347
+ chat_logger.error(f"SQL failure: {sql} =>" + str(e))
348
+ return {"error": str(e)}
349
 
rag/svr/task_executor.py CHANGED
@@ -169,16 +169,25 @@ def init_kb(row):
169
 
170
 
171
  def embedding(docs, mdl, parser_config={}, callback=None):
 
172
  tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
173
  d["content_with_weight"] for d in docs]
174
  tk_count = 0
175
  if len(tts) == len(cnts):
176
- tts, c = mdl.encode(tts)
177
- tk_count += c
 
 
 
 
 
 
 
 
178
 
179
  cnts_ = np.array([])
180
- for i in range(0, len(cnts), 8):
181
- vts, c = mdl.encode(cnts[i: i+8])
182
  if len(cnts_) == 0: cnts_ = vts
183
  else: cnts_ = np.concatenate((cnts_, vts), axis=0)
184
  tk_count += c
 
169
 
170
 
171
  def embedding(docs, mdl, parser_config={}, callback=None):
172
+ batch_size = 32
173
  tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
174
  d["content_with_weight"] for d in docs]
175
  tk_count = 0
176
  if len(tts) == len(cnts):
177
+ tts_ = np.array([])
178
+ for i in range(0, len(tts), batch_size):
179
+ vts, c = mdl.encode(tts[i: i + batch_size])
180
+ if len(tts_) == 0:
181
+ tts_ = vts
182
+ else:
183
+ tts_ = np.concatenate((tts_, vts), axis=0)
184
+ tk_count += c
185
+ callback(prog=0.6 + 0.1 * (i + 1) / len(tts), msg="")
186
+ tts = tts_
187
 
188
  cnts_ = np.array([])
189
+ for i in range(0, len(cnts), batch_size):
190
+ vts, c = mdl.encode(cnts[i: i+batch_size])
191
  if len(cnts_) == 0: cnts_ = vts
192
  else: cnts_ = np.concatenate((cnts_, vts), axis=0)
193
  tk_count += c
rag/utils/es_conn.py CHANGED
@@ -249,6 +249,8 @@ class HuEs:
249
  except ConnectionTimeout as e:
250
  es_logger.error("Timeout【Q】:" + sql)
251
  continue
 
 
252
  es_logger.error("ES search timeout for 3 times!")
253
  raise ConnectionTimeout()
254
 
 
249
  except ConnectionTimeout as e:
250
  es_logger.error("Timeout【Q】:" + sql)
251
  continue
252
+ except Exception as e:
253
+ raise e
254
  es_logger.error("ES search timeout for 3 times!")
255
  raise ConnectionTimeout()
256