Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import shutil | |
from typing import Any, ClassVar | |
from private_gpt.paths import local_data_path | |
from private_gpt.settings.settings import settings | |
def wipe_file(file: str) -> None: | |
if os.path.isfile(file): | |
os.remove(file) | |
print(f" - Deleted {file}") | |
def wipe_tree(path: str) -> None: | |
if not os.path.exists(path): | |
print(f"Warning: Path not found {path}") | |
return | |
print(f"Wiping {path}...") | |
all_files = os.listdir(path) | |
files_to_remove = [file for file in all_files if file != ".gitignore"] | |
for file_name in files_to_remove: | |
file_path = os.path.join(path, file_name) | |
try: | |
if os.path.isfile(file_path): | |
os.remove(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
print(f" - Deleted {file_path}") | |
except PermissionError: | |
print( | |
f"PermissionError: Unable to remove {file_path}. It is in use by another process." | |
) | |
continue | |
class Postgres: | |
tables: ClassVar[dict[str, list[str]]] = { | |
"nodestore": ["data_docstore", "data_indexstore"], | |
"vectorstore": ["data_embeddings"], | |
} | |
def __init__(self) -> None: | |
try: | |
import psycopg2 | |
except ModuleNotFoundError: | |
raise ModuleNotFoundError("Postgres dependencies not found") from None | |
connection = settings().postgres.model_dump(exclude_none=True) | |
self.schema = connection.pop("schema_name") | |
self.conn = psycopg2.connect(**connection) | |
def wipe(self, storetype: str) -> None: | |
cur = self.conn.cursor() | |
try: | |
for table in self.tables[storetype]: | |
sql = f"DROP TABLE IF EXISTS {self.schema}.{table}" | |
cur.execute(sql) | |
print(f"Table {self.schema}.{table} dropped.") | |
self.conn.commit() | |
finally: | |
cur.close() | |
def stats(self, store_type: str) -> None: | |
template = "SELECT '{table}', COUNT(*), pg_size_pretty(pg_total_relation_size('{table}')) FROM {table}" | |
sql = " UNION ALL ".join( | |
template.format(table=tbl) for tbl in self.tables[store_type] | |
) | |
cur = self.conn.cursor() | |
try: | |
print(f"Storage for Postgres {store_type}.") | |
print("{:<15} | {:>15} | {:>9}".format("Table", "Rows", "Size")) | |
print("-" * 45) # Print a line separator | |
cur.execute(sql) | |
for row in cur.fetchall(): | |
formatted_row_count = f"{row[1]:,}" | |
print(f"{row[0]:<15} | {formatted_row_count:>15} | {row[2]:>9}") | |
print() | |
finally: | |
cur.close() | |
def __del__(self): | |
if hasattr(self, "conn") and self.conn: | |
self.conn.close() | |
class Simple: | |
def wipe(self, store_type: str) -> None: | |
assert store_type == "nodestore" | |
from llama_index.core.storage.docstore.types import ( | |
DEFAULT_PERSIST_FNAME as DOCSTORE, | |
) | |
from llama_index.core.storage.index_store.types import ( | |
DEFAULT_PERSIST_FNAME as INDEXSTORE, | |
) | |
for store in (DOCSTORE, INDEXSTORE): | |
wipe_file(str((local_data_path / store).absolute())) | |
class Chroma: | |
def wipe(self, store_type: str) -> None: | |
assert store_type == "vectorstore" | |
wipe_tree(str((local_data_path / "chroma_db").absolute())) | |
class Qdrant: | |
COLLECTION = ( | |
"make_this_parameterizable_per_api_call" # ?! see vector_store_component.py | |
) | |
def __init__(self) -> None: | |
try: | |
from qdrant_client import QdrantClient # type: ignore | |
except ImportError: | |
raise ImportError("Qdrant dependencies not found") from None | |
self.client = QdrantClient(**settings().qdrant.model_dump(exclude_none=True)) | |
def wipe(self, store_type: str) -> None: | |
assert store_type == "vectorstore" | |
try: | |
self.client.delete_collection(self.COLLECTION) | |
print("Collection dropped successfully.") | |
except Exception as e: | |
print("Error dropping collection:", e) | |
def stats(self, store_type: str) -> None: | |
print(f"Storage for Qdrant {store_type}.") | |
try: | |
collection_data = self.client.get_collection(self.COLLECTION) | |
if collection_data: | |
# Collection Info | |
# https://qdrant.tech/documentation/concepts/collections/ | |
print(f"\tPoints: {collection_data.points_count:,}") | |
print(f"\tVectors: {collection_data.vectors_count:,}") | |
print(f"\tIndex Vectors: {collection_data.indexed_vectors_count:,}") | |
return | |
except ValueError: | |
pass | |
print("\t- Qdrant collection not found or empty") | |
class Command: | |
DB_HANDLERS: ClassVar[dict[str, Any]] = { | |
"simple": Simple, # node store | |
"chroma": Chroma, # vector store | |
"postgres": Postgres, # node, index and vector store | |
"qdrant": Qdrant, # vector store | |
} | |
def for_each_store(self, cmd: str): | |
for store_type in ("nodestore", "vectorstore"): | |
database = getattr(settings(), store_type).database | |
handler_class = self.DB_HANDLERS.get(database) | |
if handler_class is None: | |
print(f"No handler found for database '{database}'") | |
continue | |
handler_instance = handler_class() # Instantiate the class | |
# If the DB can handle this cmd dispatch it. | |
if hasattr(handler_instance, cmd) and callable( | |
func := getattr(handler_instance, cmd) | |
): | |
func(store_type) | |
else: | |
print( | |
f"Unable to execute command '{cmd}' on '{store_type}' in database '{database}'" | |
) | |
def execute(self, cmd: str) -> None: | |
if cmd in ("wipe", "stats"): | |
self.for_each_store(cmd) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("mode", help="select a mode to run", choices=["wipe", "stats"]) | |
args = parser.parse_args() | |
Command().execute(args.mode.lower()) | |