zhichyu commited on
Commit
68ba7f0
·
1 Parent(s): 792f830

Validate returned chunk at list_chunks and add_chunk (#4153)

Browse files

### What problem does this PR solve?

Validate returned chunk at list_chunks and add_chunk

### Type of change

- [x] Refactoring

Files changed (2) hide show
  1. api/apps/sdk/doc.py +23 -14
  2. rag/utils/infinity_conn.py +5 -3
api/apps/sdk/doc.py CHANGED
@@ -42,9 +42,30 @@ from rag.nlp import search
42
  from rag.utils import rmSpace
43
  from rag.utils.storage_factory import STORAGE_IMPL
44
 
 
 
45
  MAXIMUM_OF_UPLOADING_FILES = 256
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
49
  @token_required
50
  def upload(dataset_id, tenant_id):
@@ -848,20 +869,6 @@ def list_chunks(tenant_id, dataset_id, document_id):
848
  "available_int": sres.field[id].get("available_int", 1),
849
  "positions": sres.field[id].get("position_int", []),
850
  }
851
- if len(d["positions"]) % 5 == 0:
852
- poss = []
853
- for i in range(0, len(d["positions"]), 5):
854
- poss.append(
855
- [
856
- float(d["positions"][i]),
857
- float(d["positions"][i + 1]),
858
- float(d["positions"][i + 2]),
859
- float(d["positions"][i + 3]),
860
- float(d["positions"][i + 4]),
861
- ]
862
- )
863
- d["positions"] = poss
864
-
865
  origin_chunks.append(d)
866
  if req.get("id"):
867
  if req.get("id") == id:
@@ -892,6 +899,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
892
  if renamed_chunk["available"] == 1:
893
  renamed_chunk["available"] = True
894
  res["chunks"].append(renamed_chunk)
 
895
  return get_result(data=res)
896
 
897
 
@@ -1031,6 +1039,7 @@ def add_chunk(tenant_id, dataset_id, document_id):
1031
  if key in key_mapping:
1032
  new_key = key_mapping.get(key, key)
1033
  renamed_chunk[new_key] = value
 
1034
  return get_result(data={"chunk": renamed_chunk})
1035
  # return get_result(data={"chunk_id": chunk_id})
1036
 
 
42
  from rag.utils import rmSpace
43
  from rag.utils.storage_factory import STORAGE_IMPL
44
 
45
+ from pydantic import BaseModel, Field, validator
46
+
47
  MAXIMUM_OF_UPLOADING_FILES = 256
48
 
49
 
50
+ class Chunk(BaseModel):
51
+ id: str = ""
52
+ content: str = ""
53
+ document_id: str = ""
54
+ docnm_kwd: str = ""
55
+ important_keywords: list = Field(default_factory=list)
56
+ questions: list = Field(default_factory=list)
57
+ question_tks: str = ""
58
+ image_id: str = ""
59
+ available: bool = True
60
+ positions: list[list[int]] = Field(default_factory=list)
61
+
62
+ @validator('positions')
63
+ def validate_positions(cls, value):
64
+ for sublist in value:
65
+ if len(sublist) != 5:
66
+ raise ValueError("Each sublist in positions must have a length of 5")
67
+ return value
68
+
69
  @manager.route("/datasets/<dataset_id>/documents", methods=["POST"]) # noqa: F821
70
  @token_required
71
  def upload(dataset_id, tenant_id):
 
869
  "available_int": sres.field[id].get("available_int", 1),
870
  "positions": sres.field[id].get("position_int", []),
871
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
872
  origin_chunks.append(d)
873
  if req.get("id"):
874
  if req.get("id") == id:
 
899
  if renamed_chunk["available"] == 1:
900
  renamed_chunk["available"] = True
901
  res["chunks"].append(renamed_chunk)
902
+ _ = Chunk(**renamed_chunk) # validate the chunk
903
  return get_result(data=res)
904
 
905
 
 
1039
  if key in key_mapping:
1040
  new_key = key_mapping.get(key, key)
1041
  renamed_chunk[new_key] = value
1042
+ _ = Chunk(**renamed_chunk) # validate the chunk
1043
  return get_result(data={"chunk": renamed_chunk})
1044
  # return get_result(data={"chunk_id": chunk_id})
1045
 
rag/utils/infinity_conn.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import re
4
  import json
5
  import time
 
6
  import infinity
7
  from infinity.common import ConflictType, InfinityException, SortType
8
  from infinity.index import IndexInfo, IndexType
@@ -390,7 +391,8 @@ class InfinityConnection(DocStoreConnection):
390
  self.createIdx(indexName, knowledgebaseId, vector_size)
391
  table_instance = db_instance.get_table(table_name)
392
 
393
- for d in documents:
 
394
  assert "_id" not in d
395
  assert "id" in d
396
  for k, v in d.items():
@@ -407,14 +409,14 @@ class InfinityConnection(DocStoreConnection):
407
  elif k in ["page_num_int", "top_int"]:
408
  assert isinstance(v, list)
409
  d[k] = "_".join(f"{num:08x}" for num in v)
410
- ids = ["'{}'".format(d["id"]) for d in documents]
411
  str_ids = ", ".join(ids)
412
  str_filter = f"id IN ({str_ids})"
413
  table_instance.delete(str_filter)
414
  # for doc in documents:
415
  # logger.info(f"insert position_int: {doc['position_int']}")
416
  # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
417
- table_instance.insert(documents)
418
  self.connPool.release_conn(inf_conn)
419
  logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
420
  return []
 
3
  import re
4
  import json
5
  import time
6
+ import copy
7
  import infinity
8
  from infinity.common import ConflictType, InfinityException, SortType
9
  from infinity.index import IndexInfo, IndexType
 
391
  self.createIdx(indexName, knowledgebaseId, vector_size)
392
  table_instance = db_instance.get_table(table_name)
393
 
394
+ docs = copy.deepcopy(documents)
395
+ for d in docs:
396
  assert "_id" not in d
397
  assert "id" in d
398
  for k, v in d.items():
 
409
  elif k in ["page_num_int", "top_int"]:
410
  assert isinstance(v, list)
411
  d[k] = "_".join(f"{num:08x}" for num in v)
412
+ ids = ["'{}'".format(d["id"]) for d in docs]
413
  str_ids = ", ".join(ids)
414
  str_filter = f"id IN ({str_ids})"
415
  table_instance.delete(str_filter)
416
  # for doc in documents:
417
  # logger.info(f"insert position_int: {doc['position_int']}")
418
  # logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
419
+ table_instance.insert(docs)
420
  self.connPool.release_conn(inf_conn)
421
  logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
422
  return []