Commit
·
6807fac
1
Parent(s):
b88edf5
Fix bugs (#3502)
Browse files### What problem does this PR solve?
1. Remove unused code
2. Fix type mismatch, in nlp search and infinity search interface
3. Fix chunk list, get all chunks of this user.
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
---------
Signed-off-by: jinhai <haijin.chn@gmail.com>
- agent/component/base.py +1 -2
- graphrag/search.py +1 -1
- rag/utils/infinity_conn.py +52 -43
agent/component/base.py
CHANGED
@@ -17,13 +17,13 @@ from abc import ABC
|
|
17 |
import builtins
|
18 |
import json
|
19 |
import os
|
|
|
20 |
from functools import partial
|
21 |
from typing import Tuple, Union
|
22 |
|
23 |
import pandas as pd
|
24 |
|
25 |
from agent import settings
|
26 |
-
from agent.settings import flow_logger, DEBUG
|
27 |
|
28 |
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
29 |
_DEPRECATED_PARAMS = "_deprecated_params"
|
@@ -480,7 +480,6 @@ class ComponentBase(ABC):
|
|
480 |
|
481 |
upstream_outs = []
|
482 |
|
483 |
-
if DEBUG: print(self.component_name, reversed_cpnts[::-1])
|
484 |
for u in reversed_cpnts[::-1]:
|
485 |
if self.get_component_name(u) in ["switch", "concentrator"]: continue
|
486 |
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
|
|
17 |
import builtins
|
18 |
import json
|
19 |
import os
|
20 |
+
import logging
|
21 |
from functools import partial
|
22 |
from typing import Tuple, Union
|
23 |
|
24 |
import pandas as pd
|
25 |
|
26 |
from agent import settings
|
|
|
27 |
|
28 |
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
|
29 |
_DEPRECATED_PARAMS = "_deprecated_params"
|
|
|
480 |
|
481 |
upstream_outs = []
|
482 |
|
|
|
483 |
for u in reversed_cpnts[::-1]:
|
484 |
if self.get_component_name(u) in ["switch", "concentrator"]: continue
|
485 |
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
|
graphrag/search.py
CHANGED
@@ -23,7 +23,7 @@ from rag.nlp.search import Dealer
|
|
23 |
|
24 |
|
25 |
class KGSearch(Dealer):
|
26 |
-
def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
|
27 |
def merge_into_first(sres, title="") -> dict[str, str]:
|
28 |
if not sres:
|
29 |
return {}
|
|
|
23 |
|
24 |
|
25 |
class KGSearch(Dealer):
|
26 |
+
def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
|
27 |
def merge_into_first(sres, title="") -> dict[str, str]:
|
28 |
if not sres:
|
29 |
return {}
|
rag/utils/infinity_conn.py
CHANGED
@@ -4,7 +4,7 @@ import re
|
|
4 |
import json
|
5 |
import time
|
6 |
import infinity
|
7 |
-
from infinity.common import ConflictType, InfinityException
|
8 |
from infinity.index import IndexInfo, IndexType
|
9 |
from infinity.connection_pool import ConnectionPool
|
10 |
from rag import settings
|
@@ -22,6 +22,7 @@ from rag.utils.doc_store_conn import (
|
|
22 |
OrderByExpr,
|
23 |
)
|
24 |
|
|
|
25 |
def equivalent_condition_to_str(condition: dict) -> str:
|
26 |
assert "_id" not in condition
|
27 |
cond = list()
|
@@ -65,7 +66,7 @@ class InfinityConnection(DocStoreConnection):
|
|
65 |
self.connPool = connPool
|
66 |
break
|
67 |
except Exception as e:
|
68 |
-
logging.
|
69 |
time.sleep(5)
|
70 |
if self.connPool is None:
|
71 |
msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
|
@@ -168,7 +169,7 @@ class InfinityConnection(DocStoreConnection):
|
|
168 |
self.connPool.release_conn(inf_conn)
|
169 |
return True
|
170 |
except Exception as e:
|
171 |
-
logging.
|
172 |
return False
|
173 |
|
174 |
"""
|
@@ -176,16 +177,16 @@ class InfinityConnection(DocStoreConnection):
|
|
176 |
"""
|
177 |
|
178 |
def search(
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
) -> list[dict] | pl.DataFrame:
|
190 |
"""
|
191 |
TODO: Infinity doesn't provide highlight
|
@@ -219,8 +220,8 @@ class InfinityConnection(DocStoreConnection):
|
|
219 |
minimum_should_match = "0%"
|
220 |
if "minimum_should_match" in matchExpr.extra_options:
|
221 |
minimum_should_match = (
|
222 |
-
|
223 |
-
|
224 |
)
|
225 |
matchExpr.extra_options.update(
|
226 |
{"minimum_should_match": minimum_should_match}
|
@@ -234,10 +235,14 @@ class InfinityConnection(DocStoreConnection):
|
|
234 |
for k, v in matchExpr.extra_options.items():
|
235 |
if not isinstance(v, str):
|
236 |
matchExpr.extra_options[k] = str(v)
|
|
|
|
|
237 |
if orderBy.fields:
|
238 |
-
order_by_expr_list = list()
|
239 |
for order_field in orderBy.fields:
|
240 |
-
|
|
|
|
|
|
|
241 |
|
242 |
# Scatter search tables and gather the results
|
243 |
for indexName in indexNames:
|
@@ -249,28 +254,32 @@ class InfinityConnection(DocStoreConnection):
|
|
249 |
continue
|
250 |
table_list.append(table_name)
|
251 |
builder = table_instance.output(selectFields)
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
|
|
274 |
if orderBy.fields:
|
275 |
builder.sort(order_by_expr_list)
|
276 |
builder.offset(offset).limit(limit)
|
@@ -282,7 +291,7 @@ class InfinityConnection(DocStoreConnection):
|
|
282 |
return res
|
283 |
|
284 |
def get(
|
285 |
-
|
286 |
) -> dict | None:
|
287 |
inf_conn = self.connPool.get_conn()
|
288 |
db_instance = inf_conn.get_database(self.dbName)
|
@@ -299,7 +308,7 @@ class InfinityConnection(DocStoreConnection):
|
|
299 |
return res_fields.get(chunkId, None)
|
300 |
|
301 |
def insert(
|
302 |
-
|
303 |
) -> list[str]:
|
304 |
inf_conn = self.connPool.get_conn()
|
305 |
db_instance = inf_conn.get_database(self.dbName)
|
@@ -341,7 +350,7 @@ class InfinityConnection(DocStoreConnection):
|
|
341 |
return []
|
342 |
|
343 |
def update(
|
344 |
-
|
345 |
) -> bool:
|
346 |
# if 'position_list' in newValue:
|
347 |
# logging.info(f"upsert position_list: {newValue['position_list']}")
|
@@ -430,7 +439,7 @@ class InfinityConnection(DocStoreConnection):
|
|
430 |
flags=re.IGNORECASE | re.MULTILINE,
|
431 |
)
|
432 |
if not re.search(
|
433 |
-
|
434 |
):
|
435 |
continue
|
436 |
txts.append(t)
|
|
|
4 |
import json
|
5 |
import time
|
6 |
import infinity
|
7 |
+
from infinity.common import ConflictType, InfinityException, SortType
|
8 |
from infinity.index import IndexInfo, IndexType
|
9 |
from infinity.connection_pool import ConnectionPool
|
10 |
from rag import settings
|
|
|
22 |
OrderByExpr,
|
23 |
)
|
24 |
|
25 |
+
|
26 |
def equivalent_condition_to_str(condition: dict) -> str:
|
27 |
assert "_id" not in condition
|
28 |
cond = list()
|
|
|
66 |
self.connPool = connPool
|
67 |
break
|
68 |
except Exception as e:
|
69 |
+
logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
|
70 |
time.sleep(5)
|
71 |
if self.connPool is None:
|
72 |
msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
|
|
|
169 |
self.connPool.release_conn(inf_conn)
|
170 |
return True
|
171 |
except Exception as e:
|
172 |
+
logging.warning(f"INFINITY indexExist {str(e)}")
|
173 |
return False
|
174 |
|
175 |
"""
|
|
|
177 |
"""
|
178 |
|
179 |
def search(
|
180 |
+
self,
|
181 |
+
selectFields: list[str],
|
182 |
+
highlightFields: list[str],
|
183 |
+
condition: dict,
|
184 |
+
matchExprs: list[MatchExpr],
|
185 |
+
orderBy: OrderByExpr,
|
186 |
+
offset: int,
|
187 |
+
limit: int,
|
188 |
+
indexNames: str | list[str],
|
189 |
+
knowledgebaseIds: list[str],
|
190 |
) -> list[dict] | pl.DataFrame:
|
191 |
"""
|
192 |
TODO: Infinity doesn't provide highlight
|
|
|
220 |
minimum_should_match = "0%"
|
221 |
if "minimum_should_match" in matchExpr.extra_options:
|
222 |
minimum_should_match = (
|
223 |
+
str(int(matchExpr.extra_options["minimum_should_match"] * 100))
|
224 |
+
+ "%"
|
225 |
)
|
226 |
matchExpr.extra_options.update(
|
227 |
{"minimum_should_match": minimum_should_match}
|
|
|
235 |
for k, v in matchExpr.extra_options.items():
|
236 |
if not isinstance(v, str):
|
237 |
matchExpr.extra_options[k] = str(v)
|
238 |
+
|
239 |
+
order_by_expr_list = list()
|
240 |
if orderBy.fields:
|
|
|
241 |
for order_field in orderBy.fields:
|
242 |
+
if order_field[1] == 0:
|
243 |
+
order_by_expr_list.append((order_field[0], SortType.Asc))
|
244 |
+
else:
|
245 |
+
order_by_expr_list.append((order_field[0], SortType.Desc))
|
246 |
|
247 |
# Scatter search tables and gather the results
|
248 |
for indexName in indexNames:
|
|
|
254 |
continue
|
255 |
table_list.append(table_name)
|
256 |
builder = table_instance.output(selectFields)
|
257 |
+
if len(matchExprs) > 0:
|
258 |
+
for matchExpr in matchExprs:
|
259 |
+
if isinstance(matchExpr, MatchTextExpr):
|
260 |
+
fields = ",".join(matchExpr.fields)
|
261 |
+
builder = builder.match_text(
|
262 |
+
fields,
|
263 |
+
matchExpr.matching_text,
|
264 |
+
matchExpr.topn,
|
265 |
+
matchExpr.extra_options,
|
266 |
+
)
|
267 |
+
elif isinstance(matchExpr, MatchDenseExpr):
|
268 |
+
builder = builder.match_dense(
|
269 |
+
matchExpr.vector_column_name,
|
270 |
+
matchExpr.embedding_data,
|
271 |
+
matchExpr.embedding_data_type,
|
272 |
+
matchExpr.distance_type,
|
273 |
+
matchExpr.topn,
|
274 |
+
matchExpr.extra_options,
|
275 |
+
)
|
276 |
+
elif isinstance(matchExpr, FusionExpr):
|
277 |
+
builder = builder.fusion(
|
278 |
+
matchExpr.method, matchExpr.topn, matchExpr.fusion_params
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
if len(filter_cond) > 0:
|
282 |
+
builder.filter(filter_cond)
|
283 |
if orderBy.fields:
|
284 |
builder.sort(order_by_expr_list)
|
285 |
builder.offset(offset).limit(limit)
|
|
|
291 |
return res
|
292 |
|
293 |
def get(
|
294 |
+
self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
|
295 |
) -> dict | None:
|
296 |
inf_conn = self.connPool.get_conn()
|
297 |
db_instance = inf_conn.get_database(self.dbName)
|
|
|
308 |
return res_fields.get(chunkId, None)
|
309 |
|
310 |
def insert(
|
311 |
+
self, documents: list[dict], indexName: str, knowledgebaseId: str
|
312 |
) -> list[str]:
|
313 |
inf_conn = self.connPool.get_conn()
|
314 |
db_instance = inf_conn.get_database(self.dbName)
|
|
|
350 |
return []
|
351 |
|
352 |
def update(
|
353 |
+
self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
|
354 |
) -> bool:
|
355 |
# if 'position_list' in newValue:
|
356 |
# logging.info(f"upsert position_list: {newValue['position_list']}")
|
|
|
439 |
flags=re.IGNORECASE | re.MULTILINE,
|
440 |
)
|
441 |
if not re.search(
|
442 |
+
r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
|
443 |
):
|
444 |
continue
|
445 |
txts.append(t)
|