jin commited on
Commit
ac222b0
·
1 Parent(s): bb3f996

Update oracle_impl.py

Browse files
Files changed (1) hide show
  1. lightrag/kg/oracle_impl.py +91 -94
lightrag/kg/oracle_impl.py CHANGED
@@ -114,7 +114,9 @@ class OracleDB:
114
 
115
  logger.info("Finished check all tables in Oracle database")
116
 
117
- async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
 
 
118
  async with self.pool.acquire() as connection:
119
  connection.inputtypehandler = self.input_type_handler
120
  connection.outputtypehandler = self.output_type_handler
@@ -174,9 +176,9 @@ class OracleKVStorage(BaseKVStorage):
174
  async def get_by_id(self, id: str) -> Union[dict, None]:
175
  """根据 id 获取 doc_full 数据."""
176
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
177
- params = {"workspace":self.db.workspace, "id":id}
178
  # print("get_by_id:"+SQL)
179
- res = await self.db.query(SQL,params)
180
  if res:
181
  data = res # {"data":res}
182
  # print (data)
@@ -187,11 +189,13 @@ class OracleKVStorage(BaseKVStorage):
187
  # Query by id
188
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
189
  """根据 id 获取 doc_chunks 数据"""
190
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
191
- params = {"workspace":self.db.workspace}
192
- #print("get_by_ids:"+SQL)
193
- #print(params)
194
- res = await self.db.query(SQL,params, multirows=True)
 
 
195
  if res:
196
  data = res # [{"data":i} for i in res]
197
  # print(data)
@@ -201,16 +205,17 @@ class OracleKVStorage(BaseKVStorage):
201
 
202
  async def filter_keys(self, keys: list[str]) -> set[str]:
203
  """过滤掉重复内容"""
204
- SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
205
- ids=",".join([f"'{id}'" for id in keys]))
206
- params = {"workspace":self.db.workspace}
 
207
  try:
208
  await self.db.query(SQL, params)
209
  except Exception as e:
210
  logger.error(f"Oracle database error: {e}")
211
  print(SQL)
212
  print(params)
213
- res = await self.db.query(SQL, params,multirows=True)
214
  data = None
215
  if res:
216
  exist_keys = [key["id"] for key in res]
@@ -248,15 +253,16 @@ class OracleKVStorage(BaseKVStorage):
248
  # print(list_data)
249
  for item in list_data:
250
  merge_sql = SQL_TEMPLATES["merge_chunk"]
251
- data = {"check_id":item["__id__"],
252
- "id":item["__id__"],
253
- "content":item["content"],
254
- "workspace":self.db.workspace,
255
- "tokens":item["tokens"],
256
- "chunk_order_index":item["chunk_order_index"],
257
- "full_doc_id":item["full_doc_id"],
258
- "content_vector":item["__vector__"]
259
- }
 
260
  # print(merge_sql)
261
  await self.db.execute(merge_sql, data)
262
 
@@ -265,11 +271,11 @@ class OracleKVStorage(BaseKVStorage):
265
  # values.clear()
266
  merge_sql = SQL_TEMPLATES["merge_doc_full"]
267
  data = {
268
- "check_id":k,
269
- "id":k,
270
- "content":v["content"],
271
- "workspace":self.db.workspace
272
- }
273
  # print(merge_sql)
274
  await self.db.execute(merge_sql, data)
275
  return left_data
@@ -301,7 +307,7 @@ class OracleVectorDBStorage(BaseVectorStorage):
301
  # 转换精度
302
  dtype = str(embedding.dtype).upper()
303
  dimension = embedding.shape[0]
304
- embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
305
 
306
  SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
307
  params = {
@@ -309,9 +315,9 @@ class OracleVectorDBStorage(BaseVectorStorage):
309
  "workspace": self.db.workspace,
310
  "top_k": top_k,
311
  "better_than_threshold": self.cosine_better_than_threshold,
312
- }
313
  # print(SQL)
314
- results = await self.db.query(SQL,params=params, multirows=True)
315
  # print("vector search result:",results)
316
  return results
317
 
@@ -346,18 +352,18 @@ class OracleGraphStorage(BaseGraphStorage):
346
  )
347
  embeddings = np.concatenate(embeddings_list)
348
  content_vector = embeddings[0]
349
- merge_sql = SQL_TEMPLATES["merge_node"]
350
  data = {
351
- "workspace":self.db.workspace,
352
- "name":entity_name,
353
- "entity_type":entity_type,
354
- "description":description,
355
- "source_chunk_id":source_id,
356
- "content":content,
357
- "content_vector":content_vector
358
- }
359
  # print(merge_sql)
360
- await self.db.execute(merge_sql,data)
361
  # self._graph.add_node(node_id, **node_data)
362
 
363
  async def upsert_edge(
@@ -371,7 +377,9 @@ class OracleGraphStorage(BaseGraphStorage):
371
  keywords = edge_data["keywords"]
372
  description = edge_data["description"]
373
  source_chunk_id = edge_data["source_id"]
374
- logger.debug(f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}")
 
 
375
 
376
  content = keywords + source_name + target_name + description
377
  contents = [content]
@@ -384,20 +392,20 @@ class OracleGraphStorage(BaseGraphStorage):
384
  )
385
  embeddings = np.concatenate(embeddings_list)
386
  content_vector = embeddings[0]
387
- merge_sql = SQL_TEMPLATES["merge_edge"]
388
  data = {
389
- "workspace":self.db.workspace,
390
- "source_name":source_name,
391
- "target_name":target_name,
392
- "weight":weight,
393
- "keywords":keywords,
394
- "description":description,
395
- "source_chunk_id":source_chunk_id,
396
- "content":content,
397
- "content_vector":content_vector
398
- }
399
  # print(merge_sql)
400
- await self.db.execute(merge_sql,data)
401
  # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
402
 
403
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -428,13 +436,10 @@ class OracleGraphStorage(BaseGraphStorage):
428
  async def has_node(self, node_id: str) -> bool:
429
  """根据节点id检查节点是否存在"""
430
  SQL = SQL_TEMPLATES["has_node"]
431
- params = {
432
- "workspace":self.db.workspace,
433
- "node_id":node_id
434
- }
435
  # print(SQL)
436
  # print(self.db.workspace, node_id)
437
- res = await self.db.query(SQL,params)
438
  if res:
439
  # print("Node exist!",res)
440
  return True
@@ -446,12 +451,12 @@ class OracleGraphStorage(BaseGraphStorage):
446
  """根据源和目标节点id检查边是否存在"""
447
  SQL = SQL_TEMPLATES["has_edge"]
448
  params = {
449
- "workspace":self.db.workspace,
450
- "source_node_id":source_node_id,
451
- "target_node_id":target_node_id
452
- }
453
  # print(SQL)
454
- res = await self.db.query(SQL,params)
455
  if res:
456
  # print("Edge exist!",res)
457
  return True
@@ -462,12 +467,9 @@ class OracleGraphStorage(BaseGraphStorage):
462
  async def node_degree(self, node_id: str) -> int:
463
  """根据节���id获取节点的度"""
464
  SQL = SQL_TEMPLATES["node_degree"]
465
- params = {
466
- "workspace":self.db.workspace,
467
- "node_id":node_id
468
- }
469
  # print(SQL)
470
- res = await self.db.query(SQL,params)
471
  if res:
472
  # print("Node degree",res["degree"])
473
  return res["degree"]
@@ -484,13 +486,10 @@ class OracleGraphStorage(BaseGraphStorage):
484
  async def get_node(self, node_id: str) -> Union[dict, None]:
485
  """根据节点id获取节点数据"""
486
  SQL = SQL_TEMPLATES["get_node"]
487
- params = {
488
- "workspace":self.db.workspace,
489
- "node_id":node_id
490
- }
491
  # print(self.db.workspace, node_id)
492
  # print(SQL)
493
- res = await self.db.query(SQL,params)
494
  if res:
495
  # print("Get node!",self.db.workspace, node_id,res)
496
  return res
@@ -504,11 +503,11 @@ class OracleGraphStorage(BaseGraphStorage):
504
  """根据源和目标节点id获取边"""
505
  SQL = SQL_TEMPLATES["get_edge"]
506
  params = {
507
- "workspace":self.db.workspace,
508
- "source_node_id":source_node_id,
509
- "target_node_id":target_node_id
510
- }
511
- res = await self.db.query(SQL,params)
512
  if res:
513
  # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
514
  return res
@@ -520,10 +519,7 @@ class OracleGraphStorage(BaseGraphStorage):
520
  """根据节点id获取节点的所有边"""
521
  if await self.has_node(source_node_id):
522
  SQL = SQL_TEMPLATES["get_node_edges"]
523
- params = {
524
- "workspace":self.db.workspace,
525
- "source_node_id":source_node_id
526
- }
527
  res = await self.db.query(sql=SQL, params=params, multirows=True)
528
  if res:
529
  data = [(i["source_name"], i["target_name"]) for i in res]
@@ -532,30 +528,31 @@ class OracleGraphStorage(BaseGraphStorage):
532
  else:
533
  # print("Node Edge not exist!",self.db.workspace, source_node_id)
534
  return []
535
-
536
  async def get_all_nodes(self, limit: int):
537
  """查询所有节点"""
538
  SQL = SQL_TEMPLATES["get_all_nodes"]
539
- params = {"workspace":self.db.workspace, "limit":str(limit)}
540
- res = await self.db.query(sql=SQL,params=params, multirows=True)
541
  if res:
542
  return res
543
 
544
  async def get_all_edges(self, limit: int):
545
  """查询所有边"""
546
  SQL = SQL_TEMPLATES["get_all_edges"]
547
- params = {"workspace":self.db.workspace, "limit":str(limit)}
548
- res = await self.db.query(sql=SQL,params=params, multirows=True)
549
  if res:
550
  return res
551
-
552
  async def get_statistics(self):
553
  SQL = SQL_TEMPLATES["get_statistics"]
554
- params = {"workspace":self.db.workspace}
555
- res = await self.db.query(sql=SQL,params=params, multirows=True)
556
  if res:
557
  return res
558
 
 
559
  N_T = {
560
  "full_docs": "LIGHTRAG_DOC_FULL",
561
  "text_chunks": "LIGHTRAG_DOC_CHUNKS",
@@ -727,7 +724,7 @@ SQL_TEMPLATES = {
727
  WHEN NOT MATCHED THEN
728
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
729
  values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
730
- "get_all_nodes":"""WITH t0 AS (
731
  SELECT name AS id, entity_type AS label, entity_type, description,
732
  '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
733
  FROM lightrag_graph_nodes
@@ -743,20 +740,20 @@ SQL_TEMPLATES = {
743
  )
744
  SELECT t0.id, label, entity_type, description, t2.content
745
  FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
746
- "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
747
  t1.weight,t1.DESCRIPTION,t2.content
748
  FROM LIGHTRAG_GRAPH_EDGES t1
749
  LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
750
  WHERE t1.workspace=:workspace
751
  order by t1.CREATETIME DESC
752
  fetch first :limit rows only""",
753
- "get_statistics":"""select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
754
  count(distinct CASE WHEN type='edge' THEN id END) as edges_count
755
  FROM (
756
- select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
757
  MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
758
  UNION
759
- select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
760
  MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
761
  )""",
762
  }
 
114
 
115
  logger.info("Finished check all tables in Oracle database")
116
 
117
+ async def query(
118
+ self, sql: str, params: dict = None, multirows: bool = False
119
+ ) -> Union[dict, None]:
120
  async with self.pool.acquire() as connection:
121
  connection.inputtypehandler = self.input_type_handler
122
  connection.outputtypehandler = self.output_type_handler
 
176
  async def get_by_id(self, id: str) -> Union[dict, None]:
177
  """根据 id 获取 doc_full 数据."""
178
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
179
+ params = {"workspace": self.db.workspace, "id": id}
180
  # print("get_by_id:"+SQL)
181
+ res = await self.db.query(SQL, params)
182
  if res:
183
  data = res # {"data":res}
184
  # print (data)
 
189
  # Query by id
190
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
191
  """根据 id 获取 doc_chunks 数据"""
192
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
193
+ ids=",".join([f"'{id}'" for id in ids])
194
+ )
195
+ params = {"workspace": self.db.workspace}
196
+ # print("get_by_ids:"+SQL)
197
+ # print(params)
198
+ res = await self.db.query(SQL, params, multirows=True)
199
  if res:
200
  data = res # [{"data":i} for i in res]
201
  # print(data)
 
205
 
206
  async def filter_keys(self, keys: list[str]) -> set[str]:
207
  """过滤掉重复内容"""
208
+ SQL = SQL_TEMPLATES["filter_keys"].format(
209
+ table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
210
+ )
211
+ params = {"workspace": self.db.workspace}
212
  try:
213
  await self.db.query(SQL, params)
214
  except Exception as e:
215
  logger.error(f"Oracle database error: {e}")
216
  print(SQL)
217
  print(params)
218
+ res = await self.db.query(SQL, params, multirows=True)
219
  data = None
220
  if res:
221
  exist_keys = [key["id"] for key in res]
 
253
  # print(list_data)
254
  for item in list_data:
255
  merge_sql = SQL_TEMPLATES["merge_chunk"]
256
+ data = {
257
+ "check_id": item["__id__"],
258
+ "id": item["__id__"],
259
+ "content": item["content"],
260
+ "workspace": self.db.workspace,
261
+ "tokens": item["tokens"],
262
+ "chunk_order_index": item["chunk_order_index"],
263
+ "full_doc_id": item["full_doc_id"],
264
+ "content_vector": item["__vector__"],
265
+ }
266
  # print(merge_sql)
267
  await self.db.execute(merge_sql, data)
268
 
 
271
  # values.clear()
272
  merge_sql = SQL_TEMPLATES["merge_doc_full"]
273
  data = {
274
+ "check_id": k,
275
+ "id": k,
276
+ "content": v["content"],
277
+ "workspace": self.db.workspace,
278
+ }
279
  # print(merge_sql)
280
  await self.db.execute(merge_sql, data)
281
  return left_data
 
307
  # 转换精度
308
  dtype = str(embedding.dtype).upper()
309
  dimension = embedding.shape[0]
310
+ embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
311
 
312
  SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
313
  params = {
 
315
  "workspace": self.db.workspace,
316
  "top_k": top_k,
317
  "better_than_threshold": self.cosine_better_than_threshold,
318
+ }
319
  # print(SQL)
320
+ results = await self.db.query(SQL, params=params, multirows=True)
321
  # print("vector search result:",results)
322
  return results
323
 
 
352
  )
353
  embeddings = np.concatenate(embeddings_list)
354
  content_vector = embeddings[0]
355
+ merge_sql = SQL_TEMPLATES["merge_node"]
356
  data = {
357
+ "workspace": self.db.workspace,
358
+ "name": entity_name,
359
+ "entity_type": entity_type,
360
+ "description": description,
361
+ "source_chunk_id": source_id,
362
+ "content": content,
363
+ "content_vector": content_vector,
364
+ }
365
  # print(merge_sql)
366
+ await self.db.execute(merge_sql, data)
367
  # self._graph.add_node(node_id, **node_data)
368
 
369
  async def upsert_edge(
 
377
  keywords = edge_data["keywords"]
378
  description = edge_data["description"]
379
  source_chunk_id = edge_data["source_id"]
380
+ logger.debug(
381
+ f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
382
+ )
383
 
384
  content = keywords + source_name + target_name + description
385
  contents = [content]
 
392
  )
393
  embeddings = np.concatenate(embeddings_list)
394
  content_vector = embeddings[0]
395
+ merge_sql = SQL_TEMPLATES["merge_edge"]
396
  data = {
397
+ "workspace": self.db.workspace,
398
+ "source_name": source_name,
399
+ "target_name": target_name,
400
+ "weight": weight,
401
+ "keywords": keywords,
402
+ "description": description,
403
+ "source_chunk_id": source_chunk_id,
404
+ "content": content,
405
+ "content_vector": content_vector,
406
+ }
407
  # print(merge_sql)
408
+ await self.db.execute(merge_sql, data)
409
  # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
410
 
411
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
 
436
  async def has_node(self, node_id: str) -> bool:
437
  """根据节点id检查节点是否存在"""
438
  SQL = SQL_TEMPLATES["has_node"]
439
+ params = {"workspace": self.db.workspace, "node_id": node_id}
 
 
 
440
  # print(SQL)
441
  # print(self.db.workspace, node_id)
442
+ res = await self.db.query(SQL, params)
443
  if res:
444
  # print("Node exist!",res)
445
  return True
 
451
  """根据源和目标节点id检查边是否存在"""
452
  SQL = SQL_TEMPLATES["has_edge"]
453
  params = {
454
+ "workspace": self.db.workspace,
455
+ "source_node_id": source_node_id,
456
+ "target_node_id": target_node_id,
457
+ }
458
  # print(SQL)
459
+ res = await self.db.query(SQL, params)
460
  if res:
461
  # print("Edge exist!",res)
462
  return True
 
467
  async def node_degree(self, node_id: str) -> int:
468
  """根据节���id获取节点的度"""
469
  SQL = SQL_TEMPLATES["node_degree"]
470
+ params = {"workspace": self.db.workspace, "node_id": node_id}
 
 
 
471
  # print(SQL)
472
+ res = await self.db.query(SQL, params)
473
  if res:
474
  # print("Node degree",res["degree"])
475
  return res["degree"]
 
486
  async def get_node(self, node_id: str) -> Union[dict, None]:
487
  """根据节点id获取节点数据"""
488
  SQL = SQL_TEMPLATES["get_node"]
489
+ params = {"workspace": self.db.workspace, "node_id": node_id}
 
 
 
490
  # print(self.db.workspace, node_id)
491
  # print(SQL)
492
+ res = await self.db.query(SQL, params)
493
  if res:
494
  # print("Get node!",self.db.workspace, node_id,res)
495
  return res
 
503
  """根据源和目标节点id获取边"""
504
  SQL = SQL_TEMPLATES["get_edge"]
505
  params = {
506
+ "workspace": self.db.workspace,
507
+ "source_node_id": source_node_id,
508
+ "target_node_id": target_node_id,
509
+ }
510
+ res = await self.db.query(SQL, params)
511
  if res:
512
  # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
513
  return res
 
519
  """根据节点id获取节点的所有边"""
520
  if await self.has_node(source_node_id):
521
  SQL = SQL_TEMPLATES["get_node_edges"]
522
+ params = {"workspace": self.db.workspace, "source_node_id": source_node_id}
 
 
 
523
  res = await self.db.query(sql=SQL, params=params, multirows=True)
524
  if res:
525
  data = [(i["source_name"], i["target_name"]) for i in res]
 
528
  else:
529
  # print("Node Edge not exist!",self.db.workspace, source_node_id)
530
  return []
531
+
532
  async def get_all_nodes(self, limit: int):
533
  """查询所有节点"""
534
  SQL = SQL_TEMPLATES["get_all_nodes"]
535
+ params = {"workspace": self.db.workspace, "limit": str(limit)}
536
+ res = await self.db.query(sql=SQL, params=params, multirows=True)
537
  if res:
538
  return res
539
 
540
  async def get_all_edges(self, limit: int):
541
  """查询所有边"""
542
  SQL = SQL_TEMPLATES["get_all_edges"]
543
+ params = {"workspace": self.db.workspace, "limit": str(limit)}
544
+ res = await self.db.query(sql=SQL, params=params, multirows=True)
545
  if res:
546
  return res
547
+
548
  async def get_statistics(self):
549
  SQL = SQL_TEMPLATES["get_statistics"]
550
+ params = {"workspace": self.db.workspace}
551
+ res = await self.db.query(sql=SQL, params=params, multirows=True)
552
  if res:
553
  return res
554
 
555
+
556
  N_T = {
557
  "full_docs": "LIGHTRAG_DOC_FULL",
558
  "text_chunks": "LIGHTRAG_DOC_CHUNKS",
 
724
  WHEN NOT MATCHED THEN
725
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
726
  values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
727
+ "get_all_nodes": """WITH t0 AS (
728
  SELECT name AS id, entity_type AS label, entity_type, description,
729
  '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
730
  FROM lightrag_graph_nodes
 
740
  )
741
  SELECT t0.id, label, entity_type, description, t2.content
742
  FROM t0 LEFT JOIN t2 ON t0.id = t2.id""",
743
+ "get_all_edges": """SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
744
  t1.weight,t1.DESCRIPTION,t2.content
745
  FROM LIGHTRAG_GRAPH_EDGES t1
746
  LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
747
  WHERE t1.workspace=:workspace
748
  order by t1.CREATETIME DESC
749
  fetch first :limit rows only""",
750
+ "get_statistics": """select count(distinct CASE WHEN type='node' THEN id END) as nodes_count,
751
  count(distinct CASE WHEN type='edge' THEN id END) as edges_count
752
  FROM (
753
+ select 'node' as type, id FROM GRAPH_TABLE (lightrag_graph
754
  MATCH (a) WHERE a.workspace=:workspace columns(a.name as id))
755
  UNION
756
+ select 'edge' as type, TO_CHAR(id) id FROM GRAPH_TABLE (lightrag_graph
757
  MATCH (a)-[e]->(b) WHERE e.workspace=:workspace columns(e.id))
758
  )""",
759
  }