FireBird-Tech commited on
Commit
15eda99
·
verified ·
1 Parent(s): 4757044

Update src/managers/ai_manager.py

Browse files
Files changed (1) hide show
  1. src/managers/ai_manager.py +86 -137
src/managers/ai_manager.py CHANGED
@@ -1,137 +1,86 @@
1
- import logging
2
- from typing import Optional, Dict, Any
3
- import time
4
- from src.db.schemas.models import ModelUsage
5
- from src.db.init_db import session_factory
6
- from datetime import datetime
7
- import tiktoken
8
- from src.routes.analytics_routes import handle_new_model_usage
9
- import asyncio
10
-
11
- from src.utils.logger import Logger
12
-
13
- logger = Logger(name="ai_manager", see_time=True, console_log=True)
14
-
15
- # Cost per 1K tokens for different models
16
- costs = {
17
- "openai": {
18
- "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
19
- "gpt-4.5-preview": {"input": 0.075, "output": 0.15},
20
- "gpt-4": {"input": 0.03, "output": 0.06},
21
- "gpt-4o": {"input": 0.0025, "output": 0.01},
22
- "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
23
- "o1": {"input": 0.015, "output": 0.06},
24
- "o1-mini": {"input": 0.00011, "output": 0.00044},
25
- "o3-mini": {"input": 0.00011, "output": 0.00044}
26
- },
27
- "anthropic": {
28
- "claude-3-opus-latest": {"input": 0.015, "output": 0.075},
29
- "claude-3-7-sonnet-latest": {"input": 0.003, "output": 0.015},
30
- "claude-3-5-sonnet-latest": {"input": 0.003, "output": 0.015},
31
- "claude-3-5-haiku-latest": {"input": 0.0008, "output": 0.0004},
32
- },
33
- "groq": {
34
- "deepseek-r1-distill-qwen-32b": {"input": 0.00075, "output": 0.00099},
35
- "deepseek-r1-distill-llama-70b": {"input": 0.00075, "output": 0.00099},
36
- "llama-3.3-70b-versatile": {"input": 0.00059, "output": 0.00079},
37
- "llama-3.3-70b-specdec": {"input": 0.00059, "output": 0.00099},
38
- "llama2-70b-4096": {"input": 0.0007, "output": 0.0008},
39
- "llama3-8b-8192": {"input": 0.00005, "output": 0.00008},
40
- "llama-3.2-1b-preview": {"input": 0.00004, "output": 0.00004},
41
- "llama-3.2-3b-preview": {"input": 0.00006, "output": 0.00006},
42
- "llama-3.2-11b-text-preview": {"input": 0.00018, "output": 0.00018},
43
- "llama-3.2-11b-vision-preview": {"input": 0.00018, "output": 0.00018},
44
- "llama-3.2-90b-text-preview": {"input": 0.0009, "output": 0.0009},
45
- "llama-3.2-90b-vision-preview": {"input": 0.0009, "output": 0.0009},
46
- "llama3-70b-8192": {"input": 0.00059, "output": 0.00079},
47
- "llama-3.1-8b-instant": {"input": 0.00005, "output": 0.00008},
48
- "llama-3.1-70b-versatile": {"input": 0.00059, "output": 0.00079},
49
- "llama-3.1-405b-reasoning": {"input": 0.00059, "output": 0.00079},
50
- "mixtral-8x7b-32768": {"input": 0.00024, "output": 0.00024},
51
- "gemma-7b-it": {"input": 0.00007, "output": 0.00007},
52
- "gemma2-9b-it": {"input": 0.0002, "output": 0.0002},
53
- "llama3-groq-70b-8192-tool-use-preview": {"input": 0.00089, "output": 0.00089},
54
- "llama3-groq-8b-8192-tool-use-preview": {"input": 0.00019, "output": 0.00019},
55
- "qwen-2.5-coder-32b": {"input": 0.0015, "output": 0.003}
56
- }
57
- }
58
-
59
-
60
- class AI_Manager:
61
- """Manages AI model interactions and usage tracking"""
62
-
63
- def __init__(self):
64
- self.tokenizer = None
65
- # Initialize tokenizer - could use tiktoken or another tokenizer
66
- try:
67
- import tiktoken
68
- self.tokenizer = tiktoken.get_encoding("cl100k_base")
69
- except ImportError:
70
- logger.log_message("Tiktoken not available, using simple tokenizer", level=logging.WARNING)
71
- self.tokenizer = SimpleTokenizer()
72
-
73
- def save_usage_to_db(self, user_id, chat_id, model_name, provider,
74
- prompt_tokens, completion_tokens, total_tokens,
75
- query_size, response_size, cost, request_time_ms,
76
- is_streaming=False):
77
- """Save model usage data to the database"""
78
- try:
79
- session = session_factory()
80
-
81
- usage = ModelUsage(
82
- user_id=user_id,
83
- chat_id=chat_id,
84
- model_name=model_name,
85
- provider=provider,
86
- prompt_tokens=prompt_tokens,
87
- completion_tokens=completion_tokens,
88
- total_tokens=total_tokens,
89
- query_size=query_size,
90
- response_size=response_size,
91
- cost=cost,
92
- timestamp=datetime.utcnow(),
93
- is_streaming=is_streaming,
94
- request_time_ms=request_time_ms
95
- )
96
-
97
- session.add(usage)
98
- session.commit()
99
- # logger.info(f"Saved usage data to database for chat {chat_id}: {total_tokens} tokens, ${cost:.6f}")
100
-
101
- # Broadcast the event asynchronously
102
- asyncio.create_task(handle_new_model_usage(usage))
103
-
104
- except Exception as e:
105
- session.rollback()
106
- logger.log_message(f"Error saving usage data to database for chat {chat_id}: {str(e)}", level=logging.ERROR)
107
- finally:
108
- session.close()
109
-
110
- def calculate_cost(self, model_name, input_tokens, output_tokens):
111
- """Calculate the cost for using the model based on tokens"""
112
- if not model_name:
113
- return 0
114
-
115
- # Convert tokens to thousands
116
- input_tokens_in_thousands = input_tokens / 1000
117
- output_tokens_in_thousands = output_tokens / 1000
118
-
119
- # Default cost if model not found
120
- model_provider = self.get_provider_for_model(model_name)
121
- # logger.log_message(f"[> ] Model Name: {model_name}, Model Provider: {model_provider}")
122
-
123
- return input_tokens_in_thousands * costs[model_provider][model_name]["input"] + output_tokens_in_thousands * costs[model_provider][model_name]["output"]
124
-
125
- def get_provider_for_model(self, model_name):
126
- """Determine the provider based on model name"""
127
- if not model_name:
128
- return "Unknown"
129
-
130
- model_name = model_name.lower()
131
- return next((provider for provider, models in costs.items()
132
- if any(model_name in model for model in models)), "Unknown")
133
-
134
- class SimpleTokenizer:
135
- """A very simple tokenizer implementation for fallback"""
136
- def encode(self, text):
137
- return len(text.split())
 
1
+ import logging
2
+ from typing import Optional, Dict, Any
3
+ import time
4
+ from src.db.schemas.models import ModelUsage
5
+ from src.db.init_db import session_factory
6
+ from datetime import datetime
7
+ import tiktoken
8
+ from src.routes.analytics_routes import handle_new_model_usage
9
+ import asyncio
10
+
11
+ from src.utils.logger import Logger
12
+ from src.utils.model_registry import get_provider_for_model, calculate_cost
13
+
14
+ logger = Logger(name="ai_manager", see_time=True, console_log=True)
15
+
16
+ class AI_Manager:
17
+ """Manages AI model interactions and usage tracking"""
18
+
19
+ def __init__(self):
20
+ self.tokenizer = None
21
+ # Initialize tokenizer - could use tiktoken or another tokenizer
22
+ try:
23
+ import tiktoken
24
+ self.tokenizer = tiktoken.get_encoding("cl100k_base")
25
+ except ImportError:
26
+ logger.log_message("Tiktoken not available, using simple tokenizer", level=logging.WARNING)
27
+ self.tokenizer = SimpleTokenizer()
28
+
29
+ def save_usage_to_db(self, user_id, chat_id, model_name, provider,
30
+ prompt_tokens, completion_tokens, total_tokens,
31
+ query_size, response_size, cost, request_time_ms,
32
+ is_streaming=False):
33
+ """Save model usage data to the database"""
34
+ try:
35
+ session = session_factory()
36
+
37
+ usage = ModelUsage(
38
+ user_id=user_id,
39
+ chat_id=chat_id,
40
+ model_name=model_name,
41
+ provider=provider,
42
+ prompt_tokens=prompt_tokens,
43
+ completion_tokens=completion_tokens,
44
+ total_tokens=total_tokens,
45
+ query_size=query_size,
46
+ response_size=response_size,
47
+ cost=cost,
48
+ timestamp=datetime.utcnow(),
49
+ is_streaming=is_streaming,
50
+ request_time_ms=request_time_ms
51
+ )
52
+
53
+ session.add(usage)
54
+ session.commit()
55
+ # logger.info(f"Saved usage data to database for chat {chat_id}: {total_tokens} tokens, ${cost:.6f}")
56
+
57
+ # Broadcast the event asynchronously
58
+ asyncio.create_task(handle_new_model_usage(usage))
59
+
60
+ except Exception as e:
61
+ session.rollback()
62
+ logger.log_message(f"Error saving usage data to database for chat {chat_id}: {str(e)}", level=logging.ERROR)
63
+ finally:
64
+ session.close()
65
+
66
+ def calculate_cost(self, model_name, input_tokens, output_tokens):
67
+ """Calculate the cost for using the model based on tokens"""
68
+ if not model_name:
69
+ return 0
70
+
71
+ # Get provider for logging
72
+ model_provider = get_provider_for_model(model_name)
73
+ logger.log_message(f"[> ] Model Name: {model_name}, Model Provider: {model_provider}", level=logging.INFO)
74
+
75
+ # Use the centralized calculate_cost function
76
+ return calculate_cost(model_name, input_tokens, output_tokens)
77
+
78
+ def get_provider_for_model(self, model_name):
79
+ """Determine the provider based on model name"""
80
+ # Use the centralized get_provider_for_model function
81
+ return get_provider_for_model(model_name)
82
+
83
+ class SimpleTokenizer:
84
+ """A very simple tokenizer implementation for fallback"""
85
+ def encode(self, text):
86
+ return len(text.split())