yangdx commited on
Commit
fadbe07
·
1 Parent(s): 2386d58

refactor: move database connection pool initialization to lifespan of FastAPI

Browse files

- Add proper database connection lifecycle management
- Add connection pool cleanup in FastAPI lifespan

Files changed (2) hide show
  1. lightrag/api/lightrag_server.py +213 -17
  2. lightrag/lightrag.py +0 -171
lightrag/api/lightrag_server.py CHANGED
@@ -33,14 +33,39 @@ from contextlib import asynccontextmanager
33
  from starlette.status import HTTP_403_FORBIDDEN
34
  import pipmaster as pm
35
  from dotenv import load_dotenv
 
 
36
  from .ollama_api import (
37
  OllamaAPI,
38
  )
39
  from .ollama_api import ollama_server_infos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  # Load environment variables
42
  load_dotenv(override=True)
43
 
 
 
 
 
44
 
45
  class RAGStorageConfig:
46
  """存储配置类,支持通过环境变量和命令行参数修改默认值"""
@@ -714,25 +739,99 @@ def create_app(args):
714
  @asynccontextmanager
715
  async def lifespan(app: FastAPI):
716
  """Lifespan context manager for startup and shutdown events"""
717
- # Startup logic
718
- if args.auto_scan_at_startup:
719
- try:
720
- new_files = doc_manager.scan_directory_for_new_files()
721
- for file_path in new_files:
722
- try:
723
- await index_file(file_path)
724
- except Exception as e:
725
- trace_exception(e)
726
- logging.error(f"Error indexing file {file_path}: {str(e)}")
727
 
728
- ASCIIColors.info(
729
- f"Indexed {len(new_files)} documents from {args.input_dir}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
  )
731
- except Exception as e:
732
- logging.error(f"Error during startup indexing: {str(e)}")
733
- yield
734
- # Cleanup logic (if needed)
735
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
736
 
737
  # Initialize FastAPI
738
  app = FastAPI(
@@ -755,6 +854,92 @@ def create_app(args):
755
  allow_headers=["*"],
756
  )
757
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
758
  # Create the optional API key dependency
759
  optional_api_key = get_api_key_dependency(api_key)
760
 
@@ -921,6 +1106,17 @@ def create_app(args):
921
  namespace_prefix=args.namespace_prefix,
922
  )
923
 
 
 
 
 
 
 
 
 
 
 
 
924
  async def index_file(file_path: Union[str, Path]) -> None:
925
  """Index all files inside the folder with support for multiple file formats
926
 
 
33
  from starlette.status import HTTP_403_FORBIDDEN
34
  import pipmaster as pm
35
  from dotenv import load_dotenv
36
+ import configparser
37
+ from lightrag.utils import logger
38
  from .ollama_api import (
39
  OllamaAPI,
40
  )
41
  from .ollama_api import ollama_server_infos
42
+ from ..kg.postgres_impl import (
43
+ PostgreSQLDB,
44
+ PGKVStorage,
45
+ PGVectorStorage,
46
+ PGGraphStorage,
47
+ PGDocStatusStorage,
48
+ )
49
+ from ..kg.oracle_impl import (
50
+ OracleDB,
51
+ OracleKVStorage,
52
+ OracleVectorDBStorage,
53
+ OracleGraphStorage,
54
+ )
55
+ from ..kg.tidb_impl import (
56
+ TiDB,
57
+ TiDBKVStorage,
58
+ TiDBVectorDBStorage,
59
+ TiDBGraphStorage,
60
+ )
61
 
62
  # Load environment variables
63
  load_dotenv(override=True)
64
 
65
+ # Initialize config parser
66
+ config = configparser.ConfigParser()
67
+ config.read("config.ini")
68
+
69
 
70
  class RAGStorageConfig:
71
  """存储配置类,支持通过环境变量和命令行参数修改默认值"""
 
739
  @asynccontextmanager
740
  async def lifespan(app: FastAPI):
741
  """Lifespan context manager for startup and shutdown events"""
742
+ # Initialize database connections
743
+ postgres_db = None
744
+ oracle_db = None
745
+ tidb_db = None
 
 
 
 
 
 
746
 
747
+ try:
748
+ # Check if PostgreSQL is needed
749
+ if any(
750
+ isinstance(
751
+ storage_instance,
752
+ (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
753
+ )
754
+ for _, storage_instance in storage_instances
755
+ ):
756
+ postgres_db = PostgreSQLDB(_get_postgres_config())
757
+ await postgres_db.initdb()
758
+ await postgres_db.check_tables()
759
+ for storage_name, storage_instance in storage_instances:
760
+ if isinstance(
761
+ storage_instance,
762
+ (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
763
+ ):
764
+ storage_instance.db = postgres_db
765
+ logger.info(f"Injected postgres_db to {storage_name}")
766
+
767
+ # Check if Oracle is needed
768
+ if any(
769
+ isinstance(
770
+ storage_instance,
771
+ (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
772
+ )
773
+ for _, storage_instance in storage_instances
774
+ ):
775
+ oracle_db = OracleDB(_get_oracle_config())
776
+ await oracle_db.check_tables()
777
+ for storage_name, storage_instance in storage_instances:
778
+ if isinstance(
779
+ storage_instance,
780
+ (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
781
+ ):
782
+ storage_instance.db = oracle_db
783
+ logger.info(f"Injected oracle_db to {storage_name}")
784
+
785
+ # Check if TiDB is needed
786
+ if any(
787
+ isinstance(
788
+ storage_instance,
789
+ (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
790
  )
791
+ for _, storage_instance in storage_instances
792
+ ):
793
+ tidb_db = TiDB(_get_tidb_config())
794
+ await tidb_db.check_tables()
795
+ for storage_name, storage_instance in storage_instances:
796
+ if isinstance(
797
+ storage_instance,
798
+ (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage),
799
+ ):
800
+ storage_instance.db = tidb_db
801
+ logger.info(f"Injected tidb_db to {storage_name}")
802
+
803
+ # Auto scan documents if enabled
804
+ if args.auto_scan_at_startup:
805
+ try:
806
+ new_files = doc_manager.scan_directory_for_new_files()
807
+ for file_path in new_files:
808
+ try:
809
+ await index_file(file_path)
810
+ except Exception as e:
811
+ trace_exception(e)
812
+ logging.error(f"Error indexing file {file_path}: {str(e)}")
813
+
814
+ ASCIIColors.info(
815
+ f"Indexed {len(new_files)} documents from {args.input_dir}"
816
+ )
817
+ except Exception as e:
818
+ logging.error(f"Error during startup indexing: {str(e)}")
819
+
820
+ yield
821
+
822
+ finally:
823
+ # Cleanup database connections
824
+ if postgres_db and hasattr(postgres_db, "pool"):
825
+ await postgres_db.pool.close()
826
+ logger.info("Closed PostgreSQL connection pool")
827
+
828
+ if oracle_db and hasattr(oracle_db, "pool"):
829
+ await oracle_db.pool.close()
830
+ logger.info("Closed Oracle connection pool")
831
+
832
+ if tidb_db and hasattr(tidb_db, "pool"):
833
+ await tidb_db.pool.close()
834
+ logger.info("Closed TiDB connection pool")
835
 
836
  # Initialize FastAPI
837
  app = FastAPI(
 
854
  allow_headers=["*"],
855
  )
856
 
857
+ # Database configuration functions
858
+ def _get_postgres_config():
859
+ return {
860
+ "host": os.environ.get(
861
+ "POSTGRES_HOST",
862
+ config.get("postgres", "host", fallback="localhost"),
863
+ ),
864
+ "port": os.environ.get(
865
+ "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
866
+ ),
867
+ "user": os.environ.get(
868
+ "POSTGRES_USER", config.get("postgres", "user", fallback=None)
869
+ ),
870
+ "password": os.environ.get(
871
+ "POSTGRES_PASSWORD",
872
+ config.get("postgres", "password", fallback=None),
873
+ ),
874
+ "database": os.environ.get(
875
+ "POSTGRES_DATABASE",
876
+ config.get("postgres", "database", fallback=None),
877
+ ),
878
+ "workspace": os.environ.get(
879
+ "POSTGRES_WORKSPACE",
880
+ config.get("postgres", "workspace", fallback="default"),
881
+ ),
882
+ }
883
+
884
+ def _get_oracle_config():
885
+ return {
886
+ "user": os.environ.get(
887
+ "ORACLE_USER",
888
+ config.get("oracle", "user", fallback=None),
889
+ ),
890
+ "password": os.environ.get(
891
+ "ORACLE_PASSWORD",
892
+ config.get("oracle", "password", fallback=None),
893
+ ),
894
+ "dsn": os.environ.get(
895
+ "ORACLE_DSN",
896
+ config.get("oracle", "dsn", fallback=None),
897
+ ),
898
+ "config_dir": os.environ.get(
899
+ "ORACLE_CONFIG_DIR",
900
+ config.get("oracle", "config_dir", fallback=None),
901
+ ),
902
+ "wallet_location": os.environ.get(
903
+ "ORACLE_WALLET_LOCATION",
904
+ config.get("oracle", "wallet_location", fallback=None),
905
+ ),
906
+ "wallet_password": os.environ.get(
907
+ "ORACLE_WALLET_PASSWORD",
908
+ config.get("oracle", "wallet_password", fallback=None),
909
+ ),
910
+ "workspace": os.environ.get(
911
+ "ORACLE_WORKSPACE",
912
+ config.get("oracle", "workspace", fallback="default"),
913
+ ),
914
+ }
915
+
916
+ def _get_tidb_config():
917
+ return {
918
+ "host": os.environ.get(
919
+ "TIDB_HOST",
920
+ config.get("tidb", "host", fallback="localhost"),
921
+ ),
922
+ "port": os.environ.get(
923
+ "TIDB_PORT", config.get("tidb", "port", fallback=4000)
924
+ ),
925
+ "user": os.environ.get(
926
+ "TIDB_USER",
927
+ config.get("tidb", "user", fallback=None),
928
+ ),
929
+ "password": os.environ.get(
930
+ "TIDB_PASSWORD",
931
+ config.get("tidb", "password", fallback=None),
932
+ ),
933
+ "database": os.environ.get(
934
+ "TIDB_DATABASE",
935
+ config.get("tidb", "database", fallback=None),
936
+ ),
937
+ "workspace": os.environ.get(
938
+ "TIDB_WORKSPACE",
939
+ config.get("tidb", "workspace", fallback="default"),
940
+ ),
941
+ }
942
+
943
  # Create the optional API key dependency
944
  optional_api_key = get_api_key_dependency(api_key)
945
 
 
1106
  namespace_prefix=args.namespace_prefix,
1107
  )
1108
 
1109
+ # Collect all storage instances
1110
+ storage_instances = [
1111
+ ("full_docs", rag.full_docs),
1112
+ ("text_chunks", rag.text_chunks),
1113
+ ("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
1114
+ ("entities_vdb", rag.entities_vdb),
1115
+ ("relationships_vdb", rag.relationships_vdb),
1116
+ ("chunks_vdb", rag.chunks_vdb),
1117
+ ("doc_status", rag.doc_status),
1118
+ ]
1119
+
1120
  async def index_file(file_path: Union[str, Path]) -> None:
1121
  """Index all files inside the folder with support for multiple file formats
1122
 
lightrag/lightrag.py CHANGED
@@ -355,91 +355,6 @@ class LightRAG:
355
  list[dict[str, Any]],
356
  ] = chunking_by_token_size
357
 
358
- def _get_postgres_config(self):
359
- return {
360
- "host": os.environ.get(
361
- "POSTGRES_HOST",
362
- config.get("postgres", "host", fallback="localhost"),
363
- ),
364
- "port": os.environ.get(
365
- "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
366
- ),
367
- "user": os.environ.get(
368
- "POSTGRES_USER", config.get("postgres", "user", fallback=None)
369
- ),
370
- "password": os.environ.get(
371
- "POSTGRES_PASSWORD",
372
- config.get("postgres", "password", fallback=None),
373
- ),
374
- "database": os.environ.get(
375
- "POSTGRES_DATABASE",
376
- config.get("postgres", "database", fallback=None),
377
- ),
378
- "workspace": os.environ.get(
379
- "POSTGRES_WORKSPACE",
380
- config.get("postgres", "workspace", fallback="default"),
381
- ),
382
- }
383
-
384
- def _get_oracle_config(self):
385
- return {
386
- "user": os.environ.get(
387
- "ORACLE_USER",
388
- config.get("oracle", "user", fallback=None),
389
- ),
390
- "password": os.environ.get(
391
- "ORACLE_PASSWORD",
392
- config.get("oracle", "password", fallback=None),
393
- ),
394
- "dsn": os.environ.get(
395
- "ORACLE_DSN",
396
- config.get("oracle", "dsn", fallback=None),
397
- ),
398
- "config_dir": os.environ.get(
399
- "ORACLE_CONFIG_DIR",
400
- config.get("oracle", "config_dir", fallback=None),
401
- ),
402
- "wallet_location": os.environ.get(
403
- "ORACLE_WALLET_LOCATION",
404
- config.get("oracle", "wallet_location", fallback=None),
405
- ),
406
- "wallet_password": os.environ.get(
407
- "ORACLE_WALLET_PASSWORD",
408
- config.get("oracle", "wallet_password", fallback=None),
409
- ),
410
- "workspace": os.environ.get(
411
- "ORACLE_WORKSPACE",
412
- config.get("oracle", "workspace", fallback="default"),
413
- ),
414
- }
415
-
416
- def _get_tidb_config(self):
417
- return {
418
- "host": os.environ.get(
419
- "TIDB_HOST",
420
- config.get("tidb", "host", fallback="localhost"),
421
- ),
422
- "port": os.environ.get(
423
- "TIDB_PORT", config.get("tidb", "port", fallback=4000)
424
- ),
425
- "user": os.environ.get(
426
- "TIDB_USER",
427
- config.get("tidb", "user", fallback=None),
428
- ),
429
- "password": os.environ.get(
430
- "TIDB_PASSWORD",
431
- config.get("tidb", "password", fallback=None),
432
- ),
433
- "database": os.environ.get(
434
- "TIDB_DATABASE",
435
- config.get("tidb", "database", fallback=None),
436
- ),
437
- "workspace": os.environ.get(
438
- "TIDB_WORKSPACE",
439
- config.get("tidb", "workspace", fallback="default"),
440
- ),
441
- }
442
-
443
  def verify_storage_implementation(
444
  self, storage_type: str, storage_name: str
445
  ) -> None:
@@ -609,20 +524,6 @@ class LightRAG:
609
  )
610
 
611
 
612
- # Collect all storage instances with their names
613
- storage_instances = [
614
- ("full_docs", self.full_docs),
615
- ("text_chunks", self.text_chunks),
616
- ("chunk_entity_relation_graph", self.chunk_entity_relation_graph),
617
- ("entities_vdb", self.entities_vdb),
618
- ("relationships_vdb", self.relationships_vdb),
619
- ("chunks_vdb", self.chunks_vdb),
620
- ("doc_status", self.doc_status),
621
- ]
622
-
623
- # Initialize database connections if needed
624
- loop = always_get_an_event_loop()
625
- loop.run_until_complete(self._initialize_database_if_needed(storage_instances))
626
 
627
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
628
  partial(
@@ -646,78 +547,6 @@ class LightRAG:
646
  storage_class = lazy_external_import(import_path, storage_name)
647
  return storage_class
648
 
649
- async def _initialize_database_if_needed(self, storage_instances: list[tuple[str, Any]]):
650
- """Intialize database connection and inject it to storage implementation if needed"""
651
- from .kg.postgres_impl import PostgreSQLDB
652
- from .kg.oracle_impl import OracleDB
653
- from .kg.tidb_impl import TiDB
654
- from .kg.postgres_impl import (
655
- PGKVStorage,
656
- PGVectorStorage,
657
- PGGraphStorage,
658
- PGDocStatusStorage,
659
- )
660
- from .kg.oracle_impl import (
661
- OracleKVStorage,
662
- OracleVectorDBStorage,
663
- OracleGraphStorage,
664
- )
665
- from .kg.tidb_impl import (
666
- TiDBKVStorage,
667
- TiDBVectorDBStorage,
668
- TiDBGraphStorage)
669
-
670
- # Checking if PostgreSQL is needed
671
- if any(
672
- isinstance(
673
- storage_instance,
674
- (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
675
- )
676
- for _, storage_instance in storage_instances
677
- ):
678
- postgres_db = PostgreSQLDB(self._get_postgres_config())
679
- await postgres_db.initdb()
680
- await postgres_db.check_tables()
681
- for storage_name, storage_instance in storage_instances:
682
- if isinstance(
683
- storage_instance,
684
- (PGKVStorage, PGVectorStorage, PGGraphStorage, PGDocStatusStorage),
685
- ):
686
- storage_instance.db = postgres_db
687
- logger.info(f"Injected postgres_db to {storage_name}")
688
-
689
- # Checking if Oracle is needed
690
- if any(
691
- isinstance(
692
- storage_instance, (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage)
693
- )
694
- for _, storage_instance in storage_instances
695
- ):
696
- oracle_db = OracleDB(self._get_oracle_config())
697
- await oracle_db.check_tables()
698
- for storage_name, storage_instance in storage_instances:
699
- if isinstance(
700
- storage_instance,
701
- (OracleKVStorage, OracleVectorDBStorage, OracleGraphStorage),
702
- ):
703
- storage_instance.db = oracle_db
704
- logger.info(f"Injected oracle_db to {storage_name}")
705
-
706
- # Checking if TiDB is needed
707
- if any(
708
- isinstance(storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage))
709
- for _, storage_instance in storage_instances
710
- ):
711
- tidb_db = TiDB(self._get_tidb_config())
712
- await tidb_db.check_tables()
713
- # 注入db实例
714
- for storage_name, storage_instance in storage_instances:
715
- if isinstance(
716
- storage_instance, (TiDBKVStorage, TiDBVectorDBStorage, TiDBGraphStorage)
717
- ):
718
- storage_instance.db = tidb_db
719
- logger.info(f"Injected tidb_db to {storage_name}")
720
-
721
  def set_storage_client(self, db_client):
722
  # Inject db to storage implementation (only tested on Oracle Database
723
  # Deprecated, seting correct value to *_storage creating LightRAG insteaded
 
355
  list[dict[str, Any]],
356
  ] = chunking_by_token_size
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  def verify_storage_implementation(
359
  self, storage_type: str, storage_name: str
360
  ) -> None:
 
524
  )
525
 
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
529
  partial(
 
547
  storage_class = lazy_external_import(import_path, storage_name)
548
  return storage_class
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  def set_storage_client(self, db_client):
551
  # Inject db to storage implementation (only tested on Oracle Database
552
  # Deprecated, seting correct value to *_storage creating LightRAG insteaded