FireBird-Tech commited on
Commit
7b969cb
·
verified ·
1 Parent(s): 15eda99

Update src/managers/chat_manager.py

Browse files
Files changed (1) hide show
  1. src/managers/chat_manager.py +737 -1030
src/managers/chat_manager.py CHANGED
@@ -1,1030 +1,737 @@
1
- from sqlalchemy import create_engine, desc, func
2
- from sqlalchemy.orm import sessionmaker, scoped_session
3
- from sqlalchemy.exc import SQLAlchemyError
4
- from src.db.schemas.models import Base, User, Chat, Message, ModelUsage
5
- import logging
6
- import requests
7
- import json
8
- from typing import List, Dict, Optional, Tuple, Any
9
- from datetime import datetime
10
- import time
11
- import tiktoken
12
- from src.utils.logger import Logger
13
- import re
14
-
15
- logger = Logger("chat_manager", see_time=True, console_log=False)
16
-
17
-
18
- class ChatManager:
19
- """
20
- Manages chat operations including creating, storing, retrieving, and updating chats and messages.
21
- Provides an interface between the application and the database for chat-related operations.
22
- """
23
-
24
- def __init__(self, db_url: str = 'sqlite:///chat_database.db'):
25
- """
26
- Initialize the ChatManager with a database connection.
27
-
28
- Args:
29
- db_url: Database connection URL (defaults to SQLite)
30
- """
31
- self.engine = create_engine(db_url)
32
- Base.metadata.create_all(self.engine) # Ensure tables exist
33
- self.Session = scoped_session(sessionmaker(bind=self.engine))
34
-
35
- # Add price mappings for different models
36
- self.model_costs = {
37
- # OpenAI models (per 1M tokens)
38
- "gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
39
- "gpt-3.5-turbo-16k": {"input": 0.003, "output": 0.004},
40
- "gpt-4o": {"input": 0.01, "output": 0.03},
41
- "gpt-4o-mini": {"input": 0.0015, "output": 0.002},
42
- "gpt-4": {"input": 0.03, "output": 0.06},
43
- "gpt-4-32k": {"input": 0.06, "output": 0.12},
44
- # Anthropic models
45
- "claude-3-opus-20240229": {"input": 0.015, "output": 0.075},
46
- "claude-3-sonnet-20240229": {"input": 0.003, "output": 0.015},
47
- "claude-3-haiku-20240307": {"input": 0.00025, "output": 0.00125},
48
- "claude-3-5-sonnet-latest": {"input": 0.003, "output": 0.015},
49
- # Groq models
50
- "deepseek-r1-distill-qwen-32b": {"input": 0.00075, "output": 0.00099},
51
- "deepseek-r1-distill-llama-70b": {"input": 0.00075, "output": 0.00099},
52
- "llama-3.3-70b-versatile": {"input": 0.00059, "output": 0.00079},
53
- "llama-3.3-70b-specdec": {"input": 0.00059, "output": 0.00099},
54
- "llama2-70b-4096": {"input": 0.0007, "output": 0.0008},
55
- "llama3-8b-8192": {"input": 0.00005, "output": 0.00008},
56
- "llama-3.2-1b-preview": {"input": 0.00004, "output": 0.00004},
57
- "llama-3.2-3b-preview": {"input": 0.00006, "output": 0.00006},
58
- "llama-3.2-11b-text-preview": {"input": 0.00018, "output": 0.00018},
59
- "llama-3.2-11b-vision-preview": {"input": 0.00018, "output": 0.00018},
60
- "llama-3.2-90b-text-preview": {"input": 0.0009, "output": 0.0009},
61
- "llama-3.2-90b-vision-preview": {"input": 0.0009, "output": 0.0009},
62
- "llama3-70b-8192": {"input": 0.00059, "output": 0.00079},
63
- "llama-3.1-8b-instant": {"input": 0.00005, "output": 0.00008},
64
- "llama-3.1-70b-versatile": {"input": 0.00059, "output": 0.00079},
65
- "llama-3.1-405b-reasoning": {"input": 0.00059, "output": 0.00079},
66
- "mixtral-8x7b-32768": {"input": 0.00024, "output": 0.00024},
67
- "gemma-7b-it": {"input": 0.00007, "output": 0.00007},
68
- "gemma2-9b-it": {"input": 0.0002, "output": 0.0002},
69
- "llama3-groq-70b-8192-tool-use-preview": {"input": 0.00089, "output": 0.00089},
70
- "llama3-groq-8b-8192-tool-use-preview": {"input": 0.00019, "output": 0.00019},
71
- "qwen-2.5-coder-32b": {"input": 0.0015, "output": 0.003}
72
-
73
- }
74
-
75
- # Add model providers mapping
76
- self.model_providers = {
77
- "gpt-": "openai",
78
- "claude-": "anthropic",
79
- "llama-": "groq",
80
- "mixtral-": "groq",
81
- }
82
-
83
- def create_chat(self, user_id: Optional[int] = None) -> Dict[str, Any]:
84
- """
85
- Create a new chat session.
86
-
87
- Args:
88
- user_id: Optional user ID if authentication is enabled
89
-
90
- Returns:
91
- Dictionary containing chat information
92
- """
93
- session = self.Session()
94
- try:
95
- # Create a new chat
96
- chat = Chat(
97
- user_id=user_id,
98
- title='New Chat',
99
- created_at=datetime.utcnow()
100
- )
101
- session.add(chat)
102
- session.commit()
103
-
104
- logger.log_message(f"Created new chat {chat.chat_id} for user {user_id}", level=logging.INFO)
105
-
106
- return {
107
- "chat_id": chat.chat_id,
108
- "user_id": chat.user_id,
109
- "title": chat.title,
110
- "created_at": chat.created_at.isoformat()
111
- }
112
- except SQLAlchemyError as e:
113
- session.rollback()
114
- logger.log_message(f"Error creating chat: {str(e)}", level=logging.ERROR)
115
- raise
116
- finally:
117
- session.close()
118
-
119
- def add_message(self, chat_id: int, content: str, sender: str, user_id: Optional[int] = None) -> Dict[str, Any]:
120
- """
121
- Add a message to a chat.
122
-
123
- Args:
124
- chat_id: ID of the chat to add the message to
125
- content: Message content
126
- sender: Message sender ('user' or 'ai')
127
- user_id: Optional user ID to verify ownership
128
-
129
- Returns:
130
- Dictionary containing message information
131
- """
132
- session = self.Session()
133
- try:
134
- # Check if chat exists and belongs to the user if user_id is provided
135
- query = session.query(Chat).filter(Chat.chat_id == chat_id)
136
- if user_id is not None:
137
- query = query.filter((Chat.user_id == user_id) | (Chat.user_id.is_(None)))
138
-
139
- chat = query.first()
140
- if not chat:
141
- raise ValueError(f"Chat with ID {chat_id} not found or access denied")
142
-
143
- # Create a new message
144
- message = Message(
145
- chat_id=chat_id,
146
- content=content,
147
- sender=sender,
148
- timestamp=datetime.utcnow()
149
- )
150
- session.add(message)
151
-
152
- # If this is the first AI response and chat title is still default,
153
- # update the chat title based on the first user query
154
- if sender == 'ai':
155
- first_ai_message = session.query(Message).filter(
156
- Message.chat_id == chat_id,
157
- Message.sender == 'ai'
158
- ).first()
159
-
160
- if not first_ai_message and chat.title == 'New Chat':
161
- # Get the user's first message
162
- first_user_message = session.query(Message).filter(
163
- Message.chat_id == chat_id,
164
- Message.sender == 'user'
165
- ).order_by(Message.timestamp).first()
166
-
167
- if first_user_message:
168
- # Generate title from user query
169
- new_title = self.generate_title_from_query(first_user_message.content)
170
- chat.title = new_title
171
-
172
- session.commit()
173
-
174
- return {
175
- "message_id": message.message_id,
176
- "chat_id": message.chat_id,
177
- "content": message.content,
178
- "sender": message.sender,
179
- "timestamp": message.timestamp.isoformat()
180
- }
181
- except SQLAlchemyError as e:
182
- session.rollback()
183
- logger.log_message(f"Error adding message: {str(e)}", level=logging.ERROR)
184
- raise
185
- finally:
186
- session.close()
187
-
188
- def _update_chat_title(self, chat_id: int, first_response: str) -> None:
189
- """
190
- Update chat title based on the first bot response.
191
-
192
- Args:
193
- chat_id: ID of the chat to update
194
- first_response: First bot response content to generate title from
195
- """
196
- session = self.Session()
197
- try:
198
- # Get the user's query (the message before the bot response)
199
- user_query = session.query(Message).filter(
200
- Message.chat_id == chat_id,
201
- Message.sender == 'user'
202
- ).order_by(Message.timestamp.desc()).first()
203
-
204
- if not user_query:
205
- logger.warning(f"No user query found for chat {chat_id}")
206
- return
207
-
208
- # Call the chat_history_name endpoint to generate a title
209
- try:
210
- # This would typically be an internal API call
211
- # For demonstration, we're showing how it would be structured
212
- # In a real implementation, you might want to directly call the function
213
- # that generates the title rather than making an HTTP request
214
- response = requests.post(
215
- "http://localhost:8000/chat_history_name",
216
- json={"query": user_query.content},
217
- timeout=5
218
- )
219
-
220
- if response.status_code == 200:
221
- title_data = response.json()
222
- new_title = title_data.get("name", "Chat")
223
-
224
- # Update chat title
225
- chat = session.query(Chat).filter(Chat.chat_id == chat_id).first()
226
- if chat:
227
- chat.title = new_title
228
- session.commit()
229
- # logger.info(f"Updated chat {chat_id} title to '{new_title}'")
230
- else:
231
- logger.warning(f"Failed to generate title: {response.status_code}")
232
- except Exception as e:
233
- logger.log_message(f"Error calling chat_history_name endpoint: {str(e)}", level=logging.ERROR)
234
- # Continue execution even if title generation fails
235
- except SQLAlchemyError as e:
236
- session.rollback()
237
- logger.log_message(f"Error updating chat title: {str(e)}", level=logging.ERROR)
238
- finally:
239
- session.close()
240
-
241
- def get_chat(self, chat_id: int, user_id: Optional[int] = None) -> Dict[str, Any]:
242
- """
243
- Get a chat by ID with all its messages.
244
-
245
- Args:
246
- chat_id: ID of the chat to retrieve
247
- user_id: Optional user ID to verify ownership
248
-
249
- Returns:
250
- Dictionary containing chat information and messages
251
- """
252
- session = self.Session()
253
- try:
254
- # Get the chat
255
- query = session.query(Chat).filter(Chat.chat_id == chat_id)
256
-
257
- # If user_id is provided, ensure the chat belongs to this user
258
- if user_id is not None:
259
- query = query.filter(Chat.user_id == user_id)
260
-
261
- chat = query.first()
262
- if not chat:
263
- raise ValueError(f"Chat with ID {chat_id} not found or access denied")
264
-
265
- # Get the chat messages ordered by timestamp
266
- messages = session.query(Message).filter(
267
- Message.chat_id == chat_id
268
- ).order_by(Message.timestamp).all()
269
-
270
- return {
271
- "chat_id": chat.chat_id,
272
- "title": chat.title,
273
- "created_at": chat.created_at.isoformat(),
274
- "user_id": chat.user_id,
275
- "messages": [
276
- {
277
- "message_id": msg.message_id,
278
- "chat_id": msg.chat_id,
279
- "content": msg.content,
280
- "sender": msg.sender,
281
- "timestamp": msg.timestamp.isoformat()
282
- } for msg in messages
283
- ]
284
- }
285
- except SQLAlchemyError as e:
286
- logger.log_message(f"Error retrieving chat: {str(e)}", level=logging.ERROR)
287
- raise
288
- finally:
289
- session.close()
290
-
291
- def get_user_chats(self, user_id: Optional[int] = None, limit: int = 10, offset: int = 0) -> List[Dict[str, Any]]:
292
- """
293
- Get recent chats for a user, or all chats if no user_id is provided.
294
-
295
- Args:
296
- user_id: Optional user ID to filter chats
297
- limit: Maximum number of chats to return
298
- offset: Number of chats to skip (for pagination)
299
-
300
- Returns:
301
- List of dictionaries containing chat information
302
- """
303
- session = self.Session()
304
- try:
305
- query = session.query(Chat)
306
-
307
- # Filter by user_id if provided
308
- if user_id is not None:
309
- query = query.filter(Chat.user_id == user_id)
310
-
311
- chats = query.order_by(Chat.created_at.desc()).limit(limit).offset(offset).all()
312
-
313
- return [
314
- {
315
- "chat_id": chat.chat_id,
316
- "user_id": chat.user_id,
317
- "title": chat.title,
318
- "created_at": chat.created_at.isoformat()
319
- } for chat in chats
320
- ]
321
- except SQLAlchemyError as e:
322
- logger.log_message(f"Error retrieving chats: {str(e)}", level=logging.ERROR)
323
- return []
324
- finally:
325
- session.close()
326
-
327
- def _get_last_message(self, chat_id: int) -> Optional[Dict[str, Any]]:
328
- """
329
- Get the last message from a chat for preview purposes.
330
-
331
- Args:
332
- chat_id: ID of the chat
333
-
334
- Returns:
335
- Dictionary containing last message information or None
336
- """
337
- session = self.Session()
338
- try:
339
- last_message = session.query(Message).filter(
340
- Message.chat_id == chat_id
341
- ).order_by(desc(Message.timestamp)).first()
342
-
343
- if last_message:
344
- return {
345
- "content": last_message.content[:100] + "..." if len(last_message.content) > 100 else last_message.content,
346
- "sender": last_message.sender,
347
- "timestamp": last_message.timestamp.isoformat()
348
- }
349
- return None
350
- except SQLAlchemyError as e:
351
- logger.log_message(f"Error retrieving last message: {str(e)}", level=logging.ERROR)
352
- return None
353
- finally:
354
- session.close()
355
-
356
- def delete_chat(self, chat_id: int, user_id: Optional[int] = None) -> bool:
357
- """
358
- Delete a chat and all its messages.
359
-
360
- Args:
361
- chat_id: ID of the chat to delete
362
- user_id: Optional user ID to verify ownership
363
-
364
- Returns:
365
- True if deletion was successful, False otherwise
366
- """
367
- session = self.Session()
368
- try:
369
- # Check if chat exists and belongs to the user if user_id is provided
370
- if user_id is not None:
371
- chat = session.query(Chat).filter(
372
- Chat.chat_id == chat_id,
373
- Chat.user_id == user_id
374
- ).first()
375
- if not chat:
376
- return False # Chat not found or doesn't belong to the user
377
-
378
- # Delete all messages in the chat
379
- session.query(Message).filter(Message.chat_id == chat_id).delete()
380
-
381
- # Delete the chat (with user_id filter if provided)
382
- query = session.query(Chat).filter(Chat.chat_id == chat_id)
383
- if user_id is not None:
384
- query = query.filter(Chat.user_id == user_id)
385
-
386
- result = query.delete()
387
- session.commit()
388
-
389
- return result > 0
390
- except SQLAlchemyError as e:
391
- session.rollback()
392
- logger.log_message(f"Error deleting chat: {str(e)}", level=logging.ERROR)
393
- return False
394
- finally:
395
- session.close()
396
-
397
- def search_chats(self, query: str, user_id: Optional[int] = None, limit: int = 10) -> List[Dict[str, Any]]:
398
- """
399
- Search for chats containing the query string.
400
-
401
- Args:
402
- query: Search query string
403
- user_id: Optional user ID to filter chats
404
- limit: Maximum number of results to return
405
-
406
- Returns:
407
- List of dictionaries containing matching chat information
408
- """
409
- session = self.Session()
410
- try:
411
- # Build base query to find messages containing the search term
412
- message_query = session.query(Message.chat_id).filter(
413
- Message.content.ilike(f"%{query}%")
414
- ).distinct()
415
-
416
- # Apply user filter if provided
417
- chat_query = session.query(Chat).filter(Chat.chat_id.in_(message_query))
418
- if user_id is not None:
419
- chat_query = chat_query.filter(Chat.user_id == user_id)
420
-
421
- # Get matching chats
422
- chats = chat_query.order_by(desc(Chat.created_at)).limit(limit).all()
423
-
424
- # Format response
425
- return [
426
- {
427
- "chat_id": chat.chat_id,
428
- "title": chat.title,
429
- "created_at": chat.created_at.isoformat(),
430
- "user_id": chat.user_id,
431
- "matching_messages": self._get_matching_messages(chat.chat_id, query)
432
- } for chat in chats
433
- ]
434
- except SQLAlchemyError as e:
435
- logger.log_message(f"Error searching chats: {str(e)}", level=logging.ERROR)
436
- raise
437
- finally:
438
- session.close()
439
-
440
- def _get_matching_messages(self, chat_id: int, query: str, limit: int = 3) -> List[Dict[str, Any]]:
441
- """
442
- Get messages from a chat that match the search query.
443
-
444
- Args:
445
- chat_id: ID of the chat
446
- query: Search query string
447
- limit: Maximum number of matching messages to return
448
-
449
- Returns:
450
- List of dictionaries containing matching message information
451
- """
452
- session = self.Session()
453
- try:
454
- matching_messages = session.query(Message).filter(
455
- Message.chat_id == chat_id,
456
- Message.content.ilike(f"%{query}%")
457
- ).order_by(Message.timestamp).limit(limit).all()
458
-
459
- return [
460
- {
461
- "message_id": msg.message_id,
462
- "content": msg.content,
463
- "sender": msg.sender,
464
- "timestamp": msg.timestamp.isoformat()
465
- } for msg in matching_messages
466
- ]
467
- except SQLAlchemyError as e:
468
- logger.log_message(f"Error retrieving matching messages: {str(e)}", level=logging.ERROR)
469
- return []
470
- finally:
471
- session.close()
472
-
473
- def get_or_create_user(self, username: str, email: str) -> Dict[str, Any]:
474
- """
475
- Get an existing user by email or create a new one if not found.
476
-
477
- Args:
478
- username: User's display name
479
- email: User's email address
480
-
481
- Returns:
482
- Dictionary containing user information
483
- """
484
- session = self.Session()
485
- try:
486
- # Try to find existing user by email
487
- user = session.query(User).filter(User.email == email).first()
488
-
489
- if not user:
490
- # Create new user if not found
491
- user = User(username=username, email=email)
492
- session.add(user)
493
- session.commit()
494
- logger.log_message(f"Created new user: {username} ({email})", level=logging.INFO)
495
-
496
- return {
497
- "user_id": user.user_id,
498
- "username": user.username,
499
- "email": user.email,
500
- "created_at": user.created_at.isoformat()
501
- }
502
- except SQLAlchemyError as e:
503
- session.rollback()
504
- logger.log_message(f"Error getting/creating user: {str(e)}", level=logging.ERROR)
505
- raise
506
- finally:
507
- session.close()
508
-
509
- def update_chat(self, chat_id: int, title: Optional[str] = None, user_id: Optional[int] = None) -> Dict[str, Any]:
510
- """
511
- Update a chat's title or user_id.
512
-
513
- Args:
514
- chat_id: ID of the chat to update
515
- title: New title for the chat (optional)
516
- user_id: New user ID for the chat (optional)
517
-
518
- Returns:
519
- Dictionary containing updated chat information
520
- """
521
- session = self.Session()
522
- try:
523
- # Get the chat
524
- chat = session.query(Chat).filter(Chat.chat_id == chat_id).first()
525
- if not chat:
526
- raise ValueError(f"Chat with ID {chat_id} not found")
527
-
528
- # Update fields if provided
529
- if title is not None:
530
- chat.title = title
531
- if user_id is not None:
532
- chat.user_id = user_id
533
-
534
- session.commit()
535
-
536
- return {
537
- "chat_id": chat.chat_id,
538
- "title": chat.title,
539
- "created_at": chat.created_at.isoformat(),
540
- "user_id": chat.user_id
541
- }
542
- except SQLAlchemyError as e:
543
- session.rollback()
544
- logger.log_message(f"Error updating chat: {str(e)}", level=logging.ERROR)
545
- raise
546
- finally:
547
- session.close()
548
-
549
- def generate_title_from_query(self, query: str) -> str:
550
- """
551
- Generate a title for a chat based on the first query.
552
-
553
- Args:
554
- query: The user's first query in the chat
555
-
556
- Returns:
557
- A generated title string
558
- """
559
- try:
560
- # Simple title generation - take first few words
561
- words = query.split()
562
- if len(words) > 3:
563
- title = "Chat about " + " ".join(words[0:3]) + "..."
564
- else:
565
- title = "Chat about " + query
566
-
567
- # Limit title length
568
- if len(title) > 40:
569
- title = title[:37] + "..."
570
-
571
- return title
572
- except Exception as e:
573
- logger.log_message(f"Error generating title: {str(e)}", level=logging.ERROR)
574
- return "New Chat"
575
-
576
- def delete_empty_chats(self, user_id: Optional[int] = None, is_admin: bool = False) -> int:
577
- """
578
- Delete empty chats (chats with no messages) for a user.
579
-
580
- Args:
581
- user_id: ID of the user whose empty chats should be deleted
582
- is_admin: Whether this is an admin user
583
-
584
- Returns:
585
- Number of chats deleted
586
- """
587
- session = self.Session()
588
- try:
589
- # Get all chats for the user
590
- query = session.query(Chat)
591
- if user_id is not None:
592
- query = query.filter(Chat.user_id == user_id)
593
- elif not is_admin:
594
- return 0 # Don't delete anything if not a user or admin
595
-
596
- # For each chat, check if it has any messages
597
- chats_to_delete = []
598
- for chat in query.all():
599
- message_count = session.query(Message).filter(
600
- Message.chat_id == chat.chat_id
601
- ).count()
602
-
603
- if message_count == 0:
604
- chats_to_delete.append(chat.chat_id)
605
-
606
- # Delete the empty chats
607
- if chats_to_delete:
608
- deleted = session.query(Chat).filter(
609
- Chat.chat_id.in_(chats_to_delete)
610
- ).delete(synchronize_session=False)
611
-
612
- session.commit()
613
- return deleted
614
- return 0
615
- except SQLAlchemyError as e:
616
- session.rollback()
617
- logger.log_message(f"Error deleting empty chats: {str(e)}", level=logging.ERROR)
618
- return 0
619
- finally:
620
- session.close()
621
-
622
- def save_ai_response(self, chat_id: int, content: str, user_id: Optional[int] = None,
623
- model_name: str = "gpt-4o-mini", prompt: str = "",
624
- prompt_tokens: Optional[int] = None,
625
- completion_tokens: Optional[int] = None,
626
- start_time: Optional[float] = None):
627
- """
628
- Save an AI response to a chat and track model usage.
629
-
630
- Args:
631
- chat_id: ID of the chat to add the message to
632
- content: AI response content
633
- user_id: Optional user ID for tracking
634
- model_name: Model used to generate the response
635
- prompt: The prompt sent to the model
636
- prompt_tokens: Optional pre-counted prompt tokens
637
- completion_tokens: Optional pre-counted completion tokens
638
- start_time: Optional start time of the request
639
- """
640
- session = self.Session()
641
- try:
642
- # Create and save message
643
- new_message = Message(
644
- chat_id=chat_id,
645
- sender='ai',
646
- content=content,
647
- timestamp=datetime.utcnow()
648
- )
649
- session.add(new_message)
650
- session.commit()
651
-
652
- # Track model usage
653
- end_time = time.time()
654
- start_time = start_time or end_time - 1 # Default to 1 second if not provided
655
-
656
- self.track_model_usage(
657
- user_id=user_id,
658
- chat_id=chat_id,
659
- model_name=model_name,
660
- prompt=prompt,
661
- response=content,
662
- start_time=start_time,
663
- prompt_tokens=prompt_tokens,
664
- completion_tokens=completion_tokens
665
- )
666
-
667
- except SQLAlchemyError as e:
668
- logger.log_message(f"Error saving AI response: {str(e)}", level=logging.ERROR)
669
- session.rollback()
670
- finally:
671
- session.close()
672
-
673
- def track_model_usage(self, user_id: Optional[int], chat_id: int, model_name: str,
674
- prompt: str, response: str, start_time: float,
675
- is_streaming: bool = False,
676
- prompt_tokens: Optional[int] = None,
677
- completion_tokens: Optional[int] = None) -> Dict[str, Any]:
678
- """
679
- Track AI model usage for analytics and billing.
680
-
681
- Args:
682
- user_id: Optional user ID making the request
683
- chat_id: Chat ID associated with the request
684
- model_name: Name of the AI model used
685
- prompt: The prompt text sent to the model
686
- response: The response text from the model
687
- start_time: Start time of the request (time.time() value)
688
- is_streaming: Whether the response was streamed
689
- prompt_tokens: Optional pre-counted prompt tokens
690
- completion_tokens: Optional pre-counted completion tokens
691
-
692
- Returns:
693
- Dictionary with usage information
694
- """
695
- session = self.Session()
696
- try:
697
- # Determine model provider
698
- provider = "unknown"
699
- for prefix, prov in self.model_providers.items():
700
- if model_name.startswith(prefix):
701
- provider = prov
702
- break
703
-
704
- # Calculate tokens if not provided
705
- if prompt_tokens is None or completion_tokens is None:
706
- try:
707
- encoding = tiktoken.encoding_for_model(model_name) if provider == "openai" else tiktoken.get_encoding("cl100k_base")
708
- prompt_tokens = len(encoding.encode(prompt)) if prompt_tokens is None else prompt_tokens
709
- completion_tokens = len(encoding.encode(response)) if completion_tokens is None else completion_tokens
710
- except Exception as e:
711
- logger.log_message(f"Error calculating tokens: {str(e)}", level=logging.ERROR)
712
- # Fallback to character-based estimation
713
- prompt_tokens = len(prompt) // 4
714
- completion_tokens = len(response) // 4
715
-
716
- total_tokens = prompt_tokens + completion_tokens
717
-
718
- # Calculate cost
719
- cost = 0.0
720
- if model_name in self.model_costs:
721
- cost = (prompt_tokens * self.model_costs[model_name]["input"] / 1000000) + \
722
- (completion_tokens * self.model_costs[model_name]["output"] / 1000000)
723
-
724
- # Calculate request time
725
- request_time_ms = int((time.time() - start_time) * 1000)
726
-
727
- # Create usage record
728
- usage = ModelUsage(
729
- user_id=user_id,
730
- chat_id=chat_id,
731
- model_name=model_name,
732
- provider=provider,
733
- prompt_tokens=prompt_tokens,
734
- completion_tokens=completion_tokens,
735
- total_tokens=total_tokens,
736
- query_size=len(prompt),
737
- response_size=len(response),
738
- cost=cost,
739
- timestamp=datetime.utcnow(),
740
- is_streaming=is_streaming,
741
- request_time_ms=request_time_ms
742
- )
743
-
744
- session.add(usage)
745
- session.commit()
746
-
747
- return {
748
- "usage_id": usage.usage_id,
749
- "model_name": model_name,
750
- "provider": provider,
751
- "prompt_tokens": prompt_tokens,
752
- "completion_tokens": completion_tokens,
753
- "total_tokens": total_tokens,
754
- "cost": cost,
755
- "request_time_ms": request_time_ms
756
- }
757
-
758
- except SQLAlchemyError as e:
759
- session.rollback()
760
- logger.log_message(f"Error tracking model usage: {str(e)}", level=logging.ERROR)
761
- return {}
762
- finally:
763
- session.close()
764
-
765
- def get_model_usage_analytics(self, start_date: Optional[datetime] = None,
766
- end_date: Optional[datetime] = None,
767
- user_id: Optional[int] = None,
768
- model_name: Optional[str] = None,
769
- provider: Optional[str] = None,
770
- limit: int = 1000) -> List[Dict[str, Any]]:
771
- """
772
- Get model usage analytics with optional filtering.
773
-
774
- Args:
775
- start_date: Optional start date for the analytics period
776
- end_date: Optional end date for the analytics period
777
- user_id: Optional user ID to filter by
778
- model_name: Optional model name to filter by
779
- provider: Optional provider to filter by
780
- limit: Maximum number of records to return
781
-
782
- Returns:
783
- List of dictionaries containing usage records
784
- """
785
- session = self.Session()
786
- try:
787
- query = session.query(ModelUsage)
788
-
789
- # Apply filters
790
- if start_date:
791
- query = query.filter(ModelUsage.timestamp >= start_date)
792
- if end_date:
793
- query = query.filter(ModelUsage.timestamp <= end_date)
794
- if user_id:
795
- query = query.filter(ModelUsage.user_id == user_id)
796
- if model_name:
797
- query = query.filter(ModelUsage.model_name == model_name)
798
- if provider:
799
- query = query.filter(ModelUsage.provider == provider)
800
-
801
- # Order by timestamp descending
802
- query = query.order_by(ModelUsage.timestamp.desc()).limit(limit)
803
-
804
- usages = query.all()
805
-
806
- return [{
807
- "usage_id": usage.usage_id,
808
- "user_id": usage.user_id,
809
- "chat_id": usage.chat_id,
810
- "model_name": usage.model_name,
811
- "provider": usage.provider,
812
- "prompt_tokens": usage.prompt_tokens,
813
- "completion_tokens": usage.completion_tokens,
814
- "total_tokens": usage.total_tokens,
815
- "query_size": usage.query_size,
816
- "response_size": usage.response_size,
817
- "cost": usage.cost,
818
- "timestamp": usage.timestamp.isoformat(),
819
- "is_streaming": usage.is_streaming,
820
- "request_time_ms": usage.request_time_ms
821
- } for usage in usages]
822
-
823
- except SQLAlchemyError as e:
824
- logger.log_message(f"Error retrieving model usage analytics: {str(e)}", level=logging.ERROR)
825
- return []
826
- finally:
827
- session.close()
828
-
829
- def get_usage_summary(self, start_date: Optional[datetime] = None,
830
- end_date: Optional[datetime] = None) -> Dict[str, Any]:
831
- """
832
- Get a summary of model usage including total costs, tokens, and usage by model.
833
-
834
- Args:
835
- start_date: Optional start date for the summary period
836
- end_date: Optional end date for the summary period
837
-
838
- Returns:
839
- Dictionary containing usage summary
840
- """
841
- session = self.Session()
842
- try:
843
- query = session.query(
844
- func.sum(ModelUsage.cost).label("total_cost"),
845
- func.sum(ModelUsage.prompt_tokens).label("total_prompt_tokens"),
846
- func.sum(ModelUsage.completion_tokens).label("total_completion_tokens"),
847
- func.sum(ModelUsage.total_tokens).label("total_tokens"),
848
- func.count(ModelUsage.usage_id).label("request_count"),
849
- func.avg(ModelUsage.request_time_ms).label("avg_request_time")
850
- )
851
-
852
- # Apply date filters
853
- if start_date:
854
- query = query.filter(ModelUsage.timestamp >= start_date)
855
- if end_date:
856
- query = query.filter(ModelUsage.timestamp <= end_date)
857
-
858
- result = query.first()
859
-
860
- # Get usage breakdown by model
861
- model_query = session.query(
862
- ModelUsage.model_name,
863
- func.sum(ModelUsage.cost).label("model_cost"),
864
- func.sum(ModelUsage.total_tokens).label("model_tokens"),
865
- func.count(ModelUsage.usage_id).label("model_requests")
866
- )
867
-
868
- if start_date:
869
- model_query = model_query.filter(ModelUsage.timestamp >= start_date)
870
- if end_date:
871
- model_query = model_query.filter(ModelUsage.timestamp <= end_date)
872
-
873
- model_query = model_query.group_by(ModelUsage.model_name)
874
- model_breakdown = model_query.all()
875
-
876
- # Get usage breakdown by provider
877
- provider_query = session.query(
878
- ModelUsage.provider,
879
- func.sum(ModelUsage.cost).label("provider_cost"),
880
- func.sum(ModelUsage.total_tokens).label("provider_tokens"),
881
- func.count(ModelUsage.usage_id).label("provider_requests")
882
- )
883
-
884
- if start_date:
885
- provider_query = provider_query.filter(ModelUsage.timestamp >= start_date)
886
- if end_date:
887
- provider_query = provider_query.filter(ModelUsage.timestamp <= end_date)
888
-
889
- provider_query = provider_query.group_by(ModelUsage.provider)
890
- provider_breakdown = provider_query.all()
891
-
892
- # Get top users by cost
893
- user_query = session.query(
894
- ModelUsage.user_id,
895
- func.sum(ModelUsage.cost).label("user_cost"),
896
- func.sum(ModelUsage.total_tokens).label("user_tokens"),
897
- func.count(ModelUsage.usage_id).label("user_requests")
898
- )
899
-
900
- if start_date:
901
- user_query = user_query.filter(ModelUsage.timestamp >= start_date)
902
- if end_date:
903
- user_query = user_query.filter(ModelUsage.timestamp <= end_date)
904
-
905
- user_query = user_query.group_by(ModelUsage.user_id)
906
- user_query = user_query.order_by(func.sum(ModelUsage.cost).desc())
907
- user_query = user_query.limit(10)
908
- user_breakdown = user_query.all()
909
-
910
- return {
911
- "summary": {
912
- "total_cost": float(result.total_cost) if result.total_cost else 0.0,
913
- "total_prompt_tokens": int(result.total_prompt_tokens) if result.total_prompt_tokens else 0,
914
- "total_completion_tokens": int(result.total_completion_tokens) if result.total_completion_tokens else 0,
915
- "total_tokens": int(result.total_tokens) if result.total_tokens else 0,
916
- "request_count": int(result.request_count) if result.request_count else 0,
917
- "avg_request_time_ms": float(result.avg_request_time) if result.avg_request_time else 0.0
918
- },
919
- "model_breakdown": [
920
- {
921
- "model_name": model.model_name,
922
- "cost": float(model.model_cost) if model.model_cost else 0.0,
923
- "tokens": int(model.model_tokens) if model.model_tokens else 0,
924
- "requests": int(model.model_requests) if model.model_requests else 0
925
- } for model in model_breakdown
926
- ],
927
- "provider_breakdown": [
928
- {
929
- "provider": provider.provider,
930
- "cost": float(provider.provider_cost) if provider.provider_cost else 0.0,
931
- "tokens": int(provider.provider_tokens) if provider.provider_tokens else 0,
932
- "requests": int(provider.provider_requests) if provider.provider_requests else 0
933
- } for provider in provider_breakdown
934
- ],
935
- "top_users": [
936
- {
937
- "user_id": user.user_id,
938
- "cost": float(user.user_cost) if user.user_cost else 0.0,
939
- "tokens": int(user.user_tokens) if user.user_tokens else 0,
940
- "requests": int(user.user_requests) if user.user_requests else 0
941
- } for user in user_breakdown
942
- ]
943
- }
944
-
945
- except SQLAlchemyError as e:
946
- logger.log_message(f"Error retrieving usage summary: {str(e)}", level=logging.ERROR)
947
- return {
948
- "summary": {
949
- "total_cost": 0.0,
950
- "total_tokens": 0,
951
- "request_count": 0
952
- },
953
- "model_breakdown": [],
954
- "provider_breakdown": [],
955
- "top_users": []
956
- }
957
- finally:
958
- session.close()
959
-
960
- def get_recent_chat_history(self, chat_id: int, limit: int = 5) -> List[Dict[str, Any]]:
961
- """
962
- Get recent message history for a chat, limited to the last 'limit' messages.
963
-
964
- Args:
965
- chat_id: ID of the chat to get history for
966
- limit: Maximum number of recent messages to return
967
-
968
- Returns:
969
- List of dictionaries containing message information
970
- """
971
- session = self.Session()
972
- try:
973
- messages = session.query(Message).filter(
974
- Message.chat_id == chat_id
975
- ).order_by(Message.timestamp.desc()).limit(limit * 2).all() # Fetch more to get message pairs
976
-
977
- # Reverse to get chronological order
978
- messages.reverse()
979
-
980
- return [
981
- {
982
- "message_id": msg.message_id,
983
- "chat_id": msg.chat_id,
984
- "content": msg.content,
985
- "sender": msg.sender,
986
- "timestamp": msg.timestamp.isoformat()
987
- } for msg in messages
988
- ]
989
- except SQLAlchemyError as e:
990
- logger.log_message(f"Error retrieving chat history: {str(e)}", level=logging.ERROR)
991
- return []
992
- finally:
993
- session.close()
994
-
995
-
996
- def extract_response_history(self, messages: List[Dict[str, Any]]) -> str:
997
- """
998
- Extract response history from message history.
999
-
1000
- Args:
1001
- messages: List of message dictionaries
1002
-
1003
- Returns:
1004
- String containing combined response history in a structured format
1005
- """
1006
-
1007
- summaries = []
1008
- user_messages = []
1009
- for msg in messages:
1010
- # Get User Messages
1011
- if msg.get("sender") == "user":
1012
- user_messages.append(msg)
1013
- # Ensure content exists and is from AI before extracting summary
1014
- if msg.get("sender") == "ai" and "content" in msg:
1015
- content = msg["content"]
1016
- matches = re.findall(r"### Summary\n(.*?)(?=\n\n##|\Z)", content, re.DOTALL)
1017
- summaries.extend(match.strip() for match in matches)
1018
-
1019
- # Combine user messages with summaries in a structured format
1020
- combined_conversations = []
1021
- for user_msg, summary in zip(user_messages, summaries):
1022
- combined_conversations.append(f"Query: {user_msg['content']}\nSummary: {summary}")
1023
-
1024
- # Return the last 3 conversations to maintain context
1025
- formatted_context = "\n\n".join(combined_conversations[-3:])
1026
-
1027
- # Add a clear header to indicate this is past interaction history
1028
- if formatted_context:
1029
- return f"### Previous Interaction History:\n{formatted_context}"
1030
- return ""
 
1
+
2
+ from sqlalchemy import create_engine, desc, func, exists
3
+ from sqlalchemy.orm import sessionmaker, scoped_session
4
+ from sqlalchemy.exc import SQLAlchemyError
5
+ from src.db.schemas.models import Base, User, Chat, Message, ModelUsage
6
+ import logging
7
+ import requests
8
+ import json
9
+ from typing import List, Dict, Optional, Tuple, Any
10
+ from datetime import datetime
11
+ import time
12
+ import tiktoken
13
+ from src.utils.logger import Logger
14
+ import re
15
+
16
+ logger = Logger("chat_manager", see_time=True, console_log=False)
17
+
18
+
19
+ class ChatManager:
20
+ """
21
+ Manages chat operations including creating, storing, retrieving, and updating chats and messages.
22
+ Provides an interface between the application and the database for chat-related operations.
23
+ """
24
+
25
+ def __init__(self, db_url):
26
+ """
27
+ Initialize the ChatManager with a database connection.
28
+
29
+ Args:
30
+ db_url: Database connection URL (defaults to SQLite)
31
+ """
32
+ self.engine = create_engine(db_url)
33
+ Base.metadata.create_all(self.engine) # Ensure tables exist
34
+ self.Session = scoped_session(sessionmaker(bind=self.engine))
35
+
36
+ # Add price mappings for different models
37
+ self.model_costs = {
38
+ "openai": {
39
+ "gpt-4": {"input": 0.03, "output": 0.06},
40
+ "gpt-4o": {"input": 0.0025, "output": 0.01},
41
+ "gpt-4.5-preview": {"input": 0.075, "output": 0.15},
42
+ "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
43
+ "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
44
+ "o1": {"input": 0.015, "output": 0.06},
45
+ "o1-mini": {"input": 0.00011, "output": 0.00044},
46
+ "o3-mini": {"input": 0.00011, "output": 0.00044}
47
+ },
48
+ "anthropic": {
49
+ "claude-3-opus-latest": {"input": 0.015, "output": 0.075},
50
+ "claude-3-7-sonnet-latest": {"input": 0.003, "output": 0.015},
51
+ "claude-3-5-sonnet-latest": {"input": 0.003, "output": 0.015},
52
+ "claude-3-5-haiku-latest": {"input": 0.0008, "output": 0.0004},
53
+ },
54
+ "groq": {
55
+ "deepseek-r1-distill-llama-70b": {"input": 0.00075, "output": 0.00099},
56
+ "llama-3.3-70b-versatile": {"input": 0.00059, "output": 0.00079},
57
+ "llama3-8b-8192": {"input": 0.00005, "output": 0.00008},
58
+ "llama3-70b-8192": {"input": 0.00059, "output": 0.00079},
59
+ "llama-3.1-8b-instant": {"input": 0.00005, "output": 0.00008},
60
+ "mistral-saba-24b": {"input": 0.00079, "output": 0.00079},
61
+ "gemma2-9b-it": {"input": 0.0002, "output": 0.0002},
62
+ "qwen-qwq-32b": {"input": 0.00029, "output": 0.00039},
63
+ "meta-llama/llama-4-maverick-17b-128e-instruct": {"input": 0.0002, "output": 0.0006},
64
+ "meta-llama/llama-4-scout-17b-16e-instruct": {"input": 0.00011, "output": 0.00034},
65
+ },
66
+ "gemini": {
67
+ "gemini-2.5-pro-preview-03-25": {"input": 0.00015, "output": 0.001}
68
+ }
69
+ }
70
+
71
+
72
+ # Add model providers mapping
73
+ self.model_providers = {
74
+ "gpt-": "openai",
75
+ "claude-": "anthropic",
76
+ "llama-": "groq",
77
+ "mistral-": "groq",
78
+ }
79
+
80
+ def create_chat(self, user_id: Optional[int] = None) -> Dict[str, Any]:
81
+ """
82
+ Create a new chat session.
83
+
84
+ Args:
85
+ user_id: Optional user ID if authentication is enabled
86
+
87
+ Returns:
88
+ Dictionary containing chat information
89
+ """
90
+ session = self.Session()
91
+ try:
92
+ # Create a new chat
93
+ chat = Chat(
94
+ user_id=user_id,
95
+ title='New Chat',
96
+ created_at=datetime.utcnow()
97
+ )
98
+ session.add(chat)
99
+ session.flush() # Flush to get the ID before commit
100
+
101
+ chat_id = chat.chat_id # Get the ID now
102
+ session.commit()
103
+
104
+ logger.log_message(f"Created new chat {chat_id} for user {user_id}", level=logging.INFO)
105
+
106
+ return {
107
+ "chat_id": chat_id,
108
+ "user_id": chat.user_id,
109
+ "title": chat.title,
110
+ "created_at": chat.created_at.isoformat()
111
+ }
112
+ except SQLAlchemyError as e:
113
+ session.rollback()
114
+ logger.log_message(f"Error creating chat: {str(e)}", level=logging.ERROR)
115
+ raise
116
+ finally:
117
+ session.close()
118
+
119
+ def add_message(self, chat_id: int, content: str, sender: str, user_id: Optional[int] = None) -> Dict[str, Any]:
120
+ """
121
+ Add a message to a chat.
122
+
123
+ Args:
124
+ chat_id: ID of the chat to add the message to
125
+ content: Message content
126
+ sender: Message sender ('user' or 'ai')
127
+ user_id: Optional user ID to verify ownership
128
+
129
+ Returns:
130
+ Dictionary containing message information
131
+ """
132
+ session = self.Session()
133
+ try:
134
+ # Check if chat exists and belongs to the user if user_id is provided
135
+ query = session.query(Chat).filter(Chat.chat_id == chat_id)
136
+ if user_id is not None:
137
+ query = query.filter((Chat.user_id == user_id) | (Chat.user_id.is_(None)))
138
+
139
+ chat = query.first()
140
+ if not chat:
141
+ raise ValueError(f"Chat with ID {chat_id} not found or access denied")
142
+
143
+ ##! Ensure content length is reasonable for PostgreSQL
144
+ # max_content_length = 10000 # PostgreSQL can handle large text but let's be cautious
145
+ # if content and len(content) > max_content_length:
146
+ # logger.log_message(f"Truncating message content from {len(content)} to {max_content_length} characters",
147
+ # level=logging.WARNING)
148
+ # content = content[:max_content_length]
149
+
150
+ # Create a new message
151
+ message = Message(
152
+ chat_id=chat_id,
153
+ content=content,
154
+ sender=sender,
155
+ timestamp=datetime.utcnow()
156
+ )
157
+ session.add(message)
158
+ session.flush() # Flush to get the ID before commit
159
+
160
+ message_id = message.message_id # Get ID now
161
+
162
+ # If this is the first AI response and chat title is still default,
163
+ # update the chat title based on the first user query
164
+ if sender == 'ai':
165
+ first_ai_message = session.query(Message).filter(
166
+ Message.chat_id == chat_id,
167
+ Message.sender == 'ai'
168
+ ).first()
169
+
170
+ if not first_ai_message and chat.title == 'New Chat':
171
+ # Get the user's first message
172
+ first_user_message = session.query(Message).filter(
173
+ Message.chat_id == chat_id,
174
+ Message.sender == 'user'
175
+ ).order_by(Message.timestamp).first()
176
+
177
+ if first_user_message:
178
+ # Generate title from user query
179
+ new_title = self.generate_title_from_query(first_user_message.content)
180
+ chat.title = new_title
181
+
182
+ session.commit()
183
+
184
+ return {
185
+ "message_id": message_id,
186
+ "chat_id": message.chat_id,
187
+ "content": message.content,
188
+ "sender": message.sender,
189
+ "timestamp": message.timestamp.isoformat()
190
+ }
191
+ except SQLAlchemyError as e:
192
+ session.rollback()
193
+ logger.log_message(f"Error adding message: {str(e)}", level=logging.ERROR)
194
+ raise
195
+ finally:
196
+ session.close()
197
+
198
+
199
+ def get_chat(self, chat_id: int, user_id: Optional[int] = None) -> Dict[str, Any]:
200
+ """
201
+ Get a chat by ID with all its messages.
202
+
203
+ Args:
204
+ chat_id: ID of the chat to retrieve
205
+ user_id: Optional user ID to verify ownership
206
+
207
+ Returns:
208
+ Dictionary containing chat information and messages
209
+ """
210
+ session = self.Session()
211
+ try:
212
+ # Get the chat
213
+ query = session.query(Chat).filter(Chat.chat_id == chat_id)
214
+
215
+ # If user_id is provided, ensure the chat belongs to this user
216
+ if user_id is not None:
217
+ query = query.filter((Chat.user_id == user_id) | (Chat.user_id.is_(None)))
218
+
219
+ chat = query.first()
220
+ if not chat:
221
+ raise ValueError(f"Chat with ID {chat_id} not found or access denied")
222
+
223
+ # Get the chat messages ordered by timestamp
224
+ messages = session.query(Message).filter(
225
+ Message.chat_id == chat_id
226
+ ).order_by(Message.timestamp).all()
227
+
228
+ # Create a safe serializable dictionary
229
+ return {
230
+ "chat_id": chat.chat_id,
231
+ "title": chat.title,
232
+ "created_at": chat.created_at.isoformat() if chat.created_at else None,
233
+ "user_id": chat.user_id,
234
+ "messages": [
235
+ {
236
+ "message_id": msg.message_id,
237
+ "chat_id": msg.chat_id,
238
+ "content": msg.content,
239
+ "sender": msg.sender,
240
+ "timestamp": msg.timestamp.isoformat() if msg.timestamp else None
241
+ } for msg in messages
242
+ ]
243
+ }
244
+ except SQLAlchemyError as e:
245
+ logger.log_message(f"Error retrieving chat: {str(e)}", level=logging.ERROR)
246
+ raise
247
+ finally:
248
+ session.close()
249
+
250
+ def get_user_chats(self, user_id: Optional[int] = None, limit: int = 10, offset: int = 0) -> List[Dict[str, Any]]:
251
+ """
252
+ Get recent chats for a user, or all chats if no user_id is provided.
253
+
254
+ Args:
255
+ user_id: Optional user ID to filter chats
256
+ limit: Maximum number of chats to return
257
+ offset: Number of chats to skip (for pagination)
258
+
259
+ Returns:
260
+ List of dictionaries containing chat information
261
+ """
262
+ session = self.Session()
263
+ try:
264
+ query = session.query(Chat)
265
+
266
+ # Filter by user_id if provided
267
+ if user_id is not None:
268
+ query = query.filter(Chat.user_id == user_id)
269
+
270
+ # Apply safe limits for both SQLite and PostgreSQL
271
+ safe_limit = min(max(1, limit), 100) # Between 1 and 100
272
+ safe_offset = max(0, offset) # At least 0
273
+
274
+ chats = query.order_by(Chat.created_at.desc()).limit(safe_limit).offset(safe_offset).all()
275
+
276
+ return [
277
+ {
278
+ "chat_id": chat.chat_id,
279
+ "user_id": chat.user_id,
280
+ "title": chat.title,
281
+ "created_at": chat.created_at.isoformat() if chat.created_at else None
282
+ } for chat in chats
283
+ ]
284
+ except SQLAlchemyError as e:
285
+ logger.log_message(f"Error retrieving chats: {str(e)}", level=logging.ERROR)
286
+ return []
287
+ finally:
288
+ session.close()
289
+
290
+ def delete_chat(self, chat_id: int, user_id: Optional[int] = None) -> bool:
291
+ """
292
+ Delete a chat and all its messages while preserving model usage records.
293
+
294
+ Args:
295
+ chat_id: ID of the chat to delete
296
+ user_id: Optional user ID to verify ownership
297
+
298
+ Returns:
299
+ True if deletion was successful, False otherwise
300
+ """
301
+ session = self.Session()
302
+ try:
303
+ # Fetch chat with ownership check if user_id provided
304
+ query = session.query(Chat).filter(Chat.chat_id == chat_id)
305
+ if user_id is not None:
306
+ query = query.filter(Chat.user_id == user_id)
307
+
308
+ chat = query.first()
309
+ if not chat:
310
+ return False # Chat not found or ownership mismatch
311
+
312
+ # ORM-based deletion with model_usage preservation
313
+ # The SET NULL in the foreign key should handle this, but we ensure it explicitly for both
314
+ # SQLite and PostgreSQL compatibility
315
+
316
+ # For SQLite which might not respect ondelete="SET NULL" always:
317
+ # Update model_usage records to set chat_id to NULL
318
+ session.query(ModelUsage).filter(ModelUsage.chat_id == chat_id).update(
319
+ {"chat_id": None}, synchronize_session=False
320
+ )
321
+
322
+ # Now delete the chat - relationships will handle cascading to messages
323
+ session.delete(chat)
324
+ session.commit()
325
+ return True
326
+ except SQLAlchemyError as e:
327
+ session.rollback()
328
+ logger.log_message(f"Error deleting chat: {str(e)}", level=logging.ERROR)
329
+ return False
330
+ finally:
331
+ session.close()
332
+
333
+
334
+
335
+ def get_or_create_user(self, username: str, email: str) -> Dict[str, Any]:
336
+ """
337
+ Get an existing user by email or create a new one if not found.
338
+
339
+ Args:
340
+ username: User's display name
341
+ email: User's email address
342
+
343
+ Returns:
344
+ Dictionary containing user information
345
+ """
346
+ session = self.Session()
347
+ try:
348
+ # Validate and sanitize inputs
349
+ if not email or not isinstance(email, str):
350
+ raise ValueError("Valid email is required")
351
+
352
+ # Limit input length for PostgreSQL compatibility
353
+ max_length = 255 # Standard limit for varchar fields
354
+ if username and len(username) > max_length:
355
+ username = username[:max_length]
356
+ if email and len(email) > max_length:
357
+ email = email[:max_length]
358
+
359
+ # Try to find existing user by email
360
+ user = session.query(User).filter(User.email == email).first()
361
+
362
+ if not user:
363
+ # Create new user if not found
364
+ user = User(username=username, email=email)
365
+ session.add(user)
366
+ session.flush() # Get ID before committing
367
+ user_id = user.user_id
368
+ session.commit()
369
+ logger.log_message(f"Created new user: {username} ({email})", level=logging.INFO)
370
+ else:
371
+ user_id = user.user_id
372
+
373
+ return {
374
+ "user_id": user_id,
375
+ "username": user.username,
376
+ "email": user.email,
377
+ "created_at": user.created_at.isoformat() if user.created_at else None
378
+ }
379
+ except SQLAlchemyError as e:
380
+ session.rollback()
381
+ logger.log_message(f"Error getting/creating user: {str(e)}", level=logging.ERROR)
382
+ raise
383
+ finally:
384
+ session.close()
385
+
386
+ def update_chat(self, chat_id: int, title: Optional[str] = None, user_id: Optional[int] = None) -> Dict[str, Any]:
387
+ """
388
+ Update a chat's title or user_id.
389
+
390
+ Args:
391
+ chat_id: ID of the chat to update
392
+ title: New title for the chat (optional)
393
+ user_id: New user ID for the chat (optional)
394
+
395
+ Returns:
396
+ Dictionary containing updated chat information
397
+ """
398
+ session = self.Session()
399
+ try:
400
+ # Get the chat
401
+ chat = session.query(Chat).filter(Chat.chat_id == chat_id).first()
402
+ if not chat:
403
+ raise ValueError(f"Chat with ID {chat_id} not found")
404
+
405
+ # Update fields if provided
406
+ if title is not None:
407
+ # Limit title length for PostgreSQL compatibility
408
+ if len(title) > 255: # Assuming String column has a reasonable length
409
+ title = title[:255]
410
+ chat.title = title
411
+
412
+ if user_id is not None:
413
+ chat.user_id = user_id
414
+
415
+ session.commit()
416
+
417
+ return {
418
+ "chat_id": chat.chat_id,
419
+ "title": chat.title,
420
+ "created_at": chat.created_at.isoformat() if chat.created_at else None,
421
+ "user_id": chat.user_id
422
+ }
423
+ except SQLAlchemyError as e:
424
+ session.rollback()
425
+ logger.log_message(f"Error updating chat: {str(e)}", level=logging.ERROR)
426
+ raise
427
+ finally:
428
+ session.close()
429
+
430
+ def generate_title_from_query(self, query: str) -> str:
431
+ """
432
+ Generate a title for a chat based on the first query.
433
+
434
+ Args:
435
+ query: The user's first query in the chat
436
+
437
+ Returns:
438
+ A generated title string
439
+ """
440
+ try:
441
+ # Validate input
442
+ if not query or not isinstance(query, str):
443
+ return "New Chat"
444
+
445
+ # Simple title generation - take first few words
446
+ words = query.strip().split()
447
+ if len(words) > 3:
448
+ title = "Chat about " + " ".join(words[0:3]) + "..."
449
+ else:
450
+ title = "Chat about " + query.strip()
451
+
452
+ # Limit title length for PostgreSQL compatibility
453
+ max_title_length = 255
454
+ if len(title) > max_title_length:
455
+ title = title[:max_title_length-3] + "..."
456
+
457
+ return title
458
+ except Exception as e:
459
+ logger.log_message(f"Error generating title: {str(e)}", level=logging.ERROR)
460
+ return "New Chat"
461
+
462
+ def delete_empty_chats(self, user_id: Optional[int] = None, is_admin: bool = False) -> int:
463
+ """
464
+ Delete empty chats (chats with no messages) for a user.
465
+
466
+ Args:
467
+ user_id: ID of the user whose empty chats should be deleted
468
+ is_admin: Whether this is an admin user
469
+
470
+ Returns:
471
+ Number of chats deleted
472
+ """
473
+ session = self.Session()
474
+ try:
475
+ # Get all chats for the user
476
+ query = session.query(Chat)
477
+ if user_id is not None:
478
+ query = query.filter(Chat.user_id == user_id)
479
+ elif not is_admin:
480
+ return 0 # Don't delete anything if not a user or admin
481
+
482
+ # Get chats with no messages using a subquery - works in both SQLite and PostgreSQL
483
+ empty_chats = query.filter(
484
+ ~exists().where(Message.chat_id == Chat.chat_id)
485
+ ).all()
486
+
487
+ # Collect chat IDs to delete
488
+ chat_ids = [chat.chat_id for chat in empty_chats]
489
+
490
+ deleted_count = 0
491
+ if chat_ids:
492
+ # Update model_usage records to set chat_id to NULL for any associated usage records
493
+ session.query(ModelUsage).filter(ModelUsage.chat_id.in_(chat_ids)).update(
494
+ {"chat_id": None}, synchronize_session=False
495
+ )
496
+
497
+ # Delete the empty chats one by one to ensure proper relationship handling
498
+ for chat_id in chat_ids:
499
+ chat = session.query(Chat).filter(Chat.chat_id == chat_id).first()
500
+ if chat:
501
+ session.delete(chat)
502
+ deleted_count += 1
503
+
504
+ session.commit()
505
+
506
+ return deleted_count
507
+ except SQLAlchemyError as e:
508
+ session.rollback()
509
+ logger.log_message(f"Error deleting empty chats: {str(e)}", level=logging.ERROR)
510
+ return 0
511
+ finally:
512
+ session.close()
513
+
514
+ def get_usage_summary(self, start_date: Optional[datetime] = None,
515
+ end_date: Optional[datetime] = None) -> Dict[str, Any]:
516
+ """
517
+ Get a summary of model usage including total costs, tokens, and usage by model.
518
+
519
+ Args:
520
+ start_date: Optional start date for the summary period
521
+ end_date: Optional end date for the summary period
522
+
523
+ Returns:
524
+ Dictionary containing usage summary
525
+ """
526
+ session = self.Session()
527
+ try:
528
+ # Build base query - convert None values to default values for compatibility
529
+ base_query = session.query(ModelUsage)
530
+
531
+ # Apply date filters
532
+ if start_date:
533
+ base_query = base_query.filter(ModelUsage.timestamp >= start_date)
534
+ if end_date:
535
+ base_query = base_query.filter(ModelUsage.timestamp <= end_date)
536
+
537
+ # Get summary data using aggregate functions
538
+ summary_query = session.query(
539
+ func.coalesce(func.sum(ModelUsage.cost), 0.0).label("total_cost"),
540
+ func.coalesce(func.sum(ModelUsage.prompt_tokens), 0).label("total_prompt_tokens"),
541
+ func.coalesce(func.sum(ModelUsage.completion_tokens), 0).label("total_completion_tokens"),
542
+ func.coalesce(func.sum(ModelUsage.total_tokens), 0).label("total_tokens"),
543
+ func.count(ModelUsage.usage_id).label("request_count"),
544
+ func.coalesce(func.avg(ModelUsage.request_time_ms), 0.0).label("avg_request_time")
545
+ ).select_from(base_query.subquery())
546
+
547
+ result = summary_query.first()
548
+
549
+ # Get usage breakdown by model - using the same base query for consistency
550
+ model_query = session.query(
551
+ ModelUsage.model_name,
552
+ func.coalesce(func.sum(ModelUsage.cost), 0.0).label("model_cost"),
553
+ func.coalesce(func.sum(ModelUsage.total_tokens), 0).label("model_tokens"),
554
+ func.count(ModelUsage.usage_id).label("model_requests")
555
+ ).select_from(base_query.subquery()).group_by(ModelUsage.model_name)
556
+
557
+ model_breakdown = model_query.all()
558
+
559
+ # Get usage breakdown by provider using the same base query
560
+ provider_query = session.query(
561
+ ModelUsage.provider,
562
+ func.coalesce(func.sum(ModelUsage.cost), 0.0).label("provider_cost"),
563
+ func.coalesce(func.sum(ModelUsage.total_tokens), 0).label("provider_tokens"),
564
+ func.count(ModelUsage.usage_id).label("provider_requests")
565
+ ).select_from(base_query.subquery()).group_by(ModelUsage.provider)
566
+
567
+ provider_breakdown = provider_query.all()
568
+
569
+ # Get top users by cost
570
+ user_query = session.query(
571
+ ModelUsage.user_id,
572
+ func.coalesce(func.sum(ModelUsage.cost), 0.0).label("user_cost"),
573
+ func.coalesce(func.sum(ModelUsage.total_tokens), 0).label("user_tokens"),
574
+ func.count(ModelUsage.usage_id).label("user_requests")
575
+ ).select_from(base_query.subquery()).group_by(ModelUsage.user_id).order_by(
576
+ func.sum(ModelUsage.cost).desc()
577
+ ).limit(10)
578
+
579
+ user_breakdown = user_query.all()
580
+
581
+ # Handle the result data carefully to avoid None/NULL issues
582
+ return {
583
+ "summary": {
584
+ "total_cost": float(result.total_cost) if result and result.total_cost is not None else 0.0,
585
+ "total_prompt_tokens": int(result.total_prompt_tokens) if result and result.total_prompt_tokens is not None else 0,
586
+ "total_completion_tokens": int(result.total_completion_tokens) if result and result.total_completion_tokens is not None else 0,
587
+ "total_tokens": int(result.total_tokens) if result and result.total_tokens is not None else 0,
588
+ "request_count": int(result.request_count) if result and result.request_count is not None else 0,
589
+ "avg_request_time_ms": float(result.avg_request_time) if result and result.avg_request_time is not None else 0.0
590
+ },
591
+ "model_breakdown": [
592
+ {
593
+ "model_name": model.model_name,
594
+ "cost": float(model.model_cost) if model.model_cost is not None else 0.0,
595
+ "tokens": int(model.model_tokens) if model.model_tokens is not None else 0,
596
+ "requests": int(model.model_requests) if model.model_requests is not None else 0
597
+ } for model in model_breakdown
598
+ ],
599
+ "provider_breakdown": [
600
+ {
601
+ "provider": provider.provider,
602
+ "cost": float(provider.provider_cost) if provider.provider_cost is not None else 0.0,
603
+ "tokens": int(provider.provider_tokens) if provider.provider_tokens is not None else 0,
604
+ "requests": int(provider.provider_requests) if provider.provider_requests is not None else 0
605
+ } for provider in provider_breakdown
606
+ ],
607
+ "top_users": [
608
+ {
609
+ "user_id": user.user_id,
610
+ "cost": float(user.user_cost) if user.user_cost is not None else 0.0,
611
+ "tokens": int(user.user_tokens) if user.user_tokens is not None else 0,
612
+ "requests": int(user.user_requests) if user.user_requests is not None else 0
613
+ } for user in user_breakdown
614
+ ]
615
+ }
616
+
617
+ except SQLAlchemyError as e:
618
+ logger.log_message(f"Error retrieving usage summary: {str(e)}", level=logging.ERROR)
619
+ return {
620
+ "summary": {
621
+ "total_cost": 0.0,
622
+ "total_tokens": 0,
623
+ "request_count": 0
624
+ },
625
+ "model_breakdown": [],
626
+ "provider_breakdown": [],
627
+ "top_users": []
628
+ }
629
+ finally:
630
+ session.close()
631
+
632
+ def get_recent_chat_history(self, chat_id: int, limit: int = 5) -> List[Dict[str, Any]]:
633
+ """
634
+ Get recent message history for a chat, limited to the last 'limit' messages.
635
+
636
+ Args:
637
+ chat_id: ID of the chat to get history for
638
+ limit: Maximum number of recent messages to return
639
+
640
+ Returns:
641
+ List of dictionaries containing message information
642
+ """
643
+ session = self.Session()
644
+ try:
645
+ # Ensure safe limit for both databases
646
+ safe_limit = min(max(1, limit), 50) * 2 # Between 2 and 100 messages
647
+
648
+ # Use subquery for more efficient pagination in PostgreSQL
649
+ subquery = session.query(Message).filter(
650
+ Message.chat_id == chat_id
651
+ ).order_by(Message.timestamp.desc()).limit(safe_limit).subquery()
652
+
653
+ # Query from the subquery and sort in chronological order
654
+ messages = session.query(Message).from_statement(
655
+ session.query(subquery).order_by(subquery.c.timestamp).statement
656
+ ).all()
657
+
658
+ return [
659
+ {
660
+ "message_id": msg.message_id,
661
+ "chat_id": msg.chat_id,
662
+ "content": msg.content,
663
+ "sender": msg.sender,
664
+ "timestamp": msg.timestamp.isoformat() if msg.timestamp else None
665
+ } for msg in messages
666
+ ]
667
+ except SQLAlchemyError as e:
668
+ logger.log_message(f"Error retrieving chat history: {str(e)}", level=logging.ERROR)
669
+ return []
670
+ finally:
671
+ session.close()
672
+
673
+
674
+ def extract_response_history(self, messages: List[Dict[str, Any]]) -> str:
675
+ """
676
+ Extract response history from message history.
677
+
678
+ Args:
679
+ messages: List of message dictionaries
680
+
681
+ Returns:
682
+ String containing combined response history in a structured format
683
+ """
684
+
685
+ summaries = []
686
+ user_messages = []
687
+
688
+ # Input validation
689
+ if not messages or not isinstance(messages, list):
690
+ return ""
691
+
692
+ try:
693
+ for msg in messages:
694
+ # Skip invalid messages
695
+ if not isinstance(msg, dict):
696
+ continue
697
+
698
+ # Get User Messages
699
+ if msg.get("sender") == "user":
700
+ user_messages.append(msg)
701
+ # Ensure content exists and is from AI before extracting summary
702
+ if msg.get("sender") == "ai" and "content" in msg and msg["content"]:
703
+ content = msg["content"]
704
+ # Use a safer regex pattern with timeout protection
705
+ try:
706
+ matches = re.findall(r"### Summary\n(.*?)(?=\n\n##|\Z)", content, re.DOTALL)
707
+ summaries.extend(match.strip() for match in matches if match)
708
+ except Exception as e:
709
+ logger.log_message(f"Error extracting summaries: {str(e)}", level=logging.ERROR)
710
+
711
+ # Combine user messages with summaries in a structured format
712
+ combined_conversations = []
713
+ for i, user_msg in enumerate(user_messages):
714
+ if i < len(summaries):
715
+ # Ensure content exists and is not too long
716
+ user_content = user_msg.get('content', '')
717
+ if user_content and isinstance(user_content, str):
718
+ # Truncate if needed
719
+ if len(user_content) > 500:
720
+ user_content = user_content[:497] + "..."
721
+
722
+ summary = summaries[i]
723
+ if len(summary) > 500:
724
+ summary = summary[:497] + "..."
725
+
726
+ combined_conversations.append(f"Query: {user_content}\nSummary: {summary}")
727
+
728
+ # Return the last 3 conversations to maintain context
729
+ formatted_context = "\n\n".join(combined_conversations[-3:])
730
+
731
+ # Add a clear header to indicate this is past interaction history
732
+ if formatted_context:
733
+ return f"### Previous Interaction History:\n{formatted_context}"
734
+ return ""
735
+ except Exception as e:
736
+ logger.log_message(f"Error in extract_response_history: {str(e)}", level=logging.ERROR)
737
+ return ""