Mark-Lasfar commited on
Commit
493a4a6
·
1 Parent(s): 6ec5390

Fix ChunkedIteratorResult in SQLAlchemyUserDatabase and toggleBtn null error

Browse files
Files changed (6) hide show
  1. api/auth.py +4 -4
  2. api/database.py +2 -2
  3. api/endpoints.py +75 -56
  4. api/models.py +0 -1
  5. init_db.py +34 -25
  6. requirements.txt +2 -0
api/auth.py CHANGED
@@ -60,7 +60,7 @@ github_oauth_client = GitHubOAuth2(GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET)
60
 
61
  class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
62
  async def get_by_email(self, email: str) -> Optional[User]:
63
- """Override to fix ChunkedIteratorResult issue for get_by_email"""
64
  logger.info(f"Checking for user with email: {email}")
65
  try:
66
  statement = select(self.user_table).where(self.user_table.email == email)
@@ -76,7 +76,7 @@ class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
76
  raise
77
 
78
  async def create(self, create_dict: Dict[str, Any]) -> User:
79
- """Override to fix potential async issues in create"""
80
  logger.info(f"Creating user with email: {create_dict.get('email')}")
81
  try:
82
  user = self.user_table(**create_dict)
@@ -95,7 +95,7 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
95
  verification_token_secret = SECRET
96
 
97
  async def get_by_oauth_account(self, oauth_name: str, account_id: str):
98
- """Override to fix ChunkedIteratorResult issue in SQLAlchemy 2.0+"""
99
  logger.info(f"Checking for existing OAuth account: {oauth_name}/{account_id}")
100
  try:
101
  statement = select(OAuthAccount).where(
@@ -113,7 +113,7 @@ class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
113
  raise
114
 
115
  async def add_oauth_account(self, oauth_account: OAuthAccount):
116
- """Override to fix potential async issues"""
117
  logger.info(f"Adding OAuth account for user {oauth_account.user_id}")
118
  try:
119
  self.session.add(oauth_account)
 
60
 
61
  class CustomSQLAlchemyUserDatabase(SQLAlchemyUserDatabase):
62
  async def get_by_email(self, email: str) -> Optional[User]:
63
+ """Override to fix ChunkedIteratorResult issue."""
64
  logger.info(f"Checking for user with email: {email}")
65
  try:
66
  statement = select(self.user_table).where(self.user_table.email == email)
 
76
  raise
77
 
78
  async def create(self, create_dict: Dict[str, Any]) -> User:
79
+ """Override to fix potential async issues."""
80
  logger.info(f"Creating user with email: {create_dict.get('email')}")
81
  try:
82
  user = self.user_table(**create_dict)
 
95
  verification_token_secret = SECRET
96
 
97
  async def get_by_oauth_account(self, oauth_name: str, account_id: str):
98
+ """Override to fix ChunkedIteratorResult issue."""
99
  logger.info(f"Checking for existing OAuth account: {oauth_name}/{account_id}")
100
  try:
101
  statement = select(OAuthAccount).where(
 
113
  raise
114
 
115
  async def add_oauth_account(self, oauth_account: OAuthAccount):
116
+ """Override to fix potential async issues."""
117
  logger.info(f"Adding OAuth account for user {oauth_account.user_id}")
118
  try:
119
  self.session.add(oauth_account)
api/database.py CHANGED
@@ -12,11 +12,11 @@ import logging
12
  logger = logging.getLogger(__name__)
13
 
14
  # جلب URL قاعدة البيانات من المتغيرات البيئية
15
- SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL")
16
  if not SQLALCHEMY_DATABASE_URL:
17
  raise ValueError("SQLALCHEMY_DATABASE_URL is not set in environment variables.")
18
 
19
- # إنشاء محرك async (استخدم sqlite+aiosqlite للدعم async)
20
  async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True)
21
 
22
  # إعداد جلسة async
 
12
  logger = logging.getLogger(__name__)
13
 
14
  # جلب URL قاعدة البيانات من المتغيرات البيئية
15
+ SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL", "sqlite+aiosqlite:///./data/mgzon_users.db")
16
  if not SQLALCHEMY_DATABASE_URL:
17
  raise ValueError("SQLALCHEMY_DATABASE_URL is not set in environment variables.")
18
 
19
+ # إنشاء محرك async
20
  async_engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True)
21
 
22
  # إعداد جلسة async
api/endpoints.py CHANGED
@@ -6,7 +6,8 @@ from api.database import User, Conversation, Message
6
  from api.models import QueryRequest, ConversationOut, ConversationCreate, UserUpdate
7
  from api.auth import current_active_user
8
  from api.database import get_db
9
- from sqlalchemy.orm import Session
 
10
  from utils.generation import request_generation, select_model, check_model_availability
11
  from utils.web_search import web_search
12
  import io
@@ -145,7 +146,7 @@ async def chat_endpoint(
145
  request: Request,
146
  req: QueryRequest,
147
  user: User = Depends(current_active_user),
148
- db: Session = Depends(get_db)
149
  ):
150
  logger.info(f"Received chat request: {req}")
151
 
@@ -155,7 +156,10 @@ async def chat_endpoint(
155
  conversation = None
156
  if user:
157
  title = req.title or (req.message[:50] + "..." if len(req.message) > 50 else req.message or "Untitled Conversation")
158
- conversation = db.query(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc()).first()
 
 
 
159
  if not conversation:
160
  conversation_id = str(uuid.uuid4())
161
  conversation = Conversation(
@@ -164,12 +168,12 @@ async def chat_endpoint(
164
  title=title
165
  )
166
  db.add(conversation)
167
- db.commit()
168
- db.refresh(conversation)
169
 
170
  user_msg = Message(role="user", content=req.message, conversation_id=conversation.id)
171
  db.add(user_msg)
172
- db.commit()
173
 
174
  preferred_model = user.preferred_model if user else None
175
  model_name, api_endpoint = select_model(req.message, input_type="text", preferred_model=preferred_model)
@@ -231,9 +235,9 @@ async def chat_endpoint(
231
  if user and conversation:
232
  assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
233
  db.add(assistant_msg)
234
- db.commit()
235
  conversation.updated_at = datetime.utcnow()
236
- db.commit()
237
  return {
238
  "response": response,
239
  "conversation_id": conversation.conversation_id,
@@ -248,7 +252,7 @@ async def audio_transcription_endpoint(
248
  request: Request,
249
  file: UploadFile = File(...),
250
  user: User = Depends(current_active_user),
251
- db: Session = Depends(get_db)
252
  ):
253
  logger.info(f"Received audio transcription request for file: {file.filename}")
254
 
@@ -258,7 +262,10 @@ async def audio_transcription_endpoint(
258
  conversation = None
259
  if user:
260
  title = "Audio Transcription"
261
- conversation = db.query(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc()).first()
 
 
 
262
  if not conversation:
263
  conversation_id = str(uuid.uuid4())
264
  conversation = Conversation(
@@ -267,12 +274,12 @@ async def audio_transcription_endpoint(
267
  title=title
268
  )
269
  db.add(conversation)
270
- db.commit()
271
- db.refresh(conversation)
272
 
273
  user_msg = Message(role="user", content="Audio message", conversation_id=conversation.id)
274
  db.add(user_msg)
275
- db.commit()
276
 
277
  model_name, api_endpoint = select_model("transcribe audio", input_type="audio")
278
 
@@ -313,9 +320,9 @@ async def audio_transcription_endpoint(
313
  if user and conversation:
314
  assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
315
  db.add(assistant_msg)
316
- db.commit()
317
  conversation.updated_at = datetime.utcnow()
318
- db.commit()
319
  return {
320
  "transcription": response,
321
  "conversation_id": conversation.conversation_id,
@@ -330,7 +337,7 @@ async def text_to_speech_endpoint(
330
  request: Request,
331
  req: dict,
332
  user: User = Depends(current_active_user),
333
- db: Session = Depends(get_db)
334
  ):
335
  if not user:
336
  await handle_session(request)
@@ -378,7 +385,7 @@ async def code_endpoint(
378
  request: Request,
379
  req: dict,
380
  user: User = Depends(current_active_user),
381
- db: Session = Depends(get_db)
382
  ):
383
  if not user:
384
  await handle_session(request)
@@ -453,7 +460,7 @@ async def analysis_endpoint(
453
  request: Request,
454
  req: dict,
455
  user: User = Depends(current_active_user),
456
- db: Session = Depends(get_db)
457
  ):
458
  if not user:
459
  await handle_session(request)
@@ -526,7 +533,7 @@ async def image_analysis_endpoint(
526
  file: UploadFile = File(...),
527
  output_format: str = "text",
528
  user: User = Depends(current_active_user),
529
- db: Session = Depends(get_db)
530
  ):
531
  if not user:
532
  await handle_session(request)
@@ -534,7 +541,10 @@ async def image_analysis_endpoint(
534
  conversation = None
535
  if user:
536
  title = "Image Analysis"
537
- conversation = db.query(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc()).first()
 
 
 
538
  if not conversation:
539
  conversation_id = str(uuid.uuid4())
540
  conversation = Conversation(
@@ -543,12 +553,12 @@ async def image_analysis_endpoint(
543
  title=title
544
  )
545
  db.add(conversation)
546
- db.commit()
547
- db.refresh(conversation)
548
 
549
  user_msg = Message(role="user", content="Image analysis request", conversation_id=conversation.id)
550
  db.add(user_msg)
551
- db.commit()
552
 
553
  preferred_model = user.preferred_model if user else None
554
  model_name, api_endpoint = select_model("analyze image", input_type="image", preferred_model=preferred_model)
@@ -608,9 +618,9 @@ async def image_analysis_endpoint(
608
  if user and conversation:
609
  assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
610
  db.add(assistant_msg)
611
- db.commit()
612
  conversation.updated_at = datetime.utcnow()
613
- db.commit()
614
  return {
615
  "image_analysis": response,
616
  "conversation_id": conversation.conversation_id,
@@ -646,7 +656,7 @@ async def test_model(model: str = MODEL_NAME, endpoint: str = API_ENDPOINT):
646
  async def create_conversation(
647
  req: ConversationCreate,
648
  user: User = Depends(current_active_user),
649
- db: Session = Depends(get_db)
650
  ):
651
  if not user:
652
  raise HTTPException(status_code=401, detail="Login required")
@@ -657,77 +667,87 @@ async def create_conversation(
657
  user_id=user.id
658
  )
659
  db.add(conversation)
660
- db.commit()
661
- db.refresh(conversation)
662
  return ConversationOut.from_orm(conversation)
663
 
664
  @router.get("/api/conversations/{conversation_id}", response_model=ConversationOut)
665
  async def get_conversation(
666
  conversation_id: str,
667
  user: User = Depends(current_active_user),
668
- db: Session = Depends(get_db)
669
  ):
670
  if not user:
671
  raise HTTPException(status_code=401, detail="Login required")
672
- conversation = db.query(Conversation).filter(
673
- Conversation.conversation_id == conversation_id,
674
- Conversation.user_id == user.id
675
- ).first()
 
 
 
676
  if not conversation:
677
  raise HTTPException(status_code=404, detail="Conversation not found")
678
- db.add(conversation)
679
- db.commit()
680
  return ConversationOut.from_orm(conversation)
681
 
682
  @router.get("/api/conversations", response_model=List[ConversationOut])
683
  async def list_conversations(
684
  user: User = Depends(current_active_user),
685
- db: Session = Depends(get_db)
686
  ):
687
  if not user:
688
  raise HTTPException(status_code=401, detail="Login required")
689
- conversations = db.query(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.created_at.desc()).all()
690
- return conversations
 
 
 
691
 
692
  @router.put("/api/conversations/{conversation_id}/title")
693
  async def update_conversation_title(
694
  conversation_id: str,
695
  title: str,
696
  user: User = Depends(current_active_user),
697
- db: Session = Depends(get_db)
698
  ):
699
  if not user:
700
  raise HTTPException(status_code=401, detail="Login required")
701
- conversation = db.query(Conversation).filter(
702
- Conversation.conversation_id == conversation_id,
703
- Conversation.user_id == user.id
704
- ).first()
 
 
 
705
  if not conversation:
706
  raise HTTPException(status_code=404, detail="Conversation not found")
707
 
708
  conversation.title = title
709
  conversation.updated_at = datetime.utcnow()
710
- db.commit()
711
  return {"message": "Conversation title updated", "title": conversation.title}
712
 
713
  @router.delete("/api/conversations/{conversation_id}")
714
  async def delete_conversation(
715
  conversation_id: str,
716
  user: User = Depends(current_active_user),
717
- db: Session = Depends(get_db)
718
  ):
719
  if not user:
720
  raise HTTPException(status_code=401, detail="Login required")
721
- conversation = db.query(Conversation).filter(
722
- Conversation.conversation_id == conversation_id,
723
- Conversation.user_id == user.id
724
- ).first()
 
 
 
725
  if not conversation:
726
  raise HTTPException(status_code=404, detail="Conversation not found")
727
 
728
- db.query(Message).filter(Message.conversation_id == conversation.id).delete()
729
- db.delete(conversation)
730
- db.commit()
731
  return {"message": "Conversation deleted successfully"}
732
 
733
  @router.get("/users/me")
@@ -752,7 +772,7 @@ async def get_user_settings(user: User = Depends(current_active_user)):
752
  async def update_user_settings(
753
  settings: UserUpdate,
754
  user: User = Depends(current_active_user),
755
- db: Session = Depends(get_db)
756
  ):
757
  if not user:
758
  raise HTTPException(status_code=401, detail="Login required")
@@ -775,8 +795,8 @@ async def update_user_settings(
775
  if settings.conversation_style is not None:
776
  user.conversation_style = settings.conversation_style
777
 
778
- db.commit()
779
- db.refresh(user)
780
  return {"message": "Settings updated successfully", "user": {
781
  "id": user.id,
782
  "email": user.email,
@@ -790,4 +810,3 @@ async def update_user_settings(
790
  "is_active": user.is_active,
791
  "is_superuser": user.is_superuser
792
  }}
793
-
 
6
  from api.models import QueryRequest, ConversationOut, ConversationCreate, UserUpdate
7
  from api.auth import current_active_user
8
  from api.database import get_db
9
+ from sqlalchemy.ext.asyncio import AsyncSession
10
+ from sqlalchemy import select, delete
11
  from utils.generation import request_generation, select_model, check_model_availability
12
  from utils.web_search import web_search
13
  import io
 
146
  request: Request,
147
  req: QueryRequest,
148
  user: User = Depends(current_active_user),
149
+ db: AsyncSession = Depends(get_db)
150
  ):
151
  logger.info(f"Received chat request: {req}")
152
 
 
156
  conversation = None
157
  if user:
158
  title = req.title or (req.message[:50] + "..." if len(req.message) > 50 else req.message or "Untitled Conversation")
159
+ result = await db.execute(
160
+ select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc())
161
+ )
162
+ conversation = result.scalar_one_or_none()
163
  if not conversation:
164
  conversation_id = str(uuid.uuid4())
165
  conversation = Conversation(
 
168
  title=title
169
  )
170
  db.add(conversation)
171
+ await db.commit()
172
+ await db.refresh(conversation)
173
 
174
  user_msg = Message(role="user", content=req.message, conversation_id=conversation.id)
175
  db.add(user_msg)
176
+ await db.commit()
177
 
178
  preferred_model = user.preferred_model if user else None
179
  model_name, api_endpoint = select_model(req.message, input_type="text", preferred_model=preferred_model)
 
235
  if user and conversation:
236
  assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
237
  db.add(assistant_msg)
238
+ await db.commit()
239
  conversation.updated_at = datetime.utcnow()
240
+ await db.commit()
241
  return {
242
  "response": response,
243
  "conversation_id": conversation.conversation_id,
 
252
  request: Request,
253
  file: UploadFile = File(...),
254
  user: User = Depends(current_active_user),
255
+ db: AsyncSession = Depends(get_db)
256
  ):
257
  logger.info(f"Received audio transcription request for file: {file.filename}")
258
 
 
262
  conversation = None
263
  if user:
264
  title = "Audio Transcription"
265
+ result = await db.execute(
266
+ select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc())
267
+ )
268
+ conversation = result.scalar_one_or_none()
269
  if not conversation:
270
  conversation_id = str(uuid.uuid4())
271
  conversation = Conversation(
 
274
  title=title
275
  )
276
  db.add(conversation)
277
+ await db.commit()
278
+ await db.refresh(conversation)
279
 
280
  user_msg = Message(role="user", content="Audio message", conversation_id=conversation.id)
281
  db.add(user_msg)
282
+ await db.commit()
283
 
284
  model_name, api_endpoint = select_model("transcribe audio", input_type="audio")
285
 
 
320
  if user and conversation:
321
  assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
322
  db.add(assistant_msg)
323
+ await db.commit()
324
  conversation.updated_at = datetime.utcnow()
325
+ await db.commit()
326
  return {
327
  "transcription": response,
328
  "conversation_id": conversation.conversation_id,
 
337
  request: Request,
338
  req: dict,
339
  user: User = Depends(current_active_user),
340
+ db: AsyncSession = Depends(get_db)
341
  ):
342
  if not user:
343
  await handle_session(request)
 
385
  request: Request,
386
  req: dict,
387
  user: User = Depends(current_active_user),
388
+ db: AsyncSession = Depends(get_db)
389
  ):
390
  if not user:
391
  await handle_session(request)
 
460
  request: Request,
461
  req: dict,
462
  user: User = Depends(current_active_user),
463
+ db: AsyncSession = Depends(get_db)
464
  ):
465
  if not user:
466
  await handle_session(request)
 
533
  file: UploadFile = File(...),
534
  output_format: str = "text",
535
  user: User = Depends(current_active_user),
536
+ db: AsyncSession = Depends(get_db)
537
  ):
538
  if not user:
539
  await handle_session(request)
 
541
  conversation = None
542
  if user:
543
  title = "Image Analysis"
544
+ result = await db.execute(
545
+ select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc())
546
+ )
547
+ conversation = result.scalar_one_or_none()
548
  if not conversation:
549
  conversation_id = str(uuid.uuid4())
550
  conversation = Conversation(
 
553
  title=title
554
  )
555
  db.add(conversation)
556
+ await db.commit()
557
+ await db.refresh(conversation)
558
 
559
  user_msg = Message(role="user", content="Image analysis request", conversation_id=conversation.id)
560
  db.add(user_msg)
561
+ await db.commit()
562
 
563
  preferred_model = user.preferred_model if user else None
564
  model_name, api_endpoint = select_model("analyze image", input_type="image", preferred_model=preferred_model)
 
618
  if user and conversation:
619
  assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id)
620
  db.add(assistant_msg)
621
+ await db.commit()
622
  conversation.updated_at = datetime.utcnow()
623
+ await db.commit()
624
  return {
625
  "image_analysis": response,
626
  "conversation_id": conversation.conversation_id,
 
656
  async def create_conversation(
657
  req: ConversationCreate,
658
  user: User = Depends(current_active_user),
659
+ db: AsyncSession = Depends(get_db)
660
  ):
661
  if not user:
662
  raise HTTPException(status_code=401, detail="Login required")
 
667
  user_id=user.id
668
  )
669
  db.add(conversation)
670
+ await db.commit()
671
+ await db.refresh(conversation)
672
  return ConversationOut.from_orm(conversation)
673
 
674
  @router.get("/api/conversations/{conversation_id}", response_model=ConversationOut)
675
  async def get_conversation(
676
  conversation_id: str,
677
  user: User = Depends(current_active_user),
678
+ db: AsyncSession = Depends(get_db)
679
  ):
680
  if not user:
681
  raise HTTPException(status_code=401, detail="Login required")
682
+ result = await db.execute(
683
+ select(Conversation).filter(
684
+ Conversation.conversation_id == conversation_id,
685
+ Conversation.user_id == user.id
686
+ )
687
+ )
688
+ conversation = result.scalar_one_or_none()
689
  if not conversation:
690
  raise HTTPException(status_code=404, detail="Conversation not found")
 
 
691
  return ConversationOut.from_orm(conversation)
692
 
693
  @router.get("/api/conversations", response_model=List[ConversationOut])
694
  async def list_conversations(
695
  user: User = Depends(current_active_user),
696
+ db: AsyncSession = Depends(get_db)
697
  ):
698
  if not user:
699
  raise HTTPException(status_code=401, detail="Login required")
700
+ result = await db.execute(
701
+ select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.created_at.desc())
702
+ )
703
+ conversations = result.scalars().all()
704
+ return [ConversationOut.from_orm(conv) for conv in conversations]
705
 
706
  @router.put("/api/conversations/{conversation_id}/title")
707
  async def update_conversation_title(
708
  conversation_id: str,
709
  title: str,
710
  user: User = Depends(current_active_user),
711
+ db: AsyncSession = Depends(get_db)
712
  ):
713
  if not user:
714
  raise HTTPException(status_code=401, detail="Login required")
715
+ result = await db.execute(
716
+ select(Conversation).filter(
717
+ Conversation.conversation_id == conversation_id,
718
+ Conversation.user_id == user.id
719
+ )
720
+ )
721
+ conversation = result.scalar_one_or_none()
722
  if not conversation:
723
  raise HTTPException(status_code=404, detail="Conversation not found")
724
 
725
  conversation.title = title
726
  conversation.updated_at = datetime.utcnow()
727
+ await db.commit()
728
  return {"message": "Conversation title updated", "title": conversation.title}
729
 
730
  @router.delete("/api/conversations/{conversation_id}")
731
  async def delete_conversation(
732
  conversation_id: str,
733
  user: User = Depends(current_active_user),
734
+ db: AsyncSession = Depends(get_db)
735
  ):
736
  if not user:
737
  raise HTTPException(status_code=401, detail="Login required")
738
+ result = await db.execute(
739
+ select(Conversation).filter(
740
+ Conversation.conversation_id == conversation_id,
741
+ Conversation.user_id == user.id
742
+ )
743
+ )
744
+ conversation = result.scalar_one_or_none()
745
  if not conversation:
746
  raise HTTPException(status_code=404, detail="Conversation not found")
747
 
748
+ await db.execute(delete(Message).filter(Message.conversation_id == conversation.id))
749
+ await db.delete(conversation)
750
+ await db.commit()
751
  return {"message": "Conversation deleted successfully"}
752
 
753
  @router.get("/users/me")
 
772
  async def update_user_settings(
773
  settings: UserUpdate,
774
  user: User = Depends(current_active_user),
775
+ db: AsyncSession = Depends(get_db)
776
  ):
777
  if not user:
778
  raise HTTPException(status_code=401, detail="Login required")
 
795
  if settings.conversation_style is not None:
796
  user.conversation_style = settings.conversation_style
797
 
798
+ await db.commit()
799
+ await db.refresh(user)
800
  return {"message": "Settings updated successfully", "user": {
801
  "id": user.id,
802
  "email": user.email,
 
810
  "is_active": user.is_active,
811
  "is_superuser": user.is_superuser
812
  }}
 
api/models.py CHANGED
@@ -34,7 +34,6 @@ class UserCreate(schemas.BaseUserCreate):
34
 
35
  model_config = {"from_attributes": True}
36
 
37
- # Pydantic schema for updating user settings
38
  class UserUpdate(BaseModel):
39
  display_name: Optional[str] = None
40
  preferred_model: Optional[str] = None
 
34
 
35
  model_config = {"from_attributes": True}
36
 
 
37
  class UserUpdate(BaseModel):
38
  display_name: Optional[str] = None
39
  preferred_model: Optional[str] = None
init_db.py CHANGED
@@ -1,62 +1,68 @@
1
  import os
2
  import logging
 
 
 
3
  from api.database import async_engine, Base, User, OAuthAccount, Conversation, Message, AsyncSessionLocal
 
4
 
5
  # Setup logging
6
  logging.basicConfig(level=logging.INFO)
7
  logger = logging.getLogger(__name__)
8
 
9
- def init_db():
 
 
 
10
  logger.info("Starting database initialization...")
11
 
12
- # إنشاء الجداول (sync version for init_db.py)
13
  try:
14
- from sqlalchemy import create_engine
15
- sync_engine = create_engine(os.getenv("SQLALCHEMY_DATABASE_URL", "sqlite:///./data/mgzon_users.db"))
16
- Base.metadata.create_all(bind=sync_engine)
17
  logger.info("Database tables created successfully.")
18
  except Exception as e:
19
  logger.error(f"Error creating database tables: {e}")
20
  raise
21
 
22
- # تنظيف البيانات غير المتسقة (sync for simplicity in init_db)
23
- try:
24
- from sqlalchemy import select, delete
25
- from sqlalchemy.orm import sessionmaker
26
- sync_engine = create_engine(os.getenv("SQLALCHEMY_DATABASE_URL", "sqlite:///./data/mgzon_users.db"))
27
- SyncSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=sync_engine)
28
- with SyncSessionLocal() as session:
29
  # حذف سجلات oauth_accounts اللي مش مرتبطة بمستخدم موجود
30
  stmt = delete(OAuthAccount).where(
31
  OAuthAccount.user_id.notin_(select(User.id))
32
  )
33
- result = session.execute(stmt)
34
  deleted_count = result.rowcount
35
- session.commit()
36
  logger.info(f"Deleted {deleted_count} orphaned OAuth accounts.")
37
 
38
  # التأكد من إن كل المستخدمين ليهم is_active=True
39
- users = session.execute(select(User)).scalars().all()
40
  for user in users:
41
  if not user.is_active:
42
  user.is_active = True
43
  logger.info(f"Updated user {user.email} to is_active=True")
44
- session.commit()
45
 
46
  # اختبار إنشاء مستخدم ومحادثة (اختياري)
47
- test_user = session.query(User).filter_by(email="test@example.com").first()
 
 
48
  if not test_user:
49
  test_user = User(
50
  email="test@example.com",
51
- hashed_password="$2b$12$examplehashedpassword", # استبدل بكلمة مرور مشفرة حقيقية
52
  is_active=True,
53
  display_name="Test User"
54
  )
55
  session.add(test_user)
56
- session.commit()
57
  logger.info("Test user created successfully.")
58
 
59
- test_conversation = session.query(Conversation).filter_by(user_id=test_user.id).first()
 
 
60
  if not test_conversation:
61
  test_conversation = Conversation(
62
  conversation_id="test-conversation-1",
@@ -64,14 +70,17 @@ def init_db():
64
  title="Test Conversation"
65
  )
66
  session.add(test_conversation)
67
- session.commit()
68
  logger.info("Test conversation created successfully.")
69
 
70
- except Exception as e:
71
- logger.error(f"Error during initialization: {e}")
72
- raise
 
 
 
73
 
74
  logger.info("Database initialization completed.")
75
 
76
  if __name__ == "__main__":
77
- init_db()
 
1
  import os
2
  import logging
3
+ import asyncio
4
+ from sqlalchemy.ext.asyncio import AsyncSession
5
+ from sqlalchemy import select, delete
6
  from api.database import async_engine, Base, User, OAuthAccount, Conversation, Message, AsyncSessionLocal
7
+ from passlib.context import CryptContext
8
 
9
  # Setup logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ # إعداد تشفير كلمة المرور
14
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
15
+
16
+ async def init_db():
17
  logger.info("Starting database initialization...")
18
 
19
+ # إنشاء الجداول
20
  try:
21
+ async with async_engine.begin() as conn:
22
+ await conn.run_sync(Base.metadata.create_all)
 
23
  logger.info("Database tables created successfully.")
24
  except Exception as e:
25
  logger.error(f"Error creating database tables: {e}")
26
  raise
27
 
28
+ # تنظيف البيانات غير المتسقة
29
+ async with AsyncSessionLocal() as session:
30
+ try:
 
 
 
 
31
  # حذف سجلات oauth_accounts اللي مش مرتبطة بمستخدم موجود
32
  stmt = delete(OAuthAccount).where(
33
  OAuthAccount.user_id.notin_(select(User.id))
34
  )
35
+ result = await session.execute(stmt)
36
  deleted_count = result.rowcount
37
+ await session.commit()
38
  logger.info(f"Deleted {deleted_count} orphaned OAuth accounts.")
39
 
40
  # التأكد من إن كل المستخدمين ليهم is_active=True
41
+ users = (await session.execute(select(User))).scalars().all()
42
  for user in users:
43
  if not user.is_active:
44
  user.is_active = True
45
  logger.info(f"Updated user {user.email} to is_active=True")
46
+ await session.commit()
47
 
48
  # اختبار إنشاء مستخدم ومحادثة (اختياري)
49
+ test_user = (await session.execute(
50
+ select(User).filter_by(email="test@example.com")
51
+ )).scalar_one_or_none()
52
  if not test_user:
53
  test_user = User(
54
  email="test@example.com",
55
+ hashed_password=pwd_context.hash("testpassword123"), # تشفير كلمة المرور
56
  is_active=True,
57
  display_name="Test User"
58
  )
59
  session.add(test_user)
60
+ await session.commit()
61
  logger.info("Test user created successfully.")
62
 
63
+ test_conversation = (await session.execute(
64
+ select(Conversation).filter_by(user_id=test_user.id)
65
+ )).scalar_one_or_none()
66
  if not test_conversation:
67
  test_conversation = Conversation(
68
  conversation_id="test-conversation-1",
 
70
  title="Test Conversation"
71
  )
72
  session.add(test_conversation)
73
+ await session.commit()
74
  logger.info("Test conversation created successfully.")
75
 
76
+ except Exception as e:
77
+ await session.rollback()
78
+ logger.error(f"Error during initialization: {e}")
79
+ raise
80
+ finally:
81
+ await session.close()
82
 
83
  logger.info("Database initialization completed.")
84
 
85
  if __name__ == "__main__":
86
+ asyncio.run(init_db())
requirements.txt CHANGED
@@ -31,6 +31,8 @@ httpx-oauth==0.16.1
31
  python-multipart==0.0.17
32
  aiofiles==24.1.0
33
  motor==3.7.0
 
 
34
  redis==5.0.0
35
  markdown2==2.5.0
36
  pymongo==4.10.1
 
31
  python-multipart==0.0.17
32
  aiofiles==24.1.0
33
  motor==3.7.0
34
+ aiosqlite==0.20.0 # إضافة لدعم SQLite async
35
+ secrets==1.0.0 # للأمان في توليد كلمات سر عشوائية
36
  redis==5.0.0
37
  markdown2==2.5.0
38
  pymongo==4.10.1