KevinHuSh commited on
Commit
4c52eb9
·
1 Parent(s): 3772f42

refine admin initialization (#75)

Browse files
api/apps/chunk_app.py CHANGED
@@ -20,7 +20,7 @@ from flask_login import login_required, current_user
20
  from elasticsearch_dsl import Q
21
 
22
  from rag.app.qa import rmPrefix, beAdoc
23
- from rag.nlp import search, huqie, retrievaler
24
  from rag.utils import ELASTICSEARCH, rmSpace
25
  from api.db import LLMType, ParserType
26
  from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService
28
  from api.db.services.user_service import UserTenantService
29
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
30
  from api.db.services.document_service import DocumentService
31
- from api.settings import RetCode
32
  from api.utils.api_utils import get_json_result
33
  import hashlib
34
  import re
 
20
  from elasticsearch_dsl import Q
21
 
22
  from rag.app.qa import rmPrefix, beAdoc
23
+ from rag.nlp import search, huqie
24
  from rag.utils import ELASTICSEARCH, rmSpace
25
  from api.db import LLMType, ParserType
26
  from api.db.services.knowledgebase_service import KnowledgebaseService
 
28
  from api.db.services.user_service import UserTenantService
29
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
30
  from api.db.services.document_service import DocumentService
31
+ from api.settings import RetCode, retrievaler
32
  from api.utils.api_utils import get_json_result
33
  import hashlib
34
  import re
api/apps/conversation_app.py CHANGED
@@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService
21
  from api.db import LLMType
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMService, LLMBundle
24
- from api.settings import access_logger, stat_logger
25
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
  from api.utils import get_uuid
27
  from api.utils.api_utils import get_json_result
28
  from rag.app.resume import forbidden_select_fields4resume
29
- from rag.llm import ChatModel
30
- from rag.nlp import retrievaler
31
  from rag.nlp.search import index_name
32
  from rag.utils import num_tokens_from_string, encoder, rmSpace
33
 
 
21
  from api.db import LLMType
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMService, LLMBundle
24
+ from api.settings import access_logger, stat_logger, retrievaler
25
  from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
26
  from api.utils import get_uuid
27
  from api.utils.api_utils import get_json_result
28
  from rag.app.resume import forbidden_select_fields4resume
 
 
29
  from rag.nlp.search import index_name
30
  from rag.utils import num_tokens_from_string, encoder, rmSpace
31
 
api/db/init_data.py CHANGED
@@ -16,10 +16,12 @@
16
  import time
17
  import uuid
18
 
19
- from api.db import LLMType
20
  from api.db.db_models import init_database_tables as init_web_db
21
  from api.db.services import UserService
22
- from api.db.services.llm_service import LLMFactoriesService, LLMService
 
 
23
 
24
 
25
  def init_superuser():
@@ -32,8 +34,44 @@ def init_superuser():
32
  "creator": "system",
33
  "status": "1",
34
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  UserService.save(**user_info)
36
 
 
 
 
 
 
 
 
 
 
37
 
38
  def init_llm_factory():
39
  factory_infos = [{
@@ -171,10 +209,10 @@ def init_llm_factory():
171
 
172
  def init_web_data():
173
  start_time = time.time()
174
- if not UserService.get_all().count():
175
- init_superuser()
176
 
177
  if not LLMService.get_all().count():init_llm_factory()
 
 
178
 
179
  print("init web data success:{}".format(time.time() - start_time))
180
 
 
16
  import time
17
  import uuid
18
 
19
+ from api.db import LLMType, UserTenantRole
20
  from api.db.db_models import init_database_tables as init_web_db
21
  from api.db.services import UserService
22
+ from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
23
+ from api.db.services.user_service import TenantService, UserTenantService
24
+ from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
25
 
26
 
27
  def init_superuser():
 
34
  "creator": "system",
35
  "status": "1",
36
  }
37
+ tenant = {
38
+ "id": user_info["id"],
39
+ "name": user_info["nickname"] + "‘s Kingdom",
40
+ "llm_id": CHAT_MDL,
41
+ "embd_id": EMBEDDING_MDL,
42
+ "asr_id": ASR_MDL,
43
+ "parser_ids": PARSERS,
44
+ "img2txt_id": IMAGE2TEXT_MDL
45
+ }
46
+ usr_tenant = {
47
+ "tenant_id": user_info["id"],
48
+ "user_id": user_info["id"],
49
+ "invited_by": user_info["id"],
50
+ "role": UserTenantRole.OWNER
51
+ }
52
+ tenant_llm = []
53
+ for llm in LLMService.query(fid=LLM_FACTORY):
54
+ tenant_llm.append(
55
+ {"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
56
+ "api_key": API_KEY})
57
+
58
+ if not UserService.save(**user_info):
59
+ print("【ERROR】can't init admin.")
60
+ return
61
+ TenantService.save(**tenant)
62
+ UserTenantService.save(**usr_tenant)
63
+ TenantLLMService.insert_many(tenant_llm)
64
  UserService.save(**user_info)
65
 
66
+ chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
67
+ msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
68
+ if msg.find("ERROR: ") == 0:
69
+ print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
70
+ embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
71
+ v,c = embd_mdl.encode(["Hello!"])
72
+ if c == 0:
73
+ print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
74
+
75
 
76
  def init_llm_factory():
77
  factory_infos = [{
 
209
 
210
  def init_web_data():
211
  start_time = time.time()
 
 
212
 
213
  if not LLMService.get_all().count():init_llm_factory()
214
+ if not UserService.get_all().count():
215
+ init_superuser()
216
 
217
  print("init web data success:{}".format(time.time() - start_time))
218
 
api/settings.py CHANGED
@@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config
21
  from api.utils.file_utils import get_project_base_directory
22
  from api.utils.log_utils import LoggerFactory, getLogger
23
 
 
 
 
24
 
25
- # Server
26
  API_VERSION = "v1"
27
  RAG_FLOW_SERVICE_NAME = "ragflow"
28
  SERVER_MODULE = "rag_flow_server.py"
@@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
116
  PRIVILEGE_COMMAND_WHITELIST = []
117
  CHECK_NODES_IDENTITY = False
118
 
 
 
119
  class CustomEnum(Enum):
120
  @classmethod
121
  def valid(cls, value):
 
21
  from api.utils.file_utils import get_project_base_directory
22
  from api.utils.log_utils import LoggerFactory, getLogger
23
 
24
+ from rag.nlp import search
25
+ from rag.utils import ELASTICSEARCH
26
+
27
 
 
28
  API_VERSION = "v1"
29
  RAG_FLOW_SERVICE_NAME = "ragflow"
30
  SERVER_MODULE = "rag_flow_server.py"
 
118
  PRIVILEGE_COMMAND_WHITELIST = []
119
  CHECK_NODES_IDENTITY = False
120
 
121
+ retrievaler = search.Dealer(ELASTICSEARCH)
122
+
123
  class CustomEnum(Enum):
124
  @classmethod
125
  def valid(cls, value):
deepdoc/parser/pdf_parser.py CHANGED
@@ -230,7 +230,7 @@ class HuParser:
230
  b["H_right"] = headers[ii]["x1"]
231
  b["H"] = ii
232
 
233
- ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
234
  if ii is not None:
235
  b["C"] = ii
236
  b["C_left"] = clmns[ii]["x0"]
 
230
  b["H_right"] = headers[ii]["x1"]
231
  b["H"] = ii
232
 
233
+ ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
234
  if ii is not None:
235
  b["C"] = ii
236
  b["C_left"] = clmns[ii]["x0"]
deepdoc/vision/layout_recognizer.py CHANGED
@@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer):
37
  super().__init__(self.labels, domain,
38
  os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
39
 
40
- def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
41
  def __is_garbage(b):
42
  patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
43
  r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
 
37
  super().__init__(self.labels, domain,
38
  os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
39
 
40
+ def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
41
  def __is_garbage(b):
42
  patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
43
  r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
deepdoc/vision/postprocess.py CHANGED
@@ -2,7 +2,6 @@ import copy
2
 
3
  import numpy as np
4
  import cv2
5
- import paddle
6
  from shapely.geometry import Polygon
7
  import pyclipper
8
 
@@ -215,7 +214,7 @@ class DBPostProcess(object):
215
 
216
  def __call__(self, outs_dict, shape_list):
217
  pred = outs_dict['maps']
218
- if isinstance(pred, paddle.Tensor):
219
  pred = pred.numpy()
220
  pred = pred[:, 0, :, :]
221
  segmentation = pred > self.thresh
@@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
339
  def __call__(self, preds, label=None, *args, **kwargs):
340
  if isinstance(preds, tuple) or isinstance(preds, list):
341
  preds = preds[-1]
342
- if isinstance(preds, paddle.Tensor):
343
  preds = preds.numpy()
344
  preds_idx = preds.argmax(axis=2)
345
  preds_prob = preds.max(axis=2)
 
2
 
3
  import numpy as np
4
  import cv2
 
5
  from shapely.geometry import Polygon
6
  import pyclipper
7
 
 
214
 
215
  def __call__(self, outs_dict, shape_list):
216
  pred = outs_dict['maps']
217
+ if not isinstance(pred, np.ndarray):
218
  pred = pred.numpy()
219
  pred = pred[:, 0, :, :]
220
  segmentation = pred > self.thresh
 
338
  def __call__(self, preds, label=None, *args, **kwargs):
339
  if isinstance(preds, tuple) or isinstance(preds, list):
340
  preds = preds[-1]
341
+ if not isinstance(preds, np.ndarray):
342
  preds = preds.numpy()
343
  preds_idx = preds.argmax(axis=2)
344
  preds_prob = preds.max(axis=2)
deepdoc/vision/recognizer.py CHANGED
@@ -259,6 +259,18 @@ class Recognizer(object):
259
 
260
  return max_overlaped_i
261
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  @staticmethod
263
  def find_overlapped_with_threashold(box, boxes, thr=0.3):
264
  if not boxes:
 
259
 
260
  return max_overlaped_i
261
 
262
+ @staticmethod
263
+ def find_horizontally_tightest_fit(box, boxes):
264
+ if not boxes:
265
+ return
266
+ min_dis, min_i = 1000000, None
267
+ for i,b in enumerate(boxes):
268
+ dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
269
+ if dis < min_dis:
270
+ min_i = i
271
+ min_dis = dis
272
+ return min_i
273
+
274
  @staticmethod
275
  def find_overlapped_with_threashold(box, boxes, thr=0.3):
276
  if not boxes:
deepdoc/vision/t_recognizer.py CHANGED
@@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr):
74
  clmns = sorted([r for r in tb_cpns if re.match(
75
  r"table column$", r["label"])], key=lambda x: x["x0"])
76
  clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
 
77
  for b in boxes:
78
  ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
79
  if ii is not None:
@@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr):
89
  b["H_right"] = headers[ii]["x1"]
90
  b["H"] = ii
91
 
92
- ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
93
  if ii is not None:
94
  b["C"] = ii
95
  b["C_left"] = clmns[ii]["x0"]
@@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr):
102
  b["H_left"] = spans[ii]["x0"]
103
  b["H_right"] = spans[ii]["x1"]
104
  b["SP"] = ii
 
105
  html = """
106
  <html>
107
  <head>
 
74
  clmns = sorted([r for r in tb_cpns if re.match(
75
  r"table column$", r["label"])], key=lambda x: x["x0"])
76
  clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
77
+
78
  for b in boxes:
79
  ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
80
  if ii is not None:
 
90
  b["H_right"] = headers[ii]["x1"]
91
  b["H"] = ii
92
 
93
+ ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
94
  if ii is not None:
95
  b["C"] = ii
96
  b["C_left"] = clmns[ii]["x0"]
 
103
  b["H_left"] = spans[ii]["x0"]
104
  b["H_right"] = spans[ii]["x1"]
105
  b["SP"] = ii
106
+
107
  html = """
108
  <html>
109
  <head>
deepdoc/vision/table_structure_recognizer.py CHANGED
@@ -14,7 +14,6 @@ import logging
14
  import os
15
  import re
16
  from collections import Counter
17
- from copy import deepcopy
18
 
19
  import numpy as np
20
 
@@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer):
37
  super().__init__(self.labels, "tsr",
38
  os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
39
 
40
- def __call__(self, images, thr=0.5):
41
  tbls = super().__call__(images, thr)
42
  res = []
43
  # align left&right for rows, align top&bottom for columns
@@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer):
56
  "row") > 0 or b["label"].find("header") > 0]
57
  if not left:
58
  continue
59
- left = np.median(left) if len(left) > 4 else np.min(left)
60
- right = np.median(right) if len(right) > 4 else np.max(right)
61
  for b in lts:
62
  if b["label"].find("row") > 0 or b["label"].find("header") > 0:
63
  if b["x0"] > left:
@@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer):
129
  i = 0
130
  while i < len(boxes):
131
  if TableStructureRecognizer.is_caption(boxes[i]):
 
132
  cap += boxes[i]["text"]
133
  boxes.pop(i)
134
  i -= 1
@@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer):
398
  for i in range(clmno):
399
  if not tbl[r][i]:
400
  continue
401
- txt = "".join([a["text"].strip() for a in tbl[r][i]])
402
  headers[r][i] = txt
403
  hdrset.add(txt)
404
  if all([not t for t in headers[r]]):
 
14
  import os
15
  import re
16
  from collections import Counter
 
17
 
18
  import numpy as np
19
 
 
36
  super().__init__(self.labels, "tsr",
37
  os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
38
 
39
+ def __call__(self, images, thr=0.2):
40
  tbls = super().__call__(images, thr)
41
  res = []
42
  # align left&right for rows, align top&bottom for columns
 
55
  "row") > 0 or b["label"].find("header") > 0]
56
  if not left:
57
  continue
58
+ left = np.mean(left) if len(left) > 4 else np.min(left)
59
+ right = np.mean(right) if len(right) > 4 else np.max(right)
60
  for b in lts:
61
  if b["label"].find("row") > 0 or b["label"].find("header") > 0:
62
  if b["x0"] > left:
 
128
  i = 0
129
  while i < len(boxes):
130
  if TableStructureRecognizer.is_caption(boxes[i]):
131
+ if is_english: cap + " "
132
  cap += boxes[i]["text"]
133
  boxes.pop(i)
134
  i -= 1
 
398
  for i in range(clmno):
399
  if not tbl[r][i]:
400
  continue
401
+ txt = " ".join([a["text"].strip() for a in tbl[r][i]])
402
  headers[r][i] = txt
403
  hdrset.add(txt)
404
  if all([not t for t in headers[r]]):
rag/llm/chat_model.py CHANGED
@@ -15,7 +15,7 @@
15
  #
16
  from abc import ABC
17
  from openai import OpenAI
18
- import os
19
 
20
 
21
  class Base(ABC):
@@ -33,11 +33,14 @@ class GptTurbo(Base):
33
 
34
  def chat(self, system, history, gen_conf):
35
  if system: history.insert(0, {"role": "system", "content": system})
36
- res = self.client.chat.completions.create(
37
- model=self.model_name,
38
- messages=history,
39
- **gen_conf)
40
- return res.choices[0].message.content.strip(), res.usage.completion_tokens
 
 
 
41
 
42
 
43
  from dashscope import Generation
@@ -58,7 +61,7 @@ class QWenChat(Base):
58
  )
59
  if response.status_code == HTTPStatus.OK:
60
  return response.output.choices[0]['message']['content'], response.usage.output_tokens
61
- return response.message, 0
62
 
63
 
64
  from zhipuai import ZhipuAI
@@ -77,4 +80,4 @@ class ZhipuChat(Base):
77
  )
78
  if response.status_code == HTTPStatus.OK:
79
  return response.output.choices[0]['message']['content'], response.usage.completion_tokens
80
- return response.message, 0
 
15
  #
16
  from abc import ABC
17
  from openai import OpenAI
18
+ import openai
19
 
20
 
21
  class Base(ABC):
 
33
 
34
  def chat(self, system, history, gen_conf):
35
  if system: history.insert(0, {"role": "system", "content": system})
36
+ try:
37
+ res = self.client.chat.completions.create(
38
+ model=self.model_name,
39
+ messages=history,
40
+ **gen_conf)
41
+ return res.choices[0].message.content.strip(), res.usage.completion_tokens
42
+ except openai.APIError as e:
43
+ return "ERROR: "+str(e), 0
44
 
45
 
46
  from dashscope import Generation
 
61
  )
62
  if response.status_code == HTTPStatus.OK:
63
  return response.output.choices[0]['message']['content'], response.usage.output_tokens
64
+ return "ERROR: " + response.message, 0
65
 
66
 
67
  from zhipuai import ZhipuAI
 
80
  )
81
  if response.status_code == HTTPStatus.OK:
82
  return response.output.choices[0]['message']['content'], response.usage.completion_tokens
83
+ return "ERROR: " + response.message, 0
rag/nlp/__init__.py CHANGED
@@ -1,7 +1,4 @@
1
- from . import search
2
- from rag.utils import ELASTICSEARCH
3
 
4
- retrievaler = search.Dealer(ELASTICSEARCH)
5
 
6
  from nltk.stem import PorterStemmer
7
  stemmer = PorterStemmer()
@@ -39,10 +36,12 @@ BULLET_PATTERN = [[
39
  ]
40
  ]
41
 
 
42
  def random_choices(arr, k):
43
  k = min(len(arr), k)
44
  return random.choices(arr, k=k)
45
 
 
46
  def bullets_category(sections):
47
  global BULLET_PATTERN
48
  hits = [0] * len(BULLET_PATTERN)
 
 
 
1
 
 
2
 
3
  from nltk.stem import PorterStemmer
4
  stemmer = PorterStemmer()
 
36
  ]
37
  ]
38
 
39
+
40
  def random_choices(arr, k):
41
  k = min(len(arr), k)
42
  return random.choices(arr, k=k)
43
 
44
+
45
  def bullets_category(sections):
46
  global BULLET_PATTERN
47
  hits = [0] * len(BULLET_PATTERN)
rag/nlp/search.py CHANGED
@@ -1,7 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
  import json
3
  import re
4
- from elasticsearch_dsl import Q, Search, A
5
  from typing import List, Optional, Dict, Union
6
  from dataclasses import dataclass
7
 
@@ -183,6 +183,7 @@ class Dealer:
183
 
184
  def insert_citations(self, answer, chunks, chunk_v,
185
  embd_mdl, tkweight=0.3, vtweight=0.7):
 
186
  pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
187
  for i in range(1, len(pieces)):
188
  if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
@@ -216,7 +217,7 @@ class Dealer:
216
  if mx < 0.55:
217
  continue
218
  cites[idx[i]] = list(
219
- set([str(i) for i in range(len(chunk_v)) if sim[i] > mx]))[:4]
220
 
221
  res = ""
222
  for i, p in enumerate(pieces):
@@ -225,6 +226,7 @@ class Dealer:
225
  continue
226
  if i not in cites:
227
  continue
 
228
  res += "##%s$$" % "$".join(cites[i])
229
 
230
  return res
 
1
  # -*- coding: utf-8 -*-
2
  import json
3
  import re
4
+ from elasticsearch_dsl import Q, Search
5
  from typing import List, Optional, Dict, Union
6
  from dataclasses import dataclass
7
 
 
183
 
184
  def insert_citations(self, answer, chunks, chunk_v,
185
  embd_mdl, tkweight=0.3, vtweight=0.7):
186
+ assert len(chunks) == len(chunk_v)
187
  pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
188
  for i in range(1, len(pieces)):
189
  if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
 
217
  if mx < 0.55:
218
  continue
219
  cites[idx[i]] = list(
220
+ set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
221
 
222
  res = ""
223
  for i, p in enumerate(pieces):
 
226
  continue
227
  if i not in cites:
228
  continue
229
+ assert int(cites[i]) < len(chunk_v)
230
  res += "##%s$$" % "$".join(cites[i])
231
 
232
  return res