Spaces:
Running
Running
| import os | |
| import jwt | |
| import hashlib | |
| from datetime import datetime, timedelta | |
| from fastapi import APIRouter, HTTPException, Depends, Header | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from pydantic import BaseModel | |
| from sqlalchemy.orm import Session | |
| from ..database import SessionLocal | |
| from ..config import settings | |
| from .. import crud | |
| router = APIRouter() | |
| security = HTTPBearer() | |
| # Models | |
| class AdminLoginRequest(BaseModel): | |
| password: str | |
| class AdminLoginResponse(BaseModel): | |
| access_token: str | |
| token_type: str = "bearer" | |
| expires_at: str | |
| class AdminVerifyResponse(BaseModel): | |
| valid: bool | |
| message: str | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| def get_admin_password(): | |
| """Get admin password from environment variable""" | |
| password = os.getenv('ADMIN_PASSWORD') | |
| if not password: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="ADMIN_PASSWORD environment variable not set" | |
| ) | |
| return password | |
| def create_admin_token(): | |
| """Create a JWT token for admin authentication""" | |
| # In production, use a proper secret key | |
| secret_key = os.getenv('JWT_SECRET_KEY', 'your-secret-key-change-in-production') | |
| payload = { | |
| 'role': 'admin', | |
| 'exp': datetime.utcnow() + timedelta(hours=24), # 24 hour expiry | |
| 'iat': datetime.utcnow() | |
| } | |
| token = jwt.encode(payload, secret_key, algorithm='HS256') | |
| return token | |
| def verify_admin_token(token: str): | |
| """Verify the admin JWT token""" | |
| try: | |
| secret_key = os.getenv('JWT_SECRET_KEY', 'your-secret-key-change-in-production') | |
| payload = jwt.decode(token, secret_key, algorithms=['HS256']) | |
| if payload.get('role') != 'admin': | |
| return False | |
| # Check if token is expired | |
| exp = payload.get('exp') | |
| if exp and datetime.utcnow() > datetime.fromtimestamp(exp): | |
| return False | |
| return True | |
| except jwt.ExpiredSignatureError: | |
| return False | |
| except jwt.InvalidTokenError: | |
| return False | |
| async def admin_login(request: AdminLoginRequest): | |
| """Admin login endpoint""" | |
| admin_password = get_admin_password() | |
| # Hash the provided password and compare with stored hash | |
| # For now, using simple comparison (in production, use proper hashing) | |
| if request.password == admin_password: | |
| token = create_admin_token() | |
| expires_at = datetime.utcnow() + timedelta(hours=24) | |
| return AdminLoginResponse( | |
| access_token=token, | |
| expires_at=expires_at.isoformat() | |
| ) | |
| else: | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid admin password" | |
| ) | |
| async def verify_admin(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| """Verify admin token endpoint""" | |
| token = credentials.credentials | |
| if verify_admin_token(token): | |
| return AdminVerifyResponse( | |
| valid=True, | |
| message="Token is valid" | |
| ) | |
| else: | |
| return AdminVerifyResponse( | |
| valid=False, | |
| message="Token is invalid or expired" | |
| ) | |
| async def admin_status(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
| """Get admin status (protected endpoint)""" | |
| token = credentials.credentials | |
| if not verify_admin_token(token): | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid or expired token" | |
| ) | |
| return { | |
| "status": "authenticated", | |
| "role": "admin", | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| # Model Management Endpoints | |
| class ModelCreateRequest(BaseModel): | |
| m_code: str | |
| label: str | |
| model_type: str | |
| provider: str | |
| model_id: str | |
| is_available: bool = False | |
| class ModelUpdateRequest(BaseModel): | |
| label: str | None = None | |
| model_type: str | None = None | |
| provider: str | None = None | |
| model_id: str | None = None | |
| is_available: bool | None = None | |
| is_fallback: bool | None = None | |
| async def create_model( | |
| request: ModelCreateRequest, | |
| credentials: HTTPAuthorizationCredentials = Depends(security), | |
| db: Session = Depends(get_db) | |
| ): | |
| """Create a new model (admin only)""" | |
| token = credentials.credentials | |
| if not verify_admin_token(token): | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid or expired token" | |
| ) | |
| try: | |
| # Check if model already exists | |
| existing_model = crud.get_model(db, request.m_code) | |
| if existing_model: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Model with code '{request.m_code}' already exists" | |
| ) | |
| # Create new model | |
| new_model = crud.create_model( | |
| db, | |
| m_code=request.m_code, | |
| label=request.label, | |
| model_type=request.model_type, | |
| provider=request.provider, | |
| model_id=request.model_id, | |
| is_available=request.is_available | |
| ) | |
| return { | |
| "message": "Model created successfully", | |
| "model": { | |
| "m_code": new_model.m_code, | |
| "label": new_model.label, | |
| "model_type": new_model.model_type, | |
| "provider": new_model.provider, | |
| "model_id": new_model.model_id, | |
| "is_available": new_model.is_available | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to create model: {str(e)}" | |
| ) | |
| async def update_model( | |
| model_code: str, | |
| request: ModelUpdateRequest, | |
| credentials: HTTPAuthorizationCredentials = Depends(security), | |
| db: Session = Depends(get_db) | |
| ): | |
| """Update an existing model (admin only)""" | |
| token = credentials.credentials | |
| if not verify_admin_token(token): | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid or expired token" | |
| ) | |
| try: | |
| # Check if model exists | |
| existing_model = crud.get_model(db, model_code) | |
| if not existing_model: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Model '{model_code}' not found" | |
| ) | |
| # Handle fallback model setting specially | |
| if request.is_fallback is not None and request.is_fallback: | |
| # Set this model as fallback (will clear others) | |
| crud.set_fallback_model(db, model_code) | |
| updated_model = crud.get_model(db, model_code) | |
| else: | |
| # Update model fields normally | |
| update_data = {} | |
| if request.label is not None: | |
| update_data["label"] = request.label | |
| if request.model_type is not None: | |
| update_data["model_type"] = request.model_type | |
| if request.is_available is not None: | |
| update_data["is_available"] = request.is_available | |
| if request.is_fallback is not None and not request.is_fallback: | |
| update_data["is_fallback"] = False | |
| # Update config column for provider and model_id | |
| config_updates = {} | |
| if request.provider is not None: | |
| config_updates["provider"] = request.provider | |
| if request.model_id is not None: | |
| config_updates["model_id"] = request.model_id | |
| if config_updates: | |
| # Get current config or create empty dict | |
| current_config = existing_model.config or {} | |
| # Merge with updates | |
| updated_config = {**current_config, **config_updates} | |
| update_data["config"] = updated_config | |
| updated_model = crud.update_model(db, model_code, update_data) | |
| return { | |
| "message": "Model updated successfully", | |
| "model": { | |
| "m_code": updated_model.m_code, | |
| "label": updated_model.label, | |
| "model_type": updated_model.model_type, | |
| "is_available": updated_model.is_available, | |
| "config": updated_model.config or {} | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to update model: {str(e)}" | |
| ) | |
| async def get_fallback_model( | |
| credentials: HTTPAuthorizationCredentials = Depends(security), | |
| db: Session = Depends(get_db) | |
| ): | |
| """Get the current fallback model""" | |
| token = credentials.credentials | |
| if not verify_admin_token(token): | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid or expired token" | |
| ) | |
| try: | |
| fallback_model_code = crud.get_fallback_model(db) | |
| if fallback_model_code: | |
| fallback_model = crud.get_model(db, fallback_model_code) | |
| return { | |
| "fallback_model": { | |
| "m_code": fallback_model.m_code, | |
| "label": fallback_model.label, | |
| "is_available": fallback_model.is_available | |
| } | |
| } | |
| else: | |
| return {"fallback_model": None} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to get fallback model: {str(e)}" | |
| ) | |
| async def delete_model( | |
| model_code: str, | |
| credentials: HTTPAuthorizationCredentials = Depends(security), | |
| db: Session = Depends(get_db) | |
| ): | |
| """Delete a model (admin only)""" | |
| token = credentials.credentials | |
| if not verify_admin_token(token): | |
| raise HTTPException( | |
| status_code=401, | |
| detail="Invalid or expired token" | |
| ) | |
| try: | |
| # Check if model exists | |
| existing_model = crud.get_model(db, model_code) | |
| if not existing_model: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Model '{model_code}' not found" | |
| ) | |
| # Check if model is being used by any captions | |
| from ..models import Captions | |
| caption_count = db.query(Captions).filter(Captions.model == model_code).count() | |
| if caption_count > 0: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Cannot delete model '{model_code}' - it is used by {caption_count} caption(s)" | |
| ) | |
| # Hard delete model (remove from database) | |
| crud.delete_model(db, model_code) | |
| return { | |
| "message": f"Model '{model_code}' deleted successfully from database" | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to delete model: {str(e)}" | |
| ) | |