Mark-Lasfar
commited on
Commit
·
493a4a6
1
Parent(s):
6ec5390
Fix ChunkedIteratorResult in SQLAlchemyUserDatabase and toggleBtn null error
Browse files- api/auth.py +4 -4
- api/database.py +2 -2
- api/endpoints.py +75 -56
- api/models.py +0 -1
- init_db.py +34 -25
- requirements.txt +2 -0
api/auth.py
CHANGED
|
@@ -60,7 +60,7 @@ github_oauth_client = GitHubOAuth2(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET)
|
|
| 60 |
|
| 61 |
class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
|
| 62 |
async def get_by_email(self, email: str) -> Optional[User]:
|
| 63 |
-
"""Override to fix ChunkedIteratorResult issue
|
| 64 |
logger.info(f"Checking for user with email: {email}")
|
| 65 |
try:
|
| 66 |
statement = select(self.user_table).where(self.user_table.email == email)
|
|
@@ -76,7 +76,7 @@ class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
|
|
| 76 |
raise
|
| 77 |
|
| 78 |
async def create(self, create_dict: Dict[str, Any]) -> User:
|
| 79 |
-
"""Override to fix potential async issues
|
| 80 |
logger.info(f"Creating user with email: {create_dict.get('email')}")
|
| 81 |
try:
|
| 82 |
user = self.user_table(**create_dict)
|
|
@@ -95,7 +95,7 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
|
| 95 |
verification_token_secret = SECRET
|
| 96 |
|
| 97 |
async def get_by_oauth_account(self, oauth_name: str, account_id: str):
|
| 98 |
-
"""Override to fix ChunkedIteratorResult issue
|
| 99 |
logger.info(f"Checking for existing OAuth account: {oauth_name}/{account_id}")
|
| 100 |
try:
|
| 101 |
statement = select(OAuthAccount).where(
|
|
@@ -113,7 +113,7 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
|
| 113 |
raise
|
| 114 |
|
| 115 |
async def add_oauth_account(self, oauth_account: OAuthAccount):
|
| 116 |
-
"""Override to fix potential async issues"""
|
| 117 |
logger.info(f"Adding OAuth account for user {oauth_account.user_id}")
|
| 118 |
try:
|
| 119 |
self.session.add(oauth_account)
|
|
|
|
| 60 |
|
| 61 |
class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
|
| 62 |
async def get_by_email(self, email: str) -> Optional[User]:
|
| 63 |
+
"""Override to fix ChunkedIteratorResult issue."""
|
| 64 |
logger.info(f"Checking for user with email: {email}")
|
| 65 |
try:
|
| 66 |
statement = select(self.user_table).where(self.user_table.email == email)
|
|
|
|
| 76 |
raise
|
| 77 |
|
| 78 |
async def create(self, create_dict: Dict[str, Any]) -> User:
|
| 79 |
+
"""Override to fix potential async issues."""
|
| 80 |
logger.info(f"Creating user with email: {create_dict.get('email')}")
|
| 81 |
try:
|
| 82 |
user = self.user_table(**create_dict)
|
|
|
|
| 95 |
verification_token_secret = SECRET
|
| 96 |
|
| 97 |
async def get_by_oauth_account(self, oauth_name: str, account_id: str):
|
| 98 |
+
"""Override to fix ChunkedIteratorResult issue."""
|
| 99 |
logger.info(f"Checking for existing OAuth account: {oauth_name}/{account_id}")
|
| 100 |
try:
|
| 101 |
statement = select(OAuthAccount).where(
|
|
|
|
| 113 |
raise
|
| 114 |
|
| 115 |
async def add_oauth_account(self, oauth_account: OAuthAccount):
|
| 116 |
+
"""Override to fix potential async issues."""
|
| 117 |
logger.info(f"Adding OAuth account for user {oauth_account.user_id}")
|
| 118 |
try:
|
| 119 |
self.session.add(oauth_account)
|
api/database.py
CHANGED
|
@@ -12,11 +12,11 @@ import logging
|
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
# جلب URL قاعدة البيانات من المتغيرات البيئية
|
| 15 |
-
SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL")
|
| 16 |
if not SQLALCHEMY_DATABASE_URL:
|
| 17 |
raise ValueError("SQLALCHEMY_DATABASE_URL is not set in environment variables.")
|
| 18 |
|
| 19 |
-
# إنشاء محرك async
|
| 20 |
async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True)
|
| 21 |
|
| 22 |
# إعداد جلسة async
|
|
|
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
# جلب URL قاعدة البيانات من المتغيرات البيئية
|
| 15 |
+
SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL", "sqlite+aiosqlite:///./data/mgzon_users.db")
|
| 16 |
if not SQLALCHEMY_DATABASE_URL:
|
| 17 |
raise ValueError("SQLALCHEMY_DATABASE_URL is not set in environment variables.")
|
| 18 |
|
| 19 |
+
# إنشاء محرك async
|
| 20 |
async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True)
|
| 21 |
|
| 22 |
# إعداد جلسة async
|
api/endpoints.py
CHANGED
|
@@ -6,7 +6,8 @@ from api.database import User, Conversation, Message
|
|
| 6 |
from api.models import QueryRequest, ConversationOut, ConversationCreate, UserUpdate
|
| 7 |
from api.auth import current_active_user
|
| 8 |
from api.database import get_db
|
| 9 |
-
from sqlalchemy.
|
|
|
|
| 10 |
from utils.generation import request_generation, select_model, check_model_availability
|
| 11 |
from utils.web_search import web_search
|
| 12 |
import io
|
|
@@ -145,7 +146,7 @@ async def chat_endpoint(
|
|
| 145 |
request: Request,
|
| 146 |
req: QueryRequest,
|
| 147 |
user: User = Depends(current_active_user),
|
| 148 |
-
db:
|
| 149 |
):
|
| 150 |
logger.info(f"Received chat request: {req}")
|
| 151 |
|
|
@@ -155,7 +156,10 @@ async def chat_endpoint(
|
|
| 155 |
conversation = None
|
| 156 |
if user:
|
| 157 |
title = req.title or (req.message[:50] + "..." if len(req.message) > 50 else req.message or "Untitled Conversation")
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
| 159 |
if not conversation:
|
| 160 |
conversation_id = str(uuid.uuid4())
|
| 161 |
conversation = Conversation(
|
|
@@ -164,12 +168,12 @@ async def chat_endpoint(
|
|
| 164 |
title=title
|
| 165 |
)
|
| 166 |
db.add(conversation)
|
| 167 |
-
db.commit()
|
| 168 |
-
db.refresh(conversation)
|
| 169 |
|
| 170 |
user_msg = Message(role="user", content=req.message, conversation_id=conversation.id)
|
| 171 |
db.add(user_msg)
|
| 172 |
-
db.commit()
|
| 173 |
|
| 174 |
preferred_model = user.preferred_model if user else None
|
| 175 |
model_name, api_endpoint = select_model(req.message, input_type="text", preferred_model=preferred_model)
|
|
@@ -231,9 +235,9 @@ async def chat_endpoint(
|
|
| 231 |
if user and conversation:
|
| 232 |
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
|
| 233 |
db.add(assistant_msg)
|
| 234 |
-
db.commit()
|
| 235 |
conversation.updated_at = datetime.utcnow()
|
| 236 |
-
db.commit()
|
| 237 |
return {
|
| 238 |
"response": response,
|
| 239 |
"conversation_id": conversation.conversation_id,
|
|
@@ -248,7 +252,7 @@ async def audio_transcription_endpoint(
|
|
| 248 |
request: Request,
|
| 249 |
file: UploadFile = File(...),
|
| 250 |
user: User = Depends(current_active_user),
|
| 251 |
-
db:
|
| 252 |
):
|
| 253 |
logger.info(f"Received audio transcription request for file: {file.filename}")
|
| 254 |
|
|
@@ -258,7 +262,10 @@ async def audio_transcription_endpoint(
|
|
| 258 |
conversation = None
|
| 259 |
if user:
|
| 260 |
title = "Audio Transcription"
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
| 262 |
if not conversation:
|
| 263 |
conversation_id = str(uuid.uuid4())
|
| 264 |
conversation = Conversation(
|
|
@@ -267,12 +274,12 @@ async def audio_transcription_endpoint(
|
|
| 267 |
title=title
|
| 268 |
)
|
| 269 |
db.add(conversation)
|
| 270 |
-
db.commit()
|
| 271 |
-
db.refresh(conversation)
|
| 272 |
|
| 273 |
user_msg = Message(role="user", content="Audio message", conversation_id=conversation.id)
|
| 274 |
db.add(user_msg)
|
| 275 |
-
db.commit()
|
| 276 |
|
| 277 |
model_name, api_endpoint = select_model("transcribe audio", input_type="audio")
|
| 278 |
|
|
@@ -313,9 +320,9 @@ async def audio_transcription_endpoint(
|
|
| 313 |
if user and conversation:
|
| 314 |
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
|
| 315 |
db.add(assistant_msg)
|
| 316 |
-
db.commit()
|
| 317 |
conversation.updated_at = datetime.utcnow()
|
| 318 |
-
db.commit()
|
| 319 |
return {
|
| 320 |
"transcription": response,
|
| 321 |
"conversation_id": conversation.conversation_id,
|
|
@@ -330,7 +337,7 @@ async def text_to_speech_endpoint(
|
|
| 330 |
request: Request,
|
| 331 |
req: dict,
|
| 332 |
user: User = Depends(current_active_user),
|
| 333 |
-
db:
|
| 334 |
):
|
| 335 |
if not user:
|
| 336 |
await handle_session(request)
|
|
@@ -378,7 +385,7 @@ async def code_endpoint(
|
|
| 378 |
request: Request,
|
| 379 |
req: dict,
|
| 380 |
user: User = Depends(current_active_user),
|
| 381 |
-
db:
|
| 382 |
):
|
| 383 |
if not user:
|
| 384 |
await handle_session(request)
|
|
@@ -453,7 +460,7 @@ async def analysis_endpoint(
|
|
| 453 |
request: Request,
|
| 454 |
req: dict,
|
| 455 |
user: User = Depends(current_active_user),
|
| 456 |
-
db:
|
| 457 |
):
|
| 458 |
if not user:
|
| 459 |
await handle_session(request)
|
|
@@ -526,7 +533,7 @@ async def image_analysis_endpoint(
|
|
| 526 |
file: UploadFile = File(...),
|
| 527 |
output_format: str = "text",
|
| 528 |
user: User = Depends(current_active_user),
|
| 529 |
-
db:
|
| 530 |
):
|
| 531 |
if not user:
|
| 532 |
await handle_session(request)
|
|
@@ -534,7 +541,10 @@ async def image_analysis_endpoint(
|
|
| 534 |
conversation = None
|
| 535 |
if user:
|
| 536 |
title = "Image Analysis"
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
| 538 |
if not conversation:
|
| 539 |
conversation_id = str(uuid.uuid4())
|
| 540 |
conversation = Conversation(
|
|
@@ -543,12 +553,12 @@ async def image_analysis_endpoint(
|
|
| 543 |
title=title
|
| 544 |
)
|
| 545 |
db.add(conversation)
|
| 546 |
-
db.commit()
|
| 547 |
-
db.refresh(conversation)
|
| 548 |
|
| 549 |
user_msg = Message(role="user", content="Image analysis request", conversation_id=conversation.id)
|
| 550 |
db.add(user_msg)
|
| 551 |
-
db.commit()
|
| 552 |
|
| 553 |
preferred_model = user.preferred_model if user else None
|
| 554 |
model_name, api_endpoint = select_model("analyze image", input_type="image", preferred_model=preferred_model)
|
|
@@ -608,9 +618,9 @@ async def image_analysis_endpoint(
|
|
| 608 |
if user and conversation:
|
| 609 |
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
|
| 610 |
db.add(assistant_msg)
|
| 611 |
-
db.commit()
|
| 612 |
conversation.updated_at = datetime.utcnow()
|
| 613 |
-
db.commit()
|
| 614 |
return {
|
| 615 |
"image_analysis": response,
|
| 616 |
"conversation_id": conversation.conversation_id,
|
|
@@ -646,7 +656,7 @@ async def test_model(model: str = MODEL_NAME, endpoint: str = API_ENDPOINT):
|
|
| 646 |
async def create_conversation(
|
| 647 |
req: ConversationCreate,
|
| 648 |
user: User = Depends(current_active_user),
|
| 649 |
-
db:
|
| 650 |
):
|
| 651 |
if not user:
|
| 652 |
raise HTTPException(status_code=401, detail="Login required")
|
|
@@ -657,77 +667,87 @@ async def create_conversation(
|
|
| 657 |
user_id=user.id
|
| 658 |
)
|
| 659 |
db.add(conversation)
|
| 660 |
-
db.commit()
|
| 661 |
-
db.refresh(conversation)
|
| 662 |
return ConversationOut.from_orm(conversation)
|
| 663 |
|
| 664 |
@router.get("/api/conversations/{conversation_id}", response_model=ConversationOut)
|
| 665 |
async def get_conversation(
|
| 666 |
conversation_id: str,
|
| 667 |
user: User = Depends(current_active_user),
|
| 668 |
-
db:
|
| 669 |
):
|
| 670 |
if not user:
|
| 671 |
raise HTTPException(status_code=401, detail="Login required")
|
| 672 |
-
|
| 673 |
-
Conversation.
|
| 674 |
-
|
| 675 |
-
|
|
|
|
|
|
|
|
|
|
| 676 |
if not conversation:
|
| 677 |
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 678 |
-
db.add(conversation)
|
| 679 |
-
db.commit()
|
| 680 |
return ConversationOut.from_orm(conversation)
|
| 681 |
|
| 682 |
@router.get("/api/conversations", response_model=List[ConversationOut])
|
| 683 |
async def list_conversations(
|
| 684 |
user: User = Depends(current_active_user),
|
| 685 |
-
db:
|
| 686 |
):
|
| 687 |
if not user:
|
| 688 |
raise HTTPException(status_code=401, detail="Login required")
|
| 689 |
-
|
| 690 |
-
|
|
|
|
|
|
|
|
|
|
| 691 |
|
| 692 |
@router.put("/api/conversations/{conversation_id}/title")
|
| 693 |
async def update_conversation_title(
|
| 694 |
conversation_id: str,
|
| 695 |
title: str,
|
| 696 |
user: User = Depends(current_active_user),
|
| 697 |
-
db:
|
| 698 |
):
|
| 699 |
if not user:
|
| 700 |
raise HTTPException(status_code=401, detail="Login required")
|
| 701 |
-
|
| 702 |
-
Conversation.
|
| 703 |
-
|
| 704 |
-
|
|
|
|
|
|
|
|
|
|
| 705 |
if not conversation:
|
| 706 |
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 707 |
|
| 708 |
conversation.title = title
|
| 709 |
conversation.updated_at = datetime.utcnow()
|
| 710 |
-
db.commit()
|
| 711 |
return {"message": "Conversation title updated", "title": conversation.title}
|
| 712 |
|
| 713 |
@router.delete("/api/conversations/{conversation_id}")
|
| 714 |
async def delete_conversation(
|
| 715 |
conversation_id: str,
|
| 716 |
user: User = Depends(current_active_user),
|
| 717 |
-
db:
|
| 718 |
):
|
| 719 |
if not user:
|
| 720 |
raise HTTPException(status_code=401, detail="Login required")
|
| 721 |
-
|
| 722 |
-
Conversation.
|
| 723 |
-
|
| 724 |
-
|
|
|
|
|
|
|
|
|
|
| 725 |
if not conversation:
|
| 726 |
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 727 |
|
| 728 |
-
db.
|
| 729 |
-
db.delete(conversation)
|
| 730 |
-
db.commit()
|
| 731 |
return {"message": "Conversation deleted successfully"}
|
| 732 |
|
| 733 |
@router.get("/users/me")
|
|
@@ -752,7 +772,7 @@ async def get_user_settings(user: User = Depends(current_active_user)):
|
|
| 752 |
async def update_user_settings(
|
| 753 |
settings: UserUpdate,
|
| 754 |
user: User = Depends(current_active_user),
|
| 755 |
-
db:
|
| 756 |
):
|
| 757 |
if not user:
|
| 758 |
raise HTTPException(status_code=401, detail="Login required")
|
|
@@ -775,8 +795,8 @@ async def update_user_settings(
|
|
| 775 |
if settings.conversation_style is not None:
|
| 776 |
user.conversation_style = settings.conversation_style
|
| 777 |
|
| 778 |
-
db.commit()
|
| 779 |
-
db.refresh(user)
|
| 780 |
return {"message": "Settings updated successfully", "user": {
|
| 781 |
"id": user.id,
|
| 782 |
"email": user.email,
|
|
@@ -790,4 +810,3 @@ async def update_user_settings(
|
|
| 790 |
"is_active": user.is_active,
|
| 791 |
"is_superuser": user.is_superuser
|
| 792 |
}}
|
| 793 |
-
|
|
|
|
| 6 |
from api.models import QueryRequest, ConversationOut, ConversationCreate, UserUpdate
|
| 7 |
from api.auth import current_active_user
|
| 8 |
from api.database import get_db
|
| 9 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 10 |
+
from sqlalchemy import select, delete
|
| 11 |
from utils.generation import request_generation, select_model, check_model_availability
|
| 12 |
from utils.web_search import web_search
|
| 13 |
import io
|
|
|
|
| 146 |
request: Request,
|
| 147 |
req: QueryRequest,
|
| 148 |
user: User = Depends(current_active_user),
|
| 149 |
+
db: AsyncSession = Depends(get_db)
|
| 150 |
):
|
| 151 |
logger.info(f"Received chat request: {req}")
|
| 152 |
|
|
|
|
| 156 |
conversation = None
|
| 157 |
if user:
|
| 158 |
title = req.title or (req.message[:50] + "..." if len(req.message) > 50 else req.message or "Untitled Conversation")
|
| 159 |
+
result = await db.execute(
|
| 160 |
+
select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc())
|
| 161 |
+
)
|
| 162 |
+
conversation = result.scalar_one_or_none()
|
| 163 |
if not conversation:
|
| 164 |
conversation_id = str(uuid.uuid4())
|
| 165 |
conversation = Conversation(
|
|
|
|
| 168 |
title=title
|
| 169 |
)
|
| 170 |
db.add(conversation)
|
| 171 |
+
await db.commit()
|
| 172 |
+
await db.refresh(conversation)
|
| 173 |
|
| 174 |
user_msg = Message(role="user", content=req.message, conversation_id=conversation.id)
|
| 175 |
db.add(user_msg)
|
| 176 |
+
await db.commit()
|
| 177 |
|
| 178 |
preferred_model = user.preferred_model if user else None
|
| 179 |
model_name, api_endpoint = select_model(req.message, input_type="text", preferred_model=preferred_model)
|
|
|
|
| 235 |
if user and conversation:
|
| 236 |
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
|
| 237 |
db.add(assistant_msg)
|
| 238 |
+
await db.commit()
|
| 239 |
conversation.updated_at = datetime.utcnow()
|
| 240 |
+
await db.commit()
|
| 241 |
return {
|
| 242 |
"response": response,
|
| 243 |
"conversation_id": conversation.conversation_id,
|
|
|
|
| 252 |
request: Request,
|
| 253 |
file: UploadFile = File(...),
|
| 254 |
user: User = Depends(current_active_user),
|
| 255 |
+
db: AsyncSession = Depends(get_db)
|
| 256 |
):
|
| 257 |
logger.info(f"Received audio transcription request for file: {file.filename}")
|
| 258 |
|
|
|
|
| 262 |
conversation = None
|
| 263 |
if user:
|
| 264 |
title = "Audio Transcription"
|
| 265 |
+
result = await db.execute(
|
| 266 |
+
select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc())
|
| 267 |
+
)
|
| 268 |
+
conversation = result.scalar_one_or_none()
|
| 269 |
if not conversation:
|
| 270 |
conversation_id = str(uuid.uuid4())
|
| 271 |
conversation = Conversation(
|
|
|
|
| 274 |
title=title
|
| 275 |
)
|
| 276 |
db.add(conversation)
|
| 277 |
+
await db.commit()
|
| 278 |
+
await db.refresh(conversation)
|
| 279 |
|
| 280 |
user_msg = Message(role="user", content="Audio message", conversation_id=conversation.id)
|
| 281 |
db.add(user_msg)
|
| 282 |
+
await db.commit()
|
| 283 |
|
| 284 |
model_name, api_endpoint = select_model("transcribe audio", input_type="audio")
|
| 285 |
|
|
|
|
| 320 |
if user and conversation:
|
| 321 |
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
|
| 322 |
db.add(assistant_msg)
|
| 323 |
+
await db.commit()
|
| 324 |
conversation.updated_at = datetime.utcnow()
|
| 325 |
+
await db.commit()
|
| 326 |
return {
|
| 327 |
"transcription": response,
|
| 328 |
"conversation_id": conversation.conversation_id,
|
|
|
|
| 337 |
request: Request,
|
| 338 |
req: dict,
|
| 339 |
user: User = Depends(current_active_user),
|
| 340 |
+
db: AsyncSession = Depends(get_db)
|
| 341 |
):
|
| 342 |
if not user:
|
| 343 |
await handle_session(request)
|
|
|
|
| 385 |
request: Request,
|
| 386 |
req: dict,
|
| 387 |
user: User = Depends(current_active_user),
|
| 388 |
+
db: AsyncSession = Depends(get_db)
|
| 389 |
):
|
| 390 |
if not user:
|
| 391 |
await handle_session(request)
|
|
|
|
| 460 |
request: Request,
|
| 461 |
req: dict,
|
| 462 |
user: User = Depends(current_active_user),
|
| 463 |
+
db: AsyncSession = Depends(get_db)
|
| 464 |
):
|
| 465 |
if not user:
|
| 466 |
await handle_session(request)
|
|
|
|
| 533 |
file: UploadFile = File(...),
|
| 534 |
output_format: str = "text",
|
| 535 |
user: User = Depends(current_active_user),
|
| 536 |
+
db: AsyncSession = Depends(get_db)
|
| 537 |
):
|
| 538 |
if not user:
|
| 539 |
await handle_session(request)
|
|
|
|
| 541 |
conversation = None
|
| 542 |
if user:
|
| 543 |
title = "Image Analysis"
|
| 544 |
+
result = await db.execute(
|
| 545 |
+
select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc())
|
| 546 |
+
)
|
| 547 |
+
conversation = result.scalar_one_or_none()
|
| 548 |
if not conversation:
|
| 549 |
conversation_id = str(uuid.uuid4())
|
| 550 |
conversation = Conversation(
|
|
|
|
| 553 |
title=title
|
| 554 |
)
|
| 555 |
db.add(conversation)
|
| 556 |
+
await db.commit()
|
| 557 |
+
await db.refresh(conversation)
|
| 558 |
|
| 559 |
user_msg = Message(role="user", content="Image analysis request", conversation_id=conversation.id)
|
| 560 |
db.add(user_msg)
|
| 561 |
+
await db.commit()
|
| 562 |
|
| 563 |
preferred_model = user.preferred_model if user else None
|
| 564 |
model_name, api_endpoint = select_model("analyze image", input_type="image", preferred_model=preferred_model)
|
|
|
|
| 618 |
if user and conversation:
|
| 619 |
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
|
| 620 |
db.add(assistant_msg)
|
| 621 |
+
await db.commit()
|
| 622 |
conversation.updated_at = datetime.utcnow()
|
| 623 |
+
await db.commit()
|
| 624 |
return {
|
| 625 |
"image_analysis": response,
|
| 626 |
"conversation_id": conversation.conversation_id,
|
|
|
|
| 656 |
async def create_conversation(
|
| 657 |
req: ConversationCreate,
|
| 658 |
user: User = Depends(current_active_user),
|
| 659 |
+
db: AsyncSession = Depends(get_db)
|
| 660 |
):
|
| 661 |
if not user:
|
| 662 |
raise HTTPException(status_code=401, detail="Login required")
|
|
|
|
| 667 |
user_id=user.id
|
| 668 |
)
|
| 669 |
db.add(conversation)
|
| 670 |
+
await db.commit()
|
| 671 |
+
await db.refresh(conversation)
|
| 672 |
return ConversationOut.from_orm(conversation)
|
| 673 |
|
| 674 |
@router.get("/api/conversations/{conversation_id}", response_model=ConversationOut)
|
| 675 |
async def get_conversation(
|
| 676 |
conversation_id: str,
|
| 677 |
user: User = Depends(current_active_user),
|
| 678 |
+
db: AsyncSession = Depends(get_db)
|
| 679 |
):
|
| 680 |
if not user:
|
| 681 |
raise HTTPException(status_code=401, detail="Login required")
|
| 682 |
+
result = await db.execute(
|
| 683 |
+
select(Conversation).filter(
|
| 684 |
+
Conversation.conversation_id == conversation_id,
|
| 685 |
+
Conversation.user_id == user.id
|
| 686 |
+
)
|
| 687 |
+
)
|
| 688 |
+
conversation = result.scalar_one_or_none()
|
| 689 |
if not conversation:
|
| 690 |
raise HTTPException(status_code=404, detail="Conversation not found")
|
|
|
|
|
|
|
| 691 |
return ConversationOut.from_orm(conversation)
|
| 692 |
|
| 693 |
@router.get("/api/conversations", response_model=List[ConversationOut])
|
| 694 |
async def list_conversations(
|
| 695 |
user: User = Depends(current_active_user),
|
| 696 |
+
db: AsyncSession = Depends(get_db)
|
| 697 |
):
|
| 698 |
if not user:
|
| 699 |
raise HTTPException(status_code=401, detail="Login required")
|
| 700 |
+
result = await db.execute(
|
| 701 |
+
select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.created_at.desc())
|
| 702 |
+
)
|
| 703 |
+
conversations = result.scalars().all()
|
| 704 |
+
return [ConversationOut.from_orm(conv) for conv in conversations]
|
| 705 |
|
| 706 |
@router.put("/api/conversations/{conversation_id}/title")
|
| 707 |
async def update_conversation_title(
|
| 708 |
conversation_id: str,
|
| 709 |
title: str,
|
| 710 |
user: User = Depends(current_active_user),
|
| 711 |
+
db: AsyncSession = Depends(get_db)
|
| 712 |
):
|
| 713 |
if not user:
|
| 714 |
raise HTTPException(status_code=401, detail="Login required")
|
| 715 |
+
result = await db.execute(
|
| 716 |
+
select(Conversation).filter(
|
| 717 |
+
Conversation.conversation_id == conversation_id,
|
| 718 |
+
Conversation.user_id == user.id
|
| 719 |
+
)
|
| 720 |
+
)
|
| 721 |
+
conversation = result.scalar_one_or_none()
|
| 722 |
if not conversation:
|
| 723 |
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 724 |
|
| 725 |
conversation.title = title
|
| 726 |
conversation.updated_at = datetime.utcnow()
|
| 727 |
+
await db.commit()
|
| 728 |
return {"message": "Conversation title updated", "title": conversation.title}
|
| 729 |
|
| 730 |
@router.delete("/api/conversations/{conversation_id}")
|
| 731 |
async def delete_conversation(
|
| 732 |
conversation_id: str,
|
| 733 |
user: User = Depends(current_active_user),
|
| 734 |
+
db: AsyncSession = Depends(get_db)
|
| 735 |
):
|
| 736 |
if not user:
|
| 737 |
raise HTTPException(status_code=401, detail="Login required")
|
| 738 |
+
result = await db.execute(
|
| 739 |
+
select(Conversation).filter(
|
| 740 |
+
Conversation.conversation_id == conversation_id,
|
| 741 |
+
Conversation.user_id == user.id
|
| 742 |
+
)
|
| 743 |
+
)
|
| 744 |
+
conversation = result.scalar_one_or_none()
|
| 745 |
if not conversation:
|
| 746 |
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 747 |
|
| 748 |
+
await db.execute(delete(Message).filter(Message.conversation_id == conversation.id))
|
| 749 |
+
await db.delete(conversation)
|
| 750 |
+
await db.commit()
|
| 751 |
return {"message": "Conversation deleted successfully"}
|
| 752 |
|
| 753 |
@router.get("/users/me")
|
|
|
|
| 772 |
async def update_user_settings(
|
| 773 |
settings: UserUpdate,
|
| 774 |
user: User = Depends(current_active_user),
|
| 775 |
+
db: AsyncSession = Depends(get_db)
|
| 776 |
):
|
| 777 |
if not user:
|
| 778 |
raise HTTPException(status_code=401, detail="Login required")
|
|
|
|
| 795 |
if settings.conversation_style is not None:
|
| 796 |
user.conversation_style = settings.conversation_style
|
| 797 |
|
| 798 |
+
await db.commit()
|
| 799 |
+
await db.refresh(user)
|
| 800 |
return {"message": "Settings updated successfully", "user": {
|
| 801 |
"id": user.id,
|
| 802 |
"email": user.email,
|
|
|
|
| 810 |
"is_active": user.is_active,
|
| 811 |
"is_superuser": user.is_superuser
|
| 812 |
}}
|
|
|
api/models.py
CHANGED
|
@@ -34,7 +34,6 @@ class UserCreate(schemas.BaseUserCreate):
|
|
| 34 |
|
| 35 |
model_config = {"from_attributes": True}
|
| 36 |
|
| 37 |
-
# Pydantic schema for updating user settings
|
| 38 |
class UserUpdate(BaseModel):
|
| 39 |
display_name: Optional[str] = None
|
| 40 |
preferred_model: Optional[str] = None
|
|
|
|
| 34 |
|
| 35 |
model_config = {"from_attributes": True}
|
| 36 |
|
|
|
|
| 37 |
class UserUpdate(BaseModel):
|
| 38 |
display_name: Optional[str] = None
|
| 39 |
preferred_model: Optional[str] = None
|
init_db.py
CHANGED
|
@@ -1,62 +1,68 @@
|
|
| 1 |
import os
|
| 2 |
import logging
|
|
|
|
|
|
|
|
|
|
| 3 |
from api.database import async_engine, Base, User, OAuthAccount, Conversation, Message, AsyncSessionLocal
|
|
|
|
| 4 |
|
| 5 |
# Setup logging
|
| 6 |
logging.basicConfig(level=logging.INFO)
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
logger.info("Starting database initialization...")
|
| 11 |
|
| 12 |
-
# إنشاء الجداول
|
| 13 |
try:
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
Base.metadata.create_all(bind=sync_engine)
|
| 17 |
logger.info("Database tables created successfully.")
|
| 18 |
except Exception as e:
|
| 19 |
logger.error(f"Error creating database tables: {e}")
|
| 20 |
raise
|
| 21 |
|
| 22 |
-
# تنظيف البيانات غير المتسقة
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
from sqlalchemy.orm import sessionmaker
|
| 26 |
-
sync_engine = create_engine(os.getenv("SQLALCHEMY_DATABASE_URL", "sqlite:///./data/mgzon_users.db"))
|
| 27 |
-
SyncSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=sync_engine)
|
| 28 |
-
with SyncSessionLocal() as session:
|
| 29 |
# حذف سجلات oauth_accounts اللي مش مرتبطة بمستخدم موجود
|
| 30 |
stmt = delete(OAuthAccount).where(
|
| 31 |
OAuthAccount.user_id.notin_(select(User.id))
|
| 32 |
)
|
| 33 |
-
result = session.execute(stmt)
|
| 34 |
deleted_count = result.rowcount
|
| 35 |
-
session.commit()
|
| 36 |
logger.info(f"Deleted {deleted_count} orphaned OAuth accounts.")
|
| 37 |
|
| 38 |
# التأكد من إن كل المستخدمين ليهم is_active=True
|
| 39 |
-
users = session.execute(select(User)).scalars().all()
|
| 40 |
for user in users:
|
| 41 |
if not user.is_active:
|
| 42 |
user.is_active = True
|
| 43 |
logger.info(f"Updated user {user.email} to is_active=True")
|
| 44 |
-
session.commit()
|
| 45 |
|
| 46 |
# اختبار إنشاء مستخدم ومحادثة (اختياري)
|
| 47 |
-
test_user = session.
|
|
|
|
|
|
|
| 48 |
if not test_user:
|
| 49 |
test_user = User(
|
| 50 |
email="test@example.com",
|
| 51 |
-
hashed_password="
|
| 52 |
is_active=True,
|
| 53 |
display_name="Test User"
|
| 54 |
)
|
| 55 |
session.add(test_user)
|
| 56 |
-
session.commit()
|
| 57 |
logger.info("Test user created successfully.")
|
| 58 |
|
| 59 |
-
test_conversation = session.
|
|
|
|
|
|
|
| 60 |
if not test_conversation:
|
| 61 |
test_conversation = Conversation(
|
| 62 |
conversation_id="test-conversation-1",
|
|
@@ -64,14 +70,17 @@ def init_db():
|
|
| 64 |
title="Test Conversation"
|
| 65 |
)
|
| 66 |
session.add(test_conversation)
|
| 67 |
-
session.commit()
|
| 68 |
logger.info("Test conversation created successfully.")
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
logger.info("Database initialization completed.")
|
| 75 |
|
| 76 |
if __name__ == "__main__":
|
| 77 |
-
init_db()
|
|
|
|
| 1 |
import os
|
| 2 |
import logging
|
| 3 |
+
import asyncio
|
| 4 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 5 |
+
from sqlalchemy import select, delete
|
| 6 |
from api.database import async_engine, Base, User, OAuthAccount, Conversation, Message, AsyncSessionLocal
|
| 7 |
+
from passlib.context import CryptContext
|
| 8 |
|
| 9 |
# Setup logging
|
| 10 |
logging.basicConfig(level=logging.INFO)
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
# إعداد تشفير كلمة المرور
|
| 14 |
+
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 15 |
+
|
| 16 |
+
async def init_db():
|
| 17 |
logger.info("Starting database initialization...")
|
| 18 |
|
| 19 |
+
# إنشاء الجداول
|
| 20 |
try:
|
| 21 |
+
async with async_engine.begin() as conn:
|
| 22 |
+
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
| 23 |
logger.info("Database tables created successfully.")
|
| 24 |
except Exception as e:
|
| 25 |
logger.error(f"Error creating database tables: {e}")
|
| 26 |
raise
|
| 27 |
|
| 28 |
+
# تنظيف البيانات غير المتسقة
|
| 29 |
+
async with AsyncSessionLocal() as session:
|
| 30 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
# حذف سجلات oauth_accounts اللي مش مرتبطة بمستخدم موجود
|
| 32 |
stmt = delete(OAuthAccount).where(
|
| 33 |
OAuthAccount.user_id.notin_(select(User.id))
|
| 34 |
)
|
| 35 |
+
result = await session.execute(stmt)
|
| 36 |
deleted_count = result.rowcount
|
| 37 |
+
await session.commit()
|
| 38 |
logger.info(f"Deleted {deleted_count} orphaned OAuth accounts.")
|
| 39 |
|
| 40 |
# التأكد من إن كل المستخدمين ليهم is_active=True
|
| 41 |
+
users = (await session.execute(select(User))).scalars().all()
|
| 42 |
for user in users:
|
| 43 |
if not user.is_active:
|
| 44 |
user.is_active = True
|
| 45 |
logger.info(f"Updated user {user.email} to is_active=True")
|
| 46 |
+
await session.commit()
|
| 47 |
|
| 48 |
# اختبار إنشاء مستخدم ومحادثة (اختياري)
|
| 49 |
+
test_user = (await session.execute(
|
| 50 |
+
select(User).filter_by(email="test@example.com")
|
| 51 |
+
)).scalar_one_or_none()
|
| 52 |
if not test_user:
|
| 53 |
test_user = User(
|
| 54 |
email="test@example.com",
|
| 55 |
+
hashed_password=pwd_context.hash("testpassword123"), # تشفير كلمة المرور
|
| 56 |
is_active=True,
|
| 57 |
display_name="Test User"
|
| 58 |
)
|
| 59 |
session.add(test_user)
|
| 60 |
+
await session.commit()
|
| 61 |
logger.info("Test user created successfully.")
|
| 62 |
|
| 63 |
+
test_conversation = (await session.execute(
|
| 64 |
+
select(Conversation).filter_by(user_id=test_user.id)
|
| 65 |
+
)).scalar_one_or_none()
|
| 66 |
if not test_conversation:
|
| 67 |
test_conversation = Conversation(
|
| 68 |
conversation_id="test-conversation-1",
|
|
|
|
| 70 |
title="Test Conversation"
|
| 71 |
)
|
| 72 |
session.add(test_conversation)
|
| 73 |
+
await session.commit()
|
| 74 |
logger.info("Test conversation created successfully.")
|
| 75 |
|
| 76 |
+
except Exception as e:
|
| 77 |
+
await session.rollback()
|
| 78 |
+
logger.error(f"Error during initialization: {e}")
|
| 79 |
+
raise
|
| 80 |
+
finally:
|
| 81 |
+
await session.close()
|
| 82 |
|
| 83 |
logger.info("Database initialization completed.")
|
| 84 |
|
| 85 |
if __name__ == "__main__":
|
| 86 |
+
asyncio.run(init_db())
|
requirements.txt
CHANGED
|
@@ -31,6 +31,8 @@ httpx-oauth==0.16.1
|
|
| 31 |
python-multipart==0.0.17
|
| 32 |
aiofiles==24.1.0
|
| 33 |
motor==3.7.0
|
|
|
|
|
|
|
| 34 |
redis==5.0.0
|
| 35 |
markdown2==2.5.0
|
| 36 |
pymongo==4.10.1
|
|
|
|
| 31 |
python-multipart==0.0.17
|
| 32 |
aiofiles==24.1.0
|
| 33 |
motor==3.7.0
|
| 34 |
+
aiosqlite==0.20.0 # إضافة لدعم SQLite async
|
| 35 |
+
secrets==1.0.0 # للأمان في توليد كلمات سر عشوائية
|
| 36 |
redis==5.0.0
|
| 37 |
markdown2==2.5.0
|
| 38 |
pymongo==4.10.1
|