Kevin Hu
commited on
Commit
·
6c8312a
1
Parent(s):
e547053
Tagging (#4426)
Browse files### What problem does this PR solve?
#4367
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
- agent/component/retrieval.py +3 -1
- api/apps/api_app.py +3 -2
- api/apps/chunk_app.py +17 -8
- api/apps/conversation_app.py +6 -3
- api/apps/kb_app.py +76 -4
- api/apps/sdk/dataset.py +6 -2
- api/apps/sdk/dify_retrieval.py +3 -1
- api/apps/sdk/doc.py +3 -1
- api/db/__init__.py +1 -0
- api/db/init_data.py +2 -9
- api/db/services/dialog_service.py +86 -2
- api/db/services/file2document_service.py +3 -7
- api/db/services/file_service.py +1 -4
- api/db/services/knowledgebase_service.py +6 -1
- api/db/services/task_service.py +1 -0
- api/settings.py +1 -1
- api/utils/api_utils.py +41 -30
- conf/infinity_mapping.json +3 -1
- graphrag/utils.py +20 -1
- rag/app/qa.py +19 -14
- rag/app/tag.py +125 -0
- rag/nlp/query.py +29 -6
- rag/nlp/search.py +127 -41
- rag/settings.py +3 -0
- rag/svr/task_executor.py +56 -15
- rag/utils/__init__.py +2 -0
- rag/utils/doc_store_conn.py +12 -2
- rag/utils/es_conn.py +71 -38
- rag/utils/infinity_conn.py +9 -7
- sdk/python/test/test_sdk_api/t_dataset.py +1 -1
agent/component/retrieval.py
CHANGED
@@ -19,6 +19,7 @@ from abc import ABC
|
|
19 |
import pandas as pd
|
20 |
|
21 |
from api.db import LLMType
|
|
|
22 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
23 |
from api.db.services.llm_service import LLMBundle
|
24 |
from api import settings
|
@@ -70,7 +71,8 @@ class Retrieval(ComponentBase, ABC):
|
|
70 |
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
71 |
1, self._param.top_n,
|
72 |
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
73 |
-
aggs=False, rerank_mdl=rerank_mdl
|
|
|
74 |
|
75 |
if not kbinfos["chunks"]:
|
76 |
df = Retrieval.be_output("")
|
|
|
19 |
import pandas as pd
|
20 |
|
21 |
from api.db import LLMType
|
22 |
+
from api.db.services.dialog_service import label_question
|
23 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
24 |
from api.db.services.llm_service import LLMBundle
|
25 |
from api import settings
|
|
|
71 |
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
72 |
1, self._param.top_n,
|
73 |
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
74 |
+
aggs=False, rerank_mdl=rerank_mdl,
|
75 |
+
rank_feature=label_question(query, kbs))
|
76 |
|
77 |
if not kbinfos["chunks"]:
|
78 |
df = Retrieval.be_output("")
|
api/apps/api_app.py
CHANGED
@@ -25,7 +25,7 @@ from api.db import FileType, LLMType, ParserType, FileSource
|
|
25 |
from api.db.db_models import APIToken, Task, File
|
26 |
from api.db.services import duplicate_name
|
27 |
from api.db.services.api_service import APITokenService, API4ConversationService
|
28 |
-
from api.db.services.dialog_service import DialogService, chat, keyword_extraction
|
29 |
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
30 |
from api.db.services.file2document_service import File2DocumentService
|
31 |
from api.db.services.file_service import FileService
|
@@ -840,7 +840,8 @@ def retrieval():
|
|
840 |
question += keyword_extraction(chat_mdl, question)
|
841 |
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
842 |
similarity_threshold, vector_similarity_weight, top,
|
843 |
-
doc_ids, rerank_mdl=rerank_mdl
|
|
|
844 |
for c in ranks["chunks"]:
|
845 |
c.pop("vector", None)
|
846 |
return get_json_result(data=ranks)
|
|
|
25 |
from api.db.db_models import APIToken, Task, File
|
26 |
from api.db.services import duplicate_name
|
27 |
from api.db.services.api_service import APITokenService, API4ConversationService
|
28 |
+
from api.db.services.dialog_service import DialogService, chat, keyword_extraction, label_question
|
29 |
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
30 |
from api.db.services.file2document_service import File2DocumentService
|
31 |
from api.db.services.file_service import FileService
|
|
|
840 |
question += keyword_extraction(chat_mdl, question)
|
841 |
ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
842 |
similarity_threshold, vector_similarity_weight, top,
|
843 |
+
doc_ids, rerank_mdl=rerank_mdl,
|
844 |
+
rank_feature=label_question(question, kbs))
|
845 |
for c in ranks["chunks"]:
|
846 |
c.pop("vector", None)
|
847 |
return get_json_result(data=ranks)
|
api/apps/chunk_app.py
CHANGED
@@ -19,9 +19,10 @@ import json
|
|
19 |
from flask import request
|
20 |
from flask_login import login_required, current_user
|
21 |
|
22 |
-
from api.db.services.dialog_service import keyword_extraction
|
23 |
from rag.app.qa import rmPrefix, beAdoc
|
24 |
from rag.nlp import search, rag_tokenizer
|
|
|
25 |
from rag.utils import rmSpace
|
26 |
from api.db import LLMType, ParserType
|
27 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
@@ -124,10 +125,14 @@ def set():
|
|
124 |
"content_with_weight": req["content_with_weight"]}
|
125 |
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
126 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
131 |
if "available_int" in req:
|
132 |
d["available_int"] = req["available_int"]
|
133 |
|
@@ -220,7 +225,7 @@ def create():
|
|
220 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
221 |
if not e:
|
222 |
return get_data_error_result(message="Document not found!")
|
223 |
-
d["kb_id"] = doc.kb_id
|
224 |
d["docnm_kwd"] = doc.name
|
225 |
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
226 |
d["doc_id"] = doc.id
|
@@ -233,7 +238,7 @@ def create():
|
|
233 |
if not e:
|
234 |
return get_data_error_result(message="Knowledgebase not found!")
|
235 |
if kb.pagerank:
|
236 |
-
d[
|
237 |
|
238 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
239 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
@@ -294,12 +299,16 @@ def retrieval_test():
|
|
294 |
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
295 |
question += keyword_extraction(chat_mdl, question)
|
296 |
|
|
|
297 |
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
298 |
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
299 |
similarity_threshold, vector_similarity_weight, top,
|
300 |
-
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight")
|
|
|
|
|
301 |
for c in ranks["chunks"]:
|
302 |
c.pop("vector", None)
|
|
|
303 |
|
304 |
return get_json_result(data=ranks)
|
305 |
except Exception as e:
|
|
|
19 |
from flask import request
|
20 |
from flask_login import login_required, current_user
|
21 |
|
22 |
+
from api.db.services.dialog_service import keyword_extraction, label_question
|
23 |
from rag.app.qa import rmPrefix, beAdoc
|
24 |
from rag.nlp import search, rag_tokenizer
|
25 |
+
from rag.settings import PAGERANK_FLD
|
26 |
from rag.utils import rmSpace
|
27 |
from api.db import LLMType, ParserType
|
28 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
|
|
125 |
"content_with_weight": req["content_with_weight"]}
|
126 |
d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
|
127 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
128 |
+
if req.get("important_kwd"):
|
129 |
+
d["important_kwd"] = req["important_kwd"]
|
130 |
+
d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
|
131 |
+
if req.get("question_kwd"):
|
132 |
+
d["question_kwd"] = req["question_kwd"]
|
133 |
+
d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
|
134 |
+
if req.get("tag_kwd"):
|
135 |
+
d["tag_kwd"] = req["tag_kwd"]
|
136 |
if "available_int" in req:
|
137 |
d["available_int"] = req["available_int"]
|
138 |
|
|
|
225 |
e, doc = DocumentService.get_by_id(req["doc_id"])
|
226 |
if not e:
|
227 |
return get_data_error_result(message="Document not found!")
|
228 |
+
d["kb_id"] = [doc.kb_id]
|
229 |
d["docnm_kwd"] = doc.name
|
230 |
d["title_tks"] = rag_tokenizer.tokenize(doc.name)
|
231 |
d["doc_id"] = doc.id
|
|
|
238 |
if not e:
|
239 |
return get_data_error_result(message="Knowledgebase not found!")
|
240 |
if kb.pagerank:
|
241 |
+
d[PAGERANK_FLD] = kb.pagerank
|
242 |
|
243 |
embd_id = DocumentService.get_embd_id(req["doc_id"])
|
244 |
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
|
|
|
299 |
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
300 |
question += keyword_extraction(chat_mdl, question)
|
301 |
|
302 |
+
labels = label_question(question, [kb])
|
303 |
retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
|
304 |
ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
305 |
similarity_threshold, vector_similarity_weight, top,
|
306 |
+
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
|
307 |
+
rank_feature=labels
|
308 |
+
)
|
309 |
for c in ranks["chunks"]:
|
310 |
c.pop("vector", None)
|
311 |
+
ranks["labels"] = labels
|
312 |
|
313 |
return get_json_result(data=ranks)
|
314 |
except Exception as e:
|
api/apps/conversation_app.py
CHANGED
@@ -25,7 +25,7 @@ from flask import request, Response
|
|
25 |
from flask_login import login_required, current_user
|
26 |
|
27 |
from api.db import LLMType
|
28 |
-
from api.db.services.dialog_service import DialogService, chat, ask
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
30 |
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
31 |
from api import settings
|
@@ -379,8 +379,11 @@ def mindmap():
|
|
379 |
embd_mdl = TenantLLMService.model_instance(
|
380 |
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
381 |
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
384 |
mindmap = MindMapExtractor(chat_mdl)
|
385 |
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
386 |
if "error" in mind_map:
|
|
|
25 |
from flask_login import login_required, current_user
|
26 |
|
27 |
from api.db import LLMType
|
28 |
+
from api.db.services.dialog_service import DialogService, chat, ask, label_question
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
30 |
from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
|
31 |
from api import settings
|
|
|
379 |
embd_mdl = TenantLLMService.model_instance(
|
380 |
kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
381 |
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
|
382 |
+
question = req["question"]
|
383 |
+
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12,
|
384 |
+
0.3, 0.3, aggs=False,
|
385 |
+
rank_feature=label_question(question, [kb])
|
386 |
+
)
|
387 |
mindmap = MindMapExtractor(chat_mdl)
|
388 |
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
389 |
if "error" in mind_map:
|
api/apps/kb_app.py
CHANGED
@@ -30,6 +30,7 @@ from api.utils.api_utils import get_json_result
|
|
30 |
from api import settings
|
31 |
from rag.nlp import search
|
32 |
from api.constants import DATASET_NAME_LIMIT
|
|
|
33 |
|
34 |
|
35 |
@manager.route('/create', methods=['post']) # noqa: F821
|
@@ -104,11 +105,11 @@ def update():
|
|
104 |
|
105 |
if kb.pagerank != req.get("pagerank", 0):
|
106 |
if req.get("pagerank", 0) > 0:
|
107 |
-
settings.docStoreConn.update({"kb_id": kb.id}, {
|
108 |
search.index_name(kb.tenant_id), kb.id)
|
109 |
else:
|
110 |
-
# Elasticsearch requires
|
111 |
-
settings.docStoreConn.update({"exist":
|
112 |
search.index_name(kb.tenant_id), kb.id)
|
113 |
|
114 |
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
@@ -150,12 +151,14 @@ def list_kbs():
|
|
150 |
keywords = request.args.get("keywords", "")
|
151 |
page_number = int(request.args.get("page", 1))
|
152 |
items_per_page = int(request.args.get("page_size", 150))
|
|
|
153 |
orderby = request.args.get("orderby", "create_time")
|
154 |
desc = request.args.get("desc", True)
|
155 |
try:
|
156 |
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
157 |
kbs, total = KnowledgebaseService.get_by_tenant_ids(
|
158 |
-
[m["tenant_id"] for m in tenants], current_user.id, page_number,
|
|
|
159 |
return get_json_result(data={"kbs": kbs, "total": total})
|
160 |
except Exception as e:
|
161 |
return server_error_response(e)
|
@@ -199,3 +202,72 @@ def rm():
|
|
199 |
return get_json_result(data=True)
|
200 |
except Exception as e:
|
201 |
return server_error_response(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
from api import settings
|
31 |
from rag.nlp import search
|
32 |
from api.constants import DATASET_NAME_LIMIT
|
33 |
+
from rag.settings import PAGERANK_FLD
|
34 |
|
35 |
|
36 |
@manager.route('/create', methods=['post']) # noqa: F821
|
|
|
105 |
|
106 |
if kb.pagerank != req.get("pagerank", 0):
|
107 |
if req.get("pagerank", 0) > 0:
|
108 |
+
settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
|
109 |
search.index_name(kb.tenant_id), kb.id)
|
110 |
else:
|
111 |
+
# Elasticsearch requires PAGERANK_FLD be non-zero!
|
112 |
+
settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
|
113 |
search.index_name(kb.tenant_id), kb.id)
|
114 |
|
115 |
e, kb = KnowledgebaseService.get_by_id(kb.id)
|
|
|
151 |
keywords = request.args.get("keywords", "")
|
152 |
page_number = int(request.args.get("page", 1))
|
153 |
items_per_page = int(request.args.get("page_size", 150))
|
154 |
+
parser_id = request.args.get("parser_id")
|
155 |
orderby = request.args.get("orderby", "create_time")
|
156 |
desc = request.args.get("desc", True)
|
157 |
try:
|
158 |
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
159 |
kbs, total = KnowledgebaseService.get_by_tenant_ids(
|
160 |
+
[m["tenant_id"] for m in tenants], current_user.id, page_number,
|
161 |
+
items_per_page, orderby, desc, keywords, parser_id)
|
162 |
return get_json_result(data={"kbs": kbs, "total": total})
|
163 |
except Exception as e:
|
164 |
return server_error_response(e)
|
|
|
202 |
return get_json_result(data=True)
|
203 |
except Exception as e:
|
204 |
return server_error_response(e)
|
205 |
+
|
206 |
+
|
207 |
+
@manager.route('/<kb_id>/tags', methods=['GET']) # noqa: F821
|
208 |
+
@login_required
|
209 |
+
def list_tags(kb_id):
|
210 |
+
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
211 |
+
return get_json_result(
|
212 |
+
data=False,
|
213 |
+
message='No authorization.',
|
214 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
215 |
+
)
|
216 |
+
|
217 |
+
tags = settings.retrievaler.all_tags(current_user.id, [kb_id])
|
218 |
+
return get_json_result(data=tags)
|
219 |
+
|
220 |
+
|
221 |
+
@manager.route('/tags', methods=['GET']) # noqa: F821
|
222 |
+
@login_required
|
223 |
+
def list_tags_from_kbs():
|
224 |
+
kb_ids = request.args.get("kb_ids", "").split(",")
|
225 |
+
for kb_id in kb_ids:
|
226 |
+
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
227 |
+
return get_json_result(
|
228 |
+
data=False,
|
229 |
+
message='No authorization.',
|
230 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
231 |
+
)
|
232 |
+
|
233 |
+
tags = settings.retrievaler.all_tags(current_user.id, kb_ids)
|
234 |
+
return get_json_result(data=tags)
|
235 |
+
|
236 |
+
|
237 |
+
@manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
|
238 |
+
@login_required
|
239 |
+
def rm_tags(kb_id):
|
240 |
+
req = request.json
|
241 |
+
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
242 |
+
return get_json_result(
|
243 |
+
data=False,
|
244 |
+
message='No authorization.',
|
245 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
246 |
+
)
|
247 |
+
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
248 |
+
|
249 |
+
for t in req["tags"]:
|
250 |
+
settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
|
251 |
+
{"remove": {"tag_kwd": t}},
|
252 |
+
search.index_name(kb.tenant_id),
|
253 |
+
kb_id)
|
254 |
+
return get_json_result(data=True)
|
255 |
+
|
256 |
+
|
257 |
+
@manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
|
258 |
+
@login_required
|
259 |
+
def rename_tags(kb_id):
|
260 |
+
req = request.json
|
261 |
+
if not KnowledgebaseService.accessible(kb_id, current_user.id):
|
262 |
+
return get_json_result(
|
263 |
+
data=False,
|
264 |
+
message='No authorization.',
|
265 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
266 |
+
)
|
267 |
+
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
268 |
+
|
269 |
+
settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
|
270 |
+
{"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
|
271 |
+
search.index_name(kb.tenant_id),
|
272 |
+
kb_id)
|
273 |
+
return get_json_result(data=True)
|
api/apps/sdk/dataset.py
CHANGED
@@ -73,7 +73,8 @@ def create(tenant_id):
|
|
73 |
chunk_method:
|
74 |
type: string
|
75 |
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
76 |
-
"presentation", "picture", "one", "knowledge_graph", "email"
|
|
|
77 |
description: Chunking method.
|
78 |
parser_config:
|
79 |
type: object
|
@@ -108,6 +109,7 @@ def create(tenant_id):
|
|
108 |
"one",
|
109 |
"knowledge_graph",
|
110 |
"email",
|
|
|
111 |
]
|
112 |
check_validation = valid(
|
113 |
permission,
|
@@ -302,7 +304,8 @@ def update(tenant_id, dataset_id):
|
|
302 |
chunk_method:
|
303 |
type: string
|
304 |
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
305 |
-
"presentation", "picture", "one", "knowledge_graph", "email"
|
|
|
306 |
description: Updated chunking method.
|
307 |
parser_config:
|
308 |
type: object
|
@@ -339,6 +342,7 @@ def update(tenant_id, dataset_id):
|
|
339 |
"one",
|
340 |
"knowledge_graph",
|
341 |
"email",
|
|
|
342 |
]
|
343 |
check_validation = valid(
|
344 |
permission,
|
|
|
73 |
chunk_method:
|
74 |
type: string
|
75 |
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
76 |
+
"presentation", "picture", "one", "knowledge_graph", "email", "tag"
|
77 |
+
]
|
78 |
description: Chunking method.
|
79 |
parser_config:
|
80 |
type: object
|
|
|
109 |
"one",
|
110 |
"knowledge_graph",
|
111 |
"email",
|
112 |
+
"tag"
|
113 |
]
|
114 |
check_validation = valid(
|
115 |
permission,
|
|
|
304 |
chunk_method:
|
305 |
type: string
|
306 |
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
307 |
+
"presentation", "picture", "one", "knowledge_graph", "email", "tag"
|
308 |
+
]
|
309 |
description: Updated chunking method.
|
310 |
parser_config:
|
311 |
type: object
|
|
|
342 |
"one",
|
343 |
"knowledge_graph",
|
344 |
"email",
|
345 |
+
"tag"
|
346 |
]
|
347 |
check_validation = valid(
|
348 |
permission,
|
api/apps/sdk/dify_retrieval.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
from flask import request, jsonify
|
17 |
|
18 |
from api.db import LLMType, ParserType
|
|
|
19 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
20 |
from api.db.services.llm_service import LLMBundle
|
21 |
from api import settings
|
@@ -54,7 +55,8 @@ def retrieval(tenant_id):
|
|
54 |
page_size=top,
|
55 |
similarity_threshold=similarity_threshold,
|
56 |
vector_similarity_weight=0.3,
|
57 |
-
top=top
|
|
|
58 |
)
|
59 |
records = []
|
60 |
for c in ranks["chunks"]:
|
|
|
16 |
from flask import request, jsonify
|
17 |
|
18 |
from api.db import LLMType, ParserType
|
19 |
+
from api.db.services.dialog_service import label_question
|
20 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
21 |
from api.db.services.llm_service import LLMBundle
|
22 |
from api import settings
|
|
|
55 |
page_size=top,
|
56 |
similarity_threshold=similarity_threshold,
|
57 |
vector_similarity_weight=0.3,
|
58 |
+
top=top,
|
59 |
+
rank_feature=label_question(question, [kb])
|
60 |
)
|
61 |
records = []
|
62 |
for c in ranks["chunks"]:
|
api/apps/sdk/doc.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
import pathlib
|
17 |
import datetime
|
18 |
|
19 |
-
from api.db.services.dialog_service import keyword_extraction
|
20 |
from rag.app.qa import rmPrefix, beAdoc
|
21 |
from rag.nlp import rag_tokenizer
|
22 |
from api.db import LLMType, ParserType
|
@@ -276,6 +276,7 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|
276 |
"one",
|
277 |
"knowledge_graph",
|
278 |
"email",
|
|
|
279 |
}
|
280 |
if req.get("chunk_method") not in valid_chunk_method:
|
281 |
return get_error_data_result(
|
@@ -1355,6 +1356,7 @@ def retrieval_test(tenant_id):
|
|
1355 |
doc_ids,
|
1356 |
rerank_mdl=rerank_mdl,
|
1357 |
highlight=highlight,
|
|
|
1358 |
)
|
1359 |
for c in ranks["chunks"]:
|
1360 |
c.pop("vector", None)
|
|
|
16 |
import pathlib
|
17 |
import datetime
|
18 |
|
19 |
+
from api.db.services.dialog_service import keyword_extraction, label_question
|
20 |
from rag.app.qa import rmPrefix, beAdoc
|
21 |
from rag.nlp import rag_tokenizer
|
22 |
from api.db import LLMType, ParserType
|
|
|
276 |
"one",
|
277 |
"knowledge_graph",
|
278 |
"email",
|
279 |
+
"tag"
|
280 |
}
|
281 |
if req.get("chunk_method") not in valid_chunk_method:
|
282 |
return get_error_data_result(
|
|
|
1356 |
doc_ids,
|
1357 |
rerank_mdl=rerank_mdl,
|
1358 |
highlight=highlight,
|
1359 |
+
rank_feature=label_question(question, kbs)
|
1360 |
)
|
1361 |
for c in ranks["chunks"]:
|
1362 |
c.pop("vector", None)
|
api/db/__init__.py
CHANGED
@@ -89,6 +89,7 @@ class ParserType(StrEnum):
|
|
89 |
AUDIO = "audio"
|
90 |
EMAIL = "email"
|
91 |
KG = "knowledge_graph"
|
|
|
92 |
|
93 |
|
94 |
class FileSource(StrEnum):
|
|
|
89 |
AUDIO = "audio"
|
90 |
EMAIL = "email"
|
91 |
KG = "knowledge_graph"
|
92 |
+
TAG = "tag"
|
93 |
|
94 |
|
95 |
class FileSource(StrEnum):
|
api/db/init_data.py
CHANGED
@@ -133,7 +133,7 @@ def init_llm_factory():
|
|
133 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
134 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
|
135 |
TenantService.filter_update([1 == 1], {
|
136 |
-
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"})
|
137 |
## insert openai two embedding models to the current openai user.
|
138 |
# print("Start to insert 2 OpenAI embedding models...")
|
139 |
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
|
@@ -153,14 +153,7 @@ def init_llm_factory():
|
|
153 |
break
|
154 |
for kb_id in KnowledgebaseService.get_all_ids():
|
155 |
KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
|
156 |
-
|
157 |
-
drop table llm;
|
158 |
-
drop table llm_factories;
|
159 |
-
update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
|
160 |
-
alter table knowledgebase modify avatar longtext;
|
161 |
-
alter table user modify avatar longtext;
|
162 |
-
alter table dialog modify icon longtext;
|
163 |
-
"""
|
164 |
|
165 |
|
166 |
def add_graph_templates():
|
|
|
133 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
|
134 |
TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
|
135 |
TenantService.filter_update([1 == 1], {
|
136 |
+
"parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
|
137 |
## insert openai two embedding models to the current openai user.
|
138 |
# print("Start to insert 2 OpenAI embedding models...")
|
139 |
tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
|
|
|
153 |
break
|
154 |
for kb_id in KnowledgebaseService.get_all_ids():
|
155 |
KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
|
156 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
|
159 |
def add_graph_templates():
|
api/db/services/dialog_service.py
CHANGED
@@ -29,8 +29,10 @@ from api.db.services.common_service import CommonService
|
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
30 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
31 |
from api import settings
|
|
|
32 |
from rag.app.resume import forbidden_select_fields4resume
|
33 |
from rag.nlp.search import index_name
|
|
|
34 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
35 |
from api.utils.file_utils import get_project_base_directory
|
36 |
|
@@ -135,6 +137,29 @@ def kb_prompt(kbinfos, max_tokens):
|
|
135 |
return knowledges
|
136 |
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
def chat(dialog, messages, stream=True, **kwargs):
|
139 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
140 |
|
@@ -236,11 +261,14 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
236 |
generate_keyword_ts = timer()
|
237 |
|
238 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
|
|
239 |
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
240 |
dialog.similarity_threshold,
|
241 |
dialog.vector_similarity_weight,
|
242 |
doc_ids=attachments,
|
243 |
-
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl
|
|
|
|
|
244 |
|
245 |
retrieval_ts = timer()
|
246 |
|
@@ -650,7 +678,10 @@ def ask(question, kb_ids, tenant_id):
|
|
650 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
651 |
max_tokens = chat_mdl.max_length
|
652 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
653 |
-
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
|
|
|
|
|
|
|
654 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
655 |
prompt = """
|
656 |
Role: You're a smart assistant. Your name is Miss R.
|
@@ -700,3 +731,56 @@ def ask(question, kb_ids, tenant_id):
|
|
700 |
answer = ans
|
701 |
yield {"answer": answer, "reference": {}}
|
702 |
yield decorate_answer(answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
from api.db.services.knowledgebase_service import KnowledgebaseService
|
30 |
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
31 |
from api import settings
|
32 |
+
from graphrag.utils import get_tags_from_cache, set_tags_to_cache
|
33 |
from rag.app.resume import forbidden_select_fields4resume
|
34 |
from rag.nlp.search import index_name
|
35 |
+
from rag.settings import TAG_FLD
|
36 |
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
37 |
from api.utils.file_utils import get_project_base_directory
|
38 |
|
|
|
137 |
return knowledges
|
138 |
|
139 |
|
140 |
+
def label_question(question, kbs):
|
141 |
+
tags = None
|
142 |
+
tag_kb_ids = []
|
143 |
+
for kb in kbs:
|
144 |
+
if kb.parser_config.get("tag_kb_ids"):
|
145 |
+
tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
|
146 |
+
if tag_kb_ids:
|
147 |
+
all_tags = get_tags_from_cache(tag_kb_ids)
|
148 |
+
if not all_tags:
|
149 |
+
all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
|
150 |
+
set_tags_to_cache(all_tags, tag_kb_ids)
|
151 |
+
else:
|
152 |
+
all_tags = json.loads(all_tags)
|
153 |
+
tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
|
154 |
+
tags = settings.retrievaler.tag_query(question,
|
155 |
+
list(set([kb.tenant_id for kb in tag_kbs])),
|
156 |
+
tag_kb_ids,
|
157 |
+
all_tags,
|
158 |
+
kb.parser_config.get("topn_tags", 3)
|
159 |
+
)
|
160 |
+
return tags
|
161 |
+
|
162 |
+
|
163 |
def chat(dialog, messages, stream=True, **kwargs):
|
164 |
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
165 |
|
|
|
261 |
generate_keyword_ts = timer()
|
262 |
|
263 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
264 |
+
|
265 |
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
266 |
dialog.similarity_threshold,
|
267 |
dialog.vector_similarity_weight,
|
268 |
doc_ids=attachments,
|
269 |
+
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
270 |
+
rank_feature=label_question(" ".join(questions), kbs)
|
271 |
+
)
|
272 |
|
273 |
retrieval_ts = timer()
|
274 |
|
|
|
678 |
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
679 |
max_tokens = chat_mdl.max_length
|
680 |
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
681 |
+
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
|
682 |
+
1, 12, 0.1, 0.3, aggs=False,
|
683 |
+
rank_feature=label_question(question, kbs)
|
684 |
+
)
|
685 |
knowledges = kb_prompt(kbinfos, max_tokens)
|
686 |
prompt = """
|
687 |
Role: You're a smart assistant. Your name is Miss R.
|
|
|
731 |
answer = ans
|
732 |
yield {"answer": answer, "reference": {}}
|
733 |
yield decorate_answer(answer)
|
734 |
+
|
735 |
+
|
736 |
+
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
737 |
+
prompt = f"""
|
738 |
+
Role: You're a text analyzer.
|
739 |
+
|
740 |
+
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
|
741 |
+
|
742 |
+
Steps::
|
743 |
+
- Comprehend the tag/label set.
|
744 |
+
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
|
745 |
+
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
|
746 |
+
|
747 |
+
Requirements
|
748 |
+
- The tags MUST be from the tag set.
|
749 |
+
- The output MUST be in JSON format only, the key is tag and the value is its relevance score.
|
750 |
+
- The relevance score must be range from 1 to 10.
|
751 |
+
- Keywords ONLY in output.
|
752 |
+
|
753 |
+
# TAG SET
|
754 |
+
{", ".join(all_tags)}
|
755 |
+
|
756 |
+
"""
|
757 |
+
for i, ex in enumerate(examples):
|
758 |
+
prompt += """
|
759 |
+
# Examples {}
|
760 |
+
### Text Content
|
761 |
+
{}
|
762 |
+
|
763 |
+
Output:
|
764 |
+
{}
|
765 |
+
|
766 |
+
""".format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
|
767 |
+
|
768 |
+
prompt += f"""
|
769 |
+
# Real Data
|
770 |
+
### Text Content
|
771 |
+
{content}
|
772 |
+
|
773 |
+
"""
|
774 |
+
msg = [
|
775 |
+
{"role": "system", "content": prompt},
|
776 |
+
{"role": "user", "content": "Output: "}
|
777 |
+
]
|
778 |
+
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
779 |
+
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
|
780 |
+
if isinstance(kwd, tuple):
|
781 |
+
kwd = kwd[0]
|
782 |
+
if kwd.find("**ERROR**") >= 0:
|
783 |
+
raise Exception(kwd)
|
784 |
+
|
785 |
+
kwd = re.sub(r".*?\{", "{", kwd)
|
786 |
+
return json.loads(kwd)
|
api/db/services/file2document_service.py
CHANGED
@@ -43,10 +43,7 @@ class File2DocumentService(CommonService):
|
|
43 |
def insert(cls, obj):
|
44 |
if not cls.save(**obj):
|
45 |
raise RuntimeError("Database error (File)!")
|
46 |
-
|
47 |
-
if not e:
|
48 |
-
raise RuntimeError("Database error (File retrieval)!")
|
49 |
-
return obj
|
50 |
|
51 |
@classmethod
|
52 |
@DB.connection_context()
|
@@ -63,9 +60,8 @@ class File2DocumentService(CommonService):
|
|
63 |
def update_by_file_id(cls, file_id, obj):
|
64 |
obj["update_time"] = current_timestamp()
|
65 |
obj["update_date"] = datetime_format(datetime.now())
|
66 |
-
|
67 |
-
|
68 |
-
return obj
|
69 |
|
70 |
@classmethod
|
71 |
@DB.connection_context()
|
|
|
43 |
def insert(cls, obj):
|
44 |
if not cls.save(**obj):
|
45 |
raise RuntimeError("Database error (File)!")
|
46 |
+
return File2Document(**obj)
|
|
|
|
|
|
|
47 |
|
48 |
@classmethod
|
49 |
@DB.connection_context()
|
|
|
60 |
def update_by_file_id(cls, file_id, obj):
|
61 |
obj["update_time"] = current_timestamp()
|
62 |
obj["update_date"] = datetime_format(datetime.now())
|
63 |
+
cls.model.update(obj).where(cls.model.id == file_id).execute()
|
64 |
+
return File2Document(**obj)
|
|
|
65 |
|
66 |
@classmethod
|
67 |
@DB.connection_context()
|
api/db/services/file_service.py
CHANGED
@@ -251,10 +251,7 @@ class FileService(CommonService):
|
|
251 |
def insert(cls, file):
|
252 |
if not cls.save(**file):
|
253 |
raise RuntimeError("Database error (File)!")
|
254 |
-
|
255 |
-
if not e:
|
256 |
-
raise RuntimeError("Database error (File retrieval)!")
|
257 |
-
return file
|
258 |
|
259 |
@classmethod
|
260 |
@DB.connection_context()
|
|
|
251 |
def insert(cls, file):
|
252 |
if not cls.save(**file):
|
253 |
raise RuntimeError("Database error (File)!")
|
254 |
+
return File(**file)
|
|
|
|
|
|
|
255 |
|
256 |
@classmethod
|
257 |
@DB.connection_context()
|
api/db/services/knowledgebase_service.py
CHANGED
@@ -35,7 +35,10 @@ class KnowledgebaseService(CommonService):
|
|
35 |
@classmethod
|
36 |
@DB.connection_context()
|
37 |
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
38 |
-
page_number, items_per_page,
|
|
|
|
|
|
|
39 |
fields = [
|
40 |
cls.model.id,
|
41 |
cls.model.avatar,
|
@@ -67,6 +70,8 @@ class KnowledgebaseService(CommonService):
|
|
67 |
cls.model.tenant_id == user_id))
|
68 |
& (cls.model.status == StatusEnum.VALID.value)
|
69 |
)
|
|
|
|
|
70 |
if desc:
|
71 |
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
72 |
else:
|
|
|
35 |
@classmethod
|
36 |
@DB.connection_context()
|
37 |
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
38 |
+
page_number, items_per_page,
|
39 |
+
orderby, desc, keywords,
|
40 |
+
parser_id=None
|
41 |
+
):
|
42 |
fields = [
|
43 |
cls.model.id,
|
44 |
cls.model.avatar,
|
|
|
70 |
cls.model.tenant_id == user_id))
|
71 |
& (cls.model.status == StatusEnum.VALID.value)
|
72 |
)
|
73 |
+
if parser_id:
|
74 |
+
kbs = kbs.where(cls.model.parser_id == parser_id)
|
75 |
if desc:
|
76 |
kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
|
77 |
else:
|
api/db/services/task_service.py
CHANGED
@@ -69,6 +69,7 @@ class TaskService(CommonService):
|
|
69 |
Knowledgebase.language,
|
70 |
Knowledgebase.embd_id,
|
71 |
Knowledgebase.pagerank,
|
|
|
72 |
Tenant.img2txt_id,
|
73 |
Tenant.asr_id,
|
74 |
Tenant.llm_id,
|
|
|
69 |
Knowledgebase.language,
|
70 |
Knowledgebase.embd_id,
|
71 |
Knowledgebase.pagerank,
|
72 |
+
Knowledgebase.parser_config.alias("kb_parser_config"),
|
73 |
Tenant.img2txt_id,
|
74 |
Tenant.asr_id,
|
75 |
Tenant.llm_id,
|
api/settings.py
CHANGED
@@ -140,7 +140,7 @@ def init_settings():
|
|
140 |
API_KEY = LLM.get("api_key", "")
|
141 |
PARSERS = LLM.get(
|
142 |
"parsers",
|
143 |
-
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
|
144 |
|
145 |
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
146 |
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
|
|
140 |
API_KEY = LLM.get("api_key", "")
|
141 |
PARSERS = LLM.get(
|
142 |
"parsers",
|
143 |
+
"naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag")
|
144 |
|
145 |
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
|
146 |
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
|
api/utils/api_utils.py
CHANGED
@@ -173,6 +173,7 @@ def validate_request(*args, **kwargs):
|
|
173 |
|
174 |
return wrapper
|
175 |
|
|
|
176 |
def not_allowed_parameters(*params):
|
177 |
def decorator(f):
|
178 |
def wrapper(*args, **kwargs):
|
@@ -182,7 +183,9 @@ def not_allowed_parameters(*params):
|
|
182 |
return get_json_result(
|
183 |
code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
184 |
return f(*args, **kwargs)
|
|
|
185 |
return wrapper
|
|
|
186 |
return decorator
|
187 |
|
188 |
|
@@ -207,6 +210,7 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None)
|
|
207 |
response = {"code": code, "message": message, "data": data}
|
208 |
return jsonify(response)
|
209 |
|
|
|
210 |
def apikey_required(func):
|
211 |
@wraps(func)
|
212 |
def decorated_function(*args, **kwargs):
|
@@ -282,17 +286,18 @@ def construct_error_response(e):
|
|
282 |
def token_required(func):
|
283 |
@wraps(func)
|
284 |
def decorated_function(*args, **kwargs):
|
285 |
-
authorization_str=flask_request.headers.get('Authorization')
|
286 |
if not authorization_str:
|
287 |
-
return get_json_result(data=False,message="`Authorization` can't be empty")
|
288 |
-
authorization_list=authorization_str.split()
|
289 |
if len(authorization_list) < 2:
|
290 |
-
return get_json_result(data=False,message="Please check your authorization format.")
|
291 |
token = authorization_list[1]
|
292 |
objs = APIToken.query(token=token)
|
293 |
if not objs:
|
294 |
return get_json_result(
|
295 |
-
data=False, message='Authentication error: API key is invalid!',
|
|
|
296 |
)
|
297 |
kwargs['tenant_id'] = objs[0].tenant_id
|
298 |
return func(*args, **kwargs)
|
@@ -330,35 +335,41 @@ def generate_confirmation_token(tenent_id):
|
|
330 |
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
331 |
|
332 |
|
333 |
-
def valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method):
|
334 |
-
if valid_parameter(permission,valid_permission):
|
335 |
-
return valid_parameter(permission,valid_permission)
|
336 |
-
if valid_parameter(language,valid_language):
|
337 |
-
return valid_parameter(language,valid_language)
|
338 |
-
if valid_parameter(chunk_method,valid_chunk_method):
|
339 |
-
return valid_parameter(chunk_method,valid_chunk_method)
|
|
|
340 |
|
341 |
-
def valid_parameter(parameter,valid_values):
|
342 |
if parameter and parameter not in valid_values:
|
343 |
-
|
|
|
344 |
|
345 |
-
def get_parser_config(chunk_method,parser_config):
|
346 |
if parser_config:
|
347 |
return parser_config
|
348 |
if not chunk_method:
|
349 |
chunk_method = "naive"
|
350 |
-
key_mapping=
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
return wrapper
|
175 |
|
176 |
+
|
177 |
def not_allowed_parameters(*params):
|
178 |
def decorator(f):
|
179 |
def wrapper(*args, **kwargs):
|
|
|
183 |
return get_json_result(
|
184 |
code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
|
185 |
return f(*args, **kwargs)
|
186 |
+
|
187 |
return wrapper
|
188 |
+
|
189 |
return decorator
|
190 |
|
191 |
|
|
|
210 |
response = {"code": code, "message": message, "data": data}
|
211 |
return jsonify(response)
|
212 |
|
213 |
+
|
214 |
def apikey_required(func):
|
215 |
@wraps(func)
|
216 |
def decorated_function(*args, **kwargs):
|
|
|
286 |
def token_required(func):
|
287 |
@wraps(func)
|
288 |
def decorated_function(*args, **kwargs):
|
289 |
+
authorization_str = flask_request.headers.get('Authorization')
|
290 |
if not authorization_str:
|
291 |
+
return get_json_result(data=False, message="`Authorization` can't be empty")
|
292 |
+
authorization_list = authorization_str.split()
|
293 |
if len(authorization_list) < 2:
|
294 |
+
return get_json_result(data=False, message="Please check your authorization format.")
|
295 |
token = authorization_list[1]
|
296 |
objs = APIToken.query(token=token)
|
297 |
if not objs:
|
298 |
return get_json_result(
|
299 |
+
data=False, message='Authentication error: API key is invalid!',
|
300 |
+
code=settings.RetCode.AUTHENTICATION_ERROR
|
301 |
)
|
302 |
kwargs['tenant_id'] = objs[0].tenant_id
|
303 |
return func(*args, **kwargs)
|
|
|
335 |
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
336 |
|
337 |
|
338 |
+
def valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method):
|
339 |
+
if valid_parameter(permission, valid_permission):
|
340 |
+
return valid_parameter(permission, valid_permission)
|
341 |
+
if valid_parameter(language, valid_language):
|
342 |
+
return valid_parameter(language, valid_language)
|
343 |
+
if valid_parameter(chunk_method, valid_chunk_method):
|
344 |
+
return valid_parameter(chunk_method, valid_chunk_method)
|
345 |
+
|
346 |
|
347 |
+
def valid_parameter(parameter, valid_values):
|
348 |
if parameter and parameter not in valid_values:
|
349 |
+
return get_error_data_result(f"'{parameter}' is not in {valid_values}")
|
350 |
+
|
351 |
|
352 |
+
def get_parser_config(chunk_method, parser_config):
|
353 |
if parser_config:
|
354 |
return parser_config
|
355 |
if not chunk_method:
|
356 |
chunk_method = "naive"
|
357 |
+
key_mapping = {
|
358 |
+
"naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": True,
|
359 |
+
"raptor": {"use_raptor": False}},
|
360 |
+
"qa": {"raptor": {"use_raptor": False}},
|
361 |
+
"tag": None,
|
362 |
+
"resume": None,
|
363 |
+
"manual": {"raptor": {"use_raptor": False}},
|
364 |
+
"table": None,
|
365 |
+
"paper": {"raptor": {"use_raptor": False}},
|
366 |
+
"book": {"raptor": {"use_raptor": False}},
|
367 |
+
"laws": {"raptor": {"use_raptor": False}},
|
368 |
+
"presentation": {"raptor": {"use_raptor": False}},
|
369 |
+
"one": None,
|
370 |
+
"knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?",
|
371 |
+
"entity_types": ["organization", "person", "location", "event", "time"]},
|
372 |
+
"email": None,
|
373 |
+
"picture": None}
|
374 |
+
parser_config = key_mapping[chunk_method]
|
375 |
+
return parser_config
|
conf/infinity_mapping.json
CHANGED
@@ -10,6 +10,7 @@
|
|
10 |
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
11 |
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
12 |
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
|
|
13 |
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
14 |
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
15 |
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
@@ -27,5 +28,6 @@
|
|
27 |
"available_int": {"type": "integer", "default": 1},
|
28 |
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
29 |
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
30 |
-
"pagerank_fea": {"type": "integer", "default": 0}
|
|
|
31 |
}
|
|
|
10 |
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
11 |
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
12 |
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
13 |
+
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
14 |
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
15 |
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
16 |
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
|
|
28 |
"available_int": {"type": "integer", "default": 1},
|
29 |
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
30 |
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
31 |
+
"pagerank_fea": {"type": "integer", "default": 0},
|
32 |
+
"tag_fea": {"type": "integer", "default": 0}
|
33 |
}
|
graphrag/utils.py
CHANGED
@@ -111,4 +111,23 @@ def set_embed_cache(llmnm, txt, arr):
|
|
111 |
|
112 |
k = hasher.hexdigest()
|
113 |
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
|
114 |
-
REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
k = hasher.hexdigest()
|
113 |
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
|
114 |
+
REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
|
115 |
+
|
116 |
+
|
117 |
+
def get_tags_from_cache(kb_ids):
|
118 |
+
hasher = xxhash.xxh64()
|
119 |
+
hasher.update(str(kb_ids).encode("utf-8"))
|
120 |
+
|
121 |
+
k = hasher.hexdigest()
|
122 |
+
bin = REDIS_CONN.get(k)
|
123 |
+
if not bin:
|
124 |
+
return
|
125 |
+
return bin
|
126 |
+
|
127 |
+
|
128 |
+
def set_tags_to_cache(kb_ids, tags):
|
129 |
+
hasher = xxhash.xxh64()
|
130 |
+
hasher.update(str(kb_ids).encode("utf-8"))
|
131 |
+
|
132 |
+
k = hasher.hexdigest()
|
133 |
+
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
rag/app/qa.py
CHANGED
@@ -26,6 +26,7 @@ from docx import Document
|
|
26 |
from PIL import Image
|
27 |
from markdown import markdown
|
28 |
|
|
|
29 |
class Excel(ExcelParser):
|
30 |
def __call__(self, fnm, binary=None, callback=None):
|
31 |
if not binary:
|
@@ -58,11 +59,11 @@ class Excel(ExcelParser):
|
|
58 |
if len(res) % 999 == 0:
|
59 |
callback(len(res) *
|
60 |
0.6 /
|
61 |
-
total, ("Extract
|
62 |
(f"{len(fails)} failure, line: %s..." %
|
63 |
(",".join(fails[:3])) if fails else "")))
|
64 |
|
65 |
-
callback(0.6, ("Extract
|
66 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
67 |
self.is_english = is_english(
|
68 |
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
|
@@ -269,7 +270,7 @@ def beAdocPdf(d, q, a, eng, image, poss):
|
|
269 |
return d
|
270 |
|
271 |
|
272 |
-
def beAdocDocx(d, q, a, eng, image):
|
273 |
qprefix = "Question: " if eng else "问题:"
|
274 |
aprefix = "Answer: " if eng else "回答:"
|
275 |
d["content_with_weight"] = "\t".join(
|
@@ -277,16 +278,20 @@ def beAdocDocx(d, q, a, eng, image):
|
|
277 |
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
278 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
279 |
d["image"] = image
|
|
|
|
|
280 |
return d
|
281 |
|
282 |
|
283 |
-
def beAdoc(d, q, a, eng):
|
284 |
qprefix = "Question: " if eng else "问题:"
|
285 |
aprefix = "Answer: " if eng else "回答:"
|
286 |
d["content_with_weight"] = "\t".join(
|
287 |
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
|
288 |
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
289 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
|
|
|
|
290 |
return d
|
291 |
|
292 |
|
@@ -316,8 +321,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
316 |
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
317 |
callback(0.1, "Start to parse.")
|
318 |
excel_parser = Excel()
|
319 |
-
for q, a in excel_parser(filename, binary, callback):
|
320 |
-
res.append(beAdoc(deepcopy(doc), q, a, eng))
|
321 |
return res
|
322 |
|
323 |
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
|
@@ -344,7 +349,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
344 |
fails.append(str(i+1))
|
345 |
elif len(arr) == 2:
|
346 |
if question and answer:
|
347 |
-
res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
348 |
question, answer = arr
|
349 |
i += 1
|
350 |
if len(res) % 999 == 0:
|
@@ -352,7 +357,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
352 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
353 |
|
354 |
if question:
|
355 |
-
res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
356 |
|
357 |
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
358 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
@@ -378,14 +383,14 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
378 |
fails.append(str(i + 1))
|
379 |
elif len(row) == 2:
|
380 |
if question and answer:
|
381 |
-
res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
382 |
question, answer = row
|
383 |
if len(res) % 999 == 0:
|
384 |
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
385 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
386 |
|
387 |
if question:
|
388 |
-
res.append(beAdoc(deepcopy(doc), question, answer, eng))
|
389 |
|
390 |
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
391 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
@@ -420,7 +425,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
420 |
if last_answer.strip():
|
421 |
sum_question = '\n'.join(question_stack)
|
422 |
if sum_question:
|
423 |
-
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng))
|
424 |
last_answer = ''
|
425 |
|
426 |
i = question_level
|
@@ -432,7 +437,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
432 |
if last_answer.strip():
|
433 |
sum_question = '\n'.join(question_stack)
|
434 |
if sum_question:
|
435 |
-
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng))
|
436 |
return res
|
437 |
|
438 |
elif re.search(r"\.docx$", filename, re.IGNORECASE):
|
@@ -440,8 +445,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
|
440 |
qai_list, tbls = docx_parser(filename, binary,
|
441 |
from_page=0, to_page=10000, callback=callback)
|
442 |
res = tokenize_table(tbls, doc, eng)
|
443 |
-
for q, a, image in qai_list:
|
444 |
-
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image))
|
445 |
return res
|
446 |
|
447 |
raise NotImplementedError(
|
|
|
26 |
from PIL import Image
|
27 |
from markdown import markdown
|
28 |
|
29 |
+
|
30 |
class Excel(ExcelParser):
|
31 |
def __call__(self, fnm, binary=None, callback=None):
|
32 |
if not binary:
|
|
|
59 |
if len(res) % 999 == 0:
|
60 |
callback(len(res) *
|
61 |
0.6 /
|
62 |
+
total, ("Extract pairs: {}".format(len(res)) +
|
63 |
(f"{len(fails)} failure, line: %s..." %
|
64 |
(",".join(fails[:3])) if fails else "")))
|
65 |
|
66 |
+
callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (
|
67 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
68 |
self.is_english = is_english(
|
69 |
[rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
|
|
|
270 |
return d
|
271 |
|
272 |
|
273 |
+
def beAdocDocx(d, q, a, eng, image, row_num=-1):
|
274 |
qprefix = "Question: " if eng else "问题:"
|
275 |
aprefix = "Answer: " if eng else "回答:"
|
276 |
d["content_with_weight"] = "\t".join(
|
|
|
278 |
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
279 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
280 |
d["image"] = image
|
281 |
+
if row_num >= 0:
|
282 |
+
d["top_int"] = [row_num]
|
283 |
return d
|
284 |
|
285 |
|
286 |
+
def beAdoc(d, q, a, eng, row_num=-1):
|
287 |
qprefix = "Question: " if eng else "问题:"
|
288 |
aprefix = "Answer: " if eng else "回答:"
|
289 |
d["content_with_weight"] = "\t".join(
|
290 |
[qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
|
291 |
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
292 |
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
293 |
+
if row_num >= 0:
|
294 |
+
d["top_int"] = [row_num]
|
295 |
return d
|
296 |
|
297 |
|
|
|
321 |
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
322 |
callback(0.1, "Start to parse.")
|
323 |
excel_parser = Excel()
|
324 |
+
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
|
325 |
+
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
|
326 |
return res
|
327 |
|
328 |
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
|
|
|
349 |
fails.append(str(i+1))
|
350 |
elif len(arr) == 2:
|
351 |
if question and answer:
|
352 |
+
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
|
353 |
question, answer = arr
|
354 |
i += 1
|
355 |
if len(res) % 999 == 0:
|
|
|
357 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
358 |
|
359 |
if question:
|
360 |
+
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines)))
|
361 |
|
362 |
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
363 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
|
|
383 |
fails.append(str(i + 1))
|
384 |
elif len(row) == 2:
|
385 |
if question and answer:
|
386 |
+
res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
|
387 |
question, answer = row
|
388 |
if len(res) % 999 == 0:
|
389 |
callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
|
390 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
391 |
|
392 |
if question:
|
393 |
+
res.append(beAdoc(deepcopy(doc), question, answer, eng, len(reader)))
|
394 |
|
395 |
callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
|
396 |
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
|
|
425 |
if last_answer.strip():
|
426 |
sum_question = '\n'.join(question_stack)
|
427 |
if sum_question:
|
428 |
+
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
429 |
last_answer = ''
|
430 |
|
431 |
i = question_level
|
|
|
437 |
if last_answer.strip():
|
438 |
sum_question = '\n'.join(question_stack)
|
439 |
if sum_question:
|
440 |
+
res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
|
441 |
return res
|
442 |
|
443 |
elif re.search(r"\.docx$", filename, re.IGNORECASE):
|
|
|
445 |
qai_list, tbls = docx_parser(filename, binary,
|
446 |
from_page=0, to_page=10000, callback=callback)
|
447 |
res = tokenize_table(tbls, doc, eng)
|
448 |
+
for i, (q, a, image) in enumerate(qai_list):
|
449 |
+
res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
|
450 |
return res
|
451 |
|
452 |
raise NotImplementedError(
|
rag/app/tag.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
2 |
+
# you may not use this file except in compliance with the License.
|
3 |
+
# You may obtain a copy of the License at
|
4 |
+
#
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
#
|
7 |
+
# Unless required by applicable law or agreed to in writing, software
|
8 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
9 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10 |
+
# See the License for the specific language governing permissions and
|
11 |
+
# limitations under the License.
|
12 |
+
#
|
13 |
+
import re
|
14 |
+
import csv
|
15 |
+
from copy import deepcopy
|
16 |
+
|
17 |
+
from deepdoc.parser.utils import get_text
|
18 |
+
from rag.app.qa import Excel
|
19 |
+
from rag.nlp import rag_tokenizer
|
20 |
+
|
21 |
+
|
22 |
+
def beAdoc(d, q, a, eng, row_num=-1):
|
23 |
+
d["content_with_weight"] = q
|
24 |
+
d["content_ltks"] = rag_tokenizer.tokenize(q)
|
25 |
+
d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
|
26 |
+
d["tag_kwd"] = [t.strip() for t in a.split(",") if t.strip()]
|
27 |
+
if row_num >= 0:
|
28 |
+
d["top_int"] = [row_num]
|
29 |
+
return d
|
30 |
+
|
31 |
+
|
32 |
+
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
33 |
+
"""
|
34 |
+
Excel and csv(txt) format files are supported.
|
35 |
+
If the file is in excel format, there should be 2 column content and tags without header.
|
36 |
+
And content column is ahead of tags column.
|
37 |
+
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
38 |
+
|
39 |
+
If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
|
40 |
+
|
41 |
+
All the deformed lines will be ignored.
|
42 |
+
Every pair will be treated as a chunk.
|
43 |
+
"""
|
44 |
+
eng = lang.lower() == "english"
|
45 |
+
res = []
|
46 |
+
doc = {
|
47 |
+
"docnm_kwd": filename,
|
48 |
+
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
49 |
+
}
|
50 |
+
if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
51 |
+
callback(0.1, "Start to parse.")
|
52 |
+
excel_parser = Excel()
|
53 |
+
for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
|
54 |
+
res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
|
55 |
+
return res
|
56 |
+
|
57 |
+
elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
|
58 |
+
callback(0.1, "Start to parse.")
|
59 |
+
txt = get_text(filename, binary)
|
60 |
+
lines = txt.split("\n")
|
61 |
+
comma, tab = 0, 0
|
62 |
+
for line in lines:
|
63 |
+
if len(line.split(",")) == 2:
|
64 |
+
comma += 1
|
65 |
+
if len(line.split("\t")) == 2:
|
66 |
+
tab += 1
|
67 |
+
delimiter = "\t" if tab >= comma else ","
|
68 |
+
|
69 |
+
fails = []
|
70 |
+
content = ""
|
71 |
+
i = 0
|
72 |
+
while i < len(lines):
|
73 |
+
arr = lines[i].split(delimiter)
|
74 |
+
if len(arr) != 2:
|
75 |
+
content += "\n" + lines[i]
|
76 |
+
elif len(arr) == 2:
|
77 |
+
content += "\n" + arr[0]
|
78 |
+
res.append(beAdoc(deepcopy(doc), content, arr[1], eng, i))
|
79 |
+
content = ""
|
80 |
+
i += 1
|
81 |
+
if len(res) % 999 == 0:
|
82 |
+
callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (
|
83 |
+
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
84 |
+
|
85 |
+
callback(0.6, ("Extract TAG: {}".format(len(res)) + (
|
86 |
+
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
87 |
+
|
88 |
+
return res
|
89 |
+
|
90 |
+
elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
|
91 |
+
callback(0.1, "Start to parse.")
|
92 |
+
txt = get_text(filename, binary)
|
93 |
+
lines = txt.split("\n")
|
94 |
+
delimiter = "\t" if any("\t" in line for line in lines) else ","
|
95 |
+
|
96 |
+
fails = []
|
97 |
+
content = ""
|
98 |
+
res = []
|
99 |
+
reader = csv.reader(lines, delimiter=delimiter)
|
100 |
+
|
101 |
+
for i, row in enumerate(reader):
|
102 |
+
if len(row) != 2:
|
103 |
+
content += "\n" + lines[i]
|
104 |
+
elif len(row) == 2:
|
105 |
+
content += "\n" + row[0]
|
106 |
+
res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
|
107 |
+
content = ""
|
108 |
+
if len(res) % 999 == 0:
|
109 |
+
callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (
|
110 |
+
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
111 |
+
|
112 |
+
callback(0.6, ("Extract TAG : {}".format(len(res)) + (
|
113 |
+
f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
|
114 |
+
return res
|
115 |
+
|
116 |
+
raise NotImplementedError(
|
117 |
+
"Excel, csv(txt) format files are supported.")
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
import sys
|
122 |
+
|
123 |
+
def dummy(prog=None, msg=""):
|
124 |
+
pass
|
125 |
+
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
rag/nlp/query.py
CHANGED
@@ -59,13 +59,15 @@ class FulltextQueryer:
|
|
59 |
"",
|
60 |
),
|
61 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
62 |
-
(
|
|
|
|
|
63 |
]
|
64 |
for r, p in patts:
|
65 |
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
66 |
return txt
|
67 |
|
68 |
-
def question(self, txt, tbl="qa", min_match:float=0.6):
|
69 |
txt = re.sub(
|
70 |
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
71 |
" ",
|
@@ -90,7 +92,8 @@ class FulltextQueryer:
|
|
90 |
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
91 |
syns.append(" ".join(syn))
|
92 |
|
93 |
-
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
|
|
94 |
for i in range(1, len(tks_w)):
|
95 |
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
96 |
if not left or not right:
|
@@ -155,7 +158,7 @@ class FulltextQueryer:
|
|
155 |
if len(keywords) < 32:
|
156 |
keywords.extend([s for s in tk_syns if s])
|
157 |
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
158 |
-
tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]
|
159 |
|
160 |
if len(keywords) >= 32:
|
161 |
break
|
@@ -174,8 +177,6 @@ class FulltextQueryer:
|
|
174 |
|
175 |
if len(twts) > 1:
|
176 |
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
177 |
-
if re.match(r"[0-9a-z ]+$", tt):
|
178 |
-
tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
|
179 |
|
180 |
syns = " OR ".join(
|
181 |
[
|
@@ -232,3 +233,25 @@ class FulltextQueryer:
|
|
232 |
for k, v in qtwt.items():
|
233 |
q += v
|
234 |
return s / q
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
"",
|
60 |
),
|
61 |
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
62 |
+
(
|
63 |
+
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
64 |
+
" ")
|
65 |
]
|
66 |
for r, p in patts:
|
67 |
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
68 |
return txt
|
69 |
|
70 |
+
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
71 |
txt = re.sub(
|
72 |
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
73 |
" ",
|
|
|
92 |
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
93 |
syns.append(" ".join(syn))
|
94 |
|
95 |
+
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
96 |
+
tk and not re.match(r"[.^+\(\)-]", tk)]
|
97 |
for i in range(1, len(tks_w)):
|
98 |
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
99 |
if not left or not right:
|
|
|
158 |
if len(keywords) < 32:
|
159 |
keywords.extend([s for s in tk_syns if s])
|
160 |
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
161 |
+
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
162 |
|
163 |
if len(keywords) >= 32:
|
164 |
break
|
|
|
177 |
|
178 |
if len(twts) > 1:
|
179 |
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
|
|
|
|
180 |
|
181 |
syns = " OR ".join(
|
182 |
[
|
|
|
233 |
for k, v in qtwt.items():
|
234 |
q += v
|
235 |
return s / q
|
236 |
+
|
237 |
+
def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
|
238 |
+
if isinstance(content_tks, str):
|
239 |
+
content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
|
240 |
+
tks_w = self.tw.weights(content_tks, preprocess=False)
|
241 |
+
|
242 |
+
keywords = [f'"{k.strip()}"' for k in keywords]
|
243 |
+
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
244 |
+
tk_syns = self.syn.lookup(tk)
|
245 |
+
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
246 |
+
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
247 |
+
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
248 |
+
tk = FulltextQueryer.subSpecialChar(tk)
|
249 |
+
if tk.find(" ") > 0:
|
250 |
+
tk = '"%s"' % tk
|
251 |
+
if tk_syns:
|
252 |
+
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
253 |
+
if tk:
|
254 |
+
keywords.append(f"{tk}^{w}")
|
255 |
+
|
256 |
+
return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
|
257 |
+
{"minimum_should_match": min(3, len(keywords) / 10)})
|
rag/nlp/search.py
CHANGED
@@ -13,11 +13,11 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
-
|
17 |
import logging
|
18 |
import re
|
19 |
from dataclasses import dataclass
|
20 |
|
|
|
21 |
from rag.utils import rmSpace
|
22 |
from rag.nlp import rag_tokenizer, query
|
23 |
import numpy as np
|
@@ -47,7 +47,8 @@ class Dealer:
|
|
47 |
qv, _ = emb_mdl.encode_queries(txt)
|
48 |
shape = np.array(qv).shape
|
49 |
if len(shape) > 1:
|
50 |
-
raise Exception(
|
|
|
51 |
embedding_data = [float(v) for v in qv]
|
52 |
vector_column_name = f"q_{len(embedding_data)}_vec"
|
53 |
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
@@ -63,7 +64,12 @@ class Dealer:
|
|
63 |
condition[key] = req[key]
|
64 |
return condition
|
65 |
|
66 |
-
def search(self, req, idx_names: str | list[str],
|
|
|
|
|
|
|
|
|
|
|
67 |
filters = self.get_filters(req)
|
68 |
orderBy = OrderByExpr()
|
69 |
|
@@ -72,9 +78,11 @@ class Dealer:
|
|
72 |
ps = int(req.get("size", topk))
|
73 |
offset, limit = pg * ps, ps
|
74 |
|
75 |
-
src = req.get("fields",
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
kwds = set([])
|
79 |
|
80 |
qst = req.get("question", "")
|
@@ -85,15 +93,16 @@ class Dealer:
|
|
85 |
orderBy.asc("top_int")
|
86 |
orderBy.desc("create_timestamp_flt")
|
87 |
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
88 |
-
total=self.dataStore.getTotal(res)
|
89 |
logging.debug("Dealer.search TOTAL: {}".format(total))
|
90 |
else:
|
91 |
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
92 |
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
93 |
if emb_mdl is None:
|
94 |
matchExprs = [matchText]
|
95 |
-
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
96 |
-
|
|
|
97 |
logging.debug("Dealer.search TOTAL: {}".format(total))
|
98 |
else:
|
99 |
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
@@ -103,8 +112,9 @@ class Dealer:
|
|
103 |
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
104 |
matchExprs = [matchText, matchDense, fusionExpr]
|
105 |
|
106 |
-
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
107 |
-
|
|
|
108 |
logging.debug("Dealer.search TOTAL: {}".format(total))
|
109 |
|
110 |
# If result is empty, try again with lower min_match
|
@@ -112,8 +122,9 @@ class Dealer:
|
|
112 |
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
113 |
filters.pop("doc_ids", None)
|
114 |
matchDense.extra_options["similarity"] = 0.17
|
115 |
-
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
116 |
-
|
|
|
117 |
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
118 |
|
119 |
for k in keywords:
|
@@ -126,8 +137,8 @@ class Dealer:
|
|
126 |
kwds.add(kk)
|
127 |
|
128 |
logging.debug(f"TOTAL: {total}")
|
129 |
-
ids=self.dataStore.getChunkIds(res)
|
130 |
-
keywords=list(kwds)
|
131 |
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
132 |
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
133 |
return self.SearchResult(
|
@@ -188,13 +199,13 @@ class Dealer:
|
|
188 |
|
189 |
ans_v, _ = embd_mdl.encode(pieces_)
|
190 |
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
191 |
-
|
192 |
|
193 |
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
|
194 |
for ck in chunks]
|
195 |
cites = {}
|
196 |
thr = 0.63
|
197 |
-
while thr>0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
|
198 |
for i, a in enumerate(pieces_):
|
199 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
|
200 |
chunk_v,
|
@@ -228,20 +239,44 @@ class Dealer:
|
|
228 |
|
229 |
return res, seted
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
def rerank(self, sres, query, tkweight=0.3,
|
232 |
-
vtweight=0.7, cfield="content_ltks"
|
|
|
|
|
233 |
_, keywords = self.qryr.question(query)
|
234 |
vector_size = len(sres.query_vector)
|
235 |
vector_column = f"q_{vector_size}_vec"
|
236 |
zero_vector = [0.0] * vector_size
|
237 |
ins_embd = []
|
238 |
-
pageranks = []
|
239 |
for chunk_id in sres.ids:
|
240 |
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
241 |
if isinstance(vector, str):
|
242 |
vector = [float(v) for v in vector.split("\t")]
|
243 |
ins_embd.append(vector)
|
244 |
-
pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
|
245 |
if not ins_embd:
|
246 |
return [], [], []
|
247 |
|
@@ -254,18 +289,22 @@ class Dealer:
|
|
254 |
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
|
255 |
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
|
256 |
important_kwd = sres.field[i].get("important_kwd", [])
|
257 |
-
tks = content_ltks + title_tks*2 + important_kwd*5 + question_tks*6
|
258 |
ins_tw.append(tks)
|
259 |
|
|
|
|
|
|
|
260 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
261 |
ins_embd,
|
262 |
keywords,
|
263 |
ins_tw, tkweight, vtweight)
|
264 |
|
265 |
-
return sim+
|
266 |
|
267 |
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
268 |
-
|
|
|
269 |
_, keywords = self.qryr.question(query)
|
270 |
|
271 |
for i in sres.ids:
|
@@ -280,9 +319,11 @@ class Dealer:
|
|
280 |
ins_tw.append(tks)
|
281 |
|
282 |
tksim = self.qryr.token_similarity(keywords, ins_tw)
|
283 |
-
vtsim,_ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
|
|
|
|
|
284 |
|
285 |
-
return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
|
286 |
|
287 |
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
288 |
return self.qryr.hybrid_similarity(ans_embd,
|
@@ -291,13 +332,15 @@ class Dealer:
|
|
291 |
rag_tokenizer.tokenize(inst).split())
|
292 |
|
293 |
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
|
294 |
-
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
|
|
|
|
|
295 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
296 |
if not question:
|
297 |
return ranks
|
298 |
|
299 |
RERANK_PAGE_LIMIT = 3
|
300 |
-
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
|
301 |
"question": question, "vector": True, "topk": top,
|
302 |
"similarity": similarity_threshold,
|
303 |
"available_int": 1}
|
@@ -309,29 +352,30 @@ class Dealer:
|
|
309 |
if isinstance(tenant_ids, str):
|
310 |
tenant_ids = tenant_ids.split(",")
|
311 |
|
312 |
-
sres = self.search(req, [index_name(tid) for tid in tenant_ids],
|
|
|
313 |
ranks["total"] = sres.total
|
314 |
|
315 |
if page <= RERANK_PAGE_LIMIT:
|
316 |
if rerank_mdl and sres.total > 0:
|
317 |
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
318 |
-
|
|
|
|
|
319 |
else:
|
320 |
sim, tsim, vsim = self.rerank(
|
321 |
-
sres, question, 1 - vector_similarity_weight, vector_similarity_weight
|
322 |
-
|
|
|
323 |
else:
|
324 |
-
sim = tsim = vsim = [1]*len(sres.ids)
|
325 |
idx = list(range(len(sres.ids)))
|
326 |
|
327 |
-
def floor_sim(score):
|
328 |
-
return (int(score * 100.)%100)/100.
|
329 |
-
|
330 |
dim = len(sres.query_vector)
|
331 |
vector_column = f"q_{dim}_vec"
|
332 |
zero_vector = [0.0] * dim
|
333 |
for i in idx:
|
334 |
-
if
|
335 |
break
|
336 |
if len(ranks["chunks"]) >= page_size:
|
337 |
if aggs:
|
@@ -369,8 +413,8 @@ class Dealer:
|
|
369 |
ranks["doc_aggs"] = [{"doc_name": k,
|
370 |
"doc_id": v["doc_id"],
|
371 |
"count": v["count"]} for k,
|
372 |
-
|
373 |
-
|
374 |
ranks["chunks"] = ranks["chunks"][:page_size]
|
375 |
|
376 |
return ranks
|
@@ -379,15 +423,57 @@ class Dealer:
|
|
379 |
tbl = self.dataStore.sql(sql, fetch_size, format)
|
380 |
return tbl
|
381 |
|
382 |
-
def chunk_list(self, doc_id: str, tenant_id: str,
|
|
|
|
|
|
|
383 |
condition = {"doc_id": doc_id}
|
384 |
res = []
|
385 |
bs = 128
|
386 |
-
for p in range(
|
387 |
-
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
|
|
|
388 |
dict_chunks = self.dataStore.getFields(es_res, fields)
|
389 |
if dict_chunks:
|
390 |
res.extend(dict_chunks.values())
|
391 |
if len(dict_chunks.values()) < bs:
|
392 |
break
|
393 |
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
16 |
import logging
|
17 |
import re
|
18 |
from dataclasses import dataclass
|
19 |
|
20 |
+
from rag.settings import TAG_FLD, PAGERANK_FLD
|
21 |
from rag.utils import rmSpace
|
22 |
from rag.nlp import rag_tokenizer, query
|
23 |
import numpy as np
|
|
|
47 |
qv, _ = emb_mdl.encode_queries(txt)
|
48 |
shape = np.array(qv).shape
|
49 |
if len(shape) > 1:
|
50 |
+
raise Exception(
|
51 |
+
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
|
52 |
embedding_data = [float(v) for v in qv]
|
53 |
vector_column_name = f"q_{len(embedding_data)}_vec"
|
54 |
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
|
|
64 |
condition[key] = req[key]
|
65 |
return condition
|
66 |
|
67 |
+
def search(self, req, idx_names: str | list[str],
|
68 |
+
kb_ids: list[str],
|
69 |
+
emb_mdl=None,
|
70 |
+
highlight=False,
|
71 |
+
rank_feature: dict | None = None
|
72 |
+
):
|
73 |
filters = self.get_filters(req)
|
74 |
orderBy = OrderByExpr()
|
75 |
|
|
|
78 |
ps = int(req.get("size", topk))
|
79 |
offset, limit = pg * ps, ps
|
80 |
|
81 |
+
src = req.get("fields",
|
82 |
+
["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
|
83 |
+
"doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
|
84 |
+
"question_kwd", "question_tks",
|
85 |
+
"available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
|
86 |
kwds = set([])
|
87 |
|
88 |
qst = req.get("question", "")
|
|
|
93 |
orderBy.asc("top_int")
|
94 |
orderBy.desc("create_timestamp_flt")
|
95 |
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
96 |
+
total = self.dataStore.getTotal(res)
|
97 |
logging.debug("Dealer.search TOTAL: {}".format(total))
|
98 |
else:
|
99 |
highlightFields = ["content_ltks", "title_tks"] if highlight else []
|
100 |
matchText, keywords = self.qryr.question(qst, min_match=0.3)
|
101 |
if emb_mdl is None:
|
102 |
matchExprs = [matchText]
|
103 |
+
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
104 |
+
idx_names, kb_ids, rank_feature=rank_feature)
|
105 |
+
total = self.dataStore.getTotal(res)
|
106 |
logging.debug("Dealer.search TOTAL: {}".format(total))
|
107 |
else:
|
108 |
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
|
|
112 |
fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
|
113 |
matchExprs = [matchText, matchDense, fusionExpr]
|
114 |
|
115 |
+
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
116 |
+
idx_names, kb_ids, rank_feature=rank_feature)
|
117 |
+
total = self.dataStore.getTotal(res)
|
118 |
logging.debug("Dealer.search TOTAL: {}".format(total))
|
119 |
|
120 |
# If result is empty, try again with lower min_match
|
|
|
122 |
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
123 |
filters.pop("doc_ids", None)
|
124 |
matchDense.extra_options["similarity"] = 0.17
|
125 |
+
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
126 |
+
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
127 |
+
total = self.dataStore.getTotal(res)
|
128 |
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
129 |
|
130 |
for k in keywords:
|
|
|
137 |
kwds.add(kk)
|
138 |
|
139 |
logging.debug(f"TOTAL: {total}")
|
140 |
+
ids = self.dataStore.getChunkIds(res)
|
141 |
+
keywords = list(kwds)
|
142 |
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
143 |
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
144 |
return self.SearchResult(
|
|
|
199 |
|
200 |
ans_v, _ = embd_mdl.encode(pieces_)
|
201 |
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
|
202 |
+
len(ans_v[0]), len(chunk_v[0]))
|
203 |
|
204 |
chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
|
205 |
for ck in chunks]
|
206 |
cites = {}
|
207 |
thr = 0.63
|
208 |
+
while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
|
209 |
for i, a in enumerate(pieces_):
|
210 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
|
211 |
chunk_v,
|
|
|
239 |
|
240 |
return res, seted
|
241 |
|
242 |
+
def _rank_feature_scores(self, query_rfea, search_res):
|
243 |
+
## For rank feature(tag_fea) scores.
|
244 |
+
rank_fea = []
|
245 |
+
pageranks = []
|
246 |
+
for chunk_id in search_res.ids:
|
247 |
+
pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0))
|
248 |
+
pageranks = np.array(pageranks, dtype=float)
|
249 |
+
|
250 |
+
if not query_rfea:
|
251 |
+
return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
|
252 |
+
|
253 |
+
q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
|
254 |
+
for i in search_res.ids:
|
255 |
+
nor, denor = 0, 0
|
256 |
+
for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
|
257 |
+
if t in query_rfea:
|
258 |
+
nor += query_rfea[t] * sc
|
259 |
+
denor += sc * sc
|
260 |
+
if denor == 0:
|
261 |
+
rank_fea.append(0)
|
262 |
+
else:
|
263 |
+
rank_fea.append(nor/np.sqrt(denor)/q_denor)
|
264 |
+
return np.array(rank_fea)*10. + pageranks
|
265 |
+
|
266 |
def rerank(self, sres, query, tkweight=0.3,
|
267 |
+
vtweight=0.7, cfield="content_ltks",
|
268 |
+
rank_feature: dict | None = None
|
269 |
+
):
|
270 |
_, keywords = self.qryr.question(query)
|
271 |
vector_size = len(sres.query_vector)
|
272 |
vector_column = f"q_{vector_size}_vec"
|
273 |
zero_vector = [0.0] * vector_size
|
274 |
ins_embd = []
|
|
|
275 |
for chunk_id in sres.ids:
|
276 |
vector = sres.field[chunk_id].get(vector_column, zero_vector)
|
277 |
if isinstance(vector, str):
|
278 |
vector = [float(v) for v in vector.split("\t")]
|
279 |
ins_embd.append(vector)
|
|
|
280 |
if not ins_embd:
|
281 |
return [], [], []
|
282 |
|
|
|
289 |
title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
|
290 |
question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
|
291 |
important_kwd = sres.field[i].get("important_kwd", [])
|
292 |
+
tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6
|
293 |
ins_tw.append(tks)
|
294 |
|
295 |
+
## For rank feature(tag_fea) scores.
|
296 |
+
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
297 |
+
|
298 |
sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
|
299 |
ins_embd,
|
300 |
keywords,
|
301 |
ins_tw, tkweight, vtweight)
|
302 |
|
303 |
+
return sim + rank_fea, tksim, vtsim
|
304 |
|
305 |
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
|
306 |
+
vtweight=0.7, cfield="content_ltks",
|
307 |
+
rank_feature: dict | None = None):
|
308 |
_, keywords = self.qryr.question(query)
|
309 |
|
310 |
for i in sres.ids:
|
|
|
319 |
ins_tw.append(tks)
|
320 |
|
321 |
tksim = self.qryr.token_similarity(keywords, ins_tw)
|
322 |
+
vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
|
323 |
+
## For rank feature(tag_fea) scores.
|
324 |
+
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
325 |
|
326 |
+
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
|
327 |
|
328 |
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
329 |
return self.qryr.hybrid_similarity(ans_embd,
|
|
|
332 |
rag_tokenizer.tokenize(inst).split())
|
333 |
|
334 |
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
|
335 |
+
vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
|
336 |
+
rerank_mdl=None, highlight=False,
|
337 |
+
rank_feature: dict | None = {PAGERANK_FLD: 10}):
|
338 |
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
|
339 |
if not question:
|
340 |
return ranks
|
341 |
|
342 |
RERANK_PAGE_LIMIT = 3
|
343 |
+
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size * RERANK_PAGE_LIMIT, 128),
|
344 |
"question": question, "vector": True, "topk": top,
|
345 |
"similarity": similarity_threshold,
|
346 |
"available_int": 1}
|
|
|
352 |
if isinstance(tenant_ids, str):
|
353 |
tenant_ids = tenant_ids.split(",")
|
354 |
|
355 |
+
sres = self.search(req, [index_name(tid) for tid in tenant_ids],
|
356 |
+
kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
|
357 |
ranks["total"] = sres.total
|
358 |
|
359 |
if page <= RERANK_PAGE_LIMIT:
|
360 |
if rerank_mdl and sres.total > 0:
|
361 |
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
|
362 |
+
sres, question, 1 - vector_similarity_weight,
|
363 |
+
vector_similarity_weight,
|
364 |
+
rank_feature=rank_feature)
|
365 |
else:
|
366 |
sim, tsim, vsim = self.rerank(
|
367 |
+
sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
|
368 |
+
rank_feature=rank_feature)
|
369 |
+
idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
|
370 |
else:
|
371 |
+
sim = tsim = vsim = [1] * len(sres.ids)
|
372 |
idx = list(range(len(sres.ids)))
|
373 |
|
|
|
|
|
|
|
374 |
dim = len(sres.query_vector)
|
375 |
vector_column = f"q_{dim}_vec"
|
376 |
zero_vector = [0.0] * dim
|
377 |
for i in idx:
|
378 |
+
if sim[i] < similarity_threshold:
|
379 |
break
|
380 |
if len(ranks["chunks"]) >= page_size:
|
381 |
if aggs:
|
|
|
413 |
ranks["doc_aggs"] = [{"doc_name": k,
|
414 |
"doc_id": v["doc_id"],
|
415 |
"count": v["count"]} for k,
|
416 |
+
v in sorted(ranks["doc_aggs"].items(),
|
417 |
+
key=lambda x: x[1]["count"] * -1)]
|
418 |
ranks["chunks"] = ranks["chunks"][:page_size]
|
419 |
|
420 |
return ranks
|
|
|
423 |
tbl = self.dataStore.sql(sql, fetch_size, format)
|
424 |
return tbl
|
425 |
|
426 |
+
def chunk_list(self, doc_id: str, tenant_id: str,
|
427 |
+
kb_ids: list[str], max_count=1024,
|
428 |
+
offset=0,
|
429 |
+
fields=["docnm_kwd", "content_with_weight", "img_id"]):
|
430 |
condition = {"doc_id": doc_id}
|
431 |
res = []
|
432 |
bs = 128
|
433 |
+
for p in range(offset, max_count, bs):
|
434 |
+
es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
|
435 |
+
kb_ids)
|
436 |
dict_chunks = self.dataStore.getFields(es_res, fields)
|
437 |
if dict_chunks:
|
438 |
res.extend(dict_chunks.values())
|
439 |
if len(dict_chunks.values()) < bs:
|
440 |
break
|
441 |
return res
|
442 |
+
|
443 |
+
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
|
444 |
+
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
445 |
+
return self.dataStore.getAggregation(res, "tag_kwd")
|
446 |
+
|
447 |
+
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
|
448 |
+
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
449 |
+
res = self.dataStore.getAggregation(res, "tag_kwd")
|
450 |
+
total = np.sum([c for _, c in res])
|
451 |
+
return {t: (c + 1) / (total + S) for t, c in res}
|
452 |
+
|
453 |
+
def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
|
454 |
+
idx_nm = index_name(tenant_id)
|
455 |
+
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
|
456 |
+
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
457 |
+
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
458 |
+
if not aggs:
|
459 |
+
return False
|
460 |
+
cnt = np.sum([c for _, c in aggs])
|
461 |
+
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
|
462 |
+
key=lambda x: x[1] * -1)[:topn_tags]
|
463 |
+
doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0}
|
464 |
+
return True
|
465 |
+
|
466 |
+
def tag_query(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000):
|
467 |
+
if isinstance(tenant_ids, str):
|
468 |
+
idx_nms = index_name(tenant_ids)
|
469 |
+
else:
|
470 |
+
idx_nms = [index_name(tid) for tid in tenant_ids]
|
471 |
+
match_txt, _ = self.qryr.question(question, min_match=0.0)
|
472 |
+
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
|
473 |
+
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
474 |
+
if not aggs:
|
475 |
+
return {}
|
476 |
+
cnt = np.sum([c for _, c in aggs])
|
477 |
+
tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
|
478 |
+
key=lambda x: x[1] * -1)[:topn_tags]
|
479 |
+
return {a: c for a, c in tag_fea if c > 0}
|
rag/settings.py
CHANGED
@@ -38,6 +38,9 @@ SVR_QUEUE_RETENTION = 60*60
|
|
38 |
SVR_QUEUE_MAX_LEN = 1024
|
39 |
SVR_CONSUMER_NAME = "rag_flow_svr_consumer"
|
40 |
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group"
|
|
|
|
|
|
|
41 |
|
42 |
def print_rag_settings():
|
43 |
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
|
|
38 |
SVR_QUEUE_MAX_LEN = 1024
|
39 |
SVR_CONSUMER_NAME = "rag_flow_svr_consumer"
|
40 |
SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group"
|
41 |
+
PAGERANK_FLD = "pagerank_fea"
|
42 |
+
TAG_FLD = "tag_feas"
|
43 |
+
|
44 |
|
45 |
def print_rag_settings():
|
46 |
logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
|
rag/svr/task_executor.py
CHANGED
@@ -16,10 +16,10 @@
|
|
16 |
# from beartype import BeartypeConf
|
17 |
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
18 |
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
19 |
-
|
20 |
import sys
|
21 |
from api.utils.log_utils import initRootLogger
|
22 |
-
from graphrag.utils import get_llm_cache, set_llm_cache
|
23 |
|
24 |
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
25 |
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
@@ -44,7 +44,7 @@ import numpy as np
|
|
44 |
from peewee import DoesNotExist
|
45 |
|
46 |
from api.db import LLMType, ParserType, TaskStatus
|
47 |
-
from api.db.services.dialog_service import keyword_extraction, question_proposal
|
48 |
from api.db.services.document_service import DocumentService
|
49 |
from api.db.services.llm_service import LLMBundle
|
50 |
from api.db.services.task_service import TaskService
|
@@ -53,10 +53,10 @@ from api import settings
|
|
53 |
from api.versions import get_ragflow_version
|
54 |
from api.db.db_models import close_connection
|
55 |
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
56 |
-
knowledge_graph, email
|
57 |
from rag.nlp import search, rag_tokenizer
|
58 |
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
59 |
-
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
|
60 |
from rag.utils import num_tokens_from_string
|
61 |
from rag.utils.redis_conn import REDIS_CONN, Payload
|
62 |
from rag.utils.storage_factory import STORAGE_IMPL
|
@@ -78,7 +78,8 @@ FACTORY = {
|
|
78 |
ParserType.ONE.value: one,
|
79 |
ParserType.AUDIO.value: audio,
|
80 |
ParserType.EMAIL.value: email,
|
81 |
-
ParserType.KG.value: knowledge_graph
|
|
|
82 |
}
|
83 |
|
84 |
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
@@ -199,7 +200,8 @@ def build_chunks(task, progress_callback):
|
|
199 |
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
200 |
except TimeoutError:
|
201 |
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
202 |
-
logging.exception(
|
|
|
203 |
raise
|
204 |
except Exception as e:
|
205 |
if re.search("(No such file|not found)", str(e)):
|
@@ -227,7 +229,7 @@ def build_chunks(task, progress_callback):
|
|
227 |
"kb_id": str(task["kb_id"])
|
228 |
}
|
229 |
if task["pagerank"]:
|
230 |
-
doc[
|
231 |
el = 0
|
232 |
for ck in cks:
|
233 |
d = copy.deepcopy(doc)
|
@@ -252,7 +254,8 @@ def build_chunks(task, progress_callback):
|
|
252 |
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
|
253 |
el += timer() - st
|
254 |
except Exception:
|
255 |
-
logging.exception(
|
|
|
256 |
raise
|
257 |
|
258 |
d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
|
@@ -295,12 +298,43 @@ def build_chunks(task, progress_callback):
|
|
295 |
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
296 |
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
|
297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
return docs
|
299 |
|
300 |
|
301 |
def init_kb(row, vector_size: int):
|
302 |
idxnm = search.index_name(row["tenant_id"])
|
303 |
-
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size)
|
304 |
|
305 |
|
306 |
def embedding(docs, mdl, parser_config=None, callback=None):
|
@@ -381,7 +415,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
|
|
381 |
"title_tks": rag_tokenizer.tokenize(row["name"])
|
382 |
}
|
383 |
if row["pagerank"]:
|
384 |
-
doc[
|
385 |
res = []
|
386 |
tk_count = 0
|
387 |
for content, vctr in chunks[original_length:]:
|
@@ -480,7 +514,8 @@ def do_handle_task(task):
|
|
480 |
doc_store_result = ""
|
481 |
es_bulk_size = 4
|
482 |
for b in range(0, len(chunks), es_bulk_size):
|
483 |
-
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
|
|
|
484 |
if b % 128 == 0:
|
485 |
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
486 |
if doc_store_result:
|
@@ -493,15 +528,21 @@ def do_handle_task(task):
|
|
493 |
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
494 |
except DoesNotExist:
|
495 |
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
496 |
-
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
|
|
|
497 |
return
|
498 |
-
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
|
|
|
|
499 |
|
500 |
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
501 |
|
502 |
time_cost = timer() - start_ts
|
503 |
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
|
504 |
-
logging.info(
|
|
|
|
|
|
|
505 |
|
506 |
|
507 |
def handle_task():
|
|
|
16 |
# from beartype import BeartypeConf
|
17 |
# from beartype.claw import beartype_all # <-- you didn't sign up for this
|
18 |
# beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
|
19 |
+
import random
|
20 |
import sys
|
21 |
from api.utils.log_utils import initRootLogger
|
22 |
+
from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
|
23 |
|
24 |
CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
|
25 |
CONSUMER_NAME = "task_executor_" + CONSUMER_NO
|
|
|
44 |
from peewee import DoesNotExist
|
45 |
|
46 |
from api.db import LLMType, ParserType, TaskStatus
|
47 |
+
from api.db.services.dialog_service import keyword_extraction, question_proposal, content_tagging
|
48 |
from api.db.services.document_service import DocumentService
|
49 |
from api.db.services.llm_service import LLMBundle
|
50 |
from api.db.services.task_service import TaskService
|
|
|
53 |
from api.versions import get_ragflow_version
|
54 |
from api.db.db_models import close_connection
|
55 |
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
56 |
+
knowledge_graph, email, tag
|
57 |
from rag.nlp import search, rag_tokenizer
|
58 |
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
59 |
+
from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
|
60 |
from rag.utils import num_tokens_from_string
|
61 |
from rag.utils.redis_conn import REDIS_CONN, Payload
|
62 |
from rag.utils.storage_factory import STORAGE_IMPL
|
|
|
78 |
ParserType.ONE.value: one,
|
79 |
ParserType.AUDIO.value: audio,
|
80 |
ParserType.EMAIL.value: email,
|
81 |
+
ParserType.KG.value: knowledge_graph,
|
82 |
+
ParserType.TAG.value: tag
|
83 |
}
|
84 |
|
85 |
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
|
|
200 |
logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
|
201 |
except TimeoutError:
|
202 |
progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
|
203 |
+
logging.exception(
|
204 |
+
"Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
|
205 |
raise
|
206 |
except Exception as e:
|
207 |
if re.search("(No such file|not found)", str(e)):
|
|
|
229 |
"kb_id": str(task["kb_id"])
|
230 |
}
|
231 |
if task["pagerank"]:
|
232 |
+
doc[PAGERANK_FLD] = int(task["pagerank"])
|
233 |
el = 0
|
234 |
for ck in cks:
|
235 |
d = copy.deepcopy(doc)
|
|
|
254 |
STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
|
255 |
el += timer() - st
|
256 |
except Exception:
|
257 |
+
logging.exception(
|
258 |
+
"Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
|
259 |
raise
|
260 |
|
261 |
d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
|
|
|
298 |
d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
|
299 |
progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
|
300 |
|
301 |
+
if task["kb_parser_config"].get("tag_kb_ids", []):
|
302 |
+
progress_callback(msg="Start to tag for every chunk ...")
|
303 |
+
kb_ids = task["kb_parser_config"]["tag_kb_ids"]
|
304 |
+
tenant_id = task["tenant_id"]
|
305 |
+
topn_tags = task["kb_parser_config"].get("topn_tags", 3)
|
306 |
+
S = 1000
|
307 |
+
st = timer()
|
308 |
+
examples = []
|
309 |
+
all_tags = get_tags_from_cache(kb_ids)
|
310 |
+
if not all_tags:
|
311 |
+
all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
|
312 |
+
set_tags_to_cache(kb_ids, all_tags)
|
313 |
+
else:
|
314 |
+
all_tags = json.loads(all_tags)
|
315 |
+
|
316 |
+
chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
|
317 |
+
for d in docs:
|
318 |
+
if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
|
319 |
+
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
320 |
+
continue
|
321 |
+
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
|
322 |
+
if not cached:
|
323 |
+
cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags,
|
324 |
+
random.choices(examples, k=2) if len(examples)>2 else examples,
|
325 |
+
topn=topn_tags)
|
326 |
+
if cached:
|
327 |
+
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
|
328 |
+
d[TAG_FLD] = json.loads(cached)
|
329 |
+
|
330 |
+
progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
|
331 |
+
|
332 |
return docs
|
333 |
|
334 |
|
335 |
def init_kb(row, vector_size: int):
|
336 |
idxnm = search.index_name(row["tenant_id"])
|
337 |
+
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
338 |
|
339 |
|
340 |
def embedding(docs, mdl, parser_config=None, callback=None):
|
|
|
415 |
"title_tks": rag_tokenizer.tokenize(row["name"])
|
416 |
}
|
417 |
if row["pagerank"]:
|
418 |
+
doc[PAGERANK_FLD] = int(row["pagerank"])
|
419 |
res = []
|
420 |
tk_count = 0
|
421 |
for content, vctr in chunks[original_length:]:
|
|
|
514 |
doc_store_result = ""
|
515 |
es_bulk_size = 4
|
516 |
for b in range(0, len(chunks), es_bulk_size):
|
517 |
+
doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
|
518 |
+
task_dataset_id)
|
519 |
if b % 128 == 0:
|
520 |
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
521 |
if doc_store_result:
|
|
|
528 |
TaskService.update_chunk_ids(task["id"], chunk_ids_str)
|
529 |
except DoesNotExist:
|
530 |
logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
|
531 |
+
doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
|
532 |
+
task_dataset_id)
|
533 |
return
|
534 |
+
logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
|
535 |
+
task_to_page, len(chunks),
|
536 |
+
timer() - start_ts))
|
537 |
|
538 |
DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
|
539 |
|
540 |
time_cost = timer() - start_ts
|
541 |
progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
|
542 |
+
logging.info(
|
543 |
+
"Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
|
544 |
+
task_to_page, len(chunks),
|
545 |
+
token_count, time_cost))
|
546 |
|
547 |
|
548 |
def handle_task():
|
rag/utils/__init__.py
CHANGED
@@ -71,11 +71,13 @@ def findMaxTm(fnm):
|
|
71 |
pass
|
72 |
return m
|
73 |
|
|
|
74 |
tiktoken_cache_dir = get_project_base_directory()
|
75 |
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
76 |
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
77 |
encoder = tiktoken.get_encoding("cl100k_base")
|
78 |
|
|
|
79 |
def num_tokens_from_string(string: str) -> int:
|
80 |
"""Returns the number of tokens in a text string."""
|
81 |
try:
|
|
|
71 |
pass
|
72 |
return m
|
73 |
|
74 |
+
|
75 |
tiktoken_cache_dir = get_project_base_directory()
|
76 |
os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
|
77 |
# encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
78 |
encoder = tiktoken.get_encoding("cl100k_base")
|
79 |
|
80 |
+
|
81 |
def num_tokens_from_string(string: str) -> int:
|
82 |
"""Returns the number of tokens in a text string."""
|
83 |
try:
|
rag/utils/doc_store_conn.py
CHANGED
@@ -176,7 +176,17 @@ class DocStoreConnection(ABC):
|
|
176 |
|
177 |
@abstractmethod
|
178 |
def search(
|
179 |
-
self, selectFields: list[str],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
) -> list[dict] | pl.DataFrame:
|
181 |
"""
|
182 |
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
@@ -191,7 +201,7 @@ class DocStoreConnection(ABC):
|
|
191 |
raise NotImplementedError("Not implemented")
|
192 |
|
193 |
@abstractmethod
|
194 |
-
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
195 |
"""
|
196 |
Update or insert a bulk of rows
|
197 |
"""
|
|
|
176 |
|
177 |
@abstractmethod
|
178 |
def search(
|
179 |
+
self, selectFields: list[str],
|
180 |
+
highlightFields: list[str],
|
181 |
+
condition: dict,
|
182 |
+
matchExprs: list[MatchExpr],
|
183 |
+
orderBy: OrderByExpr,
|
184 |
+
offset: int,
|
185 |
+
limit: int,
|
186 |
+
indexNames: str|list[str],
|
187 |
+
knowledgebaseIds: list[str],
|
188 |
+
aggFields: list[str] = [],
|
189 |
+
rank_feature: dict | None = None
|
190 |
) -> list[dict] | pl.DataFrame:
|
191 |
"""
|
192 |
Search with given conjunctive equivalent filtering condition and return all fields of matched documents
|
|
|
201 |
raise NotImplementedError("Not implemented")
|
202 |
|
203 |
@abstractmethod
|
204 |
+
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
205 |
"""
|
206 |
Update or insert a bulk of rows
|
207 |
"""
|
rag/utils/es_conn.py
CHANGED
@@ -9,6 +9,7 @@ from elasticsearch import Elasticsearch, NotFoundError
|
|
9 |
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
10 |
from elastic_transport import ConnectionTimeout
|
11 |
from rag import settings
|
|
|
12 |
from rag.utils import singleton
|
13 |
from api.utils.file_utils import get_project_base_directory
|
14 |
import polars as pl
|
@@ -20,6 +21,7 @@ ATTEMPT_TIME = 2
|
|
20 |
|
21 |
logger = logging.getLogger('ragflow.es_conn')
|
22 |
|
|
|
23 |
@singleton
|
24 |
class ESConnection(DocStoreConnection):
|
25 |
def __init__(self):
|
@@ -111,9 +113,19 @@ class ESConnection(DocStoreConnection):
|
|
111 |
CRUD operations
|
112 |
"""
|
113 |
|
114 |
-
def search(
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
"""
|
118 |
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
119 |
"""
|
@@ -175,8 +187,13 @@ class ESConnection(DocStoreConnection):
|
|
175 |
similarity=similarity,
|
176 |
)
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
if bqry:
|
179 |
-
bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
|
180 |
s = s.query(bqry)
|
181 |
for field in highlightFields:
|
182 |
s = s.highlight(field)
|
@@ -187,7 +204,7 @@ class ESConnection(DocStoreConnection):
|
|
187 |
order = "asc" if order == 0 else "desc"
|
188 |
if field in ["page_num_int", "top_int"]:
|
189 |
order_info = {"order": order, "unmapped_type": "float",
|
190 |
-
|
191 |
elif field.endswith("_int") or field.endswith("_flt"):
|
192 |
order_info = {"order": order, "unmapped_type": "float"}
|
193 |
else:
|
@@ -195,8 +212,11 @@ class ESConnection(DocStoreConnection):
|
|
195 |
orders.append({field: order_info})
|
196 |
s = s.sort(*orders)
|
197 |
|
|
|
|
|
|
|
198 |
if limit > 0:
|
199 |
-
s = s[offset:offset+limit]
|
200 |
q = s.to_dict()
|
201 |
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
|
202 |
|
@@ -240,7 +260,7 @@ class ESConnection(DocStoreConnection):
|
|
240 |
logger.error("ESConnection.get timeout for 3 times!")
|
241 |
raise Exception("ESConnection.get timeout.")
|
242 |
|
243 |
-
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
|
244 |
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
245 |
operations = []
|
246 |
for d in documents:
|
@@ -292,44 +312,57 @@ class ESConnection(DocStoreConnection):
|
|
292 |
if str(e).find("Timeout") > 0:
|
293 |
continue
|
294 |
return False
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
continue
|
316 |
-
if (not isinstance(k, str) or not v) and k != "available_int":
|
317 |
-
continue
|
318 |
if isinstance(v, str):
|
319 |
-
scripts.append(f"ctx._source.
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
ubq = UpdateByQuery(
|
326 |
index=indexName).using(
|
327 |
self.es).query(bqry)
|
328 |
-
ubq = ubq.script(source="
|
329 |
ubq = ubq.params(refresh=True)
|
330 |
ubq = ubq.params(slices=5)
|
331 |
ubq = ubq.params(conflicts="proceed")
|
332 |
-
|
|
|
333 |
try:
|
334 |
_ = ubq.execute()
|
335 |
return True
|
|
|
9 |
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
10 |
from elastic_transport import ConnectionTimeout
|
11 |
from rag import settings
|
12 |
+
from rag.settings import TAG_FLD, PAGERANK_FLD
|
13 |
from rag.utils import singleton
|
14 |
from api.utils.file_utils import get_project_base_directory
|
15 |
import polars as pl
|
|
|
21 |
|
22 |
logger = logging.getLogger('ragflow.es_conn')
|
23 |
|
24 |
+
|
25 |
@singleton
|
26 |
class ESConnection(DocStoreConnection):
|
27 |
def __init__(self):
|
|
|
113 |
CRUD operations
|
114 |
"""
|
115 |
|
116 |
+
def search(
|
117 |
+
self, selectFields: list[str],
|
118 |
+
highlightFields: list[str],
|
119 |
+
condition: dict,
|
120 |
+
matchExprs: list[MatchExpr],
|
121 |
+
orderBy: OrderByExpr,
|
122 |
+
offset: int,
|
123 |
+
limit: int,
|
124 |
+
indexNames: str | list[str],
|
125 |
+
knowledgebaseIds: list[str],
|
126 |
+
aggFields: list[str] = [],
|
127 |
+
rank_feature: dict | None = None
|
128 |
+
) -> list[dict] | pl.DataFrame:
|
129 |
"""
|
130 |
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
131 |
"""
|
|
|
187 |
similarity=similarity,
|
188 |
)
|
189 |
|
190 |
+
if bqry and rank_feature:
|
191 |
+
for fld, sc in rank_feature.items():
|
192 |
+
if fld != PAGERANK_FLD:
|
193 |
+
fld = f"{TAG_FLD}.{fld}"
|
194 |
+
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
195 |
+
|
196 |
if bqry:
|
|
|
197 |
s = s.query(bqry)
|
198 |
for field in highlightFields:
|
199 |
s = s.highlight(field)
|
|
|
204 |
order = "asc" if order == 0 else "desc"
|
205 |
if field in ["page_num_int", "top_int"]:
|
206 |
order_info = {"order": order, "unmapped_type": "float",
|
207 |
+
"mode": "avg", "numeric_type": "double"}
|
208 |
elif field.endswith("_int") or field.endswith("_flt"):
|
209 |
order_info = {"order": order, "unmapped_type": "float"}
|
210 |
else:
|
|
|
212 |
orders.append({field: order_info})
|
213 |
s = s.sort(*orders)
|
214 |
|
215 |
+
for fld in aggFields:
|
216 |
+
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
217 |
+
|
218 |
if limit > 0:
|
219 |
+
s = s[offset:offset + limit]
|
220 |
q = s.to_dict()
|
221 |
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
|
222 |
|
|
|
260 |
logger.error("ESConnection.get timeout for 3 times!")
|
261 |
raise Exception("ESConnection.get timeout.")
|
262 |
|
263 |
+
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
264 |
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
265 |
operations = []
|
266 |
for d in documents:
|
|
|
312 |
if str(e).find("Timeout") > 0:
|
313 |
continue
|
314 |
return False
|
315 |
+
|
316 |
+
# update unspecific maybe-multiple documents
|
317 |
+
bqry = Q("bool")
|
318 |
+
for k, v in condition.items():
|
319 |
+
if not isinstance(k, str) or not v:
|
320 |
+
continue
|
321 |
+
if k == "exist":
|
322 |
+
bqry.filter.append(Q("exists", field=v))
|
323 |
+
continue
|
324 |
+
if isinstance(v, list):
|
325 |
+
bqry.filter.append(Q("terms", **{k: v}))
|
326 |
+
elif isinstance(v, str) or isinstance(v, int):
|
327 |
+
bqry.filter.append(Q("term", **{k: v}))
|
328 |
+
else:
|
329 |
+
raise Exception(
|
330 |
+
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
331 |
+
scripts = []
|
332 |
+
params = {}
|
333 |
+
for k, v in newValue.items():
|
334 |
+
if k == "remove":
|
|
|
|
|
|
|
335 |
if isinstance(v, str):
|
336 |
+
scripts.append(f"ctx._source.remove('{v}');")
|
337 |
+
if isinstance(v, dict):
|
338 |
+
for kk, vv in v.items():
|
339 |
+
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
|
340 |
+
params[f"p_{kk}"] = vv
|
341 |
+
continue
|
342 |
+
if k == "add":
|
343 |
+
if isinstance(v, dict):
|
344 |
+
for kk, vv in v.items():
|
345 |
+
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
|
346 |
+
params[f"pp_{kk}"] = vv.strip()
|
347 |
+
continue
|
348 |
+
if (not isinstance(k, str) or not v) and k != "available_int":
|
349 |
+
continue
|
350 |
+
if isinstance(v, str):
|
351 |
+
scripts.append(f"ctx._source.{k} = '{v}'")
|
352 |
+
elif isinstance(v, int):
|
353 |
+
scripts.append(f"ctx._source.{k} = {v}")
|
354 |
+
else:
|
355 |
+
raise Exception(
|
356 |
+
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
357 |
ubq = UpdateByQuery(
|
358 |
index=indexName).using(
|
359 |
self.es).query(bqry)
|
360 |
+
ubq = ubq.script(source="".join(scripts), params=params)
|
361 |
ubq = ubq.params(refresh=True)
|
362 |
ubq = ubq.params(slices=5)
|
363 |
ubq = ubq.params(conflicts="proceed")
|
364 |
+
|
365 |
+
for _ in range(ATTEMPT_TIME):
|
366 |
try:
|
367 |
_ = ubq.execute()
|
368 |
return True
|
rag/utils/infinity_conn.py
CHANGED
@@ -10,6 +10,7 @@ from infinity.index import IndexInfo, IndexType
|
|
10 |
from infinity.connection_pool import ConnectionPool
|
11 |
from infinity.errors import ErrorCode
|
12 |
from rag import settings
|
|
|
13 |
from rag.utils import singleton
|
14 |
import polars as pl
|
15 |
from polars.series.series import Series
|
@@ -231,8 +232,7 @@ class InfinityConnection(DocStoreConnection):
|
|
231 |
"""
|
232 |
|
233 |
def search(
|
234 |
-
self,
|
235 |
-
selectFields: list[str],
|
236 |
highlightFields: list[str],
|
237 |
condition: dict,
|
238 |
matchExprs: list[MatchExpr],
|
@@ -241,7 +241,9 @@ class InfinityConnection(DocStoreConnection):
|
|
241 |
limit: int,
|
242 |
indexNames: str | list[str],
|
243 |
knowledgebaseIds: list[str],
|
244 |
-
|
|
|
|
|
245 |
"""
|
246 |
TODO: Infinity doesn't provide highlight
|
247 |
"""
|
@@ -256,7 +258,7 @@ class InfinityConnection(DocStoreConnection):
|
|
256 |
if essential_field not in selectFields:
|
257 |
selectFields.append(essential_field)
|
258 |
if matchExprs:
|
259 |
-
for essential_field in ["score()",
|
260 |
selectFields.append(essential_field)
|
261 |
|
262 |
# Prepare expressions common to all tables
|
@@ -346,7 +348,7 @@ class InfinityConnection(DocStoreConnection):
|
|
346 |
self.connPool.release_conn(inf_conn)
|
347 |
res = concat_dataframes(df_list, selectFields)
|
348 |
if matchExprs:
|
349 |
-
res = res.sort(pl.col("SCORE") + pl.col(
|
350 |
res = res.limit(limit)
|
351 |
logger.debug(f"INFINITY search final result: {str(res)}")
|
352 |
return res, total_hits_count
|
@@ -378,7 +380,7 @@ class InfinityConnection(DocStoreConnection):
|
|
378 |
return res_fields.get(chunkId, None)
|
379 |
|
380 |
def insert(
|
381 |
-
self, documents: list[dict], indexName: str, knowledgebaseId: str
|
382 |
) -> list[str]:
|
383 |
inf_conn = self.connPool.get_conn()
|
384 |
db_instance = inf_conn.get_database(self.dbName)
|
@@ -456,7 +458,7 @@ class InfinityConnection(DocStoreConnection):
|
|
456 |
elif k in ["page_num_int", "top_int"]:
|
457 |
assert isinstance(v, list)
|
458 |
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
459 |
-
elif k == "remove" and v in [
|
460 |
del newValue[k]
|
461 |
newValue[v] = 0
|
462 |
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
|
|
10 |
from infinity.connection_pool import ConnectionPool
|
11 |
from infinity.errors import ErrorCode
|
12 |
from rag import settings
|
13 |
+
from rag.settings import PAGERANK_FLD
|
14 |
from rag.utils import singleton
|
15 |
import polars as pl
|
16 |
from polars.series.series import Series
|
|
|
232 |
"""
|
233 |
|
234 |
def search(
|
235 |
+
self, selectFields: list[str],
|
|
|
236 |
highlightFields: list[str],
|
237 |
condition: dict,
|
238 |
matchExprs: list[MatchExpr],
|
|
|
241 |
limit: int,
|
242 |
indexNames: str | list[str],
|
243 |
knowledgebaseIds: list[str],
|
244 |
+
aggFields: list[str] = [],
|
245 |
+
rank_feature: dict | None = None
|
246 |
+
) -> list[dict] | pl.DataFrame:
|
247 |
"""
|
248 |
TODO: Infinity doesn't provide highlight
|
249 |
"""
|
|
|
258 |
if essential_field not in selectFields:
|
259 |
selectFields.append(essential_field)
|
260 |
if matchExprs:
|
261 |
+
for essential_field in ["score()", PAGERANK_FLD]:
|
262 |
selectFields.append(essential_field)
|
263 |
|
264 |
# Prepare expressions common to all tables
|
|
|
348 |
self.connPool.release_conn(inf_conn)
|
349 |
res = concat_dataframes(df_list, selectFields)
|
350 |
if matchExprs:
|
351 |
+
res = res.sort(pl.col("SCORE") + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
|
352 |
res = res.limit(limit)
|
353 |
logger.debug(f"INFINITY search final result: {str(res)}")
|
354 |
return res, total_hits_count
|
|
|
380 |
return res_fields.get(chunkId, None)
|
381 |
|
382 |
def insert(
|
383 |
+
self, documents: list[dict], indexName: str, knowledgebaseId: str = None
|
384 |
) -> list[str]:
|
385 |
inf_conn = self.connPool.get_conn()
|
386 |
db_instance = inf_conn.get_database(self.dbName)
|
|
|
458 |
elif k in ["page_num_int", "top_int"]:
|
459 |
assert isinstance(v, list)
|
460 |
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
461 |
+
elif k == "remove" and v in [PAGERANK_FLD]:
|
462 |
del newValue[k]
|
463 |
newValue[v] = 0
|
464 |
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
sdk/python/test/test_sdk_api/t_dataset.py
CHANGED
@@ -27,7 +27,7 @@ def test_create_dataset_with_invalid_parameter(get_api_key_fixture):
|
|
27 |
API_KEY = get_api_key_fixture
|
28 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
29 |
valid_chunk_methods = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one",
|
30 |
-
"knowledge_graph", "email"]
|
31 |
chunk_method = "invalid_chunk_method"
|
32 |
with pytest.raises(Exception) as exc_info:
|
33 |
rag.create_dataset("test_create_dataset_with_invalid_chunk_method",chunk_method=chunk_method)
|
|
|
27 |
API_KEY = get_api_key_fixture
|
28 |
rag = RAGFlow(API_KEY, HOST_ADDRESS)
|
29 |
valid_chunk_methods = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one",
|
30 |
+
"knowledge_graph", "email", "tag"]
|
31 |
chunk_method = "invalid_chunk_method"
|
32 |
with pytest.raises(Exception) as exc_info:
|
33 |
rag.create_dataset("test_create_dataset_with_invalid_chunk_method",chunk_method=chunk_method)
|