Kevin Hu commited on
Commit
6c8312a
·
1 Parent(s): e547053

Tagging (#4426)

Browse files

### What problem does this PR solve?

#4367

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

agent/component/retrieval.py CHANGED
@@ -19,6 +19,7 @@ from abc import ABC
19
  import pandas as pd
20
 
21
  from api.db import LLMType
 
22
  from api.db.services.knowledgebase_service import KnowledgebaseService
23
  from api.db.services.llm_service import LLMBundle
24
  from api import settings
@@ -70,7 +71,8 @@ class Retrieval(ComponentBase, ABC):
70
  kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
71
  1, self._param.top_n,
72
  self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
73
- aggs=False, rerank_mdl=rerank_mdl)
 
74
 
75
  if not kbinfos["chunks"]:
76
  df = Retrieval.be_output("")
 
19
  import pandas as pd
20
 
21
  from api.db import LLMType
22
+ from api.db.services.dialog_service import label_question
23
  from api.db.services.knowledgebase_service import KnowledgebaseService
24
  from api.db.services.llm_service import LLMBundle
25
  from api import settings
 
71
  kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
72
  1, self._param.top_n,
73
  self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
74
+ aggs=False, rerank_mdl=rerank_mdl,
75
+ rank_feature=label_question(query, kbs))
76
 
77
  if not kbinfos["chunks"]:
78
  df = Retrieval.be_output("")
api/apps/api_app.py CHANGED
@@ -25,7 +25,7 @@ from api.db import FileType, LLMType, ParserType, FileSource
25
  from api.db.db_models import APIToken, Task, File
26
  from api.db.services import duplicate_name
27
  from api.db.services.api_service import APITokenService, API4ConversationService
28
- from api.db.services.dialog_service import DialogService, chat, keyword_extraction
29
  from api.db.services.document_service import DocumentService, doc_upload_and_parse
30
  from api.db.services.file2document_service import File2DocumentService
31
  from api.db.services.file_service import FileService
@@ -840,7 +840,8 @@ def retrieval():
840
  question += keyword_extraction(chat_mdl, question)
841
  ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
842
  similarity_threshold, vector_similarity_weight, top,
843
- doc_ids, rerank_mdl=rerank_mdl)
 
844
  for c in ranks["chunks"]:
845
  c.pop("vector", None)
846
  return get_json_result(data=ranks)
 
25
  from api.db.db_models import APIToken, Task, File
26
  from api.db.services import duplicate_name
27
  from api.db.services.api_service import APITokenService, API4ConversationService
28
+ from api.db.services.dialog_service import DialogService, chat, keyword_extraction, label_question
29
  from api.db.services.document_service import DocumentService, doc_upload_and_parse
30
  from api.db.services.file2document_service import File2DocumentService
31
  from api.db.services.file_service import FileService
 
840
  question += keyword_extraction(chat_mdl, question)
841
  ranks = settings.retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
842
  similarity_threshold, vector_similarity_weight, top,
843
+ doc_ids, rerank_mdl=rerank_mdl,
844
+ rank_feature=label_question(question, kbs))
845
  for c in ranks["chunks"]:
846
  c.pop("vector", None)
847
  return get_json_result(data=ranks)
api/apps/chunk_app.py CHANGED
@@ -19,9 +19,10 @@ import json
19
  from flask import request
20
  from flask_login import login_required, current_user
21
 
22
- from api.db.services.dialog_service import keyword_extraction
23
  from rag.app.qa import rmPrefix, beAdoc
24
  from rag.nlp import search, rag_tokenizer
 
25
  from rag.utils import rmSpace
26
  from api.db import LLMType, ParserType
27
  from api.db.services.knowledgebase_service import KnowledgebaseService
@@ -124,10 +125,14 @@ def set():
124
  "content_with_weight": req["content_with_weight"]}
125
  d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
126
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
127
- d["important_kwd"] = req["important_kwd"]
128
- d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
129
- d["question_kwd"] = req["question_kwd"]
130
- d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
 
 
 
 
131
  if "available_int" in req:
132
  d["available_int"] = req["available_int"]
133
 
@@ -220,7 +225,7 @@ def create():
220
  e, doc = DocumentService.get_by_id(req["doc_id"])
221
  if not e:
222
  return get_data_error_result(message="Document not found!")
223
- d["kb_id"] = doc.kb_id
224
  d["docnm_kwd"] = doc.name
225
  d["title_tks"] = rag_tokenizer.tokenize(doc.name)
226
  d["doc_id"] = doc.id
@@ -233,7 +238,7 @@ def create():
233
  if not e:
234
  return get_data_error_result(message="Knowledgebase not found!")
235
  if kb.pagerank:
236
- d["pagerank_fea"] = kb.pagerank
237
 
238
  embd_id = DocumentService.get_embd_id(req["doc_id"])
239
  embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
@@ -294,12 +299,16 @@ def retrieval_test():
294
  chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
295
  question += keyword_extraction(chat_mdl, question)
296
 
 
297
  retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
298
  ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
299
  similarity_threshold, vector_similarity_weight, top,
300
- doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"))
 
 
301
  for c in ranks["chunks"]:
302
  c.pop("vector", None)
 
303
 
304
  return get_json_result(data=ranks)
305
  except Exception as e:
 
19
  from flask import request
20
  from flask_login import login_required, current_user
21
 
22
+ from api.db.services.dialog_service import keyword_extraction, label_question
23
  from rag.app.qa import rmPrefix, beAdoc
24
  from rag.nlp import search, rag_tokenizer
25
+ from rag.settings import PAGERANK_FLD
26
  from rag.utils import rmSpace
27
  from api.db import LLMType, ParserType
28
  from api.db.services.knowledgebase_service import KnowledgebaseService
 
125
  "content_with_weight": req["content_with_weight"]}
126
  d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"])
127
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
128
+ if req.get("important_kwd"):
129
+ d["important_kwd"] = req["important_kwd"]
130
+ d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"]))
131
+ if req.get("question_kwd"):
132
+ d["question_kwd"] = req["question_kwd"]
133
+ d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"]))
134
+ if req.get("tag_kwd"):
135
+ d["tag_kwd"] = req["tag_kwd"]
136
  if "available_int" in req:
137
  d["available_int"] = req["available_int"]
138
 
 
225
  e, doc = DocumentService.get_by_id(req["doc_id"])
226
  if not e:
227
  return get_data_error_result(message="Document not found!")
228
+ d["kb_id"] = [doc.kb_id]
229
  d["docnm_kwd"] = doc.name
230
  d["title_tks"] = rag_tokenizer.tokenize(doc.name)
231
  d["doc_id"] = doc.id
 
238
  if not e:
239
  return get_data_error_result(message="Knowledgebase not found!")
240
  if kb.pagerank:
241
+ d[PAGERANK_FLD] = kb.pagerank
242
 
243
  embd_id = DocumentService.get_embd_id(req["doc_id"])
244
  embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id)
 
299
  chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
300
  question += keyword_extraction(chat_mdl, question)
301
 
302
+ labels = label_question(question, [kb])
303
  retr = settings.retrievaler if kb.parser_id != ParserType.KG else settings.kg_retrievaler
304
  ranks = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
305
  similarity_threshold, vector_similarity_weight, top,
306
+ doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
307
+ rank_feature=labels
308
+ )
309
  for c in ranks["chunks"]:
310
  c.pop("vector", None)
311
+ ranks["labels"] = labels
312
 
313
  return get_json_result(data=ranks)
314
  except Exception as e:
api/apps/conversation_app.py CHANGED
@@ -25,7 +25,7 @@ from flask import request, Response
25
  from flask_login import login_required, current_user
26
 
27
  from api.db import LLMType
28
- from api.db.services.dialog_service import DialogService, chat, ask
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
30
  from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
31
  from api import settings
@@ -379,8 +379,11 @@ def mindmap():
379
  embd_mdl = TenantLLMService.model_instance(
380
  kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
381
  chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
382
- ranks = settings.retrievaler.retrieval(req["question"], embd_mdl, kb.tenant_id, kb_ids, 1, 12,
383
- 0.3, 0.3, aggs=False)
 
 
 
384
  mindmap = MindMapExtractor(chat_mdl)
385
  mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
386
  if "error" in mind_map:
 
25
  from flask_login import login_required, current_user
26
 
27
  from api.db import LLMType
28
+ from api.db.services.dialog_service import DialogService, chat, ask, label_question
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
30
  from api.db.services.llm_service import LLMBundle, TenantService, TenantLLMService
31
  from api import settings
 
379
  embd_mdl = TenantLLMService.model_instance(
380
  kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
381
  chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
382
+ question = req["question"]
383
+ ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12,
384
+ 0.3, 0.3, aggs=False,
385
+ rank_feature=label_question(question, [kb])
386
+ )
387
  mindmap = MindMapExtractor(chat_mdl)
388
  mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
389
  if "error" in mind_map:
api/apps/kb_app.py CHANGED
@@ -30,6 +30,7 @@ from api.utils.api_utils import get_json_result
30
  from api import settings
31
  from rag.nlp import search
32
  from api.constants import DATASET_NAME_LIMIT
 
33
 
34
 
35
  @manager.route('/create', methods=['post']) # noqa: F821
@@ -104,11 +105,11 @@ def update():
104
 
105
  if kb.pagerank != req.get("pagerank", 0):
106
  if req.get("pagerank", 0) > 0:
107
- settings.docStoreConn.update({"kb_id": kb.id}, {"pagerank_fea": req["pagerank"]},
108
  search.index_name(kb.tenant_id), kb.id)
109
  else:
110
- # Elasticsearch requires pagerank_fea be non-zero!
111
- settings.docStoreConn.update({"exist": "pagerank_fea"}, {"remove": "pagerank_fea"},
112
  search.index_name(kb.tenant_id), kb.id)
113
 
114
  e, kb = KnowledgebaseService.get_by_id(kb.id)
@@ -150,12 +151,14 @@ def list_kbs():
150
  keywords = request.args.get("keywords", "")
151
  page_number = int(request.args.get("page", 1))
152
  items_per_page = int(request.args.get("page_size", 150))
 
153
  orderby = request.args.get("orderby", "create_time")
154
  desc = request.args.get("desc", True)
155
  try:
156
  tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
157
  kbs, total = KnowledgebaseService.get_by_tenant_ids(
158
- [m["tenant_id"] for m in tenants], current_user.id, page_number, items_per_page, orderby, desc, keywords)
 
159
  return get_json_result(data={"kbs": kbs, "total": total})
160
  except Exception as e:
161
  return server_error_response(e)
@@ -199,3 +202,72 @@ def rm():
199
  return get_json_result(data=True)
200
  except Exception as e:
201
  return server_error_response(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  from api import settings
31
  from rag.nlp import search
32
  from api.constants import DATASET_NAME_LIMIT
33
+ from rag.settings import PAGERANK_FLD
34
 
35
 
36
  @manager.route('/create', methods=['post']) # noqa: F821
 
105
 
106
  if kb.pagerank != req.get("pagerank", 0):
107
  if req.get("pagerank", 0) > 0:
108
+ settings.docStoreConn.update({"kb_id": kb.id}, {PAGERANK_FLD: req["pagerank"]},
109
  search.index_name(kb.tenant_id), kb.id)
110
  else:
111
+ # Elasticsearch requires PAGERANK_FLD be non-zero!
112
+ settings.docStoreConn.update({"exist": PAGERANK_FLD}, {"remove": PAGERANK_FLD},
113
  search.index_name(kb.tenant_id), kb.id)
114
 
115
  e, kb = KnowledgebaseService.get_by_id(kb.id)
 
151
  keywords = request.args.get("keywords", "")
152
  page_number = int(request.args.get("page", 1))
153
  items_per_page = int(request.args.get("page_size", 150))
154
+ parser_id = request.args.get("parser_id")
155
  orderby = request.args.get("orderby", "create_time")
156
  desc = request.args.get("desc", True)
157
  try:
158
  tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
159
  kbs, total = KnowledgebaseService.get_by_tenant_ids(
160
+ [m["tenant_id"] for m in tenants], current_user.id, page_number,
161
+ items_per_page, orderby, desc, keywords, parser_id)
162
  return get_json_result(data={"kbs": kbs, "total": total})
163
  except Exception as e:
164
  return server_error_response(e)
 
202
  return get_json_result(data=True)
203
  except Exception as e:
204
  return server_error_response(e)
205
+
206
+
207
+ @manager.route('/<kb_id>/tags', methods=['GET']) # noqa: F821
208
+ @login_required
209
+ def list_tags(kb_id):
210
+ if not KnowledgebaseService.accessible(kb_id, current_user.id):
211
+ return get_json_result(
212
+ data=False,
213
+ message='No authorization.',
214
+ code=settings.RetCode.AUTHENTICATION_ERROR
215
+ )
216
+
217
+ tags = settings.retrievaler.all_tags(current_user.id, [kb_id])
218
+ return get_json_result(data=tags)
219
+
220
+
221
+ @manager.route('/tags', methods=['GET']) # noqa: F821
222
+ @login_required
223
+ def list_tags_from_kbs():
224
+ kb_ids = request.args.get("kb_ids", "").split(",")
225
+ for kb_id in kb_ids:
226
+ if not KnowledgebaseService.accessible(kb_id, current_user.id):
227
+ return get_json_result(
228
+ data=False,
229
+ message='No authorization.',
230
+ code=settings.RetCode.AUTHENTICATION_ERROR
231
+ )
232
+
233
+ tags = settings.retrievaler.all_tags(current_user.id, kb_ids)
234
+ return get_json_result(data=tags)
235
+
236
+
237
+ @manager.route('/<kb_id>/rm_tags', methods=['POST']) # noqa: F821
238
+ @login_required
239
+ def rm_tags(kb_id):
240
+ req = request.json
241
+ if not KnowledgebaseService.accessible(kb_id, current_user.id):
242
+ return get_json_result(
243
+ data=False,
244
+ message='No authorization.',
245
+ code=settings.RetCode.AUTHENTICATION_ERROR
246
+ )
247
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
248
+
249
+ for t in req["tags"]:
250
+ settings.docStoreConn.update({"tag_kwd": t, "kb_id": [kb_id]},
251
+ {"remove": {"tag_kwd": t}},
252
+ search.index_name(kb.tenant_id),
253
+ kb_id)
254
+ return get_json_result(data=True)
255
+
256
+
257
+ @manager.route('/<kb_id>/rename_tag', methods=['POST']) # noqa: F821
258
+ @login_required
259
+ def rename_tags(kb_id):
260
+ req = request.json
261
+ if not KnowledgebaseService.accessible(kb_id, current_user.id):
262
+ return get_json_result(
263
+ data=False,
264
+ message='No authorization.',
265
+ code=settings.RetCode.AUTHENTICATION_ERROR
266
+ )
267
+ e, kb = KnowledgebaseService.get_by_id(kb_id)
268
+
269
+ settings.docStoreConn.update({"tag_kwd": req["from_tag"], "kb_id": [kb_id]},
270
+ {"remove": {"tag_kwd": req["from_tag"].strip()}, "add": {"tag_kwd": req["to_tag"]}},
271
+ search.index_name(kb.tenant_id),
272
+ kb_id)
273
+ return get_json_result(data=True)
api/apps/sdk/dataset.py CHANGED
@@ -73,7 +73,8 @@ def create(tenant_id):
73
  chunk_method:
74
  type: string
75
  enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
76
- "presentation", "picture", "one", "knowledge_graph", "email"]
 
77
  description: Chunking method.
78
  parser_config:
79
  type: object
@@ -108,6 +109,7 @@ def create(tenant_id):
108
  "one",
109
  "knowledge_graph",
110
  "email",
 
111
  ]
112
  check_validation = valid(
113
  permission,
@@ -302,7 +304,8 @@ def update(tenant_id, dataset_id):
302
  chunk_method:
303
  type: string
304
  enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
305
- "presentation", "picture", "one", "knowledge_graph", "email"]
 
306
  description: Updated chunking method.
307
  parser_config:
308
  type: object
@@ -339,6 +342,7 @@ def update(tenant_id, dataset_id):
339
  "one",
340
  "knowledge_graph",
341
  "email",
 
342
  ]
343
  check_validation = valid(
344
  permission,
 
73
  chunk_method:
74
  type: string
75
  enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
76
+ "presentation", "picture", "one", "knowledge_graph", "email", "tag"
77
+ ]
78
  description: Chunking method.
79
  parser_config:
80
  type: object
 
109
  "one",
110
  "knowledge_graph",
111
  "email",
112
+ "tag"
113
  ]
114
  check_validation = valid(
115
  permission,
 
304
  chunk_method:
305
  type: string
306
  enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
307
+ "presentation", "picture", "one", "knowledge_graph", "email", "tag"
308
+ ]
309
  description: Updated chunking method.
310
  parser_config:
311
  type: object
 
342
  "one",
343
  "knowledge_graph",
344
  "email",
345
+ "tag"
346
  ]
347
  check_validation = valid(
348
  permission,
api/apps/sdk/dify_retrieval.py CHANGED
@@ -16,6 +16,7 @@
16
  from flask import request, jsonify
17
 
18
  from api.db import LLMType, ParserType
 
19
  from api.db.services.knowledgebase_service import KnowledgebaseService
20
  from api.db.services.llm_service import LLMBundle
21
  from api import settings
@@ -54,7 +55,8 @@ def retrieval(tenant_id):
54
  page_size=top,
55
  similarity_threshold=similarity_threshold,
56
  vector_similarity_weight=0.3,
57
- top=top
 
58
  )
59
  records = []
60
  for c in ranks["chunks"]:
 
16
  from flask import request, jsonify
17
 
18
  from api.db import LLMType, ParserType
19
+ from api.db.services.dialog_service import label_question
20
  from api.db.services.knowledgebase_service import KnowledgebaseService
21
  from api.db.services.llm_service import LLMBundle
22
  from api import settings
 
55
  page_size=top,
56
  similarity_threshold=similarity_threshold,
57
  vector_similarity_weight=0.3,
58
+ top=top,
59
+ rank_feature=label_question(question, [kb])
60
  )
61
  records = []
62
  for c in ranks["chunks"]:
api/apps/sdk/doc.py CHANGED
@@ -16,7 +16,7 @@
16
  import pathlib
17
  import datetime
18
 
19
- from api.db.services.dialog_service import keyword_extraction
20
  from rag.app.qa import rmPrefix, beAdoc
21
  from rag.nlp import rag_tokenizer
22
  from api.db import LLMType, ParserType
@@ -276,6 +276,7 @@ def update_doc(tenant_id, dataset_id, document_id):
276
  "one",
277
  "knowledge_graph",
278
  "email",
 
279
  }
280
  if req.get("chunk_method") not in valid_chunk_method:
281
  return get_error_data_result(
@@ -1355,6 +1356,7 @@ def retrieval_test(tenant_id):
1355
  doc_ids,
1356
  rerank_mdl=rerank_mdl,
1357
  highlight=highlight,
 
1358
  )
1359
  for c in ranks["chunks"]:
1360
  c.pop("vector", None)
 
16
  import pathlib
17
  import datetime
18
 
19
+ from api.db.services.dialog_service import keyword_extraction, label_question
20
  from rag.app.qa import rmPrefix, beAdoc
21
  from rag.nlp import rag_tokenizer
22
  from api.db import LLMType, ParserType
 
276
  "one",
277
  "knowledge_graph",
278
  "email",
279
+ "tag"
280
  }
281
  if req.get("chunk_method") not in valid_chunk_method:
282
  return get_error_data_result(
 
1356
  doc_ids,
1357
  rerank_mdl=rerank_mdl,
1358
  highlight=highlight,
1359
+ rank_feature=label_question(question, kbs)
1360
  )
1361
  for c in ranks["chunks"]:
1362
  c.pop("vector", None)
api/db/__init__.py CHANGED
@@ -89,6 +89,7 @@ class ParserType(StrEnum):
89
  AUDIO = "audio"
90
  EMAIL = "email"
91
  KG = "knowledge_graph"
 
92
 
93
 
94
  class FileSource(StrEnum):
 
89
  AUDIO = "audio"
90
  EMAIL = "email"
91
  KG = "knowledge_graph"
92
+ TAG = "tag"
93
 
94
 
95
  class FileSource(StrEnum):
api/db/init_data.py CHANGED
@@ -133,7 +133,7 @@ def init_llm_factory():
133
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
134
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
135
  TenantService.filter_update([1 == 1], {
136
- "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email"})
137
  ## insert openai two embedding models to the current openai user.
138
  # print("Start to insert 2 OpenAI embedding models...")
139
  tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
@@ -153,14 +153,7 @@ def init_llm_factory():
153
  break
154
  for kb_id in KnowledgebaseService.get_all_ids():
155
  KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
156
- """
157
- drop table llm;
158
- drop table llm_factories;
159
- update tenant set parser_ids='naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph';
160
- alter table knowledgebase modify avatar longtext;
161
- alter table user modify avatar longtext;
162
- alter table dialog modify icon longtext;
163
- """
164
 
165
 
166
  def add_graph_templates():
 
133
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "QAnything"], {"llm_factory": "Youdao"})
134
  TenantLLMService.filter_update([TenantLLMService.model.llm_factory == "cohere"], {"llm_factory": "Cohere"})
135
  TenantService.filter_update([1 == 1], {
136
+ "parser_ids": "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag"})
137
  ## insert openai two embedding models to the current openai user.
138
  # print("Start to insert 2 OpenAI embedding models...")
139
  tenant_ids = set([row["tenant_id"] for row in TenantLLMService.get_openai_models()])
 
153
  break
154
  for kb_id in KnowledgebaseService.get_all_ids():
155
  KnowledgebaseService.update_by_id(kb_id, {"doc_num": DocumentService.get_kb_doc_count(kb_id)})
156
+
 
 
 
 
 
 
 
157
 
158
 
159
  def add_graph_templates():
api/db/services/dialog_service.py CHANGED
@@ -29,8 +29,10 @@ from api.db.services.common_service import CommonService
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
30
  from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
31
  from api import settings
 
32
  from rag.app.resume import forbidden_select_fields4resume
33
  from rag.nlp.search import index_name
 
34
  from rag.utils import rmSpace, num_tokens_from_string, encoder
35
  from api.utils.file_utils import get_project_base_directory
36
 
@@ -135,6 +137,29 @@ def kb_prompt(kbinfos, max_tokens):
135
  return knowledges
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def chat(dialog, messages, stream=True, **kwargs):
139
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
140
 
@@ -236,11 +261,14 @@ def chat(dialog, messages, stream=True, **kwargs):
236
  generate_keyword_ts = timer()
237
 
238
  tenant_ids = list(set([kb.tenant_id for kb in kbs]))
 
239
  kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
240
  dialog.similarity_threshold,
241
  dialog.vector_similarity_weight,
242
  doc_ids=attachments,
243
- top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
 
 
244
 
245
  retrieval_ts = timer()
246
 
@@ -650,7 +678,10 @@ def ask(question, kb_ids, tenant_id):
650
  chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
651
  max_tokens = chat_mdl.max_length
652
  tenant_ids = list(set([kb.tenant_id for kb in kbs]))
653
- kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
 
 
 
654
  knowledges = kb_prompt(kbinfos, max_tokens)
655
  prompt = """
656
  Role: You're a smart assistant. Your name is Miss R.
@@ -700,3 +731,56 @@ def ask(question, kb_ids, tenant_id):
700
  answer = ans
701
  yield {"answer": answer, "reference": {}}
702
  yield decorate_answer(answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  from api.db.services.knowledgebase_service import KnowledgebaseService
30
  from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
31
  from api import settings
32
+ from graphrag.utils import get_tags_from_cache, set_tags_to_cache
33
  from rag.app.resume import forbidden_select_fields4resume
34
  from rag.nlp.search import index_name
35
+ from rag.settings import TAG_FLD
36
  from rag.utils import rmSpace, num_tokens_from_string, encoder
37
  from api.utils.file_utils import get_project_base_directory
38
 
 
137
  return knowledges
138
 
139
 
140
+ def label_question(question, kbs):
141
+ tags = None
142
+ tag_kb_ids = []
143
+ for kb in kbs:
144
+ if kb.parser_config.get("tag_kb_ids"):
145
+ tag_kb_ids.extend(kb.parser_config["tag_kb_ids"])
146
+ if tag_kb_ids:
147
+ all_tags = get_tags_from_cache(tag_kb_ids)
148
+ if not all_tags:
149
+ all_tags = settings.retrievaler.all_tags_in_portion(kb.tenant_id, tag_kb_ids)
150
+ set_tags_to_cache(all_tags, tag_kb_ids)
151
+ else:
152
+ all_tags = json.loads(all_tags)
153
+ tag_kbs = KnowledgebaseService.get_by_ids(tag_kb_ids)
154
+ tags = settings.retrievaler.tag_query(question,
155
+ list(set([kb.tenant_id for kb in tag_kbs])),
156
+ tag_kb_ids,
157
+ all_tags,
158
+ kb.parser_config.get("topn_tags", 3)
159
+ )
160
+ return tags
161
+
162
+
163
  def chat(dialog, messages, stream=True, **kwargs):
164
  assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
165
 
 
261
  generate_keyword_ts = timer()
262
 
263
  tenant_ids = list(set([kb.tenant_id for kb in kbs]))
264
+
265
  kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
266
  dialog.similarity_threshold,
267
  dialog.vector_similarity_weight,
268
  doc_ids=attachments,
269
+ top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
270
+ rank_feature=label_question(" ".join(questions), kbs)
271
+ )
272
 
273
  retrieval_ts = timer()
274
 
 
678
  chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
679
  max_tokens = chat_mdl.max_length
680
  tenant_ids = list(set([kb.tenant_id for kb in kbs]))
681
+ kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids,
682
+ 1, 12, 0.1, 0.3, aggs=False,
683
+ rank_feature=label_question(question, kbs)
684
+ )
685
  knowledges = kb_prompt(kbinfos, max_tokens)
686
  prompt = """
687
  Role: You're a smart assistant. Your name is Miss R.
 
731
  answer = ans
732
  yield {"answer": answer, "reference": {}}
733
  yield decorate_answer(answer)
734
+
735
+
736
+ def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
737
+ prompt = f"""
738
+ Role: You're a text analyzer.
739
+
740
+ Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
741
+
742
+ Steps::
743
+ - Comprehend the tag/label set.
744
+ - Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
745
+ - Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
746
+
747
+ Requirements
748
+ - The tags MUST be from the tag set.
749
+ - The output MUST be in JSON format only, the key is tag and the value is its relevance score.
750
+ - The relevance score must be range from 1 to 10.
751
+ - Keywords ONLY in output.
752
+
753
+ # TAG SET
754
+ {", ".join(all_tags)}
755
+
756
+ """
757
+ for i, ex in enumerate(examples):
758
+ prompt += """
759
+ # Examples {}
760
+ ### Text Content
761
+ {}
762
+
763
+ Output:
764
+ {}
765
+
766
+ """.format(i, ex["content"], json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False))
767
+
768
+ prompt += f"""
769
+ # Real Data
770
+ ### Text Content
771
+ {content}
772
+
773
+ """
774
+ msg = [
775
+ {"role": "system", "content": prompt},
776
+ {"role": "user", "content": "Output: "}
777
+ ]
778
+ _, msg = message_fit_in(msg, chat_mdl.max_length)
779
+ kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.5})
780
+ if isinstance(kwd, tuple):
781
+ kwd = kwd[0]
782
+ if kwd.find("**ERROR**") >= 0:
783
+ raise Exception(kwd)
784
+
785
+ kwd = re.sub(r".*?\{", "{", kwd)
786
+ return json.loads(kwd)
api/db/services/file2document_service.py CHANGED
@@ -43,10 +43,7 @@ class File2DocumentService(CommonService):
43
  def insert(cls, obj):
44
  if not cls.save(**obj):
45
  raise RuntimeError("Database error (File)!")
46
- e, obj = cls.get_by_id(obj["id"])
47
- if not e:
48
- raise RuntimeError("Database error (File retrieval)!")
49
- return obj
50
 
51
  @classmethod
52
  @DB.connection_context()
@@ -63,9 +60,8 @@ class File2DocumentService(CommonService):
63
  def update_by_file_id(cls, file_id, obj):
64
  obj["update_time"] = current_timestamp()
65
  obj["update_date"] = datetime_format(datetime.now())
66
- # num = cls.model.update(obj).where(cls.model.id == file_id).execute()
67
- e, obj = cls.get_by_id(cls.model.id)
68
- return obj
69
 
70
  @classmethod
71
  @DB.connection_context()
 
43
  def insert(cls, obj):
44
  if not cls.save(**obj):
45
  raise RuntimeError("Database error (File)!")
46
+ return File2Document(**obj)
 
 
 
47
 
48
  @classmethod
49
  @DB.connection_context()
 
60
  def update_by_file_id(cls, file_id, obj):
61
  obj["update_time"] = current_timestamp()
62
  obj["update_date"] = datetime_format(datetime.now())
63
+ cls.model.update(obj).where(cls.model.id == file_id).execute()
64
+ return File2Document(**obj)
 
65
 
66
  @classmethod
67
  @DB.connection_context()
api/db/services/file_service.py CHANGED
@@ -251,10 +251,7 @@ class FileService(CommonService):
251
  def insert(cls, file):
252
  if not cls.save(**file):
253
  raise RuntimeError("Database error (File)!")
254
- e, file = cls.get_by_id(file["id"])
255
- if not e:
256
- raise RuntimeError("Database error (File retrieval)!")
257
- return file
258
 
259
  @classmethod
260
  @DB.connection_context()
 
251
  def insert(cls, file):
252
  if not cls.save(**file):
253
  raise RuntimeError("Database error (File)!")
254
+ return File(**file)
 
 
 
255
 
256
  @classmethod
257
  @DB.connection_context()
api/db/services/knowledgebase_service.py CHANGED
@@ -35,7 +35,10 @@ class KnowledgebaseService(CommonService):
35
  @classmethod
36
  @DB.connection_context()
37
  def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
38
- page_number, items_per_page, orderby, desc, keywords):
 
 
 
39
  fields = [
40
  cls.model.id,
41
  cls.model.avatar,
@@ -67,6 +70,8 @@ class KnowledgebaseService(CommonService):
67
  cls.model.tenant_id == user_id))
68
  & (cls.model.status == StatusEnum.VALID.value)
69
  )
 
 
70
  if desc:
71
  kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
72
  else:
 
35
  @classmethod
36
  @DB.connection_context()
37
  def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
38
+ page_number, items_per_page,
39
+ orderby, desc, keywords,
40
+ parser_id=None
41
+ ):
42
  fields = [
43
  cls.model.id,
44
  cls.model.avatar,
 
70
  cls.model.tenant_id == user_id))
71
  & (cls.model.status == StatusEnum.VALID.value)
72
  )
73
+ if parser_id:
74
+ kbs = kbs.where(cls.model.parser_id == parser_id)
75
  if desc:
76
  kbs = kbs.order_by(cls.model.getter_by(orderby).desc())
77
  else:
api/db/services/task_service.py CHANGED
@@ -69,6 +69,7 @@ class TaskService(CommonService):
69
  Knowledgebase.language,
70
  Knowledgebase.embd_id,
71
  Knowledgebase.pagerank,
 
72
  Tenant.img2txt_id,
73
  Tenant.asr_id,
74
  Tenant.llm_id,
 
69
  Knowledgebase.language,
70
  Knowledgebase.embd_id,
71
  Knowledgebase.pagerank,
72
+ Knowledgebase.parser_config.alias("kb_parser_config"),
73
  Tenant.img2txt_id,
74
  Tenant.asr_id,
75
  Tenant.llm_id,
api/settings.py CHANGED
@@ -140,7 +140,7 @@ def init_settings():
140
  API_KEY = LLM.get("api_key", "")
141
  PARSERS = LLM.get(
142
  "parsers",
143
- "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email")
144
 
145
  HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
146
  HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
 
140
  API_KEY = LLM.get("api_key", "")
141
  PARSERS = LLM.get(
142
  "parsers",
143
+ "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,knowledge_graph:Knowledge Graph,email:Email,tag:Tag")
144
 
145
  HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
146
  HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
api/utils/api_utils.py CHANGED
@@ -173,6 +173,7 @@ def validate_request(*args, **kwargs):
173
 
174
  return wrapper
175
 
 
176
  def not_allowed_parameters(*params):
177
  def decorator(f):
178
  def wrapper(*args, **kwargs):
@@ -182,7 +183,9 @@ def not_allowed_parameters(*params):
182
  return get_json_result(
183
  code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
184
  return f(*args, **kwargs)
 
185
  return wrapper
 
186
  return decorator
187
 
188
 
@@ -207,6 +210,7 @@ def get_json_result(code=settings.RetCode.SUCCESS, message='success', data=None)
207
  response = {"code": code, "message": message, "data": data}
208
  return jsonify(response)
209
 
 
210
  def apikey_required(func):
211
  @wraps(func)
212
  def decorated_function(*args, **kwargs):
@@ -282,17 +286,18 @@ def construct_error_response(e):
282
  def token_required(func):
283
  @wraps(func)
284
  def decorated_function(*args, **kwargs):
285
- authorization_str=flask_request.headers.get('Authorization')
286
  if not authorization_str:
287
- return get_json_result(data=False,message="`Authorization` can't be empty")
288
- authorization_list=authorization_str.split()
289
  if len(authorization_list) < 2:
290
- return get_json_result(data=False,message="Please check your authorization format.")
291
  token = authorization_list[1]
292
  objs = APIToken.query(token=token)
293
  if not objs:
294
  return get_json_result(
295
- data=False, message='Authentication error: API key is invalid!', code=settings.RetCode.AUTHENTICATION_ERROR
 
296
  )
297
  kwargs['tenant_id'] = objs[0].tenant_id
298
  return func(*args, **kwargs)
@@ -330,35 +335,41 @@ def generate_confirmation_token(tenent_id):
330
  return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
331
 
332
 
333
- def valid(permission,valid_permission,language,valid_language,chunk_method,valid_chunk_method):
334
- if valid_parameter(permission,valid_permission):
335
- return valid_parameter(permission,valid_permission)
336
- if valid_parameter(language,valid_language):
337
- return valid_parameter(language,valid_language)
338
- if valid_parameter(chunk_method,valid_chunk_method):
339
- return valid_parameter(chunk_method,valid_chunk_method)
 
340
 
341
- def valid_parameter(parameter,valid_values):
342
  if parameter and parameter not in valid_values:
343
- return get_error_data_result(f"'{parameter}' is not in {valid_values}")
 
344
 
345
- def get_parser_config(chunk_method,parser_config):
346
  if parser_config:
347
  return parser_config
348
  if not chunk_method:
349
  chunk_method = "naive"
350
- key_mapping={"naive":{"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False,"layout_recognize": True, "raptor": {"use_raptor": False}},
351
- "qa":{"raptor":{"use_raptor":False}},
352
- "resume":None,
353
- "manual":{"raptor":{"use_raptor":False}},
354
- "table":None,
355
- "paper":{"raptor":{"use_raptor":False}},
356
- "book":{"raptor":{"use_raptor":False}},
357
- "laws":{"raptor":{"use_raptor":False}},
358
- "presentation":{"raptor":{"use_raptor":False}},
359
- "one":None,
360
- "knowledge_graph":{"chunk_token_num":8192,"delimiter":"\\n!?;。;!?","entity_types":["organization","person","location","event","time"]},
361
- "email":None,
362
- "picture":None}
363
- parser_config=key_mapping[chunk_method]
364
- return parser_config
 
 
 
 
 
173
 
174
  return wrapper
175
 
176
+
177
  def not_allowed_parameters(*params):
178
  def decorator(f):
179
  def wrapper(*args, **kwargs):
 
183
  return get_json_result(
184
  code=settings.RetCode.ARGUMENT_ERROR, message=f"Parameter {param} isn't allowed")
185
  return f(*args, **kwargs)
186
+
187
  return wrapper
188
+
189
  return decorator
190
 
191
 
 
210
  response = {"code": code, "message": message, "data": data}
211
  return jsonify(response)
212
 
213
+
214
  def apikey_required(func):
215
  @wraps(func)
216
  def decorated_function(*args, **kwargs):
 
286
  def token_required(func):
287
  @wraps(func)
288
  def decorated_function(*args, **kwargs):
289
+ authorization_str = flask_request.headers.get('Authorization')
290
  if not authorization_str:
291
+ return get_json_result(data=False, message="`Authorization` can't be empty")
292
+ authorization_list = authorization_str.split()
293
  if len(authorization_list) < 2:
294
+ return get_json_result(data=False, message="Please check your authorization format.")
295
  token = authorization_list[1]
296
  objs = APIToken.query(token=token)
297
  if not objs:
298
  return get_json_result(
299
+ data=False, message='Authentication error: API key is invalid!',
300
+ code=settings.RetCode.AUTHENTICATION_ERROR
301
  )
302
  kwargs['tenant_id'] = objs[0].tenant_id
303
  return func(*args, **kwargs)
 
335
  return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
336
 
337
 
338
+ def valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method):
339
+ if valid_parameter(permission, valid_permission):
340
+ return valid_parameter(permission, valid_permission)
341
+ if valid_parameter(language, valid_language):
342
+ return valid_parameter(language, valid_language)
343
+ if valid_parameter(chunk_method, valid_chunk_method):
344
+ return valid_parameter(chunk_method, valid_chunk_method)
345
+
346
 
347
+ def valid_parameter(parameter, valid_values):
348
  if parameter and parameter not in valid_values:
349
+ return get_error_data_result(f"'{parameter}' is not in {valid_values}")
350
+
351
 
352
+ def get_parser_config(chunk_method, parser_config):
353
  if parser_config:
354
  return parser_config
355
  if not chunk_method:
356
  chunk_method = "naive"
357
+ key_mapping = {
358
+ "naive": {"chunk_token_num": 128, "delimiter": "\\n!?;。;!?", "html4excel": False, "layout_recognize": True,
359
+ "raptor": {"use_raptor": False}},
360
+ "qa": {"raptor": {"use_raptor": False}},
361
+ "tag": None,
362
+ "resume": None,
363
+ "manual": {"raptor": {"use_raptor": False}},
364
+ "table": None,
365
+ "paper": {"raptor": {"use_raptor": False}},
366
+ "book": {"raptor": {"use_raptor": False}},
367
+ "laws": {"raptor": {"use_raptor": False}},
368
+ "presentation": {"raptor": {"use_raptor": False}},
369
+ "one": None,
370
+ "knowledge_graph": {"chunk_token_num": 8192, "delimiter": "\\n!?;。;!?",
371
+ "entity_types": ["organization", "person", "location", "event", "time"]},
372
+ "email": None,
373
+ "picture": None}
374
+ parser_config = key_mapping[chunk_method]
375
+ return parser_config
conf/infinity_mapping.json CHANGED
@@ -10,6 +10,7 @@
10
  "title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
11
  "name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
12
  "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
 
13
  "important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
14
  "question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
15
  "question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
@@ -27,5 +28,6 @@
27
  "available_int": {"type": "integer", "default": 1},
28
  "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
29
  "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
30
- "pagerank_fea": {"type": "integer", "default": 0}
 
31
  }
 
10
  "title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
11
  "name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
12
  "important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
13
+ "tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
14
  "important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
15
  "question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
16
  "question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
 
28
  "available_int": {"type": "integer", "default": 1},
29
  "knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
30
  "entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
31
+ "pagerank_fea": {"type": "integer", "default": 0},
32
+ "tag_fea": {"type": "integer", "default": 0}
33
  }
graphrag/utils.py CHANGED
@@ -111,4 +111,23 @@ def set_embed_cache(llmnm, txt, arr):
111
 
112
  k = hasher.hexdigest()
113
  arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
114
- REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  k = hasher.hexdigest()
113
  arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
114
+ REDIS_CONN.set(k, arr.encode("utf-8"), 24*3600)
115
+
116
+
117
+ def get_tags_from_cache(kb_ids):
118
+ hasher = xxhash.xxh64()
119
+ hasher.update(str(kb_ids).encode("utf-8"))
120
+
121
+ k = hasher.hexdigest()
122
+ bin = REDIS_CONN.get(k)
123
+ if not bin:
124
+ return
125
+ return bin
126
+
127
+
128
+ def set_tags_to_cache(kb_ids, tags):
129
+ hasher = xxhash.xxh64()
130
+ hasher.update(str(kb_ids).encode("utf-8"))
131
+
132
+ k = hasher.hexdigest()
133
+ REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
rag/app/qa.py CHANGED
@@ -26,6 +26,7 @@ from docx import Document
26
  from PIL import Image
27
  from markdown import markdown
28
 
 
29
  class Excel(ExcelParser):
30
  def __call__(self, fnm, binary=None, callback=None):
31
  if not binary:
@@ -58,11 +59,11 @@ class Excel(ExcelParser):
58
  if len(res) % 999 == 0:
59
  callback(len(res) *
60
  0.6 /
61
- total, ("Extract Q&A: {}".format(len(res)) +
62
  (f"{len(fails)} failure, line: %s..." %
63
  (",".join(fails[:3])) if fails else "")))
64
 
65
- callback(0.6, ("Extract Q&A: {}. ".format(len(res)) + (
66
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
67
  self.is_english = is_english(
68
  [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
@@ -269,7 +270,7 @@ def beAdocPdf(d, q, a, eng, image, poss):
269
  return d
270
 
271
 
272
- def beAdocDocx(d, q, a, eng, image):
273
  qprefix = "Question: " if eng else "问题:"
274
  aprefix = "Answer: " if eng else "回答:"
275
  d["content_with_weight"] = "\t".join(
@@ -277,16 +278,20 @@ def beAdocDocx(d, q, a, eng, image):
277
  d["content_ltks"] = rag_tokenizer.tokenize(q)
278
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
279
  d["image"] = image
 
 
280
  return d
281
 
282
 
283
- def beAdoc(d, q, a, eng):
284
  qprefix = "Question: " if eng else "问题:"
285
  aprefix = "Answer: " if eng else "回答:"
286
  d["content_with_weight"] = "\t".join(
287
  [qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
288
  d["content_ltks"] = rag_tokenizer.tokenize(q)
289
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
 
 
290
  return d
291
 
292
 
@@ -316,8 +321,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
316
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
317
  callback(0.1, "Start to parse.")
318
  excel_parser = Excel()
319
- for q, a in excel_parser(filename, binary, callback):
320
- res.append(beAdoc(deepcopy(doc), q, a, eng))
321
  return res
322
 
323
  elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
@@ -344,7 +349,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
344
  fails.append(str(i+1))
345
  elif len(arr) == 2:
346
  if question and answer:
347
- res.append(beAdoc(deepcopy(doc), question, answer, eng))
348
  question, answer = arr
349
  i += 1
350
  if len(res) % 999 == 0:
@@ -352,7 +357,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
352
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
353
 
354
  if question:
355
- res.append(beAdoc(deepcopy(doc), question, answer, eng))
356
 
357
  callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
358
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@@ -378,14 +383,14 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
378
  fails.append(str(i + 1))
379
  elif len(row) == 2:
380
  if question and answer:
381
- res.append(beAdoc(deepcopy(doc), question, answer, eng))
382
  question, answer = row
383
  if len(res) % 999 == 0:
384
  callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
385
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
386
 
387
  if question:
388
- res.append(beAdoc(deepcopy(doc), question, answer, eng))
389
 
390
  callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
391
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
@@ -420,7 +425,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
420
  if last_answer.strip():
421
  sum_question = '\n'.join(question_stack)
422
  if sum_question:
423
- res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng))
424
  last_answer = ''
425
 
426
  i = question_level
@@ -432,7 +437,7 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
432
  if last_answer.strip():
433
  sum_question = '\n'.join(question_stack)
434
  if sum_question:
435
- res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng))
436
  return res
437
 
438
  elif re.search(r"\.docx$", filename, re.IGNORECASE):
@@ -440,8 +445,8 @@ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
440
  qai_list, tbls = docx_parser(filename, binary,
441
  from_page=0, to_page=10000, callback=callback)
442
  res = tokenize_table(tbls, doc, eng)
443
- for q, a, image in qai_list:
444
- res.append(beAdocDocx(deepcopy(doc), q, a, eng, image))
445
  return res
446
 
447
  raise NotImplementedError(
 
26
  from PIL import Image
27
  from markdown import markdown
28
 
29
+
30
  class Excel(ExcelParser):
31
  def __call__(self, fnm, binary=None, callback=None):
32
  if not binary:
 
59
  if len(res) % 999 == 0:
60
  callback(len(res) *
61
  0.6 /
62
+ total, ("Extract pairs: {}".format(len(res)) +
63
  (f"{len(fails)} failure, line: %s..." %
64
  (",".join(fails[:3])) if fails else "")))
65
 
66
+ callback(0.6, ("Extract pairs: {}. ".format(len(res)) + (
67
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
68
  self.is_english = is_english(
69
  [rmPrefix(q) for q, _ in random_choices(res, k=30) if len(q) > 1])
 
270
  return d
271
 
272
 
273
+ def beAdocDocx(d, q, a, eng, image, row_num=-1):
274
  qprefix = "Question: " if eng else "问题:"
275
  aprefix = "Answer: " if eng else "回答:"
276
  d["content_with_weight"] = "\t".join(
 
278
  d["content_ltks"] = rag_tokenizer.tokenize(q)
279
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
280
  d["image"] = image
281
+ if row_num >= 0:
282
+ d["top_int"] = [row_num]
283
  return d
284
 
285
 
286
+ def beAdoc(d, q, a, eng, row_num=-1):
287
  qprefix = "Question: " if eng else "问题:"
288
  aprefix = "Answer: " if eng else "回答:"
289
  d["content_with_weight"] = "\t".join(
290
  [qprefix + rmPrefix(q), aprefix + rmPrefix(a)])
291
  d["content_ltks"] = rag_tokenizer.tokenize(q)
292
  d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
293
+ if row_num >= 0:
294
+ d["top_int"] = [row_num]
295
  return d
296
 
297
 
 
321
  if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
322
  callback(0.1, "Start to parse.")
323
  excel_parser = Excel()
324
+ for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
325
+ res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
326
  return res
327
 
328
  elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
 
349
  fails.append(str(i+1))
350
  elif len(arr) == 2:
351
  if question and answer:
352
+ res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
353
  question, answer = arr
354
  i += 1
355
  if len(res) % 999 == 0:
 
357
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
358
 
359
  if question:
360
+ res.append(beAdoc(deepcopy(doc), question, answer, eng, len(lines)))
361
 
362
  callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
363
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
 
383
  fails.append(str(i + 1))
384
  elif len(row) == 2:
385
  if question and answer:
386
+ res.append(beAdoc(deepcopy(doc), question, answer, eng, i))
387
  question, answer = row
388
  if len(res) % 999 == 0:
389
  callback(len(res) * 0.6 / len(lines), ("Extract Q&A: {}".format(len(res)) + (
390
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
391
 
392
  if question:
393
+ res.append(beAdoc(deepcopy(doc), question, answer, eng, len(reader)))
394
 
395
  callback(0.6, ("Extract Q&A: {}".format(len(res)) + (
396
  f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
 
425
  if last_answer.strip():
426
  sum_question = '\n'.join(question_stack)
427
  if sum_question:
428
+ res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
429
  last_answer = ''
430
 
431
  i = question_level
 
437
  if last_answer.strip():
438
  sum_question = '\n'.join(question_stack)
439
  if sum_question:
440
+ res.append(beAdoc(deepcopy(doc), sum_question, markdown(last_answer, extensions=['markdown.extensions.tables']), eng, index))
441
  return res
442
 
443
  elif re.search(r"\.docx$", filename, re.IGNORECASE):
 
445
  qai_list, tbls = docx_parser(filename, binary,
446
  from_page=0, to_page=10000, callback=callback)
447
  res = tokenize_table(tbls, doc, eng)
448
+ for i, (q, a, image) in enumerate(qai_list):
449
+ res.append(beAdocDocx(deepcopy(doc), q, a, eng, image, i))
450
  return res
451
 
452
  raise NotImplementedError(
rag/app/tag.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+ #
13
+ import re
14
+ import csv
15
+ from copy import deepcopy
16
+
17
+ from deepdoc.parser.utils import get_text
18
+ from rag.app.qa import Excel
19
+ from rag.nlp import rag_tokenizer
20
+
21
+
22
+ def beAdoc(d, q, a, eng, row_num=-1):
23
+ d["content_with_weight"] = q
24
+ d["content_ltks"] = rag_tokenizer.tokenize(q)
25
+ d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"])
26
+ d["tag_kwd"] = [t.strip() for t in a.split(",") if t.strip()]
27
+ if row_num >= 0:
28
+ d["top_int"] = [row_num]
29
+ return d
30
+
31
+
32
+ def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
33
+ """
34
+ Excel and csv(txt) format files are supported.
35
+ If the file is in excel format, there should be 2 column content and tags without header.
36
+ And content column is ahead of tags column.
37
+ And it's O.K if it has multiple sheets as long as the columns are rightly composed.
38
+
39
+ If it's in csv format, it should be UTF-8 encoded. Use TAB as delimiter to separate content and tags.
40
+
41
+ All the deformed lines will be ignored.
42
+ Every pair will be treated as a chunk.
43
+ """
44
+ eng = lang.lower() == "english"
45
+ res = []
46
+ doc = {
47
+ "docnm_kwd": filename,
48
+ "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
49
+ }
50
+ if re.search(r"\.xlsx?$", filename, re.IGNORECASE):
51
+ callback(0.1, "Start to parse.")
52
+ excel_parser = Excel()
53
+ for ii, (q, a) in enumerate(excel_parser(filename, binary, callback)):
54
+ res.append(beAdoc(deepcopy(doc), q, a, eng, ii))
55
+ return res
56
+
57
+ elif re.search(r"\.(txt)$", filename, re.IGNORECASE):
58
+ callback(0.1, "Start to parse.")
59
+ txt = get_text(filename, binary)
60
+ lines = txt.split("\n")
61
+ comma, tab = 0, 0
62
+ for line in lines:
63
+ if len(line.split(",")) == 2:
64
+ comma += 1
65
+ if len(line.split("\t")) == 2:
66
+ tab += 1
67
+ delimiter = "\t" if tab >= comma else ","
68
+
69
+ fails = []
70
+ content = ""
71
+ i = 0
72
+ while i < len(lines):
73
+ arr = lines[i].split(delimiter)
74
+ if len(arr) != 2:
75
+ content += "\n" + lines[i]
76
+ elif len(arr) == 2:
77
+ content += "\n" + arr[0]
78
+ res.append(beAdoc(deepcopy(doc), content, arr[1], eng, i))
79
+ content = ""
80
+ i += 1
81
+ if len(res) % 999 == 0:
82
+ callback(len(res) * 0.6 / len(lines), ("Extract TAG: {}".format(len(res)) + (
83
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
84
+
85
+ callback(0.6, ("Extract TAG: {}".format(len(res)) + (
86
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
87
+
88
+ return res
89
+
90
+ elif re.search(r"\.(csv)$", filename, re.IGNORECASE):
91
+ callback(0.1, "Start to parse.")
92
+ txt = get_text(filename, binary)
93
+ lines = txt.split("\n")
94
+ delimiter = "\t" if any("\t" in line for line in lines) else ","
95
+
96
+ fails = []
97
+ content = ""
98
+ res = []
99
+ reader = csv.reader(lines, delimiter=delimiter)
100
+
101
+ for i, row in enumerate(reader):
102
+ if len(row) != 2:
103
+ content += "\n" + lines[i]
104
+ elif len(row) == 2:
105
+ content += "\n" + row[0]
106
+ res.append(beAdoc(deepcopy(doc), content, row[1], eng, i))
107
+ content = ""
108
+ if len(res) % 999 == 0:
109
+ callback(len(res) * 0.6 / len(lines), ("Extract Tags: {}".format(len(res)) + (
110
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
111
+
112
+ callback(0.6, ("Extract TAG : {}".format(len(res)) + (
113
+ f"{len(fails)} failure, line: %s..." % (",".join(fails[:3])) if fails else "")))
114
+ return res
115
+
116
+ raise NotImplementedError(
117
+ "Excel, csv(txt) format files are supported.")
118
+
119
+
120
+ if __name__ == "__main__":
121
+ import sys
122
+
123
+ def dummy(prog=None, msg=""):
124
+ pass
125
+ chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
rag/nlp/query.py CHANGED
@@ -59,13 +59,15 @@ class FulltextQueryer:
59
  "",
60
  ),
61
  (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
62
- (r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ", " ")
 
 
63
  ]
64
  for r, p in patts:
65
  txt = re.sub(r, p, txt, flags=re.IGNORECASE)
66
  return txt
67
 
68
- def question(self, txt, tbl="qa", min_match:float=0.6):
69
  txt = re.sub(
70
  r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
71
  " ",
@@ -90,7 +92,8 @@ class FulltextQueryer:
90
  syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
91
  syns.append(" ".join(syn))
92
 
93
- q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if tk and not re.match(r"[.^+\(\)-]", tk)]
 
94
  for i in range(1, len(tks_w)):
95
  left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
96
  if not left or not right:
@@ -155,7 +158,7 @@ class FulltextQueryer:
155
  if len(keywords) < 32:
156
  keywords.extend([s for s in tk_syns if s])
157
  tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
158
- tk_syns = [f"\"{s}\"" if s.find(" ")>0 else s for s in tk_syns]
159
 
160
  if len(keywords) >= 32:
161
  break
@@ -174,8 +177,6 @@ class FulltextQueryer:
174
 
175
  if len(twts) > 1:
176
  tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
177
- if re.match(r"[0-9a-z ]+$", tt):
178
- tms = f'("{tt}" OR "%s")' % rag_tokenizer.tokenize(tt)
179
 
180
  syns = " OR ".join(
181
  [
@@ -232,3 +233,25 @@ class FulltextQueryer:
232
  for k, v in qtwt.items():
233
  q += v
234
  return s / q
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  "",
60
  ),
61
  (r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
62
+ (
63
+ r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
64
+ " ")
65
  ]
66
  for r, p in patts:
67
  txt = re.sub(r, p, txt, flags=re.IGNORECASE)
68
  return txt
69
 
70
+ def question(self, txt, tbl="qa", min_match: float = 0.6):
71
  txt = re.sub(
72
  r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
73
  " ",
 
92
  syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
93
  syns.append(" ".join(syn))
94
 
95
+ q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
96
+ tk and not re.match(r"[.^+\(\)-]", tk)]
97
  for i in range(1, len(tks_w)):
98
  left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
99
  if not left or not right:
 
158
  if len(keywords) < 32:
159
  keywords.extend([s for s in tk_syns if s])
160
  tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
161
+ tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
162
 
163
  if len(keywords) >= 32:
164
  break
 
177
 
178
  if len(twts) > 1:
179
  tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
 
 
180
 
181
  syns = " OR ".join(
182
  [
 
233
  for k, v in qtwt.items():
234
  q += v
235
  return s / q
236
+
237
+ def paragraph(self, content_tks: str, keywords: list = [], keywords_topn=30):
238
+ if isinstance(content_tks, str):
239
+ content_tks = [c.strip() for c in content_tks.strip() if c.strip()]
240
+ tks_w = self.tw.weights(content_tks, preprocess=False)
241
+
242
+ keywords = [f'"{k.strip()}"' for k in keywords]
243
+ for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
244
+ tk_syns = self.syn.lookup(tk)
245
+ tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
246
+ tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
247
+ tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
248
+ tk = FulltextQueryer.subSpecialChar(tk)
249
+ if tk.find(" ") > 0:
250
+ tk = '"%s"' % tk
251
+ if tk_syns:
252
+ tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
253
+ if tk:
254
+ keywords.append(f"{tk}^{w}")
255
+
256
+ return MatchTextExpr(self.query_fields, " ".join(keywords), 100,
257
+ {"minimum_should_match": min(3, len(keywords) / 10)})
rag/nlp/search.py CHANGED
@@ -13,11 +13,11 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
-
17
  import logging
18
  import re
19
  from dataclasses import dataclass
20
 
 
21
  from rag.utils import rmSpace
22
  from rag.nlp import rag_tokenizer, query
23
  import numpy as np
@@ -47,7 +47,8 @@ class Dealer:
47
  qv, _ = emb_mdl.encode_queries(txt)
48
  shape = np.array(qv).shape
49
  if len(shape) > 1:
50
- raise Exception(f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
 
51
  embedding_data = [float(v) for v in qv]
52
  vector_column_name = f"q_{len(embedding_data)}_vec"
53
  return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
@@ -63,7 +64,12 @@ class Dealer:
63
  condition[key] = req[key]
64
  return condition
65
 
66
- def search(self, req, idx_names: str | list[str], kb_ids: list[str], emb_mdl=None, highlight = False):
 
 
 
 
 
67
  filters = self.get_filters(req)
68
  orderBy = OrderByExpr()
69
 
@@ -72,9 +78,11 @@ class Dealer:
72
  ps = int(req.get("size", topk))
73
  offset, limit = pg * ps, ps
74
 
75
- src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
76
- "doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd", "question_kwd", "question_tks",
77
- "available_int", "content_with_weight", "pagerank_fea"])
 
 
78
  kwds = set([])
79
 
80
  qst = req.get("question", "")
@@ -85,15 +93,16 @@ class Dealer:
85
  orderBy.asc("top_int")
86
  orderBy.desc("create_timestamp_flt")
87
  res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
88
- total=self.dataStore.getTotal(res)
89
  logging.debug("Dealer.search TOTAL: {}".format(total))
90
  else:
91
  highlightFields = ["content_ltks", "title_tks"] if highlight else []
92
  matchText, keywords = self.qryr.question(qst, min_match=0.3)
93
  if emb_mdl is None:
94
  matchExprs = [matchText]
95
- res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
96
- total=self.dataStore.getTotal(res)
 
97
  logging.debug("Dealer.search TOTAL: {}".format(total))
98
  else:
99
  matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
@@ -103,8 +112,9 @@ class Dealer:
103
  fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
104
  matchExprs = [matchText, matchDense, fusionExpr]
105
 
106
- res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit, idx_names, kb_ids)
107
- total=self.dataStore.getTotal(res)
 
108
  logging.debug("Dealer.search TOTAL: {}".format(total))
109
 
110
  # If result is empty, try again with lower min_match
@@ -112,8 +122,9 @@ class Dealer:
112
  matchText, _ = self.qryr.question(qst, min_match=0.1)
113
  filters.pop("doc_ids", None)
114
  matchDense.extra_options["similarity"] = 0.17
115
- res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr], orderBy, offset, limit, idx_names, kb_ids)
116
- total=self.dataStore.getTotal(res)
 
117
  logging.debug("Dealer.search 2 TOTAL: {}".format(total))
118
 
119
  for k in keywords:
@@ -126,8 +137,8 @@ class Dealer:
126
  kwds.add(kk)
127
 
128
  logging.debug(f"TOTAL: {total}")
129
- ids=self.dataStore.getChunkIds(res)
130
- keywords=list(kwds)
131
  highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
132
  aggs = self.dataStore.getAggregation(res, "docnm_kwd")
133
  return self.SearchResult(
@@ -188,13 +199,13 @@ class Dealer:
188
 
189
  ans_v, _ = embd_mdl.encode(pieces_)
190
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
191
- len(ans_v[0]), len(chunk_v[0]))
192
 
193
  chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
194
  for ck in chunks]
195
  cites = {}
196
  thr = 0.63
197
- while thr>0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
198
  for i, a in enumerate(pieces_):
199
  sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
200
  chunk_v,
@@ -228,20 +239,44 @@ class Dealer:
228
 
229
  return res, seted
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  def rerank(self, sres, query, tkweight=0.3,
232
- vtweight=0.7, cfield="content_ltks"):
 
 
233
  _, keywords = self.qryr.question(query)
234
  vector_size = len(sres.query_vector)
235
  vector_column = f"q_{vector_size}_vec"
236
  zero_vector = [0.0] * vector_size
237
  ins_embd = []
238
- pageranks = []
239
  for chunk_id in sres.ids:
240
  vector = sres.field[chunk_id].get(vector_column, zero_vector)
241
  if isinstance(vector, str):
242
  vector = [float(v) for v in vector.split("\t")]
243
  ins_embd.append(vector)
244
- pageranks.append(sres.field[chunk_id].get("pagerank_fea", 0))
245
  if not ins_embd:
246
  return [], [], []
247
 
@@ -254,18 +289,22 @@ class Dealer:
254
  title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
255
  question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
256
  important_kwd = sres.field[i].get("important_kwd", [])
257
- tks = content_ltks + title_tks*2 + important_kwd*5 + question_tks*6
258
  ins_tw.append(tks)
259
 
 
 
 
260
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
261
  ins_embd,
262
  keywords,
263
  ins_tw, tkweight, vtweight)
264
 
265
- return sim+np.array(pageranks, dtype=float), tksim, vtsim
266
 
267
  def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
268
- vtweight=0.7, cfield="content_ltks"):
 
269
  _, keywords = self.qryr.question(query)
270
 
271
  for i in sres.ids:
@@ -280,9 +319,11 @@ class Dealer:
280
  ins_tw.append(tks)
281
 
282
  tksim = self.qryr.token_similarity(keywords, ins_tw)
283
- vtsim,_ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
 
 
284
 
285
- return tkweight*np.array(tksim) + vtweight*vtsim, tksim, vtsim
286
 
287
  def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
288
  return self.qryr.hybrid_similarity(ans_embd,
@@ -291,13 +332,15 @@ class Dealer:
291
  rag_tokenizer.tokenize(inst).split())
292
 
293
  def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
294
- vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True, rerank_mdl=None, highlight=False):
 
 
295
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
296
  if not question:
297
  return ranks
298
 
299
  RERANK_PAGE_LIMIT = 3
300
- req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size*RERANK_PAGE_LIMIT, 128),
301
  "question": question, "vector": True, "topk": top,
302
  "similarity": similarity_threshold,
303
  "available_int": 1}
@@ -309,29 +352,30 @@ class Dealer:
309
  if isinstance(tenant_ids, str):
310
  tenant_ids = tenant_ids.split(",")
311
 
312
- sres = self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight)
 
313
  ranks["total"] = sres.total
314
 
315
  if page <= RERANK_PAGE_LIMIT:
316
  if rerank_mdl and sres.total > 0:
317
  sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
318
- sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
 
 
319
  else:
320
  sim, tsim, vsim = self.rerank(
321
- sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
322
- idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size]
 
323
  else:
324
- sim = tsim = vsim = [1]*len(sres.ids)
325
  idx = list(range(len(sres.ids)))
326
 
327
- def floor_sim(score):
328
- return (int(score * 100.)%100)/100.
329
-
330
  dim = len(sres.query_vector)
331
  vector_column = f"q_{dim}_vec"
332
  zero_vector = [0.0] * dim
333
  for i in idx:
334
- if floor_sim(sim[i]) < similarity_threshold:
335
  break
336
  if len(ranks["chunks"]) >= page_size:
337
  if aggs:
@@ -369,8 +413,8 @@ class Dealer:
369
  ranks["doc_aggs"] = [{"doc_name": k,
370
  "doc_id": v["doc_id"],
371
  "count": v["count"]} for k,
372
- v in sorted(ranks["doc_aggs"].items(),
373
- key=lambda x:x[1]["count"] * -1)]
374
  ranks["chunks"] = ranks["chunks"][:page_size]
375
 
376
  return ranks
@@ -379,15 +423,57 @@ class Dealer:
379
  tbl = self.dataStore.sql(sql, fetch_size, format)
380
  return tbl
381
 
382
- def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, fields=["docnm_kwd", "content_with_weight", "img_id"]):
 
 
 
383
  condition = {"doc_id": doc_id}
384
  res = []
385
  bs = 128
386
- for p in range(0, max_count, bs):
387
- es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id), kb_ids)
 
388
  dict_chunks = self.dataStore.getFields(es_res, fields)
389
  if dict_chunks:
390
  res.extend(dict_chunks.values())
391
  if len(dict_chunks.values()) < bs:
392
  break
393
  return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import logging
17
  import re
18
  from dataclasses import dataclass
19
 
20
+ from rag.settings import TAG_FLD, PAGERANK_FLD
21
  from rag.utils import rmSpace
22
  from rag.nlp import rag_tokenizer, query
23
  import numpy as np
 
47
  qv, _ = emb_mdl.encode_queries(txt)
48
  shape = np.array(qv).shape
49
  if len(shape) > 1:
50
+ raise Exception(
51
+ f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
52
  embedding_data = [float(v) for v in qv]
53
  vector_column_name = f"q_{len(embedding_data)}_vec"
54
  return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
 
64
  condition[key] = req[key]
65
  return condition
66
 
67
+ def search(self, req, idx_names: str | list[str],
68
+ kb_ids: list[str],
69
+ emb_mdl=None,
70
+ highlight=False,
71
+ rank_feature: dict | None = None
72
+ ):
73
  filters = self.get_filters(req)
74
  orderBy = OrderByExpr()
75
 
 
78
  ps = int(req.get("size", topk))
79
  offset, limit = pg * ps, ps
80
 
81
+ src = req.get("fields",
82
+ ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
83
+ "doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
84
+ "question_kwd", "question_tks",
85
+ "available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
86
  kwds = set([])
87
 
88
  qst = req.get("question", "")
 
93
  orderBy.asc("top_int")
94
  orderBy.desc("create_timestamp_flt")
95
  res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
96
+ total = self.dataStore.getTotal(res)
97
  logging.debug("Dealer.search TOTAL: {}".format(total))
98
  else:
99
  highlightFields = ["content_ltks", "title_tks"] if highlight else []
100
  matchText, keywords = self.qryr.question(qst, min_match=0.3)
101
  if emb_mdl is None:
102
  matchExprs = [matchText]
103
+ res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
104
+ idx_names, kb_ids, rank_feature=rank_feature)
105
+ total = self.dataStore.getTotal(res)
106
  logging.debug("Dealer.search TOTAL: {}".format(total))
107
  else:
108
  matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
 
112
  fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05, 0.95"})
113
  matchExprs = [matchText, matchDense, fusionExpr]
114
 
115
+ res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
116
+ idx_names, kb_ids, rank_feature=rank_feature)
117
+ total = self.dataStore.getTotal(res)
118
  logging.debug("Dealer.search TOTAL: {}".format(total))
119
 
120
  # If result is empty, try again with lower min_match
 
122
  matchText, _ = self.qryr.question(qst, min_match=0.1)
123
  filters.pop("doc_ids", None)
124
  matchDense.extra_options["similarity"] = 0.17
125
+ res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
126
+ orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
127
+ total = self.dataStore.getTotal(res)
128
  logging.debug("Dealer.search 2 TOTAL: {}".format(total))
129
 
130
  for k in keywords:
 
137
  kwds.add(kk)
138
 
139
  logging.debug(f"TOTAL: {total}")
140
+ ids = self.dataStore.getChunkIds(res)
141
+ keywords = list(kwds)
142
  highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
143
  aggs = self.dataStore.getAggregation(res, "docnm_kwd")
144
  return self.SearchResult(
 
199
 
200
  ans_v, _ = embd_mdl.encode(pieces_)
201
  assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
202
+ len(ans_v[0]), len(chunk_v[0]))
203
 
204
  chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split()
205
  for ck in chunks]
206
  cites = {}
207
  thr = 0.63
208
+ while thr > 0.3 and len(cites.keys()) == 0 and pieces_ and chunks_tks:
209
  for i, a in enumerate(pieces_):
210
  sim, tksim, vtsim = self.qryr.hybrid_similarity(ans_v[i],
211
  chunk_v,
 
239
 
240
  return res, seted
241
 
242
+ def _rank_feature_scores(self, query_rfea, search_res):
243
+ ## For rank feature(tag_fea) scores.
244
+ rank_fea = []
245
+ pageranks = []
246
+ for chunk_id in search_res.ids:
247
+ pageranks.append(search_res.field[chunk_id].get(PAGERANK_FLD, 0))
248
+ pageranks = np.array(pageranks, dtype=float)
249
+
250
+ if not query_rfea:
251
+ return np.array([0 for _ in range(len(search_res.ids))]) + pageranks
252
+
253
+ q_denor = np.sqrt(np.sum([s*s for t,s in query_rfea.items() if t != PAGERANK_FLD]))
254
+ for i in search_res.ids:
255
+ nor, denor = 0, 0
256
+ for t, sc in eval(search_res.field[i].get(TAG_FLD, "{}")).items():
257
+ if t in query_rfea:
258
+ nor += query_rfea[t] * sc
259
+ denor += sc * sc
260
+ if denor == 0:
261
+ rank_fea.append(0)
262
+ else:
263
+ rank_fea.append(nor/np.sqrt(denor)/q_denor)
264
+ return np.array(rank_fea)*10. + pageranks
265
+
266
  def rerank(self, sres, query, tkweight=0.3,
267
+ vtweight=0.7, cfield="content_ltks",
268
+ rank_feature: dict | None = None
269
+ ):
270
  _, keywords = self.qryr.question(query)
271
  vector_size = len(sres.query_vector)
272
  vector_column = f"q_{vector_size}_vec"
273
  zero_vector = [0.0] * vector_size
274
  ins_embd = []
 
275
  for chunk_id in sres.ids:
276
  vector = sres.field[chunk_id].get(vector_column, zero_vector)
277
  if isinstance(vector, str):
278
  vector = [float(v) for v in vector.split("\t")]
279
  ins_embd.append(vector)
 
280
  if not ins_embd:
281
  return [], [], []
282
 
 
289
  title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
290
  question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
291
  important_kwd = sres.field[i].get("important_kwd", [])
292
+ tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6
293
  ins_tw.append(tks)
294
 
295
+ ## For rank feature(tag_fea) scores.
296
+ rank_fea = self._rank_feature_scores(rank_feature, sres)
297
+
298
  sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector,
299
  ins_embd,
300
  keywords,
301
  ins_tw, tkweight, vtweight)
302
 
303
+ return sim + rank_fea, tksim, vtsim
304
 
305
  def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
306
+ vtweight=0.7, cfield="content_ltks",
307
+ rank_feature: dict | None = None):
308
  _, keywords = self.qryr.question(query)
309
 
310
  for i in sres.ids:
 
319
  ins_tw.append(tks)
320
 
321
  tksim = self.qryr.token_similarity(keywords, ins_tw)
322
+ vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw])
323
+ ## For rank feature(tag_fea) scores.
324
+ rank_fea = self._rank_feature_scores(rank_feature, sres)
325
 
326
+ return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
327
 
328
  def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
329
  return self.qryr.hybrid_similarity(ans_embd,
 
332
  rag_tokenizer.tokenize(inst).split())
333
 
334
  def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
335
+ vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
336
+ rerank_mdl=None, highlight=False,
337
+ rank_feature: dict | None = {PAGERANK_FLD: 10}):
338
  ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
339
  if not question:
340
  return ranks
341
 
342
  RERANK_PAGE_LIMIT = 3
343
+ req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": max(page_size * RERANK_PAGE_LIMIT, 128),
344
  "question": question, "vector": True, "topk": top,
345
  "similarity": similarity_threshold,
346
  "available_int": 1}
 
352
  if isinstance(tenant_ids, str):
353
  tenant_ids = tenant_ids.split(",")
354
 
355
+ sres = self.search(req, [index_name(tid) for tid in tenant_ids],
356
+ kb_ids, embd_mdl, highlight, rank_feature=rank_feature)
357
  ranks["total"] = sres.total
358
 
359
  if page <= RERANK_PAGE_LIMIT:
360
  if rerank_mdl and sres.total > 0:
361
  sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
362
+ sres, question, 1 - vector_similarity_weight,
363
+ vector_similarity_weight,
364
+ rank_feature=rank_feature)
365
  else:
366
  sim, tsim, vsim = self.rerank(
367
+ sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
368
+ rank_feature=rank_feature)
369
+ idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size]
370
  else:
371
+ sim = tsim = vsim = [1] * len(sres.ids)
372
  idx = list(range(len(sres.ids)))
373
 
 
 
 
374
  dim = len(sres.query_vector)
375
  vector_column = f"q_{dim}_vec"
376
  zero_vector = [0.0] * dim
377
  for i in idx:
378
+ if sim[i] < similarity_threshold:
379
  break
380
  if len(ranks["chunks"]) >= page_size:
381
  if aggs:
 
413
  ranks["doc_aggs"] = [{"doc_name": k,
414
  "doc_id": v["doc_id"],
415
  "count": v["count"]} for k,
416
+ v in sorted(ranks["doc_aggs"].items(),
417
+ key=lambda x: x[1]["count"] * -1)]
418
  ranks["chunks"] = ranks["chunks"][:page_size]
419
 
420
  return ranks
 
423
  tbl = self.dataStore.sql(sql, fetch_size, format)
424
  return tbl
425
 
426
+ def chunk_list(self, doc_id: str, tenant_id: str,
427
+ kb_ids: list[str], max_count=1024,
428
+ offset=0,
429
+ fields=["docnm_kwd", "content_with_weight", "img_id"]):
430
  condition = {"doc_id": doc_id}
431
  res = []
432
  bs = 128
433
+ for p in range(offset, max_count, bs):
434
+ es_res = self.dataStore.search(fields, [], condition, [], OrderByExpr(), p, bs, index_name(tenant_id),
435
+ kb_ids)
436
  dict_chunks = self.dataStore.getFields(es_res, fields)
437
  if dict_chunks:
438
  res.extend(dict_chunks.values())
439
  if len(dict_chunks.values()) < bs:
440
  break
441
  return res
442
+
443
+ def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
444
+ res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
445
+ return self.dataStore.getAggregation(res, "tag_kwd")
446
+
447
+ def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
448
+ res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
449
+ res = self.dataStore.getAggregation(res, "tag_kwd")
450
+ total = np.sum([c for _, c in res])
451
+ return {t: (c + 1) / (total + S) for t, c in res}
452
+
453
+ def tag_content(self, tenant_id: str, kb_ids: list[str], doc, all_tags, topn_tags=3, keywords_topn=30, S=1000):
454
+ idx_nm = index_name(tenant_id)
455
+ match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
456
+ res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
457
+ aggs = self.dataStore.getAggregation(res, "tag_kwd")
458
+ if not aggs:
459
+ return False
460
+ cnt = np.sum([c for _, c in aggs])
461
+ tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
462
+ key=lambda x: x[1] * -1)[:topn_tags]
463
+ doc[TAG_FLD] = {a: c for a, c in tag_fea if c > 0}
464
+ return True
465
+
466
+ def tag_query(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], all_tags, topn_tags=3, S=1000):
467
+ if isinstance(tenant_ids, str):
468
+ idx_nms = index_name(tenant_ids)
469
+ else:
470
+ idx_nms = [index_name(tid) for tid in tenant_ids]
471
+ match_txt, _ = self.qryr.question(question, min_match=0.0)
472
+ res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
473
+ aggs = self.dataStore.getAggregation(res, "tag_kwd")
474
+ if not aggs:
475
+ return {}
476
+ cnt = np.sum([c for _, c in aggs])
477
+ tag_fea = sorted([(a, round(0.1*(c + 1) / (cnt + S) / (all_tags.get(a, 0.0001)))) for a, c in aggs],
478
+ key=lambda x: x[1] * -1)[:topn_tags]
479
+ return {a: c for a, c in tag_fea if c > 0}
rag/settings.py CHANGED
@@ -38,6 +38,9 @@ SVR_QUEUE_RETENTION = 60*60
38
  SVR_QUEUE_MAX_LEN = 1024
39
  SVR_CONSUMER_NAME = "rag_flow_svr_consumer"
40
  SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group"
 
 
 
41
 
42
  def print_rag_settings():
43
  logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
 
38
  SVR_QUEUE_MAX_LEN = 1024
39
  SVR_CONSUMER_NAME = "rag_flow_svr_consumer"
40
  SVR_CONSUMER_GROUP_NAME = "rag_flow_svr_consumer_group"
41
+ PAGERANK_FLD = "pagerank_fea"
42
+ TAG_FLD = "tag_feas"
43
+
44
 
45
  def print_rag_settings():
46
  logging.info(f"MAX_CONTENT_LENGTH: {DOC_MAXIMUM_SIZE}")
rag/svr/task_executor.py CHANGED
@@ -16,10 +16,10 @@
16
  # from beartype import BeartypeConf
17
  # from beartype.claw import beartype_all # <-- you didn't sign up for this
18
  # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
19
-
20
  import sys
21
  from api.utils.log_utils import initRootLogger
22
- from graphrag.utils import get_llm_cache, set_llm_cache
23
 
24
  CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
25
  CONSUMER_NAME = "task_executor_" + CONSUMER_NO
@@ -44,7 +44,7 @@ import numpy as np
44
  from peewee import DoesNotExist
45
 
46
  from api.db import LLMType, ParserType, TaskStatus
47
- from api.db.services.dialog_service import keyword_extraction, question_proposal
48
  from api.db.services.document_service import DocumentService
49
  from api.db.services.llm_service import LLMBundle
50
  from api.db.services.task_service import TaskService
@@ -53,10 +53,10 @@ from api import settings
53
  from api.versions import get_ragflow_version
54
  from api.db.db_models import close_connection
55
  from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
56
- knowledge_graph, email
57
  from rag.nlp import search, rag_tokenizer
58
  from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
59
- from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings
60
  from rag.utils import num_tokens_from_string
61
  from rag.utils.redis_conn import REDIS_CONN, Payload
62
  from rag.utils.storage_factory import STORAGE_IMPL
@@ -78,7 +78,8 @@ FACTORY = {
78
  ParserType.ONE.value: one,
79
  ParserType.AUDIO.value: audio,
80
  ParserType.EMAIL.value: email,
81
- ParserType.KG.value: knowledge_graph
 
82
  }
83
 
84
  CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
@@ -199,7 +200,8 @@ def build_chunks(task, progress_callback):
199
  logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
200
  except TimeoutError:
201
  progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
202
- logging.exception("Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
 
203
  raise
204
  except Exception as e:
205
  if re.search("(No such file|not found)", str(e)):
@@ -227,7 +229,7 @@ def build_chunks(task, progress_callback):
227
  "kb_id": str(task["kb_id"])
228
  }
229
  if task["pagerank"]:
230
- doc["pagerank_fea"] = int(task["pagerank"])
231
  el = 0
232
  for ck in cks:
233
  d = copy.deepcopy(doc)
@@ -252,7 +254,8 @@ def build_chunks(task, progress_callback):
252
  STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
253
  el += timer() - st
254
  except Exception:
255
- logging.exception("Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
 
256
  raise
257
 
258
  d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
@@ -295,12 +298,43 @@ def build_chunks(task, progress_callback):
295
  d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
296
  progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  return docs
299
 
300
 
301
  def init_kb(row, vector_size: int):
302
  idxnm = search.index_name(row["tenant_id"])
303
- return settings.docStoreConn.createIdx(idxnm, row.get("kb_id",""), vector_size)
304
 
305
 
306
  def embedding(docs, mdl, parser_config=None, callback=None):
@@ -381,7 +415,7 @@ def run_raptor(row, chat_mdl, embd_mdl, callback=None):
381
  "title_tks": rag_tokenizer.tokenize(row["name"])
382
  }
383
  if row["pagerank"]:
384
- doc["pagerank_fea"] = int(row["pagerank"])
385
  res = []
386
  tk_count = 0
387
  for content, vctr in chunks[original_length:]:
@@ -480,7 +514,8 @@ def do_handle_task(task):
480
  doc_store_result = ""
481
  es_bulk_size = 4
482
  for b in range(0, len(chunks), es_bulk_size):
483
- doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id), task_dataset_id)
 
484
  if b % 128 == 0:
485
  progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
486
  if doc_store_result:
@@ -493,15 +528,21 @@ def do_handle_task(task):
493
  TaskService.update_chunk_ids(task["id"], chunk_ids_str)
494
  except DoesNotExist:
495
  logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
496
- doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id), task_dataset_id)
 
497
  return
498
- logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts))
 
 
499
 
500
  DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
501
 
502
  time_cost = timer() - start_ts
503
  progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
504
- logging.info("Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page, task_to_page, len(chunks), token_count, time_cost))
 
 
 
505
 
506
 
507
  def handle_task():
 
16
  # from beartype import BeartypeConf
17
  # from beartype.claw import beartype_all # <-- you didn't sign up for this
18
  # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code
19
+ import random
20
  import sys
21
  from api.utils.log_utils import initRootLogger
22
+ from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache
23
 
24
  CONSUMER_NO = "0" if len(sys.argv) < 2 else sys.argv[1]
25
  CONSUMER_NAME = "task_executor_" + CONSUMER_NO
 
44
  from peewee import DoesNotExist
45
 
46
  from api.db import LLMType, ParserType, TaskStatus
47
+ from api.db.services.dialog_service import keyword_extraction, question_proposal, content_tagging
48
  from api.db.services.document_service import DocumentService
49
  from api.db.services.llm_service import LLMBundle
50
  from api.db.services.task_service import TaskService
 
53
  from api.versions import get_ragflow_version
54
  from api.db.db_models import close_connection
55
  from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
56
+ knowledge_graph, email, tag
57
  from rag.nlp import search, rag_tokenizer
58
  from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
59
+ from rag.settings import DOC_MAXIMUM_SIZE, SVR_QUEUE_NAME, print_rag_settings, TAG_FLD, PAGERANK_FLD
60
  from rag.utils import num_tokens_from_string
61
  from rag.utils.redis_conn import REDIS_CONN, Payload
62
  from rag.utils.storage_factory import STORAGE_IMPL
 
78
  ParserType.ONE.value: one,
79
  ParserType.AUDIO.value: audio,
80
  ParserType.EMAIL.value: email,
81
+ ParserType.KG.value: knowledge_graph,
82
+ ParserType.TAG.value: tag
83
  }
84
 
85
  CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
 
200
  logging.info("From minio({}) {}/{}".format(timer() - st, task["location"], task["name"]))
201
  except TimeoutError:
202
  progress_callback(-1, "Internal server error: Fetch file from minio timeout. Could you try it again.")
203
+ logging.exception(
204
+ "Minio {}/{} got timeout: Fetch file from minio timeout.".format(task["location"], task["name"]))
205
  raise
206
  except Exception as e:
207
  if re.search("(No such file|not found)", str(e)):
 
229
  "kb_id": str(task["kb_id"])
230
  }
231
  if task["pagerank"]:
232
+ doc[PAGERANK_FLD] = int(task["pagerank"])
233
  el = 0
234
  for ck in cks:
235
  d = copy.deepcopy(doc)
 
254
  STORAGE_IMPL.put(task["kb_id"], d["id"], output_buffer.getvalue())
255
  el += timer() - st
256
  except Exception:
257
+ logging.exception(
258
+ "Saving image of chunk {}/{}/{} got exception".format(task["location"], task["name"], d["id"]))
259
  raise
260
 
261
  d["img_id"] = "{}-{}".format(task["kb_id"], d["id"])
 
298
  d["question_tks"] = rag_tokenizer.tokenize("\n".join(d["question_kwd"]))
299
  progress_callback(msg="Question generation completed in {:.2f}s".format(timer() - st))
300
 
301
+ if task["kb_parser_config"].get("tag_kb_ids", []):
302
+ progress_callback(msg="Start to tag for every chunk ...")
303
+ kb_ids = task["kb_parser_config"]["tag_kb_ids"]
304
+ tenant_id = task["tenant_id"]
305
+ topn_tags = task["kb_parser_config"].get("topn_tags", 3)
306
+ S = 1000
307
+ st = timer()
308
+ examples = []
309
+ all_tags = get_tags_from_cache(kb_ids)
310
+ if not all_tags:
311
+ all_tags = settings.retrievaler.all_tags_in_portion(tenant_id, kb_ids, S)
312
+ set_tags_to_cache(kb_ids, all_tags)
313
+ else:
314
+ all_tags = json.loads(all_tags)
315
+
316
+ chat_mdl = LLMBundle(task["tenant_id"], LLMType.CHAT, llm_name=task["llm_id"], lang=task["language"])
317
+ for d in docs:
318
+ if settings.retrievaler.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S):
319
+ examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
320
+ continue
321
+ cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], all_tags, {"topn": topn_tags})
322
+ if not cached:
323
+ cached = content_tagging(chat_mdl, d["content_with_weight"], all_tags,
324
+ random.choices(examples, k=2) if len(examples)>2 else examples,
325
+ topn=topn_tags)
326
+ if cached:
327
+ set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, all_tags, {"topn": topn_tags})
328
+ d[TAG_FLD] = json.loads(cached)
329
+
330
+ progress_callback(msg="Tagging completed in {:.2f}s".format(timer() - st))
331
+
332
  return docs
333
 
334
 
335
  def init_kb(row, vector_size: int):
336
  idxnm = search.index_name(row["tenant_id"])
337
+ return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
338
 
339
 
340
  def embedding(docs, mdl, parser_config=None, callback=None):
 
415
  "title_tks": rag_tokenizer.tokenize(row["name"])
416
  }
417
  if row["pagerank"]:
418
+ doc[PAGERANK_FLD] = int(row["pagerank"])
419
  res = []
420
  tk_count = 0
421
  for content, vctr in chunks[original_length:]:
 
514
  doc_store_result = ""
515
  es_bulk_size = 4
516
  for b in range(0, len(chunks), es_bulk_size):
517
+ doc_store_result = settings.docStoreConn.insert(chunks[b:b + es_bulk_size], search.index_name(task_tenant_id),
518
+ task_dataset_id)
519
  if b % 128 == 0:
520
  progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
521
  if doc_store_result:
 
528
  TaskService.update_chunk_ids(task["id"], chunk_ids_str)
529
  except DoesNotExist:
530
  logging.warning(f"do_handle_task update_chunk_ids failed since task {task['id']} is unknown.")
531
+ doc_store_result = settings.docStoreConn.delete({"id": chunk_ids}, search.index_name(task_tenant_id),
532
+ task_dataset_id)
533
  return
534
+ logging.info("Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(task_document_name, task_from_page,
535
+ task_to_page, len(chunks),
536
+ timer() - start_ts))
537
 
538
  DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0)
539
 
540
  time_cost = timer() - start_ts
541
  progress_callback(prog=1.0, msg="Done ({:.2f}s)".format(time_cost))
542
+ logging.info(
543
+ "Chunk doc({}), page({}-{}), chunks({}), token({}), elapsed:{:.2f}".format(task_document_name, task_from_page,
544
+ task_to_page, len(chunks),
545
+ token_count, time_cost))
546
 
547
 
548
  def handle_task():
rag/utils/__init__.py CHANGED
@@ -71,11 +71,13 @@ def findMaxTm(fnm):
71
  pass
72
  return m
73
 
 
74
  tiktoken_cache_dir = get_project_base_directory()
75
  os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
76
  # encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
77
  encoder = tiktoken.get_encoding("cl100k_base")
78
 
 
79
  def num_tokens_from_string(string: str) -> int:
80
  """Returns the number of tokens in a text string."""
81
  try:
 
71
  pass
72
  return m
73
 
74
+
75
  tiktoken_cache_dir = get_project_base_directory()
76
  os.environ["TIKTOKEN_CACHE_DIR"] = tiktoken_cache_dir
77
  # encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
78
  encoder = tiktoken.get_encoding("cl100k_base")
79
 
80
+
81
  def num_tokens_from_string(string: str) -> int:
82
  """Returns the number of tokens in a text string."""
83
  try:
rag/utils/doc_store_conn.py CHANGED
@@ -176,7 +176,17 @@ class DocStoreConnection(ABC):
176
 
177
  @abstractmethod
178
  def search(
179
- self, selectFields: list[str], highlight: list[str], condition: dict, matchExprs: list[MatchExpr], orderBy: OrderByExpr, offset: int, limit: int, indexNames: str|list[str], knowledgebaseIds: list[str]
 
 
 
 
 
 
 
 
 
 
180
  ) -> list[dict] | pl.DataFrame:
181
  """
182
  Search with given conjunctive equivalent filtering condition and return all fields of matched documents
@@ -191,7 +201,7 @@ class DocStoreConnection(ABC):
191
  raise NotImplementedError("Not implemented")
192
 
193
  @abstractmethod
194
- def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
195
  """
196
  Update or insert a bulk of rows
197
  """
 
176
 
177
  @abstractmethod
178
  def search(
179
+ self, selectFields: list[str],
180
+ highlightFields: list[str],
181
+ condition: dict,
182
+ matchExprs: list[MatchExpr],
183
+ orderBy: OrderByExpr,
184
+ offset: int,
185
+ limit: int,
186
+ indexNames: str|list[str],
187
+ knowledgebaseIds: list[str],
188
+ aggFields: list[str] = [],
189
+ rank_feature: dict | None = None
190
  ) -> list[dict] | pl.DataFrame:
191
  """
192
  Search with given conjunctive equivalent filtering condition and return all fields of matched documents
 
201
  raise NotImplementedError("Not implemented")
202
 
203
  @abstractmethod
204
+ def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
205
  """
206
  Update or insert a bulk of rows
207
  """
rag/utils/es_conn.py CHANGED
@@ -9,6 +9,7 @@ from elasticsearch import Elasticsearch, NotFoundError
9
  from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
10
  from elastic_transport import ConnectionTimeout
11
  from rag import settings
 
12
  from rag.utils import singleton
13
  from api.utils.file_utils import get_project_base_directory
14
  import polars as pl
@@ -20,6 +21,7 @@ ATTEMPT_TIME = 2
20
 
21
  logger = logging.getLogger('ragflow.es_conn')
22
 
 
23
  @singleton
24
  class ESConnection(DocStoreConnection):
25
  def __init__(self):
@@ -111,9 +113,19 @@ class ESConnection(DocStoreConnection):
111
  CRUD operations
112
  """
113
 
114
- def search(self, selectFields: list[str], highlightFields: list[str], condition: dict, matchExprs: list[MatchExpr],
115
- orderBy: OrderByExpr, offset: int, limit: int, indexNames: str | list[str],
116
- knowledgebaseIds: list[str]) -> list[dict] | pl.DataFrame:
 
 
 
 
 
 
 
 
 
 
117
  """
118
  Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
119
  """
@@ -175,8 +187,13 @@ class ESConnection(DocStoreConnection):
175
  similarity=similarity,
176
  )
177
 
 
 
 
 
 
 
178
  if bqry:
179
- bqry.should.append(Q("rank_feature", field="pagerank_fea", linear={}, boost=10))
180
  s = s.query(bqry)
181
  for field in highlightFields:
182
  s = s.highlight(field)
@@ -187,7 +204,7 @@ class ESConnection(DocStoreConnection):
187
  order = "asc" if order == 0 else "desc"
188
  if field in ["page_num_int", "top_int"]:
189
  order_info = {"order": order, "unmapped_type": "float",
190
- "mode": "avg", "numeric_type": "double"}
191
  elif field.endswith("_int") or field.endswith("_flt"):
192
  order_info = {"order": order, "unmapped_type": "float"}
193
  else:
@@ -195,8 +212,11 @@ class ESConnection(DocStoreConnection):
195
  orders.append({field: order_info})
196
  s = s.sort(*orders)
197
 
 
 
 
198
  if limit > 0:
199
- s = s[offset:offset+limit]
200
  q = s.to_dict()
201
  logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
202
 
@@ -240,7 +260,7 @@ class ESConnection(DocStoreConnection):
240
  logger.error("ESConnection.get timeout for 3 times!")
241
  raise Exception("ESConnection.get timeout.")
242
 
243
- def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str) -> list[str]:
244
  # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
245
  operations = []
246
  for d in documents:
@@ -292,44 +312,57 @@ class ESConnection(DocStoreConnection):
292
  if str(e).find("Timeout") > 0:
293
  continue
294
  return False
295
- else:
296
- # update unspecific maybe-multiple documents
297
- bqry = Q("bool")
298
- for k, v in condition.items():
299
- if not isinstance(k, str) or not v:
300
- continue
301
- if k == "exist":
302
- bqry.filter.append(Q("exists", field=v))
303
- continue
304
- if isinstance(v, list):
305
- bqry.filter.append(Q("terms", **{k: v}))
306
- elif isinstance(v, str) or isinstance(v, int):
307
- bqry.filter.append(Q("term", **{k: v}))
308
- else:
309
- raise Exception(
310
- f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
311
- scripts = []
312
- for k, v in newValue.items():
313
- if k == "remove":
314
- scripts.append(f"ctx._source.remove('{v}');")
315
- continue
316
- if (not isinstance(k, str) or not v) and k != "available_int":
317
- continue
318
  if isinstance(v, str):
319
- scripts.append(f"ctx._source.{k} = '{v}'")
320
- elif isinstance(v, int):
321
- scripts.append(f"ctx._source.{k} = {v}")
322
- else:
323
- raise Exception(
324
- f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  ubq = UpdateByQuery(
326
  index=indexName).using(
327
  self.es).query(bqry)
328
- ubq = ubq.script(source="; ".join(scripts))
329
  ubq = ubq.params(refresh=True)
330
  ubq = ubq.params(slices=5)
331
  ubq = ubq.params(conflicts="proceed")
332
- for i in range(3):
 
333
  try:
334
  _ = ubq.execute()
335
  return True
 
9
  from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
10
  from elastic_transport import ConnectionTimeout
11
  from rag import settings
12
+ from rag.settings import TAG_FLD, PAGERANK_FLD
13
  from rag.utils import singleton
14
  from api.utils.file_utils import get_project_base_directory
15
  import polars as pl
 
21
 
22
  logger = logging.getLogger('ragflow.es_conn')
23
 
24
+
25
  @singleton
26
  class ESConnection(DocStoreConnection):
27
  def __init__(self):
 
113
  CRUD operations
114
  """
115
 
116
+ def search(
117
+ self, selectFields: list[str],
118
+ highlightFields: list[str],
119
+ condition: dict,
120
+ matchExprs: list[MatchExpr],
121
+ orderBy: OrderByExpr,
122
+ offset: int,
123
+ limit: int,
124
+ indexNames: str | list[str],
125
+ knowledgebaseIds: list[str],
126
+ aggFields: list[str] = [],
127
+ rank_feature: dict | None = None
128
+ ) -> list[dict] | pl.DataFrame:
129
  """
130
  Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
131
  """
 
187
  similarity=similarity,
188
  )
189
 
190
+ if bqry and rank_feature:
191
+ for fld, sc in rank_feature.items():
192
+ if fld != PAGERANK_FLD:
193
+ fld = f"{TAG_FLD}.{fld}"
194
+ bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
195
+
196
  if bqry:
 
197
  s = s.query(bqry)
198
  for field in highlightFields:
199
  s = s.highlight(field)
 
204
  order = "asc" if order == 0 else "desc"
205
  if field in ["page_num_int", "top_int"]:
206
  order_info = {"order": order, "unmapped_type": "float",
207
+ "mode": "avg", "numeric_type": "double"}
208
  elif field.endswith("_int") or field.endswith("_flt"):
209
  order_info = {"order": order, "unmapped_type": "float"}
210
  else:
 
212
  orders.append({field: order_info})
213
  s = s.sort(*orders)
214
 
215
+ for fld in aggFields:
216
+ s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
217
+
218
  if limit > 0:
219
+ s = s[offset:offset + limit]
220
  q = s.to_dict()
221
  logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
222
 
 
260
  logger.error("ESConnection.get timeout for 3 times!")
261
  raise Exception("ESConnection.get timeout.")
262
 
263
+ def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
264
  # Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
265
  operations = []
266
  for d in documents:
 
312
  if str(e).find("Timeout") > 0:
313
  continue
314
  return False
315
+
316
+ # update unspecific maybe-multiple documents
317
+ bqry = Q("bool")
318
+ for k, v in condition.items():
319
+ if not isinstance(k, str) or not v:
320
+ continue
321
+ if k == "exist":
322
+ bqry.filter.append(Q("exists", field=v))
323
+ continue
324
+ if isinstance(v, list):
325
+ bqry.filter.append(Q("terms", **{k: v}))
326
+ elif isinstance(v, str) or isinstance(v, int):
327
+ bqry.filter.append(Q("term", **{k: v}))
328
+ else:
329
+ raise Exception(
330
+ f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
331
+ scripts = []
332
+ params = {}
333
+ for k, v in newValue.items():
334
+ if k == "remove":
 
 
 
335
  if isinstance(v, str):
336
+ scripts.append(f"ctx._source.remove('{v}');")
337
+ if isinstance(v, dict):
338
+ for kk, vv in v.items():
339
+ scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
340
+ params[f"p_{kk}"] = vv
341
+ continue
342
+ if k == "add":
343
+ if isinstance(v, dict):
344
+ for kk, vv in v.items():
345
+ scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
346
+ params[f"pp_{kk}"] = vv.strip()
347
+ continue
348
+ if (not isinstance(k, str) or not v) and k != "available_int":
349
+ continue
350
+ if isinstance(v, str):
351
+ scripts.append(f"ctx._source.{k} = '{v}'")
352
+ elif isinstance(v, int):
353
+ scripts.append(f"ctx._source.{k} = {v}")
354
+ else:
355
+ raise Exception(
356
+ f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
357
  ubq = UpdateByQuery(
358
  index=indexName).using(
359
  self.es).query(bqry)
360
+ ubq = ubq.script(source="".join(scripts), params=params)
361
  ubq = ubq.params(refresh=True)
362
  ubq = ubq.params(slices=5)
363
  ubq = ubq.params(conflicts="proceed")
364
+
365
+ for _ in range(ATTEMPT_TIME):
366
  try:
367
  _ = ubq.execute()
368
  return True
rag/utils/infinity_conn.py CHANGED
@@ -10,6 +10,7 @@ from infinity.index import IndexInfo, IndexType
10
  from infinity.connection_pool import ConnectionPool
11
  from infinity.errors import ErrorCode
12
  from rag import settings
 
13
  from rag.utils import singleton
14
  import polars as pl
15
  from polars.series.series import Series
@@ -231,8 +232,7 @@ class InfinityConnection(DocStoreConnection):
231
  """
232
 
233
  def search(
234
- self,
235
- selectFields: list[str],
236
  highlightFields: list[str],
237
  condition: dict,
238
  matchExprs: list[MatchExpr],
@@ -241,7 +241,9 @@ class InfinityConnection(DocStoreConnection):
241
  limit: int,
242
  indexNames: str | list[str],
243
  knowledgebaseIds: list[str],
244
- ) -> tuple[pl.DataFrame, int]:
 
 
245
  """
246
  TODO: Infinity doesn't provide highlight
247
  """
@@ -256,7 +258,7 @@ class InfinityConnection(DocStoreConnection):
256
  if essential_field not in selectFields:
257
  selectFields.append(essential_field)
258
  if matchExprs:
259
- for essential_field in ["score()", "pagerank_fea"]:
260
  selectFields.append(essential_field)
261
 
262
  # Prepare expressions common to all tables
@@ -346,7 +348,7 @@ class InfinityConnection(DocStoreConnection):
346
  self.connPool.release_conn(inf_conn)
347
  res = concat_dataframes(df_list, selectFields)
348
  if matchExprs:
349
- res = res.sort(pl.col("SCORE") + pl.col("pagerank_fea"), descending=True, maintain_order=True)
350
  res = res.limit(limit)
351
  logger.debug(f"INFINITY search final result: {str(res)}")
352
  return res, total_hits_count
@@ -378,7 +380,7 @@ class InfinityConnection(DocStoreConnection):
378
  return res_fields.get(chunkId, None)
379
 
380
  def insert(
381
- self, documents: list[dict], indexName: str, knowledgebaseId: str
382
  ) -> list[str]:
383
  inf_conn = self.connPool.get_conn()
384
  db_instance = inf_conn.get_database(self.dbName)
@@ -456,7 +458,7 @@ class InfinityConnection(DocStoreConnection):
456
  elif k in ["page_num_int", "top_int"]:
457
  assert isinstance(v, list)
458
  newValue[k] = "_".join(f"{num:08x}" for num in v)
459
- elif k == "remove" and v in ["pagerank_fea"]:
460
  del newValue[k]
461
  newValue[v] = 0
462
  logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
 
10
  from infinity.connection_pool import ConnectionPool
11
  from infinity.errors import ErrorCode
12
  from rag import settings
13
+ from rag.settings import PAGERANK_FLD
14
  from rag.utils import singleton
15
  import polars as pl
16
  from polars.series.series import Series
 
232
  """
233
 
234
  def search(
235
+ self, selectFields: list[str],
 
236
  highlightFields: list[str],
237
  condition: dict,
238
  matchExprs: list[MatchExpr],
 
241
  limit: int,
242
  indexNames: str | list[str],
243
  knowledgebaseIds: list[str],
244
+ aggFields: list[str] = [],
245
+ rank_feature: dict | None = None
246
+ ) -> list[dict] | pl.DataFrame:
247
  """
248
  TODO: Infinity doesn't provide highlight
249
  """
 
258
  if essential_field not in selectFields:
259
  selectFields.append(essential_field)
260
  if matchExprs:
261
+ for essential_field in ["score()", PAGERANK_FLD]:
262
  selectFields.append(essential_field)
263
 
264
  # Prepare expressions common to all tables
 
348
  self.connPool.release_conn(inf_conn)
349
  res = concat_dataframes(df_list, selectFields)
350
  if matchExprs:
351
+ res = res.sort(pl.col("SCORE") + pl.col(PAGERANK_FLD), descending=True, maintain_order=True)
352
  res = res.limit(limit)
353
  logger.debug(f"INFINITY search final result: {str(res)}")
354
  return res, total_hits_count
 
380
  return res_fields.get(chunkId, None)
381
 
382
  def insert(
383
+ self, documents: list[dict], indexName: str, knowledgebaseId: str = None
384
  ) -> list[str]:
385
  inf_conn = self.connPool.get_conn()
386
  db_instance = inf_conn.get_database(self.dbName)
 
458
  elif k in ["page_num_int", "top_int"]:
459
  assert isinstance(v, list)
460
  newValue[k] = "_".join(f"{num:08x}" for num in v)
461
+ elif k == "remove" and v in [PAGERANK_FLD]:
462
  del newValue[k]
463
  newValue[v] = 0
464
  logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
sdk/python/test/test_sdk_api/t_dataset.py CHANGED
@@ -27,7 +27,7 @@ def test_create_dataset_with_invalid_parameter(get_api_key_fixture):
27
  API_KEY = get_api_key_fixture
28
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
29
  valid_chunk_methods = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one",
30
- "knowledge_graph", "email"]
31
  chunk_method = "invalid_chunk_method"
32
  with pytest.raises(Exception) as exc_info:
33
  rag.create_dataset("test_create_dataset_with_invalid_chunk_method",chunk_method=chunk_method)
 
27
  API_KEY = get_api_key_fixture
28
  rag = RAGFlow(API_KEY, HOST_ADDRESS)
29
  valid_chunk_methods = ["naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one",
30
+ "knowledge_graph", "email", "tag"]
31
  chunk_method = "invalid_chunk_method"
32
  with pytest.raises(Exception) as exc_info:
33
  rag.create_dataset("test_create_dataset_with_invalid_chunk_method",chunk_method=chunk_method)