KevinHuSh commited on
Commit
407b252
·
1 Parent(s): f305776

remove unused codes, seperate layout detection out as a new api. Add new rag methed 'table' (#55)

Browse files
api/apps/__init__.py CHANGED
@@ -28,8 +28,6 @@ from api.utils import CustomJSONEncoder
28
  from flask_session import Session
29
  from flask_login import LoginManager
30
  from api.settings import RetCode, SECRET_KEY, stat_logger
31
- from api.hook import HookManager
32
- from api.hook.common.parameters import AuthenticationParameters, ClientAuthenticationParameters
33
  from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
34
  from api.utils.api_utils import get_json_result, server_error_response
35
  from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
@@ -96,37 +94,7 @@ client_urls_prefix = [
96
  ]
97
 
98
 
99
- def client_authentication_before_request():
100
- result = HookManager.client_authentication(ClientAuthenticationParameters(
101
- request.full_path, request.headers,
102
- request.form, request.data, request.json,
103
- ))
104
 
105
- if result.code != RetCode.SUCCESS:
106
- return get_json_result(result.code, result.message)
107
-
108
-
109
- def site_authentication_before_request():
110
- for url_prefix in client_urls_prefix:
111
- if request.path.startswith(url_prefix):
112
- return
113
-
114
- result = HookManager.site_authentication(AuthenticationParameters(
115
- request.headers.get('site_signature'),
116
- request.json,
117
- ))
118
-
119
- if result.code != RetCode.SUCCESS:
120
- return get_json_result(result.code, result.message)
121
-
122
-
123
- @app.before_request
124
- def authentication_before_request():
125
- if CLIENT_AUTHENTICATION:
126
- return client_authentication_before_request()
127
-
128
- if SITE_AUTHENTICATION:
129
- return site_authentication_before_request()
130
 
131
  @login_manager.request_loader
132
  def load_user(web_request):
 
28
  from flask_session import Session
29
  from flask_login import LoginManager
30
  from api.settings import RetCode, SECRET_KEY, stat_logger
 
 
31
  from api.settings import API_VERSION, CLIENT_AUTHENTICATION, SITE_AUTHENTICATION, access_logger
32
  from api.utils.api_utils import get_json_result, server_error_response
33
  from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
 
94
  ]
95
 
96
 
 
 
 
 
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  @login_manager.request_loader
100
  def load_user(web_request):
api/apps/chunk_app.py CHANGED
@@ -57,7 +57,7 @@ def list():
57
  for id in sres.ids:
58
  d = {
59
  "chunk_id": id,
60
- "content_ltks": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_ltks"],
61
  "doc_id": sres.field[id]["doc_id"],
62
  "docnm_kwd": sres.field[id]["docnm_kwd"],
63
  "important_kwd": sres.field[id].get("important_kwd", []),
@@ -134,7 +134,7 @@ def set():
134
  q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
135
  d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a]))
136
 
137
- v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
138
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
139
  d["q_%d_vec" % len(v)] = v.tolist()
140
  ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
@@ -175,13 +175,13 @@ def rm():
175
 
176
  @manager.route('/create', methods=['POST'])
177
  @login_required
178
- @validate_request("doc_id", "content_ltks")
179
  def create():
180
  req = request.json
181
  md5 = hashlib.md5()
182
- md5.update((req["content_ltks"] + req["doc_id"]).encode("utf-8"))
183
  chunck_id = md5.hexdigest()
184
- d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_ltks"])}
185
  d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
186
  d["important_kwd"] = req.get("important_kwd", [])
187
  d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
@@ -201,7 +201,7 @@ def create():
201
 
202
  embd_mdl = TenantLLMService.model_instance(
203
  tenant_id, LLMType.EMBEDDING.value)
204
- v, c = embd_mdl.encode([doc.name, req["content_ltks"]])
205
  DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
206
  v = 0.1 * v[0] + 0.9 * v[1]
207
  d["q_%d_vec" % len(v)] = v.tolist()
 
57
  for id in sres.ids:
58
  d = {
59
  "chunk_id": id,
60
+ "content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"],
61
  "doc_id": sres.field[id]["doc_id"],
62
  "docnm_kwd": sres.field[id]["docnm_kwd"],
63
  "important_kwd": sres.field[id].get("important_kwd", []),
 
134
  q, a = rmPrefix(arr[0]), rmPrefix[arr[1]]
135
  d = beAdoc(d, arr[0], arr[1], not any([huqie.is_chinese(t) for t in q+a]))
136
 
137
+ v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
138
  v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1]
139
  d["q_%d_vec" % len(v)] = v.tolist()
140
  ELASTICSEARCH.upsert([d], search.index_name(tenant_id))
 
175
 
176
  @manager.route('/create', methods=['POST'])
177
  @login_required
178
+ @validate_request("doc_id", "content_with_weight")
179
  def create():
180
  req = request.json
181
  md5 = hashlib.md5()
182
+ md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
183
  chunck_id = md5.hexdigest()
184
+ d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])}
185
  d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
186
  d["important_kwd"] = req.get("important_kwd", [])
187
  d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
 
201
 
202
  embd_mdl = TenantLLMService.model_instance(
203
  tenant_id, LLMType.EMBEDDING.value)
204
+ v, c = embd_mdl.encode([doc.name, req["content_with_weight"]])
205
  DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0)
206
  v = 0.1 * v[0] + 0.9 * v[1]
207
  d["q_%d_vec" % len(v)] = v.tolist()
api/apps/conversation_app.py CHANGED
@@ -175,7 +175,7 @@ def chat(dialog, messages, **kwargs):
175
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
176
  kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
177
  dialog.vector_similarity_weight, top=1024, aggs=False)
178
- knowledges = [ck["content_ltks"] for ck in kbinfos["chunks"]]
179
 
180
  if not knowledges and prompt_config["empty_response"]:
181
  return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}
 
175
  chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
176
  kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
177
  dialog.vector_similarity_weight, top=1024, aggs=False)
178
+ knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
179
 
180
  if not knowledges and prompt_config["empty_response"]:
181
  return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}
api/apps/document_app.py CHANGED
@@ -73,6 +73,7 @@ def upload():
73
  "id": get_uuid(),
74
  "kb_id": kb.id,
75
  "parser_id": kb.parser_id,
 
76
  "created_by": current_user.id,
77
  "type": filename_type(filename),
78
  "name": filename,
@@ -108,6 +109,7 @@ def create():
108
  "id": get_uuid(),
109
  "kb_id": kb.id,
110
  "parser_id": kb.parser_id,
 
111
  "created_by": current_user.id,
112
  "type": FileType.VIRTUAL,
113
  "name": req["name"],
@@ -128,8 +130,8 @@ def list():
128
  data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
129
  keywords = request.args.get("keywords", "")
130
 
131
- page_number = request.args.get("page", 1)
132
- items_per_page = request.args.get("page_size", 15)
133
  orderby = request.args.get("orderby", "create_time")
134
  desc = request.args.get("desc", True)
135
  try:
@@ -214,7 +216,9 @@ def run():
214
  req = request.json
215
  try:
216
  for id in req["doc_ids"]:
217
- DocumentService.update_by_id(id, {"run": str(req["run"]), "progress": 0})
 
 
218
  if str(req["run"]) == TaskStatus.CANCEL.value:
219
  tenant_id = DocumentService.get_tenant_id(id)
220
  if not tenant_id:
 
73
  "id": get_uuid(),
74
  "kb_id": kb.id,
75
  "parser_id": kb.parser_id,
76
+ "parser_config": kb.parser_config,
77
  "created_by": current_user.id,
78
  "type": filename_type(filename),
79
  "name": filename,
 
109
  "id": get_uuid(),
110
  "kb_id": kb.id,
111
  "parser_id": kb.parser_id,
112
+ "parser_config": kb.parser_config,
113
  "created_by": current_user.id,
114
  "type": FileType.VIRTUAL,
115
  "name": req["name"],
 
130
  data=False, retmsg='Lack of "KB ID"', retcode=RetCode.ARGUMENT_ERROR)
131
  keywords = request.args.get("keywords", "")
132
 
133
+ page_number = int(request.args.get("page", 1))
134
+ items_per_page = int(request.args.get("page_size", 15))
135
  orderby = request.args.get("orderby", "create_time")
136
  desc = request.args.get("desc", True)
137
  try:
 
216
  req = request.json
217
  try:
218
  for id in req["doc_ids"]:
219
+ info = {"run": str(req["run"]), "progress": 0}
220
+ if str(req["run"]) == TaskStatus.RUNNING.value:info["progress_msg"] = ""
221
+ DocumentService.update_by_id(id, info)
222
  if str(req["run"]) == TaskStatus.CANCEL.value:
223
  tenant_id = DocumentService.get_tenant_id(id)
224
  if not tenant_id:
api/apps/kb_app.py CHANGED
@@ -29,7 +29,7 @@ from api.utils.api_utils import get_json_result
29
 
30
  @manager.route('/create', methods=['post'])
31
  @login_required
32
- @validate_request("name", "description", "permission", "parser_id")
33
  def create():
34
  req = request.json
35
  req["name"] = req["name"].strip()
 
29
 
30
  @manager.route('/create', methods=['post'])
31
  @login_required
32
+ @validate_request("name")
33
  def create():
34
  req = request.json
35
  req["name"] = req["name"].strip()
api/db/__init__.py CHANGED
@@ -77,3 +77,4 @@ class ParserType(StrEnum):
77
  RESUME = "resume"
78
  BOOK = "book"
79
  QA = "qa"
 
 
77
  RESUME = "resume"
78
  BOOK = "book"
79
  QA = "qa"
80
+ TABLE = "table"
api/db/db_models.py CHANGED
@@ -29,7 +29,7 @@ from peewee import (
29
  )
30
  from playhouse.pool import PooledMySQLDatabase
31
 
32
- from api.db import SerializedType
33
  from api.settings import DATABASE, stat_logger, SECRET_KEY
34
  from api.utils.log_utils import getLogger
35
  from api import utils
@@ -381,7 +381,8 @@ class Tenant(DataBaseModel):
381
  embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
382
  asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
383
  img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
384
- parser_ids = CharField(max_length=128, null=False, help_text="default image to text model ID")
 
385
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
386
 
387
  class Meta:
@@ -472,7 +473,8 @@ class Knowledgebase(DataBaseModel):
472
  similarity_threshold = FloatField(default=0.2)
473
  vector_similarity_weight = FloatField(default=0.3)
474
 
475
- parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
 
476
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
477
 
478
  def __str__(self):
@@ -487,6 +489,7 @@ class Document(DataBaseModel):
487
  thumbnail = TextField(null=True, help_text="thumbnail base64 string")
488
  kb_id = CharField(max_length=256, null=False, index=True)
489
  parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
 
490
  source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
491
  type = CharField(max_length=32, null=False, help_text="file extension")
492
  created_by = CharField(max_length=32, null=False, help_text="who created it")
 
29
  )
30
  from playhouse.pool import PooledMySQLDatabase
31
 
32
+ from api.db import SerializedType, ParserType
33
  from api.settings import DATABASE, stat_logger, SECRET_KEY
34
  from api.utils.log_utils import getLogger
35
  from api import utils
 
381
  embd_id = CharField(max_length=128, null=False, help_text="default embedding model ID")
382
  asr_id = CharField(max_length=128, null=False, help_text="default ASR model ID")
383
  img2txt_id = CharField(max_length=128, null=False, help_text="default image to text model ID")
384
+ parser_ids = CharField(max_length=128, null=False, help_text="document processors")
385
+ credit = IntegerField(default=512)
386
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
387
 
388
  class Meta:
 
473
  similarity_threshold = FloatField(default=0.2)
474
  vector_similarity_weight = FloatField(default=0.3)
475
 
476
+ parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
477
+ parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000})
478
  status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
479
 
480
  def __str__(self):
 
489
  thumbnail = TextField(null=True, help_text="thumbnail base64 string")
490
  kb_id = CharField(max_length=256, null=False, index=True)
491
  parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
492
+ parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000})
493
  source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
494
  type = CharField(max_length=32, null=False, help_text="file extension")
495
  created_by = CharField(max_length=32, null=False, help_text="who created it")
api/db/db_services.py DELETED
@@ -1,157 +0,0 @@
1
- #
2
- # Copyright 2021 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- import abc
17
- import json
18
- import time
19
- from functools import wraps
20
- from shortuuid import ShortUUID
21
-
22
- from api.versions import get_rag_version
23
-
24
- from api.errors.error_services import *
25
- from api.settings import (
26
- GRPC_PORT, HOST, HTTP_PORT,
27
- RANDOM_INSTANCE_ID, stat_logger,
28
- )
29
-
30
-
31
- instance_id = ShortUUID().random(length=8) if RANDOM_INSTANCE_ID else f'flow-{HOST}-{HTTP_PORT}'
32
- server_instance = (
33
- f'{HOST}:{GRPC_PORT}',
34
- json.dumps({
35
- 'instance_id': instance_id,
36
- 'timestamp': round(time.time() * 1000),
37
- 'version': get_rag_version() or '',
38
- 'host': HOST,
39
- 'grpc_port': GRPC_PORT,
40
- 'http_port': HTTP_PORT,
41
- }),
42
- )
43
-
44
-
45
- def check_service_supported(method):
46
- """Decorator to check if `service_name` is supported.
47
- The attribute `supported_services` MUST be defined in class.
48
- The first and second arguments of `method` MUST be `self` and `service_name`.
49
-
50
- :param Callable method: The class method.
51
- :return: The inner wrapper function.
52
- :rtype: Callable
53
- """
54
- @wraps(method)
55
- def magic(self, service_name, *args, **kwargs):
56
- if service_name not in self.supported_services:
57
- raise ServiceNotSupported(service_name=service_name)
58
- return method(self, service_name, *args, **kwargs)
59
- return magic
60
-
61
-
62
- class ServicesDB(abc.ABC):
63
- """Database for storage service urls.
64
- Abstract base class for the real backends.
65
-
66
- """
67
- @property
68
- @abc.abstractmethod
69
- def supported_services(self):
70
- """The names of supported services.
71
- The returned list SHOULD contain `ragflow` (model download) and `servings` (RAG-Serving).
72
-
73
- :return: The service names.
74
- :rtype: list
75
- """
76
- pass
77
-
78
- @abc.abstractmethod
79
- def _get_serving(self):
80
- pass
81
-
82
- def get_serving(self):
83
-
84
- try:
85
- return self._get_serving()
86
- except ServicesError as e:
87
- stat_logger.exception(e)
88
- return []
89
-
90
- @abc.abstractmethod
91
- def _insert(self, service_name, service_url, value=''):
92
- pass
93
-
94
- @check_service_supported
95
- def insert(self, service_name, service_url, value=''):
96
- """Insert a service url to database.
97
-
98
- :param str service_name: The service name.
99
- :param str service_url: The service url.
100
- :return: None
101
- """
102
- try:
103
- self._insert(service_name, service_url, value)
104
- except ServicesError as e:
105
- stat_logger.exception(e)
106
-
107
- @abc.abstractmethod
108
- def _delete(self, service_name, service_url):
109
- pass
110
-
111
- @check_service_supported
112
- def delete(self, service_name, service_url):
113
- """Delete a service url from database.
114
-
115
- :param str service_name: The service name.
116
- :param str service_url: The service url.
117
- :return: None
118
- """
119
- try:
120
- self._delete(service_name, service_url)
121
- except ServicesError as e:
122
- stat_logger.exception(e)
123
-
124
- def register_flow(self):
125
- """Call `self.insert` for insert the flow server address to databae.
126
-
127
- :return: None
128
- """
129
- self.insert('flow-server', *server_instance)
130
-
131
- def unregister_flow(self):
132
- """Call `self.delete` for delete the flow server address from databae.
133
-
134
- :return: None
135
- """
136
- self.delete('flow-server', server_instance[0])
137
-
138
- @abc.abstractmethod
139
- def _get_urls(self, service_name, with_values=False):
140
- pass
141
-
142
- @check_service_supported
143
- def get_urls(self, service_name, with_values=False):
144
- """Query service urls from database. The urls may belong to other nodes.
145
- Currently, only `ragflow` (model download) urls and `servings` (RAG-Serving) urls are supported.
146
- `ragflow` is a url containing scheme, host, port and path,
147
- while `servings` only contains host and port.
148
-
149
- :param str service_name: The service name.
150
- :return: The service urls.
151
- :rtype: list
152
- """
153
- try:
154
- return self._get_urls(service_name, with_values)
155
- except ServicesError as e:
156
- stat_logger.exception(e)
157
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/db/services/document_service.py CHANGED
@@ -63,7 +63,7 @@ class DocumentService(CommonService):
63
  @classmethod
64
  @DB.connection_context()
65
  def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
66
- fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
67
  docs = cls.model.select(*fields) \
68
  .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
69
  .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
 
63
  @classmethod
64
  @DB.connection_context()
65
  def get_newly_uploaded(cls, tm, mod=0, comm=1, items_per_page=64):
66
+ fields = [cls.model.id, cls.model.kb_id, cls.model.parser_id, cls.model.parser_config, cls.model.name, cls.model.type, cls.model.location, cls.model.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
67
  docs = cls.model.select(*fields) \
68
  .join(Knowledgebase, on=(cls.model.kb_id == Knowledgebase.id)) \
69
  .join(Tenant, on=(Knowledgebase.tenant_id == Tenant.id))\
api/db/services/knowledgebase_service.py CHANGED
@@ -52,7 +52,8 @@ class KnowledgebaseService(CommonService):
52
  cls.model.doc_num,
53
  cls.model.token_num,
54
  cls.model.chunk_num,
55
- cls.model.parser_id]
 
56
  kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
57
  (cls.model.id == kb_id),
58
  (cls.model.status == StatusEnum.VALID.value)
 
52
  cls.model.doc_num,
53
  cls.model.token_num,
54
  cls.model.chunk_num,
55
+ cls.model.parser_id,
56
+ cls.model.parser_config]
57
  kbs = cls.model.select(*fields).join(Tenant, on=((Tenant.id == cls.model.tenant_id)&(Tenant.status== StatusEnum.VALID.value))).where(
58
  (cls.model.id == kb_id),
59
  (cls.model.status == StatusEnum.VALID.value)
api/db/services/task_service.py CHANGED
@@ -27,7 +27,7 @@ class TaskService(CommonService):
27
  @classmethod
28
  @DB.connection_context()
29
  def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64):
30
- fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
31
  docs = cls.model.select(*fields) \
32
  .join(Document, on=(cls.model.doc_id == Document.id)) \
33
  .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
@@ -53,3 +53,13 @@ class TaskService(CommonService):
53
  except Exception as e:
54
  pass
55
  return True
 
 
 
 
 
 
 
 
 
 
 
27
  @classmethod
28
  @DB.connection_context()
29
  def get_tasks(cls, tm, mod=0, comm=1, items_per_page=64):
30
+ fields = [cls.model.id, cls.model.doc_id, cls.model.from_page,cls.model.to_page, Document.kb_id, Document.parser_id, Document.parser_config, Document.name, Document.type, Document.location, Document.size, Knowledgebase.tenant_id, Tenant.embd_id, Tenant.img2txt_id, Tenant.asr_id, cls.model.update_time]
31
  docs = cls.model.select(*fields) \
32
  .join(Document, on=(cls.model.doc_id == Document.id)) \
33
  .join(Knowledgebase, on=(Document.kb_id == Knowledgebase.id)) \
 
53
  except Exception as e:
54
  pass
55
  return True
56
+
57
+
58
+ @classmethod
59
+ @DB.connection_context()
60
+ def update_progress(cls, id, info):
61
+ cls.model.update(progress_msg=cls.model.progress_msg + "\n"+info["progress_msg"]).where(
62
+ cls.model.id == id).execute()
63
+ if "progress" in info:
64
+ cls.model.update(progress=info["progress"]).where(
65
+ cls.model.id == id).execute()
api/db/services/user_service.py CHANGED
@@ -92,6 +92,12 @@ class TenantService(CommonService):
92
  .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
93
  .where(cls.model.status == StatusEnum.VALID.value).dicts())
94
 
 
 
 
 
 
 
95
 
96
  class UserTenantService(CommonService):
97
  model = UserTenant
 
92
  .join(UserTenant, on=((cls.model.id == UserTenant.tenant_id) & (UserTenant.user_id==user_id) & (UserTenant.status == StatusEnum.VALID.value) & (UserTenant.role==UserTenantRole.NORMAL.value)))\
93
  .where(cls.model.status == StatusEnum.VALID.value).dicts())
94
 
95
+ @classmethod
96
+ @DB.connection_context()
97
+ def decrease(cls, user_id, num):
98
+ num = cls.model.update(credit=cls.model.credit - num).where(
99
+ cls.model.id == user_id).execute()
100
+ if num == 0: raise LookupError("Tenant not found which is supposed to be there")
101
 
102
  class UserTenantService(CommonService):
103
  model = UserTenant
api/errors/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- from .general_error import *
2
-
3
-
4
- class RagFlowError(Exception):
5
- message = 'Unknown Rag Flow Error'
6
-
7
- def __init__(self, message=None, *args, **kwargs):
8
- message = str(message) if message is not None else self.message
9
- message = message.format(*args, **kwargs)
10
- super().__init__(message)
 
 
 
 
 
 
 
 
 
 
 
api/errors/error_services.py DELETED
@@ -1,13 +0,0 @@
1
- from api.errors import RagFlowError
2
-
3
- __all__ = ['ServicesError', 'ServiceNotSupported', 'ZooKeeperNotConfigured',
4
- 'MissingZooKeeperUsernameOrPassword', 'ZooKeeperBackendError']
5
-
6
-
7
- class ServicesError(RagFlowError):
8
- message = 'Unknown services error'
9
-
10
-
11
- class ServiceNotSupported(ServicesError):
12
- message = 'The service {service_name} is not supported'
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/errors/general_error.py DELETED
@@ -1,21 +0,0 @@
1
- #
2
- # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- class ParameterError(Exception):
17
- pass
18
-
19
-
20
- class PassError(Exception):
21
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/hook/__init__.py DELETED
@@ -1,57 +0,0 @@
1
- import importlib
2
-
3
- from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, \
4
- SignatureReturn, AuthenticationReturn, PermissionReturn, ClientAuthenticationReturn, ClientAuthenticationParameters
5
- from api.settings import HOOK_MODULE, stat_logger,RetCode
6
-
7
-
8
- class HookManager:
9
- SITE_SIGNATURE = []
10
- SITE_AUTHENTICATION = []
11
- CLIENT_AUTHENTICATION = []
12
- PERMISSION_CHECK = []
13
-
14
- @staticmethod
15
- def init():
16
- if HOOK_MODULE is not None:
17
- for modules in HOOK_MODULE.values():
18
- for module in modules.split(";"):
19
- try:
20
- importlib.import_module(module)
21
- except Exception as e:
22
- stat_logger.exception(e)
23
-
24
- @staticmethod
25
- def register_site_signature_hook(func):
26
- HookManager.SITE_SIGNATURE.append(func)
27
-
28
- @staticmethod
29
- def register_site_authentication_hook(func):
30
- HookManager.SITE_AUTHENTICATION.append(func)
31
-
32
- @staticmethod
33
- def register_client_authentication_hook(func):
34
- HookManager.CLIENT_AUTHENTICATION.append(func)
35
-
36
- @staticmethod
37
- def register_permission_check_hook(func):
38
- HookManager.PERMISSION_CHECK.append(func)
39
-
40
- @staticmethod
41
- def client_authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn:
42
- if HookManager.CLIENT_AUTHENTICATION:
43
- return HookManager.CLIENT_AUTHENTICATION[0](parm)
44
- return ClientAuthenticationReturn()
45
-
46
- @staticmethod
47
- def site_signature(parm: SignatureParameters) -> SignatureReturn:
48
- if HookManager.SITE_SIGNATURE:
49
- return HookManager.SITE_SIGNATURE[0](parm)
50
- return SignatureReturn()
51
-
52
- @staticmethod
53
- def site_authentication(parm: AuthenticationParameters) -> AuthenticationReturn:
54
- if HookManager.SITE_AUTHENTICATION:
55
- return HookManager.SITE_AUTHENTICATION[0](parm)
56
- return AuthenticationReturn()
57
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/hook/api/client_authentication.py DELETED
@@ -1,29 +0,0 @@
1
- import requests
2
-
3
- from api.db.service_registry import ServiceRegistry
4
- from api.settings import RegistryServiceName
5
- from api.hook import HookManager
6
- from api.hook.common.parameters import ClientAuthenticationParameters, ClientAuthenticationReturn
7
- from api.settings import HOOK_SERVER_NAME
8
-
9
-
10
- @HookManager.register_client_authentication_hook
11
- def authentication(parm: ClientAuthenticationParameters) -> ClientAuthenticationReturn:
12
- service_list = ServiceRegistry.load_service(
13
- server_name=HOOK_SERVER_NAME,
14
- service_name=RegistryServiceName.CLIENT_AUTHENTICATION.value
15
- )
16
- if not service_list:
17
- raise Exception(f"client authentication error: no found server"
18
- f" {HOOK_SERVER_NAME} service client_authentication")
19
- service = service_list[0]
20
- response = getattr(requests, service.f_method.lower(), None)(
21
- url=service.f_url,
22
- json=parm.to_dict()
23
- )
24
- if response.status_code != 200:
25
- raise Exception(
26
- f"client authentication error: request authentication url failed, status code {response.status_code}")
27
- elif response.json().get("code") != 0:
28
- return ClientAuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg"))
29
- return ClientAuthenticationReturn()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/hook/api/permission.py DELETED
@@ -1,25 +0,0 @@
1
- import requests
2
-
3
- from api.db.service_registry import ServiceRegistry
4
- from api.settings import RegistryServiceName
5
- from api.hook import HookManager
6
- from api.hook.common.parameters import PermissionCheckParameters, PermissionReturn
7
- from api.settings import HOOK_SERVER_NAME
8
-
9
-
10
- @HookManager.register_permission_check_hook
11
- def permission(parm: PermissionCheckParameters) -> PermissionReturn:
12
- service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.PERMISSION_CHECK.value)
13
- if not service_list:
14
- raise Exception(f"permission check error: no found server {HOOK_SERVER_NAME} service permission")
15
- service = service_list[0]
16
- response = getattr(requests, service.f_method.lower(), None)(
17
- url=service.f_url,
18
- json=parm.to_dict()
19
- )
20
- if response.status_code != 200:
21
- raise Exception(
22
- f"permission check error: request permission url failed, status code {response.status_code}")
23
- elif response.json().get("code") != 0:
24
- return PermissionReturn(code=response.json().get("code"), message=response.json().get("msg"))
25
- return PermissionReturn()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/hook/api/site_authentication.py DELETED
@@ -1,49 +0,0 @@
1
- import requests
2
-
3
- from api.db.service_registry import ServiceRegistry
4
- from api.settings import RegistryServiceName
5
- from api.hook import HookManager
6
- from api.hook.common.parameters import SignatureParameters, AuthenticationParameters, AuthenticationReturn,\
7
- SignatureReturn
8
- from api.settings import HOOK_SERVER_NAME, PARTY_ID
9
-
10
-
11
- @HookManager.register_site_signature_hook
12
- def signature(parm: SignatureParameters) -> SignatureReturn:
13
- service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME, service_name=RegistryServiceName.SIGNATURE.value)
14
- if not service_list:
15
- raise Exception(f"signature error: no found server {HOOK_SERVER_NAME} service signature")
16
- service = service_list[0]
17
- response = getattr(requests, service.f_method.lower(), None)(
18
- url=service.f_url,
19
- json=parm.to_dict()
20
- )
21
- if response.status_code == 200:
22
- if response.json().get("code") == 0:
23
- return SignatureReturn(site_signature=response.json().get("data"))
24
- else:
25
- raise Exception(f"signature error: request signature url failed, result: {response.json()}")
26
- else:
27
- raise Exception(f"signature error: request signature url failed, status code {response.status_code}")
28
-
29
-
30
- @HookManager.register_site_authentication_hook
31
- def authentication(parm: AuthenticationParameters) -> AuthenticationReturn:
32
- if not parm.src_party_id or str(parm.src_party_id) == "0":
33
- parm.src_party_id = PARTY_ID
34
- service_list = ServiceRegistry.load_service(server_name=HOOK_SERVER_NAME,
35
- service_name=RegistryServiceName.SITE_AUTHENTICATION.value)
36
- if not service_list:
37
- raise Exception(
38
- f"site authentication error: no found server {HOOK_SERVER_NAME} service site_authentication")
39
- service = service_list[0]
40
- response = getattr(requests, service.f_method.lower(), None)(
41
- url=service.f_url,
42
- json=parm.to_dict()
43
- )
44
- if response.status_code != 200:
45
- raise Exception(
46
- f"site authentication error: request site_authentication url failed, status code {response.status_code}")
47
- elif response.json().get("code") != 0:
48
- return AuthenticationReturn(code=response.json().get("code"), message=response.json().get("msg"))
49
- return AuthenticationReturn()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/hook/common/parameters.py DELETED
@@ -1,56 +0,0 @@
1
- from api.settings import RetCode
2
-
3
-
4
- class ParametersBase:
5
- def to_dict(self):
6
- d = {}
7
- for k, v in self.__dict__.items():
8
- d[k] = v
9
- return d
10
-
11
-
12
- class ClientAuthenticationParameters(ParametersBase):
13
- def __init__(self, full_path, headers, form, data, json):
14
- self.full_path = full_path
15
- self.headers = headers
16
- self.form = form
17
- self.data = data
18
- self.json = json
19
-
20
-
21
- class ClientAuthenticationReturn(ParametersBase):
22
- def __init__(self, code=RetCode.SUCCESS, message="success"):
23
- self.code = code
24
- self.message = message
25
-
26
-
27
- class SignatureParameters(ParametersBase):
28
- def __init__(self, party_id, body):
29
- self.party_id = party_id
30
- self.body = body
31
-
32
-
33
- class SignatureReturn(ParametersBase):
34
- def __init__(self, code=RetCode.SUCCESS, site_signature=None):
35
- self.code = code
36
- self.site_signature = site_signature
37
-
38
-
39
- class AuthenticationParameters(ParametersBase):
40
- def __init__(self, site_signature, body):
41
- self.site_signature = site_signature
42
- self.body = body
43
-
44
-
45
- class AuthenticationReturn(ParametersBase):
46
- def __init__(self, code=RetCode.SUCCESS, message="success"):
47
- self.code = code
48
- self.message = message
49
-
50
-
51
- class PermissionReturn(ParametersBase):
52
- def __init__(self, code=RetCode.SUCCESS, message="success"):
53
- self.code = code
54
- self.message = message
55
-
56
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
api/ragflow_server.py CHANGED
@@ -20,12 +20,9 @@ import os
20
  import signal
21
  import sys
22
  import traceback
23
-
24
  from werkzeug.serving import run_simple
25
-
26
  from api.apps import app
27
  from api.db.runtime_config import RuntimeConfig
28
- from api.hook import HookManager
29
  from api.settings import (
30
  HOST, HTTP_PORT, access_logger, database_logger, stat_logger,
31
  )
@@ -60,8 +57,6 @@ if __name__ == '__main__':
60
  RuntimeConfig.init_env()
61
  RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
62
 
63
- HookManager.init()
64
-
65
  peewee_logger = logging.getLogger('peewee')
66
  peewee_logger.propagate = False
67
  # rag_arch.common.log.ROpenHandler
 
20
  import signal
21
  import sys
22
  import traceback
 
23
  from werkzeug.serving import run_simple
 
24
  from api.apps import app
25
  from api.db.runtime_config import RuntimeConfig
 
26
  from api.settings import (
27
  HOST, HTTP_PORT, access_logger, database_logger, stat_logger,
28
  )
 
57
  RuntimeConfig.init_env()
58
  RuntimeConfig.init_config(JOB_SERVER_HOST=HOST, HTTP_PORT=HTTP_PORT)
59
 
 
 
60
  peewee_logger = logging.getLogger('peewee')
61
  peewee_logger.propagate = False
62
  # rag_arch.common.log.ROpenHandler
api/settings.py CHANGED
@@ -47,7 +47,7 @@ LLM = get_base_config("llm", {})
47
  CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
48
  EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
49
  ASR_MDL = LLM.get("asr_model", "whisper-1")
50
- PARSERS = LLM.get("parsers", "General,Resume,Laws,Product Instructions,Books,Paper,Q&A,Programming Code,Power Point,Research Report")
51
  IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
52
 
53
  # distribution
 
47
  CHAT_MDL = LLM.get("chat_model", "gpt-3.5-turbo")
48
  EMBEDDING_MDL = LLM.get("embedding_model", "text-embedding-ada-002")
49
  ASR_MDL = LLM.get("asr_model", "whisper-1")
50
+ PARSERS = LLM.get("parsers", "general:General,resume:esume,laws:Laws,manual:Manual,book:Book,paper:Paper,qa:Q&A,presentation:Presentation")
51
  IMAGE2TEXT_MDL = LLM.get("image2text_model", "gpt-4-vision-preview")
52
 
53
  # distribution
rag/app/book.py CHANGED
@@ -3,7 +3,7 @@ import random
3
  import re
4
  import numpy as np
5
  from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \
6
- hierarchical_merge, make_colon_as_title, naive_merge
7
  from rag.nlp import huqie
8
  from rag.parser.docx_parser import HuDocxParser
9
  from rag.parser.pdf_parser import HuParser
@@ -51,7 +51,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
51
  doc_parser = HuDocxParser()
52
  # TODO: table of contents need to be removed
53
  sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page)
54
- remove_contents_table(sections, eng=is_english(random.choices([t for t,_ in sections], k=200)))
55
  callback(0.8, "Finish parsing.")
56
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
57
  pdf_parser = Pdf()
@@ -67,20 +67,20 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
67
  l = f.readline()
68
  if not l:break
69
  txt += l
70
- sections = txt.split("\n")
71
  sections = [(l,"") for l in sections if l]
72
- remove_contents_table(sections, eng = is_english(random.choices([t for t,_ in sections], k=200)))
73
  callback(0.8, "Finish parsing.")
74
  else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
75
 
76
  make_colon_as_title(sections)
77
- bull = bullets_category([t for t in random.choices([t for t,_ in sections], k=100)])
78
  if bull >= 0: cks = hierarchical_merge(bull, sections, 3)
79
  else: cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?"))
80
 
81
  sections = [t for t, _ in sections]
82
  # is it English
83
- eng = is_english(random.choices(sections, k=218))
84
 
85
  res = []
86
  # add tables
 
3
  import re
4
  import numpy as np
5
  from rag.parser import bullets_category, BULLET_PATTERN, is_english, tokenize, remove_contents_table, \
6
+ hierarchical_merge, make_colon_as_title, naive_merge, random_choices
7
  from rag.nlp import huqie
8
  from rag.parser.docx_parser import HuDocxParser
9
  from rag.parser.pdf_parser import HuParser
 
51
  doc_parser = HuDocxParser()
52
  # TODO: table of contents need to be removed
53
  sections, tbls = doc_parser(binary if binary else filename, from_page=from_page, to_page=to_page)
54
+ remove_contents_table(sections, eng=is_english(random_choices([t for t,_ in sections], k=200)))
55
  callback(0.8, "Finish parsing.")
56
  elif re.search(r"\.pdf$", filename, re.IGNORECASE):
57
  pdf_parser = Pdf()
 
67
  l = f.readline()
68
  if not l:break
69
  txt += l
70
+ sections = txt.split("\n")
71
  sections = [(l,"") for l in sections if l]
72
+ remove_contents_table(sections, eng = is_english(random_choices([t for t,_ in sections], k=200)))
73
  callback(0.8, "Finish parsing.")
74
  else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
75
 
76
  make_colon_as_title(sections)
77
+ bull = bullets_category([t for t in random_choices([t for t,_ in sections], k=100)])
78
  if bull >= 0: cks = hierarchical_merge(bull, sections, 3)
79
  else: cks = naive_merge(sections, kwargs.get("chunk_token_num", 256), kwargs.get("delimer", "\n。;!?"))
80
 
81
  sections = [t for t, _ in sections]
82
  # is it English
83
+ eng = is_english(random_choices(sections, k=218))
84
 
85
  res = []
86
  # add tables
rag/app/laws.py CHANGED
@@ -86,7 +86,8 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
86
  l = f.readline()
87
  if not l:break
88
  txt += l
89
- sections = txt.split("\n")
 
90
  sections = [l for l in sections if l]
91
  callback(0.8, "Finish parsing.")
92
  else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
 
86
  l = f.readline()
87
  if not l:break
88
  txt += l
89
+ sections = txt.split("\n")
90
+ sections = txt.split("\n")
91
  sections = [l for l in sections if l]
92
  callback(0.8, "Finish parsing.")
93
  else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
rag/app/naive.py CHANGED
@@ -52,7 +52,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
52
  l = f.readline()
53
  if not l:break
54
  txt += l
55
- sections = txt.split("\n")
56
  sections = [(l,"") for l in sections if l]
57
  callback(0.8, "Finish parsing.")
58
  else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
 
52
  l = f.readline()
53
  if not l:break
54
  txt += l
55
+ sections = txt.split("\n")
56
  sections = [(l,"") for l in sections if l]
57
  callback(0.8, "Finish parsing.")
58
  else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
rag/app/paper.py CHANGED
@@ -1,6 +1,9 @@
1
  import copy
2
  import re
3
  from collections import Counter
 
 
 
4
  from rag.parser import tokenize
5
  from rag.nlp import huqie
6
  from rag.parser.pdf_parser import HuParser
@@ -9,6 +12,10 @@ from rag.utils import num_tokens_from_string
9
 
10
 
11
  class Pdf(HuParser):
 
 
 
 
12
  def __call__(self, filename, binary=None, from_page=0,
13
  to_page=100000, zoomin=3, callback=None):
14
  self.__images__(
@@ -63,6 +70,15 @@ class Pdf(HuParser):
63
  "[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)",
64
  txt.lower().strip())
65
 
 
 
 
 
 
 
 
 
 
66
  # get title and authors
67
  title = ""
68
  authors = []
@@ -115,18 +131,13 @@ class Pdf(HuParser):
115
 
116
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
117
  pdf_parser = None
118
- paper = {}
119
-
120
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
121
  pdf_parser = Pdf()
122
  paper = pdf_parser(filename if not binary else binary,
123
  from_page=from_page, to_page=to_page, callback=callback)
124
  else: raise NotImplementedError("file type not supported yet(pdf supported)")
125
- doc = {
126
- "docnm_kwd": paper["title"] if paper["title"] else filename,
127
- "authors_tks": paper["authors"]
128
- }
129
- doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
130
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
131
  doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
132
  # is it English
 
1
  import copy
2
  import re
3
  from collections import Counter
4
+
5
+ from api.db import ParserType
6
+ from rag.cv.ppdetection import PPDet
7
  from rag.parser import tokenize
8
  from rag.nlp import huqie
9
  from rag.parser.pdf_parser import HuParser
 
12
 
13
 
14
  class Pdf(HuParser):
15
+ def __init__(self):
16
+ self.model_speciess = ParserType.PAPER.value
17
+ super().__init__()
18
+
19
  def __call__(self, filename, binary=None, from_page=0,
20
  to_page=100000, zoomin=3, callback=None):
21
  self.__images__(
 
70
  "[0-9. 一、i]*(introduction|abstract|摘要|引言|keywords|key words|关键词|background|背景|目录|前言|contents)",
71
  txt.lower().strip())
72
 
73
+ if from_page > 0:
74
+ return {
75
+ "title":"",
76
+ "authors": "",
77
+ "abstract": "",
78
+ "lines": [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno", "")) for b in self.boxes[i:] if
79
+ re.match(r"(text|title)", b.get("layoutno", "text"))],
80
+ "tables": tbls
81
+ }
82
  # get title and authors
83
  title = ""
84
  authors = []
 
131
 
132
  def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
133
  pdf_parser = None
 
 
134
  if re.search(r"\.pdf$", filename, re.IGNORECASE):
135
  pdf_parser = Pdf()
136
  paper = pdf_parser(filename if not binary else binary,
137
  from_page=from_page, to_page=to_page, callback=callback)
138
  else: raise NotImplementedError("file type not supported yet(pdf supported)")
139
+ doc = {"docnm_kwd": filename, "authors_tks": paper["authors"],
140
+ "title_tks": huqie.qie(paper["title"] if paper["title"] else filename)}
 
 
 
141
  doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
142
  doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
143
  # is it English
rag/app/qa.py CHANGED
@@ -3,7 +3,7 @@ import re
3
  from io import BytesIO
4
  from nltk import word_tokenize
5
  from openpyxl import load_workbook
6
- from rag.parser import is_english
7
  from rag.nlp import huqie, stemmer
8
 
9
 
@@ -33,9 +33,9 @@ class Excel(object):
33
  if len(res) % 999 == 0:
34
  callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else "")))
35
 
36
- callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
37
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
38
- self.is_english = is_english([rmPrefix(q) for q, _ in random.choices(res, k=30) if len(q)>1])
39
  return res
40
 
41
 
 
3
  from io import BytesIO
4
  from nltk import word_tokenize
5
  from openpyxl import load_workbook
6
+ from rag.parser import is_english, random_choices
7
  from rag.nlp import huqie, stemmer
8
 
9
 
 
33
  if len(res) % 999 == 0:
34
  callback(len(res)*0.6/total, ("Extract Q&A: {}".format(len(res)) + (f"{len(fails)} failure, line: %s..."%(",".join(fails[:3])) if fails else "")))
35
 
36
+ callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + (
37
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
38
+ self.is_english = is_english([rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q)>1])
39
  return res
40
 
41
 
rag/app/table.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import re
4
+ from io import BytesIO
5
+ from xpinyin import Pinyin
6
+ import numpy as np
7
+ import pandas as pd
8
+ from nltk import word_tokenize
9
+ from openpyxl import load_workbook
10
+ from dateutil.parser import parse as datetime_parse
11
+ from rag.parser import is_english, tokenize
12
+ from rag.nlp import huqie, stemmer
13
+
14
+
15
+ class Excel(object):
16
+ def __call__(self, fnm, binary=None, callback=None):
17
+ if not binary:
18
+ wb = load_workbook(fnm)
19
+ else:
20
+ wb = load_workbook(BytesIO(binary))
21
+ total = 0
22
+ for sheetname in wb.sheetnames:
23
+ total += len(list(wb[sheetname].rows))
24
+
25
+ res, fails, done = [], [], 0
26
+ for sheetname in wb.sheetnames:
27
+ ws = wb[sheetname]
28
+ rows = list(ws.rows)
29
+ headers = [cell.value for cell in rows[0]]
30
+ missed = set([i for i,h in enumerate(headers) if h is None])
31
+ headers = [cell.value for i,cell in enumerate(rows[0]) if i not in missed]
32
+ data = []
33
+ for i, r in enumerate(rows[1:]):
34
+ row = [cell.value for ii,cell in enumerate(r) if ii not in missed]
35
+ if len(row) != len(headers):
36
+ fails.append(str(i))
37
+ continue
38
+ data.append(row)
39
+ done += 1
40
+ if done % 999 == 0:
41
+ callback(done * 0.6/total, ("Extract records: {}".format(len(res)) + (f"{len(fails)} failure({sheetname}), line: %s..."%(",".join(fails[:3])) if fails else "")))
42
+ res.append(pd.DataFrame(np.array(data), columns=headers))
43
+
44
+ callback(0.6, ("Extract records: {}. ".format(done) + (
45
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
46
+ return res
47
+
48
+
49
+ def trans_datatime(s):
50
+ try:
51
+ return datetime_parse(s.strip()).strftime("%Y-%m-%dT%H:%M:%S")
52
+ except Exception as e:
53
+ pass
54
+
55
+
56
+ def trans_bool(s):
57
+ if re.match(r"(true|yes|是)$", str(s).strip(), flags=re.IGNORECASE): return ["yes", "是"]
58
+ if re.match(r"(false|no|否)$", str(s).strip(), flags=re.IGNORECASE): return ["no", "否"]
59
+
60
+
61
+ def column_data_type(arr):
62
+ uni = len(set([a for a in arr if a is not None]))
63
+ counts = {"int": 0, "float": 0, "text": 0, "datetime": 0, "bool": 0}
64
+ trans = {t:f for f,t in [(int, "int"), (float, "float"), (trans_datatime, "datetime"), (trans_bool, "bool"), (str, "text")]}
65
+ for a in arr:
66
+ if a is None:continue
67
+ if re.match(r"[+-]?[0-9]+(\.0+)?$", str(a).replace("%%", "")):
68
+ counts["int"] += 1
69
+ elif re.match(r"[+-]?[0-9.]+$", str(a).replace("%%", "")):
70
+ counts["float"] += 1
71
+ elif re.match(r"(true|false|yes|no|是|否)$", str(a), flags=re.IGNORECASE):
72
+ counts["bool"] += 1
73
+ elif trans_datatime(str(a)):
74
+ counts["datetime"] += 1
75
+ else: counts["text"] += 1
76
+ counts = sorted(counts.items(), key=lambda x: x[1]*-1)
77
+ ty = counts[0][0]
78
+ for i in range(len(arr)):
79
+ if arr[i] is None:continue
80
+ try:
81
+ arr[i] = trans[ty](str(arr[i]))
82
+ except Exception as e:
83
+ arr[i] = None
84
+ if ty == "text":
85
+ if len(arr) > 128 and uni/len(arr) < 0.1:
86
+ ty = "keyword"
87
+ return arr, ty
88
+
89
+
90
+ def chunk(filename, binary=None, callback=None, **kwargs):
91
+ dfs = []
92
+ if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
93
+ callback(0.1, "Start to parse.")
94
+ excel_parser = Excel()
95
+ dfs = excel_parser(filename, binary, callback)
96
+ elif re.search(r"\.(txt|csv)$", filename, re.IGNORECASE):
97
+ callback(0.1, "Start to parse.")
98
+ txt = ""
99
+ if binary:
100
+ txt = binary.decode("utf-8")
101
+ else:
102
+ with open(filename, "r") as f:
103
+ while True:
104
+ l = f.readline()
105
+ if not l: break
106
+ txt += l
107
+ lines = txt.split("\n")
108
+ fails = []
109
+ headers = lines[0].split(kwargs.get("delimiter", "\t"))
110
+ rows = []
111
+ for i, line in enumerate(lines[1:]):
112
+ row = [l for l in line.split(kwargs.get("delimiter", "\t"))]
113
+ if len(row) != len(headers):
114
+ fails.append(str(i))
115
+ continue
116
+ rows.append(row)
117
+ if len(rows) % 999 == 0:
118
+ callback(len(rows) * 0.6 / len(lines), ("Extract records: {}".format(len(rows)) + (
119
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
120
+
121
+ callback(0.6, ("Extract records: {}".format(len(rows)) + (
122
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
123
+
124
+ dfs = [pd.DataFrame(np.array(rows), columns=headers)]
125
+
126
+ else: raise NotImplementedError("file type not supported yet(excel, text, csv supported)")
127
+
128
+ res = []
129
+ PY = Pinyin()
130
+ fieds_map = {"text": "_tks", "int": "_int", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"}
131
+ for df in dfs:
132
+ for n in ["id", "_id", "index", "idx"]:
133
+ if n in df.columns:del df[n]
134
+ clmns = df.columns.values
135
+ txts = list(copy.deepcopy(clmns))
136
+ py_clmns = [PY.get_pinyins(n)[0].replace("-", "_") for n in clmns]
137
+ clmn_tys = []
138
+ for j in range(len(clmns)):
139
+ cln,ty = column_data_type(df[clmns[j]])
140
+ clmn_tys.append(ty)
141
+ df[clmns[j]] = cln
142
+ if ty == "text": txts.extend([str(c) for c in cln if c])
143
+ clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j]) for i in range(len(clmns))]
144
+ # TODO: set this column map to KB parser configuration
145
+
146
+ eng = is_english(txts)
147
+ for ii,row in df.iterrows():
148
+ d = {}
149
+ row_txt = []
150
+ for j in range(len(clmns)):
151
+ if row[clmns[j]] is None:continue
152
+ fld = clmns_map[j][0]
153
+ d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else huqie.qie(row[clmns[j]])
154
+ row_txt.append("{}:{}".format(clmns[j], row[clmns[j]]))
155
+ if not row_txt:continue
156
+ tokenize(d, "; ".join(row_txt), eng)
157
+ print(d)
158
+ res.append(d)
159
+ callback(0.6, "")
160
+
161
+ return res
162
+
163
+
164
+
165
+ if __name__== "__main__":
166
+ import sys
167
+ def dummy(a, b):
168
+ pass
169
+ chunk(sys.argv[1], callback=dummy)
170
+
rag/nlp/search.py CHANGED
@@ -67,7 +67,7 @@ class Dealer:
67
  ps = int(req.get("size", 1000))
68
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id",
69
  "image_id", "doc_id", "q_512_vec", "q_768_vec",
70
- "q_1024_vec", "q_1536_vec", "available_int"])
71
 
72
  s = s.query(bqry)[pg * ps:(pg + 1) * ps]
73
  s = s.highlight("content_ltks")
@@ -234,7 +234,7 @@ class Dealer:
234
  sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
235
  if not ins_embd:
236
  return [], [], []
237
- ins_tw = [huqie.qie(sres.field[i][cfield]).split(" ")
238
  for i in sres.ids]
239
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
240
  ins_embd,
@@ -281,6 +281,7 @@ class Dealer:
281
  d = {
282
  "chunk_id": id,
283
  "content_ltks": sres.field[id]["content_ltks"],
 
284
  "doc_id": sres.field[id]["doc_id"],
285
  "docnm_kwd": dnm,
286
  "kb_id": sres.field[id]["kb_id"],
 
67
  ps = int(req.get("size", 1000))
68
  src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id",
69
  "image_id", "doc_id", "q_512_vec", "q_768_vec",
70
+ "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight"])
71
 
72
  s = s.query(bqry)[pg * ps:(pg + 1) * ps]
73
  s = s.highlight("content_ltks")
 
234
  sres.field[i].get("q_%d_vec" % len(sres.query_vector), "\t".join(["0"] * len(sres.query_vector)))) for i in sres.ids]
235
  if not ins_embd:
236
  return [], [], []
237
+ ins_tw = [sres.field[i][cfield].split(" ")
238
  for i in sres.ids]
239
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
240
  ins_embd,
 
281
  d = {
282
  "chunk_id": id,
283
  "content_ltks": sres.field[id]["content_ltks"],
284
+ "content_with_weight": sres.field[id]["content_with_weight"],
285
  "doc_id": sres.field[id]["doc_id"],
286
  "docnm_kwd": dnm,
287
  "kb_id": sres.field[id]["kb_id"],
rag/parser/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
  import copy
 
2
 
3
  from .pdf_parser import HuParser as PdfParser
4
  from .docx_parser import HuDocxParser as DocxParser
@@ -38,6 +39,9 @@ BULLET_PATTERN = [[
38
  ]
39
  ]
40
 
 
 
 
41
 
42
  def bullets_category(sections):
43
  global BULLET_PATTERN
 
1
  import copy
2
+ import random
3
 
4
  from .pdf_parser import HuParser as PdfParser
5
  from .docx_parser import HuDocxParser as DocxParser
 
39
  ]
40
  ]
41
 
42
+ def random_choices(arr, k):
43
+ k = min(len(arr), k)
44
+ return random.choices(arr, k=k)
45
 
46
  def bullets_category(sections):
47
  global BULLET_PATTERN
rag/parser/pdf_parser.py CHANGED
@@ -1,7 +1,10 @@
1
  # -*- coding: utf-8 -*-
 
2
  import random
 
3
 
4
  import fitz
 
5
  import xgboost as xgb
6
  from io import BytesIO
7
  import torch
@@ -10,13 +13,14 @@ import pdfplumber
10
  import logging
11
  from PIL import Image
12
  import numpy as np
 
 
13
  from rag.nlp import huqie
14
  from collections import Counter
15
  from copy import deepcopy
16
- from rag.cv.table_recognize import TableTransformer
17
- from rag.cv.ppdetection import PPDet
18
  from huggingface_hub import hf_hub_download
19
 
 
20
  logging.getLogger("pdfminer").setLevel(logging.WARNING)
21
 
22
 
@@ -25,8 +29,10 @@ class HuParser:
25
  from paddleocr import PaddleOCR
26
  logging.getLogger("ppocr").setLevel(logging.ERROR)
27
  self.ocr = PaddleOCR(use_angle_cls=False, lang="ch")
28
- self.layouter = PPDet("/data/newpeak/medical-gpt/res/ppdet")
29
- self.tbl_det = PPDet("/data/newpeak/medical-gpt/res/ppdet.tbl")
 
 
30
 
31
  self.updown_cnt_mdl = xgb.Booster()
32
  if torch.cuda.is_available():
@@ -45,6 +51,38 @@ class HuParser:
45
 
46
  """
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def __char_width(self, c):
49
  return (c["x1"] - c["x0"]) // len(c["text"])
50
 
@@ -344,7 +382,7 @@ class HuParser:
344
  return layouts
345
 
346
  def __table_paddle(self, images):
347
- tbls = self.tbl_det([np.array(img) for img in images], thr=0.5)
348
  res = []
349
  # align left&right for rows, align top&bottom for columns
350
  for tbl in tbls:
@@ -522,7 +560,7 @@ class HuParser:
522
  assert len(self.page_images) == len(self.boxes)
523
  # Tag layout type
524
  boxes = []
525
- layouts = self.layouter([np.array(img) for img in self.page_images])
526
  assert len(self.page_images) == len(layouts)
527
  for pn, lts in enumerate(layouts):
528
  bxs = self.boxes[pn]
@@ -1705,7 +1743,8 @@ class HuParser:
1705
  self.__ocr_paddle(i + 1, img, chars, zoomin)
1706
 
1707
  if not self.is_english and not any([c for c in self.page_chars]) and self.boxes:
1708
- self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices([b for bxs in self.boxes for b in bxs], k=30)]))
 
1709
 
1710
  logging.info("Is it English:", self.is_english)
1711
 
 
1
  # -*- coding: utf-8 -*-
2
+ import os
3
  import random
4
+ from functools import partial
5
 
6
  import fitz
7
+ import requests
8
  import xgboost as xgb
9
  from io import BytesIO
10
  import torch
 
13
  import logging
14
  from PIL import Image
15
  import numpy as np
16
+
17
+ from api.db import ParserType
18
  from rag.nlp import huqie
19
  from collections import Counter
20
  from copy import deepcopy
 
 
21
  from huggingface_hub import hf_hub_download
22
 
23
+
24
  logging.getLogger("pdfminer").setLevel(logging.WARNING)
25
 
26
 
 
29
  from paddleocr import PaddleOCR
30
  logging.getLogger("ppocr").setLevel(logging.ERROR)
31
  self.ocr = PaddleOCR(use_angle_cls=False, lang="ch")
32
+ if not hasattr(self, "model_speciess"):
33
+ self.model_speciess = ParserType.GENERAL.value
34
+ self.layouter = partial(self.__remote_call, self.model_speciess)
35
+ self.tbl_det = partial(self.__remote_call, "table_component")
36
 
37
  self.updown_cnt_mdl = xgb.Booster()
38
  if torch.cuda.is_available():
 
51
 
52
  """
53
 
54
+ def __remote_call(self, species, images, thr=0.7):
55
+ url = os.environ.get("INFINIFLOW_SERVER")
56
+ if not url:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_SERVER'")
57
+ token = os.environ.get("INFINIFLOW_TOKEN")
58
+ if not token:raise EnvironmentError("Please set environment variable: 'INFINIFLOW_TOKEN'")
59
+
60
+ def convert_image_to_bytes(PILimage):
61
+ image = BytesIO()
62
+ PILimage.save(image, format='png')
63
+ image.seek(0)
64
+ return image.getvalue()
65
+
66
+ images = [convert_image_to_bytes(img) for img in images]
67
+
68
+ def remote_call():
69
+ nonlocal images, thr
70
+ res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr},
71
+ headers={"Authorization": token}, timeout=len(images) * 10)
72
+ res = res.json()
73
+ if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
74
+ return res["data"]
75
+
76
+ for _ in range(3):
77
+ try:
78
+ return remote_call()
79
+ except RuntimeError as e:
80
+ raise e
81
+ except Exception as e:
82
+ logging.error("layout_predict:"+str(e))
83
+ return remote_call()
84
+
85
+
86
  def __char_width(self, c):
87
  return (c["x1"] - c["x0"]) // len(c["text"])
88
 
 
382
  return layouts
383
 
384
  def __table_paddle(self, images):
385
+ tbls = self.tbl_det(images, thr=0.5)
386
  res = []
387
  # align left&right for rows, align top&bottom for columns
388
  for tbl in tbls:
 
560
  assert len(self.page_images) == len(self.boxes)
561
  # Tag layout type
562
  boxes = []
563
+ layouts = self.layouter(self.page_images)
564
  assert len(self.page_images) == len(layouts)
565
  for pn, lts in enumerate(layouts):
566
  bxs = self.boxes[pn]
 
1743
  self.__ocr_paddle(i + 1, img, chars, zoomin)
1744
 
1745
  if not self.is_english and not any([c for c in self.page_chars]) and self.boxes:
1746
+ bxes = [b for bxs in self.boxes for b in bxs]
1747
+ self.is_english = re.search(r"[\na-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join([b["text"] for b in random.choices(bxes, k=min(30, len(bxes)))]))
1748
 
1749
  logging.info("Is it English:", self.is_english)
1750
 
rag/svr/task_broker.py CHANGED
@@ -134,5 +134,5 @@ if __name__ == "__main__":
134
 
135
  while True:
136
  dispatch()
137
- time.sleep(3)
138
  update_progress()
 
134
 
135
  while True:
136
  dispatch()
137
+ time.sleep(1)
138
  update_progress()
rag/svr/task_executor.py CHANGED
@@ -36,7 +36,7 @@ from rag.nlp import search
36
  from io import BytesIO
37
  import pandas as pd
38
 
39
- from rag.app import laws, paper, presentation, manual, qa
40
 
41
  from api.db import LLMType, ParserType
42
  from api.db.services.document_service import DocumentService
@@ -49,10 +49,12 @@ BATCH_SIZE = 64
49
  FACTORY = {
50
  ParserType.GENERAL.value: laws,
51
  ParserType.PAPER.value: paper,
 
52
  ParserType.PRESENTATION.value: presentation,
53
  ParserType.MANUAL.value: manual,
54
  ParserType.LAWS.value: laws,
55
  ParserType.QA.value: qa,
 
56
  }
57
 
58
 
@@ -66,7 +68,7 @@ def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."):
66
  d = {"progress_msg": msg}
67
  if prog is not None: d["progress"] = prog
68
  try:
69
- TaskService.update_by_id(task_id, d)
70
  except Exception as e:
71
  cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
72
 
@@ -113,7 +115,7 @@ def build(row, cvmdl):
113
  return []
114
 
115
  callback = partial(set_progress, row["id"], row["from_page"], row["to_page"])
116
- chunker = FACTORY[row["parser_id"]]
117
  try:
118
  cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
119
  cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
@@ -154,6 +156,7 @@ def build(row, cvmdl):
154
 
155
  MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
156
  d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
 
157
  docs.append(d)
158
 
159
  return docs
@@ -168,7 +171,7 @@ def init_kb(row):
168
 
169
 
170
  def embedding(docs, mdl):
171
- tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs]
172
  tk_count = 0
173
  if len(tts) == len(cnts):
174
  tts, c = mdl.encode(tts)
@@ -207,6 +210,7 @@ def main(comm, mod):
207
  cks = build(r, cv_mdl)
208
  if not cks:
209
  tmf.write(str(r["update_time"]) + "\n")
 
210
  continue
211
  # TODO: exception handler
212
  ## set_progress(r["did"], -1, "ERROR: ")
@@ -215,7 +219,6 @@ def main(comm, mod):
215
  except Exception as e:
216
  callback(-1, "Embedding error:{}".format(str(e)))
217
  cron_logger.error(str(e))
218
- continue
219
 
220
  callback(msg="Finished embedding! Start to build index!")
221
  init_kb(r)
@@ -227,6 +230,7 @@ def main(comm, mod):
227
  else:
228
  if TaskService.do_cancel(r["id"]):
229
  ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
 
230
  callback(1., "Done!")
231
  DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
232
  cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
 
36
  from io import BytesIO
37
  import pandas as pd
38
 
39
+ from rag.app import laws, paper, presentation, manual, qa, table,book
40
 
41
  from api.db import LLMType, ParserType
42
  from api.db.services.document_service import DocumentService
 
49
  FACTORY = {
50
  ParserType.GENERAL.value: laws,
51
  ParserType.PAPER.value: paper,
52
+ ParserType.BOOK.value: book,
53
  ParserType.PRESENTATION.value: presentation,
54
  ParserType.MANUAL.value: manual,
55
  ParserType.LAWS.value: laws,
56
  ParserType.QA.value: qa,
57
+ ParserType.TABLE.value: table,
58
  }
59
 
60
 
 
68
  d = {"progress_msg": msg}
69
  if prog is not None: d["progress"] = prog
70
  try:
71
+ TaskService.update_progress(task_id, d)
72
  except Exception as e:
73
  cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
74
 
 
115
  return []
116
 
117
  callback = partial(set_progress, row["id"], row["from_page"], row["to_page"])
118
+ chunker = FACTORY[row["parser_id"].lower()]
119
  try:
120
  cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
121
  cks = chunker.chunk(row["name"], MINIO.get(row["kb_id"], row["location"]), row["from_page"], row["to_page"],
 
156
 
157
  MINIO.put(row["kb_id"], d["_id"], output_buffer.getvalue())
158
  d["img_id"] = "{}-{}".format(row["kb_id"], d["_id"])
159
+ del d["image"]
160
  docs.append(d)
161
 
162
  return docs
 
171
 
172
 
173
  def embedding(docs, mdl):
174
+ tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
175
  tk_count = 0
176
  if len(tts) == len(cnts):
177
  tts, c = mdl.encode(tts)
 
210
  cks = build(r, cv_mdl)
211
  if not cks:
212
  tmf.write(str(r["update_time"]) + "\n")
213
+ callback(1., "No chunk! Done!")
214
  continue
215
  # TODO: exception handler
216
  ## set_progress(r["did"], -1, "ERROR: ")
 
219
  except Exception as e:
220
  callback(-1, "Embedding error:{}".format(str(e)))
221
  cron_logger.error(str(e))
 
222
 
223
  callback(msg="Finished embedding! Start to build index!")
224
  init_kb(r)
 
230
  else:
231
  if TaskService.do_cancel(r["id"]):
232
  ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
233
+ continue
234
  callback(1., "Done!")
235
  DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
236
  cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))