import argparse import os import shutil from typing import Any, ClassVar from private_gpt.paths import local_data_path from private_gpt.settings.settings import settings def wipe_file(file: str) -> None: if os.path.isfile(file): os.remove(file) print(f" - Deleted {file}") def wipe_tree(path: str) -> None: if not os.path.exists(path): print(f"Warning: Path not found {path}") return print(f"Wiping {path}...") all_files = os.listdir(path) files_to_remove = [file for file in all_files if file != ".gitignore"] for file_name in files_to_remove: file_path = os.path.join(path, file_name) try: if os.path.isfile(file_path): os.remove(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) print(f" - Deleted {file_path}") except PermissionError: print( f"PermissionError: Unable to remove {file_path}. It is in use by another process." ) continue class Postgres: tables: ClassVar[dict[str, list[str]]] = { "nodestore": ["data_docstore", "data_indexstore"], "vectorstore": ["data_embeddings"], } def __init__(self) -> None: try: import psycopg2 except ModuleNotFoundError: raise ModuleNotFoundError("Postgres dependencies not found") from None connection = settings().postgres.model_dump(exclude_none=True) self.schema = connection.pop("schema_name") self.conn = psycopg2.connect(**connection) def wipe(self, storetype: str) -> None: cur = self.conn.cursor() try: for table in self.tables[storetype]: sql = f"DROP TABLE IF EXISTS {self.schema}.{table}" cur.execute(sql) print(f"Table {self.schema}.{table} dropped.") self.conn.commit() finally: cur.close() def stats(self, store_type: str) -> None: template = "SELECT '{table}', COUNT(*), pg_size_pretty(pg_total_relation_size('{table}')) FROM {table}" sql = " UNION ALL ".join( template.format(table=tbl) for tbl in self.tables[store_type] ) cur = self.conn.cursor() try: print(f"Storage for Postgres {store_type}.") print("{:<15} | {:>15} | {:>9}".format("Table", "Rows", "Size")) print("-" * 45) # Print a line separator cur.execute(sql) for row in cur.fetchall(): formatted_row_count = f"{row[1]:,}" print(f"{row[0]:<15} | {formatted_row_count:>15} | {row[2]:>9}") print() finally: cur.close() def __del__(self): if hasattr(self, "conn") and self.conn: self.conn.close() class Simple: def wipe(self, store_type: str) -> None: assert store_type == "nodestore" from import ( DEFAULT_PERSIST_FNAME as DOCSTORE, ) from import ( DEFAULT_PERSIST_FNAME as INDEXSTORE, ) for store in (DOCSTORE, INDEXSTORE): wipe_file(str((local_data_path / store).absolute())) class Chroma: def wipe(self, store_type: str) -> None: assert store_type == "vectorstore" wipe_tree(str((local_data_path / "chroma_db").absolute())) class Qdrant: COLLECTION = ( "make_this_parameterizable_per_api_call" # ?! see ) def __init__(self) -> None: try: from qdrant_client import QdrantClient # type: ignore except ImportError: raise ImportError("Qdrant dependencies not found") from None self.client = QdrantClient(**settings().qdrant.model_dump(exclude_none=True)) def wipe(self, store_type: str) -> None: assert store_type == "vectorstore" try: self.client.delete_collection(self.COLLECTION) print("Collection dropped successfully.") except Exception as e: print("Error dropping collection:", e) def stats(self, store_type: str) -> None: print(f"Storage for Qdrant {store_type}.") try: collection_data = self.client.get_collection(self.COLLECTION) if collection_data: # Collection Info # print(f"\tPoints: {collection_data.points_count:,}") print(f"\tVectors: {collection_data.vectors_count:,}") print(f"\tIndex Vectors: {collection_data.indexed_vectors_count:,}") return except ValueError: pass print("\t- Qdrant collection not found or empty") class Command: DB_HANDLERS: ClassVar[dict[str, Any]] = { "simple": Simple, # node store "chroma": Chroma, # vector store "postgres": Postgres, # node, index and vector store "qdrant": Qdrant, # vector store } def for_each_store(self, cmd: str): for store_type in ("nodestore", "vectorstore"): database = getattr(settings(), store_type).database handler_class = self.DB_HANDLERS.get(database) if handler_class is None: print(f"No handler found for database '{database}'") continue handler_instance = handler_class() # Instantiate the class # If the DB can handle this cmd dispatch it. if hasattr(handler_instance, cmd) and callable( func := getattr(handler_instance, cmd) ): func(store_type) else: print( f"Unable to execute command '{cmd}' on '{store_type}' in database '{database}'" ) def execute(self, cmd: str) -> None: if cmd in ("wipe", "stats"): self.for_each_store(cmd) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("mode", help="select a mode to run", choices=["wipe", "stats"]) args = parser.parse_args() Command().execute(args.mode.lower())