Spaces:
Sleeping
Sleeping
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine | |
from sqlalchemy.orm import sessionmaker | |
from sqlalchemy.exc import SQLAlchemyError | |
from sqlalchemy.pool import QueuePool | |
from fastapi import HTTPException | |
import asyncio | |
import logging | |
from app.db.models import Base | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Use SQLite with aiosqlite and connection pooling | |
DATABASE_URL = "sqlite+aiosqlite:///./sql_app.db" | |
# Configure the engine with connection pooling and timeouts | |
engine = create_async_engine( | |
DATABASE_URL, | |
echo=True, | |
pool_size=20, # Maximum number of connections in the pool | |
max_overflow=10, # Maximum number of connections that can be created beyond pool_size | |
pool_timeout=30, # Timeout for getting a connection from the pool | |
pool_recycle=1800, # Recycle connections after 30 minutes | |
pool_pre_ping=True, # Enable connection health checks | |
poolclass=QueuePool | |
) | |
# Configure session with retry logic | |
AsyncSessionLocal = sessionmaker( | |
engine, | |
class_=AsyncSession, | |
expire_on_commit=False | |
) | |
# Semaphore to limit concurrent database operations | |
MAX_CONCURRENT_DB_OPS = 10 | |
db_semaphore = asyncio.Semaphore(MAX_CONCURRENT_DB_OPS) | |
async def get_db() -> AsyncSession: | |
async with db_semaphore: # Limit concurrent database operations | |
session = AsyncSessionLocal() | |
try: | |
yield session | |
except SQLAlchemyError as e: | |
logger.error(f"Database error: {str(e)}") | |
await session.rollback() | |
raise HTTPException( | |
status_code=503, | |
detail="Database service temporarily unavailable. Please try again." | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
await session.rollback() | |
raise HTTPException( | |
status_code=500, | |
detail="An unexpected error occurred. Please try again." | |
) | |
finally: | |
await session.close() | |
# Rate limiting configuration | |
from fastapi import Request | |
import time | |
from collections import defaultdict | |
class RateLimiter: | |
def __init__(self, requests_per_minute=60): | |
self.requests_per_minute = requests_per_minute | |
self.requests = defaultdict(list) | |
def is_allowed(self, client_ip: str) -> bool: | |
now = time.time() | |
minute_ago = now - 60 | |
# Clean old requests | |
self.requests[client_ip] = [req_time for req_time in self.requests[client_ip] | |
if req_time > minute_ago] | |
# Check if allowed | |
if len(self.requests[client_ip]) >= self.requests_per_minute: | |
return False | |
# Add new request | |
self.requests[client_ip].append(now) | |
return True | |
rate_limiter = RateLimiter(requests_per_minute=60) | |
# Add this to your database initialization | |
async def init_db(): | |
try: | |
async with engine.begin() as conn: | |
# Check if the database already has tables | |
if not await database_exists(conn): | |
logger.info("Initializing database: Creating tables...") | |
await conn.run_sync(Base.metadata.create_all) | |
logger.info("Database initialization completed successfully") | |
else: | |
logger.info("Database already exists. Skipping initialization.") | |
except Exception as e: | |
logger.error(f"Error initializing database: {str(e)}") | |
raise | |
async def database_exists(conn) -> bool: | |
"""Check if any tables exist in the database.""" | |
try: | |
result = await conn.run_sync( | |
lambda sync_conn: sync_conn.dialect.has_table(sync_conn, "invoices") | |
) | |
return result | |
except Exception as e: | |
logger.error(f"Error checking database existence: {str(e)}") | |
return False | |
###sdf |