jin commited on
Commit
bb3f996
·
1 Parent(s): 7ff41e3

Update oracle_impl.py

Browse files
Files changed (1) hide show
  1. lightrag/kg/oracle_impl.py +167 -110
lightrag/kg/oracle_impl.py CHANGED
@@ -114,9 +114,7 @@ class OracleDB:
114
 
115
  logger.info("Finished check all tables in Oracle database")
116
 
117
- async def query(
118
- self, sql: str, params: dict = None, multirows: bool = False
119
- ) -> Union[dict, None]:
120
  async with self.pool.acquire() as connection:
121
  connection.inputtypehandler = self.input_type_handler
122
  connection.outputtypehandler = self.output_type_handler
@@ -175,11 +173,10 @@ class OracleKVStorage(BaseKVStorage):
175
 
176
  async def get_by_id(self, id: str) -> Union[dict, None]:
177
  """根据 id 获取 doc_full 数据."""
178
- SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
179
- workspace=self.db.workspace, id=id
180
- )
181
  # print("get_by_id:"+SQL)
182
- res = await self.db.query(SQL)
183
  if res:
184
  data = res # {"data":res}
185
  # print (data)
@@ -190,11 +187,11 @@ class OracleKVStorage(BaseKVStorage):
190
  # Query by id
191
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
192
  """根据 id 获取 doc_chunks 数据"""
193
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
194
- workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
195
- )
196
- # print("get_by_ids:"+SQL)
197
- res = await self.db.query(SQL, multirows=True)
198
  if res:
199
  data = res # [{"data":i} for i in res]
200
  # print(data)
@@ -204,12 +201,16 @@ class OracleKVStorage(BaseKVStorage):
204
 
205
  async def filter_keys(self, keys: list[str]) -> set[str]:
206
  """过滤掉重复内容"""
207
- SQL = SQL_TEMPLATES["filter_keys"].format(
208
- table_name=N_T[self.namespace],
209
- workspace=self.db.workspace,
210
- ids=",".join([f"'{k}'" for k in keys]),
211
- )
212
- res = await self.db.query(SQL, multirows=True)
 
 
 
 
213
  data = None
214
  if res:
215
  exist_keys = [key["id"] for key in res]
@@ -246,29 +247,31 @@ class OracleKVStorage(BaseKVStorage):
246
  d["__vector__"] = embeddings[i]
247
  # print(list_data)
248
  for item in list_data:
249
- merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
250
-
251
- values = [
252
- item["__id__"],
253
- item["content"],
254
- self.db.workspace,
255
- item["tokens"],
256
- item["chunk_order_index"],
257
- item["full_doc_id"],
258
- item["__vector__"],
259
- ]
260
  # print(merge_sql)
261
- await self.db.execute(merge_sql, values)
262
 
263
  if self.namespace == "full_docs":
264
  for k, v in self._data.items():
265
  # values.clear()
266
- merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
267
- check_id=k,
268
- )
269
- values = [k, self._data[k]["content"], self.db.workspace]
 
 
 
270
  # print(merge_sql)
271
- await self.db.execute(merge_sql, values)
272
  return left_data
273
 
274
  async def index_done_callback(self):
@@ -298,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
298
  # 转换精度
299
  dtype = str(embedding.dtype).upper()
300
  dimension = embedding.shape[0]
301
- embedding_string = ", ".join(map(str, embedding.tolist()))
302
-
303
- SQL = SQL_TEMPLATES[self.namespace].format(
304
- embedding_string=embedding_string,
305
- dimension=dimension,
306
- dtype=dtype,
307
- workspace=self.db.workspace,
308
- top_k=top_k,
309
- better_than_threshold=self.cosine_better_than_threshold,
310
- )
311
  # print(SQL)
312
- results = await self.db.query(SQL, multirows=True)
313
  # print("vector search result:",results)
314
  return results
315
 
@@ -344,22 +346,18 @@ class OracleGraphStorage(BaseGraphStorage):
344
  )
345
  embeddings = np.concatenate(embeddings_list)
346
  content_vector = embeddings[0]
347
- merge_sql = SQL_TEMPLATES["merge_node"].format(
348
- workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
349
- )
 
 
 
 
 
 
 
350
  # print(merge_sql)
351
- await self.db.execute(
352
- merge_sql,
353
- [
354
- self.db.workspace,
355
- entity_name,
356
- entity_type,
357
- description,
358
- source_id,
359
- content,
360
- content_vector,
361
- ],
362
- )
363
  # self._graph.add_node(node_id, **node_data)
364
 
365
  async def upsert_edge(
@@ -373,6 +371,8 @@ class OracleGraphStorage(BaseGraphStorage):
373
  keywords = edge_data["keywords"]
374
  description = edge_data["description"]
375
  source_chunk_id = edge_data["source_id"]
 
 
376
  content = keywords + source_name + target_name + description
377
  contents = [content]
378
  batches = [
@@ -384,27 +384,20 @@ class OracleGraphStorage(BaseGraphStorage):
384
  )
385
  embeddings = np.concatenate(embeddings_list)
386
  content_vector = embeddings[0]
387
- merge_sql = SQL_TEMPLATES["merge_edge"].format(
388
- workspace=self.db.workspace,
389
- source_name=source_name,
390
- target_name=target_name,
391
- source_chunk_id=source_chunk_id,
392
- )
 
 
 
 
 
 
393
  # print(merge_sql)
394
- await self.db.execute(
395
- merge_sql,
396
- [
397
- self.db.workspace,
398
- source_name,
399
- target_name,
400
- weight,
401
- keywords,
402
- description,
403
- source_chunk_id,
404
- content,
405
- content_vector,
406
- ],
407
- )
408
  # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
409
 
410
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -434,12 +427,14 @@ class OracleGraphStorage(BaseGraphStorage):
434
  #################### query method #################
435
  async def has_node(self, node_id: str) -> bool:
436
  """根据节点id检查节点是否存在"""
437
- SQL = SQL_TEMPLATES["has_node"].format(
438
- workspace=self.db.workspace, node_id=node_id
439
- )
 
 
440
  # print(SQL)
441
  # print(self.db.workspace, node_id)
442
- res = await self.db.query(SQL)
443
  if res:
444
  # print("Node exist!",res)
445
  return True
@@ -449,13 +444,14 @@ class OracleGraphStorage(BaseGraphStorage):
449
 
450
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
451
  """根据源和��标节点id检查边是否存在"""
452
- SQL = SQL_TEMPLATES["has_edge"].format(
453
- workspace=self.db.workspace,
454
- source_node_id=source_node_id,
455
- target_node_id=target_node_id,
456
- )
 
457
  # print(SQL)
458
- res = await self.db.query(SQL)
459
  if res:
460
  # print("Edge exist!",res)
461
  return True
@@ -465,11 +461,13 @@ class OracleGraphStorage(BaseGraphStorage):
465
 
466
  async def node_degree(self, node_id: str) -> int:
467
  """根据节点id获取节点的度"""
468
- SQL = SQL_TEMPLATES["node_degree"].format(
469
- workspace=self.db.workspace, node_id=node_id
470
- )
 
 
471
  # print(SQL)
472
- res = await self.db.query(SQL)
473
  if res:
474
  # print("Node degree",res["degree"])
475
  return res["degree"]
@@ -485,12 +483,14 @@ class OracleGraphStorage(BaseGraphStorage):
485
 
486
  async def get_node(self, node_id: str) -> Union[dict, None]:
487
  """根据节点id获取节点数据"""
488
- SQL = SQL_TEMPLATES["get_node"].format(
489
- workspace=self.db.workspace, node_id=node_id
490
- )
 
 
491
  # print(self.db.workspace, node_id)
492
  # print(SQL)
493
- res = await self.db.query(SQL)
494
  if res:
495
  # print("Get node!",self.db.workspace, node_id,res)
496
  return res
@@ -502,12 +502,13 @@ class OracleGraphStorage(BaseGraphStorage):
502
  self, source_node_id: str, target_node_id: str
503
  ) -> Union[dict, None]:
504
  """根据源和目标节点id获取边"""
505
- SQL = SQL_TEMPLATES["get_edge"].format(
506
- workspace=self.db.workspace,
507
- source_node_id=source_node_id,
508
- target_node_id=target_node_id,
509
- )
510
- res = await self.db.query(SQL)
 
511
  if res:
512
  # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
513
  return res
@@ -518,10 +519,12 @@ class OracleGraphStorage(BaseGraphStorage):
518
  async def get_node_edges(self, source_node_id: str):
519
  """根据节点id获取节点的所有边"""
520
  if await self.has_node(source_node_id):
521
- SQL = SQL_TEMPLATES["get_node_edges"].format(
522
- workspace=self.db.workspace, source_node_id=source_node_id
523
- )
524
- res = await self.db.query(sql=SQL, multirows=True)
 
 
525
  if res:
526
  data = [(i["source_name"], i["target_name"]) for i in res]
527
  # print("Get node edge!",self.db.workspace, source_node_id,data)
@@ -529,7 +532,29 @@ class OracleGraphStorage(BaseGraphStorage):
529
  else:
530
  # print("Node Edge not exist!",self.db.workspace, source_node_id)
531
  return []
 
 
 
 
 
 
 
 
532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
 
534
  N_T = {
535
  "full_docs": "LIGHTRAG_DOC_FULL",
@@ -701,5 +726,37 @@ SQL_TEMPLATES = {
701
  ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
702
  WHEN NOT MATCHED THEN
703
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
704
- values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
  }
 
114
 
115
  logger.info("Finished check all tables in Oracle database")
116
 
117
+ async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
 
 
118
  async with self.pool.acquire() as connection:
119
  connection.inputtypehandler = self.input_type_handler
120
  connection.outputtypehandler = self.output_type_handler
 
173
 
174
  async def get_by_id(self, id: str) -> Union[dict, None]:
175
  """根据 id 获取 doc_full 数据."""
176
+ SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
177
+ params = {"workspace":self.db.workspace, "id":id}
 
178
  # print("get_by_id:"+SQL)
179
+ res = await self.db.query(SQL,params)
180
  if res:
181
  data = res # {"data":res}
182
  # print (data)
 
187
  # Query by id
188
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
189
  """根据 id 获取 doc_chunks 数据"""
190
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
191
+ params = {"workspace":self.db.workspace}
192
+ #print("get_by_ids:"+SQL)
193
+ #print(params)
194
+ res = await self.db.query(SQL,params, multirows=True)
195
  if res:
196
  data = res # [{"data":i} for i in res]
197
  # print(data)
 
201
 
202
  async def filter_keys(self, keys: list[str]) -> set[str]:
203
  """过滤掉重复内容"""
204
+ SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
205
+ ids=",".join([f"'{id}'" for id in keys]))
206
+ params = {"workspace":self.db.workspace}
207
+ try:
208
+ await self.db.query(SQL, params)
209
+ except Exception as e:
210
+ logger.error(f"Oracle database error: {e}")
211
+ print(SQL)
212
+ print(params)
213
+ res = await self.db.query(SQL, params,multirows=True)
214
  data = None
215
  if res:
216
  exist_keys = [key["id"] for key in res]
 
247
  d["__vector__"] = embeddings[i]
248
  # print(list_data)
249
  for item in list_data:
250
+ merge_sql = SQL_TEMPLATES["merge_chunk"]
251
+ data = {"check_id":item["__id__"],
252
+ "id":item["__id__"],
253
+ "content":item["content"],
254
+ "workspace":self.db.workspace,
255
+ "tokens":item["tokens"],
256
+ "chunk_order_index":item["chunk_order_index"],
257
+ "full_doc_id":item["full_doc_id"],
258
+ "content_vector":item["__vector__"]
259
+ }
 
260
  # print(merge_sql)
261
+ await self.db.execute(merge_sql, data)
262
 
263
  if self.namespace == "full_docs":
264
  for k, v in self._data.items():
265
  # values.clear()
266
+ merge_sql = SQL_TEMPLATES["merge_doc_full"]
267
+ data = {
268
+ "check_id":k,
269
+ "id":k,
270
+ "content":v["content"],
271
+ "workspace":self.db.workspace
272
+ }
273
  # print(merge_sql)
274
+ await self.db.execute(merge_sql, data)
275
  return left_data
276
 
277
  async def index_done_callback(self):
 
301
  # 转换精度
302
  dtype = str(embedding.dtype).upper()
303
  dimension = embedding.shape[0]
304
+ embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
305
+
306
+ SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
307
+ params = {
308
+ "embedding_string": embedding_string,
309
+ "workspace": self.db.workspace,
310
+ "top_k": top_k,
311
+ "better_than_threshold": self.cosine_better_than_threshold,
312
+ }
 
313
  # print(SQL)
314
+ results = await self.db.query(SQL,params=params, multirows=True)
315
  # print("vector search result:",results)
316
  return results
317
 
 
346
  )
347
  embeddings = np.concatenate(embeddings_list)
348
  content_vector = embeddings[0]
349
+ merge_sql = SQL_TEMPLATES["merge_node"]
350
+ data = {
351
+ "workspace":self.db.workspace,
352
+ "name":entity_name,
353
+ "entity_type":entity_type,
354
+ "description":description,
355
+ "source_chunk_id":source_id,
356
+ "content":content,
357
+ "content_vector":content_vector
358
+ }
359
  # print(merge_sql)
360
+ await self.db.execute(merge_sql,data)
 
 
 
 
 
 
 
 
 
 
 
361
  # self._graph.add_node(node_id, **node_data)
362
 
363
  async def upsert_edge(
 
371
  keywords = edge_data["keywords"]
372
  description = edge_data["description"]
373
  source_chunk_id = edge_data["source_id"]
374
+ logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
375
+
376
  content = keywords + source_name + target_name + description
377
  contents = [content]
378
  batches = [
 
384
  )
385
  embeddings = np.concatenate(embeddings_list)
386
  content_vector = embeddings[0]
387
+ merge_sql = SQL_TEMPLATES["merge_edge"]
388
+ data = {
389
+ "workspace":self.db.workspace,
390
+ "source_name":source_name,
391
+ "target_name":target_name,
392
+ "weight":weight,
393
+ "keywords":keywords,
394
+ "description":description,
395
+ "source_chunk_id":source_chunk_id,
396
+ "content":content,
397
+ "content_vector":content_vector
398
+ }
399
  # print(merge_sql)
400
+ await self.db.execute(merge_sql,data)
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
402
 
403
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
 
427
  #################### query method #################
428
  async def has_node(self, node_id: str) -> bool:
429
  """根据节点id检查节点是否存在"""
430
+ SQL = SQL_TEMPLATES["has_node"]
431
+ params = {
432
+ "workspace":self.db.workspace,
433
+ "node_id":node_id
434
+ }
435
  # print(SQL)
436
  # print(self.db.workspace, node_id)
437
+ res = await self.db.query(SQL,params)
438
  if res:
439
  # print("Node exist!",res)
440
  return True
 
444
 
445
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
446
  """根据源和��标节点id检查边是否存在"""
447
+ SQL = SQL_TEMPLATES["has_edge"]
448
+ params = {
449
+ "workspace":self.db.workspace,
450
+ "source_node_id":source_node_id,
451
+ "target_node_id":target_node_id
452
+ }
453
  # print(SQL)
454
+ res = await self.db.query(SQL,params)
455
  if res:
456
  # print("Edge exist!",res)
457
  return True
 
461
 
462
  async def node_degree(self, node_id: str) -> int:
463
  """根据节点id获取节点的度"""
464
+ SQL = SQL_TEMPLATES["node_degree"]
465
+ params = {
466
+ "workspace":self.db.workspace,
467
+ "node_id":node_id
468
+ }
469
  # print(SQL)
470
+ res = await self.db.query(SQL,params)
471
  if res:
472
  # print("Node degree",res["degree"])
473
  return res["degree"]
 
483
 
484
  async def get_node(self, node_id: str) -> Union[dict, None]:
485
  """根据节点id获取节点数据"""
486
+ SQL = SQL_TEMPLATES["get_node"]
487
+ params = {
488
+ "workspace":self.db.workspace,
489
+ "node_id":node_id
490
+ }
491
  # print(self.db.workspace, node_id)
492
  # print(SQL)
493
+ res = await self.db.query(SQL,params)
494
  if res:
495
  # print("Get node!",self.db.workspace, node_id,res)
496
  return res
 
502
  self, source_node_id: str, target_node_id: str
503
  ) -> Union[dict, None]:
504
  """根据源和目标节点id获取边"""
505
+ SQL = SQL_TEMPLATES["get_edge"]
506
+ params = {
507
+ "workspace":self.db.workspace,
508
+ "source_node_id":source_node_id,
509
+ "target_node_id":target_node_id
510
+ }
511
+ res = await self.db.query(SQL,params)
512
  if res:
513
  # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
514
  return res
 
519
  async def get_node_edges(self, source_node_id: str):
520
  """根据节点id获取节点的所有边"""
521
  if await self.has_node(source_node_id):
522
+ SQL = SQL_TEMPLATES["get_node_edges"]
523
+ params = {
524
+ "workspace":self.db.workspace,
525
+ "source_node_id":source_node_id
526
+ }
527
+ res = await self.db.query(sql=SQL, params=params, multirows=True)
528
  if res:
529
  data = [(i["source_name"], i["target_name"]) for i in res]
530
  # print("Get node edge!",self.db.workspace, source_node_id,data)
 
532
  else:
533
  # print("Node Edge not exist!",self.db.workspace, source_node_id)
534
  return []
535
+
536
+ async def get_all_nodes(self, limit: int):
537
+ """查询所有节点"""
538
+ SQL = SQL_TEMPLATES["get_all_nodes"]
539
+ params = {"workspace":self.db.workspace, "limit":str(limit)}
540
+ res = await self.db.query(sql=SQL,params=params, multirows=True)
541
+ if res:
542
+ return res
543
 
544
+ async def get_all_edges(self, limit: int):
545
+ """查询所有边"""
546
+ SQL = SQL_TEMPLATES["get_all_edges"]
547
+ params = {"workspace":self.db.workspace, "limit":str(limit)}
548
+ res = await self.db.query(sql=SQL,params=params, multirows=True)
549
+ if res:
550
+ return res
551
+
552
+ async def get_statistics(self):
553
+ SQL = SQL_TEMPLATES["get_statistics"]
554
+ params = {"workspace":self.db.workspace}
555
+ res = await self.db.query(sql=SQL,params=params, multirows=True)
556
+ if res:
557
+ return res
558
 
559
  N_T = {
560
  "full_docs": "LIGHTRAG_DOC_FULL",
 
726
  ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
727
  WHEN NOT MATCHED THEN
728
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
729
+ values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
730
+ "get_all_nodes":"""WITH t0 AS (
731
+ SELECT name AS id, entity_type AS label, entity_type, description,
732
+ '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
733
+ FROM lightrag_graph_nodes
734
+ WHERE workspace = :workspace
735
+ ORDER BY createtime DESC fetch first :limit rows only
736
+ ), t1 AS (
737
+ SELECT t0.id, source_chunk_id
738
+ FROM t0, JSON_TABLE ( source_chunk_ids, '$[*]' COLUMNS ( source_chunk_id PATH '$' ) )
739
+ ), t2 AS (
740
+ SELECT t1.id, LISTAGG(t2.content, '\n') content
741
+ FROM t1 LEFT JOIN lightrag_doc_chunks t2 ON t1.source_chunk_id = t2.id
742
+ GROUP BY t1.id
743
+ )
744
+ SELECT t0.id, label, entity_type, description, t2.content
745
+ FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
746
+ "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
747
+ t1.weight,t1.DESCRIPTION,t2.content
748
+ FROM LIGHTRAG_GRAPH_EDGES t1
749
+ LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
750
+ WHERE t1.workspace=:workspace
751
+ order by t1.CREATETIME DESC
752
+ fetch first :limit rows only""",
753
+ "get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
754
+ count(distinct CASE WHEN type='edge' THEN id END) as edges_count
755
+ FROM (
756
+ select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
757
+ MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
758
+ UNION
759
+ select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
760
+ MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
761
+ )""",
762
  }