Spaces:
Sleeping
Sleeping
| """ | |
| Database module for PostgreSQL integration (Neon/Vercel Postgres). | |
| This module provides a clean abstraction layer for database operations, | |
| supporting both in-memory storage (for development/testing) and PostgreSQL | |
| via Neon (for production). | |
| Environment Variables: | |
| DATABASE_URL or POSTGRES_URL: PostgreSQL connection string | |
| TESTING: Set to "true" to force in-memory database | |
| Usage: | |
| from database import db | |
| # Create user | |
| user = await db.create_user(email, password_hash, full_name) | |
| # Get user | |
| user = await db.get_user_by_email(email) | |
| """ | |
| import os | |
| import sys | |
| from datetime import datetime, timezone | |
| from typing import Optional | |
| from abc import ABC, abstractmethod | |
| from urllib.parse import urlparse, parse_qs | |
| import asyncio | |
| import threading | |
| import asyncio | |
| import threading | |
| import time | |
| import logging | |
| # Configure logger | |
| logger = logging.getLogger(__name__) | |
| # pg8000 for PostgreSQL - pure Python, lightweight | |
| # Wrap in try/except for safer imports | |
| PG8000_AVAILABLE = False | |
| pg8000 = None | |
| try: | |
| import pg8000 | |
| import pg8000.native | |
| PG8000_AVAILABLE = True | |
| PG8000_AVAILABLE = True | |
| except ImportError as e: | |
| logger.warning(f"pg8000 not available: {e}") | |
| except Exception as e: | |
| logger.error(f"Error loading pg8000: {e}") | |
| # ============= Abstract Database Interface ============= | |
| class DatabaseInterface(ABC): | |
| """Abstract interface for database operations""" | |
| async def init_tables(self) -> None: | |
| """Initialize database tables""" | |
| pass | |
| # User operations | |
| async def create_user(self, email: str, password_hash: str, full_name: str, | |
| is_admin: bool = False, totp_secret: Optional[str] = None) -> dict: | |
| pass | |
| async def get_user_by_email(self, email: str) -> Optional[dict]: | |
| pass | |
| async def get_user_by_id(self, user_id: int) -> Optional[dict]: | |
| pass | |
| async def update_user_last_login(self, email: str) -> None: | |
| pass | |
| async def delete_user(self, user_id: int) -> bool: | |
| pass | |
| async def list_users(self) -> list[dict]: | |
| pass | |
| # Hotel operations | |
| async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None, | |
| website_url: Optional[str] = None) -> dict: | |
| pass | |
| async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]: | |
| pass | |
| async def get_hotel(self, hotel_id: int) -> Optional[dict]: | |
| pass | |
| async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool: | |
| pass | |
| # ============= In-Memory Database (Development/Testing) ============= | |
| class InMemoryDatabase(DatabaseInterface): | |
| """In-memory database for development and testing""" | |
| def __init__(self): | |
| self._users: dict[str, dict] = {} | |
| self._hotels: list[dict] = [] | |
| self._user_id_counter = 0 | |
| self._hotel_id_counter = 0 | |
| async def init_tables(self) -> None: | |
| """No-op for in-memory database""" | |
| pass | |
| async def create_user(self, email: str, password_hash: str, full_name: str, | |
| is_admin: bool = False, totp_secret: Optional[str] = None) -> dict: | |
| self._user_id_counter += 1 | |
| user = { | |
| "id": self._user_id_counter, | |
| "email": email, | |
| "password_hash": password_hash, | |
| "full_name": full_name, | |
| "is_admin": is_admin, | |
| "totp_secret": totp_secret, | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| "last_login": None | |
| } | |
| self._users[email] = user | |
| return user | |
| async def get_user_by_email(self, email: str) -> Optional[dict]: | |
| return self._users.get(email) | |
| async def get_user_by_id(self, user_id: int) -> Optional[dict]: | |
| for user in self._users.values(): | |
| if user["id"] == user_id: | |
| return user | |
| return None | |
| async def update_user_last_login(self, email: str) -> None: | |
| if email in self._users: | |
| self._users[email]["last_login"] = datetime.now(timezone.utc).isoformat() | |
| async def delete_user(self, user_id: int) -> bool: | |
| for email, user in list(self._users.items()): | |
| if user["id"] == user_id: | |
| del self._users[email] | |
| return True | |
| return False | |
| async def list_users(self) -> list[dict]: | |
| return list(self._users.values()) | |
| async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None, | |
| website_url: Optional[str] = None) -> dict: | |
| self._hotel_id_counter += 1 | |
| hotel = { | |
| "id": self._hotel_id_counter, | |
| "name": name, | |
| "owner_id": owner_id, | |
| "booking_url": booking_url, | |
| "website_url": website_url, | |
| "created_at": datetime.now(timezone.utc).isoformat() | |
| } | |
| self._hotels.append(hotel) | |
| return hotel | |
| async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]: | |
| if owner_id: | |
| return [h for h in self._hotels if h.get("owner_id") == owner_id or h.get("owner_id") is None] | |
| return self._hotels | |
| async def get_hotel(self, hotel_id: int) -> Optional[dict]: | |
| for hotel in self._hotels: | |
| if hotel["id"] == hotel_id: | |
| return hotel | |
| return None | |
| async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool: | |
| for i, hotel in enumerate(self._hotels): | |
| if hotel["id"] == hotel_id: | |
| if owner_id and hotel.get("owner_id") != owner_id: | |
| return False | |
| self._hotels.pop(i) | |
| return True | |
| return False | |
| def seed_admin_users(self, admin_configs: list[dict]) -> None: | |
| """Seed admin users for development""" | |
| for config in admin_configs: | |
| self._user_id_counter += 1 | |
| self._users[config["email"]] = { | |
| "id": self._user_id_counter, | |
| "email": config["email"], | |
| "password_hash": config["password_hash"], | |
| "full_name": config["full_name"], | |
| "is_admin": True, | |
| "totp_secret": config.get("totp_secret"), | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| "last_login": None | |
| } | |
| def seed_regular_users(self, user_configs: list[dict]) -> None: | |
| """Seed regular (non-admin) users for development""" | |
| for config in user_configs: | |
| self._user_id_counter += 1 | |
| self._users[config["email"]] = { | |
| "id": self._user_id_counter, | |
| "email": config["email"], | |
| "password_hash": config["password_hash"], | |
| "full_name": config["full_name"], | |
| "is_admin": False, # Regular users are NOT admins | |
| "totp_secret": config.get("totp_secret"), | |
| "created_at": datetime.now(timezone.utc).isoformat(), | |
| "last_login": None | |
| } | |
| def seed_demo_hotels(self, hotels: list[dict]) -> None: | |
| """Seed demo hotels for development""" | |
| for hotel in hotels: | |
| self._hotel_id_counter += 1 | |
| self._hotels.append({ | |
| "id": self._hotel_id_counter, | |
| "name": hotel["name"], | |
| "owner_id": None, # Publicly visible demo hotel | |
| "booking_url": hotel.get("booking_url"), | |
| "website_url": hotel.get("website_url"), | |
| "created_at": datetime.now(timezone.utc).isoformat() | |
| }) | |
| # ============= PostgreSQL Database (Production - Neon/Vercel) ============= | |
| class PostgresDatabase(DatabaseInterface): | |
| """PostgreSQL database using pg8000 (pure Python, works with Neon, Vercel Postgres, etc.)""" | |
| def __init__(self, connection_url: str): | |
| """ | |
| Initialize with a PostgreSQL connection URL. | |
| Format: postgresql://user:password@host:port/database?sslmode=require | |
| """ | |
| self.connection_url = connection_url | |
| self._conn = None | |
| self._parse_connection_url() | |
| def _parse_connection_url(self): | |
| """Parse the connection URL into components""" | |
| from urllib.parse import unquote | |
| parsed = urlparse(self.connection_url) | |
| self._host = parsed.hostname or 'localhost' | |
| self._port = parsed.port or 5432 | |
| self._user = unquote(parsed.username or 'postgres') | |
| self._password = unquote(parsed.password or '') | |
| self._database = parsed.path.lstrip('/').split('?')[0] or 'postgres' | |
| # Parse query params for SSL | |
| query_params = parse_qs(parsed.query) | |
| self._ssl = 'sslmode' in query_params and query_params['sslmode'][0] != 'disable' | |
| # Neon and Vercel always require SSL | |
| if 'neon' in (self._host or '') or 'vercel' in (self._host or ''): | |
| self._ssl = True | |
| def _get_connection(self): | |
| """Get or create database connection""" | |
| # Ensure we have a lock | |
| if not hasattr(self, '_lock'): | |
| self._lock = threading.Lock() | |
| if self._conn is None: | |
| try: | |
| ssl_context = None | |
| if self._ssl: | |
| import ssl | |
| ssl_context = ssl.create_default_context() | |
| # pg8000.connect is the correct API | |
| self._conn = pg8000.connect( | |
| host=self._host, | |
| port=self._port, | |
| user=self._user, | |
| password=self._password, | |
| database=self._database, | |
| ssl_context=ssl_context | |
| ) | |
| self._conn.autocommit = True | |
| except Exception as e: | |
| logger.error(f"Failed to connect to database: {e}") | |
| raise | |
| return self._conn | |
| def _run_query_in_thread(self, query: str, params: Optional[tuple] = None) -> list[dict]: | |
| """Run query in a separate thread with locking and retry logic""" | |
| if not hasattr(self, '_lock'): | |
| self._lock = threading.Lock() | |
| with self._lock: | |
| # Simple retry mechanism for lost connections | |
| max_retries = 2 | |
| for attempt in range(max_retries): | |
| try: | |
| conn = self._get_connection() | |
| cursor = conn.cursor() | |
| try: | |
| # pg8000 uses %s style parameters just like psycopg2 | |
| if params is not None: | |
| cursor.execute(query, params) | |
| else: | |
| cursor.execute(query) | |
| # Check if this is a SELECT-like query that returns rows | |
| try: | |
| if cursor.description and len(cursor.description) > 0: | |
| columns = [desc[0] for desc in cursor.description] | |
| rows = cursor.fetchall() | |
| return [dict(zip(columns, row)) for row in rows] | |
| except (TypeError, AttributeError): | |
| # DDL statements don't return rows | |
| pass | |
| return [] | |
| finally: | |
| cursor.close() | |
| except (pg8000.InterfaceError, pg8000.DatabaseError, AttributeError) as e: | |
| # Connection might be dead, close and retry | |
| logger.warning(f"Database error (attempt {attempt+1}/{max_retries}): {e}") | |
| if self._conn: | |
| try: | |
| self._conn.close() | |
| except: | |
| pass | |
| self._conn = None | |
| if attempt == max_retries - 1: | |
| raise e | |
| # Wait a bit before retry | |
| time.sleep(0.5) | |
| # Should not be reached due to raise above | |
| return [] | |
| async def _execute_query(self, query: str, params: Optional[tuple] = None) -> list[dict]: | |
| """Execute a SQL query asynchronously using a thread pool""" | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor(None, self._run_query_in_thread, query, params) | |
| async def init_tables(self) -> None: | |
| """Initialize database tables""" | |
| # Users table | |
| await self._execute_query(""" | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id SERIAL PRIMARY KEY, | |
| email VARCHAR(255) UNIQUE NOT NULL, | |
| password_hash VARCHAR(512) NOT NULL, | |
| full_name VARCHAR(255) NOT NULL, | |
| is_admin BOOLEAN DEFAULT FALSE, | |
| totp_secret VARCHAR(64), | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| last_login TIMESTAMP | |
| ) | |
| """) | |
| # Hotels table | |
| await self._execute_query(""" | |
| CREATE TABLE IF NOT EXISTS hotels ( | |
| id SERIAL PRIMARY KEY, | |
| name VARCHAR(255) NOT NULL, | |
| booking_url TEXT, | |
| website_url TEXT, | |
| owner_id INTEGER REFERENCES users(id) ON DELETE CASCADE, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # Add owner_id column if it doesn't exist (for migration) | |
| await self._execute_query(""" | |
| ALTER TABLE hotels ADD COLUMN IF NOT EXISTS owner_id INTEGER REFERENCES users(id) ON DELETE CASCADE | |
| """) | |
| # Price comparisons table (for history) | |
| await self._execute_query(""" | |
| CREATE TABLE IF NOT EXISTS price_comparisons ( | |
| id SERIAL PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id) ON DELETE CASCADE, | |
| hotel_ids INTEGER[] NOT NULL, | |
| check_in DATE NOT NULL, | |
| check_out DATE NOT NULL, | |
| occupancy VARCHAR(50), | |
| results JSONB, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # Create index for faster email lookups | |
| await self._execute_query(""" | |
| CREATE INDEX IF NOT EXISTS idx_users_email ON users(email) | |
| """) | |
| async def create_user(self, email: str, password_hash: str, full_name: str, | |
| is_admin: bool = False, totp_secret: Optional[str] = None) -> dict: | |
| rows = await self._execute_query( | |
| """ | |
| INSERT INTO users (email, password_hash, full_name, is_admin, totp_secret) | |
| VALUES (%s, %s, %s, %s, %s) | |
| RETURNING id, email, password_hash, full_name, is_admin, totp_secret, | |
| created_at::text, last_login::text | |
| """, | |
| (email, password_hash, full_name, is_admin, totp_secret) | |
| ) | |
| return rows[0] if rows else {} | |
| async def get_user_by_email(self, email: str) -> Optional[dict]: | |
| rows = await self._execute_query( | |
| """SELECT id, email, password_hash, full_name, is_admin, totp_secret, | |
| created_at::text, last_login::text | |
| FROM users WHERE email = %s""", | |
| (email,) | |
| ) | |
| return rows[0] if rows else None | |
| async def get_user_by_id(self, user_id: int) -> Optional[dict]: | |
| rows = await self._execute_query( | |
| """SELECT id, email, password_hash, full_name, is_admin, totp_secret, | |
| created_at::text, last_login::text | |
| FROM users WHERE id = %s""", | |
| (user_id,) | |
| ) | |
| return rows[0] if rows else None | |
| async def update_user_last_login(self, email: str) -> None: | |
| await self._execute_query( | |
| "UPDATE users SET last_login = CURRENT_TIMESTAMP WHERE email = %s", | |
| (email,) | |
| ) | |
| async def delete_user(self, user_id: int) -> bool: | |
| rows = await self._execute_query( | |
| "DELETE FROM users WHERE id = %s RETURNING id", | |
| (user_id,) | |
| ) | |
| return len(rows) > 0 | |
| async def list_users(self) -> list[dict]: | |
| return await self._execute_query( | |
| """SELECT id, email, password_hash, full_name, is_admin, totp_secret, | |
| created_at::text, last_login::text | |
| FROM users ORDER BY created_at DESC""" | |
| ) | |
| async def create_hotel(self, name: str, owner_id: int, booking_url: Optional[str] = None, | |
| website_url: Optional[str] = None) -> dict: | |
| rows = await self._execute_query( | |
| """ | |
| INSERT INTO hotels (name, owner_id, booking_url, website_url) | |
| VALUES (%s, %s, %s, %s) | |
| RETURNING id, name, owner_id, booking_url, website_url, created_at::text | |
| """, | |
| (name, owner_id, booking_url, website_url) | |
| ) | |
| return rows[0] if rows else {} | |
| async def get_hotels(self, owner_id: Optional[int] = None) -> list[dict]: | |
| if owner_id: | |
| return await self._execute_query( | |
| """SELECT id, name, owner_id, booking_url, website_url, created_at::text | |
| FROM hotels WHERE owner_id = %s OR owner_id IS NULL ORDER BY name""", | |
| (owner_id,) | |
| ) | |
| return await self._execute_query( | |
| "SELECT id, name, owner_id, booking_url, website_url, created_at::text FROM hotels ORDER BY name" | |
| ) | |
| async def get_hotel(self, hotel_id: int) -> Optional[dict]: | |
| """Get a single hotel by ID""" | |
| rows = await self._execute_query( | |
| """SELECT id, name, owner_id, booking_url, website_url, created_at::text | |
| FROM hotels WHERE id = %s""", | |
| (hotel_id,) | |
| ) | |
| return rows[0] if rows else None | |
| async def delete_hotel(self, hotel_id: int, owner_id: Optional[int] = None) -> bool: | |
| """Delete hotel, optionally verifying ownership""" | |
| if owner_id: | |
| rows = await self._execute_query( | |
| "DELETE FROM hotels WHERE id = %s AND owner_id = %s RETURNING id", | |
| (hotel_id, owner_id) | |
| ) | |
| else: | |
| rows = await self._execute_query( | |
| "DELETE FROM hotels WHERE id = %s RETURNING id", | |
| (hotel_id,) | |
| ) | |
| return len(rows) > 0 | |
| rows = await self._execute_query( | |
| "DELETE FROM hotels WHERE id = %s RETURNING id", | |
| (hotel_id,) | |
| ) | |
| return len(rows) > 0 | |
| def close(self): | |
| """Close database connection""" | |
| if self._conn: | |
| self._conn.close() | |
| self._conn = None | |
| # ============= Database Factory ============= | |
| def create_database() -> DatabaseInterface: | |
| """ | |
| Create the appropriate database instance based on environment. | |
| Environment variables checked (in order): | |
| - DATABASE_URL: Standard PostgreSQL connection string | |
| - POSTGRES_URL: Vercel/Neon Postgres connection string | |
| Returns InMemoryDatabase for development/testing. | |
| Returns PostgresDatabase when a connection URL is configured. | |
| """ | |
| # Check for database URL (multiple common env var names) | |
| postgres_url = ( | |
| os.environ.get("DATABASE_URL") or | |
| os.environ.get("POSTGRES_URL") or | |
| os.environ.get("POSTGRES_URL_NON_POOLING") or | |
| "" | |
| ) | |
| testing = os.environ.get("TESTING", "false").lower() == "true" | |
| # Use in-memory database for testing or when no database URL is configured | |
| if testing or not postgres_url: | |
| return InMemoryDatabase() | |
| # Check if pg8000 is available | |
| if not PG8000_AVAILABLE: | |
| logger.warning("pg8000 not available, falling back to in-memory database") | |
| return InMemoryDatabase() | |
| return PostgresDatabase(postgres_url) | |
| # Singleton database instance | |
| db: DatabaseInterface = create_database() | |