jinhai-2012 commited on
Commit
6807fac
·
1 Parent(s): b88edf5

Fix bugs (#3502)

Browse files

### What problem does this PR solve?

1. Remove unused code
2. Fix type mismatch, in nlp search and infinity search interface
3. Fix chunk list, get all chunks of this user.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Signed-off-by: jinhai <haijin.chn@gmail.com>

agent/component/base.py CHANGED
@@ -17,13 +17,13 @@ from abc import ABC
17
  import builtins
18
  import json
19
  import os
 
20
  from functools import partial
21
  from typing import Tuple, Union
22
 
23
  import pandas as pd
24
 
25
  from agent import settings
26
- from agent.settings import flow_logger, DEBUG
27
 
28
  _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
29
  _DEPRECATED_PARAMS = "_deprecated_params"
@@ -480,7 +480,6 @@ class ComponentBase(ABC):
480
 
481
  upstream_outs = []
482
 
483
- if DEBUG: print(self.component_name, reversed_cpnts[::-1])
484
  for u in reversed_cpnts[::-1]:
485
  if self.get_component_name(u) in ["switch", "concentrator"]: continue
486
  if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
 
17
  import builtins
18
  import json
19
  import os
20
+ import logging
21
  from functools import partial
22
  from typing import Tuple, Union
23
 
24
  import pandas as pd
25
 
26
  from agent import settings
 
27
 
28
  _FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
29
  _DEPRECATED_PARAMS = "_deprecated_params"
 
480
 
481
  upstream_outs = []
482
 
 
483
  for u in reversed_cpnts[::-1]:
484
  if self.get_component_name(u) in ["switch", "concentrator"]: continue
485
  if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
graphrag/search.py CHANGED
@@ -23,7 +23,7 @@ from rag.nlp.search import Dealer
23
 
24
 
25
  class KGSearch(Dealer):
26
- def search(self, req, idxnm, kb_ids, emb_mdl, highlight=False):
27
  def merge_into_first(sres, title="") -> dict[str, str]:
28
  if not sres:
29
  return {}
 
23
 
24
 
25
  class KGSearch(Dealer):
26
+ def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
27
  def merge_into_first(sres, title="") -> dict[str, str]:
28
  if not sres:
29
  return {}
rag/utils/infinity_conn.py CHANGED
@@ -4,7 +4,7 @@ import re
4
  import json
5
  import time
6
  import infinity
7
- from infinity.common import ConflictType, InfinityException
8
  from infinity.index import IndexInfo, IndexType
9
  from infinity.connection_pool import ConnectionPool
10
  from rag import settings
@@ -22,6 +22,7 @@ from rag.utils.doc_store_conn import (
22
  OrderByExpr,
23
  )
24
 
 
25
  def equivalent_condition_to_str(condition: dict) -> str:
26
  assert "_id" not in condition
27
  cond = list()
@@ -65,7 +66,7 @@ class InfinityConnection(DocStoreConnection):
65
  self.connPool = connPool
66
  break
67
  except Exception as e:
68
- logging.warn(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
69
  time.sleep(5)
70
  if self.connPool is None:
71
  msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
@@ -168,7 +169,7 @@ class InfinityConnection(DocStoreConnection):
168
  self.connPool.release_conn(inf_conn)
169
  return True
170
  except Exception as e:
171
- logging.warn(f"INFINITY indexExist {str(e)}")
172
  return False
173
 
174
  """
@@ -176,16 +177,16 @@ class InfinityConnection(DocStoreConnection):
176
  """
177
 
178
  def search(
179
- self,
180
- selectFields: list[str],
181
- highlightFields: list[str],
182
- condition: dict,
183
- matchExprs: list[MatchExpr],
184
- orderBy: OrderByExpr,
185
- offset: int,
186
- limit: int,
187
- indexNames: str|list[str],
188
- knowledgebaseIds: list[str],
189
  ) -> list[dict] | pl.DataFrame:
190
  """
191
  TODO: Infinity doesn't provide highlight
@@ -219,8 +220,8 @@ class InfinityConnection(DocStoreConnection):
219
  minimum_should_match = "0%"
220
  if "minimum_should_match" in matchExpr.extra_options:
221
  minimum_should_match = (
222
- str(int(matchExpr.extra_options["minimum_should_match"] * 100))
223
- + "%"
224
  )
225
  matchExpr.extra_options.update(
226
  {"minimum_should_match": minimum_should_match}
@@ -234,10 +235,14 @@ class InfinityConnection(DocStoreConnection):
234
  for k, v in matchExpr.extra_options.items():
235
  if not isinstance(v, str):
236
  matchExpr.extra_options[k] = str(v)
 
 
237
  if orderBy.fields:
238
- order_by_expr_list = list()
239
  for order_field in orderBy.fields:
240
- order_by_expr_list.append((order_field[0], order_field[1] == 0))
 
 
 
241
 
242
  # Scatter search tables and gather the results
243
  for indexName in indexNames:
@@ -249,28 +254,32 @@ class InfinityConnection(DocStoreConnection):
249
  continue
250
  table_list.append(table_name)
251
  builder = table_instance.output(selectFields)
252
- for matchExpr in matchExprs:
253
- if isinstance(matchExpr, MatchTextExpr):
254
- fields = ",".join(matchExpr.fields)
255
- builder = builder.match_text(
256
- fields,
257
- matchExpr.matching_text,
258
- matchExpr.topn,
259
- matchExpr.extra_options,
260
- )
261
- elif isinstance(matchExpr, MatchDenseExpr):
262
- builder = builder.match_dense(
263
- matchExpr.vector_column_name,
264
- matchExpr.embedding_data,
265
- matchExpr.embedding_data_type,
266
- matchExpr.distance_type,
267
- matchExpr.topn,
268
- matchExpr.extra_options,
269
- )
270
- elif isinstance(matchExpr, FusionExpr):
271
- builder = builder.fusion(
272
- matchExpr.method, matchExpr.topn, matchExpr.fusion_params
273
- )
 
 
 
 
274
  if orderBy.fields:
275
  builder.sort(order_by_expr_list)
276
  builder.offset(offset).limit(limit)
@@ -282,7 +291,7 @@ class InfinityConnection(DocStoreConnection):
282
  return res
283
 
284
  def get(
285
- self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
286
  ) -> dict | None:
287
  inf_conn = self.connPool.get_conn()
288
  db_instance = inf_conn.get_database(self.dbName)
@@ -299,7 +308,7 @@ class InfinityConnection(DocStoreConnection):
299
  return res_fields.get(chunkId, None)
300
 
301
  def insert(
302
- self, documents: list[dict], indexName: str, knowledgebaseId: str
303
  ) -> list[str]:
304
  inf_conn = self.connPool.get_conn()
305
  db_instance = inf_conn.get_database(self.dbName)
@@ -341,7 +350,7 @@ class InfinityConnection(DocStoreConnection):
341
  return []
342
 
343
  def update(
344
- self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
345
  ) -> bool:
346
  # if 'position_list' in newValue:
347
  # logging.info(f"upsert position_list: {newValue['position_list']}")
@@ -430,7 +439,7 @@ class InfinityConnection(DocStoreConnection):
430
  flags=re.IGNORECASE | re.MULTILINE,
431
  )
432
  if not re.search(
433
- r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
434
  ):
435
  continue
436
  txts.append(t)
 
4
  import json
5
  import time
6
  import infinity
7
+ from infinity.common import ConflictType, InfinityException, SortType
8
  from infinity.index import IndexInfo, IndexType
9
  from infinity.connection_pool import ConnectionPool
10
  from rag import settings
 
22
  OrderByExpr,
23
  )
24
 
25
+
26
  def equivalent_condition_to_str(condition: dict) -> str:
27
  assert "_id" not in condition
28
  cond = list()
 
66
  self.connPool = connPool
67
  break
68
  except Exception as e:
69
+ logging.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
70
  time.sleep(5)
71
  if self.connPool is None:
72
  msg = f"Infinity {infinity_uri} didn't become healthy in 120s."
 
169
  self.connPool.release_conn(inf_conn)
170
  return True
171
  except Exception as e:
172
+ logging.warning(f"INFINITY indexExist {str(e)}")
173
  return False
174
 
175
  """
 
177
  """
178
 
179
  def search(
180
+ self,
181
+ selectFields: list[str],
182
+ highlightFields: list[str],
183
+ condition: dict,
184
+ matchExprs: list[MatchExpr],
185
+ orderBy: OrderByExpr,
186
+ offset: int,
187
+ limit: int,
188
+ indexNames: str | list[str],
189
+ knowledgebaseIds: list[str],
190
  ) -> list[dict] | pl.DataFrame:
191
  """
192
  TODO: Infinity doesn't provide highlight
 
220
  minimum_should_match = "0%"
221
  if "minimum_should_match" in matchExpr.extra_options:
222
  minimum_should_match = (
223
+ str(int(matchExpr.extra_options["minimum_should_match"] * 100))
224
+ + "%"
225
  )
226
  matchExpr.extra_options.update(
227
  {"minimum_should_match": minimum_should_match}
 
235
  for k, v in matchExpr.extra_options.items():
236
  if not isinstance(v, str):
237
  matchExpr.extra_options[k] = str(v)
238
+
239
+ order_by_expr_list = list()
240
  if orderBy.fields:
 
241
  for order_field in orderBy.fields:
242
+ if order_field[1] == 0:
243
+ order_by_expr_list.append((order_field[0], SortType.Asc))
244
+ else:
245
+ order_by_expr_list.append((order_field[0], SortType.Desc))
246
 
247
  # Scatter search tables and gather the results
248
  for indexName in indexNames:
 
254
  continue
255
  table_list.append(table_name)
256
  builder = table_instance.output(selectFields)
257
+ if len(matchExprs) > 0:
258
+ for matchExpr in matchExprs:
259
+ if isinstance(matchExpr, MatchTextExpr):
260
+ fields = ",".join(matchExpr.fields)
261
+ builder = builder.match_text(
262
+ fields,
263
+ matchExpr.matching_text,
264
+ matchExpr.topn,
265
+ matchExpr.extra_options,
266
+ )
267
+ elif isinstance(matchExpr, MatchDenseExpr):
268
+ builder = builder.match_dense(
269
+ matchExpr.vector_column_name,
270
+ matchExpr.embedding_data,
271
+ matchExpr.embedding_data_type,
272
+ matchExpr.distance_type,
273
+ matchExpr.topn,
274
+ matchExpr.extra_options,
275
+ )
276
+ elif isinstance(matchExpr, FusionExpr):
277
+ builder = builder.fusion(
278
+ matchExpr.method, matchExpr.topn, matchExpr.fusion_params
279
+ )
280
+ else:
281
+ if len(filter_cond) > 0:
282
+ builder.filter(filter_cond)
283
  if orderBy.fields:
284
  builder.sort(order_by_expr_list)
285
  builder.offset(offset).limit(limit)
 
291
  return res
292
 
293
  def get(
294
+ self, chunkId: str, indexName: str, knowledgebaseIds: list[str]
295
  ) -> dict | None:
296
  inf_conn = self.connPool.get_conn()
297
  db_instance = inf_conn.get_database(self.dbName)
 
308
  return res_fields.get(chunkId, None)
309
 
310
  def insert(
311
+ self, documents: list[dict], indexName: str, knowledgebaseId: str
312
  ) -> list[str]:
313
  inf_conn = self.connPool.get_conn()
314
  db_instance = inf_conn.get_database(self.dbName)
 
350
  return []
351
 
352
  def update(
353
+ self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str
354
  ) -> bool:
355
  # if 'position_list' in newValue:
356
  # logging.info(f"upsert position_list: {newValue['position_list']}")
 
439
  flags=re.IGNORECASE | re.MULTILINE,
440
  )
441
  if not re.search(
442
+ r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE
443
  ):
444
  continue
445
  txts.append(t)