Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update src/managers/ai_manager.py
Browse files- src/managers/ai_manager.py +86 -137
src/managers/ai_manager.py
CHANGED
@@ -1,137 +1,86 @@
|
|
1 |
-
import logging
|
2 |
-
from typing import Optional, Dict, Any
|
3 |
-
import time
|
4 |
-
from src.db.schemas.models import ModelUsage
|
5 |
-
from src.db.init_db import session_factory
|
6 |
-
from datetime import datetime
|
7 |
-
import tiktoken
|
8 |
-
from src.routes.analytics_routes import handle_new_model_usage
|
9 |
-
import asyncio
|
10 |
-
|
11 |
-
from src.utils.logger import Logger
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
completion_tokens=completion_tokens,
|
88 |
-
total_tokens=total_tokens,
|
89 |
-
query_size=query_size,
|
90 |
-
response_size=response_size,
|
91 |
-
cost=cost,
|
92 |
-
timestamp=datetime.utcnow(),
|
93 |
-
is_streaming=is_streaming,
|
94 |
-
request_time_ms=request_time_ms
|
95 |
-
)
|
96 |
-
|
97 |
-
session.add(usage)
|
98 |
-
session.commit()
|
99 |
-
# logger.info(f"Saved usage data to database for chat {chat_id}: {total_tokens} tokens, ${cost:.6f}")
|
100 |
-
|
101 |
-
# Broadcast the event asynchronously
|
102 |
-
asyncio.create_task(handle_new_model_usage(usage))
|
103 |
-
|
104 |
-
except Exception as e:
|
105 |
-
session.rollback()
|
106 |
-
logger.log_message(f"Error saving usage data to database for chat {chat_id}: {str(e)}", level=logging.ERROR)
|
107 |
-
finally:
|
108 |
-
session.close()
|
109 |
-
|
110 |
-
def calculate_cost(self, model_name, input_tokens, output_tokens):
|
111 |
-
"""Calculate the cost for using the model based on tokens"""
|
112 |
-
if not model_name:
|
113 |
-
return 0
|
114 |
-
|
115 |
-
# Convert tokens to thousands
|
116 |
-
input_tokens_in_thousands = input_tokens / 1000
|
117 |
-
output_tokens_in_thousands = output_tokens / 1000
|
118 |
-
|
119 |
-
# Default cost if model not found
|
120 |
-
model_provider = self.get_provider_for_model(model_name)
|
121 |
-
# logger.log_message(f"[> ] Model Name: {model_name}, Model Provider: {model_provider}")
|
122 |
-
|
123 |
-
return input_tokens_in_thousands * costs[model_provider][model_name]["input"] + output_tokens_in_thousands * costs[model_provider][model_name]["output"]
|
124 |
-
|
125 |
-
def get_provider_for_model(self, model_name):
|
126 |
-
"""Determine the provider based on model name"""
|
127 |
-
if not model_name:
|
128 |
-
return "Unknown"
|
129 |
-
|
130 |
-
model_name = model_name.lower()
|
131 |
-
return next((provider for provider, models in costs.items()
|
132 |
-
if any(model_name in model for model in models)), "Unknown")
|
133 |
-
|
134 |
-
class SimpleTokenizer:
|
135 |
-
"""A very simple tokenizer implementation for fallback"""
|
136 |
-
def encode(self, text):
|
137 |
-
return len(text.split())
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional, Dict, Any
|
3 |
+
import time
|
4 |
+
from src.db.schemas.models import ModelUsage
|
5 |
+
from src.db.init_db import session_factory
|
6 |
+
from datetime import datetime
|
7 |
+
import tiktoken
|
8 |
+
from src.routes.analytics_routes import handle_new_model_usage
|
9 |
+
import asyncio
|
10 |
+
|
11 |
+
from src.utils.logger import Logger
|
12 |
+
from src.utils.model_registry import get_provider_for_model, calculate_cost
|
13 |
+
|
14 |
+
logger = Logger(name="ai_manager", see_time=True, console_log=True)
|
15 |
+
|
16 |
+
class AI_Manager:
|
17 |
+
"""Manages AI model interactions and usage tracking"""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self.tokenizer = None
|
21 |
+
# Initialize tokenizer - could use tiktoken or another tokenizer
|
22 |
+
try:
|
23 |
+
import tiktoken
|
24 |
+
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
25 |
+
except ImportError:
|
26 |
+
logger.log_message("Tiktoken not available, using simple tokenizer", level=logging.WARNING)
|
27 |
+
self.tokenizer = SimpleTokenizer()
|
28 |
+
|
29 |
+
def save_usage_to_db(self, user_id, chat_id, model_name, provider,
|
30 |
+
prompt_tokens, completion_tokens, total_tokens,
|
31 |
+
query_size, response_size, cost, request_time_ms,
|
32 |
+
is_streaming=False):
|
33 |
+
"""Save model usage data to the database"""
|
34 |
+
try:
|
35 |
+
session = session_factory()
|
36 |
+
|
37 |
+
usage = ModelUsage(
|
38 |
+
user_id=user_id,
|
39 |
+
chat_id=chat_id,
|
40 |
+
model_name=model_name,
|
41 |
+
provider=provider,
|
42 |
+
prompt_tokens=prompt_tokens,
|
43 |
+
completion_tokens=completion_tokens,
|
44 |
+
total_tokens=total_tokens,
|
45 |
+
query_size=query_size,
|
46 |
+
response_size=response_size,
|
47 |
+
cost=cost,
|
48 |
+
timestamp=datetime.utcnow(),
|
49 |
+
is_streaming=is_streaming,
|
50 |
+
request_time_ms=request_time_ms
|
51 |
+
)
|
52 |
+
|
53 |
+
session.add(usage)
|
54 |
+
session.commit()
|
55 |
+
# logger.info(f"Saved usage data to database for chat {chat_id}: {total_tokens} tokens, ${cost:.6f}")
|
56 |
+
|
57 |
+
# Broadcast the event asynchronously
|
58 |
+
asyncio.create_task(handle_new_model_usage(usage))
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
session.rollback()
|
62 |
+
logger.log_message(f"Error saving usage data to database for chat {chat_id}: {str(e)}", level=logging.ERROR)
|
63 |
+
finally:
|
64 |
+
session.close()
|
65 |
+
|
66 |
+
def calculate_cost(self, model_name, input_tokens, output_tokens):
|
67 |
+
"""Calculate the cost for using the model based on tokens"""
|
68 |
+
if not model_name:
|
69 |
+
return 0
|
70 |
+
|
71 |
+
# Get provider for logging
|
72 |
+
model_provider = get_provider_for_model(model_name)
|
73 |
+
logger.log_message(f"[> ] Model Name: {model_name}, Model Provider: {model_provider}", level=logging.INFO)
|
74 |
+
|
75 |
+
# Use the centralized calculate_cost function
|
76 |
+
return calculate_cost(model_name, input_tokens, output_tokens)
|
77 |
+
|
78 |
+
def get_provider_for_model(self, model_name):
|
79 |
+
"""Determine the provider based on model name"""
|
80 |
+
# Use the centralized get_provider_for_model function
|
81 |
+
return get_provider_for_model(model_name)
|
82 |
+
|
83 |
+
class SimpleTokenizer:
|
84 |
+
"""A very simple tokenizer implementation for fallback"""
|
85 |
+
def encode(self, text):
|
86 |
+
return len(text.split())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|