nothingworry's picture
Update the backend
e44e5dd
raw
history blame
14.2 kB
"""
Supabase/PostgreSQL database utilities shared by all MCP tools.
This module provides:
1. Direct PostgreSQL connections (via psycopg2) for pgvector operations
2. A Supabase client for REST-style administrative needs
"""
from __future__ import annotations
import os
from typing import Optional, List, Dict, Any
import psycopg2
import psycopg2.extras
from dotenv import load_dotenv
from supabase import Client, create_client
# Load environment variables
load_dotenv()
# -----------------------------------
# Environment variables
# -----------------------------------
DATABASE_URL = os.getenv("POSTGRESQL_URL") # Direct PostgreSQL connection
SUPABASE_URL = os.getenv("SUPABASE_URL")
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY") # MUST be service role key
# Global Supabase client instance
_supabase_client: Optional[Client] = None
# -----------------------------------
# PostgreSQL Connection (for pgvector)
# -----------------------------------
def get_connection():
"""
Establish a direct PostgreSQL connection for pgvector operations.
"""
if not DATABASE_URL:
raise ValueError(
"PostgreSQL connection string not configured. "
"Set POSTGRESQL_URL in your .env file."
)
return psycopg2.connect(DATABASE_URL)
# -----------------------------------
# Database Schema Initialization
# -----------------------------------
def initialize_database():
"""
Initialize the database schema:
- Enable pgvector extension
- Create documents table with vector support
"""
try:
conn = get_connection()
cur = conn.cursor()
# Enable pgvector extension
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
print("βœ… pgvector extension enabled")
# Create documents table
cur.execute(
"""
CREATE TABLE IF NOT EXISTS documents (
id BIGSERIAL PRIMARY KEY,
tenant_id TEXT NOT NULL,
chunk_text TEXT NOT NULL,
embedding vector(384) NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
"""
)
print("βœ… documents table created")
# Create index for vector similarity search
cur.execute(
"""
CREATE INDEX IF NOT EXISTS documents_embedding_idx
ON documents
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
"""
)
print("βœ… vector index created")
# Create index for tenant_id for faster filtering
cur.execute(
"""
CREATE INDEX IF NOT EXISTS documents_tenant_id_idx
ON documents (tenant_id);
"""
)
print("βœ… tenant_id index created")
conn.commit()
cur.close()
conn.close()
print("βœ… Database schema initialized successfully")
except Exception as e:
print(f"❌ Database initialization error: {e}")
# Don't raise - allow the app to continue even if table exists
if "already exists" not in str(e).lower():
raise
# -----------------------------------
# Document + Embedding Operations
# -----------------------------------
def insert_document_chunks(tenant_id: str, text: str, embedding: list):
"""
Insert document chunk + embedding.
"""
try:
# Normalize tenant_id to ensure consistency
tenant_id = tenant_id.strip()
conn = get_connection()
cur = conn.cursor()
cur.execute(
"""
INSERT INTO documents (tenant_id, chunk_text, embedding)
VALUES (%s, %s, %s);
""",
(tenant_id, text, embedding),
)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print("DB INSERT ERROR:", e)
raise
def search_vectors(tenant_id: str, vector: list, limit: int = 5) -> List[Dict[str, Any]]:
"""
Perform semantic vector search using pgvector.
Results are filtered by tenant_id to ensure data isolation.
"""
try:
# Validate tenant_id
if not tenant_id or not tenant_id.strip():
print("DB SEARCH ERROR: tenant_id is empty")
return []
tenant_id = tenant_id.strip()
conn = get_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
# Query with explicit tenant_id filtering
cur.execute(
"""
SELECT
chunk_text,
tenant_id,
1 - (embedding <=> %s::vector(384)) AS similarity
FROM documents
WHERE tenant_id = %s
ORDER BY embedding <=> %s::vector(384)
LIMIT %s;
""",
(vector, tenant_id, vector, limit),
)
rows = cur.fetchall()
# Verify all results belong to the requested tenant (safety check)
results: List[Dict[str, Any]] = []
for row in rows:
row_tenant_id = row.get("tenant_id", "")
if row_tenant_id != tenant_id:
print(
f"WARNING: Found document with tenant_id '{row_tenant_id}' when searching for '{tenant_id}' - skipping"
)
continue
results.append(
{
"text": row["chunk_text"],
"similarity": float(row.get("similarity", 0.0)),
}
)
cur.close()
conn.close()
return results
except Exception as e:
print(f"DB SEARCH ERROR (tenant_id={tenant_id}): {e}")
import traceback
traceback.print_exc()
return []
def list_all_documents(
tenant_id: str, limit: int = 1000, offset: int = 0
) -> Dict[str, Any]:
"""
List all documents for a tenant with pagination.
Handles tenant_id normalization to match documents stored with different formatting.
"""
try:
# Normalize tenant_id to ensure consistency
tenant_id_normalized = tenant_id.strip()
conn = get_connection()
cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
# Get all unique tenant_ids that match when normalized
cur.execute("SELECT DISTINCT tenant_id FROM documents;")
all_tenant_ids = [row[0] for row in cur.fetchall()]
# Find tenant_ids that match when normalized
matching_tenant_ids = []
for stored_tenant_id in all_tenant_ids:
if stored_tenant_id and stored_tenant_id.strip() == tenant_id_normalized:
matching_tenant_ids.append(stored_tenant_id)
if not matching_tenant_ids:
# No matching tenant_ids found
cur.close()
conn.close()
return {"documents": [], "total": 0, "limit": limit, "offset": offset}
# Build query to match any of the normalized tenant_ids
placeholders = ','.join(['%s'] * len(matching_tenant_ids))
cur.execute(
f"""
SELECT
id,
chunk_text,
created_at
FROM documents
WHERE tenant_id IN ({placeholders})
ORDER BY created_at DESC
LIMIT %s OFFSET %s;
""",
tuple(matching_tenant_ids) + (limit, offset),
)
rows = cur.fetchall()
# Get total count for all matching tenant_ids
placeholders = ','.join(['%s'] * len(matching_tenant_ids))
cur.execute(
f"""
SELECT COUNT(*) as total
FROM documents
WHERE tenant_id IN ({placeholders});
""",
tuple(matching_tenant_ids),
)
total_row = cur.fetchone()
total = total_row["total"] if total_row else 0
cur.close()
conn.close()
results: List[Dict[str, Any]] = []
for row in rows:
results.append(
{
"id": row["id"],
"text": row["chunk_text"],
"created_at": row["created_at"].isoformat()
if row["created_at"]
else None,
}
)
return {
"documents": results,
"total": total,
"limit": limit,
"offset": offset,
}
except Exception as e:
print("DB LIST ERROR:", e)
return {"documents": [], "total": 0, "limit": limit, "offset": offset}
def delete_document(tenant_id: str, document_id: int) -> bool:
"""
Delete a specific document by ID for a tenant.
Returns True if document was deleted, False otherwise.
"""
try:
# Normalize tenant_id to ensure consistency
tenant_id = tenant_id.strip()
conn = get_connection()
cur = conn.cursor()
# First, verify the document exists
cur.execute(
"""
SELECT id, tenant_id FROM documents
WHERE id = %s;
""",
(document_id,),
)
doc_row = cur.fetchone()
if doc_row is None:
print(f"DB DELETE: Document {document_id} does not exist")
cur.close()
conn.close()
return False
doc_tenant_id = doc_row[1] if len(doc_row) > 1 else None
# Normalize both tenant_ids for comparison (handle existing data with whitespace)
doc_tenant_id_normalized = doc_tenant_id.strip() if doc_tenant_id else None
tenant_id_normalized = tenant_id.strip()
# Try to delete with normalized comparison - if normalized match, use stored value for actual delete
if doc_tenant_id_normalized == tenant_id_normalized:
# Tenant IDs match after normalization - proceed with delete using stored tenant_id
cur.execute(
"""
DELETE FROM documents
WHERE id = %s AND tenant_id = %s;
""",
(document_id, doc_tenant_id),
)
deleted = cur.rowcount > 0
else:
# Tenant IDs don't match - log the mismatch
print(f"DB DELETE: Document {document_id} belongs to tenant '{doc_tenant_id}' (normalized: '{doc_tenant_id_normalized}'), not '{tenant_id}' (normalized: '{tenant_id_normalized}')")
print(f"DB DELETE: Tenant ID lengths - stored: {len(doc_tenant_id) if doc_tenant_id else 0}, requested: {len(tenant_id)}")
print(f"DB DELETE: Tenant ID repr - stored: {repr(doc_tenant_id)}, requested: {repr(tenant_id)}")
deleted = False
if deleted:
print(f"DB DELETE: Successfully deleted document {document_id} for tenant '{tenant_id}'")
else:
print(f"DB DELETE: Failed to delete document {document_id} for tenant '{tenant_id}' (rowcount: {cur.rowcount})")
conn.commit()
cur.close()
conn.close()
return deleted
except Exception as e:
print(f"DB DELETE ERROR (document_id={document_id}, tenant_id={tenant_id}): {e}")
import traceback
traceback.print_exc()
return False
def delete_all_documents(tenant_id: str) -> int:
"""
Delete all documents for a tenant.
Returns the number of documents deleted.
Handles tenant_id normalization to match documents stored with different formatting.
"""
try:
# Normalize tenant_id
tenant_id = tenant_id.strip()
conn = get_connection()
cur = conn.cursor()
# First, get all unique tenant_ids that match when normalized
cur.execute(
"""
SELECT DISTINCT tenant_id FROM documents;
"""
)
all_tenant_ids = [row[0] for row in cur.fetchall()]
# Find tenant_ids that match when normalized
matching_tenant_ids = []
tenant_id_normalized = tenant_id.strip()
for stored_tenant_id in all_tenant_ids:
if stored_tenant_id and stored_tenant_id.strip() == tenant_id_normalized:
matching_tenant_ids.append(stored_tenant_id)
if not matching_tenant_ids:
print(f"DB DELETE ALL: No documents found for tenant '{tenant_id}' (normalized: '{tenant_id_normalized}')")
cur.close()
conn.close()
return 0
# Delete documents matching any of the normalized tenant_ids
deleted_count = 0
for matching_tenant_id in matching_tenant_ids:
cur.execute(
"""
DELETE FROM documents
WHERE tenant_id = %s;
""",
(matching_tenant_id,),
)
deleted_count += cur.rowcount
print(f"DB DELETE ALL: Deleted {deleted_count} document(s) for tenant '{tenant_id}' (matched {len(matching_tenant_ids)} tenant_id variant(s))")
conn.commit()
cur.close()
conn.close()
return deleted_count
except Exception as e:
print(f"DB DELETE ALL ERROR (tenant_id={tenant_id}): {e}")
import traceback
traceback.print_exc()
return 0
# -----------------------------------
# Supabase Client (for REST operations)
# -----------------------------------
def get_supabase_client() -> Client:
"""
Get or create Supabase client.
"""
global _supabase_client
if _supabase_client is None:
if not SUPABASE_URL or not SUPABASE_KEY:
raise ValueError(
"Supabase credentials missing. "
"Set SUPABASE_URL and SUPABASE_SERVICE_KEY."
)
_supabase_client = create_client(SUPABASE_URL, SUPABASE_KEY)
return _supabase_client
def reset_client():
global _supabase_client
_supabase_client = None
# Table names
TABLES = {
"tenants": "tenants",
"documents": "documents",
"embeddings": "tenant_embeddings",
"redflag_rules": "redflag_rules",
"analytics": "analytics_events",
"tool_usage": "tool_usage_stats",
}