import logging import os from typing import Optional from datetime import datetime, UTC from fastapi import Depends, HTTPException, Request, status from fastapi.security import APIKeyHeader from src.db.init_db import get_session from src.db.schemas.models import User as DBUser, AgentTemplate, UserTemplatePreference from src.schemas.user_schema import User from src.utils.logger import Logger logger = Logger("user_manager", see_time=True, console_log=False) # Define API key header for authentication API_KEY_NAME = "X-API-Key" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) async def get_current_user( request: Request, api_key: Optional[str] = Depends(api_key_header) ) -> Optional[User]: """ Dependency to get the current authenticated user. Returns None if no user is authenticated. """ # FastAPI resolves the `api_key` parameter when this function is used as a dependency. However, when the # function is called directly (e.g. from the session manager), the `api_key` parameter will still hold the # unresolved `Depends` placeholder object. In that case – or when no API key is supplied – we need to # manually look for the key in the request headers or query parameters. if not api_key or not isinstance(api_key, str): # Prefer header first for consistency with the dependency implementation api_key = request.headers.get(API_KEY_NAME) or request.query_params.get("api_key") # If an API key still isn't available, treat the caller as anonymous if not api_key: return None try: # In a real application, you'd validate the API key against stored user keys # For this example, we'll use a simple lookup using user id session = get_session() try: # Simplified example: assume API key is the user_id for demonstration # In a real app, you'd do a secure lookup try: # Check if api_key is actually a string before converting to int if isinstance(api_key, str): user_id = int(api_key) db_user = session.query(DBUser).filter(DBUser.user_id == user_id).first() else: # Handle the case where api_key is not a string (like Depends object) logger.log_message("API key is not a string", level=logging.ERROR) return None except ValueError: # If api_key isn't a number, maybe check by username or something else logger.log_message(f"API key is not a number: {api_key}", level=logging.ERROR) db_user = session.query(DBUser).filter(DBUser.username == api_key).first() if not db_user: logger.log_message("User not found", level=logging.ERROR) return None return User( user_id=db_user.user_id, username=db_user.username, email=db_user.email ) finally: session.close() except Exception as e: logger.log_message(f"Error authenticating user: {str(e)}", level=logging.ERROR) return None # Function to create a new user def create_user(username: str, email: str) -> User: """Create a new user in the database""" session = get_session() try: # Check if user with this email already exists existing_user = session.query(DBUser).filter(DBUser.email == email).first() if existing_user: return User( user_id=existing_user.user_id, username=existing_user.username, email=existing_user.email ) # Create new user new_user = DBUser( username=username, email=email ) session.add(new_user) session.commit() session.refresh(new_user) # Enable default agents for the new user _enable_default_agents_for_user(new_user.user_id, session) return User( user_id=new_user.user_id, username=new_user.username, email=new_user.email ) except Exception as e: session.rollback() logger.log_message(f"Error creating user: {str(e)}", logging.ERROR) raise finally: session.close() def get_user_by_email(email: str) -> Optional[User]: """Get a user by email""" session = get_session() try: user = session.query(DBUser).filter(DBUser.email == email).first() if user is None: return None return User( user_id=user.user_id, username=user.username, email=user.email ) except Exception as e: logger.log_message(f"Error getting user by email: {str(e)}", logging.ERROR) return None finally: session.close() def _enable_default_agents_for_user(user_id: int, session): """Enable default agents for a new user""" try: # Get all default agents (the 4 built-in agents) default_agent_names = [ "preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent" ] # Find these agents in the database default_agents = session.query(AgentTemplate).filter( AgentTemplate.template_name.in_(default_agent_names), AgentTemplate.is_active == True ).all() # Enable each default agent for the user for agent in default_agents: # Check if preference already exists existing_pref = session.query(UserTemplatePreference).filter( UserTemplatePreference.user_id == user_id, UserTemplatePreference.template_id == agent.template_id ).first() if not existing_pref: # Create new preference with enabled=True new_pref = UserTemplatePreference( user_id=user_id, template_id=agent.template_id, is_enabled=True, # Enable by default usage_count=0, created_at=datetime.now(UTC), updated_at=datetime.now(UTC) ) session.add(new_pref) session.commit() logger.log_message(f"Enabled {len(default_agents)} default agents for user {user_id}", level=logging.INFO) except Exception as e: session.rollback() logger.log_message(f"Error enabling default agents for user {user_id}: {str(e)}", level=logging.ERROR) raise