KevinHuSh
commited on
Commit
·
4c52eb9
1
Parent(s):
3772f42
refine admin initialization (#75)
Browse files- api/apps/chunk_app.py +2 -2
- api/apps/conversation_app.py +1 -3
- api/db/init_data.py +42 -4
- api/settings.py +5 -1
- deepdoc/parser/pdf_parser.py +1 -1
- deepdoc/vision/layout_recognizer.py +1 -1
- deepdoc/vision/postprocess.py +2 -3
- deepdoc/vision/recognizer.py +12 -0
- deepdoc/vision/t_recognizer.py +3 -1
- deepdoc/vision/table_structure_recognizer.py +5 -5
- rag/llm/chat_model.py +11 -8
- rag/nlp/__init__.py +2 -3
- rag/nlp/search.py +4 -2
api/apps/chunk_app.py
CHANGED
@@ -20,7 +20,7 @@ from flask_login import login_required, current_user
|
|
20 |
from elasticsearch_dsl import Q
|
21 |
|
22 |
from rag.app.qa import rmPrefix, beAdoc
|
23 |
-
from rag.nlp import search, huqie
|
24 |
from rag.utils import ELASTICSEARCH, rmSpace
|
25 |
from api.db import LLMType, ParserType
|
26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
@@ -28,7 +28,7 @@ from api.db.services.llm_service import TenantLLMService
|
|
28 |
from api.db.services.user_service import UserTenantService
|
29 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
30 |
from api.db.services.document_service import DocumentService
|
31 |
-
from api.settings import RetCode
|
32 |
from api.utils.api_utils import get_json_result
|
33 |
import hashlib
|
34 |
import re
|
|
|
20 |
from elasticsearch_dsl import Q
|
21 |
|
22 |
from rag.app.qa import rmPrefix, beAdoc
|
23 |
+
from rag.nlp import search, huqie
|
24 |
from rag.utils import ELASTICSEARCH, rmSpace
|
25 |
from api.db import LLMType, ParserType
|
26 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
28 |
from api.db.services.user_service import UserTenantService
|
29 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
30 |
from api.db.services.document_service import DocumentService
|
31 |
+
from api.settings import RetCode, retrievaler
|
32 |
from api.utils.api_utils import get_json_result
|
33 |
import hashlib
|
34 |
import re
|
api/apps/conversation_app.py
CHANGED
@@ -21,13 +21,11 @@ from api.db.services.dialog_service import DialogService, ConversationService
|
|
21 |
from api.db import LLMType
|
22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
23 |
from api.db.services.llm_service import LLMService, LLMBundle
|
24 |
-
from api.settings import access_logger, stat_logger
|
25 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
26 |
from api.utils import get_uuid
|
27 |
from api.utils.api_utils import get_json_result
|
28 |
from rag.app.resume import forbidden_select_fields4resume
|
29 |
-
from rag.llm import ChatModel
|
30 |
-
from rag.nlp import retrievaler
|
31 |
from rag.nlp.search import index_name
|
32 |
from rag.utils import num_tokens_from_string, encoder, rmSpace
|
33 |
|
|
|
21 |
from api.db import LLMType
|
22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
23 |
from api.db.services.llm_service import LLMService, LLMBundle
|
24 |
+
from api.settings import access_logger, stat_logger, retrievaler
|
25 |
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
26 |
from api.utils import get_uuid
|
27 |
from api.utils.api_utils import get_json_result
|
28 |
from rag.app.resume import forbidden_select_fields4resume
|
|
|
|
|
29 |
from rag.nlp.search import index_name
|
30 |
from rag.utils import num_tokens_from_string, encoder, rmSpace
|
31 |
|
api/db/init_data.py
CHANGED
@@ -16,10 +16,12 @@
|
|
16 |
import time
|
17 |
import uuid
|
18 |
|
19 |
-
from api.db import LLMType
|
20 |
from api.db.db_models import init_database_tables as init_web_db
|
21 |
from api.db.services import UserService
|
22 |
-
from api.db.services.llm_service import LLMFactoriesService, LLMService
|
|
|
|
|
23 |
|
24 |
|
25 |
def init_superuser():
|
@@ -32,8 +34,44 @@ def init_superuser():
|
|
32 |
"creator": "system",
|
33 |
"status": "1",
|
34 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
UserService.save(**user_info)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def init_llm_factory():
|
39 |
factory_infos = [{
|
@@ -171,10 +209,10 @@ def init_llm_factory():
|
|
171 |
|
172 |
def init_web_data():
|
173 |
start_time = time.time()
|
174 |
-
if not UserService.get_all().count():
|
175 |
-
init_superuser()
|
176 |
|
177 |
if not LLMService.get_all().count():init_llm_factory()
|
|
|
|
|
178 |
|
179 |
print("init web data success:{}".format(time.time() - start_time))
|
180 |
|
|
|
16 |
import time
|
17 |
import uuid
|
18 |
|
19 |
+
from api.db import LLMType, UserTenantRole
|
20 |
from api.db.db_models import init_database_tables as init_web_db
|
21 |
from api.db.services import UserService
|
22 |
+
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
23 |
+
from api.db.services.user_service import TenantService, UserTenantService
|
24 |
+
from api.settings import CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, LLM_FACTORY, API_KEY
|
25 |
|
26 |
|
27 |
def init_superuser():
|
|
|
34 |
"creator": "system",
|
35 |
"status": "1",
|
36 |
}
|
37 |
+
tenant = {
|
38 |
+
"id": user_info["id"],
|
39 |
+
"name": user_info["nickname"] + "‘s Kingdom",
|
40 |
+
"llm_id": CHAT_MDL,
|
41 |
+
"embd_id": EMBEDDING_MDL,
|
42 |
+
"asr_id": ASR_MDL,
|
43 |
+
"parser_ids": PARSERS,
|
44 |
+
"img2txt_id": IMAGE2TEXT_MDL
|
45 |
+
}
|
46 |
+
usr_tenant = {
|
47 |
+
"tenant_id": user_info["id"],
|
48 |
+
"user_id": user_info["id"],
|
49 |
+
"invited_by": user_info["id"],
|
50 |
+
"role": UserTenantRole.OWNER
|
51 |
+
}
|
52 |
+
tenant_llm = []
|
53 |
+
for llm in LLMService.query(fid=LLM_FACTORY):
|
54 |
+
tenant_llm.append(
|
55 |
+
{"tenant_id": user_info["id"], "llm_factory": LLM_FACTORY, "llm_name": llm.llm_name, "model_type": llm.model_type,
|
56 |
+
"api_key": API_KEY})
|
57 |
+
|
58 |
+
if not UserService.save(**user_info):
|
59 |
+
print("【ERROR】can't init admin.")
|
60 |
+
return
|
61 |
+
TenantService.save(**tenant)
|
62 |
+
UserTenantService.save(**usr_tenant)
|
63 |
+
TenantLLMService.insert_many(tenant_llm)
|
64 |
UserService.save(**user_info)
|
65 |
|
66 |
+
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
67 |
+
msg = chat_mdl.chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})
|
68 |
+
if msg.find("ERROR: ") == 0:
|
69 |
+
print("【ERROR】: '{}' dosen't work. {}".format(tenant["llm_id"]), msg)
|
70 |
+
embd_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["embd_id"])
|
71 |
+
v,c = embd_mdl.encode(["Hello!"])
|
72 |
+
if c == 0:
|
73 |
+
print("【ERROR】: '{}' dosen't work...".format(tenant["embd_id"]))
|
74 |
+
|
75 |
|
76 |
def init_llm_factory():
|
77 |
factory_infos = [{
|
|
|
209 |
|
210 |
def init_web_data():
|
211 |
start_time = time.time()
|
|
|
|
|
212 |
|
213 |
if not LLMService.get_all().count():init_llm_factory()
|
214 |
+
if not UserService.get_all().count():
|
215 |
+
init_superuser()
|
216 |
|
217 |
print("init web data success:{}".format(time.time() - start_time))
|
218 |
|
api/settings.py
CHANGED
@@ -21,8 +21,10 @@ from api.utils import get_base_config,decrypt_database_config
|
|
21 |
from api.utils.file_utils import get_project_base_directory
|
22 |
from api.utils.log_utils import LoggerFactory, getLogger
|
23 |
|
|
|
|
|
|
|
24 |
|
25 |
-
# Server
|
26 |
API_VERSION = "v1"
|
27 |
RAG_FLOW_SERVICE_NAME = "ragflow"
|
28 |
SERVER_MODULE = "rag_flow_server.py"
|
@@ -116,6 +118,8 @@ AUTHENTICATION_DEFAULT_TIMEOUT = 30 * 24 * 60 * 60 # s
|
|
116 |
PRIVILEGE_COMMAND_WHITELIST = []
|
117 |
CHECK_NODES_IDENTITY = False
|
118 |
|
|
|
|
|
119 |
class CustomEnum(Enum):
|
120 |
@classmethod
|
121 |
def valid(cls, value):
|
|
|
21 |
from api.utils.file_utils import get_project_base_directory
|
22 |
from api.utils.log_utils import LoggerFactory, getLogger
|
23 |
|
24 |
+
from rag.nlp import search
|
25 |
+
from rag.utils import ELASTICSEARCH
|
26 |
+
|
27 |
|
|
|
28 |
API_VERSION = "v1"
|
29 |
RAG_FLOW_SERVICE_NAME = "ragflow"
|
30 |
SERVER_MODULE = "rag_flow_server.py"
|
|
|
118 |
PRIVILEGE_COMMAND_WHITELIST = []
|
119 |
CHECK_NODES_IDENTITY = False
|
120 |
|
121 |
+
retrievaler = search.Dealer(ELASTICSEARCH)
|
122 |
+
|
123 |
class CustomEnum(Enum):
|
124 |
@classmethod
|
125 |
def valid(cls, value):
|
deepdoc/parser/pdf_parser.py
CHANGED
@@ -230,7 +230,7 @@ class HuParser:
|
|
230 |
b["H_right"] = headers[ii]["x1"]
|
231 |
b["H"] = ii
|
232 |
|
233 |
-
ii = Recognizer.
|
234 |
if ii is not None:
|
235 |
b["C"] = ii
|
236 |
b["C_left"] = clmns[ii]["x0"]
|
|
|
230 |
b["H_right"] = headers[ii]["x1"]
|
231 |
b["H"] = ii
|
232 |
|
233 |
+
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
|
234 |
if ii is not None:
|
235 |
b["C"] = ii
|
236 |
b["C_left"] = clmns[ii]["x0"]
|
deepdoc/vision/layout_recognizer.py
CHANGED
@@ -37,7 +37,7 @@ class LayoutRecognizer(Recognizer):
|
|
37 |
super().__init__(self.labels, domain,
|
38 |
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
39 |
|
40 |
-
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.
|
41 |
def __is_garbage(b):
|
42 |
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
43 |
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
|
|
37 |
super().__init__(self.labels, domain,
|
38 |
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
39 |
|
40 |
+
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.2, batch_size=16):
|
41 |
def __is_garbage(b):
|
42 |
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
43 |
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
deepdoc/vision/postprocess.py
CHANGED
@@ -2,7 +2,6 @@ import copy
|
|
2 |
|
3 |
import numpy as np
|
4 |
import cv2
|
5 |
-
import paddle
|
6 |
from shapely.geometry import Polygon
|
7 |
import pyclipper
|
8 |
|
@@ -215,7 +214,7 @@ class DBPostProcess(object):
|
|
215 |
|
216 |
def __call__(self, outs_dict, shape_list):
|
217 |
pred = outs_dict['maps']
|
218 |
-
if isinstance(pred,
|
219 |
pred = pred.numpy()
|
220 |
pred = pred[:, 0, :, :]
|
221 |
segmentation = pred > self.thresh
|
@@ -339,7 +338,7 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|
339 |
def __call__(self, preds, label=None, *args, **kwargs):
|
340 |
if isinstance(preds, tuple) or isinstance(preds, list):
|
341 |
preds = preds[-1]
|
342 |
-
if isinstance(preds,
|
343 |
preds = preds.numpy()
|
344 |
preds_idx = preds.argmax(axis=2)
|
345 |
preds_prob = preds.max(axis=2)
|
|
|
2 |
|
3 |
import numpy as np
|
4 |
import cv2
|
|
|
5 |
from shapely.geometry import Polygon
|
6 |
import pyclipper
|
7 |
|
|
|
214 |
|
215 |
def __call__(self, outs_dict, shape_list):
|
216 |
pred = outs_dict['maps']
|
217 |
+
if not isinstance(pred, np.ndarray):
|
218 |
pred = pred.numpy()
|
219 |
pred = pred[:, 0, :, :]
|
220 |
segmentation = pred > self.thresh
|
|
|
338 |
def __call__(self, preds, label=None, *args, **kwargs):
|
339 |
if isinstance(preds, tuple) or isinstance(preds, list):
|
340 |
preds = preds[-1]
|
341 |
+
if not isinstance(preds, np.ndarray):
|
342 |
preds = preds.numpy()
|
343 |
preds_idx = preds.argmax(axis=2)
|
344 |
preds_prob = preds.max(axis=2)
|
deepdoc/vision/recognizer.py
CHANGED
@@ -259,6 +259,18 @@ class Recognizer(object):
|
|
259 |
|
260 |
return max_overlaped_i
|
261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
@staticmethod
|
263 |
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
264 |
if not boxes:
|
|
|
259 |
|
260 |
return max_overlaped_i
|
261 |
|
262 |
+
@staticmethod
|
263 |
+
def find_horizontally_tightest_fit(box, boxes):
|
264 |
+
if not boxes:
|
265 |
+
return
|
266 |
+
min_dis, min_i = 1000000, None
|
267 |
+
for i,b in enumerate(boxes):
|
268 |
+
dis = min(abs(box["x0"] - b["x0"]), abs(box["x1"] - b["x1"]), abs(box["x0"]+box["x1"] - b["x1"] - b["x0"])/2)
|
269 |
+
if dis < min_dis:
|
270 |
+
min_i = i
|
271 |
+
min_dis = dis
|
272 |
+
return min_i
|
273 |
+
|
274 |
@staticmethod
|
275 |
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
276 |
if not boxes:
|
deepdoc/vision/t_recognizer.py
CHANGED
@@ -74,6 +74,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
74 |
clmns = sorted([r for r in tb_cpns if re.match(
|
75 |
r"table column$", r["label"])], key=lambda x: x["x0"])
|
76 |
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
|
|
77 |
for b in boxes:
|
78 |
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
|
79 |
if ii is not None:
|
@@ -89,7 +90,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
89 |
b["H_right"] = headers[ii]["x1"]
|
90 |
b["H"] = ii
|
91 |
|
92 |
-
ii = Recognizer.
|
93 |
if ii is not None:
|
94 |
b["C"] = ii
|
95 |
b["C_left"] = clmns[ii]["x0"]
|
@@ -102,6 +103,7 @@ def get_table_html(img, tb_cpns, ocr):
|
|
102 |
b["H_left"] = spans[ii]["x0"]
|
103 |
b["H_right"] = spans[ii]["x1"]
|
104 |
b["SP"] = ii
|
|
|
105 |
html = """
|
106 |
<html>
|
107 |
<head>
|
|
|
74 |
clmns = sorted([r for r in tb_cpns if re.match(
|
75 |
r"table column$", r["label"])], key=lambda x: x["x0"])
|
76 |
clmns = Recognizer.layouts_cleanup(boxes, clmns, 5, 0.5)
|
77 |
+
|
78 |
for b in boxes:
|
79 |
ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
|
80 |
if ii is not None:
|
|
|
90 |
b["H_right"] = headers[ii]["x1"]
|
91 |
b["H"] = ii
|
92 |
|
93 |
+
ii = Recognizer.find_horizontally_tightest_fit(b, clmns)
|
94 |
if ii is not None:
|
95 |
b["C"] = ii
|
96 |
b["C_left"] = clmns[ii]["x0"]
|
|
|
103 |
b["H_left"] = spans[ii]["x0"]
|
104 |
b["H_right"] = spans[ii]["x1"]
|
105 |
b["SP"] = ii
|
106 |
+
|
107 |
html = """
|
108 |
<html>
|
109 |
<head>
|
deepdoc/vision/table_structure_recognizer.py
CHANGED
@@ -14,7 +14,6 @@ import logging
|
|
14 |
import os
|
15 |
import re
|
16 |
from collections import Counter
|
17 |
-
from copy import deepcopy
|
18 |
|
19 |
import numpy as np
|
20 |
|
@@ -37,7 +36,7 @@ class TableStructureRecognizer(Recognizer):
|
|
37 |
super().__init__(self.labels, "tsr",
|
38 |
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
39 |
|
40 |
-
def __call__(self, images, thr=0.
|
41 |
tbls = super().__call__(images, thr)
|
42 |
res = []
|
43 |
# align left&right for rows, align top&bottom for columns
|
@@ -56,8 +55,8 @@ class TableStructureRecognizer(Recognizer):
|
|
56 |
"row") > 0 or b["label"].find("header") > 0]
|
57 |
if not left:
|
58 |
continue
|
59 |
-
left = np.
|
60 |
-
right = np.
|
61 |
for b in lts:
|
62 |
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
|
63 |
if b["x0"] > left:
|
@@ -129,6 +128,7 @@ class TableStructureRecognizer(Recognizer):
|
|
129 |
i = 0
|
130 |
while i < len(boxes):
|
131 |
if TableStructureRecognizer.is_caption(boxes[i]):
|
|
|
132 |
cap += boxes[i]["text"]
|
133 |
boxes.pop(i)
|
134 |
i -= 1
|
@@ -398,7 +398,7 @@ class TableStructureRecognizer(Recognizer):
|
|
398 |
for i in range(clmno):
|
399 |
if not tbl[r][i]:
|
400 |
continue
|
401 |
-
txt = "".join([a["text"].strip() for a in tbl[r][i]])
|
402 |
headers[r][i] = txt
|
403 |
hdrset.add(txt)
|
404 |
if all([not t for t in headers[r]]):
|
|
|
14 |
import os
|
15 |
import re
|
16 |
from collections import Counter
|
|
|
17 |
|
18 |
import numpy as np
|
19 |
|
|
|
36 |
super().__init__(self.labels, "tsr",
|
37 |
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
38 |
|
39 |
+
def __call__(self, images, thr=0.2):
|
40 |
tbls = super().__call__(images, thr)
|
41 |
res = []
|
42 |
# align left&right for rows, align top&bottom for columns
|
|
|
55 |
"row") > 0 or b["label"].find("header") > 0]
|
56 |
if not left:
|
57 |
continue
|
58 |
+
left = np.mean(left) if len(left) > 4 else np.min(left)
|
59 |
+
right = np.mean(right) if len(right) > 4 else np.max(right)
|
60 |
for b in lts:
|
61 |
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
|
62 |
if b["x0"] > left:
|
|
|
128 |
i = 0
|
129 |
while i < len(boxes):
|
130 |
if TableStructureRecognizer.is_caption(boxes[i]):
|
131 |
+
if is_english: cap + " "
|
132 |
cap += boxes[i]["text"]
|
133 |
boxes.pop(i)
|
134 |
i -= 1
|
|
|
398 |
for i in range(clmno):
|
399 |
if not tbl[r][i]:
|
400 |
continue
|
401 |
+
txt = " ".join([a["text"].strip() for a in tbl[r][i]])
|
402 |
headers[r][i] = txt
|
403 |
hdrset.add(txt)
|
404 |
if all([not t for t in headers[r]]):
|
rag/llm/chat_model.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
#
|
16 |
from abc import ABC
|
17 |
from openai import OpenAI
|
18 |
-
import
|
19 |
|
20 |
|
21 |
class Base(ABC):
|
@@ -33,11 +33,14 @@ class GptTurbo(Base):
|
|
33 |
|
34 |
def chat(self, system, history, gen_conf):
|
35 |
if system: history.insert(0, {"role": "system", "content": system})
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
from dashscope import Generation
|
@@ -58,7 +61,7 @@ class QWenChat(Base):
|
|
58 |
)
|
59 |
if response.status_code == HTTPStatus.OK:
|
60 |
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
61 |
-
return response.message, 0
|
62 |
|
63 |
|
64 |
from zhipuai import ZhipuAI
|
@@ -77,4 +80,4 @@ class ZhipuChat(Base):
|
|
77 |
)
|
78 |
if response.status_code == HTTPStatus.OK:
|
79 |
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
|
80 |
-
return response.message, 0
|
|
|
15 |
#
|
16 |
from abc import ABC
|
17 |
from openai import OpenAI
|
18 |
+
import openai
|
19 |
|
20 |
|
21 |
class Base(ABC):
|
|
|
33 |
|
34 |
def chat(self, system, history, gen_conf):
|
35 |
if system: history.insert(0, {"role": "system", "content": system})
|
36 |
+
try:
|
37 |
+
res = self.client.chat.completions.create(
|
38 |
+
model=self.model_name,
|
39 |
+
messages=history,
|
40 |
+
**gen_conf)
|
41 |
+
return res.choices[0].message.content.strip(), res.usage.completion_tokens
|
42 |
+
except openai.APIError as e:
|
43 |
+
return "ERROR: "+str(e), 0
|
44 |
|
45 |
|
46 |
from dashscope import Generation
|
|
|
61 |
)
|
62 |
if response.status_code == HTTPStatus.OK:
|
63 |
return response.output.choices[0]['message']['content'], response.usage.output_tokens
|
64 |
+
return "ERROR: " + response.message, 0
|
65 |
|
66 |
|
67 |
from zhipuai import ZhipuAI
|
|
|
80 |
)
|
81 |
if response.status_code == HTTPStatus.OK:
|
82 |
return response.output.choices[0]['message']['content'], response.usage.completion_tokens
|
83 |
+
return "ERROR: " + response.message, 0
|
rag/nlp/__init__.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1 |
-
from . import search
|
2 |
-
from rag.utils import ELASTICSEARCH
|
3 |
|
4 |
-
retrievaler = search.Dealer(ELASTICSEARCH)
|
5 |
|
6 |
from nltk.stem import PorterStemmer
|
7 |
stemmer = PorterStemmer()
|
@@ -39,10 +36,12 @@ BULLET_PATTERN = [[
|
|
39 |
]
|
40 |
]
|
41 |
|
|
|
42 |
def random_choices(arr, k):
|
43 |
k = min(len(arr), k)
|
44 |
return random.choices(arr, k=k)
|
45 |
|
|
|
46 |
def bullets_category(sections):
|
47 |
global BULLET_PATTERN
|
48 |
hits = [0] * len(BULLET_PATTERN)
|
|
|
|
|
|
|
1 |
|
|
|
2 |
|
3 |
from nltk.stem import PorterStemmer
|
4 |
stemmer = PorterStemmer()
|
|
|
36 |
]
|
37 |
]
|
38 |
|
39 |
+
|
40 |
def random_choices(arr, k):
|
41 |
k = min(len(arr), k)
|
42 |
return random.choices(arr, k=k)
|
43 |
|
44 |
+
|
45 |
def bullets_category(sections):
|
46 |
global BULLET_PATTERN
|
47 |
hits = [0] * len(BULLET_PATTERN)
|
rag/nlp/search.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
import json
|
3 |
import re
|
4 |
-
from elasticsearch_dsl import Q, Search
|
5 |
from typing import List, Optional, Dict, Union
|
6 |
from dataclasses import dataclass
|
7 |
|
@@ -183,6 +183,7 @@ class Dealer:
|
|
183 |
|
184 |
def insert_citations(self, answer, chunks, chunk_v,
|
185 |
embd_mdl, tkweight=0.3, vtweight=0.7):
|
|
|
186 |
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
|
187 |
for i in range(1, len(pieces)):
|
188 |
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
|
@@ -216,7 +217,7 @@ class Dealer:
|
|
216 |
if mx < 0.55:
|
217 |
continue
|
218 |
cites[idx[i]] = list(
|
219 |
-
set([str(
|
220 |
|
221 |
res = ""
|
222 |
for i, p in enumerate(pieces):
|
@@ -225,6 +226,7 @@ class Dealer:
|
|
225 |
continue
|
226 |
if i not in cites:
|
227 |
continue
|
|
|
228 |
res += "##%s$$" % "$".join(cites[i])
|
229 |
|
230 |
return res
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
import json
|
3 |
import re
|
4 |
+
from elasticsearch_dsl import Q, Search
|
5 |
from typing import List, Optional, Dict, Union
|
6 |
from dataclasses import dataclass
|
7 |
|
|
|
183 |
|
184 |
def insert_citations(self, answer, chunks, chunk_v,
|
185 |
embd_mdl, tkweight=0.3, vtweight=0.7):
|
186 |
+
assert len(chunks) == len(chunk_v)
|
187 |
pieces = re.split(r"([;。?!!\n]|[a-z][.?;!][ \n])", answer)
|
188 |
for i in range(1, len(pieces)):
|
189 |
if re.match(r"[a-z][.?;!][ \n]", pieces[i]):
|
|
|
217 |
if mx < 0.55:
|
218 |
continue
|
219 |
cites[idx[i]] = list(
|
220 |
+
set([str(ii) for ii in range(len(chunk_v)) if sim[ii] > mx]))[:4]
|
221 |
|
222 |
res = ""
|
223 |
for i, p in enumerate(pieces):
|
|
|
226 |
continue
|
227 |
if i not in cites:
|
228 |
continue
|
229 |
+
assert int(cites[i]) < len(chunk_v)
|
230 |
res += "##%s$$" % "$".join(cites[i])
|
231 |
|
232 |
return res
|