Mark-Lasfar
		
	commited on
		
		
					Commit 
							
							·
						
						0e85bcd
	
1
								Parent(s):
							
							766ef88
								
Fix ChunkedIteratorResult in SQLAlchemyUserDatabase and toggleBtn null error
Browse files- api/auth.py +77 -8
- api/database.py +8 -1
- main.py +8 -1
- requirements.txt +6 -5
- static/js/scripts.js +8 -3
    	
        api/auth.py
    CHANGED
    
    | @@ -4,12 +4,13 @@ from fastapi_users.db import SQLAlchemyUserDatabase | |
| 4 | 
             
            from httpx_oauth.clients.google import GoogleOAuth2
         | 
| 5 | 
             
            from httpx_oauth.clients.github import GitHubOAuth2
         | 
| 6 | 
             
            from api.database import User, OAuthAccount, get_user_db
         | 
| 7 | 
            -
            from api.models import UserRead, UserCreate, UserUpdate | 
| 8 | 
             
            from fastapi_users.manager import BaseUserManager, IntegerIDMixin
         | 
| 9 | 
             
            from fastapi import Depends, Request, FastAPI
         | 
| 10 | 
             
            from sqlalchemy.ext.asyncio import AsyncSession
         | 
|  | |
| 11 | 
             
            from fastapi_users.models import UP
         | 
| 12 | 
            -
            from typing import Optional
         | 
| 13 | 
             
            import os
         | 
| 14 | 
             
            import logging
         | 
| 15 |  | 
| @@ -51,10 +52,73 @@ if not all([GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET, GITHUB_CLIENT_ID, GITHUB_CLI | |
| 51 | 
             
            google_oauth_client = GoogleOAuth2(GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET)
         | 
| 52 | 
             
            github_oauth_client = GitHubOAuth2(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET)
         | 
| 53 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 54 | 
             
            class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
         | 
| 55 | 
             
                reset_password_token_secret = SECRET
         | 
| 56 | 
             
                verification_token_secret = SECRET
         | 
| 57 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 58 | 
             
                async def oauth_callback(
         | 
| 59 | 
             
                    self,
         | 
| 60 | 
             
                    oauth_name: str,
         | 
| @@ -68,6 +132,7 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]): | |
| 68 | 
             
                    associate_by_email: bool = False,
         | 
| 69 | 
             
                    is_verified_by_default: bool = False,
         | 
| 70 | 
             
                ) -> UP:
         | 
|  | |
| 71 | 
             
                    oauth_account_dict = {
         | 
| 72 | 
             
                        "oauth_name": oauth_name,
         | 
| 73 | 
             
                        "access_token": access_token,
         | 
| @@ -77,15 +142,17 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]): | |
| 77 | 
             
                        "refresh_token": refresh_token,
         | 
| 78 | 
             
                    }
         | 
| 79 | 
             
                    oauth_account = OAuthAccount(**oauth_account_dict)
         | 
| 80 | 
            -
                    existing_oauth_account = await self. | 
| 81 | 
             
                    if existing_oauth_account is not None:
         | 
|  | |
| 82 | 
             
                        return await self.on_after_login(existing_oauth_account.user, request)
         | 
| 83 |  | 
| 84 | 
             
                    if associate_by_email:
         | 
| 85 | 
             
                        user = await self.user_db.get_by_email(account_email)
         | 
| 86 | 
             
                        if user is not None:
         | 
| 87 | 
             
                            oauth_account.user_id = user.id
         | 
| 88 | 
            -
                            await self. | 
|  | |
| 89 | 
             
                            return await self.on_after_login(user, request)
         | 
| 90 |  | 
| 91 | 
             
                    user_dict = {
         | 
| @@ -96,13 +163,15 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]): | |
| 96 | 
             
                    }
         | 
| 97 | 
             
                    user = await self.user_db.create(user_dict)
         | 
| 98 | 
             
                    oauth_account.user_id = user.id
         | 
| 99 | 
            -
                    await self. | 
|  | |
| 100 | 
             
                    return await self.on_after_login(user, request)
         | 
| 101 |  | 
| 102 | 
            -
            async def  | 
| 103 | 
            -
                yield  | 
| 104 |  | 
| 105 | 
            -
             | 
|  | |
| 106 |  | 
| 107 | 
             
            google_oauth_router = get_oauth_router(
         | 
| 108 | 
             
                google_oauth_client,
         | 
|  | |
| 4 | 
             
            from httpx_oauth.clients.google import GoogleOAuth2
         | 
| 5 | 
             
            from httpx_oauth.clients.github import GitHubOAuth2
         | 
| 6 | 
             
            from api.database import User, OAuthAccount, get_user_db
         | 
| 7 | 
            +
            from api.models import UserRead, UserCreate, UserUpdate
         | 
| 8 | 
             
            from fastapi_users.manager import BaseUserManager, IntegerIDMixin
         | 
| 9 | 
             
            from fastapi import Depends, Request, FastAPI
         | 
| 10 | 
             
            from sqlalchemy.ext.asyncio import AsyncSession
         | 
| 11 | 
            +
            from sqlalchemy import select
         | 
| 12 | 
             
            from fastapi_users.models import UP
         | 
| 13 | 
            +
            from typing import Optional, Dict, Any
         | 
| 14 | 
             
            import os
         | 
| 15 | 
             
            import logging
         | 
| 16 |  | 
|  | |
| 52 | 
             
            google_oauth_client = GoogleOAuth2(GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET)
         | 
| 53 | 
             
            github_oauth_client = GitHubOAuth2(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET)
         | 
| 54 |  | 
| 55 | 
            +
            class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
         | 
| 56 | 
            +
                async def get_by_email(self, email: str) -> Optional[User]:
         | 
| 57 | 
            +
                    """Override to fix ChunkedIteratorResult issue for get_by_email"""
         | 
| 58 | 
            +
                    logger.info(f"Checking for user with email: {email}")
         | 
| 59 | 
            +
                    try:
         | 
| 60 | 
            +
                        statement = select(self.user_table).where(self.user_table.email == email)
         | 
| 61 | 
            +
                        result = await self.session.execute(statement)
         | 
| 62 | 
            +
                        user = result.scalar_one_or_none()
         | 
| 63 | 
            +
                        if user:
         | 
| 64 | 
            +
                            logger.info(f"Found user with email: {email}")
         | 
| 65 | 
            +
                        else:
         | 
| 66 | 
            +
                            logger.info(f"No user found with email: {email}")
         | 
| 67 | 
            +
                        return user
         | 
| 68 | 
            +
                    except Exception as e:
         | 
| 69 | 
            +
                        logger.error(f"Error in get_by_email: {e}")
         | 
| 70 | 
            +
                        raise
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                async def create(self, create_dict: Dict[str, Any]) -> User:
         | 
| 73 | 
            +
                    """Override to fix potential async issues in create"""
         | 
| 74 | 
            +
                    logger.info(f"Creating user with email: {create_dict.get('email')}")
         | 
| 75 | 
            +
                    try:
         | 
| 76 | 
            +
                        user = self.user_table(**create_dict)
         | 
| 77 | 
            +
                        self.session.add(user)
         | 
| 78 | 
            +
                        await self.session.commit()
         | 
| 79 | 
            +
                        await self.session.refresh(user)
         | 
| 80 | 
            +
                        logger.info(f"Created user with email: {create_dict.get('email')}")
         | 
| 81 | 
            +
                        return user
         | 
| 82 | 
            +
                    except Exception as e:
         | 
| 83 | 
            +
                        logger.error(f"Error creating user: {e}")
         | 
| 84 | 
            +
                        await self.session.rollback()
         | 
| 85 | 
            +
                        raise
         | 
| 86 | 
            +
             | 
| 87 | 
             
            class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
         | 
| 88 | 
             
                reset_password_token_secret = SECRET
         | 
| 89 | 
             
                verification_token_secret = SECRET
         | 
| 90 |  | 
| 91 | 
            +
                async def get_by_oauth_account(self, oauth_name: str, account_id: str):
         | 
| 92 | 
            +
                    """Override to fix ChunkedIteratorResult issue in SQLAlchemy 2.0+"""
         | 
| 93 | 
            +
                    logger.info(f"Checking for existing OAuth account: {oauth_name}/{account_id}")
         | 
| 94 | 
            +
                    try:
         | 
| 95 | 
            +
                        statement = select(OAuthAccount).where(
         | 
| 96 | 
            +
                            OAuthAccount.oauth_name == oauth_name, OAuthAccount.account_id == account_id
         | 
| 97 | 
            +
                        )
         | 
| 98 | 
            +
                        result = await self.session.execute(statement)
         | 
| 99 | 
            +
                        oauth_account = result.scalar_one_or_none()
         | 
| 100 | 
            +
                        if oauth_account:
         | 
| 101 | 
            +
                            logger.info(f"Found existing OAuth account for {account_id}")
         | 
| 102 | 
            +
                        else:
         | 
| 103 | 
            +
                            logger.info(f"No existing OAuth account found for {account_id}")
         | 
| 104 | 
            +
                        return oauth_account
         | 
| 105 | 
            +
                    except Exception as e:
         | 
| 106 | 
            +
                        logger.error(f"Error in get_by_oauth_account: {e}")
         | 
| 107 | 
            +
                        raise
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                async def add_oauth_account(self, oauth_account: OAuthAccount):
         | 
| 110 | 
            +
                    """Override to fix potential async issues"""
         | 
| 111 | 
            +
                    logger.info(f"Adding OAuth account for user {oauth_account.user_id}")
         | 
| 112 | 
            +
                    try:
         | 
| 113 | 
            +
                        self.session.add(oauth_account)
         | 
| 114 | 
            +
                        await self.session.commit()
         | 
| 115 | 
            +
                        await self.session.refresh(oauth_account)
         | 
| 116 | 
            +
                        logger.info(f"Successfully added OAuth account for user {oauth_account.user_id}")
         | 
| 117 | 
            +
                    except Exception as e:
         | 
| 118 | 
            +
                        logger.error(f"Error adding OAuth account: {e}")
         | 
| 119 | 
            +
                        await self.session.rollback()
         | 
| 120 | 
            +
                        raise
         | 
| 121 | 
            +
             | 
| 122 | 
             
                async def oauth_callback(
         | 
| 123 | 
             
                    self,
         | 
| 124 | 
             
                    oauth_name: str,
         | 
|  | |
| 132 | 
             
                    associate_by_email: bool = False,
         | 
| 133 | 
             
                    is_verified_by_default: bool = False,
         | 
| 134 | 
             
                ) -> UP:
         | 
| 135 | 
            +
                    logger.info(f"OAuth callback for {oauth_name} with account_id {account_id}")
         | 
| 136 | 
             
                    oauth_account_dict = {
         | 
| 137 | 
             
                        "oauth_name": oauth_name,
         | 
| 138 | 
             
                        "access_token": access_token,
         | 
|  | |
| 142 | 
             
                        "refresh_token": refresh_token,
         | 
| 143 | 
             
                    }
         | 
| 144 | 
             
                    oauth_account = OAuthAccount(**oauth_account_dict)
         | 
| 145 | 
            +
                    existing_oauth_account = await self.get_by_oauth_account(oauth_name, account_id)
         | 
| 146 | 
             
                    if existing_oauth_account is not None:
         | 
| 147 | 
            +
                        logger.info(f"Existing account found, logging in user {existing_oauth_account.user.email}")
         | 
| 148 | 
             
                        return await self.on_after_login(existing_oauth_account.user, request)
         | 
| 149 |  | 
| 150 | 
             
                    if associate_by_email:
         | 
| 151 | 
             
                        user = await self.user_db.get_by_email(account_email)
         | 
| 152 | 
             
                        if user is not None:
         | 
| 153 | 
             
                            oauth_account.user_id = user.id
         | 
| 154 | 
            +
                            await self.add_oauth_account(oauth_account)
         | 
| 155 | 
            +
                            logger.info(f"Associated with existing user {user.email}")
         | 
| 156 | 
             
                            return await self.on_after_login(user, request)
         | 
| 157 |  | 
| 158 | 
             
                    user_dict = {
         | 
|  | |
| 163 | 
             
                    }
         | 
| 164 | 
             
                    user = await self.user_db.create(user_dict)
         | 
| 165 | 
             
                    oauth_account.user_id = user.id
         | 
| 166 | 
            +
                    await self.add_oauth_account(oauth_account)
         | 
| 167 | 
            +
                    logger.info(f"Created new user {user.email}")
         | 
| 168 | 
             
                    return await self.on_after_login(user, request)
         | 
| 169 |  | 
| 170 | 
            +
            async def get_user_db(session: AsyncSession = Depends(get_db)):
         | 
| 171 | 
            +
                yield CustomSQLAlchemyUserDatabase(session, User, OAuthAccount)
         | 
| 172 |  | 
| 173 | 
            +
            async def get_user_manager(user_db: CustomSQLAlchemyUserDatabase = Depends(get_user_db)):
         | 
| 174 | 
            +
                yield UserManager(user_db)
         | 
| 175 |  | 
| 176 | 
             
            google_oauth_router = get_oauth_router(
         | 
| 177 | 
             
                google_oauth_client,
         | 
    	
        api/database.py
    CHANGED
    
    | @@ -8,7 +8,9 @@ from sqlalchemy.orm import Session | |
| 8 | 
             
            from typing import AsyncGenerator
         | 
| 9 | 
             
            from fastapi import Depends
         | 
| 10 | 
             
            from datetime import datetime
         | 
|  | |
| 11 |  | 
|  | |
| 12 | 
             
            # جلب URL قاعدة البيانات من المتغيرات البيئية
         | 
| 13 | 
             
            SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL")
         | 
| 14 | 
             
            if not SQLALCHEMY_DATABASE_URL:
         | 
| @@ -84,4 +86,9 @@ async def get_user_db(session: Session = Depends(get_db)): | |
| 84 | 
             
                yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
         | 
| 85 |  | 
| 86 | 
             
            # إنشاء الجداول
         | 
| 87 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 8 | 
             
            from typing import AsyncGenerator
         | 
| 9 | 
             
            from fastapi import Depends
         | 
| 10 | 
             
            from datetime import datetime
         | 
| 11 | 
            +
            import logging
         | 
| 12 |  | 
| 13 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 14 | 
             
            # جلب URL قاعدة البيانات من المتغيرات البيئية
         | 
| 15 | 
             
            SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL")
         | 
| 16 | 
             
            if not SQLALCHEMY_DATABASE_URL:
         | 
|  | |
| 86 | 
             
                yield SQLAlchemyUserDatabase(session, User, OAuthAccount)
         | 
| 87 |  | 
| 88 | 
             
            # إنشاء الجداول
         | 
| 89 | 
            +
            try:
         | 
| 90 | 
            +
                Base.metadata.create_all(bind=engine)
         | 
| 91 | 
            +
                logger.info("Database tables created successfully")
         | 
| 92 | 
            +
            except Exception as e:
         | 
| 93 | 
            +
                logger.error(f"Error creating database tables: {e}")
         | 
| 94 | 
            +
                raise
         | 
    	
        main.py
    CHANGED
    
    | @@ -127,6 +127,7 @@ async def debug_routes(): | |
| 127 | 
             
                return "\n".join(sorted(routes))
         | 
| 128 |  | 
| 129 | 
             
            # Custom middleware for 404 and 500 errors
         | 
|  | |
| 130 | 
             
            class NotFoundMiddleware(BaseHTTPMiddleware):
         | 
| 131 | 
             
                async def dispatch(self, request: Request, call_next):
         | 
| 132 | 
             
                    try:
         | 
| @@ -137,7 +138,13 @@ class NotFoundMiddleware(BaseHTTPMiddleware): | |
| 137 | 
             
                        return response
         | 
| 138 | 
             
                    except Exception as e:
         | 
| 139 | 
             
                        logger.exception(f"Error processing request {request.url}: {e}")
         | 
| 140 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 141 |  | 
| 142 | 
             
            app.add_middleware(NotFoundMiddleware)
         | 
| 143 |  | 
|  | |
| 127 | 
             
                return "\n".join(sorted(routes))
         | 
| 128 |  | 
| 129 | 
             
            # Custom middleware for 404 and 500 errors
         | 
| 130 | 
            +
            # في main.py، استبدل NotFoundMiddleware ب:
         | 
| 131 | 
             
            class NotFoundMiddleware(BaseHTTPMiddleware):
         | 
| 132 | 
             
                async def dispatch(self, request: Request, call_next):
         | 
| 133 | 
             
                    try:
         | 
|  | |
| 138 | 
             
                        return response
         | 
| 139 | 
             
                    except Exception as e:
         | 
| 140 | 
             
                        logger.exception(f"Error processing request {request.url}: {e}")
         | 
| 141 | 
            +
                        if "ChunkedIteratorResult" in str(e):
         | 
| 142 | 
            +
                            logger.error("ChunkedIteratorResult error detected - check SQLAlchemy async execution")
         | 
| 143 | 
            +
                        return templates.TemplateResponse(
         | 
| 144 | 
            +
                            "500.html",
         | 
| 145 | 
            +
                            {"request": request, "error": str(e)},
         | 
| 146 | 
            +
                            status_code=500
         | 
| 147 | 
            +
                        )
         | 
| 148 |  | 
| 149 | 
             
            app.add_middleware(NotFoundMiddleware)
         | 
| 150 |  | 
    	
        requirements.txt
    CHANGED
    
    | @@ -2,13 +2,13 @@ fastapi==0.115.2 | |
| 2 | 
             
            packaging>=23.0
         | 
| 3 | 
             
            uvicorn==0.30.6
         | 
| 4 | 
             
            gradio>=4.44.1
         | 
| 5 | 
            -
            openai==1. | 
| 6 | 
             
            httpx==0.27.0
         | 
| 7 | 
             
            python-dotenv==1.0.1
         | 
| 8 | 
             
            pydocstyle==6.3.0
         | 
| 9 | 
             
            requests==2.32.5
         | 
| 10 | 
             
            beautifulsoup4==4.12.3
         | 
| 11 | 
            -
            tenacity== | 
| 12 | 
             
            selenium==4.25.0
         | 
| 13 | 
             
            webdriver-manager==4.0.2
         | 
| 14 | 
             
            jinja2==3.1.4
         | 
| @@ -25,15 +25,16 @@ Pillow==10.4.0 | |
| 25 | 
             
            urllib3==2.0.7
         | 
| 26 | 
             
            itsdangerous
         | 
| 27 | 
             
            protobuf==3.19.6
         | 
| 28 | 
            -
            fastapi-users[sqlalchemy,oauth]>=13.0.0
         | 
| 29 |  | 
|  | |
| 30 | 
             
            sqlalchemy>=2.0.0
         | 
| 31 | 
             
            python-jose[cryptography]>=3.3.0
         | 
| 32 | 
             
            passlib[bcrypt]>=1.7.4
         | 
| 33 | 
            -
            httpx-oauth
         | 
| 34 | 
            -
            python-multipart
         | 
| 35 | 
             
            aiofiles
         | 
| 36 | 
             
            motor
         | 
| 37 | 
             
            redis
         | 
| 38 | 
             
            markdown2
         | 
|  | |
| 39 |  | 
|  | |
| 2 | 
             
            packaging>=23.0
         | 
| 3 | 
             
            uvicorn==0.30.6
         | 
| 4 | 
             
            gradio>=4.44.1
         | 
| 5 | 
            +
            openai==1.51.2
         | 
| 6 | 
             
            httpx==0.27.0
         | 
| 7 | 
             
            python-dotenv==1.0.1
         | 
| 8 | 
             
            pydocstyle==6.3.0
         | 
| 9 | 
             
            requests==2.32.5
         | 
| 10 | 
             
            beautifulsoup4==4.12.3
         | 
| 11 | 
            +
            tenacity==9.0.0
         | 
| 12 | 
             
            selenium==4.25.0
         | 
| 13 | 
             
            webdriver-manager==4.0.2
         | 
| 14 | 
             
            jinja2==3.1.4
         | 
|  | |
| 25 | 
             
            urllib3==2.0.7
         | 
| 26 | 
             
            itsdangerous
         | 
| 27 | 
             
            protobuf==3.19.6
         | 
|  | |
| 28 |  | 
| 29 | 
            +
            fastapi-users[sqlalchemy,oauth2]==14.0.0
         | 
| 30 | 
             
            sqlalchemy>=2.0.0
         | 
| 31 | 
             
            python-jose[cryptography]>=3.3.0
         | 
| 32 | 
             
            passlib[bcrypt]>=1.7.4
         | 
| 33 | 
            +
            httpx-oauth==0.15.3
         | 
| 34 | 
            +
            python-multipart==0.0.12
         | 
| 35 | 
             
            aiofiles
         | 
| 36 | 
             
            motor
         | 
| 37 | 
             
            redis
         | 
| 38 | 
             
            markdown2
         | 
| 39 | 
            +
            pymongo==4.10.1
         | 
| 40 |  | 
    	
        static/js/scripts.js
    CHANGED
    
    | @@ -33,7 +33,12 @@ document.addEventListener('DOMContentLoaded', () => { | |
| 33 | 
             
                const sidebar = document.querySelector('.sidebar');
         | 
| 34 | 
             
                const toggleBtn = document.querySelector('.sidebar-toggle');
         | 
| 35 |  | 
| 36 | 
            -
                toggleBtn | 
| 37 | 
            -
                     | 
| 38 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 39 | 
             
            });
         | 
|  | |
| 33 | 
             
                const sidebar = document.querySelector('.sidebar');
         | 
| 34 | 
             
                const toggleBtn = document.querySelector('.sidebar-toggle');
         | 
| 35 |  | 
| 36 | 
            +
                if (toggleBtn && sidebar) {
         | 
| 37 | 
            +
                    toggleBtn.addEventListener('click', () => {
         | 
| 38 | 
            +
                        sidebar.classList.toggle('active');
         | 
| 39 | 
            +
                        console.log('Sidebar toggled'); // Debugging
         | 
| 40 | 
            +
                    });
         | 
| 41 | 
            +
                } else {
         | 
| 42 | 
            +
                    console.warn('Sidebar or toggle button not found');
         | 
| 43 | 
            +
                }
         | 
| 44 | 
             
            });
         | 
