aelitta's picture
Upload folder using huggingface_hub
4bdb245 verified
raw
history blame
6.3 kB
import argparse
import os
import shutil
from typing import Any, ClassVar
from private_gpt.paths import local_data_path
from private_gpt.settings.settings import settings
def wipe_file(file: str) -> None:
if os.path.isfile(file):
os.remove(file)
print(f" - Deleted {file}")
def wipe_tree(path: str) -> None:
if not os.path.exists(path):
print(f"Warning: Path not found {path}")
return
print(f"Wiping {path}...")
all_files = os.listdir(path)
files_to_remove = [file for file in all_files if file != ".gitignore"]
for file_name in files_to_remove:
file_path = os.path.join(path, file_name)
try:
if os.path.isfile(file_path):
os.remove(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
print(f" - Deleted {file_path}")
except PermissionError:
print(
f"PermissionError: Unable to remove {file_path}. It is in use by another process."
)
continue
class Postgres:
tables: ClassVar[dict[str, list[str]]] = {
"nodestore": ["data_docstore", "data_indexstore"],
"vectorstore": ["data_embeddings"],
}
def __init__(self) -> None:
try:
import psycopg2
except ModuleNotFoundError:
raise ModuleNotFoundError("Postgres dependencies not found") from None
connection = settings().postgres.model_dump(exclude_none=True)
self.schema = connection.pop("schema_name")
self.conn = psycopg2.connect(**connection)
def wipe(self, storetype: str) -> None:
cur = self.conn.cursor()
try:
for table in self.tables[storetype]:
sql = f"DROP TABLE IF EXISTS {self.schema}.{table}"
cur.execute(sql)
print(f"Table {self.schema}.{table} dropped.")
self.conn.commit()
finally:
cur.close()
def stats(self, store_type: str) -> None:
template = "SELECT '{table}', COUNT(*), pg_size_pretty(pg_total_relation_size('{table}')) FROM {table}"
sql = " UNION ALL ".join(
template.format(table=tbl) for tbl in self.tables[store_type]
)
cur = self.conn.cursor()
try:
print(f"Storage for Postgres {store_type}.")
print("{:<15} | {:>15} | {:>9}".format("Table", "Rows", "Size"))
print("-" * 45) # Print a line separator
cur.execute(sql)
for row in cur.fetchall():
formatted_row_count = f"{row[1]:,}"
print(f"{row[0]:<15} | {formatted_row_count:>15} | {row[2]:>9}")
print()
finally:
cur.close()
def __del__(self):
if hasattr(self, "conn") and self.conn:
self.conn.close()
class Simple:
def wipe(self, store_type: str) -> None:
assert store_type == "nodestore"
from llama_index.core.storage.docstore.types import (
DEFAULT_PERSIST_FNAME as DOCSTORE,
)
from llama_index.core.storage.index_store.types import (
DEFAULT_PERSIST_FNAME as INDEXSTORE,
)
for store in (DOCSTORE, INDEXSTORE):
wipe_file(str((local_data_path / store).absolute()))
class Chroma:
def wipe(self, store_type: str) -> None:
assert store_type == "vectorstore"
wipe_tree(str((local_data_path / "chroma_db").absolute()))
class Qdrant:
COLLECTION = (
"make_this_parameterizable_per_api_call" # ?! see vector_store_component.py
)
def __init__(self) -> None:
try:
from qdrant_client import QdrantClient # type: ignore
except ImportError:
raise ImportError("Qdrant dependencies not found") from None
self.client = QdrantClient(**settings().qdrant.model_dump(exclude_none=True))
def wipe(self, store_type: str) -> None:
assert store_type == "vectorstore"
try:
self.client.delete_collection(self.COLLECTION)
print("Collection dropped successfully.")
except Exception as e:
print("Error dropping collection:", e)
def stats(self, store_type: str) -> None:
print(f"Storage for Qdrant {store_type}.")
try:
collection_data = self.client.get_collection(self.COLLECTION)
if collection_data:
# Collection Info
# https://qdrant.tech/documentation/concepts/collections/
print(f"\tPoints: {collection_data.points_count:,}")
print(f"\tVectors: {collection_data.vectors_count:,}")
print(f"\tIndex Vectors: {collection_data.indexed_vectors_count:,}")
return
except ValueError:
pass
print("\t- Qdrant collection not found or empty")
class Command:
DB_HANDLERS: ClassVar[dict[str, Any]] = {
"simple": Simple, # node store
"chroma": Chroma, # vector store
"postgres": Postgres, # node, index and vector store
"qdrant": Qdrant, # vector store
}
def for_each_store(self, cmd: str):
for store_type in ("nodestore", "vectorstore"):
database = getattr(settings(), store_type).database
handler_class = self.DB_HANDLERS.get(database)
if handler_class is None:
print(f"No handler found for database '{database}'")
continue
handler_instance = handler_class() # Instantiate the class
# If the DB can handle this cmd dispatch it.
if hasattr(handler_instance, cmd) and callable(
func := getattr(handler_instance, cmd)
):
func(store_type)
else:
print(
f"Unable to execute command '{cmd}' on '{store_type}' in database '{database}'"
)
def execute(self, cmd: str) -> None:
if cmd in ("wipe", "stats"):
self.for_each_store(cmd)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("mode", help="select a mode to run", choices=["wipe", "stats"])
args = parser.parse_args()
Command().execute(args.mode.lower())